##// END OF EJS Templates
python-zstandard: blacken at 80 characters...
Gregory Szorc -
r44605:5e84a96d default
parent child Browse files
Show More
@@ -1,15 +1,14 b''
1 [tool.black]
1 [tool.black]
2 line-length = 80
2 line-length = 80
3 exclude = '''
3 exclude = '''
4 build/
4 build/
5 | wheelhouse/
5 | wheelhouse/
6 | dist/
6 | dist/
7 | packages/
7 | packages/
8 | \.hg/
8 | \.hg/
9 | \.mypy_cache/
9 | \.mypy_cache/
10 | \.venv/
10 | \.venv/
11 | mercurial/thirdparty/
11 | mercurial/thirdparty/
12 | contrib/python-zstandard/
13 '''
12 '''
14 skip-string-normalization = true
13 skip-string-normalization = true
15 quiet = true
14 quiet = true
@@ -1,14 +1,14 b''
1 [fix]
1 [fix]
2 clang-format:command = clang-format --style file
2 clang-format:command = clang-format --style file
3 clang-format:pattern = set:(**.c or **.cc or **.h) and not "include:contrib/clang-format-ignorelist"
3 clang-format:pattern = set:(**.c or **.cc or **.h) and not "include:contrib/clang-format-ignorelist"
4
4
5 rustfmt:command = rustfmt +nightly
5 rustfmt:command = rustfmt +nightly
6 rustfmt:pattern = set:**.rs
6 rustfmt:pattern = set:**.rs
7
7
8 black:command = black --config=black.toml -
8 black:command = black --config=black.toml -
9 black:pattern = set:**.py - mercurial/thirdparty/** - "contrib/python-zstandard/**"
9 black:pattern = set:**.py - mercurial/thirdparty/**
10
10
11 # Mercurial doesn't have any Go code, but if we did this is how we
11 # Mercurial doesn't have any Go code, but if we did this is how we
12 # would configure `hg fix` for Go:
12 # would configure `hg fix` for Go:
13 go:command = gofmt
13 go:command = gofmt
14 go:pattern = set:**.go
14 go:pattern = set:**.go
@@ -1,225 +1,228 b''
1 # Copyright (c) 2016-present, Gregory Szorc
1 # Copyright (c) 2016-present, Gregory Szorc
2 # All rights reserved.
2 # All rights reserved.
3 #
3 #
4 # This software may be modified and distributed under the terms
4 # This software may be modified and distributed under the terms
5 # of the BSD license. See the LICENSE file for details.
5 # of the BSD license. See the LICENSE file for details.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import cffi
9 import cffi
10 import distutils.ccompiler
10 import distutils.ccompiler
11 import os
11 import os
12 import re
12 import re
13 import subprocess
13 import subprocess
14 import tempfile
14 import tempfile
15
15
16
16
17 HERE = os.path.abspath(os.path.dirname(__file__))
17 HERE = os.path.abspath(os.path.dirname(__file__))
18
18
19 SOURCES = [
19 SOURCES = [
20 "zstd/%s" % p
20 "zstd/%s" % p
21 for p in (
21 for p in (
22 "common/debug.c",
22 "common/debug.c",
23 "common/entropy_common.c",
23 "common/entropy_common.c",
24 "common/error_private.c",
24 "common/error_private.c",
25 "common/fse_decompress.c",
25 "common/fse_decompress.c",
26 "common/pool.c",
26 "common/pool.c",
27 "common/threading.c",
27 "common/threading.c",
28 "common/xxhash.c",
28 "common/xxhash.c",
29 "common/zstd_common.c",
29 "common/zstd_common.c",
30 "compress/fse_compress.c",
30 "compress/fse_compress.c",
31 "compress/hist.c",
31 "compress/hist.c",
32 "compress/huf_compress.c",
32 "compress/huf_compress.c",
33 "compress/zstd_compress.c",
33 "compress/zstd_compress.c",
34 "compress/zstd_compress_literals.c",
34 "compress/zstd_compress_literals.c",
35 "compress/zstd_compress_sequences.c",
35 "compress/zstd_compress_sequences.c",
36 "compress/zstd_double_fast.c",
36 "compress/zstd_double_fast.c",
37 "compress/zstd_fast.c",
37 "compress/zstd_fast.c",
38 "compress/zstd_lazy.c",
38 "compress/zstd_lazy.c",
39 "compress/zstd_ldm.c",
39 "compress/zstd_ldm.c",
40 "compress/zstd_opt.c",
40 "compress/zstd_opt.c",
41 "compress/zstdmt_compress.c",
41 "compress/zstdmt_compress.c",
42 "decompress/huf_decompress.c",
42 "decompress/huf_decompress.c",
43 "decompress/zstd_ddict.c",
43 "decompress/zstd_ddict.c",
44 "decompress/zstd_decompress.c",
44 "decompress/zstd_decompress.c",
45 "decompress/zstd_decompress_block.c",
45 "decompress/zstd_decompress_block.c",
46 "dictBuilder/cover.c",
46 "dictBuilder/cover.c",
47 "dictBuilder/fastcover.c",
47 "dictBuilder/fastcover.c",
48 "dictBuilder/divsufsort.c",
48 "dictBuilder/divsufsort.c",
49 "dictBuilder/zdict.c",
49 "dictBuilder/zdict.c",
50 )
50 )
51 ]
51 ]
52
52
53 # Headers whose preprocessed output will be fed into cdef().
53 # Headers whose preprocessed output will be fed into cdef().
54 HEADERS = [
54 HEADERS = [
55 os.path.join(HERE, "zstd", *p) for p in (("zstd.h",), ("dictBuilder", "zdict.h"),)
55 os.path.join(HERE, "zstd", *p)
56 for p in (("zstd.h",), ("dictBuilder", "zdict.h"),)
56 ]
57 ]
57
58
58 INCLUDE_DIRS = [
59 INCLUDE_DIRS = [
59 os.path.join(HERE, d)
60 os.path.join(HERE, d)
60 for d in (
61 for d in (
61 "zstd",
62 "zstd",
62 "zstd/common",
63 "zstd/common",
63 "zstd/compress",
64 "zstd/compress",
64 "zstd/decompress",
65 "zstd/decompress",
65 "zstd/dictBuilder",
66 "zstd/dictBuilder",
66 )
67 )
67 ]
68 ]
68
69
69 # cffi can't parse some of the primitives in zstd.h. So we invoke the
70 # cffi can't parse some of the primitives in zstd.h. So we invoke the
70 # preprocessor and feed its output into cffi.
71 # preprocessor and feed its output into cffi.
71 compiler = distutils.ccompiler.new_compiler()
72 compiler = distutils.ccompiler.new_compiler()
72
73
73 # Needed for MSVC.
74 # Needed for MSVC.
74 if hasattr(compiler, "initialize"):
75 if hasattr(compiler, "initialize"):
75 compiler.initialize()
76 compiler.initialize()
76
77
77 # Distutils doesn't set compiler.preprocessor, so invoke the preprocessor
78 # Distutils doesn't set compiler.preprocessor, so invoke the preprocessor
78 # manually.
79 # manually.
79 if compiler.compiler_type == "unix":
80 if compiler.compiler_type == "unix":
80 args = list(compiler.executables["compiler"])
81 args = list(compiler.executables["compiler"])
81 args.extend(
82 args.extend(
82 ["-E", "-DZSTD_STATIC_LINKING_ONLY", "-DZDICT_STATIC_LINKING_ONLY",]
83 ["-E", "-DZSTD_STATIC_LINKING_ONLY", "-DZDICT_STATIC_LINKING_ONLY",]
83 )
84 )
84 elif compiler.compiler_type == "msvc":
85 elif compiler.compiler_type == "msvc":
85 args = [compiler.cc]
86 args = [compiler.cc]
86 args.extend(
87 args.extend(
87 ["/EP", "/DZSTD_STATIC_LINKING_ONLY", "/DZDICT_STATIC_LINKING_ONLY",]
88 ["/EP", "/DZSTD_STATIC_LINKING_ONLY", "/DZDICT_STATIC_LINKING_ONLY",]
88 )
89 )
89 else:
90 else:
90 raise Exception("unsupported compiler type: %s" % compiler.compiler_type)
91 raise Exception("unsupported compiler type: %s" % compiler.compiler_type)
91
92
92
93
93 def preprocess(path):
94 def preprocess(path):
94 with open(path, "rb") as fh:
95 with open(path, "rb") as fh:
95 lines = []
96 lines = []
96 it = iter(fh)
97 it = iter(fh)
97
98
98 for l in it:
99 for l in it:
99 # zstd.h includes <stddef.h>, which is also included by cffi's
100 # zstd.h includes <stddef.h>, which is also included by cffi's
100 # boilerplate. This can lead to duplicate declarations. So we strip
101 # boilerplate. This can lead to duplicate declarations. So we strip
101 # this include from the preprocessor invocation.
102 # this include from the preprocessor invocation.
102 #
103 #
103 # The same things happens for including zstd.h, so give it the same
104 # The same things happens for including zstd.h, so give it the same
104 # treatment.
105 # treatment.
105 #
106 #
106 # We define ZSTD_STATIC_LINKING_ONLY, which is redundant with the inline
107 # We define ZSTD_STATIC_LINKING_ONLY, which is redundant with the inline
107 # #define in zstdmt_compress.h and results in a compiler warning. So drop
108 # #define in zstdmt_compress.h and results in a compiler warning. So drop
108 # the inline #define.
109 # the inline #define.
109 if l.startswith(
110 if l.startswith(
110 (
111 (
111 b"#include <stddef.h>",
112 b"#include <stddef.h>",
112 b'#include "zstd.h"',
113 b'#include "zstd.h"',
113 b"#define ZSTD_STATIC_LINKING_ONLY",
114 b"#define ZSTD_STATIC_LINKING_ONLY",
114 )
115 )
115 ):
116 ):
116 continue
117 continue
117
118
118 # The preprocessor environment on Windows doesn't define include
119 # The preprocessor environment on Windows doesn't define include
119 # paths, so the #include of limits.h fails. We work around this
120 # paths, so the #include of limits.h fails. We work around this
120 # by removing that import and defining INT_MAX ourselves. This is
121 # by removing that import and defining INT_MAX ourselves. This is
121 # a bit hacky. But it gets the job done.
122 # a bit hacky. But it gets the job done.
122 # TODO make limits.h work on Windows so we ensure INT_MAX is
123 # TODO make limits.h work on Windows so we ensure INT_MAX is
123 # correct.
124 # correct.
124 if l.startswith(b"#include <limits.h>"):
125 if l.startswith(b"#include <limits.h>"):
125 l = b"#define INT_MAX 2147483647\n"
126 l = b"#define INT_MAX 2147483647\n"
126
127
127 # ZSTDLIB_API may not be defined if we dropped zstd.h. It isn't
128 # ZSTDLIB_API may not be defined if we dropped zstd.h. It isn't
128 # important so just filter it out.
129 # important so just filter it out.
129 if l.startswith(b"ZSTDLIB_API"):
130 if l.startswith(b"ZSTDLIB_API"):
130 l = l[len(b"ZSTDLIB_API ") :]
131 l = l[len(b"ZSTDLIB_API ") :]
131
132
132 lines.append(l)
133 lines.append(l)
133
134
134 fd, input_file = tempfile.mkstemp(suffix=".h")
135 fd, input_file = tempfile.mkstemp(suffix=".h")
135 os.write(fd, b"".join(lines))
136 os.write(fd, b"".join(lines))
136 os.close(fd)
137 os.close(fd)
137
138
138 try:
139 try:
139 env = dict(os.environ)
140 env = dict(os.environ)
140 if getattr(compiler, "_paths", None):
141 if getattr(compiler, "_paths", None):
141 env["PATH"] = compiler._paths
142 env["PATH"] = compiler._paths
142 process = subprocess.Popen(args + [input_file], stdout=subprocess.PIPE, env=env)
143 process = subprocess.Popen(
144 args + [input_file], stdout=subprocess.PIPE, env=env
145 )
143 output = process.communicate()[0]
146 output = process.communicate()[0]
144 ret = process.poll()
147 ret = process.poll()
145 if ret:
148 if ret:
146 raise Exception("preprocessor exited with error")
149 raise Exception("preprocessor exited with error")
147
150
148 return output
151 return output
149 finally:
152 finally:
150 os.unlink(input_file)
153 os.unlink(input_file)
151
154
152
155
153 def normalize_output(output):
156 def normalize_output(output):
154 lines = []
157 lines = []
155 for line in output.splitlines():
158 for line in output.splitlines():
156 # CFFI's parser doesn't like __attribute__ on UNIX compilers.
159 # CFFI's parser doesn't like __attribute__ on UNIX compilers.
157 if line.startswith(b'__attribute__ ((visibility ("default"))) '):
160 if line.startswith(b'__attribute__ ((visibility ("default"))) '):
158 line = line[len(b'__attribute__ ((visibility ("default"))) ') :]
161 line = line[len(b'__attribute__ ((visibility ("default"))) ') :]
159
162
160 if line.startswith(b"__attribute__((deprecated("):
163 if line.startswith(b"__attribute__((deprecated("):
161 continue
164 continue
162 elif b"__declspec(deprecated(" in line:
165 elif b"__declspec(deprecated(" in line:
163 continue
166 continue
164
167
165 lines.append(line)
168 lines.append(line)
166
169
167 return b"\n".join(lines)
170 return b"\n".join(lines)
168
171
169
172
170 ffi = cffi.FFI()
173 ffi = cffi.FFI()
171 # zstd.h uses a possible undefined MIN(). Define it until
174 # zstd.h uses a possible undefined MIN(). Define it until
172 # https://github.com/facebook/zstd/issues/976 is fixed.
175 # https://github.com/facebook/zstd/issues/976 is fixed.
173 # *_DISABLE_DEPRECATE_WARNINGS prevents the compiler from emitting a warning
176 # *_DISABLE_DEPRECATE_WARNINGS prevents the compiler from emitting a warning
174 # when cffi uses the function. Since we statically link against zstd, even
177 # when cffi uses the function. Since we statically link against zstd, even
175 # if we use the deprecated functions it shouldn't be a huge problem.
178 # if we use the deprecated functions it shouldn't be a huge problem.
176 ffi.set_source(
179 ffi.set_source(
177 "_zstd_cffi",
180 "_zstd_cffi",
178 """
181 """
179 #define MIN(a,b) ((a)<(b) ? (a) : (b))
182 #define MIN(a,b) ((a)<(b) ? (a) : (b))
180 #define ZSTD_STATIC_LINKING_ONLY
183 #define ZSTD_STATIC_LINKING_ONLY
181 #include <zstd.h>
184 #include <zstd.h>
182 #define ZDICT_STATIC_LINKING_ONLY
185 #define ZDICT_STATIC_LINKING_ONLY
183 #define ZDICT_DISABLE_DEPRECATE_WARNINGS
186 #define ZDICT_DISABLE_DEPRECATE_WARNINGS
184 #include <zdict.h>
187 #include <zdict.h>
185 """,
188 """,
186 sources=SOURCES,
189 sources=SOURCES,
187 include_dirs=INCLUDE_DIRS,
190 include_dirs=INCLUDE_DIRS,
188 extra_compile_args=["-DZSTD_MULTITHREAD"],
191 extra_compile_args=["-DZSTD_MULTITHREAD"],
189 )
192 )
190
193
191 DEFINE = re.compile(b"^\\#define ([a-zA-Z0-9_]+) ")
194 DEFINE = re.compile(b"^\\#define ([a-zA-Z0-9_]+) ")
192
195
193 sources = []
196 sources = []
194
197
195 # Feed normalized preprocessor output for headers into the cdef parser.
198 # Feed normalized preprocessor output for headers into the cdef parser.
196 for header in HEADERS:
199 for header in HEADERS:
197 preprocessed = preprocess(header)
200 preprocessed = preprocess(header)
198 sources.append(normalize_output(preprocessed))
201 sources.append(normalize_output(preprocessed))
199
202
200 # #define's are effectively erased as part of going through preprocessor.
203 # #define's are effectively erased as part of going through preprocessor.
201 # So perform a manual pass to re-add those to the cdef source.
204 # So perform a manual pass to re-add those to the cdef source.
202 with open(header, "rb") as fh:
205 with open(header, "rb") as fh:
203 for line in fh:
206 for line in fh:
204 line = line.strip()
207 line = line.strip()
205 m = DEFINE.match(line)
208 m = DEFINE.match(line)
206 if not m:
209 if not m:
207 continue
210 continue
208
211
209 if m.group(1) == b"ZSTD_STATIC_LINKING_ONLY":
212 if m.group(1) == b"ZSTD_STATIC_LINKING_ONLY":
210 continue
213 continue
211
214
212 # The parser doesn't like some constants with complex values.
215 # The parser doesn't like some constants with complex values.
213 if m.group(1) in (b"ZSTD_LIB_VERSION", b"ZSTD_VERSION_STRING"):
216 if m.group(1) in (b"ZSTD_LIB_VERSION", b"ZSTD_VERSION_STRING"):
214 continue
217 continue
215
218
216 # The ... is magic syntax by the cdef parser to resolve the
219 # The ... is magic syntax by the cdef parser to resolve the
217 # value at compile time.
220 # value at compile time.
218 sources.append(m.group(0) + b" ...")
221 sources.append(m.group(0) + b" ...")
219
222
220 cdeflines = b"\n".join(sources).splitlines()
223 cdeflines = b"\n".join(sources).splitlines()
221 cdeflines = [l for l in cdeflines if l.strip()]
224 cdeflines = [l for l in cdeflines if l.strip()]
222 ffi.cdef(b"\n".join(cdeflines).decode("latin1"))
225 ffi.cdef(b"\n".join(cdeflines).decode("latin1"))
223
226
224 if __name__ == "__main__":
227 if __name__ == "__main__":
225 ffi.compile()
228 ffi.compile()
@@ -1,118 +1,120 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # Copyright (c) 2016-present, Gregory Szorc
2 # Copyright (c) 2016-present, Gregory Szorc
3 # All rights reserved.
3 # All rights reserved.
4 #
4 #
5 # This software may be modified and distributed under the terms
5 # This software may be modified and distributed under the terms
6 # of the BSD license. See the LICENSE file for details.
6 # of the BSD license. See the LICENSE file for details.
7
7
8 from __future__ import print_function
8 from __future__ import print_function
9
9
10 from distutils.version import LooseVersion
10 from distutils.version import LooseVersion
11 import os
11 import os
12 import sys
12 import sys
13 from setuptools import setup
13 from setuptools import setup
14
14
15 # Need change in 1.10 for ffi.from_buffer() to handle all buffer types
15 # Need change in 1.10 for ffi.from_buffer() to handle all buffer types
16 # (like memoryview).
16 # (like memoryview).
17 # Need feature in 1.11 for ffi.gc() to declare size of objects so we avoid
17 # Need feature in 1.11 for ffi.gc() to declare size of objects so we avoid
18 # garbage collection pitfalls.
18 # garbage collection pitfalls.
19 MINIMUM_CFFI_VERSION = "1.11"
19 MINIMUM_CFFI_VERSION = "1.11"
20
20
21 try:
21 try:
22 import cffi
22 import cffi
23
23
24 # PyPy (and possibly other distros) have CFFI distributed as part of
24 # PyPy (and possibly other distros) have CFFI distributed as part of
25 # them. The install_requires for CFFI below won't work. We need to sniff
25 # them. The install_requires for CFFI below won't work. We need to sniff
26 # out the CFFI version here and reject CFFI if it is too old.
26 # out the CFFI version here and reject CFFI if it is too old.
27 cffi_version = LooseVersion(cffi.__version__)
27 cffi_version = LooseVersion(cffi.__version__)
28 if cffi_version < LooseVersion(MINIMUM_CFFI_VERSION):
28 if cffi_version < LooseVersion(MINIMUM_CFFI_VERSION):
29 print(
29 print(
30 "CFFI 1.11 or newer required (%s found); "
30 "CFFI 1.11 or newer required (%s found); "
31 "not building CFFI backend" % cffi_version,
31 "not building CFFI backend" % cffi_version,
32 file=sys.stderr,
32 file=sys.stderr,
33 )
33 )
34 cffi = None
34 cffi = None
35
35
36 except ImportError:
36 except ImportError:
37 cffi = None
37 cffi = None
38
38
39 import setup_zstd
39 import setup_zstd
40
40
41 SUPPORT_LEGACY = False
41 SUPPORT_LEGACY = False
42 SYSTEM_ZSTD = False
42 SYSTEM_ZSTD = False
43 WARNINGS_AS_ERRORS = False
43 WARNINGS_AS_ERRORS = False
44
44
45 if os.environ.get("ZSTD_WARNINGS_AS_ERRORS", ""):
45 if os.environ.get("ZSTD_WARNINGS_AS_ERRORS", ""):
46 WARNINGS_AS_ERRORS = True
46 WARNINGS_AS_ERRORS = True
47
47
48 if "--legacy" in sys.argv:
48 if "--legacy" in sys.argv:
49 SUPPORT_LEGACY = True
49 SUPPORT_LEGACY = True
50 sys.argv.remove("--legacy")
50 sys.argv.remove("--legacy")
51
51
52 if "--system-zstd" in sys.argv:
52 if "--system-zstd" in sys.argv:
53 SYSTEM_ZSTD = True
53 SYSTEM_ZSTD = True
54 sys.argv.remove("--system-zstd")
54 sys.argv.remove("--system-zstd")
55
55
56 if "--warnings-as-errors" in sys.argv:
56 if "--warnings-as-errors" in sys.argv:
57 WARNINGS_AS_ERRORS = True
57 WARNINGS_AS_ERRORS = True
58 sys.argv.remove("--warning-as-errors")
58 sys.argv.remove("--warning-as-errors")
59
59
60 # Code for obtaining the Extension instance is in its own module to
60 # Code for obtaining the Extension instance is in its own module to
61 # facilitate reuse in other projects.
61 # facilitate reuse in other projects.
62 extensions = [
62 extensions = [
63 setup_zstd.get_c_extension(
63 setup_zstd.get_c_extension(
64 name="zstd",
64 name="zstd",
65 support_legacy=SUPPORT_LEGACY,
65 support_legacy=SUPPORT_LEGACY,
66 system_zstd=SYSTEM_ZSTD,
66 system_zstd=SYSTEM_ZSTD,
67 warnings_as_errors=WARNINGS_AS_ERRORS,
67 warnings_as_errors=WARNINGS_AS_ERRORS,
68 ),
68 ),
69 ]
69 ]
70
70
71 install_requires = []
71 install_requires = []
72
72
73 if cffi:
73 if cffi:
74 import make_cffi
74 import make_cffi
75
75
76 extensions.append(make_cffi.ffi.distutils_extension())
76 extensions.append(make_cffi.ffi.distutils_extension())
77 install_requires.append("cffi>=%s" % MINIMUM_CFFI_VERSION)
77 install_requires.append("cffi>=%s" % MINIMUM_CFFI_VERSION)
78
78
79 version = None
79 version = None
80
80
81 with open("c-ext/python-zstandard.h", "r") as fh:
81 with open("c-ext/python-zstandard.h", "r") as fh:
82 for line in fh:
82 for line in fh:
83 if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"):
83 if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"):
84 continue
84 continue
85
85
86 version = line.split()[2][1:-1]
86 version = line.split()[2][1:-1]
87 break
87 break
88
88
89 if not version:
89 if not version:
90 raise Exception("could not resolve package version; " "this should never happen")
90 raise Exception(
91 "could not resolve package version; " "this should never happen"
92 )
91
93
92 setup(
94 setup(
93 name="zstandard",
95 name="zstandard",
94 version=version,
96 version=version,
95 description="Zstandard bindings for Python",
97 description="Zstandard bindings for Python",
96 long_description=open("README.rst", "r").read(),
98 long_description=open("README.rst", "r").read(),
97 url="https://github.com/indygreg/python-zstandard",
99 url="https://github.com/indygreg/python-zstandard",
98 author="Gregory Szorc",
100 author="Gregory Szorc",
99 author_email="gregory.szorc@gmail.com",
101 author_email="gregory.szorc@gmail.com",
100 license="BSD",
102 license="BSD",
101 classifiers=[
103 classifiers=[
102 "Development Status :: 4 - Beta",
104 "Development Status :: 4 - Beta",
103 "Intended Audience :: Developers",
105 "Intended Audience :: Developers",
104 "License :: OSI Approved :: BSD License",
106 "License :: OSI Approved :: BSD License",
105 "Programming Language :: C",
107 "Programming Language :: C",
106 "Programming Language :: Python :: 2.7",
108 "Programming Language :: Python :: 2.7",
107 "Programming Language :: Python :: 3.5",
109 "Programming Language :: Python :: 3.5",
108 "Programming Language :: Python :: 3.6",
110 "Programming Language :: Python :: 3.6",
109 "Programming Language :: Python :: 3.7",
111 "Programming Language :: Python :: 3.7",
110 "Programming Language :: Python :: 3.8",
112 "Programming Language :: Python :: 3.8",
111 ],
113 ],
112 keywords="zstandard zstd compression",
114 keywords="zstandard zstd compression",
113 packages=["zstandard"],
115 packages=["zstandard"],
114 ext_modules=extensions,
116 ext_modules=extensions,
115 test_suite="tests",
117 test_suite="tests",
116 install_requires=install_requires,
118 install_requires=install_requires,
117 tests_require=["hypothesis"],
119 tests_require=["hypothesis"],
118 )
120 )
@@ -1,206 +1,210 b''
1 # Copyright (c) 2016-present, Gregory Szorc
1 # Copyright (c) 2016-present, Gregory Szorc
2 # All rights reserved.
2 # All rights reserved.
3 #
3 #
4 # This software may be modified and distributed under the terms
4 # This software may be modified and distributed under the terms
5 # of the BSD license. See the LICENSE file for details.
5 # of the BSD license. See the LICENSE file for details.
6
6
7 import distutils.ccompiler
7 import distutils.ccompiler
8 import os
8 import os
9
9
10 from distutils.extension import Extension
10 from distutils.extension import Extension
11
11
12
12
13 zstd_sources = [
13 zstd_sources = [
14 "zstd/%s" % p
14 "zstd/%s" % p
15 for p in (
15 for p in (
16 "common/debug.c",
16 "common/debug.c",
17 "common/entropy_common.c",
17 "common/entropy_common.c",
18 "common/error_private.c",
18 "common/error_private.c",
19 "common/fse_decompress.c",
19 "common/fse_decompress.c",
20 "common/pool.c",
20 "common/pool.c",
21 "common/threading.c",
21 "common/threading.c",
22 "common/xxhash.c",
22 "common/xxhash.c",
23 "common/zstd_common.c",
23 "common/zstd_common.c",
24 "compress/fse_compress.c",
24 "compress/fse_compress.c",
25 "compress/hist.c",
25 "compress/hist.c",
26 "compress/huf_compress.c",
26 "compress/huf_compress.c",
27 "compress/zstd_compress_literals.c",
27 "compress/zstd_compress_literals.c",
28 "compress/zstd_compress_sequences.c",
28 "compress/zstd_compress_sequences.c",
29 "compress/zstd_compress.c",
29 "compress/zstd_compress.c",
30 "compress/zstd_double_fast.c",
30 "compress/zstd_double_fast.c",
31 "compress/zstd_fast.c",
31 "compress/zstd_fast.c",
32 "compress/zstd_lazy.c",
32 "compress/zstd_lazy.c",
33 "compress/zstd_ldm.c",
33 "compress/zstd_ldm.c",
34 "compress/zstd_opt.c",
34 "compress/zstd_opt.c",
35 "compress/zstdmt_compress.c",
35 "compress/zstdmt_compress.c",
36 "decompress/huf_decompress.c",
36 "decompress/huf_decompress.c",
37 "decompress/zstd_ddict.c",
37 "decompress/zstd_ddict.c",
38 "decompress/zstd_decompress.c",
38 "decompress/zstd_decompress.c",
39 "decompress/zstd_decompress_block.c",
39 "decompress/zstd_decompress_block.c",
40 "dictBuilder/cover.c",
40 "dictBuilder/cover.c",
41 "dictBuilder/divsufsort.c",
41 "dictBuilder/divsufsort.c",
42 "dictBuilder/fastcover.c",
42 "dictBuilder/fastcover.c",
43 "dictBuilder/zdict.c",
43 "dictBuilder/zdict.c",
44 )
44 )
45 ]
45 ]
46
46
47 zstd_sources_legacy = [
47 zstd_sources_legacy = [
48 "zstd/%s" % p
48 "zstd/%s" % p
49 for p in (
49 for p in (
50 "deprecated/zbuff_common.c",
50 "deprecated/zbuff_common.c",
51 "deprecated/zbuff_compress.c",
51 "deprecated/zbuff_compress.c",
52 "deprecated/zbuff_decompress.c",
52 "deprecated/zbuff_decompress.c",
53 "legacy/zstd_v01.c",
53 "legacy/zstd_v01.c",
54 "legacy/zstd_v02.c",
54 "legacy/zstd_v02.c",
55 "legacy/zstd_v03.c",
55 "legacy/zstd_v03.c",
56 "legacy/zstd_v04.c",
56 "legacy/zstd_v04.c",
57 "legacy/zstd_v05.c",
57 "legacy/zstd_v05.c",
58 "legacy/zstd_v06.c",
58 "legacy/zstd_v06.c",
59 "legacy/zstd_v07.c",
59 "legacy/zstd_v07.c",
60 )
60 )
61 ]
61 ]
62
62
63 zstd_includes = [
63 zstd_includes = [
64 "zstd",
64 "zstd",
65 "zstd/common",
65 "zstd/common",
66 "zstd/compress",
66 "zstd/compress",
67 "zstd/decompress",
67 "zstd/decompress",
68 "zstd/dictBuilder",
68 "zstd/dictBuilder",
69 ]
69 ]
70
70
71 zstd_includes_legacy = [
71 zstd_includes_legacy = [
72 "zstd/deprecated",
72 "zstd/deprecated",
73 "zstd/legacy",
73 "zstd/legacy",
74 ]
74 ]
75
75
76 ext_includes = [
76 ext_includes = [
77 "c-ext",
77 "c-ext",
78 "zstd/common",
78 "zstd/common",
79 ]
79 ]
80
80
81 ext_sources = [
81 ext_sources = [
82 "zstd/common/error_private.c",
82 "zstd/common/error_private.c",
83 "zstd/common/pool.c",
83 "zstd/common/pool.c",
84 "zstd/common/threading.c",
84 "zstd/common/threading.c",
85 "zstd/common/zstd_common.c",
85 "zstd/common/zstd_common.c",
86 "zstd.c",
86 "zstd.c",
87 "c-ext/bufferutil.c",
87 "c-ext/bufferutil.c",
88 "c-ext/compressiondict.c",
88 "c-ext/compressiondict.c",
89 "c-ext/compressobj.c",
89 "c-ext/compressobj.c",
90 "c-ext/compressor.c",
90 "c-ext/compressor.c",
91 "c-ext/compressoriterator.c",
91 "c-ext/compressoriterator.c",
92 "c-ext/compressionchunker.c",
92 "c-ext/compressionchunker.c",
93 "c-ext/compressionparams.c",
93 "c-ext/compressionparams.c",
94 "c-ext/compressionreader.c",
94 "c-ext/compressionreader.c",
95 "c-ext/compressionwriter.c",
95 "c-ext/compressionwriter.c",
96 "c-ext/constants.c",
96 "c-ext/constants.c",
97 "c-ext/decompressobj.c",
97 "c-ext/decompressobj.c",
98 "c-ext/decompressor.c",
98 "c-ext/decompressor.c",
99 "c-ext/decompressoriterator.c",
99 "c-ext/decompressoriterator.c",
100 "c-ext/decompressionreader.c",
100 "c-ext/decompressionreader.c",
101 "c-ext/decompressionwriter.c",
101 "c-ext/decompressionwriter.c",
102 "c-ext/frameparams.c",
102 "c-ext/frameparams.c",
103 ]
103 ]
104
104
105 zstd_depends = [
105 zstd_depends = [
106 "c-ext/python-zstandard.h",
106 "c-ext/python-zstandard.h",
107 ]
107 ]
108
108
109
109
110 def get_c_extension(
110 def get_c_extension(
111 support_legacy=False,
111 support_legacy=False,
112 system_zstd=False,
112 system_zstd=False,
113 name="zstd",
113 name="zstd",
114 warnings_as_errors=False,
114 warnings_as_errors=False,
115 root=None,
115 root=None,
116 ):
116 ):
117 """Obtain a distutils.extension.Extension for the C extension.
117 """Obtain a distutils.extension.Extension for the C extension.
118
118
119 ``support_legacy`` controls whether to compile in legacy zstd format support.
119 ``support_legacy`` controls whether to compile in legacy zstd format support.
120
120
121 ``system_zstd`` controls whether to compile against the system zstd library.
121 ``system_zstd`` controls whether to compile against the system zstd library.
122 For this to work, the system zstd library and headers must match what
122 For this to work, the system zstd library and headers must match what
123 python-zstandard is coded against exactly.
123 python-zstandard is coded against exactly.
124
124
125 ``name`` is the module name of the C extension to produce.
125 ``name`` is the module name of the C extension to produce.
126
126
127 ``warnings_as_errors`` controls whether compiler warnings are turned into
127 ``warnings_as_errors`` controls whether compiler warnings are turned into
128 compiler errors.
128 compiler errors.
129
129
130 ``root`` defines a root path that source should be computed as relative
130 ``root`` defines a root path that source should be computed as relative
131 to. This should be the directory with the main ``setup.py`` that is
131 to. This should be the directory with the main ``setup.py`` that is
132 being invoked. If not defined, paths will be relative to this file.
132 being invoked. If not defined, paths will be relative to this file.
133 """
133 """
134 actual_root = os.path.abspath(os.path.dirname(__file__))
134 actual_root = os.path.abspath(os.path.dirname(__file__))
135 root = root or actual_root
135 root = root or actual_root
136
136
137 sources = set([os.path.join(actual_root, p) for p in ext_sources])
137 sources = set([os.path.join(actual_root, p) for p in ext_sources])
138 if not system_zstd:
138 if not system_zstd:
139 sources.update([os.path.join(actual_root, p) for p in zstd_sources])
139 sources.update([os.path.join(actual_root, p) for p in zstd_sources])
140 if support_legacy:
140 if support_legacy:
141 sources.update([os.path.join(actual_root, p) for p in zstd_sources_legacy])
141 sources.update(
142 [os.path.join(actual_root, p) for p in zstd_sources_legacy]
143 )
142 sources = list(sources)
144 sources = list(sources)
143
145
144 include_dirs = set([os.path.join(actual_root, d) for d in ext_includes])
146 include_dirs = set([os.path.join(actual_root, d) for d in ext_includes])
145 if not system_zstd:
147 if not system_zstd:
146 include_dirs.update([os.path.join(actual_root, d) for d in zstd_includes])
148 include_dirs.update(
149 [os.path.join(actual_root, d) for d in zstd_includes]
150 )
147 if support_legacy:
151 if support_legacy:
148 include_dirs.update(
152 include_dirs.update(
149 [os.path.join(actual_root, d) for d in zstd_includes_legacy]
153 [os.path.join(actual_root, d) for d in zstd_includes_legacy]
150 )
154 )
151 include_dirs = list(include_dirs)
155 include_dirs = list(include_dirs)
152
156
153 depends = [os.path.join(actual_root, p) for p in zstd_depends]
157 depends = [os.path.join(actual_root, p) for p in zstd_depends]
154
158
155 compiler = distutils.ccompiler.new_compiler()
159 compiler = distutils.ccompiler.new_compiler()
156
160
157 # Needed for MSVC.
161 # Needed for MSVC.
158 if hasattr(compiler, "initialize"):
162 if hasattr(compiler, "initialize"):
159 compiler.initialize()
163 compiler.initialize()
160
164
161 if compiler.compiler_type == "unix":
165 if compiler.compiler_type == "unix":
162 compiler_type = "unix"
166 compiler_type = "unix"
163 elif compiler.compiler_type == "msvc":
167 elif compiler.compiler_type == "msvc":
164 compiler_type = "msvc"
168 compiler_type = "msvc"
165 elif compiler.compiler_type == "mingw32":
169 elif compiler.compiler_type == "mingw32":
166 compiler_type = "mingw32"
170 compiler_type = "mingw32"
167 else:
171 else:
168 raise Exception("unhandled compiler type: %s" % compiler.compiler_type)
172 raise Exception("unhandled compiler type: %s" % compiler.compiler_type)
169
173
170 extra_args = ["-DZSTD_MULTITHREAD"]
174 extra_args = ["-DZSTD_MULTITHREAD"]
171
175
172 if not system_zstd:
176 if not system_zstd:
173 extra_args.append("-DZSTDLIB_VISIBILITY=")
177 extra_args.append("-DZSTDLIB_VISIBILITY=")
174 extra_args.append("-DZDICTLIB_VISIBILITY=")
178 extra_args.append("-DZDICTLIB_VISIBILITY=")
175 extra_args.append("-DZSTDERRORLIB_VISIBILITY=")
179 extra_args.append("-DZSTDERRORLIB_VISIBILITY=")
176
180
177 if compiler_type == "unix":
181 if compiler_type == "unix":
178 extra_args.append("-fvisibility=hidden")
182 extra_args.append("-fvisibility=hidden")
179
183
180 if not system_zstd and support_legacy:
184 if not system_zstd and support_legacy:
181 extra_args.append("-DZSTD_LEGACY_SUPPORT=1")
185 extra_args.append("-DZSTD_LEGACY_SUPPORT=1")
182
186
183 if warnings_as_errors:
187 if warnings_as_errors:
184 if compiler_type in ("unix", "mingw32"):
188 if compiler_type in ("unix", "mingw32"):
185 extra_args.append("-Werror")
189 extra_args.append("-Werror")
186 elif compiler_type == "msvc":
190 elif compiler_type == "msvc":
187 extra_args.append("/WX")
191 extra_args.append("/WX")
188 else:
192 else:
189 assert False
193 assert False
190
194
191 libraries = ["zstd"] if system_zstd else []
195 libraries = ["zstd"] if system_zstd else []
192
196
193 # Python 3.7 doesn't like absolute paths. So normalize to relative.
197 # Python 3.7 doesn't like absolute paths. So normalize to relative.
194 sources = [os.path.relpath(p, root) for p in sources]
198 sources = [os.path.relpath(p, root) for p in sources]
195 include_dirs = [os.path.relpath(p, root) for p in include_dirs]
199 include_dirs = [os.path.relpath(p, root) for p in include_dirs]
196 depends = [os.path.relpath(p, root) for p in depends]
200 depends = [os.path.relpath(p, root) for p in depends]
197
201
198 # TODO compile with optimizations.
202 # TODO compile with optimizations.
199 return Extension(
203 return Extension(
200 name,
204 name,
201 sources,
205 sources,
202 include_dirs=include_dirs,
206 include_dirs=include_dirs,
203 depends=depends,
207 depends=depends,
204 extra_compile_args=extra_args,
208 extra_compile_args=extra_args,
205 libraries=libraries,
209 libraries=libraries,
206 )
210 )
@@ -1,197 +1,203 b''
1 import imp
1 import imp
2 import inspect
2 import inspect
3 import io
3 import io
4 import os
4 import os
5 import types
5 import types
6 import unittest
6 import unittest
7
7
8 try:
8 try:
9 import hypothesis
9 import hypothesis
10 except ImportError:
10 except ImportError:
11 hypothesis = None
11 hypothesis = None
12
12
13
13
14 class TestCase(unittest.TestCase):
14 class TestCase(unittest.TestCase):
15 if not getattr(unittest.TestCase, "assertRaisesRegex", False):
15 if not getattr(unittest.TestCase, "assertRaisesRegex", False):
16 assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
16 assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
17
17
18
18
19 def make_cffi(cls):
19 def make_cffi(cls):
20 """Decorator to add CFFI versions of each test method."""
20 """Decorator to add CFFI versions of each test method."""
21
21
22 # The module containing this class definition should
22 # The module containing this class definition should
23 # `import zstandard as zstd`. Otherwise things may blow up.
23 # `import zstandard as zstd`. Otherwise things may blow up.
24 mod = inspect.getmodule(cls)
24 mod = inspect.getmodule(cls)
25 if not hasattr(mod, "zstd"):
25 if not hasattr(mod, "zstd"):
26 raise Exception('test module does not contain "zstd" symbol')
26 raise Exception('test module does not contain "zstd" symbol')
27
27
28 if not hasattr(mod.zstd, "backend"):
28 if not hasattr(mod.zstd, "backend"):
29 raise Exception(
29 raise Exception(
30 'zstd symbol does not have "backend" attribute; did '
30 'zstd symbol does not have "backend" attribute; did '
31 "you `import zstandard as zstd`?"
31 "you `import zstandard as zstd`?"
32 )
32 )
33
33
34 # If `import zstandard` already chose the cffi backend, there is nothing
34 # If `import zstandard` already chose the cffi backend, there is nothing
35 # for us to do: we only add the cffi variation if the default backend
35 # for us to do: we only add the cffi variation if the default backend
36 # is the C extension.
36 # is the C extension.
37 if mod.zstd.backend == "cffi":
37 if mod.zstd.backend == "cffi":
38 return cls
38 return cls
39
39
40 old_env = dict(os.environ)
40 old_env = dict(os.environ)
41 os.environ["PYTHON_ZSTANDARD_IMPORT_POLICY"] = "cffi"
41 os.environ["PYTHON_ZSTANDARD_IMPORT_POLICY"] = "cffi"
42 try:
42 try:
43 try:
43 try:
44 mod_info = imp.find_module("zstandard")
44 mod_info = imp.find_module("zstandard")
45 mod = imp.load_module("zstandard_cffi", *mod_info)
45 mod = imp.load_module("zstandard_cffi", *mod_info)
46 except ImportError:
46 except ImportError:
47 return cls
47 return cls
48 finally:
48 finally:
49 os.environ.clear()
49 os.environ.clear()
50 os.environ.update(old_env)
50 os.environ.update(old_env)
51
51
52 if mod.backend != "cffi":
52 if mod.backend != "cffi":
53 raise Exception("got the zstandard %s backend instead of cffi" % mod.backend)
53 raise Exception(
54 "got the zstandard %s backend instead of cffi" % mod.backend
55 )
54
56
55 # If CFFI version is available, dynamically construct test methods
57 # If CFFI version is available, dynamically construct test methods
56 # that use it.
58 # that use it.
57
59
58 for attr in dir(cls):
60 for attr in dir(cls):
59 fn = getattr(cls, attr)
61 fn = getattr(cls, attr)
60 if not inspect.ismethod(fn) and not inspect.isfunction(fn):
62 if not inspect.ismethod(fn) and not inspect.isfunction(fn):
61 continue
63 continue
62
64
63 if not fn.__name__.startswith("test_"):
65 if not fn.__name__.startswith("test_"):
64 continue
66 continue
65
67
66 name = "%s_cffi" % fn.__name__
68 name = "%s_cffi" % fn.__name__
67
69
68 # Replace the "zstd" symbol with the CFFI module instance. Then copy
70 # Replace the "zstd" symbol with the CFFI module instance. Then copy
69 # the function object and install it in a new attribute.
71 # the function object and install it in a new attribute.
70 if isinstance(fn, types.FunctionType):
72 if isinstance(fn, types.FunctionType):
71 globs = dict(fn.__globals__)
73 globs = dict(fn.__globals__)
72 globs["zstd"] = mod
74 globs["zstd"] = mod
73 new_fn = types.FunctionType(
75 new_fn = types.FunctionType(
74 fn.__code__, globs, name, fn.__defaults__, fn.__closure__
76 fn.__code__, globs, name, fn.__defaults__, fn.__closure__
75 )
77 )
76 new_method = new_fn
78 new_method = new_fn
77 else:
79 else:
78 globs = dict(fn.__func__.func_globals)
80 globs = dict(fn.__func__.func_globals)
79 globs["zstd"] = mod
81 globs["zstd"] = mod
80 new_fn = types.FunctionType(
82 new_fn = types.FunctionType(
81 fn.__func__.func_code,
83 fn.__func__.func_code,
82 globs,
84 globs,
83 name,
85 name,
84 fn.__func__.func_defaults,
86 fn.__func__.func_defaults,
85 fn.__func__.func_closure,
87 fn.__func__.func_closure,
86 )
88 )
87 new_method = types.UnboundMethodType(new_fn, fn.im_self, fn.im_class)
89 new_method = types.UnboundMethodType(
90 new_fn, fn.im_self, fn.im_class
91 )
88
92
89 setattr(cls, name, new_method)
93 setattr(cls, name, new_method)
90
94
91 return cls
95 return cls
92
96
93
97
94 class NonClosingBytesIO(io.BytesIO):
98 class NonClosingBytesIO(io.BytesIO):
95 """BytesIO that saves the underlying buffer on close().
99 """BytesIO that saves the underlying buffer on close().
96
100
97 This allows us to access written data after close().
101 This allows us to access written data after close().
98 """
102 """
99
103
100 def __init__(self, *args, **kwargs):
104 def __init__(self, *args, **kwargs):
101 super(NonClosingBytesIO, self).__init__(*args, **kwargs)
105 super(NonClosingBytesIO, self).__init__(*args, **kwargs)
102 self._saved_buffer = None
106 self._saved_buffer = None
103
107
104 def close(self):
108 def close(self):
105 self._saved_buffer = self.getvalue()
109 self._saved_buffer = self.getvalue()
106 return super(NonClosingBytesIO, self).close()
110 return super(NonClosingBytesIO, self).close()
107
111
108 def getvalue(self):
112 def getvalue(self):
109 if self.closed:
113 if self.closed:
110 return self._saved_buffer
114 return self._saved_buffer
111 else:
115 else:
112 return super(NonClosingBytesIO, self).getvalue()
116 return super(NonClosingBytesIO, self).getvalue()
113
117
114
118
115 class OpCountingBytesIO(NonClosingBytesIO):
119 class OpCountingBytesIO(NonClosingBytesIO):
116 def __init__(self, *args, **kwargs):
120 def __init__(self, *args, **kwargs):
117 self._flush_count = 0
121 self._flush_count = 0
118 self._read_count = 0
122 self._read_count = 0
119 self._write_count = 0
123 self._write_count = 0
120 return super(OpCountingBytesIO, self).__init__(*args, **kwargs)
124 return super(OpCountingBytesIO, self).__init__(*args, **kwargs)
121
125
122 def flush(self):
126 def flush(self):
123 self._flush_count += 1
127 self._flush_count += 1
124 return super(OpCountingBytesIO, self).flush()
128 return super(OpCountingBytesIO, self).flush()
125
129
126 def read(self, *args):
130 def read(self, *args):
127 self._read_count += 1
131 self._read_count += 1
128 return super(OpCountingBytesIO, self).read(*args)
132 return super(OpCountingBytesIO, self).read(*args)
129
133
130 def write(self, data):
134 def write(self, data):
131 self._write_count += 1
135 self._write_count += 1
132 return super(OpCountingBytesIO, self).write(data)
136 return super(OpCountingBytesIO, self).write(data)
133
137
134
138
135 _source_files = []
139 _source_files = []
136
140
137
141
138 def random_input_data():
142 def random_input_data():
139 """Obtain the raw content of source files.
143 """Obtain the raw content of source files.
140
144
141 This is used for generating "random" data to feed into fuzzing, since it is
145 This is used for generating "random" data to feed into fuzzing, since it is
142 faster than random content generation.
146 faster than random content generation.
143 """
147 """
144 if _source_files:
148 if _source_files:
145 return _source_files
149 return _source_files
146
150
147 for root, dirs, files in os.walk(os.path.dirname(__file__)):
151 for root, dirs, files in os.walk(os.path.dirname(__file__)):
148 dirs[:] = list(sorted(dirs))
152 dirs[:] = list(sorted(dirs))
149 for f in sorted(files):
153 for f in sorted(files):
150 try:
154 try:
151 with open(os.path.join(root, f), "rb") as fh:
155 with open(os.path.join(root, f), "rb") as fh:
152 data = fh.read()
156 data = fh.read()
153 if data:
157 if data:
154 _source_files.append(data)
158 _source_files.append(data)
155 except OSError:
159 except OSError:
156 pass
160 pass
157
161
158 # Also add some actual random data.
162 # Also add some actual random data.
159 _source_files.append(os.urandom(100))
163 _source_files.append(os.urandom(100))
160 _source_files.append(os.urandom(1000))
164 _source_files.append(os.urandom(1000))
161 _source_files.append(os.urandom(10000))
165 _source_files.append(os.urandom(10000))
162 _source_files.append(os.urandom(100000))
166 _source_files.append(os.urandom(100000))
163 _source_files.append(os.urandom(1000000))
167 _source_files.append(os.urandom(1000000))
164
168
165 return _source_files
169 return _source_files
166
170
167
171
168 def generate_samples():
172 def generate_samples():
169 inputs = [
173 inputs = [
170 b"foo",
174 b"foo",
171 b"bar",
175 b"bar",
172 b"abcdef",
176 b"abcdef",
173 b"sometext",
177 b"sometext",
174 b"baz",
178 b"baz",
175 ]
179 ]
176
180
177 samples = []
181 samples = []
178
182
179 for i in range(128):
183 for i in range(128):
180 samples.append(inputs[i % 5])
184 samples.append(inputs[i % 5])
181 samples.append(inputs[i % 5] * (i + 3))
185 samples.append(inputs[i % 5] * (i + 3))
182 samples.append(inputs[-(i % 5)] * (i + 2))
186 samples.append(inputs[-(i % 5)] * (i + 2))
183
187
184 return samples
188 return samples
185
189
186
190
187 if hypothesis:
191 if hypothesis:
188 default_settings = hypothesis.settings(deadline=10000)
192 default_settings = hypothesis.settings(deadline=10000)
189 hypothesis.settings.register_profile("default", default_settings)
193 hypothesis.settings.register_profile("default", default_settings)
190
194
191 ci_settings = hypothesis.settings(deadline=20000, max_examples=1000)
195 ci_settings = hypothesis.settings(deadline=20000, max_examples=1000)
192 hypothesis.settings.register_profile("ci", ci_settings)
196 hypothesis.settings.register_profile("ci", ci_settings)
193
197
194 expensive_settings = hypothesis.settings(deadline=None, max_examples=10000)
198 expensive_settings = hypothesis.settings(deadline=None, max_examples=10000)
195 hypothesis.settings.register_profile("expensive", expensive_settings)
199 hypothesis.settings.register_profile("expensive", expensive_settings)
196
200
197 hypothesis.settings.load_profile(os.environ.get("HYPOTHESIS_PROFILE", "default"))
201 hypothesis.settings.load_profile(
202 os.environ.get("HYPOTHESIS_PROFILE", "default")
203 )
@@ -1,146 +1,153 b''
1 import struct
1 import struct
2 import unittest
2 import unittest
3
3
4 import zstandard as zstd
4 import zstandard as zstd
5
5
6 from .common import TestCase
6 from .common import TestCase
7
7
8 ss = struct.Struct("=QQ")
8 ss = struct.Struct("=QQ")
9
9
10
10
11 class TestBufferWithSegments(TestCase):
11 class TestBufferWithSegments(TestCase):
12 def test_arguments(self):
12 def test_arguments(self):
13 if not hasattr(zstd, "BufferWithSegments"):
13 if not hasattr(zstd, "BufferWithSegments"):
14 self.skipTest("BufferWithSegments not available")
14 self.skipTest("BufferWithSegments not available")
15
15
16 with self.assertRaises(TypeError):
16 with self.assertRaises(TypeError):
17 zstd.BufferWithSegments()
17 zstd.BufferWithSegments()
18
18
19 with self.assertRaises(TypeError):
19 with self.assertRaises(TypeError):
20 zstd.BufferWithSegments(b"foo")
20 zstd.BufferWithSegments(b"foo")
21
21
22 # Segments data should be a multiple of 16.
22 # Segments data should be a multiple of 16.
23 with self.assertRaisesRegex(
23 with self.assertRaisesRegex(
24 ValueError, "segments array size is not a multiple of 16"
24 ValueError, "segments array size is not a multiple of 16"
25 ):
25 ):
26 zstd.BufferWithSegments(b"foo", b"\x00\x00")
26 zstd.BufferWithSegments(b"foo", b"\x00\x00")
27
27
28 def test_invalid_offset(self):
28 def test_invalid_offset(self):
29 if not hasattr(zstd, "BufferWithSegments"):
29 if not hasattr(zstd, "BufferWithSegments"):
30 self.skipTest("BufferWithSegments not available")
30 self.skipTest("BufferWithSegments not available")
31
31
32 with self.assertRaisesRegex(
32 with self.assertRaisesRegex(
33 ValueError, "offset within segments array references memory"
33 ValueError, "offset within segments array references memory"
34 ):
34 ):
35 zstd.BufferWithSegments(b"foo", ss.pack(0, 4))
35 zstd.BufferWithSegments(b"foo", ss.pack(0, 4))
36
36
37 def test_invalid_getitem(self):
37 def test_invalid_getitem(self):
38 if not hasattr(zstd, "BufferWithSegments"):
38 if not hasattr(zstd, "BufferWithSegments"):
39 self.skipTest("BufferWithSegments not available")
39 self.skipTest("BufferWithSegments not available")
40
40
41 b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
41 b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
42
42
43 with self.assertRaisesRegex(IndexError, "offset must be non-negative"):
43 with self.assertRaisesRegex(IndexError, "offset must be non-negative"):
44 test = b[-10]
44 test = b[-10]
45
45
46 with self.assertRaisesRegex(IndexError, "offset must be less than 1"):
46 with self.assertRaisesRegex(IndexError, "offset must be less than 1"):
47 test = b[1]
47 test = b[1]
48
48
49 with self.assertRaisesRegex(IndexError, "offset must be less than 1"):
49 with self.assertRaisesRegex(IndexError, "offset must be less than 1"):
50 test = b[2]
50 test = b[2]
51
51
52 def test_single(self):
52 def test_single(self):
53 if not hasattr(zstd, "BufferWithSegments"):
53 if not hasattr(zstd, "BufferWithSegments"):
54 self.skipTest("BufferWithSegments not available")
54 self.skipTest("BufferWithSegments not available")
55
55
56 b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
56 b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
57 self.assertEqual(len(b), 1)
57 self.assertEqual(len(b), 1)
58 self.assertEqual(b.size, 3)
58 self.assertEqual(b.size, 3)
59 self.assertEqual(b.tobytes(), b"foo")
59 self.assertEqual(b.tobytes(), b"foo")
60
60
61 self.assertEqual(len(b[0]), 3)
61 self.assertEqual(len(b[0]), 3)
62 self.assertEqual(b[0].offset, 0)
62 self.assertEqual(b[0].offset, 0)
63 self.assertEqual(b[0].tobytes(), b"foo")
63 self.assertEqual(b[0].tobytes(), b"foo")
64
64
65 def test_multiple(self):
65 def test_multiple(self):
66 if not hasattr(zstd, "BufferWithSegments"):
66 if not hasattr(zstd, "BufferWithSegments"):
67 self.skipTest("BufferWithSegments not available")
67 self.skipTest("BufferWithSegments not available")
68
68
69 b = zstd.BufferWithSegments(
69 b = zstd.BufferWithSegments(
70 b"foofooxfooxy", b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)])
70 b"foofooxfooxy",
71 b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]),
71 )
72 )
72 self.assertEqual(len(b), 3)
73 self.assertEqual(len(b), 3)
73 self.assertEqual(b.size, 12)
74 self.assertEqual(b.size, 12)
74 self.assertEqual(b.tobytes(), b"foofooxfooxy")
75 self.assertEqual(b.tobytes(), b"foofooxfooxy")
75
76
76 self.assertEqual(b[0].tobytes(), b"foo")
77 self.assertEqual(b[0].tobytes(), b"foo")
77 self.assertEqual(b[1].tobytes(), b"foox")
78 self.assertEqual(b[1].tobytes(), b"foox")
78 self.assertEqual(b[2].tobytes(), b"fooxy")
79 self.assertEqual(b[2].tobytes(), b"fooxy")
79
80
80
81
81 class TestBufferWithSegmentsCollection(TestCase):
82 class TestBufferWithSegmentsCollection(TestCase):
82 def test_empty_constructor(self):
83 def test_empty_constructor(self):
83 if not hasattr(zstd, "BufferWithSegmentsCollection"):
84 if not hasattr(zstd, "BufferWithSegmentsCollection"):
84 self.skipTest("BufferWithSegmentsCollection not available")
85 self.skipTest("BufferWithSegmentsCollection not available")
85
86
86 with self.assertRaisesRegex(ValueError, "must pass at least 1 argument"):
87 with self.assertRaisesRegex(
88 ValueError, "must pass at least 1 argument"
89 ):
87 zstd.BufferWithSegmentsCollection()
90 zstd.BufferWithSegmentsCollection()
88
91
89 def test_argument_validation(self):
92 def test_argument_validation(self):
90 if not hasattr(zstd, "BufferWithSegmentsCollection"):
93 if not hasattr(zstd, "BufferWithSegmentsCollection"):
91 self.skipTest("BufferWithSegmentsCollection not available")
94 self.skipTest("BufferWithSegmentsCollection not available")
92
95
93 with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"):
96 with self.assertRaisesRegex(
97 TypeError, "arguments must be BufferWithSegments"
98 ):
94 zstd.BufferWithSegmentsCollection(None)
99 zstd.BufferWithSegmentsCollection(None)
95
100
96 with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"):
101 with self.assertRaisesRegex(
102 TypeError, "arguments must be BufferWithSegments"
103 ):
97 zstd.BufferWithSegmentsCollection(
104 zstd.BufferWithSegmentsCollection(
98 zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None
105 zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None
99 )
106 )
100
107
101 with self.assertRaisesRegex(
108 with self.assertRaisesRegex(
102 ValueError, "ZstdBufferWithSegments cannot be empty"
109 ValueError, "ZstdBufferWithSegments cannot be empty"
103 ):
110 ):
104 zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b""))
111 zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b""))
105
112
106 def test_length(self):
113 def test_length(self):
107 if not hasattr(zstd, "BufferWithSegmentsCollection"):
114 if not hasattr(zstd, "BufferWithSegmentsCollection"):
108 self.skipTest("BufferWithSegmentsCollection not available")
115 self.skipTest("BufferWithSegmentsCollection not available")
109
116
110 b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
117 b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
111 b2 = zstd.BufferWithSegments(
118 b2 = zstd.BufferWithSegments(
112 b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)])
119 b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)])
113 )
120 )
114
121
115 c = zstd.BufferWithSegmentsCollection(b1)
122 c = zstd.BufferWithSegmentsCollection(b1)
116 self.assertEqual(len(c), 1)
123 self.assertEqual(len(c), 1)
117 self.assertEqual(c.size(), 3)
124 self.assertEqual(c.size(), 3)
118
125
119 c = zstd.BufferWithSegmentsCollection(b2)
126 c = zstd.BufferWithSegmentsCollection(b2)
120 self.assertEqual(len(c), 2)
127 self.assertEqual(len(c), 2)
121 self.assertEqual(c.size(), 6)
128 self.assertEqual(c.size(), 6)
122
129
123 c = zstd.BufferWithSegmentsCollection(b1, b2)
130 c = zstd.BufferWithSegmentsCollection(b1, b2)
124 self.assertEqual(len(c), 3)
131 self.assertEqual(len(c), 3)
125 self.assertEqual(c.size(), 9)
132 self.assertEqual(c.size(), 9)
126
133
127 def test_getitem(self):
134 def test_getitem(self):
128 if not hasattr(zstd, "BufferWithSegmentsCollection"):
135 if not hasattr(zstd, "BufferWithSegmentsCollection"):
129 self.skipTest("BufferWithSegmentsCollection not available")
136 self.skipTest("BufferWithSegmentsCollection not available")
130
137
131 b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
138 b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
132 b2 = zstd.BufferWithSegments(
139 b2 = zstd.BufferWithSegments(
133 b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)])
140 b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)])
134 )
141 )
135
142
136 c = zstd.BufferWithSegmentsCollection(b1, b2)
143 c = zstd.BufferWithSegmentsCollection(b1, b2)
137
144
138 with self.assertRaisesRegex(IndexError, "offset must be less than 3"):
145 with self.assertRaisesRegex(IndexError, "offset must be less than 3"):
139 c[3]
146 c[3]
140
147
141 with self.assertRaisesRegex(IndexError, "offset must be less than 3"):
148 with self.assertRaisesRegex(IndexError, "offset must be less than 3"):
142 c[4]
149 c[4]
143
150
144 self.assertEqual(c[0].tobytes(), b"foo")
151 self.assertEqual(c[0].tobytes(), b"foo")
145 self.assertEqual(c[1].tobytes(), b"bar")
152 self.assertEqual(c[1].tobytes(), b"bar")
146 self.assertEqual(c[2].tobytes(), b"baz")
153 self.assertEqual(c[2].tobytes(), b"baz")
@@ -1,1770 +1,1803 b''
1 import hashlib
1 import hashlib
2 import io
2 import io
3 import os
3 import os
4 import struct
4 import struct
5 import sys
5 import sys
6 import tarfile
6 import tarfile
7 import tempfile
7 import tempfile
8 import unittest
8 import unittest
9
9
10 import zstandard as zstd
10 import zstandard as zstd
11
11
12 from .common import (
12 from .common import (
13 make_cffi,
13 make_cffi,
14 NonClosingBytesIO,
14 NonClosingBytesIO,
15 OpCountingBytesIO,
15 OpCountingBytesIO,
16 TestCase,
16 TestCase,
17 )
17 )
18
18
19
19
20 if sys.version_info[0] >= 3:
20 if sys.version_info[0] >= 3:
21 next = lambda it: it.__next__()
21 next = lambda it: it.__next__()
22 else:
22 else:
23 next = lambda it: it.next()
23 next = lambda it: it.next()
24
24
25
25
26 def multithreaded_chunk_size(level, source_size=0):
26 def multithreaded_chunk_size(level, source_size=0):
27 params = zstd.ZstdCompressionParameters.from_level(level, source_size=source_size)
27 params = zstd.ZstdCompressionParameters.from_level(
28 level, source_size=source_size
29 )
28
30
29 return 1 << (params.window_log + 2)
31 return 1 << (params.window_log + 2)
30
32
31
33
32 @make_cffi
34 @make_cffi
33 class TestCompressor(TestCase):
35 class TestCompressor(TestCase):
34 def test_level_bounds(self):
36 def test_level_bounds(self):
35 with self.assertRaises(ValueError):
37 with self.assertRaises(ValueError):
36 zstd.ZstdCompressor(level=23)
38 zstd.ZstdCompressor(level=23)
37
39
38 def test_memory_size(self):
40 def test_memory_size(self):
39 cctx = zstd.ZstdCompressor(level=1)
41 cctx = zstd.ZstdCompressor(level=1)
40 self.assertGreater(cctx.memory_size(), 100)
42 self.assertGreater(cctx.memory_size(), 100)
41
43
42
44
43 @make_cffi
45 @make_cffi
44 class TestCompressor_compress(TestCase):
46 class TestCompressor_compress(TestCase):
45 def test_compress_empty(self):
47 def test_compress_empty(self):
46 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
48 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
47 result = cctx.compress(b"")
49 result = cctx.compress(b"")
48 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
50 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
49 params = zstd.get_frame_parameters(result)
51 params = zstd.get_frame_parameters(result)
50 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
52 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
51 self.assertEqual(params.window_size, 524288)
53 self.assertEqual(params.window_size, 524288)
52 self.assertEqual(params.dict_id, 0)
54 self.assertEqual(params.dict_id, 0)
53 self.assertFalse(params.has_checksum, 0)
55 self.assertFalse(params.has_checksum, 0)
54
56
55 cctx = zstd.ZstdCompressor()
57 cctx = zstd.ZstdCompressor()
56 result = cctx.compress(b"")
58 result = cctx.compress(b"")
57 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00")
59 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00")
58 params = zstd.get_frame_parameters(result)
60 params = zstd.get_frame_parameters(result)
59 self.assertEqual(params.content_size, 0)
61 self.assertEqual(params.content_size, 0)
60
62
61 def test_input_types(self):
63 def test_input_types(self):
62 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
64 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
63 expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f"
65 expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f"
64
66
65 mutable_array = bytearray(3)
67 mutable_array = bytearray(3)
66 mutable_array[:] = b"foo"
68 mutable_array[:] = b"foo"
67
69
68 sources = [
70 sources = [
69 memoryview(b"foo"),
71 memoryview(b"foo"),
70 bytearray(b"foo"),
72 bytearray(b"foo"),
71 mutable_array,
73 mutable_array,
72 ]
74 ]
73
75
74 for source in sources:
76 for source in sources:
75 self.assertEqual(cctx.compress(source), expected)
77 self.assertEqual(cctx.compress(source), expected)
76
78
77 def test_compress_large(self):
79 def test_compress_large(self):
78 chunks = []
80 chunks = []
79 for i in range(255):
81 for i in range(255):
80 chunks.append(struct.Struct(">B").pack(i) * 16384)
82 chunks.append(struct.Struct(">B").pack(i) * 16384)
81
83
82 cctx = zstd.ZstdCompressor(level=3, write_content_size=False)
84 cctx = zstd.ZstdCompressor(level=3, write_content_size=False)
83 result = cctx.compress(b"".join(chunks))
85 result = cctx.compress(b"".join(chunks))
84 self.assertEqual(len(result), 999)
86 self.assertEqual(len(result), 999)
85 self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
87 self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
86
88
87 # This matches the test for read_to_iter() below.
89 # This matches the test for read_to_iter() below.
88 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
90 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
89 result = cctx.compress(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o")
91 result = cctx.compress(
92 b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o"
93 )
90 self.assertEqual(
94 self.assertEqual(
91 result,
95 result,
92 b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00"
96 b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00"
93 b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0"
97 b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0"
94 b"\x02\x09\x00\x00\x6f",
98 b"\x02\x09\x00\x00\x6f",
95 )
99 )
96
100
97 def test_negative_level(self):
101 def test_negative_level(self):
98 cctx = zstd.ZstdCompressor(level=-4)
102 cctx = zstd.ZstdCompressor(level=-4)
99 result = cctx.compress(b"foo" * 256)
103 result = cctx.compress(b"foo" * 256)
100
104
101 def test_no_magic(self):
105 def test_no_magic(self):
102 params = zstd.ZstdCompressionParameters.from_level(1, format=zstd.FORMAT_ZSTD1)
106 params = zstd.ZstdCompressionParameters.from_level(
107 1, format=zstd.FORMAT_ZSTD1
108 )
103 cctx = zstd.ZstdCompressor(compression_params=params)
109 cctx = zstd.ZstdCompressor(compression_params=params)
104 magic = cctx.compress(b"foobar")
110 magic = cctx.compress(b"foobar")
105
111
106 params = zstd.ZstdCompressionParameters.from_level(
112 params = zstd.ZstdCompressionParameters.from_level(
107 1, format=zstd.FORMAT_ZSTD1_MAGICLESS
113 1, format=zstd.FORMAT_ZSTD1_MAGICLESS
108 )
114 )
109 cctx = zstd.ZstdCompressor(compression_params=params)
115 cctx = zstd.ZstdCompressor(compression_params=params)
110 no_magic = cctx.compress(b"foobar")
116 no_magic = cctx.compress(b"foobar")
111
117
112 self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd")
118 self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd")
113 self.assertEqual(magic[4:], no_magic)
119 self.assertEqual(magic[4:], no_magic)
114
120
115 def test_write_checksum(self):
121 def test_write_checksum(self):
116 cctx = zstd.ZstdCompressor(level=1)
122 cctx = zstd.ZstdCompressor(level=1)
117 no_checksum = cctx.compress(b"foobar")
123 no_checksum = cctx.compress(b"foobar")
118 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
124 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
119 with_checksum = cctx.compress(b"foobar")
125 with_checksum = cctx.compress(b"foobar")
120
126
121 self.assertEqual(len(with_checksum), len(no_checksum) + 4)
127 self.assertEqual(len(with_checksum), len(no_checksum) + 4)
122
128
123 no_params = zstd.get_frame_parameters(no_checksum)
129 no_params = zstd.get_frame_parameters(no_checksum)
124 with_params = zstd.get_frame_parameters(with_checksum)
130 with_params = zstd.get_frame_parameters(with_checksum)
125
131
126 self.assertFalse(no_params.has_checksum)
132 self.assertFalse(no_params.has_checksum)
127 self.assertTrue(with_params.has_checksum)
133 self.assertTrue(with_params.has_checksum)
128
134
129 def test_write_content_size(self):
135 def test_write_content_size(self):
130 cctx = zstd.ZstdCompressor(level=1)
136 cctx = zstd.ZstdCompressor(level=1)
131 with_size = cctx.compress(b"foobar" * 256)
137 with_size = cctx.compress(b"foobar" * 256)
132 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
138 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
133 no_size = cctx.compress(b"foobar" * 256)
139 no_size = cctx.compress(b"foobar" * 256)
134
140
135 self.assertEqual(len(with_size), len(no_size) + 1)
141 self.assertEqual(len(with_size), len(no_size) + 1)
136
142
137 no_params = zstd.get_frame_parameters(no_size)
143 no_params = zstd.get_frame_parameters(no_size)
138 with_params = zstd.get_frame_parameters(with_size)
144 with_params = zstd.get_frame_parameters(with_size)
139 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
145 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
140 self.assertEqual(with_params.content_size, 1536)
146 self.assertEqual(with_params.content_size, 1536)
141
147
142 def test_no_dict_id(self):
148 def test_no_dict_id(self):
143 samples = []
149 samples = []
144 for i in range(128):
150 for i in range(128):
145 samples.append(b"foo" * 64)
151 samples.append(b"foo" * 64)
146 samples.append(b"bar" * 64)
152 samples.append(b"bar" * 64)
147 samples.append(b"foobar" * 64)
153 samples.append(b"foobar" * 64)
148
154
149 d = zstd.train_dictionary(1024, samples)
155 d = zstd.train_dictionary(1024, samples)
150
156
151 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
157 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
152 with_dict_id = cctx.compress(b"foobarfoobar")
158 with_dict_id = cctx.compress(b"foobarfoobar")
153
159
154 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
160 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
155 no_dict_id = cctx.compress(b"foobarfoobar")
161 no_dict_id = cctx.compress(b"foobarfoobar")
156
162
157 self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
163 self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
158
164
159 no_params = zstd.get_frame_parameters(no_dict_id)
165 no_params = zstd.get_frame_parameters(no_dict_id)
160 with_params = zstd.get_frame_parameters(with_dict_id)
166 with_params = zstd.get_frame_parameters(with_dict_id)
161 self.assertEqual(no_params.dict_id, 0)
167 self.assertEqual(no_params.dict_id, 0)
162 self.assertEqual(with_params.dict_id, 1880053135)
168 self.assertEqual(with_params.dict_id, 1880053135)
163
169
164 def test_compress_dict_multiple(self):
170 def test_compress_dict_multiple(self):
165 samples = []
171 samples = []
166 for i in range(128):
172 for i in range(128):
167 samples.append(b"foo" * 64)
173 samples.append(b"foo" * 64)
168 samples.append(b"bar" * 64)
174 samples.append(b"bar" * 64)
169 samples.append(b"foobar" * 64)
175 samples.append(b"foobar" * 64)
170
176
171 d = zstd.train_dictionary(8192, samples)
177 d = zstd.train_dictionary(8192, samples)
172
178
173 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
179 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
174
180
175 for i in range(32):
181 for i in range(32):
176 cctx.compress(b"foo bar foobar foo bar foobar")
182 cctx.compress(b"foo bar foobar foo bar foobar")
177
183
178 def test_dict_precompute(self):
184 def test_dict_precompute(self):
179 samples = []
185 samples = []
180 for i in range(128):
186 for i in range(128):
181 samples.append(b"foo" * 64)
187 samples.append(b"foo" * 64)
182 samples.append(b"bar" * 64)
188 samples.append(b"bar" * 64)
183 samples.append(b"foobar" * 64)
189 samples.append(b"foobar" * 64)
184
190
185 d = zstd.train_dictionary(8192, samples)
191 d = zstd.train_dictionary(8192, samples)
186 d.precompute_compress(level=1)
192 d.precompute_compress(level=1)
187
193
188 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
194 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
189
195
190 for i in range(32):
196 for i in range(32):
191 cctx.compress(b"foo bar foobar foo bar foobar")
197 cctx.compress(b"foo bar foobar foo bar foobar")
192
198
193 def test_multithreaded(self):
199 def test_multithreaded(self):
194 chunk_size = multithreaded_chunk_size(1)
200 chunk_size = multithreaded_chunk_size(1)
195 source = b"".join([b"x" * chunk_size, b"y" * chunk_size])
201 source = b"".join([b"x" * chunk_size, b"y" * chunk_size])
196
202
197 cctx = zstd.ZstdCompressor(level=1, threads=2)
203 cctx = zstd.ZstdCompressor(level=1, threads=2)
198 compressed = cctx.compress(source)
204 compressed = cctx.compress(source)
199
205
200 params = zstd.get_frame_parameters(compressed)
206 params = zstd.get_frame_parameters(compressed)
201 self.assertEqual(params.content_size, chunk_size * 2)
207 self.assertEqual(params.content_size, chunk_size * 2)
202 self.assertEqual(params.dict_id, 0)
208 self.assertEqual(params.dict_id, 0)
203 self.assertFalse(params.has_checksum)
209 self.assertFalse(params.has_checksum)
204
210
205 dctx = zstd.ZstdDecompressor()
211 dctx = zstd.ZstdDecompressor()
206 self.assertEqual(dctx.decompress(compressed), source)
212 self.assertEqual(dctx.decompress(compressed), source)
207
213
208 def test_multithreaded_dict(self):
214 def test_multithreaded_dict(self):
209 samples = []
215 samples = []
210 for i in range(128):
216 for i in range(128):
211 samples.append(b"foo" * 64)
217 samples.append(b"foo" * 64)
212 samples.append(b"bar" * 64)
218 samples.append(b"bar" * 64)
213 samples.append(b"foobar" * 64)
219 samples.append(b"foobar" * 64)
214
220
215 d = zstd.train_dictionary(1024, samples)
221 d = zstd.train_dictionary(1024, samples)
216
222
217 cctx = zstd.ZstdCompressor(dict_data=d, threads=2)
223 cctx = zstd.ZstdCompressor(dict_data=d, threads=2)
218
224
219 result = cctx.compress(b"foo")
225 result = cctx.compress(b"foo")
220 params = zstd.get_frame_parameters(result)
226 params = zstd.get_frame_parameters(result)
221 self.assertEqual(params.content_size, 3)
227 self.assertEqual(params.content_size, 3)
222 self.assertEqual(params.dict_id, d.dict_id())
228 self.assertEqual(params.dict_id, d.dict_id())
223
229
224 self.assertEqual(
230 self.assertEqual(
225 result,
231 result,
226 b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" b"\x66\x6f\x6f",
232 b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00"
233 b"\x66\x6f\x6f",
227 )
234 )
228
235
229 def test_multithreaded_compression_params(self):
236 def test_multithreaded_compression_params(self):
230 params = zstd.ZstdCompressionParameters.from_level(0, threads=2)
237 params = zstd.ZstdCompressionParameters.from_level(0, threads=2)
231 cctx = zstd.ZstdCompressor(compression_params=params)
238 cctx = zstd.ZstdCompressor(compression_params=params)
232
239
233 result = cctx.compress(b"foo")
240 result = cctx.compress(b"foo")
234 params = zstd.get_frame_parameters(result)
241 params = zstd.get_frame_parameters(result)
235 self.assertEqual(params.content_size, 3)
242 self.assertEqual(params.content_size, 3)
236
243
237 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f")
244 self.assertEqual(
245 result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f"
246 )
238
247
239
248
240 @make_cffi
249 @make_cffi
241 class TestCompressor_compressobj(TestCase):
250 class TestCompressor_compressobj(TestCase):
242 def test_compressobj_empty(self):
251 def test_compressobj_empty(self):
243 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
252 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
244 cobj = cctx.compressobj()
253 cobj = cctx.compressobj()
245 self.assertEqual(cobj.compress(b""), b"")
254 self.assertEqual(cobj.compress(b""), b"")
246 self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
255 self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
247
256
248 def test_input_types(self):
257 def test_input_types(self):
249 expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
258 expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
250 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
259 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
251
260
252 mutable_array = bytearray(3)
261 mutable_array = bytearray(3)
253 mutable_array[:] = b"foo"
262 mutable_array[:] = b"foo"
254
263
255 sources = [
264 sources = [
256 memoryview(b"foo"),
265 memoryview(b"foo"),
257 bytearray(b"foo"),
266 bytearray(b"foo"),
258 mutable_array,
267 mutable_array,
259 ]
268 ]
260
269
261 for source in sources:
270 for source in sources:
262 cobj = cctx.compressobj()
271 cobj = cctx.compressobj()
263 self.assertEqual(cobj.compress(source), b"")
272 self.assertEqual(cobj.compress(source), b"")
264 self.assertEqual(cobj.flush(), expected)
273 self.assertEqual(cobj.flush(), expected)
265
274
266 def test_compressobj_large(self):
275 def test_compressobj_large(self):
267 chunks = []
276 chunks = []
268 for i in range(255):
277 for i in range(255):
269 chunks.append(struct.Struct(">B").pack(i) * 16384)
278 chunks.append(struct.Struct(">B").pack(i) * 16384)
270
279
271 cctx = zstd.ZstdCompressor(level=3)
280 cctx = zstd.ZstdCompressor(level=3)
272 cobj = cctx.compressobj()
281 cobj = cctx.compressobj()
273
282
274 result = cobj.compress(b"".join(chunks)) + cobj.flush()
283 result = cobj.compress(b"".join(chunks)) + cobj.flush()
275 self.assertEqual(len(result), 999)
284 self.assertEqual(len(result), 999)
276 self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
285 self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
277
286
278 params = zstd.get_frame_parameters(result)
287 params = zstd.get_frame_parameters(result)
279 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
288 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
280 self.assertEqual(params.window_size, 2097152)
289 self.assertEqual(params.window_size, 2097152)
281 self.assertEqual(params.dict_id, 0)
290 self.assertEqual(params.dict_id, 0)
282 self.assertFalse(params.has_checksum)
291 self.assertFalse(params.has_checksum)
283
292
284 def test_write_checksum(self):
293 def test_write_checksum(self):
285 cctx = zstd.ZstdCompressor(level=1)
294 cctx = zstd.ZstdCompressor(level=1)
286 cobj = cctx.compressobj()
295 cobj = cctx.compressobj()
287 no_checksum = cobj.compress(b"foobar") + cobj.flush()
296 no_checksum = cobj.compress(b"foobar") + cobj.flush()
288 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
297 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
289 cobj = cctx.compressobj()
298 cobj = cctx.compressobj()
290 with_checksum = cobj.compress(b"foobar") + cobj.flush()
299 with_checksum = cobj.compress(b"foobar") + cobj.flush()
291
300
292 no_params = zstd.get_frame_parameters(no_checksum)
301 no_params = zstd.get_frame_parameters(no_checksum)
293 with_params = zstd.get_frame_parameters(with_checksum)
302 with_params = zstd.get_frame_parameters(with_checksum)
294 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
303 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
295 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
304 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
296 self.assertEqual(no_params.dict_id, 0)
305 self.assertEqual(no_params.dict_id, 0)
297 self.assertEqual(with_params.dict_id, 0)
306 self.assertEqual(with_params.dict_id, 0)
298 self.assertFalse(no_params.has_checksum)
307 self.assertFalse(no_params.has_checksum)
299 self.assertTrue(with_params.has_checksum)
308 self.assertTrue(with_params.has_checksum)
300
309
301 self.assertEqual(len(with_checksum), len(no_checksum) + 4)
310 self.assertEqual(len(with_checksum), len(no_checksum) + 4)
302
311
303 def test_write_content_size(self):
312 def test_write_content_size(self):
304 cctx = zstd.ZstdCompressor(level=1)
313 cctx = zstd.ZstdCompressor(level=1)
305 cobj = cctx.compressobj(size=len(b"foobar" * 256))
314 cobj = cctx.compressobj(size=len(b"foobar" * 256))
306 with_size = cobj.compress(b"foobar" * 256) + cobj.flush()
315 with_size = cobj.compress(b"foobar" * 256) + cobj.flush()
307 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
316 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
308 cobj = cctx.compressobj(size=len(b"foobar" * 256))
317 cobj = cctx.compressobj(size=len(b"foobar" * 256))
309 no_size = cobj.compress(b"foobar" * 256) + cobj.flush()
318 no_size = cobj.compress(b"foobar" * 256) + cobj.flush()
310
319
311 no_params = zstd.get_frame_parameters(no_size)
320 no_params = zstd.get_frame_parameters(no_size)
312 with_params = zstd.get_frame_parameters(with_size)
321 with_params = zstd.get_frame_parameters(with_size)
313 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
322 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
314 self.assertEqual(with_params.content_size, 1536)
323 self.assertEqual(with_params.content_size, 1536)
315 self.assertEqual(no_params.dict_id, 0)
324 self.assertEqual(no_params.dict_id, 0)
316 self.assertEqual(with_params.dict_id, 0)
325 self.assertEqual(with_params.dict_id, 0)
317 self.assertFalse(no_params.has_checksum)
326 self.assertFalse(no_params.has_checksum)
318 self.assertFalse(with_params.has_checksum)
327 self.assertFalse(with_params.has_checksum)
319
328
320 self.assertEqual(len(with_size), len(no_size) + 1)
329 self.assertEqual(len(with_size), len(no_size) + 1)
321
330
322 def test_compress_after_finished(self):
331 def test_compress_after_finished(self):
323 cctx = zstd.ZstdCompressor()
332 cctx = zstd.ZstdCompressor()
324 cobj = cctx.compressobj()
333 cobj = cctx.compressobj()
325
334
326 cobj.compress(b"foo")
335 cobj.compress(b"foo")
327 cobj.flush()
336 cobj.flush()
328
337
329 with self.assertRaisesRegex(
338 with self.assertRaisesRegex(
330 zstd.ZstdError, r"cannot call compress\(\) after compressor"
339 zstd.ZstdError, r"cannot call compress\(\) after compressor"
331 ):
340 ):
332 cobj.compress(b"foo")
341 cobj.compress(b"foo")
333
342
334 with self.assertRaisesRegex(
343 with self.assertRaisesRegex(
335 zstd.ZstdError, "compressor object already finished"
344 zstd.ZstdError, "compressor object already finished"
336 ):
345 ):
337 cobj.flush()
346 cobj.flush()
338
347
339 def test_flush_block_repeated(self):
348 def test_flush_block_repeated(self):
340 cctx = zstd.ZstdCompressor(level=1)
349 cctx = zstd.ZstdCompressor(level=1)
341 cobj = cctx.compressobj()
350 cobj = cctx.compressobj()
342
351
343 self.assertEqual(cobj.compress(b"foo"), b"")
352 self.assertEqual(cobj.compress(b"foo"), b"")
344 self.assertEqual(
353 self.assertEqual(
345 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
354 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
346 b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo",
355 b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo",
347 )
356 )
348 self.assertEqual(cobj.compress(b"bar"), b"")
357 self.assertEqual(cobj.compress(b"bar"), b"")
349 # 3 byte header plus content.
358 # 3 byte header plus content.
350 self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar")
359 self.assertEqual(
360 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar"
361 )
351 self.assertEqual(cobj.flush(), b"\x01\x00\x00")
362 self.assertEqual(cobj.flush(), b"\x01\x00\x00")
352
363
353 def test_flush_empty_block(self):
364 def test_flush_empty_block(self):
354 cctx = zstd.ZstdCompressor(write_checksum=True)
365 cctx = zstd.ZstdCompressor(write_checksum=True)
355 cobj = cctx.compressobj()
366 cobj = cctx.compressobj()
356
367
357 cobj.compress(b"foobar")
368 cobj.compress(b"foobar")
358 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
369 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
359 # No-op if no block is active (this is internal to zstd).
370 # No-op if no block is active (this is internal to zstd).
360 self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"")
371 self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"")
361
372
362 trailing = cobj.flush()
373 trailing = cobj.flush()
363 # 3 bytes block header + 4 bytes frame checksum
374 # 3 bytes block header + 4 bytes frame checksum
364 self.assertEqual(len(trailing), 7)
375 self.assertEqual(len(trailing), 7)
365 header = trailing[0:3]
376 header = trailing[0:3]
366 self.assertEqual(header, b"\x01\x00\x00")
377 self.assertEqual(header, b"\x01\x00\x00")
367
378
368 def test_multithreaded(self):
379 def test_multithreaded(self):
369 source = io.BytesIO()
380 source = io.BytesIO()
370 source.write(b"a" * 1048576)
381 source.write(b"a" * 1048576)
371 source.write(b"b" * 1048576)
382 source.write(b"b" * 1048576)
372 source.write(b"c" * 1048576)
383 source.write(b"c" * 1048576)
373 source.seek(0)
384 source.seek(0)
374
385
375 cctx = zstd.ZstdCompressor(level=1, threads=2)
386 cctx = zstd.ZstdCompressor(level=1, threads=2)
376 cobj = cctx.compressobj()
387 cobj = cctx.compressobj()
377
388
378 chunks = []
389 chunks = []
379 while True:
390 while True:
380 d = source.read(8192)
391 d = source.read(8192)
381 if not d:
392 if not d:
382 break
393 break
383
394
384 chunks.append(cobj.compress(d))
395 chunks.append(cobj.compress(d))
385
396
386 chunks.append(cobj.flush())
397 chunks.append(cobj.flush())
387
398
388 compressed = b"".join(chunks)
399 compressed = b"".join(chunks)
389
400
390 self.assertEqual(len(compressed), 119)
401 self.assertEqual(len(compressed), 119)
391
402
392 def test_frame_progression(self):
403 def test_frame_progression(self):
393 cctx = zstd.ZstdCompressor()
404 cctx = zstd.ZstdCompressor()
394
405
395 self.assertEqual(cctx.frame_progression(), (0, 0, 0))
406 self.assertEqual(cctx.frame_progression(), (0, 0, 0))
396
407
397 cobj = cctx.compressobj()
408 cobj = cctx.compressobj()
398
409
399 cobj.compress(b"foobar")
410 cobj.compress(b"foobar")
400 self.assertEqual(cctx.frame_progression(), (6, 0, 0))
411 self.assertEqual(cctx.frame_progression(), (6, 0, 0))
401
412
402 cobj.flush()
413 cobj.flush()
403 self.assertEqual(cctx.frame_progression(), (6, 6, 15))
414 self.assertEqual(cctx.frame_progression(), (6, 6, 15))
404
415
405 def test_bad_size(self):
416 def test_bad_size(self):
406 cctx = zstd.ZstdCompressor()
417 cctx = zstd.ZstdCompressor()
407
418
408 cobj = cctx.compressobj(size=2)
419 cobj = cctx.compressobj(size=2)
409 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
420 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
410 cobj.compress(b"foo")
421 cobj.compress(b"foo")
411
422
412 # Try another operation on this instance.
423 # Try another operation on this instance.
413 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
424 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
414 cobj.compress(b"aa")
425 cobj.compress(b"aa")
415
426
416 # Try another operation on the compressor.
427 # Try another operation on the compressor.
417 cctx.compressobj(size=4)
428 cctx.compressobj(size=4)
418 cctx.compress(b"foobar")
429 cctx.compress(b"foobar")
419
430
420
431
421 @make_cffi
432 @make_cffi
422 class TestCompressor_copy_stream(TestCase):
433 class TestCompressor_copy_stream(TestCase):
423 def test_no_read(self):
434 def test_no_read(self):
424 source = object()
435 source = object()
425 dest = io.BytesIO()
436 dest = io.BytesIO()
426
437
427 cctx = zstd.ZstdCompressor()
438 cctx = zstd.ZstdCompressor()
428 with self.assertRaises(ValueError):
439 with self.assertRaises(ValueError):
429 cctx.copy_stream(source, dest)
440 cctx.copy_stream(source, dest)
430
441
431 def test_no_write(self):
442 def test_no_write(self):
432 source = io.BytesIO()
443 source = io.BytesIO()
433 dest = object()
444 dest = object()
434
445
435 cctx = zstd.ZstdCompressor()
446 cctx = zstd.ZstdCompressor()
436 with self.assertRaises(ValueError):
447 with self.assertRaises(ValueError):
437 cctx.copy_stream(source, dest)
448 cctx.copy_stream(source, dest)
438
449
439 def test_empty(self):
450 def test_empty(self):
440 source = io.BytesIO()
451 source = io.BytesIO()
441 dest = io.BytesIO()
452 dest = io.BytesIO()
442
453
443 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
454 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
444 r, w = cctx.copy_stream(source, dest)
455 r, w = cctx.copy_stream(source, dest)
445 self.assertEqual(int(r), 0)
456 self.assertEqual(int(r), 0)
446 self.assertEqual(w, 9)
457 self.assertEqual(w, 9)
447
458
448 self.assertEqual(dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
459 self.assertEqual(
460 dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00"
461 )
449
462
450 def test_large_data(self):
463 def test_large_data(self):
451 source = io.BytesIO()
464 source = io.BytesIO()
452 for i in range(255):
465 for i in range(255):
453 source.write(struct.Struct(">B").pack(i) * 16384)
466 source.write(struct.Struct(">B").pack(i) * 16384)
454 source.seek(0)
467 source.seek(0)
455
468
456 dest = io.BytesIO()
469 dest = io.BytesIO()
457 cctx = zstd.ZstdCompressor()
470 cctx = zstd.ZstdCompressor()
458 r, w = cctx.copy_stream(source, dest)
471 r, w = cctx.copy_stream(source, dest)
459
472
460 self.assertEqual(r, 255 * 16384)
473 self.assertEqual(r, 255 * 16384)
461 self.assertEqual(w, 999)
474 self.assertEqual(w, 999)
462
475
463 params = zstd.get_frame_parameters(dest.getvalue())
476 params = zstd.get_frame_parameters(dest.getvalue())
464 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
477 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
465 self.assertEqual(params.window_size, 2097152)
478 self.assertEqual(params.window_size, 2097152)
466 self.assertEqual(params.dict_id, 0)
479 self.assertEqual(params.dict_id, 0)
467 self.assertFalse(params.has_checksum)
480 self.assertFalse(params.has_checksum)
468
481
469 def test_write_checksum(self):
482 def test_write_checksum(self):
470 source = io.BytesIO(b"foobar")
483 source = io.BytesIO(b"foobar")
471 no_checksum = io.BytesIO()
484 no_checksum = io.BytesIO()
472
485
473 cctx = zstd.ZstdCompressor(level=1)
486 cctx = zstd.ZstdCompressor(level=1)
474 cctx.copy_stream(source, no_checksum)
487 cctx.copy_stream(source, no_checksum)
475
488
476 source.seek(0)
489 source.seek(0)
477 with_checksum = io.BytesIO()
490 with_checksum = io.BytesIO()
478 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
491 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
479 cctx.copy_stream(source, with_checksum)
492 cctx.copy_stream(source, with_checksum)
480
493
481 self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4)
494 self.assertEqual(
495 len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4
496 )
482
497
483 no_params = zstd.get_frame_parameters(no_checksum.getvalue())
498 no_params = zstd.get_frame_parameters(no_checksum.getvalue())
484 with_params = zstd.get_frame_parameters(with_checksum.getvalue())
499 with_params = zstd.get_frame_parameters(with_checksum.getvalue())
485 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
500 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
486 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
501 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
487 self.assertEqual(no_params.dict_id, 0)
502 self.assertEqual(no_params.dict_id, 0)
488 self.assertEqual(with_params.dict_id, 0)
503 self.assertEqual(with_params.dict_id, 0)
489 self.assertFalse(no_params.has_checksum)
504 self.assertFalse(no_params.has_checksum)
490 self.assertTrue(with_params.has_checksum)
505 self.assertTrue(with_params.has_checksum)
491
506
492 def test_write_content_size(self):
507 def test_write_content_size(self):
493 source = io.BytesIO(b"foobar" * 256)
508 source = io.BytesIO(b"foobar" * 256)
494 no_size = io.BytesIO()
509 no_size = io.BytesIO()
495
510
496 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
511 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
497 cctx.copy_stream(source, no_size)
512 cctx.copy_stream(source, no_size)
498
513
499 source.seek(0)
514 source.seek(0)
500 with_size = io.BytesIO()
515 with_size = io.BytesIO()
501 cctx = zstd.ZstdCompressor(level=1)
516 cctx = zstd.ZstdCompressor(level=1)
502 cctx.copy_stream(source, with_size)
517 cctx.copy_stream(source, with_size)
503
518
504 # Source content size is unknown, so no content size written.
519 # Source content size is unknown, so no content size written.
505 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
520 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
506
521
507 source.seek(0)
522 source.seek(0)
508 with_size = io.BytesIO()
523 with_size = io.BytesIO()
509 cctx.copy_stream(source, with_size, size=len(source.getvalue()))
524 cctx.copy_stream(source, with_size, size=len(source.getvalue()))
510
525
511 # We specified source size, so content size header is present.
526 # We specified source size, so content size header is present.
512 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
527 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
513
528
514 no_params = zstd.get_frame_parameters(no_size.getvalue())
529 no_params = zstd.get_frame_parameters(no_size.getvalue())
515 with_params = zstd.get_frame_parameters(with_size.getvalue())
530 with_params = zstd.get_frame_parameters(with_size.getvalue())
516 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
531 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
517 self.assertEqual(with_params.content_size, 1536)
532 self.assertEqual(with_params.content_size, 1536)
518 self.assertEqual(no_params.dict_id, 0)
533 self.assertEqual(no_params.dict_id, 0)
519 self.assertEqual(with_params.dict_id, 0)
534 self.assertEqual(with_params.dict_id, 0)
520 self.assertFalse(no_params.has_checksum)
535 self.assertFalse(no_params.has_checksum)
521 self.assertFalse(with_params.has_checksum)
536 self.assertFalse(with_params.has_checksum)
522
537
523 def test_read_write_size(self):
538 def test_read_write_size(self):
524 source = OpCountingBytesIO(b"foobarfoobar")
539 source = OpCountingBytesIO(b"foobarfoobar")
525 dest = OpCountingBytesIO()
540 dest = OpCountingBytesIO()
526 cctx = zstd.ZstdCompressor()
541 cctx = zstd.ZstdCompressor()
527 r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1)
542 r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1)
528
543
529 self.assertEqual(r, len(source.getvalue()))
544 self.assertEqual(r, len(source.getvalue()))
530 self.assertEqual(w, 21)
545 self.assertEqual(w, 21)
531 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
546 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
532 self.assertEqual(dest._write_count, len(dest.getvalue()))
547 self.assertEqual(dest._write_count, len(dest.getvalue()))
533
548
534 def test_multithreaded(self):
549 def test_multithreaded(self):
535 source = io.BytesIO()
550 source = io.BytesIO()
536 source.write(b"a" * 1048576)
551 source.write(b"a" * 1048576)
537 source.write(b"b" * 1048576)
552 source.write(b"b" * 1048576)
538 source.write(b"c" * 1048576)
553 source.write(b"c" * 1048576)
539 source.seek(0)
554 source.seek(0)
540
555
541 dest = io.BytesIO()
556 dest = io.BytesIO()
542 cctx = zstd.ZstdCompressor(threads=2, write_content_size=False)
557 cctx = zstd.ZstdCompressor(threads=2, write_content_size=False)
543 r, w = cctx.copy_stream(source, dest)
558 r, w = cctx.copy_stream(source, dest)
544 self.assertEqual(r, 3145728)
559 self.assertEqual(r, 3145728)
545 self.assertEqual(w, 111)
560 self.assertEqual(w, 111)
546
561
547 params = zstd.get_frame_parameters(dest.getvalue())
562 params = zstd.get_frame_parameters(dest.getvalue())
548 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
563 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
549 self.assertEqual(params.dict_id, 0)
564 self.assertEqual(params.dict_id, 0)
550 self.assertFalse(params.has_checksum)
565 self.assertFalse(params.has_checksum)
551
566
552 # Writing content size and checksum works.
567 # Writing content size and checksum works.
553 cctx = zstd.ZstdCompressor(threads=2, write_checksum=True)
568 cctx = zstd.ZstdCompressor(threads=2, write_checksum=True)
554 dest = io.BytesIO()
569 dest = io.BytesIO()
555 source.seek(0)
570 source.seek(0)
556 cctx.copy_stream(source, dest, size=len(source.getvalue()))
571 cctx.copy_stream(source, dest, size=len(source.getvalue()))
557
572
558 params = zstd.get_frame_parameters(dest.getvalue())
573 params = zstd.get_frame_parameters(dest.getvalue())
559 self.assertEqual(params.content_size, 3145728)
574 self.assertEqual(params.content_size, 3145728)
560 self.assertEqual(params.dict_id, 0)
575 self.assertEqual(params.dict_id, 0)
561 self.assertTrue(params.has_checksum)
576 self.assertTrue(params.has_checksum)
562
577
563 def test_bad_size(self):
578 def test_bad_size(self):
564 source = io.BytesIO()
579 source = io.BytesIO()
565 source.write(b"a" * 32768)
580 source.write(b"a" * 32768)
566 source.write(b"b" * 32768)
581 source.write(b"b" * 32768)
567 source.seek(0)
582 source.seek(0)
568
583
569 dest = io.BytesIO()
584 dest = io.BytesIO()
570
585
571 cctx = zstd.ZstdCompressor()
586 cctx = zstd.ZstdCompressor()
572
587
573 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
588 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
574 cctx.copy_stream(source, dest, size=42)
589 cctx.copy_stream(source, dest, size=42)
575
590
576 # Try another operation on this compressor.
591 # Try another operation on this compressor.
577 source.seek(0)
592 source.seek(0)
578 dest = io.BytesIO()
593 dest = io.BytesIO()
579 cctx.copy_stream(source, dest)
594 cctx.copy_stream(source, dest)
580
595
581
596
582 @make_cffi
597 @make_cffi
583 class TestCompressor_stream_reader(TestCase):
598 class TestCompressor_stream_reader(TestCase):
584 def test_context_manager(self):
599 def test_context_manager(self):
585 cctx = zstd.ZstdCompressor()
600 cctx = zstd.ZstdCompressor()
586
601
587 with cctx.stream_reader(b"foo") as reader:
602 with cctx.stream_reader(b"foo") as reader:
588 with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"):
603 with self.assertRaisesRegex(
604 ValueError, "cannot __enter__ multiple times"
605 ):
589 with reader as reader2:
606 with reader as reader2:
590 pass
607 pass
591
608
592 def test_no_context_manager(self):
609 def test_no_context_manager(self):
593 cctx = zstd.ZstdCompressor()
610 cctx = zstd.ZstdCompressor()
594
611
595 reader = cctx.stream_reader(b"foo")
612 reader = cctx.stream_reader(b"foo")
596 reader.read(4)
613 reader.read(4)
597 self.assertFalse(reader.closed)
614 self.assertFalse(reader.closed)
598
615
599 reader.close()
616 reader.close()
600 self.assertTrue(reader.closed)
617 self.assertTrue(reader.closed)
601 with self.assertRaisesRegex(ValueError, "stream is closed"):
618 with self.assertRaisesRegex(ValueError, "stream is closed"):
602 reader.read(1)
619 reader.read(1)
603
620
604 def test_not_implemented(self):
621 def test_not_implemented(self):
605 cctx = zstd.ZstdCompressor()
622 cctx = zstd.ZstdCompressor()
606
623
607 with cctx.stream_reader(b"foo" * 60) as reader:
624 with cctx.stream_reader(b"foo" * 60) as reader:
608 with self.assertRaises(io.UnsupportedOperation):
625 with self.assertRaises(io.UnsupportedOperation):
609 reader.readline()
626 reader.readline()
610
627
611 with self.assertRaises(io.UnsupportedOperation):
628 with self.assertRaises(io.UnsupportedOperation):
612 reader.readlines()
629 reader.readlines()
613
630
614 with self.assertRaises(io.UnsupportedOperation):
631 with self.assertRaises(io.UnsupportedOperation):
615 iter(reader)
632 iter(reader)
616
633
617 with self.assertRaises(io.UnsupportedOperation):
634 with self.assertRaises(io.UnsupportedOperation):
618 next(reader)
635 next(reader)
619
636
620 with self.assertRaises(OSError):
637 with self.assertRaises(OSError):
621 reader.writelines([])
638 reader.writelines([])
622
639
623 with self.assertRaises(OSError):
640 with self.assertRaises(OSError):
624 reader.write(b"foo")
641 reader.write(b"foo")
625
642
626 def test_constant_methods(self):
643 def test_constant_methods(self):
627 cctx = zstd.ZstdCompressor()
644 cctx = zstd.ZstdCompressor()
628
645
629 with cctx.stream_reader(b"boo") as reader:
646 with cctx.stream_reader(b"boo") as reader:
630 self.assertTrue(reader.readable())
647 self.assertTrue(reader.readable())
631 self.assertFalse(reader.writable())
648 self.assertFalse(reader.writable())
632 self.assertFalse(reader.seekable())
649 self.assertFalse(reader.seekable())
633 self.assertFalse(reader.isatty())
650 self.assertFalse(reader.isatty())
634 self.assertFalse(reader.closed)
651 self.assertFalse(reader.closed)
635 self.assertIsNone(reader.flush())
652 self.assertIsNone(reader.flush())
636 self.assertFalse(reader.closed)
653 self.assertFalse(reader.closed)
637
654
638 self.assertTrue(reader.closed)
655 self.assertTrue(reader.closed)
639
656
640 def test_read_closed(self):
657 def test_read_closed(self):
641 cctx = zstd.ZstdCompressor()
658 cctx = zstd.ZstdCompressor()
642
659
643 with cctx.stream_reader(b"foo" * 60) as reader:
660 with cctx.stream_reader(b"foo" * 60) as reader:
644 reader.close()
661 reader.close()
645 self.assertTrue(reader.closed)
662 self.assertTrue(reader.closed)
646 with self.assertRaisesRegex(ValueError, "stream is closed"):
663 with self.assertRaisesRegex(ValueError, "stream is closed"):
647 reader.read(10)
664 reader.read(10)
648
665
649 def test_read_sizes(self):
666 def test_read_sizes(self):
650 cctx = zstd.ZstdCompressor()
667 cctx = zstd.ZstdCompressor()
651 foo = cctx.compress(b"foo")
668 foo = cctx.compress(b"foo")
652
669
653 with cctx.stream_reader(b"foo") as reader:
670 with cctx.stream_reader(b"foo") as reader:
654 with self.assertRaisesRegex(
671 with self.assertRaisesRegex(
655 ValueError, "cannot read negative amounts less than -1"
672 ValueError, "cannot read negative amounts less than -1"
656 ):
673 ):
657 reader.read(-2)
674 reader.read(-2)
658
675
659 self.assertEqual(reader.read(0), b"")
676 self.assertEqual(reader.read(0), b"")
660 self.assertEqual(reader.read(), foo)
677 self.assertEqual(reader.read(), foo)
661
678
662 def test_read_buffer(self):
679 def test_read_buffer(self):
663 cctx = zstd.ZstdCompressor()
680 cctx = zstd.ZstdCompressor()
664
681
665 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
682 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
666 frame = cctx.compress(source)
683 frame = cctx.compress(source)
667
684
668 with cctx.stream_reader(source) as reader:
685 with cctx.stream_reader(source) as reader:
669 self.assertEqual(reader.tell(), 0)
686 self.assertEqual(reader.tell(), 0)
670
687
671 # We should get entire frame in one read.
688 # We should get entire frame in one read.
672 result = reader.read(8192)
689 result = reader.read(8192)
673 self.assertEqual(result, frame)
690 self.assertEqual(result, frame)
674 self.assertEqual(reader.tell(), len(result))
691 self.assertEqual(reader.tell(), len(result))
675 self.assertEqual(reader.read(), b"")
692 self.assertEqual(reader.read(), b"")
676 self.assertEqual(reader.tell(), len(result))
693 self.assertEqual(reader.tell(), len(result))
677
694
678 def test_read_buffer_small_chunks(self):
695 def test_read_buffer_small_chunks(self):
679 cctx = zstd.ZstdCompressor()
696 cctx = zstd.ZstdCompressor()
680
697
681 source = b"foo" * 60
698 source = b"foo" * 60
682 chunks = []
699 chunks = []
683
700
684 with cctx.stream_reader(source) as reader:
701 with cctx.stream_reader(source) as reader:
685 self.assertEqual(reader.tell(), 0)
702 self.assertEqual(reader.tell(), 0)
686
703
687 while True:
704 while True:
688 chunk = reader.read(1)
705 chunk = reader.read(1)
689 if not chunk:
706 if not chunk:
690 break
707 break
691
708
692 chunks.append(chunk)
709 chunks.append(chunk)
693 self.assertEqual(reader.tell(), sum(map(len, chunks)))
710 self.assertEqual(reader.tell(), sum(map(len, chunks)))
694
711
695 self.assertEqual(b"".join(chunks), cctx.compress(source))
712 self.assertEqual(b"".join(chunks), cctx.compress(source))
696
713
697 def test_read_stream(self):
714 def test_read_stream(self):
698 cctx = zstd.ZstdCompressor()
715 cctx = zstd.ZstdCompressor()
699
716
700 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
717 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
701 frame = cctx.compress(source)
718 frame = cctx.compress(source)
702
719
703 with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
720 with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
704 self.assertEqual(reader.tell(), 0)
721 self.assertEqual(reader.tell(), 0)
705
722
706 chunk = reader.read(8192)
723 chunk = reader.read(8192)
707 self.assertEqual(chunk, frame)
724 self.assertEqual(chunk, frame)
708 self.assertEqual(reader.tell(), len(chunk))
725 self.assertEqual(reader.tell(), len(chunk))
709 self.assertEqual(reader.read(), b"")
726 self.assertEqual(reader.read(), b"")
710 self.assertEqual(reader.tell(), len(chunk))
727 self.assertEqual(reader.tell(), len(chunk))
711
728
712 def test_read_stream_small_chunks(self):
729 def test_read_stream_small_chunks(self):
713 cctx = zstd.ZstdCompressor()
730 cctx = zstd.ZstdCompressor()
714
731
715 source = b"foo" * 60
732 source = b"foo" * 60
716 chunks = []
733 chunks = []
717
734
718 with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
735 with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
719 self.assertEqual(reader.tell(), 0)
736 self.assertEqual(reader.tell(), 0)
720
737
721 while True:
738 while True:
722 chunk = reader.read(1)
739 chunk = reader.read(1)
723 if not chunk:
740 if not chunk:
724 break
741 break
725
742
726 chunks.append(chunk)
743 chunks.append(chunk)
727 self.assertEqual(reader.tell(), sum(map(len, chunks)))
744 self.assertEqual(reader.tell(), sum(map(len, chunks)))
728
745
729 self.assertEqual(b"".join(chunks), cctx.compress(source))
746 self.assertEqual(b"".join(chunks), cctx.compress(source))
730
747
731 def test_read_after_exit(self):
748 def test_read_after_exit(self):
732 cctx = zstd.ZstdCompressor()
749 cctx = zstd.ZstdCompressor()
733
750
734 with cctx.stream_reader(b"foo" * 60) as reader:
751 with cctx.stream_reader(b"foo" * 60) as reader:
735 while reader.read(8192):
752 while reader.read(8192):
736 pass
753 pass
737
754
738 with self.assertRaisesRegex(ValueError, "stream is closed"):
755 with self.assertRaisesRegex(ValueError, "stream is closed"):
739 reader.read(10)
756 reader.read(10)
740
757
741 def test_bad_size(self):
758 def test_bad_size(self):
742 cctx = zstd.ZstdCompressor()
759 cctx = zstd.ZstdCompressor()
743
760
744 source = io.BytesIO(b"foobar")
761 source = io.BytesIO(b"foobar")
745
762
746 with cctx.stream_reader(source, size=2) as reader:
763 with cctx.stream_reader(source, size=2) as reader:
747 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
764 with self.assertRaisesRegex(
765 zstd.ZstdError, "Src size is incorrect"
766 ):
748 reader.read(10)
767 reader.read(10)
749
768
750 # Try another compression operation.
769 # Try another compression operation.
751 with cctx.stream_reader(source, size=42):
770 with cctx.stream_reader(source, size=42):
752 pass
771 pass
753
772
754 def test_readall(self):
773 def test_readall(self):
755 cctx = zstd.ZstdCompressor()
774 cctx = zstd.ZstdCompressor()
756 frame = cctx.compress(b"foo" * 1024)
775 frame = cctx.compress(b"foo" * 1024)
757
776
758 reader = cctx.stream_reader(b"foo" * 1024)
777 reader = cctx.stream_reader(b"foo" * 1024)
759 self.assertEqual(reader.readall(), frame)
778 self.assertEqual(reader.readall(), frame)
760
779
761 def test_readinto(self):
780 def test_readinto(self):
762 cctx = zstd.ZstdCompressor()
781 cctx = zstd.ZstdCompressor()
763 foo = cctx.compress(b"foo")
782 foo = cctx.compress(b"foo")
764
783
765 reader = cctx.stream_reader(b"foo")
784 reader = cctx.stream_reader(b"foo")
766 with self.assertRaises(Exception):
785 with self.assertRaises(Exception):
767 reader.readinto(b"foobar")
786 reader.readinto(b"foobar")
768
787
769 # readinto() with sufficiently large destination.
788 # readinto() with sufficiently large destination.
770 b = bytearray(1024)
789 b = bytearray(1024)
771 reader = cctx.stream_reader(b"foo")
790 reader = cctx.stream_reader(b"foo")
772 self.assertEqual(reader.readinto(b), len(foo))
791 self.assertEqual(reader.readinto(b), len(foo))
773 self.assertEqual(b[0 : len(foo)], foo)
792 self.assertEqual(b[0 : len(foo)], foo)
774 self.assertEqual(reader.readinto(b), 0)
793 self.assertEqual(reader.readinto(b), 0)
775 self.assertEqual(b[0 : len(foo)], foo)
794 self.assertEqual(b[0 : len(foo)], foo)
776
795
777 # readinto() with small reads.
796 # readinto() with small reads.
778 b = bytearray(1024)
797 b = bytearray(1024)
779 reader = cctx.stream_reader(b"foo", read_size=1)
798 reader = cctx.stream_reader(b"foo", read_size=1)
780 self.assertEqual(reader.readinto(b), len(foo))
799 self.assertEqual(reader.readinto(b), len(foo))
781 self.assertEqual(b[0 : len(foo)], foo)
800 self.assertEqual(b[0 : len(foo)], foo)
782
801
783 # Too small destination buffer.
802 # Too small destination buffer.
784 b = bytearray(2)
803 b = bytearray(2)
785 reader = cctx.stream_reader(b"foo")
804 reader = cctx.stream_reader(b"foo")
786 self.assertEqual(reader.readinto(b), 2)
805 self.assertEqual(reader.readinto(b), 2)
787 self.assertEqual(b[:], foo[0:2])
806 self.assertEqual(b[:], foo[0:2])
788 self.assertEqual(reader.readinto(b), 2)
807 self.assertEqual(reader.readinto(b), 2)
789 self.assertEqual(b[:], foo[2:4])
808 self.assertEqual(b[:], foo[2:4])
790 self.assertEqual(reader.readinto(b), 2)
809 self.assertEqual(reader.readinto(b), 2)
791 self.assertEqual(b[:], foo[4:6])
810 self.assertEqual(b[:], foo[4:6])
792
811
793 def test_readinto1(self):
812 def test_readinto1(self):
794 cctx = zstd.ZstdCompressor()
813 cctx = zstd.ZstdCompressor()
795 foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
814 foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
796
815
797 reader = cctx.stream_reader(b"foo")
816 reader = cctx.stream_reader(b"foo")
798 with self.assertRaises(Exception):
817 with self.assertRaises(Exception):
799 reader.readinto1(b"foobar")
818 reader.readinto1(b"foobar")
800
819
801 b = bytearray(1024)
820 b = bytearray(1024)
802 source = OpCountingBytesIO(b"foo")
821 source = OpCountingBytesIO(b"foo")
803 reader = cctx.stream_reader(source)
822 reader = cctx.stream_reader(source)
804 self.assertEqual(reader.readinto1(b), len(foo))
823 self.assertEqual(reader.readinto1(b), len(foo))
805 self.assertEqual(b[0 : len(foo)], foo)
824 self.assertEqual(b[0 : len(foo)], foo)
806 self.assertEqual(source._read_count, 2)
825 self.assertEqual(source._read_count, 2)
807
826
808 # readinto1() with small reads.
827 # readinto1() with small reads.
809 b = bytearray(1024)
828 b = bytearray(1024)
810 source = OpCountingBytesIO(b"foo")
829 source = OpCountingBytesIO(b"foo")
811 reader = cctx.stream_reader(source, read_size=1)
830 reader = cctx.stream_reader(source, read_size=1)
812 self.assertEqual(reader.readinto1(b), len(foo))
831 self.assertEqual(reader.readinto1(b), len(foo))
813 self.assertEqual(b[0 : len(foo)], foo)
832 self.assertEqual(b[0 : len(foo)], foo)
814 self.assertEqual(source._read_count, 4)
833 self.assertEqual(source._read_count, 4)
815
834
816 def test_read1(self):
835 def test_read1(self):
817 cctx = zstd.ZstdCompressor()
836 cctx = zstd.ZstdCompressor()
818 foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
837 foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
819
838
820 b = OpCountingBytesIO(b"foo")
839 b = OpCountingBytesIO(b"foo")
821 reader = cctx.stream_reader(b)
840 reader = cctx.stream_reader(b)
822
841
823 self.assertEqual(reader.read1(), foo)
842 self.assertEqual(reader.read1(), foo)
824 self.assertEqual(b._read_count, 2)
843 self.assertEqual(b._read_count, 2)
825
844
826 b = OpCountingBytesIO(b"foo")
845 b = OpCountingBytesIO(b"foo")
827 reader = cctx.stream_reader(b)
846 reader = cctx.stream_reader(b)
828
847
829 self.assertEqual(reader.read1(0), b"")
848 self.assertEqual(reader.read1(0), b"")
830 self.assertEqual(reader.read1(2), foo[0:2])
849 self.assertEqual(reader.read1(2), foo[0:2])
831 self.assertEqual(b._read_count, 2)
850 self.assertEqual(b._read_count, 2)
832 self.assertEqual(reader.read1(2), foo[2:4])
851 self.assertEqual(reader.read1(2), foo[2:4])
833 self.assertEqual(reader.read1(1024), foo[4:])
852 self.assertEqual(reader.read1(1024), foo[4:])
834
853
835
854
836 @make_cffi
855 @make_cffi
837 class TestCompressor_stream_writer(TestCase):
856 class TestCompressor_stream_writer(TestCase):
838 def test_io_api(self):
857 def test_io_api(self):
839 buffer = io.BytesIO()
858 buffer = io.BytesIO()
840 cctx = zstd.ZstdCompressor()
859 cctx = zstd.ZstdCompressor()
841 writer = cctx.stream_writer(buffer)
860 writer = cctx.stream_writer(buffer)
842
861
843 self.assertFalse(writer.isatty())
862 self.assertFalse(writer.isatty())
844 self.assertFalse(writer.readable())
863 self.assertFalse(writer.readable())
845
864
846 with self.assertRaises(io.UnsupportedOperation):
865 with self.assertRaises(io.UnsupportedOperation):
847 writer.readline()
866 writer.readline()
848
867
849 with self.assertRaises(io.UnsupportedOperation):
868 with self.assertRaises(io.UnsupportedOperation):
850 writer.readline(42)
869 writer.readline(42)
851
870
852 with self.assertRaises(io.UnsupportedOperation):
871 with self.assertRaises(io.UnsupportedOperation):
853 writer.readline(size=42)
872 writer.readline(size=42)
854
873
855 with self.assertRaises(io.UnsupportedOperation):
874 with self.assertRaises(io.UnsupportedOperation):
856 writer.readlines()
875 writer.readlines()
857
876
858 with self.assertRaises(io.UnsupportedOperation):
877 with self.assertRaises(io.UnsupportedOperation):
859 writer.readlines(42)
878 writer.readlines(42)
860
879
861 with self.assertRaises(io.UnsupportedOperation):
880 with self.assertRaises(io.UnsupportedOperation):
862 writer.readlines(hint=42)
881 writer.readlines(hint=42)
863
882
864 with self.assertRaises(io.UnsupportedOperation):
883 with self.assertRaises(io.UnsupportedOperation):
865 writer.seek(0)
884 writer.seek(0)
866
885
867 with self.assertRaises(io.UnsupportedOperation):
886 with self.assertRaises(io.UnsupportedOperation):
868 writer.seek(10, os.SEEK_SET)
887 writer.seek(10, os.SEEK_SET)
869
888
870 self.assertFalse(writer.seekable())
889 self.assertFalse(writer.seekable())
871
890
872 with self.assertRaises(io.UnsupportedOperation):
891 with self.assertRaises(io.UnsupportedOperation):
873 writer.truncate()
892 writer.truncate()
874
893
875 with self.assertRaises(io.UnsupportedOperation):
894 with self.assertRaises(io.UnsupportedOperation):
876 writer.truncate(42)
895 writer.truncate(42)
877
896
878 with self.assertRaises(io.UnsupportedOperation):
897 with self.assertRaises(io.UnsupportedOperation):
879 writer.truncate(size=42)
898 writer.truncate(size=42)
880
899
881 self.assertTrue(writer.writable())
900 self.assertTrue(writer.writable())
882
901
883 with self.assertRaises(NotImplementedError):
902 with self.assertRaises(NotImplementedError):
884 writer.writelines([])
903 writer.writelines([])
885
904
886 with self.assertRaises(io.UnsupportedOperation):
905 with self.assertRaises(io.UnsupportedOperation):
887 writer.read()
906 writer.read()
888
907
889 with self.assertRaises(io.UnsupportedOperation):
908 with self.assertRaises(io.UnsupportedOperation):
890 writer.read(42)
909 writer.read(42)
891
910
892 with self.assertRaises(io.UnsupportedOperation):
911 with self.assertRaises(io.UnsupportedOperation):
893 writer.read(size=42)
912 writer.read(size=42)
894
913
895 with self.assertRaises(io.UnsupportedOperation):
914 with self.assertRaises(io.UnsupportedOperation):
896 writer.readall()
915 writer.readall()
897
916
898 with self.assertRaises(io.UnsupportedOperation):
917 with self.assertRaises(io.UnsupportedOperation):
899 writer.readinto(None)
918 writer.readinto(None)
900
919
901 with self.assertRaises(io.UnsupportedOperation):
920 with self.assertRaises(io.UnsupportedOperation):
902 writer.fileno()
921 writer.fileno()
903
922
904 self.assertFalse(writer.closed)
923 self.assertFalse(writer.closed)
905
924
906 def test_fileno_file(self):
925 def test_fileno_file(self):
907 with tempfile.TemporaryFile("wb") as tf:
926 with tempfile.TemporaryFile("wb") as tf:
908 cctx = zstd.ZstdCompressor()
927 cctx = zstd.ZstdCompressor()
909 writer = cctx.stream_writer(tf)
928 writer = cctx.stream_writer(tf)
910
929
911 self.assertEqual(writer.fileno(), tf.fileno())
930 self.assertEqual(writer.fileno(), tf.fileno())
912
931
913 def test_close(self):
932 def test_close(self):
914 buffer = NonClosingBytesIO()
933 buffer = NonClosingBytesIO()
915 cctx = zstd.ZstdCompressor(level=1)
934 cctx = zstd.ZstdCompressor(level=1)
916 writer = cctx.stream_writer(buffer)
935 writer = cctx.stream_writer(buffer)
917
936
918 writer.write(b"foo" * 1024)
937 writer.write(b"foo" * 1024)
919 self.assertFalse(writer.closed)
938 self.assertFalse(writer.closed)
920 self.assertFalse(buffer.closed)
939 self.assertFalse(buffer.closed)
921 writer.close()
940 writer.close()
922 self.assertTrue(writer.closed)
941 self.assertTrue(writer.closed)
923 self.assertTrue(buffer.closed)
942 self.assertTrue(buffer.closed)
924
943
925 with self.assertRaisesRegex(ValueError, "stream is closed"):
944 with self.assertRaisesRegex(ValueError, "stream is closed"):
926 writer.write(b"foo")
945 writer.write(b"foo")
927
946
928 with self.assertRaisesRegex(ValueError, "stream is closed"):
947 with self.assertRaisesRegex(ValueError, "stream is closed"):
929 writer.flush()
948 writer.flush()
930
949
931 with self.assertRaisesRegex(ValueError, "stream is closed"):
950 with self.assertRaisesRegex(ValueError, "stream is closed"):
932 with writer:
951 with writer:
933 pass
952 pass
934
953
935 self.assertEqual(
954 self.assertEqual(
936 buffer.getvalue(),
955 buffer.getvalue(),
937 b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f"
956 b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f"
938 b"\x6f\x01\x00\xfa\xd3\x77\x43",
957 b"\x6f\x01\x00\xfa\xd3\x77\x43",
939 )
958 )
940
959
941 # Context manager exit should close stream.
960 # Context manager exit should close stream.
942 buffer = io.BytesIO()
961 buffer = io.BytesIO()
943 writer = cctx.stream_writer(buffer)
962 writer = cctx.stream_writer(buffer)
944
963
945 with writer:
964 with writer:
946 writer.write(b"foo")
965 writer.write(b"foo")
947
966
948 self.assertTrue(writer.closed)
967 self.assertTrue(writer.closed)
949
968
950 def test_empty(self):
969 def test_empty(self):
951 buffer = NonClosingBytesIO()
970 buffer = NonClosingBytesIO()
952 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
971 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
953 with cctx.stream_writer(buffer) as compressor:
972 with cctx.stream_writer(buffer) as compressor:
954 compressor.write(b"")
973 compressor.write(b"")
955
974
956 result = buffer.getvalue()
975 result = buffer.getvalue()
957 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
976 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
958
977
959 params = zstd.get_frame_parameters(result)
978 params = zstd.get_frame_parameters(result)
960 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
979 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
961 self.assertEqual(params.window_size, 524288)
980 self.assertEqual(params.window_size, 524288)
962 self.assertEqual(params.dict_id, 0)
981 self.assertEqual(params.dict_id, 0)
963 self.assertFalse(params.has_checksum)
982 self.assertFalse(params.has_checksum)
964
983
965 # Test without context manager.
984 # Test without context manager.
966 buffer = io.BytesIO()
985 buffer = io.BytesIO()
967 compressor = cctx.stream_writer(buffer)
986 compressor = cctx.stream_writer(buffer)
968 self.assertEqual(compressor.write(b""), 0)
987 self.assertEqual(compressor.write(b""), 0)
969 self.assertEqual(buffer.getvalue(), b"")
988 self.assertEqual(buffer.getvalue(), b"")
970 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9)
989 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9)
971 result = buffer.getvalue()
990 result = buffer.getvalue()
972 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
991 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
973
992
974 params = zstd.get_frame_parameters(result)
993 params = zstd.get_frame_parameters(result)
975 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
994 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
976 self.assertEqual(params.window_size, 524288)
995 self.assertEqual(params.window_size, 524288)
977 self.assertEqual(params.dict_id, 0)
996 self.assertEqual(params.dict_id, 0)
978 self.assertFalse(params.has_checksum)
997 self.assertFalse(params.has_checksum)
979
998
980 # Test write_return_read=True
999 # Test write_return_read=True
981 compressor = cctx.stream_writer(buffer, write_return_read=True)
1000 compressor = cctx.stream_writer(buffer, write_return_read=True)
982 self.assertEqual(compressor.write(b""), 0)
1001 self.assertEqual(compressor.write(b""), 0)
983
1002
984 def test_input_types(self):
1003 def test_input_types(self):
985 expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
1004 expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
986 cctx = zstd.ZstdCompressor(level=1)
1005 cctx = zstd.ZstdCompressor(level=1)
987
1006
988 mutable_array = bytearray(3)
1007 mutable_array = bytearray(3)
989 mutable_array[:] = b"foo"
1008 mutable_array[:] = b"foo"
990
1009
991 sources = [
1010 sources = [
992 memoryview(b"foo"),
1011 memoryview(b"foo"),
993 bytearray(b"foo"),
1012 bytearray(b"foo"),
994 mutable_array,
1013 mutable_array,
995 ]
1014 ]
996
1015
997 for source in sources:
1016 for source in sources:
998 buffer = NonClosingBytesIO()
1017 buffer = NonClosingBytesIO()
999 with cctx.stream_writer(buffer) as compressor:
1018 with cctx.stream_writer(buffer) as compressor:
1000 compressor.write(source)
1019 compressor.write(source)
1001
1020
1002 self.assertEqual(buffer.getvalue(), expected)
1021 self.assertEqual(buffer.getvalue(), expected)
1003
1022
1004 compressor = cctx.stream_writer(buffer, write_return_read=True)
1023 compressor = cctx.stream_writer(buffer, write_return_read=True)
1005 self.assertEqual(compressor.write(source), len(source))
1024 self.assertEqual(compressor.write(source), len(source))
1006
1025
1007 def test_multiple_compress(self):
1026 def test_multiple_compress(self):
1008 buffer = NonClosingBytesIO()
1027 buffer = NonClosingBytesIO()
1009 cctx = zstd.ZstdCompressor(level=5)
1028 cctx = zstd.ZstdCompressor(level=5)
1010 with cctx.stream_writer(buffer) as compressor:
1029 with cctx.stream_writer(buffer) as compressor:
1011 self.assertEqual(compressor.write(b"foo"), 0)
1030 self.assertEqual(compressor.write(b"foo"), 0)
1012 self.assertEqual(compressor.write(b"bar"), 0)
1031 self.assertEqual(compressor.write(b"bar"), 0)
1013 self.assertEqual(compressor.write(b"x" * 8192), 0)
1032 self.assertEqual(compressor.write(b"x" * 8192), 0)
1014
1033
1015 result = buffer.getvalue()
1034 result = buffer.getvalue()
1016 self.assertEqual(
1035 self.assertEqual(
1017 result,
1036 result,
1018 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
1037 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
1019 b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
1038 b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
1020 )
1039 )
1021
1040
1022 # Test without context manager.
1041 # Test without context manager.
1023 buffer = io.BytesIO()
1042 buffer = io.BytesIO()
1024 compressor = cctx.stream_writer(buffer)
1043 compressor = cctx.stream_writer(buffer)
1025 self.assertEqual(compressor.write(b"foo"), 0)
1044 self.assertEqual(compressor.write(b"foo"), 0)
1026 self.assertEqual(compressor.write(b"bar"), 0)
1045 self.assertEqual(compressor.write(b"bar"), 0)
1027 self.assertEqual(compressor.write(b"x" * 8192), 0)
1046 self.assertEqual(compressor.write(b"x" * 8192), 0)
1028 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
1047 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
1029 result = buffer.getvalue()
1048 result = buffer.getvalue()
1030 self.assertEqual(
1049 self.assertEqual(
1031 result,
1050 result,
1032 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
1051 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
1033 b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
1052 b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
1034 )
1053 )
1035
1054
1036 # Test with write_return_read=True.
1055 # Test with write_return_read=True.
1037 compressor = cctx.stream_writer(buffer, write_return_read=True)
1056 compressor = cctx.stream_writer(buffer, write_return_read=True)
1038 self.assertEqual(compressor.write(b"foo"), 3)
1057 self.assertEqual(compressor.write(b"foo"), 3)
1039 self.assertEqual(compressor.write(b"barbiz"), 6)
1058 self.assertEqual(compressor.write(b"barbiz"), 6)
1040 self.assertEqual(compressor.write(b"x" * 8192), 8192)
1059 self.assertEqual(compressor.write(b"x" * 8192), 8192)
1041
1060
1042 def test_dictionary(self):
1061 def test_dictionary(self):
1043 samples = []
1062 samples = []
1044 for i in range(128):
1063 for i in range(128):
1045 samples.append(b"foo" * 64)
1064 samples.append(b"foo" * 64)
1046 samples.append(b"bar" * 64)
1065 samples.append(b"bar" * 64)
1047 samples.append(b"foobar" * 64)
1066 samples.append(b"foobar" * 64)
1048
1067
1049 d = zstd.train_dictionary(8192, samples)
1068 d = zstd.train_dictionary(8192, samples)
1050
1069
1051 h = hashlib.sha1(d.as_bytes()).hexdigest()
1070 h = hashlib.sha1(d.as_bytes()).hexdigest()
1052 self.assertEqual(h, "7a2e59a876db958f74257141045af8f912e00d4e")
1071 self.assertEqual(h, "7a2e59a876db958f74257141045af8f912e00d4e")
1053
1072
1054 buffer = NonClosingBytesIO()
1073 buffer = NonClosingBytesIO()
1055 cctx = zstd.ZstdCompressor(level=9, dict_data=d)
1074 cctx = zstd.ZstdCompressor(level=9, dict_data=d)
1056 with cctx.stream_writer(buffer) as compressor:
1075 with cctx.stream_writer(buffer) as compressor:
1057 self.assertEqual(compressor.write(b"foo"), 0)
1076 self.assertEqual(compressor.write(b"foo"), 0)
1058 self.assertEqual(compressor.write(b"bar"), 0)
1077 self.assertEqual(compressor.write(b"bar"), 0)
1059 self.assertEqual(compressor.write(b"foo" * 16384), 0)
1078 self.assertEqual(compressor.write(b"foo" * 16384), 0)
1060
1079
1061 compressed = buffer.getvalue()
1080 compressed = buffer.getvalue()
1062
1081
1063 params = zstd.get_frame_parameters(compressed)
1082 params = zstd.get_frame_parameters(compressed)
1064 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1083 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1065 self.assertEqual(params.window_size, 2097152)
1084 self.assertEqual(params.window_size, 2097152)
1066 self.assertEqual(params.dict_id, d.dict_id())
1085 self.assertEqual(params.dict_id, d.dict_id())
1067 self.assertFalse(params.has_checksum)
1086 self.assertFalse(params.has_checksum)
1068
1087
1069 h = hashlib.sha1(compressed).hexdigest()
1088 h = hashlib.sha1(compressed).hexdigest()
1070 self.assertEqual(h, "0a7c05635061f58039727cdbe76388c6f4cfef06")
1089 self.assertEqual(h, "0a7c05635061f58039727cdbe76388c6f4cfef06")
1071
1090
1072 source = b"foo" + b"bar" + (b"foo" * 16384)
1091 source = b"foo" + b"bar" + (b"foo" * 16384)
1073
1092
1074 dctx = zstd.ZstdDecompressor(dict_data=d)
1093 dctx = zstd.ZstdDecompressor(dict_data=d)
1075
1094
1076 self.assertEqual(
1095 self.assertEqual(
1077 dctx.decompress(compressed, max_output_size=len(source)), source
1096 dctx.decompress(compressed, max_output_size=len(source)), source
1078 )
1097 )
1079
1098
1080 def test_compression_params(self):
1099 def test_compression_params(self):
1081 params = zstd.ZstdCompressionParameters(
1100 params = zstd.ZstdCompressionParameters(
1082 window_log=20,
1101 window_log=20,
1083 chain_log=6,
1102 chain_log=6,
1084 hash_log=12,
1103 hash_log=12,
1085 min_match=5,
1104 min_match=5,
1086 search_log=4,
1105 search_log=4,
1087 target_length=10,
1106 target_length=10,
1088 strategy=zstd.STRATEGY_FAST,
1107 strategy=zstd.STRATEGY_FAST,
1089 )
1108 )
1090
1109
1091 buffer = NonClosingBytesIO()
1110 buffer = NonClosingBytesIO()
1092 cctx = zstd.ZstdCompressor(compression_params=params)
1111 cctx = zstd.ZstdCompressor(compression_params=params)
1093 with cctx.stream_writer(buffer) as compressor:
1112 with cctx.stream_writer(buffer) as compressor:
1094 self.assertEqual(compressor.write(b"foo"), 0)
1113 self.assertEqual(compressor.write(b"foo"), 0)
1095 self.assertEqual(compressor.write(b"bar"), 0)
1114 self.assertEqual(compressor.write(b"bar"), 0)
1096 self.assertEqual(compressor.write(b"foobar" * 16384), 0)
1115 self.assertEqual(compressor.write(b"foobar" * 16384), 0)
1097
1116
1098 compressed = buffer.getvalue()
1117 compressed = buffer.getvalue()
1099
1118
1100 params = zstd.get_frame_parameters(compressed)
1119 params = zstd.get_frame_parameters(compressed)
1101 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1120 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1102 self.assertEqual(params.window_size, 1048576)
1121 self.assertEqual(params.window_size, 1048576)
1103 self.assertEqual(params.dict_id, 0)
1122 self.assertEqual(params.dict_id, 0)
1104 self.assertFalse(params.has_checksum)
1123 self.assertFalse(params.has_checksum)
1105
1124
1106 h = hashlib.sha1(compressed).hexdigest()
1125 h = hashlib.sha1(compressed).hexdigest()
1107 self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b")
1126 self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b")
1108
1127
1109 def test_write_checksum(self):
1128 def test_write_checksum(self):
1110 no_checksum = NonClosingBytesIO()
1129 no_checksum = NonClosingBytesIO()
1111 cctx = zstd.ZstdCompressor(level=1)
1130 cctx = zstd.ZstdCompressor(level=1)
1112 with cctx.stream_writer(no_checksum) as compressor:
1131 with cctx.stream_writer(no_checksum) as compressor:
1113 self.assertEqual(compressor.write(b"foobar"), 0)
1132 self.assertEqual(compressor.write(b"foobar"), 0)
1114
1133
1115 with_checksum = NonClosingBytesIO()
1134 with_checksum = NonClosingBytesIO()
1116 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
1135 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
1117 with cctx.stream_writer(with_checksum) as compressor:
1136 with cctx.stream_writer(with_checksum) as compressor:
1118 self.assertEqual(compressor.write(b"foobar"), 0)
1137 self.assertEqual(compressor.write(b"foobar"), 0)
1119
1138
1120 no_params = zstd.get_frame_parameters(no_checksum.getvalue())
1139 no_params = zstd.get_frame_parameters(no_checksum.getvalue())
1121 with_params = zstd.get_frame_parameters(with_checksum.getvalue())
1140 with_params = zstd.get_frame_parameters(with_checksum.getvalue())
1122 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1141 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1123 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1142 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1124 self.assertEqual(no_params.dict_id, 0)
1143 self.assertEqual(no_params.dict_id, 0)
1125 self.assertEqual(with_params.dict_id, 0)
1144 self.assertEqual(with_params.dict_id, 0)
1126 self.assertFalse(no_params.has_checksum)
1145 self.assertFalse(no_params.has_checksum)
1127 self.assertTrue(with_params.has_checksum)
1146 self.assertTrue(with_params.has_checksum)
1128
1147
1129 self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4)
1148 self.assertEqual(
1149 len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4
1150 )
1130
1151
1131 def test_write_content_size(self):
1152 def test_write_content_size(self):
1132 no_size = NonClosingBytesIO()
1153 no_size = NonClosingBytesIO()
1133 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1154 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1134 with cctx.stream_writer(no_size) as compressor:
1155 with cctx.stream_writer(no_size) as compressor:
1135 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1156 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1136
1157
1137 with_size = NonClosingBytesIO()
1158 with_size = NonClosingBytesIO()
1138 cctx = zstd.ZstdCompressor(level=1)
1159 cctx = zstd.ZstdCompressor(level=1)
1139 with cctx.stream_writer(with_size) as compressor:
1160 with cctx.stream_writer(with_size) as compressor:
1140 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1161 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1141
1162
1142 # Source size is not known in streaming mode, so header not
1163 # Source size is not known in streaming mode, so header not
1143 # written.
1164 # written.
1144 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
1165 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
1145
1166
1146 # Declaring size will write the header.
1167 # Declaring size will write the header.
1147 with_size = NonClosingBytesIO()
1168 with_size = NonClosingBytesIO()
1148 with cctx.stream_writer(with_size, size=len(b"foobar" * 256)) as compressor:
1169 with cctx.stream_writer(
1170 with_size, size=len(b"foobar" * 256)
1171 ) as compressor:
1149 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1172 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1150
1173
1151 no_params = zstd.get_frame_parameters(no_size.getvalue())
1174 no_params = zstd.get_frame_parameters(no_size.getvalue())
1152 with_params = zstd.get_frame_parameters(with_size.getvalue())
1175 with_params = zstd.get_frame_parameters(with_size.getvalue())
1153 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1176 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1154 self.assertEqual(with_params.content_size, 1536)
1177 self.assertEqual(with_params.content_size, 1536)
1155 self.assertEqual(no_params.dict_id, 0)
1178 self.assertEqual(no_params.dict_id, 0)
1156 self.assertEqual(with_params.dict_id, 0)
1179 self.assertEqual(with_params.dict_id, 0)
1157 self.assertFalse(no_params.has_checksum)
1180 self.assertFalse(no_params.has_checksum)
1158 self.assertFalse(with_params.has_checksum)
1181 self.assertFalse(with_params.has_checksum)
1159
1182
1160 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
1183 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
1161
1184
1162 def test_no_dict_id(self):
1185 def test_no_dict_id(self):
1163 samples = []
1186 samples = []
1164 for i in range(128):
1187 for i in range(128):
1165 samples.append(b"foo" * 64)
1188 samples.append(b"foo" * 64)
1166 samples.append(b"bar" * 64)
1189 samples.append(b"bar" * 64)
1167 samples.append(b"foobar" * 64)
1190 samples.append(b"foobar" * 64)
1168
1191
1169 d = zstd.train_dictionary(1024, samples)
1192 d = zstd.train_dictionary(1024, samples)
1170
1193
1171 with_dict_id = NonClosingBytesIO()
1194 with_dict_id = NonClosingBytesIO()
1172 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
1195 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
1173 with cctx.stream_writer(with_dict_id) as compressor:
1196 with cctx.stream_writer(with_dict_id) as compressor:
1174 self.assertEqual(compressor.write(b"foobarfoobar"), 0)
1197 self.assertEqual(compressor.write(b"foobarfoobar"), 0)
1175
1198
1176 self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03")
1199 self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03")
1177
1200
1178 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
1201 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
1179 no_dict_id = NonClosingBytesIO()
1202 no_dict_id = NonClosingBytesIO()
1180 with cctx.stream_writer(no_dict_id) as compressor:
1203 with cctx.stream_writer(no_dict_id) as compressor:
1181 self.assertEqual(compressor.write(b"foobarfoobar"), 0)
1204 self.assertEqual(compressor.write(b"foobarfoobar"), 0)
1182
1205
1183 self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00")
1206 self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00")
1184
1207
1185 no_params = zstd.get_frame_parameters(no_dict_id.getvalue())
1208 no_params = zstd.get_frame_parameters(no_dict_id.getvalue())
1186 with_params = zstd.get_frame_parameters(with_dict_id.getvalue())
1209 with_params = zstd.get_frame_parameters(with_dict_id.getvalue())
1187 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1210 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1188 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1211 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1189 self.assertEqual(no_params.dict_id, 0)
1212 self.assertEqual(no_params.dict_id, 0)
1190 self.assertEqual(with_params.dict_id, d.dict_id())
1213 self.assertEqual(with_params.dict_id, d.dict_id())
1191 self.assertFalse(no_params.has_checksum)
1214 self.assertFalse(no_params.has_checksum)
1192 self.assertFalse(with_params.has_checksum)
1215 self.assertFalse(with_params.has_checksum)
1193
1216
1194 self.assertEqual(len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4)
1217 self.assertEqual(
1218 len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4
1219 )
1195
1220
1196 def test_memory_size(self):
1221 def test_memory_size(self):
1197 cctx = zstd.ZstdCompressor(level=3)
1222 cctx = zstd.ZstdCompressor(level=3)
1198 buffer = io.BytesIO()
1223 buffer = io.BytesIO()
1199 with cctx.stream_writer(buffer) as compressor:
1224 with cctx.stream_writer(buffer) as compressor:
1200 compressor.write(b"foo")
1225 compressor.write(b"foo")
1201 size = compressor.memory_size()
1226 size = compressor.memory_size()
1202
1227
1203 self.assertGreater(size, 100000)
1228 self.assertGreater(size, 100000)
1204
1229
1205 def test_write_size(self):
1230 def test_write_size(self):
1206 cctx = zstd.ZstdCompressor(level=3)
1231 cctx = zstd.ZstdCompressor(level=3)
1207 dest = OpCountingBytesIO()
1232 dest = OpCountingBytesIO()
1208 with cctx.stream_writer(dest, write_size=1) as compressor:
1233 with cctx.stream_writer(dest, write_size=1) as compressor:
1209 self.assertEqual(compressor.write(b"foo"), 0)
1234 self.assertEqual(compressor.write(b"foo"), 0)
1210 self.assertEqual(compressor.write(b"bar"), 0)
1235 self.assertEqual(compressor.write(b"bar"), 0)
1211 self.assertEqual(compressor.write(b"foobar"), 0)
1236 self.assertEqual(compressor.write(b"foobar"), 0)
1212
1237
1213 self.assertEqual(len(dest.getvalue()), dest._write_count)
1238 self.assertEqual(len(dest.getvalue()), dest._write_count)
1214
1239
1215 def test_flush_repeated(self):
1240 def test_flush_repeated(self):
1216 cctx = zstd.ZstdCompressor(level=3)
1241 cctx = zstd.ZstdCompressor(level=3)
1217 dest = OpCountingBytesIO()
1242 dest = OpCountingBytesIO()
1218 with cctx.stream_writer(dest) as compressor:
1243 with cctx.stream_writer(dest) as compressor:
1219 self.assertEqual(compressor.write(b"foo"), 0)
1244 self.assertEqual(compressor.write(b"foo"), 0)
1220 self.assertEqual(dest._write_count, 0)
1245 self.assertEqual(dest._write_count, 0)
1221 self.assertEqual(compressor.flush(), 12)
1246 self.assertEqual(compressor.flush(), 12)
1222 self.assertEqual(dest._write_count, 1)
1247 self.assertEqual(dest._write_count, 1)
1223 self.assertEqual(compressor.write(b"bar"), 0)
1248 self.assertEqual(compressor.write(b"bar"), 0)
1224 self.assertEqual(dest._write_count, 1)
1249 self.assertEqual(dest._write_count, 1)
1225 self.assertEqual(compressor.flush(), 6)
1250 self.assertEqual(compressor.flush(), 6)
1226 self.assertEqual(dest._write_count, 2)
1251 self.assertEqual(dest._write_count, 2)
1227 self.assertEqual(compressor.write(b"baz"), 0)
1252 self.assertEqual(compressor.write(b"baz"), 0)
1228
1253
1229 self.assertEqual(dest._write_count, 3)
1254 self.assertEqual(dest._write_count, 3)
1230
1255
1231 def test_flush_empty_block(self):
1256 def test_flush_empty_block(self):
1232 cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
1257 cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
1233 dest = OpCountingBytesIO()
1258 dest = OpCountingBytesIO()
1234 with cctx.stream_writer(dest) as compressor:
1259 with cctx.stream_writer(dest) as compressor:
1235 self.assertEqual(compressor.write(b"foobar" * 8192), 0)
1260 self.assertEqual(compressor.write(b"foobar" * 8192), 0)
1236 count = dest._write_count
1261 count = dest._write_count
1237 offset = dest.tell()
1262 offset = dest.tell()
1238 self.assertEqual(compressor.flush(), 23)
1263 self.assertEqual(compressor.flush(), 23)
1239 self.assertGreater(dest._write_count, count)
1264 self.assertGreater(dest._write_count, count)
1240 self.assertGreater(dest.tell(), offset)
1265 self.assertGreater(dest.tell(), offset)
1241 offset = dest.tell()
1266 offset = dest.tell()
1242 # Ending the write here should cause an empty block to be written
1267 # Ending the write here should cause an empty block to be written
1243 # to denote end of frame.
1268 # to denote end of frame.
1244
1269
1245 trailing = dest.getvalue()[offset:]
1270 trailing = dest.getvalue()[offset:]
1246 # 3 bytes block header + 4 bytes frame checksum
1271 # 3 bytes block header + 4 bytes frame checksum
1247 self.assertEqual(len(trailing), 7)
1272 self.assertEqual(len(trailing), 7)
1248
1273
1249 header = trailing[0:3]
1274 header = trailing[0:3]
1250 self.assertEqual(header, b"\x01\x00\x00")
1275 self.assertEqual(header, b"\x01\x00\x00")
1251
1276
1252 def test_flush_frame(self):
1277 def test_flush_frame(self):
1253 cctx = zstd.ZstdCompressor(level=3)
1278 cctx = zstd.ZstdCompressor(level=3)
1254 dest = OpCountingBytesIO()
1279 dest = OpCountingBytesIO()
1255
1280
1256 with cctx.stream_writer(dest) as compressor:
1281 with cctx.stream_writer(dest) as compressor:
1257 self.assertEqual(compressor.write(b"foobar" * 8192), 0)
1282 self.assertEqual(compressor.write(b"foobar" * 8192), 0)
1258 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
1283 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
1259 compressor.write(b"biz" * 16384)
1284 compressor.write(b"biz" * 16384)
1260
1285
1261 self.assertEqual(
1286 self.assertEqual(
1262 dest.getvalue(),
1287 dest.getvalue(),
1263 # Frame 1.
1288 # Frame 1.
1264 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f"
1289 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f"
1265 b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08"
1290 b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08"
1266 # Frame 2.
1291 # Frame 2.
1267 b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a"
1292 b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a"
1268 b"\x01\x00\xfa\x3f\x75\x37\x04",
1293 b"\x01\x00\xfa\x3f\x75\x37\x04",
1269 )
1294 )
1270
1295
1271 def test_bad_flush_mode(self):
1296 def test_bad_flush_mode(self):
1272 cctx = zstd.ZstdCompressor()
1297 cctx = zstd.ZstdCompressor()
1273 dest = io.BytesIO()
1298 dest = io.BytesIO()
1274 with cctx.stream_writer(dest) as compressor:
1299 with cctx.stream_writer(dest) as compressor:
1275 with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"):
1300 with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"):
1276 compressor.flush(flush_mode=42)
1301 compressor.flush(flush_mode=42)
1277
1302
1278 def test_multithreaded(self):
1303 def test_multithreaded(self):
1279 dest = NonClosingBytesIO()
1304 dest = NonClosingBytesIO()
1280 cctx = zstd.ZstdCompressor(threads=2)
1305 cctx = zstd.ZstdCompressor(threads=2)
1281 with cctx.stream_writer(dest) as compressor:
1306 with cctx.stream_writer(dest) as compressor:
1282 compressor.write(b"a" * 1048576)
1307 compressor.write(b"a" * 1048576)
1283 compressor.write(b"b" * 1048576)
1308 compressor.write(b"b" * 1048576)
1284 compressor.write(b"c" * 1048576)
1309 compressor.write(b"c" * 1048576)
1285
1310
1286 self.assertEqual(len(dest.getvalue()), 111)
1311 self.assertEqual(len(dest.getvalue()), 111)
1287
1312
1288 def test_tell(self):
1313 def test_tell(self):
1289 dest = io.BytesIO()
1314 dest = io.BytesIO()
1290 cctx = zstd.ZstdCompressor()
1315 cctx = zstd.ZstdCompressor()
1291 with cctx.stream_writer(dest) as compressor:
1316 with cctx.stream_writer(dest) as compressor:
1292 self.assertEqual(compressor.tell(), 0)
1317 self.assertEqual(compressor.tell(), 0)
1293
1318
1294 for i in range(256):
1319 for i in range(256):
1295 compressor.write(b"foo" * (i + 1))
1320 compressor.write(b"foo" * (i + 1))
1296 self.assertEqual(compressor.tell(), dest.tell())
1321 self.assertEqual(compressor.tell(), dest.tell())
1297
1322
1298 def test_bad_size(self):
1323 def test_bad_size(self):
1299 cctx = zstd.ZstdCompressor()
1324 cctx = zstd.ZstdCompressor()
1300
1325
1301 dest = io.BytesIO()
1326 dest = io.BytesIO()
1302
1327
1303 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
1328 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
1304 with cctx.stream_writer(dest, size=2) as compressor:
1329 with cctx.stream_writer(dest, size=2) as compressor:
1305 compressor.write(b"foo")
1330 compressor.write(b"foo")
1306
1331
1307 # Test another operation.
1332 # Test another operation.
1308 with cctx.stream_writer(dest, size=42):
1333 with cctx.stream_writer(dest, size=42):
1309 pass
1334 pass
1310
1335
1311 def test_tarfile_compat(self):
1336 def test_tarfile_compat(self):
1312 dest = NonClosingBytesIO()
1337 dest = NonClosingBytesIO()
1313 cctx = zstd.ZstdCompressor()
1338 cctx = zstd.ZstdCompressor()
1314 with cctx.stream_writer(dest) as compressor:
1339 with cctx.stream_writer(dest) as compressor:
1315 with tarfile.open("tf", mode="w|", fileobj=compressor) as tf:
1340 with tarfile.open("tf", mode="w|", fileobj=compressor) as tf:
1316 tf.add(__file__, "test_compressor.py")
1341 tf.add(__file__, "test_compressor.py")
1317
1342
1318 dest = io.BytesIO(dest.getvalue())
1343 dest = io.BytesIO(dest.getvalue())
1319
1344
1320 dctx = zstd.ZstdDecompressor()
1345 dctx = zstd.ZstdDecompressor()
1321 with dctx.stream_reader(dest) as reader:
1346 with dctx.stream_reader(dest) as reader:
1322 with tarfile.open(mode="r|", fileobj=reader) as tf:
1347 with tarfile.open(mode="r|", fileobj=reader) as tf:
1323 for member in tf:
1348 for member in tf:
1324 self.assertEqual(member.name, "test_compressor.py")
1349 self.assertEqual(member.name, "test_compressor.py")
1325
1350
1326
1351
1327 @make_cffi
1352 @make_cffi
1328 class TestCompressor_read_to_iter(TestCase):
1353 class TestCompressor_read_to_iter(TestCase):
1329 def test_type_validation(self):
1354 def test_type_validation(self):
1330 cctx = zstd.ZstdCompressor()
1355 cctx = zstd.ZstdCompressor()
1331
1356
1332 # Object with read() works.
1357 # Object with read() works.
1333 for chunk in cctx.read_to_iter(io.BytesIO()):
1358 for chunk in cctx.read_to_iter(io.BytesIO()):
1334 pass
1359 pass
1335
1360
1336 # Buffer protocol works.
1361 # Buffer protocol works.
1337 for chunk in cctx.read_to_iter(b"foobar"):
1362 for chunk in cctx.read_to_iter(b"foobar"):
1338 pass
1363 pass
1339
1364
1340 with self.assertRaisesRegex(ValueError, "must pass an object with a read"):
1365 with self.assertRaisesRegex(
1366 ValueError, "must pass an object with a read"
1367 ):
1341 for chunk in cctx.read_to_iter(True):
1368 for chunk in cctx.read_to_iter(True):
1342 pass
1369 pass
1343
1370
1344 def test_read_empty(self):
1371 def test_read_empty(self):
1345 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1372 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1346
1373
1347 source = io.BytesIO()
1374 source = io.BytesIO()
1348 it = cctx.read_to_iter(source)
1375 it = cctx.read_to_iter(source)
1349 chunks = list(it)
1376 chunks = list(it)
1350 self.assertEqual(len(chunks), 1)
1377 self.assertEqual(len(chunks), 1)
1351 compressed = b"".join(chunks)
1378 compressed = b"".join(chunks)
1352 self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
1379 self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
1353
1380
1354 # And again with the buffer protocol.
1381 # And again with the buffer protocol.
1355 it = cctx.read_to_iter(b"")
1382 it = cctx.read_to_iter(b"")
1356 chunks = list(it)
1383 chunks = list(it)
1357 self.assertEqual(len(chunks), 1)
1384 self.assertEqual(len(chunks), 1)
1358 compressed2 = b"".join(chunks)
1385 compressed2 = b"".join(chunks)
1359 self.assertEqual(compressed2, compressed)
1386 self.assertEqual(compressed2, compressed)
1360
1387
1361 def test_read_large(self):
1388 def test_read_large(self):
1362 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1389 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1363
1390
1364 source = io.BytesIO()
1391 source = io.BytesIO()
1365 source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
1392 source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
1366 source.write(b"o")
1393 source.write(b"o")
1367 source.seek(0)
1394 source.seek(0)
1368
1395
1369 # Creating an iterator should not perform any compression until
1396 # Creating an iterator should not perform any compression until
1370 # first read.
1397 # first read.
1371 it = cctx.read_to_iter(source, size=len(source.getvalue()))
1398 it = cctx.read_to_iter(source, size=len(source.getvalue()))
1372 self.assertEqual(source.tell(), 0)
1399 self.assertEqual(source.tell(), 0)
1373
1400
1374 # We should have exactly 2 output chunks.
1401 # We should have exactly 2 output chunks.
1375 chunks = []
1402 chunks = []
1376 chunk = next(it)
1403 chunk = next(it)
1377 self.assertIsNotNone(chunk)
1404 self.assertIsNotNone(chunk)
1378 self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
1405 self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
1379 chunks.append(chunk)
1406 chunks.append(chunk)
1380 chunk = next(it)
1407 chunk = next(it)
1381 self.assertIsNotNone(chunk)
1408 self.assertIsNotNone(chunk)
1382 chunks.append(chunk)
1409 chunks.append(chunk)
1383
1410
1384 self.assertEqual(source.tell(), len(source.getvalue()))
1411 self.assertEqual(source.tell(), len(source.getvalue()))
1385
1412
1386 with self.assertRaises(StopIteration):
1413 with self.assertRaises(StopIteration):
1387 next(it)
1414 next(it)
1388
1415
1389 # And again for good measure.
1416 # And again for good measure.
1390 with self.assertRaises(StopIteration):
1417 with self.assertRaises(StopIteration):
1391 next(it)
1418 next(it)
1392
1419
1393 # We should get the same output as the one-shot compression mechanism.
1420 # We should get the same output as the one-shot compression mechanism.
1394 self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
1421 self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
1395
1422
1396 params = zstd.get_frame_parameters(b"".join(chunks))
1423 params = zstd.get_frame_parameters(b"".join(chunks))
1397 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1424 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1398 self.assertEqual(params.window_size, 262144)
1425 self.assertEqual(params.window_size, 262144)
1399 self.assertEqual(params.dict_id, 0)
1426 self.assertEqual(params.dict_id, 0)
1400 self.assertFalse(params.has_checksum)
1427 self.assertFalse(params.has_checksum)
1401
1428
1402 # Now check the buffer protocol.
1429 # Now check the buffer protocol.
1403 it = cctx.read_to_iter(source.getvalue())
1430 it = cctx.read_to_iter(source.getvalue())
1404 chunks = list(it)
1431 chunks = list(it)
1405 self.assertEqual(len(chunks), 2)
1432 self.assertEqual(len(chunks), 2)
1406
1433
1407 params = zstd.get_frame_parameters(b"".join(chunks))
1434 params = zstd.get_frame_parameters(b"".join(chunks))
1408 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1435 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1409 # self.assertEqual(params.window_size, 262144)
1436 # self.assertEqual(params.window_size, 262144)
1410 self.assertEqual(params.dict_id, 0)
1437 self.assertEqual(params.dict_id, 0)
1411 self.assertFalse(params.has_checksum)
1438 self.assertFalse(params.has_checksum)
1412
1439
1413 self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
1440 self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
1414
1441
1415 def test_read_write_size(self):
1442 def test_read_write_size(self):
1416 source = OpCountingBytesIO(b"foobarfoobar")
1443 source = OpCountingBytesIO(b"foobarfoobar")
1417 cctx = zstd.ZstdCompressor(level=3)
1444 cctx = zstd.ZstdCompressor(level=3)
1418 for chunk in cctx.read_to_iter(source, read_size=1, write_size=1):
1445 for chunk in cctx.read_to_iter(source, read_size=1, write_size=1):
1419 self.assertEqual(len(chunk), 1)
1446 self.assertEqual(len(chunk), 1)
1420
1447
1421 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
1448 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
1422
1449
1423 def test_multithreaded(self):
1450 def test_multithreaded(self):
1424 source = io.BytesIO()
1451 source = io.BytesIO()
1425 source.write(b"a" * 1048576)
1452 source.write(b"a" * 1048576)
1426 source.write(b"b" * 1048576)
1453 source.write(b"b" * 1048576)
1427 source.write(b"c" * 1048576)
1454 source.write(b"c" * 1048576)
1428 source.seek(0)
1455 source.seek(0)
1429
1456
1430 cctx = zstd.ZstdCompressor(threads=2)
1457 cctx = zstd.ZstdCompressor(threads=2)
1431
1458
1432 compressed = b"".join(cctx.read_to_iter(source))
1459 compressed = b"".join(cctx.read_to_iter(source))
1433 self.assertEqual(len(compressed), 111)
1460 self.assertEqual(len(compressed), 111)
1434
1461
1435 def test_bad_size(self):
1462 def test_bad_size(self):
1436 cctx = zstd.ZstdCompressor()
1463 cctx = zstd.ZstdCompressor()
1437
1464
1438 source = io.BytesIO(b"a" * 42)
1465 source = io.BytesIO(b"a" * 42)
1439
1466
1440 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
1467 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
1441 b"".join(cctx.read_to_iter(source, size=2))
1468 b"".join(cctx.read_to_iter(source, size=2))
1442
1469
1443 # Test another operation on errored compressor.
1470 # Test another operation on errored compressor.
1444 b"".join(cctx.read_to_iter(source))
1471 b"".join(cctx.read_to_iter(source))
1445
1472
1446
1473
1447 @make_cffi
1474 @make_cffi
1448 class TestCompressor_chunker(TestCase):
1475 class TestCompressor_chunker(TestCase):
1449 def test_empty(self):
1476 def test_empty(self):
1450 cctx = zstd.ZstdCompressor(write_content_size=False)
1477 cctx = zstd.ZstdCompressor(write_content_size=False)
1451 chunker = cctx.chunker()
1478 chunker = cctx.chunker()
1452
1479
1453 it = chunker.compress(b"")
1480 it = chunker.compress(b"")
1454
1481
1455 with self.assertRaises(StopIteration):
1482 with self.assertRaises(StopIteration):
1456 next(it)
1483 next(it)
1457
1484
1458 it = chunker.finish()
1485 it = chunker.finish()
1459
1486
1460 self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x58\x01\x00\x00")
1487 self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x58\x01\x00\x00")
1461
1488
1462 with self.assertRaises(StopIteration):
1489 with self.assertRaises(StopIteration):
1463 next(it)
1490 next(it)
1464
1491
1465 def test_simple_input(self):
1492 def test_simple_input(self):
1466 cctx = zstd.ZstdCompressor()
1493 cctx = zstd.ZstdCompressor()
1467 chunker = cctx.chunker()
1494 chunker = cctx.chunker()
1468
1495
1469 it = chunker.compress(b"foobar")
1496 it = chunker.compress(b"foobar")
1470
1497
1471 with self.assertRaises(StopIteration):
1498 with self.assertRaises(StopIteration):
1472 next(it)
1499 next(it)
1473
1500
1474 it = chunker.compress(b"baz" * 30)
1501 it = chunker.compress(b"baz" * 30)
1475
1502
1476 with self.assertRaises(StopIteration):
1503 with self.assertRaises(StopIteration):
1477 next(it)
1504 next(it)
1478
1505
1479 it = chunker.finish()
1506 it = chunker.finish()
1480
1507
1481 self.assertEqual(
1508 self.assertEqual(
1482 next(it),
1509 next(it),
1483 b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f"
1510 b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f"
1484 b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e",
1511 b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e",
1485 )
1512 )
1486
1513
1487 with self.assertRaises(StopIteration):
1514 with self.assertRaises(StopIteration):
1488 next(it)
1515 next(it)
1489
1516
1490 def test_input_size(self):
1517 def test_input_size(self):
1491 cctx = zstd.ZstdCompressor()
1518 cctx = zstd.ZstdCompressor()
1492 chunker = cctx.chunker(size=1024)
1519 chunker = cctx.chunker(size=1024)
1493
1520
1494 it = chunker.compress(b"x" * 1000)
1521 it = chunker.compress(b"x" * 1000)
1495
1522
1496 with self.assertRaises(StopIteration):
1523 with self.assertRaises(StopIteration):
1497 next(it)
1524 next(it)
1498
1525
1499 it = chunker.compress(b"y" * 24)
1526 it = chunker.compress(b"y" * 24)
1500
1527
1501 with self.assertRaises(StopIteration):
1528 with self.assertRaises(StopIteration):
1502 next(it)
1529 next(it)
1503
1530
1504 chunks = list(chunker.finish())
1531 chunks = list(chunker.finish())
1505
1532
1506 self.assertEqual(
1533 self.assertEqual(
1507 chunks,
1534 chunks,
1508 [
1535 [
1509 b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00"
1536 b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00"
1510 b"\xa0\x16\xe3\x2b\x80\x05"
1537 b"\xa0\x16\xe3\x2b\x80\x05"
1511 ],
1538 ],
1512 )
1539 )
1513
1540
1514 dctx = zstd.ZstdDecompressor()
1541 dctx = zstd.ZstdDecompressor()
1515
1542
1516 self.assertEqual(dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24))
1543 self.assertEqual(
1544 dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24)
1545 )
1517
1546
1518 def test_small_chunk_size(self):
1547 def test_small_chunk_size(self):
1519 cctx = zstd.ZstdCompressor()
1548 cctx = zstd.ZstdCompressor()
1520 chunker = cctx.chunker(chunk_size=1)
1549 chunker = cctx.chunker(chunk_size=1)
1521
1550
1522 chunks = list(chunker.compress(b"foo" * 1024))
1551 chunks = list(chunker.compress(b"foo" * 1024))
1523 self.assertEqual(chunks, [])
1552 self.assertEqual(chunks, [])
1524
1553
1525 chunks = list(chunker.finish())
1554 chunks = list(chunker.finish())
1526 self.assertTrue(all(len(chunk) == 1 for chunk in chunks))
1555 self.assertTrue(all(len(chunk) == 1 for chunk in chunks))
1527
1556
1528 self.assertEqual(
1557 self.assertEqual(
1529 b"".join(chunks),
1558 b"".join(chunks),
1530 b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00"
1559 b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00"
1531 b"\xfa\xd3\x77\x43",
1560 b"\xfa\xd3\x77\x43",
1532 )
1561 )
1533
1562
1534 dctx = zstd.ZstdDecompressor()
1563 dctx = zstd.ZstdDecompressor()
1535 self.assertEqual(
1564 self.assertEqual(
1536 dctx.decompress(b"".join(chunks), max_output_size=10000), b"foo" * 1024
1565 dctx.decompress(b"".join(chunks), max_output_size=10000),
1566 b"foo" * 1024,
1537 )
1567 )
1538
1568
1539 def test_input_types(self):
1569 def test_input_types(self):
1540 cctx = zstd.ZstdCompressor()
1570 cctx = zstd.ZstdCompressor()
1541
1571
1542 mutable_array = bytearray(3)
1572 mutable_array = bytearray(3)
1543 mutable_array[:] = b"foo"
1573 mutable_array[:] = b"foo"
1544
1574
1545 sources = [
1575 sources = [
1546 memoryview(b"foo"),
1576 memoryview(b"foo"),
1547 bytearray(b"foo"),
1577 bytearray(b"foo"),
1548 mutable_array,
1578 mutable_array,
1549 ]
1579 ]
1550
1580
1551 for source in sources:
1581 for source in sources:
1552 chunker = cctx.chunker()
1582 chunker = cctx.chunker()
1553
1583
1554 self.assertEqual(list(chunker.compress(source)), [])
1584 self.assertEqual(list(chunker.compress(source)), [])
1555 self.assertEqual(
1585 self.assertEqual(
1556 list(chunker.finish()),
1586 list(chunker.finish()),
1557 [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"],
1587 [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"],
1558 )
1588 )
1559
1589
1560 def test_flush(self):
1590 def test_flush(self):
1561 cctx = zstd.ZstdCompressor()
1591 cctx = zstd.ZstdCompressor()
1562 chunker = cctx.chunker()
1592 chunker = cctx.chunker()
1563
1593
1564 self.assertEqual(list(chunker.compress(b"foo" * 1024)), [])
1594 self.assertEqual(list(chunker.compress(b"foo" * 1024)), [])
1565 self.assertEqual(list(chunker.compress(b"bar" * 1024)), [])
1595 self.assertEqual(list(chunker.compress(b"bar" * 1024)), [])
1566
1596
1567 chunks1 = list(chunker.flush())
1597 chunks1 = list(chunker.flush())
1568
1598
1569 self.assertEqual(
1599 self.assertEqual(
1570 chunks1,
1600 chunks1,
1571 [
1601 [
1572 b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72"
1602 b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72"
1573 b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02"
1603 b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02"
1574 ],
1604 ],
1575 )
1605 )
1576
1606
1577 self.assertEqual(list(chunker.flush()), [])
1607 self.assertEqual(list(chunker.flush()), [])
1578 self.assertEqual(list(chunker.flush()), [])
1608 self.assertEqual(list(chunker.flush()), [])
1579
1609
1580 self.assertEqual(list(chunker.compress(b"baz" * 1024)), [])
1610 self.assertEqual(list(chunker.compress(b"baz" * 1024)), [])
1581
1611
1582 chunks2 = list(chunker.flush())
1612 chunks2 = list(chunker.flush())
1583 self.assertEqual(len(chunks2), 1)
1613 self.assertEqual(len(chunks2), 1)
1584
1614
1585 chunks3 = list(chunker.finish())
1615 chunks3 = list(chunker.finish())
1586 self.assertEqual(len(chunks2), 1)
1616 self.assertEqual(len(chunks2), 1)
1587
1617
1588 dctx = zstd.ZstdDecompressor()
1618 dctx = zstd.ZstdDecompressor()
1589
1619
1590 self.assertEqual(
1620 self.assertEqual(
1591 dctx.decompress(
1621 dctx.decompress(
1592 b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000
1622 b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000
1593 ),
1623 ),
1594 (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024),
1624 (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024),
1595 )
1625 )
1596
1626
1597 def test_compress_after_finish(self):
1627 def test_compress_after_finish(self):
1598 cctx = zstd.ZstdCompressor()
1628 cctx = zstd.ZstdCompressor()
1599 chunker = cctx.chunker()
1629 chunker = cctx.chunker()
1600
1630
1601 list(chunker.compress(b"foo"))
1631 list(chunker.compress(b"foo"))
1602 list(chunker.finish())
1632 list(chunker.finish())
1603
1633
1604 with self.assertRaisesRegex(
1634 with self.assertRaisesRegex(
1605 zstd.ZstdError, r"cannot call compress\(\) after compression finished"
1635 zstd.ZstdError,
1636 r"cannot call compress\(\) after compression finished",
1606 ):
1637 ):
1607 list(chunker.compress(b"foo"))
1638 list(chunker.compress(b"foo"))
1608
1639
1609 def test_flush_after_finish(self):
1640 def test_flush_after_finish(self):
1610 cctx = zstd.ZstdCompressor()
1641 cctx = zstd.ZstdCompressor()
1611 chunker = cctx.chunker()
1642 chunker = cctx.chunker()
1612
1643
1613 list(chunker.compress(b"foo"))
1644 list(chunker.compress(b"foo"))
1614 list(chunker.finish())
1645 list(chunker.finish())
1615
1646
1616 with self.assertRaisesRegex(
1647 with self.assertRaisesRegex(
1617 zstd.ZstdError, r"cannot call flush\(\) after compression finished"
1648 zstd.ZstdError, r"cannot call flush\(\) after compression finished"
1618 ):
1649 ):
1619 list(chunker.flush())
1650 list(chunker.flush())
1620
1651
1621 def test_finish_after_finish(self):
1652 def test_finish_after_finish(self):
1622 cctx = zstd.ZstdCompressor()
1653 cctx = zstd.ZstdCompressor()
1623 chunker = cctx.chunker()
1654 chunker = cctx.chunker()
1624
1655
1625 list(chunker.compress(b"foo"))
1656 list(chunker.compress(b"foo"))
1626 list(chunker.finish())
1657 list(chunker.finish())
1627
1658
1628 with self.assertRaisesRegex(
1659 with self.assertRaisesRegex(
1629 zstd.ZstdError, r"cannot call finish\(\) after compression finished"
1660 zstd.ZstdError, r"cannot call finish\(\) after compression finished"
1630 ):
1661 ):
1631 list(chunker.finish())
1662 list(chunker.finish())
1632
1663
1633
1664
1634 class TestCompressor_multi_compress_to_buffer(TestCase):
1665 class TestCompressor_multi_compress_to_buffer(TestCase):
1635 def test_invalid_inputs(self):
1666 def test_invalid_inputs(self):
1636 cctx = zstd.ZstdCompressor()
1667 cctx = zstd.ZstdCompressor()
1637
1668
1638 if not hasattr(cctx, "multi_compress_to_buffer"):
1669 if not hasattr(cctx, "multi_compress_to_buffer"):
1639 self.skipTest("multi_compress_to_buffer not available")
1670 self.skipTest("multi_compress_to_buffer not available")
1640
1671
1641 with self.assertRaises(TypeError):
1672 with self.assertRaises(TypeError):
1642 cctx.multi_compress_to_buffer(True)
1673 cctx.multi_compress_to_buffer(True)
1643
1674
1644 with self.assertRaises(TypeError):
1675 with self.assertRaises(TypeError):
1645 cctx.multi_compress_to_buffer((1, 2))
1676 cctx.multi_compress_to_buffer((1, 2))
1646
1677
1647 with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"):
1678 with self.assertRaisesRegex(
1679 TypeError, "item 0 not a bytes like object"
1680 ):
1648 cctx.multi_compress_to_buffer([u"foo"])
1681 cctx.multi_compress_to_buffer([u"foo"])
1649
1682
1650 def test_empty_input(self):
1683 def test_empty_input(self):
1651 cctx = zstd.ZstdCompressor()
1684 cctx = zstd.ZstdCompressor()
1652
1685
1653 if not hasattr(cctx, "multi_compress_to_buffer"):
1686 if not hasattr(cctx, "multi_compress_to_buffer"):
1654 self.skipTest("multi_compress_to_buffer not available")
1687 self.skipTest("multi_compress_to_buffer not available")
1655
1688
1656 with self.assertRaisesRegex(ValueError, "no source elements found"):
1689 with self.assertRaisesRegex(ValueError, "no source elements found"):
1657 cctx.multi_compress_to_buffer([])
1690 cctx.multi_compress_to_buffer([])
1658
1691
1659 with self.assertRaisesRegex(ValueError, "source elements are empty"):
1692 with self.assertRaisesRegex(ValueError, "source elements are empty"):
1660 cctx.multi_compress_to_buffer([b"", b"", b""])
1693 cctx.multi_compress_to_buffer([b"", b"", b""])
1661
1694
1662 def test_list_input(self):
1695 def test_list_input(self):
1663 cctx = zstd.ZstdCompressor(write_checksum=True)
1696 cctx = zstd.ZstdCompressor(write_checksum=True)
1664
1697
1665 if not hasattr(cctx, "multi_compress_to_buffer"):
1698 if not hasattr(cctx, "multi_compress_to_buffer"):
1666 self.skipTest("multi_compress_to_buffer not available")
1699 self.skipTest("multi_compress_to_buffer not available")
1667
1700
1668 original = [b"foo" * 12, b"bar" * 6]
1701 original = [b"foo" * 12, b"bar" * 6]
1669 frames = [cctx.compress(c) for c in original]
1702 frames = [cctx.compress(c) for c in original]
1670 b = cctx.multi_compress_to_buffer(original)
1703 b = cctx.multi_compress_to_buffer(original)
1671
1704
1672 self.assertIsInstance(b, zstd.BufferWithSegmentsCollection)
1705 self.assertIsInstance(b, zstd.BufferWithSegmentsCollection)
1673
1706
1674 self.assertEqual(len(b), 2)
1707 self.assertEqual(len(b), 2)
1675 self.assertEqual(b.size(), 44)
1708 self.assertEqual(b.size(), 44)
1676
1709
1677 self.assertEqual(b[0].tobytes(), frames[0])
1710 self.assertEqual(b[0].tobytes(), frames[0])
1678 self.assertEqual(b[1].tobytes(), frames[1])
1711 self.assertEqual(b[1].tobytes(), frames[1])
1679
1712
1680 def test_buffer_with_segments_input(self):
1713 def test_buffer_with_segments_input(self):
1681 cctx = zstd.ZstdCompressor(write_checksum=True)
1714 cctx = zstd.ZstdCompressor(write_checksum=True)
1682
1715
1683 if not hasattr(cctx, "multi_compress_to_buffer"):
1716 if not hasattr(cctx, "multi_compress_to_buffer"):
1684 self.skipTest("multi_compress_to_buffer not available")
1717 self.skipTest("multi_compress_to_buffer not available")
1685
1718
1686 original = [b"foo" * 4, b"bar" * 6]
1719 original = [b"foo" * 4, b"bar" * 6]
1687 frames = [cctx.compress(c) for c in original]
1720 frames = [cctx.compress(c) for c in original]
1688
1721
1689 offsets = struct.pack(
1722 offsets = struct.pack(
1690 "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
1723 "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
1691 )
1724 )
1692 segments = zstd.BufferWithSegments(b"".join(original), offsets)
1725 segments = zstd.BufferWithSegments(b"".join(original), offsets)
1693
1726
1694 result = cctx.multi_compress_to_buffer(segments)
1727 result = cctx.multi_compress_to_buffer(segments)
1695
1728
1696 self.assertEqual(len(result), 2)
1729 self.assertEqual(len(result), 2)
1697 self.assertEqual(result.size(), 47)
1730 self.assertEqual(result.size(), 47)
1698
1731
1699 self.assertEqual(result[0].tobytes(), frames[0])
1732 self.assertEqual(result[0].tobytes(), frames[0])
1700 self.assertEqual(result[1].tobytes(), frames[1])
1733 self.assertEqual(result[1].tobytes(), frames[1])
1701
1734
1702 def test_buffer_with_segments_collection_input(self):
1735 def test_buffer_with_segments_collection_input(self):
1703 cctx = zstd.ZstdCompressor(write_checksum=True)
1736 cctx = zstd.ZstdCompressor(write_checksum=True)
1704
1737
1705 if not hasattr(cctx, "multi_compress_to_buffer"):
1738 if not hasattr(cctx, "multi_compress_to_buffer"):
1706 self.skipTest("multi_compress_to_buffer not available")
1739 self.skipTest("multi_compress_to_buffer not available")
1707
1740
1708 original = [
1741 original = [
1709 b"foo1",
1742 b"foo1",
1710 b"foo2" * 2,
1743 b"foo2" * 2,
1711 b"foo3" * 3,
1744 b"foo3" * 3,
1712 b"foo4" * 4,
1745 b"foo4" * 4,
1713 b"foo5" * 5,
1746 b"foo5" * 5,
1714 ]
1747 ]
1715
1748
1716 frames = [cctx.compress(c) for c in original]
1749 frames = [cctx.compress(c) for c in original]
1717
1750
1718 b = b"".join([original[0], original[1]])
1751 b = b"".join([original[0], original[1]])
1719 b1 = zstd.BufferWithSegments(
1752 b1 = zstd.BufferWithSegments(
1720 b,
1753 b,
1721 struct.pack(
1754 struct.pack(
1722 "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
1755 "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
1723 ),
1756 ),
1724 )
1757 )
1725 b = b"".join([original[2], original[3], original[4]])
1758 b = b"".join([original[2], original[3], original[4]])
1726 b2 = zstd.BufferWithSegments(
1759 b2 = zstd.BufferWithSegments(
1727 b,
1760 b,
1728 struct.pack(
1761 struct.pack(
1729 "=QQQQQQ",
1762 "=QQQQQQ",
1730 0,
1763 0,
1731 len(original[2]),
1764 len(original[2]),
1732 len(original[2]),
1765 len(original[2]),
1733 len(original[3]),
1766 len(original[3]),
1734 len(original[2]) + len(original[3]),
1767 len(original[2]) + len(original[3]),
1735 len(original[4]),
1768 len(original[4]),
1736 ),
1769 ),
1737 )
1770 )
1738
1771
1739 c = zstd.BufferWithSegmentsCollection(b1, b2)
1772 c = zstd.BufferWithSegmentsCollection(b1, b2)
1740
1773
1741 result = cctx.multi_compress_to_buffer(c)
1774 result = cctx.multi_compress_to_buffer(c)
1742
1775
1743 self.assertEqual(len(result), len(frames))
1776 self.assertEqual(len(result), len(frames))
1744
1777
1745 for i, frame in enumerate(frames):
1778 for i, frame in enumerate(frames):
1746 self.assertEqual(result[i].tobytes(), frame)
1779 self.assertEqual(result[i].tobytes(), frame)
1747
1780
1748 def test_multiple_threads(self):
1781 def test_multiple_threads(self):
1749 # threads argument will cause multi-threaded ZSTD APIs to be used, which will
1782 # threads argument will cause multi-threaded ZSTD APIs to be used, which will
1750 # make output different.
1783 # make output different.
1751 refcctx = zstd.ZstdCompressor(write_checksum=True)
1784 refcctx = zstd.ZstdCompressor(write_checksum=True)
1752 reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)]
1785 reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)]
1753
1786
1754 cctx = zstd.ZstdCompressor(write_checksum=True)
1787 cctx = zstd.ZstdCompressor(write_checksum=True)
1755
1788
1756 if not hasattr(cctx, "multi_compress_to_buffer"):
1789 if not hasattr(cctx, "multi_compress_to_buffer"):
1757 self.skipTest("multi_compress_to_buffer not available")
1790 self.skipTest("multi_compress_to_buffer not available")
1758
1791
1759 frames = []
1792 frames = []
1760 frames.extend(b"x" * 64 for i in range(256))
1793 frames.extend(b"x" * 64 for i in range(256))
1761 frames.extend(b"y" * 64 for i in range(256))
1794 frames.extend(b"y" * 64 for i in range(256))
1762
1795
1763 result = cctx.multi_compress_to_buffer(frames, threads=-1)
1796 result = cctx.multi_compress_to_buffer(frames, threads=-1)
1764
1797
1765 self.assertEqual(len(result), 512)
1798 self.assertEqual(len(result), 512)
1766 for i in range(512):
1799 for i in range(512):
1767 if i < 256:
1800 if i < 256:
1768 self.assertEqual(result[i].tobytes(), reference[0])
1801 self.assertEqual(result[i].tobytes(), reference[0])
1769 else:
1802 else:
1770 self.assertEqual(result[i].tobytes(), reference[1])
1803 self.assertEqual(result[i].tobytes(), reference[1])
@@ -1,836 +1,884 b''
1 import io
1 import io
2 import os
2 import os
3 import unittest
3 import unittest
4
4
5 try:
5 try:
6 import hypothesis
6 import hypothesis
7 import hypothesis.strategies as strategies
7 import hypothesis.strategies as strategies
8 except ImportError:
8 except ImportError:
9 raise unittest.SkipTest("hypothesis not available")
9 raise unittest.SkipTest("hypothesis not available")
10
10
11 import zstandard as zstd
11 import zstandard as zstd
12
12
13 from .common import (
13 from .common import (
14 make_cffi,
14 make_cffi,
15 NonClosingBytesIO,
15 NonClosingBytesIO,
16 random_input_data,
16 random_input_data,
17 TestCase,
17 TestCase,
18 )
18 )
19
19
20
20
21 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
21 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
22 @make_cffi
22 @make_cffi
23 class TestCompressor_stream_reader_fuzzing(TestCase):
23 class TestCompressor_stream_reader_fuzzing(TestCase):
24 @hypothesis.settings(
24 @hypothesis.settings(
25 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
25 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
26 )
26 )
27 @hypothesis.given(
27 @hypothesis.given(
28 original=strategies.sampled_from(random_input_data()),
28 original=strategies.sampled_from(random_input_data()),
29 level=strategies.integers(min_value=1, max_value=5),
29 level=strategies.integers(min_value=1, max_value=5),
30 source_read_size=strategies.integers(1, 16384),
30 source_read_size=strategies.integers(1, 16384),
31 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
31 read_size=strategies.integers(
32 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
33 ),
32 )
34 )
33 def test_stream_source_read(self, original, level, source_read_size, read_size):
35 def test_stream_source_read(
36 self, original, level, source_read_size, read_size
37 ):
34 if read_size == 0:
38 if read_size == 0:
35 read_size = -1
39 read_size = -1
36
40
37 refctx = zstd.ZstdCompressor(level=level)
41 refctx = zstd.ZstdCompressor(level=level)
38 ref_frame = refctx.compress(original)
42 ref_frame = refctx.compress(original)
39
43
40 cctx = zstd.ZstdCompressor(level=level)
44 cctx = zstd.ZstdCompressor(level=level)
41 with cctx.stream_reader(
45 with cctx.stream_reader(
42 io.BytesIO(original), size=len(original), read_size=source_read_size
46 io.BytesIO(original), size=len(original), read_size=source_read_size
43 ) as reader:
47 ) as reader:
44 chunks = []
48 chunks = []
45 while True:
49 while True:
46 chunk = reader.read(read_size)
50 chunk = reader.read(read_size)
47 if not chunk:
51 if not chunk:
48 break
52 break
49
53
50 chunks.append(chunk)
54 chunks.append(chunk)
51
55
52 self.assertEqual(b"".join(chunks), ref_frame)
56 self.assertEqual(b"".join(chunks), ref_frame)
53
57
54 @hypothesis.settings(
58 @hypothesis.settings(
55 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
59 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
56 )
60 )
57 @hypothesis.given(
61 @hypothesis.given(
58 original=strategies.sampled_from(random_input_data()),
62 original=strategies.sampled_from(random_input_data()),
59 level=strategies.integers(min_value=1, max_value=5),
63 level=strategies.integers(min_value=1, max_value=5),
60 source_read_size=strategies.integers(1, 16384),
64 source_read_size=strategies.integers(1, 16384),
61 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
65 read_size=strategies.integers(
66 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
67 ),
62 )
68 )
63 def test_buffer_source_read(self, original, level, source_read_size, read_size):
69 def test_buffer_source_read(
70 self, original, level, source_read_size, read_size
71 ):
64 if read_size == 0:
72 if read_size == 0:
65 read_size = -1
73 read_size = -1
66
74
67 refctx = zstd.ZstdCompressor(level=level)
75 refctx = zstd.ZstdCompressor(level=level)
68 ref_frame = refctx.compress(original)
76 ref_frame = refctx.compress(original)
69
77
70 cctx = zstd.ZstdCompressor(level=level)
78 cctx = zstd.ZstdCompressor(level=level)
71 with cctx.stream_reader(
79 with cctx.stream_reader(
72 original, size=len(original), read_size=source_read_size
80 original, size=len(original), read_size=source_read_size
73 ) as reader:
81 ) as reader:
74 chunks = []
82 chunks = []
75 while True:
83 while True:
76 chunk = reader.read(read_size)
84 chunk = reader.read(read_size)
77 if not chunk:
85 if not chunk:
78 break
86 break
79
87
80 chunks.append(chunk)
88 chunks.append(chunk)
81
89
82 self.assertEqual(b"".join(chunks), ref_frame)
90 self.assertEqual(b"".join(chunks), ref_frame)
83
91
84 @hypothesis.settings(
92 @hypothesis.settings(
85 suppress_health_check=[
93 suppress_health_check=[
86 hypothesis.HealthCheck.large_base_example,
94 hypothesis.HealthCheck.large_base_example,
87 hypothesis.HealthCheck.too_slow,
95 hypothesis.HealthCheck.too_slow,
88 ]
96 ]
89 )
97 )
90 @hypothesis.given(
98 @hypothesis.given(
91 original=strategies.sampled_from(random_input_data()),
99 original=strategies.sampled_from(random_input_data()),
92 level=strategies.integers(min_value=1, max_value=5),
100 level=strategies.integers(min_value=1, max_value=5),
93 source_read_size=strategies.integers(1, 16384),
101 source_read_size=strategies.integers(1, 16384),
94 read_sizes=strategies.data(),
102 read_sizes=strategies.data(),
95 )
103 )
96 def test_stream_source_read_variance(
104 def test_stream_source_read_variance(
97 self, original, level, source_read_size, read_sizes
105 self, original, level, source_read_size, read_sizes
98 ):
106 ):
99 refctx = zstd.ZstdCompressor(level=level)
107 refctx = zstd.ZstdCompressor(level=level)
100 ref_frame = refctx.compress(original)
108 ref_frame = refctx.compress(original)
101
109
102 cctx = zstd.ZstdCompressor(level=level)
110 cctx = zstd.ZstdCompressor(level=level)
103 with cctx.stream_reader(
111 with cctx.stream_reader(
104 io.BytesIO(original), size=len(original), read_size=source_read_size
112 io.BytesIO(original), size=len(original), read_size=source_read_size
105 ) as reader:
113 ) as reader:
106 chunks = []
114 chunks = []
107 while True:
115 while True:
108 read_size = read_sizes.draw(strategies.integers(-1, 16384))
116 read_size = read_sizes.draw(strategies.integers(-1, 16384))
109 chunk = reader.read(read_size)
117 chunk = reader.read(read_size)
110 if not chunk and read_size:
118 if not chunk and read_size:
111 break
119 break
112
120
113 chunks.append(chunk)
121 chunks.append(chunk)
114
122
115 self.assertEqual(b"".join(chunks), ref_frame)
123 self.assertEqual(b"".join(chunks), ref_frame)
116
124
117 @hypothesis.settings(
125 @hypothesis.settings(
118 suppress_health_check=[
126 suppress_health_check=[
119 hypothesis.HealthCheck.large_base_example,
127 hypothesis.HealthCheck.large_base_example,
120 hypothesis.HealthCheck.too_slow,
128 hypothesis.HealthCheck.too_slow,
121 ]
129 ]
122 )
130 )
123 @hypothesis.given(
131 @hypothesis.given(
124 original=strategies.sampled_from(random_input_data()),
132 original=strategies.sampled_from(random_input_data()),
125 level=strategies.integers(min_value=1, max_value=5),
133 level=strategies.integers(min_value=1, max_value=5),
126 source_read_size=strategies.integers(1, 16384),
134 source_read_size=strategies.integers(1, 16384),
127 read_sizes=strategies.data(),
135 read_sizes=strategies.data(),
128 )
136 )
129 def test_buffer_source_read_variance(
137 def test_buffer_source_read_variance(
130 self, original, level, source_read_size, read_sizes
138 self, original, level, source_read_size, read_sizes
131 ):
139 ):
132
140
133 refctx = zstd.ZstdCompressor(level=level)
141 refctx = zstd.ZstdCompressor(level=level)
134 ref_frame = refctx.compress(original)
142 ref_frame = refctx.compress(original)
135
143
136 cctx = zstd.ZstdCompressor(level=level)
144 cctx = zstd.ZstdCompressor(level=level)
137 with cctx.stream_reader(
145 with cctx.stream_reader(
138 original, size=len(original), read_size=source_read_size
146 original, size=len(original), read_size=source_read_size
139 ) as reader:
147 ) as reader:
140 chunks = []
148 chunks = []
141 while True:
149 while True:
142 read_size = read_sizes.draw(strategies.integers(-1, 16384))
150 read_size = read_sizes.draw(strategies.integers(-1, 16384))
143 chunk = reader.read(read_size)
151 chunk = reader.read(read_size)
144 if not chunk and read_size:
152 if not chunk and read_size:
145 break
153 break
146
154
147 chunks.append(chunk)
155 chunks.append(chunk)
148
156
149 self.assertEqual(b"".join(chunks), ref_frame)
157 self.assertEqual(b"".join(chunks), ref_frame)
150
158
151 @hypothesis.settings(
159 @hypothesis.settings(
152 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
160 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
153 )
161 )
154 @hypothesis.given(
162 @hypothesis.given(
155 original=strategies.sampled_from(random_input_data()),
163 original=strategies.sampled_from(random_input_data()),
156 level=strategies.integers(min_value=1, max_value=5),
164 level=strategies.integers(min_value=1, max_value=5),
157 source_read_size=strategies.integers(1, 16384),
165 source_read_size=strategies.integers(1, 16384),
158 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
166 read_size=strategies.integers(
167 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
168 ),
159 )
169 )
160 def test_stream_source_readinto(self, original, level, source_read_size, read_size):
170 def test_stream_source_readinto(
171 self, original, level, source_read_size, read_size
172 ):
161 refctx = zstd.ZstdCompressor(level=level)
173 refctx = zstd.ZstdCompressor(level=level)
162 ref_frame = refctx.compress(original)
174 ref_frame = refctx.compress(original)
163
175
164 cctx = zstd.ZstdCompressor(level=level)
176 cctx = zstd.ZstdCompressor(level=level)
165 with cctx.stream_reader(
177 with cctx.stream_reader(
166 io.BytesIO(original), size=len(original), read_size=source_read_size
178 io.BytesIO(original), size=len(original), read_size=source_read_size
167 ) as reader:
179 ) as reader:
168 chunks = []
180 chunks = []
169 while True:
181 while True:
170 b = bytearray(read_size)
182 b = bytearray(read_size)
171 count = reader.readinto(b)
183 count = reader.readinto(b)
172
184
173 if not count:
185 if not count:
174 break
186 break
175
187
176 chunks.append(bytes(b[0:count]))
188 chunks.append(bytes(b[0:count]))
177
189
178 self.assertEqual(b"".join(chunks), ref_frame)
190 self.assertEqual(b"".join(chunks), ref_frame)
179
191
180 @hypothesis.settings(
192 @hypothesis.settings(
181 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
193 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
182 )
194 )
183 @hypothesis.given(
195 @hypothesis.given(
184 original=strategies.sampled_from(random_input_data()),
196 original=strategies.sampled_from(random_input_data()),
185 level=strategies.integers(min_value=1, max_value=5),
197 level=strategies.integers(min_value=1, max_value=5),
186 source_read_size=strategies.integers(1, 16384),
198 source_read_size=strategies.integers(1, 16384),
187 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
199 read_size=strategies.integers(
200 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
201 ),
188 )
202 )
189 def test_buffer_source_readinto(self, original, level, source_read_size, read_size):
203 def test_buffer_source_readinto(
204 self, original, level, source_read_size, read_size
205 ):
190
206
191 refctx = zstd.ZstdCompressor(level=level)
207 refctx = zstd.ZstdCompressor(level=level)
192 ref_frame = refctx.compress(original)
208 ref_frame = refctx.compress(original)
193
209
194 cctx = zstd.ZstdCompressor(level=level)
210 cctx = zstd.ZstdCompressor(level=level)
195 with cctx.stream_reader(
211 with cctx.stream_reader(
196 original, size=len(original), read_size=source_read_size
212 original, size=len(original), read_size=source_read_size
197 ) as reader:
213 ) as reader:
198 chunks = []
214 chunks = []
199 while True:
215 while True:
200 b = bytearray(read_size)
216 b = bytearray(read_size)
201 count = reader.readinto(b)
217 count = reader.readinto(b)
202
218
203 if not count:
219 if not count:
204 break
220 break
205
221
206 chunks.append(bytes(b[0:count]))
222 chunks.append(bytes(b[0:count]))
207
223
208 self.assertEqual(b"".join(chunks), ref_frame)
224 self.assertEqual(b"".join(chunks), ref_frame)
209
225
210 @hypothesis.settings(
226 @hypothesis.settings(
211 suppress_health_check=[
227 suppress_health_check=[
212 hypothesis.HealthCheck.large_base_example,
228 hypothesis.HealthCheck.large_base_example,
213 hypothesis.HealthCheck.too_slow,
229 hypothesis.HealthCheck.too_slow,
214 ]
230 ]
215 )
231 )
216 @hypothesis.given(
232 @hypothesis.given(
217 original=strategies.sampled_from(random_input_data()),
233 original=strategies.sampled_from(random_input_data()),
218 level=strategies.integers(min_value=1, max_value=5),
234 level=strategies.integers(min_value=1, max_value=5),
219 source_read_size=strategies.integers(1, 16384),
235 source_read_size=strategies.integers(1, 16384),
220 read_sizes=strategies.data(),
236 read_sizes=strategies.data(),
221 )
237 )
222 def test_stream_source_readinto_variance(
238 def test_stream_source_readinto_variance(
223 self, original, level, source_read_size, read_sizes
239 self, original, level, source_read_size, read_sizes
224 ):
240 ):
225 refctx = zstd.ZstdCompressor(level=level)
241 refctx = zstd.ZstdCompressor(level=level)
226 ref_frame = refctx.compress(original)
242 ref_frame = refctx.compress(original)
227
243
228 cctx = zstd.ZstdCompressor(level=level)
244 cctx = zstd.ZstdCompressor(level=level)
229 with cctx.stream_reader(
245 with cctx.stream_reader(
230 io.BytesIO(original), size=len(original), read_size=source_read_size
246 io.BytesIO(original), size=len(original), read_size=source_read_size
231 ) as reader:
247 ) as reader:
232 chunks = []
248 chunks = []
233 while True:
249 while True:
234 read_size = read_sizes.draw(strategies.integers(1, 16384))
250 read_size = read_sizes.draw(strategies.integers(1, 16384))
235 b = bytearray(read_size)
251 b = bytearray(read_size)
236 count = reader.readinto(b)
252 count = reader.readinto(b)
237
253
238 if not count:
254 if not count:
239 break
255 break
240
256
241 chunks.append(bytes(b[0:count]))
257 chunks.append(bytes(b[0:count]))
242
258
243 self.assertEqual(b"".join(chunks), ref_frame)
259 self.assertEqual(b"".join(chunks), ref_frame)
244
260
245 @hypothesis.settings(
261 @hypothesis.settings(
246 suppress_health_check=[
262 suppress_health_check=[
247 hypothesis.HealthCheck.large_base_example,
263 hypothesis.HealthCheck.large_base_example,
248 hypothesis.HealthCheck.too_slow,
264 hypothesis.HealthCheck.too_slow,
249 ]
265 ]
250 )
266 )
251 @hypothesis.given(
267 @hypothesis.given(
252 original=strategies.sampled_from(random_input_data()),
268 original=strategies.sampled_from(random_input_data()),
253 level=strategies.integers(min_value=1, max_value=5),
269 level=strategies.integers(min_value=1, max_value=5),
254 source_read_size=strategies.integers(1, 16384),
270 source_read_size=strategies.integers(1, 16384),
255 read_sizes=strategies.data(),
271 read_sizes=strategies.data(),
256 )
272 )
257 def test_buffer_source_readinto_variance(
273 def test_buffer_source_readinto_variance(
258 self, original, level, source_read_size, read_sizes
274 self, original, level, source_read_size, read_sizes
259 ):
275 ):
260
276
261 refctx = zstd.ZstdCompressor(level=level)
277 refctx = zstd.ZstdCompressor(level=level)
262 ref_frame = refctx.compress(original)
278 ref_frame = refctx.compress(original)
263
279
264 cctx = zstd.ZstdCompressor(level=level)
280 cctx = zstd.ZstdCompressor(level=level)
265 with cctx.stream_reader(
281 with cctx.stream_reader(
266 original, size=len(original), read_size=source_read_size
282 original, size=len(original), read_size=source_read_size
267 ) as reader:
283 ) as reader:
268 chunks = []
284 chunks = []
269 while True:
285 while True:
270 read_size = read_sizes.draw(strategies.integers(1, 16384))
286 read_size = read_sizes.draw(strategies.integers(1, 16384))
271 b = bytearray(read_size)
287 b = bytearray(read_size)
272 count = reader.readinto(b)
288 count = reader.readinto(b)
273
289
274 if not count:
290 if not count:
275 break
291 break
276
292
277 chunks.append(bytes(b[0:count]))
293 chunks.append(bytes(b[0:count]))
278
294
279 self.assertEqual(b"".join(chunks), ref_frame)
295 self.assertEqual(b"".join(chunks), ref_frame)
280
296
281 @hypothesis.settings(
297 @hypothesis.settings(
282 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
298 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
283 )
299 )
284 @hypothesis.given(
300 @hypothesis.given(
285 original=strategies.sampled_from(random_input_data()),
301 original=strategies.sampled_from(random_input_data()),
286 level=strategies.integers(min_value=1, max_value=5),
302 level=strategies.integers(min_value=1, max_value=5),
287 source_read_size=strategies.integers(1, 16384),
303 source_read_size=strategies.integers(1, 16384),
288 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
304 read_size=strategies.integers(
305 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
306 ),
289 )
307 )
290 def test_stream_source_read1(self, original, level, source_read_size, read_size):
308 def test_stream_source_read1(
309 self, original, level, source_read_size, read_size
310 ):
291 if read_size == 0:
311 if read_size == 0:
292 read_size = -1
312 read_size = -1
293
313
294 refctx = zstd.ZstdCompressor(level=level)
314 refctx = zstd.ZstdCompressor(level=level)
295 ref_frame = refctx.compress(original)
315 ref_frame = refctx.compress(original)
296
316
297 cctx = zstd.ZstdCompressor(level=level)
317 cctx = zstd.ZstdCompressor(level=level)
298 with cctx.stream_reader(
318 with cctx.stream_reader(
299 io.BytesIO(original), size=len(original), read_size=source_read_size
319 io.BytesIO(original), size=len(original), read_size=source_read_size
300 ) as reader:
320 ) as reader:
301 chunks = []
321 chunks = []
302 while True:
322 while True:
303 chunk = reader.read1(read_size)
323 chunk = reader.read1(read_size)
304 if not chunk:
324 if not chunk:
305 break
325 break
306
326
307 chunks.append(chunk)
327 chunks.append(chunk)
308
328
309 self.assertEqual(b"".join(chunks), ref_frame)
329 self.assertEqual(b"".join(chunks), ref_frame)
310
330
311 @hypothesis.settings(
331 @hypothesis.settings(
312 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
332 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
313 )
333 )
314 @hypothesis.given(
334 @hypothesis.given(
315 original=strategies.sampled_from(random_input_data()),
335 original=strategies.sampled_from(random_input_data()),
316 level=strategies.integers(min_value=1, max_value=5),
336 level=strategies.integers(min_value=1, max_value=5),
317 source_read_size=strategies.integers(1, 16384),
337 source_read_size=strategies.integers(1, 16384),
318 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
338 read_size=strategies.integers(
339 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
340 ),
319 )
341 )
320 def test_buffer_source_read1(self, original, level, source_read_size, read_size):
342 def test_buffer_source_read1(
343 self, original, level, source_read_size, read_size
344 ):
321 if read_size == 0:
345 if read_size == 0:
322 read_size = -1
346 read_size = -1
323
347
324 refctx = zstd.ZstdCompressor(level=level)
348 refctx = zstd.ZstdCompressor(level=level)
325 ref_frame = refctx.compress(original)
349 ref_frame = refctx.compress(original)
326
350
327 cctx = zstd.ZstdCompressor(level=level)
351 cctx = zstd.ZstdCompressor(level=level)
328 with cctx.stream_reader(
352 with cctx.stream_reader(
329 original, size=len(original), read_size=source_read_size
353 original, size=len(original), read_size=source_read_size
330 ) as reader:
354 ) as reader:
331 chunks = []
355 chunks = []
332 while True:
356 while True:
333 chunk = reader.read1(read_size)
357 chunk = reader.read1(read_size)
334 if not chunk:
358 if not chunk:
335 break
359 break
336
360
337 chunks.append(chunk)
361 chunks.append(chunk)
338
362
339 self.assertEqual(b"".join(chunks), ref_frame)
363 self.assertEqual(b"".join(chunks), ref_frame)
340
364
341 @hypothesis.settings(
365 @hypothesis.settings(
342 suppress_health_check=[
366 suppress_health_check=[
343 hypothesis.HealthCheck.large_base_example,
367 hypothesis.HealthCheck.large_base_example,
344 hypothesis.HealthCheck.too_slow,
368 hypothesis.HealthCheck.too_slow,
345 ]
369 ]
346 )
370 )
347 @hypothesis.given(
371 @hypothesis.given(
348 original=strategies.sampled_from(random_input_data()),
372 original=strategies.sampled_from(random_input_data()),
349 level=strategies.integers(min_value=1, max_value=5),
373 level=strategies.integers(min_value=1, max_value=5),
350 source_read_size=strategies.integers(1, 16384),
374 source_read_size=strategies.integers(1, 16384),
351 read_sizes=strategies.data(),
375 read_sizes=strategies.data(),
352 )
376 )
353 def test_stream_source_read1_variance(
377 def test_stream_source_read1_variance(
354 self, original, level, source_read_size, read_sizes
378 self, original, level, source_read_size, read_sizes
355 ):
379 ):
356 refctx = zstd.ZstdCompressor(level=level)
380 refctx = zstd.ZstdCompressor(level=level)
357 ref_frame = refctx.compress(original)
381 ref_frame = refctx.compress(original)
358
382
359 cctx = zstd.ZstdCompressor(level=level)
383 cctx = zstd.ZstdCompressor(level=level)
360 with cctx.stream_reader(
384 with cctx.stream_reader(
361 io.BytesIO(original), size=len(original), read_size=source_read_size
385 io.BytesIO(original), size=len(original), read_size=source_read_size
362 ) as reader:
386 ) as reader:
363 chunks = []
387 chunks = []
364 while True:
388 while True:
365 read_size = read_sizes.draw(strategies.integers(-1, 16384))
389 read_size = read_sizes.draw(strategies.integers(-1, 16384))
366 chunk = reader.read1(read_size)
390 chunk = reader.read1(read_size)
367 if not chunk and read_size:
391 if not chunk and read_size:
368 break
392 break
369
393
370 chunks.append(chunk)
394 chunks.append(chunk)
371
395
372 self.assertEqual(b"".join(chunks), ref_frame)
396 self.assertEqual(b"".join(chunks), ref_frame)
373
397
374 @hypothesis.settings(
398 @hypothesis.settings(
375 suppress_health_check=[
399 suppress_health_check=[
376 hypothesis.HealthCheck.large_base_example,
400 hypothesis.HealthCheck.large_base_example,
377 hypothesis.HealthCheck.too_slow,
401 hypothesis.HealthCheck.too_slow,
378 ]
402 ]
379 )
403 )
380 @hypothesis.given(
404 @hypothesis.given(
381 original=strategies.sampled_from(random_input_data()),
405 original=strategies.sampled_from(random_input_data()),
382 level=strategies.integers(min_value=1, max_value=5),
406 level=strategies.integers(min_value=1, max_value=5),
383 source_read_size=strategies.integers(1, 16384),
407 source_read_size=strategies.integers(1, 16384),
384 read_sizes=strategies.data(),
408 read_sizes=strategies.data(),
385 )
409 )
386 def test_buffer_source_read1_variance(
410 def test_buffer_source_read1_variance(
387 self, original, level, source_read_size, read_sizes
411 self, original, level, source_read_size, read_sizes
388 ):
412 ):
389
413
390 refctx = zstd.ZstdCompressor(level=level)
414 refctx = zstd.ZstdCompressor(level=level)
391 ref_frame = refctx.compress(original)
415 ref_frame = refctx.compress(original)
392
416
393 cctx = zstd.ZstdCompressor(level=level)
417 cctx = zstd.ZstdCompressor(level=level)
394 with cctx.stream_reader(
418 with cctx.stream_reader(
395 original, size=len(original), read_size=source_read_size
419 original, size=len(original), read_size=source_read_size
396 ) as reader:
420 ) as reader:
397 chunks = []
421 chunks = []
398 while True:
422 while True:
399 read_size = read_sizes.draw(strategies.integers(-1, 16384))
423 read_size = read_sizes.draw(strategies.integers(-1, 16384))
400 chunk = reader.read1(read_size)
424 chunk = reader.read1(read_size)
401 if not chunk and read_size:
425 if not chunk and read_size:
402 break
426 break
403
427
404 chunks.append(chunk)
428 chunks.append(chunk)
405
429
406 self.assertEqual(b"".join(chunks), ref_frame)
430 self.assertEqual(b"".join(chunks), ref_frame)
407
431
408 @hypothesis.settings(
432 @hypothesis.settings(
409 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
433 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
410 )
434 )
411 @hypothesis.given(
435 @hypothesis.given(
412 original=strategies.sampled_from(random_input_data()),
436 original=strategies.sampled_from(random_input_data()),
413 level=strategies.integers(min_value=1, max_value=5),
437 level=strategies.integers(min_value=1, max_value=5),
414 source_read_size=strategies.integers(1, 16384),
438 source_read_size=strategies.integers(1, 16384),
415 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
439 read_size=strategies.integers(
440 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
441 ),
416 )
442 )
417 def test_stream_source_readinto1(
443 def test_stream_source_readinto1(
418 self, original, level, source_read_size, read_size
444 self, original, level, source_read_size, read_size
419 ):
445 ):
420 if read_size == 0:
446 if read_size == 0:
421 read_size = -1
447 read_size = -1
422
448
423 refctx = zstd.ZstdCompressor(level=level)
449 refctx = zstd.ZstdCompressor(level=level)
424 ref_frame = refctx.compress(original)
450 ref_frame = refctx.compress(original)
425
451
426 cctx = zstd.ZstdCompressor(level=level)
452 cctx = zstd.ZstdCompressor(level=level)
427 with cctx.stream_reader(
453 with cctx.stream_reader(
428 io.BytesIO(original), size=len(original), read_size=source_read_size
454 io.BytesIO(original), size=len(original), read_size=source_read_size
429 ) as reader:
455 ) as reader:
430 chunks = []
456 chunks = []
431 while True:
457 while True:
432 b = bytearray(read_size)
458 b = bytearray(read_size)
433 count = reader.readinto1(b)
459 count = reader.readinto1(b)
434
460
435 if not count:
461 if not count:
436 break
462 break
437
463
438 chunks.append(bytes(b[0:count]))
464 chunks.append(bytes(b[0:count]))
439
465
440 self.assertEqual(b"".join(chunks), ref_frame)
466 self.assertEqual(b"".join(chunks), ref_frame)
441
467
442 @hypothesis.settings(
468 @hypothesis.settings(
443 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
469 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
444 )
470 )
445 @hypothesis.given(
471 @hypothesis.given(
446 original=strategies.sampled_from(random_input_data()),
472 original=strategies.sampled_from(random_input_data()),
447 level=strategies.integers(min_value=1, max_value=5),
473 level=strategies.integers(min_value=1, max_value=5),
448 source_read_size=strategies.integers(1, 16384),
474 source_read_size=strategies.integers(1, 16384),
449 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
475 read_size=strategies.integers(
476 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
477 ),
450 )
478 )
451 def test_buffer_source_readinto1(
479 def test_buffer_source_readinto1(
452 self, original, level, source_read_size, read_size
480 self, original, level, source_read_size, read_size
453 ):
481 ):
454 if read_size == 0:
482 if read_size == 0:
455 read_size = -1
483 read_size = -1
456
484
457 refctx = zstd.ZstdCompressor(level=level)
485 refctx = zstd.ZstdCompressor(level=level)
458 ref_frame = refctx.compress(original)
486 ref_frame = refctx.compress(original)
459
487
460 cctx = zstd.ZstdCompressor(level=level)
488 cctx = zstd.ZstdCompressor(level=level)
461 with cctx.stream_reader(
489 with cctx.stream_reader(
462 original, size=len(original), read_size=source_read_size
490 original, size=len(original), read_size=source_read_size
463 ) as reader:
491 ) as reader:
464 chunks = []
492 chunks = []
465 while True:
493 while True:
466 b = bytearray(read_size)
494 b = bytearray(read_size)
467 count = reader.readinto1(b)
495 count = reader.readinto1(b)
468
496
469 if not count:
497 if not count:
470 break
498 break
471
499
472 chunks.append(bytes(b[0:count]))
500 chunks.append(bytes(b[0:count]))
473
501
474 self.assertEqual(b"".join(chunks), ref_frame)
502 self.assertEqual(b"".join(chunks), ref_frame)
475
503
476 @hypothesis.settings(
504 @hypothesis.settings(
477 suppress_health_check=[
505 suppress_health_check=[
478 hypothesis.HealthCheck.large_base_example,
506 hypothesis.HealthCheck.large_base_example,
479 hypothesis.HealthCheck.too_slow,
507 hypothesis.HealthCheck.too_slow,
480 ]
508 ]
481 )
509 )
482 @hypothesis.given(
510 @hypothesis.given(
483 original=strategies.sampled_from(random_input_data()),
511 original=strategies.sampled_from(random_input_data()),
484 level=strategies.integers(min_value=1, max_value=5),
512 level=strategies.integers(min_value=1, max_value=5),
485 source_read_size=strategies.integers(1, 16384),
513 source_read_size=strategies.integers(1, 16384),
486 read_sizes=strategies.data(),
514 read_sizes=strategies.data(),
487 )
515 )
488 def test_stream_source_readinto1_variance(
516 def test_stream_source_readinto1_variance(
489 self, original, level, source_read_size, read_sizes
517 self, original, level, source_read_size, read_sizes
490 ):
518 ):
491 refctx = zstd.ZstdCompressor(level=level)
519 refctx = zstd.ZstdCompressor(level=level)
492 ref_frame = refctx.compress(original)
520 ref_frame = refctx.compress(original)
493
521
494 cctx = zstd.ZstdCompressor(level=level)
522 cctx = zstd.ZstdCompressor(level=level)
495 with cctx.stream_reader(
523 with cctx.stream_reader(
496 io.BytesIO(original), size=len(original), read_size=source_read_size
524 io.BytesIO(original), size=len(original), read_size=source_read_size
497 ) as reader:
525 ) as reader:
498 chunks = []
526 chunks = []
499 while True:
527 while True:
500 read_size = read_sizes.draw(strategies.integers(1, 16384))
528 read_size = read_sizes.draw(strategies.integers(1, 16384))
501 b = bytearray(read_size)
529 b = bytearray(read_size)
502 count = reader.readinto1(b)
530 count = reader.readinto1(b)
503
531
504 if not count:
532 if not count:
505 break
533 break
506
534
507 chunks.append(bytes(b[0:count]))
535 chunks.append(bytes(b[0:count]))
508
536
509 self.assertEqual(b"".join(chunks), ref_frame)
537 self.assertEqual(b"".join(chunks), ref_frame)
510
538
511 @hypothesis.settings(
539 @hypothesis.settings(
512 suppress_health_check=[
540 suppress_health_check=[
513 hypothesis.HealthCheck.large_base_example,
541 hypothesis.HealthCheck.large_base_example,
514 hypothesis.HealthCheck.too_slow,
542 hypothesis.HealthCheck.too_slow,
515 ]
543 ]
516 )
544 )
517 @hypothesis.given(
545 @hypothesis.given(
518 original=strategies.sampled_from(random_input_data()),
546 original=strategies.sampled_from(random_input_data()),
519 level=strategies.integers(min_value=1, max_value=5),
547 level=strategies.integers(min_value=1, max_value=5),
520 source_read_size=strategies.integers(1, 16384),
548 source_read_size=strategies.integers(1, 16384),
521 read_sizes=strategies.data(),
549 read_sizes=strategies.data(),
522 )
550 )
523 def test_buffer_source_readinto1_variance(
551 def test_buffer_source_readinto1_variance(
524 self, original, level, source_read_size, read_sizes
552 self, original, level, source_read_size, read_sizes
525 ):
553 ):
526
554
527 refctx = zstd.ZstdCompressor(level=level)
555 refctx = zstd.ZstdCompressor(level=level)
528 ref_frame = refctx.compress(original)
556 ref_frame = refctx.compress(original)
529
557
530 cctx = zstd.ZstdCompressor(level=level)
558 cctx = zstd.ZstdCompressor(level=level)
531 with cctx.stream_reader(
559 with cctx.stream_reader(
532 original, size=len(original), read_size=source_read_size
560 original, size=len(original), read_size=source_read_size
533 ) as reader:
561 ) as reader:
534 chunks = []
562 chunks = []
535 while True:
563 while True:
536 read_size = read_sizes.draw(strategies.integers(1, 16384))
564 read_size = read_sizes.draw(strategies.integers(1, 16384))
537 b = bytearray(read_size)
565 b = bytearray(read_size)
538 count = reader.readinto1(b)
566 count = reader.readinto1(b)
539
567
540 if not count:
568 if not count:
541 break
569 break
542
570
543 chunks.append(bytes(b[0:count]))
571 chunks.append(bytes(b[0:count]))
544
572
545 self.assertEqual(b"".join(chunks), ref_frame)
573 self.assertEqual(b"".join(chunks), ref_frame)
546
574
547
575
548 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
576 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
549 @make_cffi
577 @make_cffi
550 class TestCompressor_stream_writer_fuzzing(TestCase):
578 class TestCompressor_stream_writer_fuzzing(TestCase):
551 @hypothesis.given(
579 @hypothesis.given(
552 original=strategies.sampled_from(random_input_data()),
580 original=strategies.sampled_from(random_input_data()),
553 level=strategies.integers(min_value=1, max_value=5),
581 level=strategies.integers(min_value=1, max_value=5),
554 write_size=strategies.integers(min_value=1, max_value=1048576),
582 write_size=strategies.integers(min_value=1, max_value=1048576),
555 )
583 )
556 def test_write_size_variance(self, original, level, write_size):
584 def test_write_size_variance(self, original, level, write_size):
557 refctx = zstd.ZstdCompressor(level=level)
585 refctx = zstd.ZstdCompressor(level=level)
558 ref_frame = refctx.compress(original)
586 ref_frame = refctx.compress(original)
559
587
560 cctx = zstd.ZstdCompressor(level=level)
588 cctx = zstd.ZstdCompressor(level=level)
561 b = NonClosingBytesIO()
589 b = NonClosingBytesIO()
562 with cctx.stream_writer(
590 with cctx.stream_writer(
563 b, size=len(original), write_size=write_size
591 b, size=len(original), write_size=write_size
564 ) as compressor:
592 ) as compressor:
565 compressor.write(original)
593 compressor.write(original)
566
594
567 self.assertEqual(b.getvalue(), ref_frame)
595 self.assertEqual(b.getvalue(), ref_frame)
568
596
569
597
570 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
598 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
571 @make_cffi
599 @make_cffi
572 class TestCompressor_copy_stream_fuzzing(TestCase):
600 class TestCompressor_copy_stream_fuzzing(TestCase):
573 @hypothesis.given(
601 @hypothesis.given(
574 original=strategies.sampled_from(random_input_data()),
602 original=strategies.sampled_from(random_input_data()),
575 level=strategies.integers(min_value=1, max_value=5),
603 level=strategies.integers(min_value=1, max_value=5),
576 read_size=strategies.integers(min_value=1, max_value=1048576),
604 read_size=strategies.integers(min_value=1, max_value=1048576),
577 write_size=strategies.integers(min_value=1, max_value=1048576),
605 write_size=strategies.integers(min_value=1, max_value=1048576),
578 )
606 )
579 def test_read_write_size_variance(self, original, level, read_size, write_size):
607 def test_read_write_size_variance(
608 self, original, level, read_size, write_size
609 ):
580 refctx = zstd.ZstdCompressor(level=level)
610 refctx = zstd.ZstdCompressor(level=level)
581 ref_frame = refctx.compress(original)
611 ref_frame = refctx.compress(original)
582
612
583 cctx = zstd.ZstdCompressor(level=level)
613 cctx = zstd.ZstdCompressor(level=level)
584 source = io.BytesIO(original)
614 source = io.BytesIO(original)
585 dest = io.BytesIO()
615 dest = io.BytesIO()
586
616
587 cctx.copy_stream(
617 cctx.copy_stream(
588 source, dest, size=len(original), read_size=read_size, write_size=write_size
618 source,
619 dest,
620 size=len(original),
621 read_size=read_size,
622 write_size=write_size,
589 )
623 )
590
624
591 self.assertEqual(dest.getvalue(), ref_frame)
625 self.assertEqual(dest.getvalue(), ref_frame)
592
626
593
627
594 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
628 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
595 @make_cffi
629 @make_cffi
596 class TestCompressor_compressobj_fuzzing(TestCase):
630 class TestCompressor_compressobj_fuzzing(TestCase):
597 @hypothesis.settings(
631 @hypothesis.settings(
598 suppress_health_check=[
632 suppress_health_check=[
599 hypothesis.HealthCheck.large_base_example,
633 hypothesis.HealthCheck.large_base_example,
600 hypothesis.HealthCheck.too_slow,
634 hypothesis.HealthCheck.too_slow,
601 ]
635 ]
602 )
636 )
603 @hypothesis.given(
637 @hypothesis.given(
604 original=strategies.sampled_from(random_input_data()),
638 original=strategies.sampled_from(random_input_data()),
605 level=strategies.integers(min_value=1, max_value=5),
639 level=strategies.integers(min_value=1, max_value=5),
606 chunk_sizes=strategies.data(),
640 chunk_sizes=strategies.data(),
607 )
641 )
608 def test_random_input_sizes(self, original, level, chunk_sizes):
642 def test_random_input_sizes(self, original, level, chunk_sizes):
609 refctx = zstd.ZstdCompressor(level=level)
643 refctx = zstd.ZstdCompressor(level=level)
610 ref_frame = refctx.compress(original)
644 ref_frame = refctx.compress(original)
611
645
612 cctx = zstd.ZstdCompressor(level=level)
646 cctx = zstd.ZstdCompressor(level=level)
613 cobj = cctx.compressobj(size=len(original))
647 cobj = cctx.compressobj(size=len(original))
614
648
615 chunks = []
649 chunks = []
616 i = 0
650 i = 0
617 while True:
651 while True:
618 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
652 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
619 source = original[i : i + chunk_size]
653 source = original[i : i + chunk_size]
620 if not source:
654 if not source:
621 break
655 break
622
656
623 chunks.append(cobj.compress(source))
657 chunks.append(cobj.compress(source))
624 i += chunk_size
658 i += chunk_size
625
659
626 chunks.append(cobj.flush())
660 chunks.append(cobj.flush())
627
661
628 self.assertEqual(b"".join(chunks), ref_frame)
662 self.assertEqual(b"".join(chunks), ref_frame)
629
663
630 @hypothesis.settings(
664 @hypothesis.settings(
631 suppress_health_check=[
665 suppress_health_check=[
632 hypothesis.HealthCheck.large_base_example,
666 hypothesis.HealthCheck.large_base_example,
633 hypothesis.HealthCheck.too_slow,
667 hypothesis.HealthCheck.too_slow,
634 ]
668 ]
635 )
669 )
636 @hypothesis.given(
670 @hypothesis.given(
637 original=strategies.sampled_from(random_input_data()),
671 original=strategies.sampled_from(random_input_data()),
638 level=strategies.integers(min_value=1, max_value=5),
672 level=strategies.integers(min_value=1, max_value=5),
639 chunk_sizes=strategies.data(),
673 chunk_sizes=strategies.data(),
640 flushes=strategies.data(),
674 flushes=strategies.data(),
641 )
675 )
642 def test_flush_block(self, original, level, chunk_sizes, flushes):
676 def test_flush_block(self, original, level, chunk_sizes, flushes):
643 cctx = zstd.ZstdCompressor(level=level)
677 cctx = zstd.ZstdCompressor(level=level)
644 cobj = cctx.compressobj()
678 cobj = cctx.compressobj()
645
679
646 dctx = zstd.ZstdDecompressor()
680 dctx = zstd.ZstdDecompressor()
647 dobj = dctx.decompressobj()
681 dobj = dctx.decompressobj()
648
682
649 compressed_chunks = []
683 compressed_chunks = []
650 decompressed_chunks = []
684 decompressed_chunks = []
651 i = 0
685 i = 0
652 while True:
686 while True:
653 input_size = chunk_sizes.draw(strategies.integers(1, 4096))
687 input_size = chunk_sizes.draw(strategies.integers(1, 4096))
654 source = original[i : i + input_size]
688 source = original[i : i + input_size]
655 if not source:
689 if not source:
656 break
690 break
657
691
658 i += input_size
692 i += input_size
659
693
660 chunk = cobj.compress(source)
694 chunk = cobj.compress(source)
661 compressed_chunks.append(chunk)
695 compressed_chunks.append(chunk)
662 decompressed_chunks.append(dobj.decompress(chunk))
696 decompressed_chunks.append(dobj.decompress(chunk))
663
697
664 if not flushes.draw(strategies.booleans()):
698 if not flushes.draw(strategies.booleans()):
665 continue
699 continue
666
700
667 chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
701 chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
668 compressed_chunks.append(chunk)
702 compressed_chunks.append(chunk)
669 decompressed_chunks.append(dobj.decompress(chunk))
703 decompressed_chunks.append(dobj.decompress(chunk))
670
704
671 self.assertEqual(b"".join(decompressed_chunks), original[0:i])
705 self.assertEqual(b"".join(decompressed_chunks), original[0:i])
672
706
673 chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH)
707 chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH)
674 compressed_chunks.append(chunk)
708 compressed_chunks.append(chunk)
675 decompressed_chunks.append(dobj.decompress(chunk))
709 decompressed_chunks.append(dobj.decompress(chunk))
676
710
677 self.assertEqual(
711 self.assertEqual(
678 dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)),
712 dctx.decompress(
713 b"".join(compressed_chunks), max_output_size=len(original)
714 ),
679 original,
715 original,
680 )
716 )
681 self.assertEqual(b"".join(decompressed_chunks), original)
717 self.assertEqual(b"".join(decompressed_chunks), original)
682
718
683
719
684 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
720 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
685 @make_cffi
721 @make_cffi
686 class TestCompressor_read_to_iter_fuzzing(TestCase):
722 class TestCompressor_read_to_iter_fuzzing(TestCase):
687 @hypothesis.given(
723 @hypothesis.given(
688 original=strategies.sampled_from(random_input_data()),
724 original=strategies.sampled_from(random_input_data()),
689 level=strategies.integers(min_value=1, max_value=5),
725 level=strategies.integers(min_value=1, max_value=5),
690 read_size=strategies.integers(min_value=1, max_value=4096),
726 read_size=strategies.integers(min_value=1, max_value=4096),
691 write_size=strategies.integers(min_value=1, max_value=4096),
727 write_size=strategies.integers(min_value=1, max_value=4096),
692 )
728 )
693 def test_read_write_size_variance(self, original, level, read_size, write_size):
729 def test_read_write_size_variance(
730 self, original, level, read_size, write_size
731 ):
694 refcctx = zstd.ZstdCompressor(level=level)
732 refcctx = zstd.ZstdCompressor(level=level)
695 ref_frame = refcctx.compress(original)
733 ref_frame = refcctx.compress(original)
696
734
697 source = io.BytesIO(original)
735 source = io.BytesIO(original)
698
736
699 cctx = zstd.ZstdCompressor(level=level)
737 cctx = zstd.ZstdCompressor(level=level)
700 chunks = list(
738 chunks = list(
701 cctx.read_to_iter(
739 cctx.read_to_iter(
702 source, size=len(original), read_size=read_size, write_size=write_size
740 source,
741 size=len(original),
742 read_size=read_size,
743 write_size=write_size,
703 )
744 )
704 )
745 )
705
746
706 self.assertEqual(b"".join(chunks), ref_frame)
747 self.assertEqual(b"".join(chunks), ref_frame)
707
748
708
749
709 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
750 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
710 class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase):
751 class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase):
711 @hypothesis.given(
752 @hypothesis.given(
712 original=strategies.lists(
753 original=strategies.lists(
713 strategies.sampled_from(random_input_data()), min_size=1, max_size=1024
754 strategies.sampled_from(random_input_data()),
755 min_size=1,
756 max_size=1024,
714 ),
757 ),
715 threads=strategies.integers(min_value=1, max_value=8),
758 threads=strategies.integers(min_value=1, max_value=8),
716 use_dict=strategies.booleans(),
759 use_dict=strategies.booleans(),
717 )
760 )
718 def test_data_equivalence(self, original, threads, use_dict):
761 def test_data_equivalence(self, original, threads, use_dict):
719 kwargs = {}
762 kwargs = {}
720
763
721 # Use a content dictionary because it is cheap to create.
764 # Use a content dictionary because it is cheap to create.
722 if use_dict:
765 if use_dict:
723 kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0])
766 kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0])
724
767
725 cctx = zstd.ZstdCompressor(level=1, write_checksum=True, **kwargs)
768 cctx = zstd.ZstdCompressor(level=1, write_checksum=True, **kwargs)
726
769
727 if not hasattr(cctx, "multi_compress_to_buffer"):
770 if not hasattr(cctx, "multi_compress_to_buffer"):
728 self.skipTest("multi_compress_to_buffer not available")
771 self.skipTest("multi_compress_to_buffer not available")
729
772
730 result = cctx.multi_compress_to_buffer(original, threads=-1)
773 result = cctx.multi_compress_to_buffer(original, threads=-1)
731
774
732 self.assertEqual(len(result), len(original))
775 self.assertEqual(len(result), len(original))
733
776
734 # The frame produced via the batch APIs may not be bit identical to that
777 # The frame produced via the batch APIs may not be bit identical to that
735 # produced by compress() because compression parameters are adjusted
778 # produced by compress() because compression parameters are adjusted
736 # from the first input in batch mode. So the only thing we can do is
779 # from the first input in batch mode. So the only thing we can do is
737 # verify the decompressed data matches the input.
780 # verify the decompressed data matches the input.
738 dctx = zstd.ZstdDecompressor(**kwargs)
781 dctx = zstd.ZstdDecompressor(**kwargs)
739
782
740 for i, frame in enumerate(result):
783 for i, frame in enumerate(result):
741 self.assertEqual(dctx.decompress(frame), original[i])
784 self.assertEqual(dctx.decompress(frame), original[i])
742
785
743
786
744 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
787 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
745 @make_cffi
788 @make_cffi
746 class TestCompressor_chunker_fuzzing(TestCase):
789 class TestCompressor_chunker_fuzzing(TestCase):
747 @hypothesis.settings(
790 @hypothesis.settings(
748 suppress_health_check=[
791 suppress_health_check=[
749 hypothesis.HealthCheck.large_base_example,
792 hypothesis.HealthCheck.large_base_example,
750 hypothesis.HealthCheck.too_slow,
793 hypothesis.HealthCheck.too_slow,
751 ]
794 ]
752 )
795 )
753 @hypothesis.given(
796 @hypothesis.given(
754 original=strategies.sampled_from(random_input_data()),
797 original=strategies.sampled_from(random_input_data()),
755 level=strategies.integers(min_value=1, max_value=5),
798 level=strategies.integers(min_value=1, max_value=5),
756 chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576),
799 chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576),
757 input_sizes=strategies.data(),
800 input_sizes=strategies.data(),
758 )
801 )
759 def test_random_input_sizes(self, original, level, chunk_size, input_sizes):
802 def test_random_input_sizes(self, original, level, chunk_size, input_sizes):
760 cctx = zstd.ZstdCompressor(level=level)
803 cctx = zstd.ZstdCompressor(level=level)
761 chunker = cctx.chunker(chunk_size=chunk_size)
804 chunker = cctx.chunker(chunk_size=chunk_size)
762
805
763 chunks = []
806 chunks = []
764 i = 0
807 i = 0
765 while True:
808 while True:
766 input_size = input_sizes.draw(strategies.integers(1, 4096))
809 input_size = input_sizes.draw(strategies.integers(1, 4096))
767 source = original[i : i + input_size]
810 source = original[i : i + input_size]
768 if not source:
811 if not source:
769 break
812 break
770
813
771 chunks.extend(chunker.compress(source))
814 chunks.extend(chunker.compress(source))
772 i += input_size
815 i += input_size
773
816
774 chunks.extend(chunker.finish())
817 chunks.extend(chunker.finish())
775
818
776 dctx = zstd.ZstdDecompressor()
819 dctx = zstd.ZstdDecompressor()
777
820
778 self.assertEqual(
821 self.assertEqual(
779 dctx.decompress(b"".join(chunks), max_output_size=len(original)), original
822 dctx.decompress(b"".join(chunks), max_output_size=len(original)),
823 original,
780 )
824 )
781
825
782 self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1]))
826 self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1]))
783
827
784 @hypothesis.settings(
828 @hypothesis.settings(
785 suppress_health_check=[
829 suppress_health_check=[
786 hypothesis.HealthCheck.large_base_example,
830 hypothesis.HealthCheck.large_base_example,
787 hypothesis.HealthCheck.too_slow,
831 hypothesis.HealthCheck.too_slow,
788 ]
832 ]
789 )
833 )
790 @hypothesis.given(
834 @hypothesis.given(
791 original=strategies.sampled_from(random_input_data()),
835 original=strategies.sampled_from(random_input_data()),
792 level=strategies.integers(min_value=1, max_value=5),
836 level=strategies.integers(min_value=1, max_value=5),
793 chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576),
837 chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576),
794 input_sizes=strategies.data(),
838 input_sizes=strategies.data(),
795 flushes=strategies.data(),
839 flushes=strategies.data(),
796 )
840 )
797 def test_flush_block(self, original, level, chunk_size, input_sizes, flushes):
841 def test_flush_block(
842 self, original, level, chunk_size, input_sizes, flushes
843 ):
798 cctx = zstd.ZstdCompressor(level=level)
844 cctx = zstd.ZstdCompressor(level=level)
799 chunker = cctx.chunker(chunk_size=chunk_size)
845 chunker = cctx.chunker(chunk_size=chunk_size)
800
846
801 dctx = zstd.ZstdDecompressor()
847 dctx = zstd.ZstdDecompressor()
802 dobj = dctx.decompressobj()
848 dobj = dctx.decompressobj()
803
849
804 compressed_chunks = []
850 compressed_chunks = []
805 decompressed_chunks = []
851 decompressed_chunks = []
806 i = 0
852 i = 0
807 while True:
853 while True:
808 input_size = input_sizes.draw(strategies.integers(1, 4096))
854 input_size = input_sizes.draw(strategies.integers(1, 4096))
809 source = original[i : i + input_size]
855 source = original[i : i + input_size]
810 if not source:
856 if not source:
811 break
857 break
812
858
813 i += input_size
859 i += input_size
814
860
815 chunks = list(chunker.compress(source))
861 chunks = list(chunker.compress(source))
816 compressed_chunks.extend(chunks)
862 compressed_chunks.extend(chunks)
817 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
863 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
818
864
819 if not flushes.draw(strategies.booleans()):
865 if not flushes.draw(strategies.booleans()):
820 continue
866 continue
821
867
822 chunks = list(chunker.flush())
868 chunks = list(chunker.flush())
823 compressed_chunks.extend(chunks)
869 compressed_chunks.extend(chunks)
824 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
870 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
825
871
826 self.assertEqual(b"".join(decompressed_chunks), original[0:i])
872 self.assertEqual(b"".join(decompressed_chunks), original[0:i])
827
873
828 chunks = list(chunker.finish())
874 chunks = list(chunker.finish())
829 compressed_chunks.extend(chunks)
875 compressed_chunks.extend(chunks)
830 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
876 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
831
877
832 self.assertEqual(
878 self.assertEqual(
833 dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)),
879 dctx.decompress(
880 b"".join(compressed_chunks), max_output_size=len(original)
881 ),
834 original,
882 original,
835 )
883 )
836 self.assertEqual(b"".join(decompressed_chunks), original)
884 self.assertEqual(b"".join(decompressed_chunks), original)
@@ -1,241 +1,255 b''
1 import sys
1 import sys
2 import unittest
2 import unittest
3
3
4 import zstandard as zstd
4 import zstandard as zstd
5
5
6 from .common import (
6 from .common import (
7 make_cffi,
7 make_cffi,
8 TestCase,
8 TestCase,
9 )
9 )
10
10
11
11
12 @make_cffi
12 @make_cffi
13 class TestCompressionParameters(TestCase):
13 class TestCompressionParameters(TestCase):
14 def test_bounds(self):
14 def test_bounds(self):
15 zstd.ZstdCompressionParameters(
15 zstd.ZstdCompressionParameters(
16 window_log=zstd.WINDOWLOG_MIN,
16 window_log=zstd.WINDOWLOG_MIN,
17 chain_log=zstd.CHAINLOG_MIN,
17 chain_log=zstd.CHAINLOG_MIN,
18 hash_log=zstd.HASHLOG_MIN,
18 hash_log=zstd.HASHLOG_MIN,
19 search_log=zstd.SEARCHLOG_MIN,
19 search_log=zstd.SEARCHLOG_MIN,
20 min_match=zstd.MINMATCH_MIN + 1,
20 min_match=zstd.MINMATCH_MIN + 1,
21 target_length=zstd.TARGETLENGTH_MIN,
21 target_length=zstd.TARGETLENGTH_MIN,
22 strategy=zstd.STRATEGY_FAST,
22 strategy=zstd.STRATEGY_FAST,
23 )
23 )
24
24
25 zstd.ZstdCompressionParameters(
25 zstd.ZstdCompressionParameters(
26 window_log=zstd.WINDOWLOG_MAX,
26 window_log=zstd.WINDOWLOG_MAX,
27 chain_log=zstd.CHAINLOG_MAX,
27 chain_log=zstd.CHAINLOG_MAX,
28 hash_log=zstd.HASHLOG_MAX,
28 hash_log=zstd.HASHLOG_MAX,
29 search_log=zstd.SEARCHLOG_MAX,
29 search_log=zstd.SEARCHLOG_MAX,
30 min_match=zstd.MINMATCH_MAX - 1,
30 min_match=zstd.MINMATCH_MAX - 1,
31 target_length=zstd.TARGETLENGTH_MAX,
31 target_length=zstd.TARGETLENGTH_MAX,
32 strategy=zstd.STRATEGY_BTULTRA2,
32 strategy=zstd.STRATEGY_BTULTRA2,
33 )
33 )
34
34
35 def test_from_level(self):
35 def test_from_level(self):
36 p = zstd.ZstdCompressionParameters.from_level(1)
36 p = zstd.ZstdCompressionParameters.from_level(1)
37 self.assertIsInstance(p, zstd.CompressionParameters)
37 self.assertIsInstance(p, zstd.CompressionParameters)
38
38
39 self.assertEqual(p.window_log, 19)
39 self.assertEqual(p.window_log, 19)
40
40
41 p = zstd.ZstdCompressionParameters.from_level(-4)
41 p = zstd.ZstdCompressionParameters.from_level(-4)
42 self.assertEqual(p.window_log, 19)
42 self.assertEqual(p.window_log, 19)
43
43
44 def test_members(self):
44 def test_members(self):
45 p = zstd.ZstdCompressionParameters(
45 p = zstd.ZstdCompressionParameters(
46 window_log=10,
46 window_log=10,
47 chain_log=6,
47 chain_log=6,
48 hash_log=7,
48 hash_log=7,
49 search_log=4,
49 search_log=4,
50 min_match=5,
50 min_match=5,
51 target_length=8,
51 target_length=8,
52 strategy=1,
52 strategy=1,
53 )
53 )
54 self.assertEqual(p.window_log, 10)
54 self.assertEqual(p.window_log, 10)
55 self.assertEqual(p.chain_log, 6)
55 self.assertEqual(p.chain_log, 6)
56 self.assertEqual(p.hash_log, 7)
56 self.assertEqual(p.hash_log, 7)
57 self.assertEqual(p.search_log, 4)
57 self.assertEqual(p.search_log, 4)
58 self.assertEqual(p.min_match, 5)
58 self.assertEqual(p.min_match, 5)
59 self.assertEqual(p.target_length, 8)
59 self.assertEqual(p.target_length, 8)
60 self.assertEqual(p.compression_strategy, 1)
60 self.assertEqual(p.compression_strategy, 1)
61
61
62 p = zstd.ZstdCompressionParameters(compression_level=2)
62 p = zstd.ZstdCompressionParameters(compression_level=2)
63 self.assertEqual(p.compression_level, 2)
63 self.assertEqual(p.compression_level, 2)
64
64
65 p = zstd.ZstdCompressionParameters(threads=4)
65 p = zstd.ZstdCompressionParameters(threads=4)
66 self.assertEqual(p.threads, 4)
66 self.assertEqual(p.threads, 4)
67
67
68 p = zstd.ZstdCompressionParameters(threads=2, job_size=1048576, overlap_log=6)
68 p = zstd.ZstdCompressionParameters(
69 threads=2, job_size=1048576, overlap_log=6
70 )
69 self.assertEqual(p.threads, 2)
71 self.assertEqual(p.threads, 2)
70 self.assertEqual(p.job_size, 1048576)
72 self.assertEqual(p.job_size, 1048576)
71 self.assertEqual(p.overlap_log, 6)
73 self.assertEqual(p.overlap_log, 6)
72 self.assertEqual(p.overlap_size_log, 6)
74 self.assertEqual(p.overlap_size_log, 6)
73
75
74 p = zstd.ZstdCompressionParameters(compression_level=-1)
76 p = zstd.ZstdCompressionParameters(compression_level=-1)
75 self.assertEqual(p.compression_level, -1)
77 self.assertEqual(p.compression_level, -1)
76
78
77 p = zstd.ZstdCompressionParameters(compression_level=-2)
79 p = zstd.ZstdCompressionParameters(compression_level=-2)
78 self.assertEqual(p.compression_level, -2)
80 self.assertEqual(p.compression_level, -2)
79
81
80 p = zstd.ZstdCompressionParameters(force_max_window=True)
82 p = zstd.ZstdCompressionParameters(force_max_window=True)
81 self.assertEqual(p.force_max_window, 1)
83 self.assertEqual(p.force_max_window, 1)
82
84
83 p = zstd.ZstdCompressionParameters(enable_ldm=True)
85 p = zstd.ZstdCompressionParameters(enable_ldm=True)
84 self.assertEqual(p.enable_ldm, 1)
86 self.assertEqual(p.enable_ldm, 1)
85
87
86 p = zstd.ZstdCompressionParameters(ldm_hash_log=7)
88 p = zstd.ZstdCompressionParameters(ldm_hash_log=7)
87 self.assertEqual(p.ldm_hash_log, 7)
89 self.assertEqual(p.ldm_hash_log, 7)
88
90
89 p = zstd.ZstdCompressionParameters(ldm_min_match=6)
91 p = zstd.ZstdCompressionParameters(ldm_min_match=6)
90 self.assertEqual(p.ldm_min_match, 6)
92 self.assertEqual(p.ldm_min_match, 6)
91
93
92 p = zstd.ZstdCompressionParameters(ldm_bucket_size_log=7)
94 p = zstd.ZstdCompressionParameters(ldm_bucket_size_log=7)
93 self.assertEqual(p.ldm_bucket_size_log, 7)
95 self.assertEqual(p.ldm_bucket_size_log, 7)
94
96
95 p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8)
97 p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8)
96 self.assertEqual(p.ldm_hash_every_log, 8)
98 self.assertEqual(p.ldm_hash_every_log, 8)
97 self.assertEqual(p.ldm_hash_rate_log, 8)
99 self.assertEqual(p.ldm_hash_rate_log, 8)
98
100
99 def test_estimated_compression_context_size(self):
101 def test_estimated_compression_context_size(self):
100 p = zstd.ZstdCompressionParameters(
102 p = zstd.ZstdCompressionParameters(
101 window_log=20,
103 window_log=20,
102 chain_log=16,
104 chain_log=16,
103 hash_log=17,
105 hash_log=17,
104 search_log=1,
106 search_log=1,
105 min_match=5,
107 min_match=5,
106 target_length=16,
108 target_length=16,
107 strategy=zstd.STRATEGY_DFAST,
109 strategy=zstd.STRATEGY_DFAST,
108 )
110 )
109
111
110 # 32-bit has slightly different values from 64-bit.
112 # 32-bit has slightly different values from 64-bit.
111 self.assertAlmostEqual(
113 self.assertAlmostEqual(
112 p.estimated_compression_context_size(), 1294464, delta=400
114 p.estimated_compression_context_size(), 1294464, delta=400
113 )
115 )
114
116
115 def test_strategy(self):
117 def test_strategy(self):
116 with self.assertRaisesRegex(
118 with self.assertRaisesRegex(
117 ValueError, "cannot specify both compression_strategy"
119 ValueError, "cannot specify both compression_strategy"
118 ):
120 ):
119 zstd.ZstdCompressionParameters(strategy=0, compression_strategy=0)
121 zstd.ZstdCompressionParameters(strategy=0, compression_strategy=0)
120
122
121 p = zstd.ZstdCompressionParameters(strategy=2)
123 p = zstd.ZstdCompressionParameters(strategy=2)
122 self.assertEqual(p.compression_strategy, 2)
124 self.assertEqual(p.compression_strategy, 2)
123
125
124 p = zstd.ZstdCompressionParameters(strategy=3)
126 p = zstd.ZstdCompressionParameters(strategy=3)
125 self.assertEqual(p.compression_strategy, 3)
127 self.assertEqual(p.compression_strategy, 3)
126
128
127 def test_ldm_hash_rate_log(self):
129 def test_ldm_hash_rate_log(self):
128 with self.assertRaisesRegex(
130 with self.assertRaisesRegex(
129 ValueError, "cannot specify both ldm_hash_rate_log"
131 ValueError, "cannot specify both ldm_hash_rate_log"
130 ):
132 ):
131 zstd.ZstdCompressionParameters(ldm_hash_rate_log=8, ldm_hash_every_log=4)
133 zstd.ZstdCompressionParameters(
134 ldm_hash_rate_log=8, ldm_hash_every_log=4
135 )
132
136
133 p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8)
137 p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8)
134 self.assertEqual(p.ldm_hash_every_log, 8)
138 self.assertEqual(p.ldm_hash_every_log, 8)
135
139
136 p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16)
140 p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16)
137 self.assertEqual(p.ldm_hash_every_log, 16)
141 self.assertEqual(p.ldm_hash_every_log, 16)
138
142
139 def test_overlap_log(self):
143 def test_overlap_log(self):
140 with self.assertRaisesRegex(ValueError, "cannot specify both overlap_log"):
144 with self.assertRaisesRegex(
145 ValueError, "cannot specify both overlap_log"
146 ):
141 zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9)
147 zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9)
142
148
143 p = zstd.ZstdCompressionParameters(overlap_log=2)
149 p = zstd.ZstdCompressionParameters(overlap_log=2)
144 self.assertEqual(p.overlap_log, 2)
150 self.assertEqual(p.overlap_log, 2)
145 self.assertEqual(p.overlap_size_log, 2)
151 self.assertEqual(p.overlap_size_log, 2)
146
152
147 p = zstd.ZstdCompressionParameters(overlap_size_log=4)
153 p = zstd.ZstdCompressionParameters(overlap_size_log=4)
148 self.assertEqual(p.overlap_log, 4)
154 self.assertEqual(p.overlap_log, 4)
149 self.assertEqual(p.overlap_size_log, 4)
155 self.assertEqual(p.overlap_size_log, 4)
150
156
151
157
152 @make_cffi
158 @make_cffi
153 class TestFrameParameters(TestCase):
159 class TestFrameParameters(TestCase):
154 def test_invalid_type(self):
160 def test_invalid_type(self):
155 with self.assertRaises(TypeError):
161 with self.assertRaises(TypeError):
156 zstd.get_frame_parameters(None)
162 zstd.get_frame_parameters(None)
157
163
158 # Python 3 doesn't appear to convert unicode to Py_buffer.
164 # Python 3 doesn't appear to convert unicode to Py_buffer.
159 if sys.version_info[0] >= 3:
165 if sys.version_info[0] >= 3:
160 with self.assertRaises(TypeError):
166 with self.assertRaises(TypeError):
161 zstd.get_frame_parameters(u"foobarbaz")
167 zstd.get_frame_parameters(u"foobarbaz")
162 else:
168 else:
163 # CPython will convert unicode to Py_buffer. But CFFI won't.
169 # CPython will convert unicode to Py_buffer. But CFFI won't.
164 if zstd.backend == "cffi":
170 if zstd.backend == "cffi":
165 with self.assertRaises(TypeError):
171 with self.assertRaises(TypeError):
166 zstd.get_frame_parameters(u"foobarbaz")
172 zstd.get_frame_parameters(u"foobarbaz")
167 else:
173 else:
168 with self.assertRaises(zstd.ZstdError):
174 with self.assertRaises(zstd.ZstdError):
169 zstd.get_frame_parameters(u"foobarbaz")
175 zstd.get_frame_parameters(u"foobarbaz")
170
176
171 def test_invalid_input_sizes(self):
177 def test_invalid_input_sizes(self):
172 with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"):
178 with self.assertRaisesRegex(
179 zstd.ZstdError, "not enough data for frame"
180 ):
173 zstd.get_frame_parameters(b"")
181 zstd.get_frame_parameters(b"")
174
182
175 with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"):
183 with self.assertRaisesRegex(
184 zstd.ZstdError, "not enough data for frame"
185 ):
176 zstd.get_frame_parameters(zstd.FRAME_HEADER)
186 zstd.get_frame_parameters(zstd.FRAME_HEADER)
177
187
178 def test_invalid_frame(self):
188 def test_invalid_frame(self):
179 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
189 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
180 zstd.get_frame_parameters(b"foobarbaz")
190 zstd.get_frame_parameters(b"foobarbaz")
181
191
182 def test_attributes(self):
192 def test_attributes(self):
183 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00")
193 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00")
184 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
194 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
185 self.assertEqual(params.window_size, 1024)
195 self.assertEqual(params.window_size, 1024)
186 self.assertEqual(params.dict_id, 0)
196 self.assertEqual(params.dict_id, 0)
187 self.assertFalse(params.has_checksum)
197 self.assertFalse(params.has_checksum)
188
198
189 # Lowest 2 bits indicate a dictionary and length. Here, the dict id is 1 byte.
199 # Lowest 2 bits indicate a dictionary and length. Here, the dict id is 1 byte.
190 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x01\x00\xff")
200 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x01\x00\xff")
191 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
201 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
192 self.assertEqual(params.window_size, 1024)
202 self.assertEqual(params.window_size, 1024)
193 self.assertEqual(params.dict_id, 255)
203 self.assertEqual(params.dict_id, 255)
194 self.assertFalse(params.has_checksum)
204 self.assertFalse(params.has_checksum)
195
205
196 # Lowest 3rd bit indicates if checksum is present.
206 # Lowest 3rd bit indicates if checksum is present.
197 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00")
207 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00")
198 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
208 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
199 self.assertEqual(params.window_size, 1024)
209 self.assertEqual(params.window_size, 1024)
200 self.assertEqual(params.dict_id, 0)
210 self.assertEqual(params.dict_id, 0)
201 self.assertTrue(params.has_checksum)
211 self.assertTrue(params.has_checksum)
202
212
203 # Upper 2 bits indicate content size.
213 # Upper 2 bits indicate content size.
204 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x40\x00\xff\x00")
214 params = zstd.get_frame_parameters(
215 zstd.FRAME_HEADER + b"\x40\x00\xff\x00"
216 )
205 self.assertEqual(params.content_size, 511)
217 self.assertEqual(params.content_size, 511)
206 self.assertEqual(params.window_size, 1024)
218 self.assertEqual(params.window_size, 1024)
207 self.assertEqual(params.dict_id, 0)
219 self.assertEqual(params.dict_id, 0)
208 self.assertFalse(params.has_checksum)
220 self.assertFalse(params.has_checksum)
209
221
210 # Window descriptor is 2nd byte after frame header.
222 # Window descriptor is 2nd byte after frame header.
211 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40")
223 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40")
212 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
224 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
213 self.assertEqual(params.window_size, 262144)
225 self.assertEqual(params.window_size, 262144)
214 self.assertEqual(params.dict_id, 0)
226 self.assertEqual(params.dict_id, 0)
215 self.assertFalse(params.has_checksum)
227 self.assertFalse(params.has_checksum)
216
228
217 # Set multiple things.
229 # Set multiple things.
218 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00")
230 params = zstd.get_frame_parameters(
231 zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00"
232 )
219 self.assertEqual(params.content_size, 272)
233 self.assertEqual(params.content_size, 272)
220 self.assertEqual(params.window_size, 262144)
234 self.assertEqual(params.window_size, 262144)
221 self.assertEqual(params.dict_id, 15)
235 self.assertEqual(params.dict_id, 15)
222 self.assertTrue(params.has_checksum)
236 self.assertTrue(params.has_checksum)
223
237
224 def test_input_types(self):
238 def test_input_types(self):
225 v = zstd.FRAME_HEADER + b"\x00\x00"
239 v = zstd.FRAME_HEADER + b"\x00\x00"
226
240
227 mutable_array = bytearray(len(v))
241 mutable_array = bytearray(len(v))
228 mutable_array[:] = v
242 mutable_array[:] = v
229
243
230 sources = [
244 sources = [
231 memoryview(v),
245 memoryview(v),
232 bytearray(v),
246 bytearray(v),
233 mutable_array,
247 mutable_array,
234 ]
248 ]
235
249
236 for source in sources:
250 for source in sources:
237 params = zstd.get_frame_parameters(source)
251 params = zstd.get_frame_parameters(source)
238 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
252 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
239 self.assertEqual(params.window_size, 1024)
253 self.assertEqual(params.window_size, 1024)
240 self.assertEqual(params.dict_id, 0)
254 self.assertEqual(params.dict_id, 0)
241 self.assertFalse(params.has_checksum)
255 self.assertFalse(params.has_checksum)
@@ -1,105 +1,121 b''
1 import io
1 import io
2 import os
2 import os
3 import sys
3 import sys
4 import unittest
4 import unittest
5
5
6 try:
6 try:
7 import hypothesis
7 import hypothesis
8 import hypothesis.strategies as strategies
8 import hypothesis.strategies as strategies
9 except ImportError:
9 except ImportError:
10 raise unittest.SkipTest("hypothesis not available")
10 raise unittest.SkipTest("hypothesis not available")
11
11
12 import zstandard as zstd
12 import zstandard as zstd
13
13
14 from .common import (
14 from .common import (
15 make_cffi,
15 make_cffi,
16 TestCase,
16 TestCase,
17 )
17 )
18
18
19
19
20 s_windowlog = strategies.integers(
20 s_windowlog = strategies.integers(
21 min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX
21 min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX
22 )
22 )
23 s_chainlog = strategies.integers(
23 s_chainlog = strategies.integers(
24 min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX
24 min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX
25 )
25 )
26 s_hashlog = strategies.integers(min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX)
26 s_hashlog = strategies.integers(
27 min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX
28 )
27 s_searchlog = strategies.integers(
29 s_searchlog = strategies.integers(
28 min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX
30 min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX
29 )
31 )
30 s_minmatch = strategies.integers(
32 s_minmatch = strategies.integers(
31 min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX
33 min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX
32 )
34 )
33 s_targetlength = strategies.integers(
35 s_targetlength = strategies.integers(
34 min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX
36 min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX
35 )
37 )
36 s_strategy = strategies.sampled_from(
38 s_strategy = strategies.sampled_from(
37 (
39 (
38 zstd.STRATEGY_FAST,
40 zstd.STRATEGY_FAST,
39 zstd.STRATEGY_DFAST,
41 zstd.STRATEGY_DFAST,
40 zstd.STRATEGY_GREEDY,
42 zstd.STRATEGY_GREEDY,
41 zstd.STRATEGY_LAZY,
43 zstd.STRATEGY_LAZY,
42 zstd.STRATEGY_LAZY2,
44 zstd.STRATEGY_LAZY2,
43 zstd.STRATEGY_BTLAZY2,
45 zstd.STRATEGY_BTLAZY2,
44 zstd.STRATEGY_BTOPT,
46 zstd.STRATEGY_BTOPT,
45 zstd.STRATEGY_BTULTRA,
47 zstd.STRATEGY_BTULTRA,
46 zstd.STRATEGY_BTULTRA2,
48 zstd.STRATEGY_BTULTRA2,
47 )
49 )
48 )
50 )
49
51
50
52
51 @make_cffi
53 @make_cffi
52 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
54 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
53 class TestCompressionParametersHypothesis(TestCase):
55 class TestCompressionParametersHypothesis(TestCase):
54 @hypothesis.given(
56 @hypothesis.given(
55 s_windowlog,
57 s_windowlog,
56 s_chainlog,
58 s_chainlog,
57 s_hashlog,
59 s_hashlog,
58 s_searchlog,
60 s_searchlog,
59 s_minmatch,
61 s_minmatch,
60 s_targetlength,
62 s_targetlength,
61 s_strategy,
63 s_strategy,
62 )
64 )
63 def test_valid_init(
65 def test_valid_init(
64 self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy
66 self,
67 windowlog,
68 chainlog,
69 hashlog,
70 searchlog,
71 minmatch,
72 targetlength,
73 strategy,
65 ):
74 ):
66 zstd.ZstdCompressionParameters(
75 zstd.ZstdCompressionParameters(
67 window_log=windowlog,
76 window_log=windowlog,
68 chain_log=chainlog,
77 chain_log=chainlog,
69 hash_log=hashlog,
78 hash_log=hashlog,
70 search_log=searchlog,
79 search_log=searchlog,
71 min_match=minmatch,
80 min_match=minmatch,
72 target_length=targetlength,
81 target_length=targetlength,
73 strategy=strategy,
82 strategy=strategy,
74 )
83 )
75
84
76 @hypothesis.given(
85 @hypothesis.given(
77 s_windowlog,
86 s_windowlog,
78 s_chainlog,
87 s_chainlog,
79 s_hashlog,
88 s_hashlog,
80 s_searchlog,
89 s_searchlog,
81 s_minmatch,
90 s_minmatch,
82 s_targetlength,
91 s_targetlength,
83 s_strategy,
92 s_strategy,
84 )
93 )
85 def test_estimated_compression_context_size(
94 def test_estimated_compression_context_size(
86 self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy
95 self,
96 windowlog,
97 chainlog,
98 hashlog,
99 searchlog,
100 minmatch,
101 targetlength,
102 strategy,
87 ):
103 ):
88 if minmatch == zstd.MINMATCH_MIN and strategy in (
104 if minmatch == zstd.MINMATCH_MIN and strategy in (
89 zstd.STRATEGY_FAST,
105 zstd.STRATEGY_FAST,
90 zstd.STRATEGY_GREEDY,
106 zstd.STRATEGY_GREEDY,
91 ):
107 ):
92 minmatch += 1
108 minmatch += 1
93 elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST:
109 elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST:
94 minmatch -= 1
110 minmatch -= 1
95
111
96 p = zstd.ZstdCompressionParameters(
112 p = zstd.ZstdCompressionParameters(
97 window_log=windowlog,
113 window_log=windowlog,
98 chain_log=chainlog,
114 chain_log=chainlog,
99 hash_log=hashlog,
115 hash_log=hashlog,
100 search_log=searchlog,
116 search_log=searchlog,
101 min_match=minmatch,
117 min_match=minmatch,
102 target_length=targetlength,
118 target_length=targetlength,
103 strategy=strategy,
119 strategy=strategy,
104 )
120 )
105 size = p.estimated_compression_context_size()
121 size = p.estimated_compression_context_size()
@@ -1,1670 +1,1714 b''
1 import io
1 import io
2 import os
2 import os
3 import random
3 import random
4 import struct
4 import struct
5 import sys
5 import sys
6 import tempfile
6 import tempfile
7 import unittest
7 import unittest
8
8
9 import zstandard as zstd
9 import zstandard as zstd
10
10
11 from .common import (
11 from .common import (
12 generate_samples,
12 generate_samples,
13 make_cffi,
13 make_cffi,
14 NonClosingBytesIO,
14 NonClosingBytesIO,
15 OpCountingBytesIO,
15 OpCountingBytesIO,
16 TestCase,
16 TestCase,
17 )
17 )
18
18
19
19
20 if sys.version_info[0] >= 3:
20 if sys.version_info[0] >= 3:
21 next = lambda it: it.__next__()
21 next = lambda it: it.__next__()
22 else:
22 else:
23 next = lambda it: it.next()
23 next = lambda it: it.next()
24
24
25
25
26 @make_cffi
26 @make_cffi
27 class TestFrameHeaderSize(TestCase):
27 class TestFrameHeaderSize(TestCase):
28 def test_empty(self):
28 def test_empty(self):
29 with self.assertRaisesRegex(
29 with self.assertRaisesRegex(
30 zstd.ZstdError,
30 zstd.ZstdError,
31 "could not determine frame header size: Src size " "is incorrect",
31 "could not determine frame header size: Src size " "is incorrect",
32 ):
32 ):
33 zstd.frame_header_size(b"")
33 zstd.frame_header_size(b"")
34
34
35 def test_too_small(self):
35 def test_too_small(self):
36 with self.assertRaisesRegex(
36 with self.assertRaisesRegex(
37 zstd.ZstdError,
37 zstd.ZstdError,
38 "could not determine frame header size: Src size " "is incorrect",
38 "could not determine frame header size: Src size " "is incorrect",
39 ):
39 ):
40 zstd.frame_header_size(b"foob")
40 zstd.frame_header_size(b"foob")
41
41
42 def test_basic(self):
42 def test_basic(self):
43 # It doesn't matter that it isn't a valid frame.
43 # It doesn't matter that it isn't a valid frame.
44 self.assertEqual(zstd.frame_header_size(b"long enough but no magic"), 6)
44 self.assertEqual(zstd.frame_header_size(b"long enough but no magic"), 6)
45
45
46
46
47 @make_cffi
47 @make_cffi
48 class TestFrameContentSize(TestCase):
48 class TestFrameContentSize(TestCase):
49 def test_empty(self):
49 def test_empty(self):
50 with self.assertRaisesRegex(
50 with self.assertRaisesRegex(
51 zstd.ZstdError, "error when determining content size"
51 zstd.ZstdError, "error when determining content size"
52 ):
52 ):
53 zstd.frame_content_size(b"")
53 zstd.frame_content_size(b"")
54
54
55 def test_too_small(self):
55 def test_too_small(self):
56 with self.assertRaisesRegex(
56 with self.assertRaisesRegex(
57 zstd.ZstdError, "error when determining content size"
57 zstd.ZstdError, "error when determining content size"
58 ):
58 ):
59 zstd.frame_content_size(b"foob")
59 zstd.frame_content_size(b"foob")
60
60
61 def test_bad_frame(self):
61 def test_bad_frame(self):
62 with self.assertRaisesRegex(
62 with self.assertRaisesRegex(
63 zstd.ZstdError, "error when determining content size"
63 zstd.ZstdError, "error when determining content size"
64 ):
64 ):
65 zstd.frame_content_size(b"invalid frame header")
65 zstd.frame_content_size(b"invalid frame header")
66
66
67 def test_unknown(self):
67 def test_unknown(self):
68 cctx = zstd.ZstdCompressor(write_content_size=False)
68 cctx = zstd.ZstdCompressor(write_content_size=False)
69 frame = cctx.compress(b"foobar")
69 frame = cctx.compress(b"foobar")
70
70
71 self.assertEqual(zstd.frame_content_size(frame), -1)
71 self.assertEqual(zstd.frame_content_size(frame), -1)
72
72
73 def test_empty(self):
73 def test_empty(self):
74 cctx = zstd.ZstdCompressor()
74 cctx = zstd.ZstdCompressor()
75 frame = cctx.compress(b"")
75 frame = cctx.compress(b"")
76
76
77 self.assertEqual(zstd.frame_content_size(frame), 0)
77 self.assertEqual(zstd.frame_content_size(frame), 0)
78
78
79 def test_basic(self):
79 def test_basic(self):
80 cctx = zstd.ZstdCompressor()
80 cctx = zstd.ZstdCompressor()
81 frame = cctx.compress(b"foobar")
81 frame = cctx.compress(b"foobar")
82
82
83 self.assertEqual(zstd.frame_content_size(frame), 6)
83 self.assertEqual(zstd.frame_content_size(frame), 6)
84
84
85
85
86 @make_cffi
86 @make_cffi
87 class TestDecompressor(TestCase):
87 class TestDecompressor(TestCase):
88 def test_memory_size(self):
88 def test_memory_size(self):
89 dctx = zstd.ZstdDecompressor()
89 dctx = zstd.ZstdDecompressor()
90
90
91 self.assertGreater(dctx.memory_size(), 100)
91 self.assertGreater(dctx.memory_size(), 100)
92
92
93
93
94 @make_cffi
94 @make_cffi
95 class TestDecompressor_decompress(TestCase):
95 class TestDecompressor_decompress(TestCase):
96 def test_empty_input(self):
96 def test_empty_input(self):
97 dctx = zstd.ZstdDecompressor()
97 dctx = zstd.ZstdDecompressor()
98
98
99 with self.assertRaisesRegex(
99 with self.assertRaisesRegex(
100 zstd.ZstdError, "error determining content size from frame header"
100 zstd.ZstdError, "error determining content size from frame header"
101 ):
101 ):
102 dctx.decompress(b"")
102 dctx.decompress(b"")
103
103
104 def test_invalid_input(self):
104 def test_invalid_input(self):
105 dctx = zstd.ZstdDecompressor()
105 dctx = zstd.ZstdDecompressor()
106
106
107 with self.assertRaisesRegex(
107 with self.assertRaisesRegex(
108 zstd.ZstdError, "error determining content size from frame header"
108 zstd.ZstdError, "error determining content size from frame header"
109 ):
109 ):
110 dctx.decompress(b"foobar")
110 dctx.decompress(b"foobar")
111
111
112 def test_input_types(self):
112 def test_input_types(self):
113 cctx = zstd.ZstdCompressor(level=1)
113 cctx = zstd.ZstdCompressor(level=1)
114 compressed = cctx.compress(b"foo")
114 compressed = cctx.compress(b"foo")
115
115
116 mutable_array = bytearray(len(compressed))
116 mutable_array = bytearray(len(compressed))
117 mutable_array[:] = compressed
117 mutable_array[:] = compressed
118
118
119 sources = [
119 sources = [
120 memoryview(compressed),
120 memoryview(compressed),
121 bytearray(compressed),
121 bytearray(compressed),
122 mutable_array,
122 mutable_array,
123 ]
123 ]
124
124
125 dctx = zstd.ZstdDecompressor()
125 dctx = zstd.ZstdDecompressor()
126 for source in sources:
126 for source in sources:
127 self.assertEqual(dctx.decompress(source), b"foo")
127 self.assertEqual(dctx.decompress(source), b"foo")
128
128
129 def test_no_content_size_in_frame(self):
129 def test_no_content_size_in_frame(self):
130 cctx = zstd.ZstdCompressor(write_content_size=False)
130 cctx = zstd.ZstdCompressor(write_content_size=False)
131 compressed = cctx.compress(b"foobar")
131 compressed = cctx.compress(b"foobar")
132
132
133 dctx = zstd.ZstdDecompressor()
133 dctx = zstd.ZstdDecompressor()
134 with self.assertRaisesRegex(
134 with self.assertRaisesRegex(
135 zstd.ZstdError, "could not determine content size in frame header"
135 zstd.ZstdError, "could not determine content size in frame header"
136 ):
136 ):
137 dctx.decompress(compressed)
137 dctx.decompress(compressed)
138
138
139 def test_content_size_present(self):
139 def test_content_size_present(self):
140 cctx = zstd.ZstdCompressor()
140 cctx = zstd.ZstdCompressor()
141 compressed = cctx.compress(b"foobar")
141 compressed = cctx.compress(b"foobar")
142
142
143 dctx = zstd.ZstdDecompressor()
143 dctx = zstd.ZstdDecompressor()
144 decompressed = dctx.decompress(compressed)
144 decompressed = dctx.decompress(compressed)
145 self.assertEqual(decompressed, b"foobar")
145 self.assertEqual(decompressed, b"foobar")
146
146
147 def test_empty_roundtrip(self):
147 def test_empty_roundtrip(self):
148 cctx = zstd.ZstdCompressor()
148 cctx = zstd.ZstdCompressor()
149 compressed = cctx.compress(b"")
149 compressed = cctx.compress(b"")
150
150
151 dctx = zstd.ZstdDecompressor()
151 dctx = zstd.ZstdDecompressor()
152 decompressed = dctx.decompress(compressed)
152 decompressed = dctx.decompress(compressed)
153
153
154 self.assertEqual(decompressed, b"")
154 self.assertEqual(decompressed, b"")
155
155
156 def test_max_output_size(self):
156 def test_max_output_size(self):
157 cctx = zstd.ZstdCompressor(write_content_size=False)
157 cctx = zstd.ZstdCompressor(write_content_size=False)
158 source = b"foobar" * 256
158 source = b"foobar" * 256
159 compressed = cctx.compress(source)
159 compressed = cctx.compress(source)
160
160
161 dctx = zstd.ZstdDecompressor()
161 dctx = zstd.ZstdDecompressor()
162 # Will fit into buffer exactly the size of input.
162 # Will fit into buffer exactly the size of input.
163 decompressed = dctx.decompress(compressed, max_output_size=len(source))
163 decompressed = dctx.decompress(compressed, max_output_size=len(source))
164 self.assertEqual(decompressed, source)
164 self.assertEqual(decompressed, source)
165
165
166 # Input size - 1 fails
166 # Input size - 1 fails
167 with self.assertRaisesRegex(
167 with self.assertRaisesRegex(
168 zstd.ZstdError, "decompression error: did not decompress full frame"
168 zstd.ZstdError, "decompression error: did not decompress full frame"
169 ):
169 ):
170 dctx.decompress(compressed, max_output_size=len(source) - 1)
170 dctx.decompress(compressed, max_output_size=len(source) - 1)
171
171
172 # Input size + 1 works
172 # Input size + 1 works
173 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1)
173 decompressed = dctx.decompress(
174 compressed, max_output_size=len(source) + 1
175 )
174 self.assertEqual(decompressed, source)
176 self.assertEqual(decompressed, source)
175
177
176 # A much larger buffer works.
178 # A much larger buffer works.
177 decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64)
179 decompressed = dctx.decompress(
180 compressed, max_output_size=len(source) * 64
181 )
178 self.assertEqual(decompressed, source)
182 self.assertEqual(decompressed, source)
179
183
180 def test_stupidly_large_output_buffer(self):
184 def test_stupidly_large_output_buffer(self):
181 cctx = zstd.ZstdCompressor(write_content_size=False)
185 cctx = zstd.ZstdCompressor(write_content_size=False)
182 compressed = cctx.compress(b"foobar" * 256)
186 compressed = cctx.compress(b"foobar" * 256)
183 dctx = zstd.ZstdDecompressor()
187 dctx = zstd.ZstdDecompressor()
184
188
185 # Will get OverflowError on some Python distributions that can't
189 # Will get OverflowError on some Python distributions that can't
186 # handle really large integers.
190 # handle really large integers.
187 with self.assertRaises((MemoryError, OverflowError)):
191 with self.assertRaises((MemoryError, OverflowError)):
188 dctx.decompress(compressed, max_output_size=2 ** 62)
192 dctx.decompress(compressed, max_output_size=2 ** 62)
189
193
190 def test_dictionary(self):
194 def test_dictionary(self):
191 samples = []
195 samples = []
192 for i in range(128):
196 for i in range(128):
193 samples.append(b"foo" * 64)
197 samples.append(b"foo" * 64)
194 samples.append(b"bar" * 64)
198 samples.append(b"bar" * 64)
195 samples.append(b"foobar" * 64)
199 samples.append(b"foobar" * 64)
196
200
197 d = zstd.train_dictionary(8192, samples)
201 d = zstd.train_dictionary(8192, samples)
198
202
199 orig = b"foobar" * 16384
203 orig = b"foobar" * 16384
200 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
204 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
201 compressed = cctx.compress(orig)
205 compressed = cctx.compress(orig)
202
206
203 dctx = zstd.ZstdDecompressor(dict_data=d)
207 dctx = zstd.ZstdDecompressor(dict_data=d)
204 decompressed = dctx.decompress(compressed)
208 decompressed = dctx.decompress(compressed)
205
209
206 self.assertEqual(decompressed, orig)
210 self.assertEqual(decompressed, orig)
207
211
208 def test_dictionary_multiple(self):
212 def test_dictionary_multiple(self):
209 samples = []
213 samples = []
210 for i in range(128):
214 for i in range(128):
211 samples.append(b"foo" * 64)
215 samples.append(b"foo" * 64)
212 samples.append(b"bar" * 64)
216 samples.append(b"bar" * 64)
213 samples.append(b"foobar" * 64)
217 samples.append(b"foobar" * 64)
214
218
215 d = zstd.train_dictionary(8192, samples)
219 d = zstd.train_dictionary(8192, samples)
216
220
217 sources = (b"foobar" * 8192, b"foo" * 8192, b"bar" * 8192)
221 sources = (b"foobar" * 8192, b"foo" * 8192, b"bar" * 8192)
218 compressed = []
222 compressed = []
219 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
223 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
220 for source in sources:
224 for source in sources:
221 compressed.append(cctx.compress(source))
225 compressed.append(cctx.compress(source))
222
226
223 dctx = zstd.ZstdDecompressor(dict_data=d)
227 dctx = zstd.ZstdDecompressor(dict_data=d)
224 for i in range(len(sources)):
228 for i in range(len(sources)):
225 decompressed = dctx.decompress(compressed[i])
229 decompressed = dctx.decompress(compressed[i])
226 self.assertEqual(decompressed, sources[i])
230 self.assertEqual(decompressed, sources[i])
227
231
228 def test_max_window_size(self):
232 def test_max_window_size(self):
229 with open(__file__, "rb") as fh:
233 with open(__file__, "rb") as fh:
230 source = fh.read()
234 source = fh.read()
231
235
232 # If we write a content size, the decompressor engages single pass
236 # If we write a content size, the decompressor engages single pass
233 # mode and the window size doesn't come into play.
237 # mode and the window size doesn't come into play.
234 cctx = zstd.ZstdCompressor(write_content_size=False)
238 cctx = zstd.ZstdCompressor(write_content_size=False)
235 frame = cctx.compress(source)
239 frame = cctx.compress(source)
236
240
237 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN)
241 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN)
238
242
239 with self.assertRaisesRegex(
243 with self.assertRaisesRegex(
240 zstd.ZstdError, "decompression error: Frame requires too much memory"
244 zstd.ZstdError,
245 "decompression error: Frame requires too much memory",
241 ):
246 ):
242 dctx.decompress(frame, max_output_size=len(source))
247 dctx.decompress(frame, max_output_size=len(source))
243
248
244
249
245 @make_cffi
250 @make_cffi
246 class TestDecompressor_copy_stream(TestCase):
251 class TestDecompressor_copy_stream(TestCase):
247 def test_no_read(self):
252 def test_no_read(self):
248 source = object()
253 source = object()
249 dest = io.BytesIO()
254 dest = io.BytesIO()
250
255
251 dctx = zstd.ZstdDecompressor()
256 dctx = zstd.ZstdDecompressor()
252 with self.assertRaises(ValueError):
257 with self.assertRaises(ValueError):
253 dctx.copy_stream(source, dest)
258 dctx.copy_stream(source, dest)
254
259
255 def test_no_write(self):
260 def test_no_write(self):
256 source = io.BytesIO()
261 source = io.BytesIO()
257 dest = object()
262 dest = object()
258
263
259 dctx = zstd.ZstdDecompressor()
264 dctx = zstd.ZstdDecompressor()
260 with self.assertRaises(ValueError):
265 with self.assertRaises(ValueError):
261 dctx.copy_stream(source, dest)
266 dctx.copy_stream(source, dest)
262
267
263 def test_empty(self):
268 def test_empty(self):
264 source = io.BytesIO()
269 source = io.BytesIO()
265 dest = io.BytesIO()
270 dest = io.BytesIO()
266
271
267 dctx = zstd.ZstdDecompressor()
272 dctx = zstd.ZstdDecompressor()
268 # TODO should this raise an error?
273 # TODO should this raise an error?
269 r, w = dctx.copy_stream(source, dest)
274 r, w = dctx.copy_stream(source, dest)
270
275
271 self.assertEqual(r, 0)
276 self.assertEqual(r, 0)
272 self.assertEqual(w, 0)
277 self.assertEqual(w, 0)
273 self.assertEqual(dest.getvalue(), b"")
278 self.assertEqual(dest.getvalue(), b"")
274
279
275 def test_large_data(self):
280 def test_large_data(self):
276 source = io.BytesIO()
281 source = io.BytesIO()
277 for i in range(255):
282 for i in range(255):
278 source.write(struct.Struct(">B").pack(i) * 16384)
283 source.write(struct.Struct(">B").pack(i) * 16384)
279 source.seek(0)
284 source.seek(0)
280
285
281 compressed = io.BytesIO()
286 compressed = io.BytesIO()
282 cctx = zstd.ZstdCompressor()
287 cctx = zstd.ZstdCompressor()
283 cctx.copy_stream(source, compressed)
288 cctx.copy_stream(source, compressed)
284
289
285 compressed.seek(0)
290 compressed.seek(0)
286 dest = io.BytesIO()
291 dest = io.BytesIO()
287 dctx = zstd.ZstdDecompressor()
292 dctx = zstd.ZstdDecompressor()
288 r, w = dctx.copy_stream(compressed, dest)
293 r, w = dctx.copy_stream(compressed, dest)
289
294
290 self.assertEqual(r, len(compressed.getvalue()))
295 self.assertEqual(r, len(compressed.getvalue()))
291 self.assertEqual(w, len(source.getvalue()))
296 self.assertEqual(w, len(source.getvalue()))
292
297
293 def test_read_write_size(self):
298 def test_read_write_size(self):
294 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar"))
299 source = OpCountingBytesIO(
300 zstd.ZstdCompressor().compress(b"foobarfoobar")
301 )
295
302
296 dest = OpCountingBytesIO()
303 dest = OpCountingBytesIO()
297 dctx = zstd.ZstdDecompressor()
304 dctx = zstd.ZstdDecompressor()
298 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
305 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
299
306
300 self.assertEqual(r, len(source.getvalue()))
307 self.assertEqual(r, len(source.getvalue()))
301 self.assertEqual(w, len(b"foobarfoobar"))
308 self.assertEqual(w, len(b"foobarfoobar"))
302 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
309 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
303 self.assertEqual(dest._write_count, len(dest.getvalue()))
310 self.assertEqual(dest._write_count, len(dest.getvalue()))
304
311
305
312
306 @make_cffi
313 @make_cffi
307 class TestDecompressor_stream_reader(TestCase):
314 class TestDecompressor_stream_reader(TestCase):
308 def test_context_manager(self):
315 def test_context_manager(self):
309 dctx = zstd.ZstdDecompressor()
316 dctx = zstd.ZstdDecompressor()
310
317
311 with dctx.stream_reader(b"foo") as reader:
318 with dctx.stream_reader(b"foo") as reader:
312 with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"):
319 with self.assertRaisesRegex(
320 ValueError, "cannot __enter__ multiple times"
321 ):
313 with reader as reader2:
322 with reader as reader2:
314 pass
323 pass
315
324
316 def test_not_implemented(self):
325 def test_not_implemented(self):
317 dctx = zstd.ZstdDecompressor()
326 dctx = zstd.ZstdDecompressor()
318
327
319 with dctx.stream_reader(b"foo") as reader:
328 with dctx.stream_reader(b"foo") as reader:
320 with self.assertRaises(io.UnsupportedOperation):
329 with self.assertRaises(io.UnsupportedOperation):
321 reader.readline()
330 reader.readline()
322
331
323 with self.assertRaises(io.UnsupportedOperation):
332 with self.assertRaises(io.UnsupportedOperation):
324 reader.readlines()
333 reader.readlines()
325
334
326 with self.assertRaises(io.UnsupportedOperation):
335 with self.assertRaises(io.UnsupportedOperation):
327 iter(reader)
336 iter(reader)
328
337
329 with self.assertRaises(io.UnsupportedOperation):
338 with self.assertRaises(io.UnsupportedOperation):
330 next(reader)
339 next(reader)
331
340
332 with self.assertRaises(io.UnsupportedOperation):
341 with self.assertRaises(io.UnsupportedOperation):
333 reader.write(b"foo")
342 reader.write(b"foo")
334
343
335 with self.assertRaises(io.UnsupportedOperation):
344 with self.assertRaises(io.UnsupportedOperation):
336 reader.writelines([])
345 reader.writelines([])
337
346
338 def test_constant_methods(self):
347 def test_constant_methods(self):
339 dctx = zstd.ZstdDecompressor()
348 dctx = zstd.ZstdDecompressor()
340
349
341 with dctx.stream_reader(b"foo") as reader:
350 with dctx.stream_reader(b"foo") as reader:
342 self.assertFalse(reader.closed)
351 self.assertFalse(reader.closed)
343 self.assertTrue(reader.readable())
352 self.assertTrue(reader.readable())
344 self.assertFalse(reader.writable())
353 self.assertFalse(reader.writable())
345 self.assertTrue(reader.seekable())
354 self.assertTrue(reader.seekable())
346 self.assertFalse(reader.isatty())
355 self.assertFalse(reader.isatty())
347 self.assertFalse(reader.closed)
356 self.assertFalse(reader.closed)
348 self.assertIsNone(reader.flush())
357 self.assertIsNone(reader.flush())
349 self.assertFalse(reader.closed)
358 self.assertFalse(reader.closed)
350
359
351 self.assertTrue(reader.closed)
360 self.assertTrue(reader.closed)
352
361
353 def test_read_closed(self):
362 def test_read_closed(self):
354 dctx = zstd.ZstdDecompressor()
363 dctx = zstd.ZstdDecompressor()
355
364
356 with dctx.stream_reader(b"foo") as reader:
365 with dctx.stream_reader(b"foo") as reader:
357 reader.close()
366 reader.close()
358 self.assertTrue(reader.closed)
367 self.assertTrue(reader.closed)
359 with self.assertRaisesRegex(ValueError, "stream is closed"):
368 with self.assertRaisesRegex(ValueError, "stream is closed"):
360 reader.read(1)
369 reader.read(1)
361
370
362 def test_read_sizes(self):
371 def test_read_sizes(self):
363 cctx = zstd.ZstdCompressor()
372 cctx = zstd.ZstdCompressor()
364 foo = cctx.compress(b"foo")
373 foo = cctx.compress(b"foo")
365
374
366 dctx = zstd.ZstdDecompressor()
375 dctx = zstd.ZstdDecompressor()
367
376
368 with dctx.stream_reader(foo) as reader:
377 with dctx.stream_reader(foo) as reader:
369 with self.assertRaisesRegex(
378 with self.assertRaisesRegex(
370 ValueError, "cannot read negative amounts less than -1"
379 ValueError, "cannot read negative amounts less than -1"
371 ):
380 ):
372 reader.read(-2)
381 reader.read(-2)
373
382
374 self.assertEqual(reader.read(0), b"")
383 self.assertEqual(reader.read(0), b"")
375 self.assertEqual(reader.read(), b"foo")
384 self.assertEqual(reader.read(), b"foo")
376
385
377 def test_read_buffer(self):
386 def test_read_buffer(self):
378 cctx = zstd.ZstdCompressor()
387 cctx = zstd.ZstdCompressor()
379
388
380 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
389 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
381 frame = cctx.compress(source)
390 frame = cctx.compress(source)
382
391
383 dctx = zstd.ZstdDecompressor()
392 dctx = zstd.ZstdDecompressor()
384
393
385 with dctx.stream_reader(frame) as reader:
394 with dctx.stream_reader(frame) as reader:
386 self.assertEqual(reader.tell(), 0)
395 self.assertEqual(reader.tell(), 0)
387
396
388 # We should get entire frame in one read.
397 # We should get entire frame in one read.
389 result = reader.read(8192)
398 result = reader.read(8192)
390 self.assertEqual(result, source)
399 self.assertEqual(result, source)
391 self.assertEqual(reader.tell(), len(source))
400 self.assertEqual(reader.tell(), len(source))
392
401
393 # Read after EOF should return empty bytes.
402 # Read after EOF should return empty bytes.
394 self.assertEqual(reader.read(1), b"")
403 self.assertEqual(reader.read(1), b"")
395 self.assertEqual(reader.tell(), len(result))
404 self.assertEqual(reader.tell(), len(result))
396
405
397 self.assertTrue(reader.closed)
406 self.assertTrue(reader.closed)
398
407
399 def test_read_buffer_small_chunks(self):
408 def test_read_buffer_small_chunks(self):
400 cctx = zstd.ZstdCompressor()
409 cctx = zstd.ZstdCompressor()
401 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
410 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
402 frame = cctx.compress(source)
411 frame = cctx.compress(source)
403
412
404 dctx = zstd.ZstdDecompressor()
413 dctx = zstd.ZstdDecompressor()
405 chunks = []
414 chunks = []
406
415
407 with dctx.stream_reader(frame, read_size=1) as reader:
416 with dctx.stream_reader(frame, read_size=1) as reader:
408 while True:
417 while True:
409 chunk = reader.read(1)
418 chunk = reader.read(1)
410 if not chunk:
419 if not chunk:
411 break
420 break
412
421
413 chunks.append(chunk)
422 chunks.append(chunk)
414 self.assertEqual(reader.tell(), sum(map(len, chunks)))
423 self.assertEqual(reader.tell(), sum(map(len, chunks)))
415
424
416 self.assertEqual(b"".join(chunks), source)
425 self.assertEqual(b"".join(chunks), source)
417
426
418 def test_read_stream(self):
427 def test_read_stream(self):
419 cctx = zstd.ZstdCompressor()
428 cctx = zstd.ZstdCompressor()
420 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
429 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
421 frame = cctx.compress(source)
430 frame = cctx.compress(source)
422
431
423 dctx = zstd.ZstdDecompressor()
432 dctx = zstd.ZstdDecompressor()
424 with dctx.stream_reader(io.BytesIO(frame)) as reader:
433 with dctx.stream_reader(io.BytesIO(frame)) as reader:
425 self.assertEqual(reader.tell(), 0)
434 self.assertEqual(reader.tell(), 0)
426
435
427 chunk = reader.read(8192)
436 chunk = reader.read(8192)
428 self.assertEqual(chunk, source)
437 self.assertEqual(chunk, source)
429 self.assertEqual(reader.tell(), len(source))
438 self.assertEqual(reader.tell(), len(source))
430 self.assertEqual(reader.read(1), b"")
439 self.assertEqual(reader.read(1), b"")
431 self.assertEqual(reader.tell(), len(source))
440 self.assertEqual(reader.tell(), len(source))
432 self.assertFalse(reader.closed)
441 self.assertFalse(reader.closed)
433
442
434 self.assertTrue(reader.closed)
443 self.assertTrue(reader.closed)
435
444
436 def test_read_stream_small_chunks(self):
445 def test_read_stream_small_chunks(self):
437 cctx = zstd.ZstdCompressor()
446 cctx = zstd.ZstdCompressor()
438 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
447 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
439 frame = cctx.compress(source)
448 frame = cctx.compress(source)
440
449
441 dctx = zstd.ZstdDecompressor()
450 dctx = zstd.ZstdDecompressor()
442 chunks = []
451 chunks = []
443
452
444 with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader:
453 with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader:
445 while True:
454 while True:
446 chunk = reader.read(1)
455 chunk = reader.read(1)
447 if not chunk:
456 if not chunk:
448 break
457 break
449
458
450 chunks.append(chunk)
459 chunks.append(chunk)
451 self.assertEqual(reader.tell(), sum(map(len, chunks)))
460 self.assertEqual(reader.tell(), sum(map(len, chunks)))
452
461
453 self.assertEqual(b"".join(chunks), source)
462 self.assertEqual(b"".join(chunks), source)
454
463
455 def test_read_after_exit(self):
464 def test_read_after_exit(self):
456 cctx = zstd.ZstdCompressor()
465 cctx = zstd.ZstdCompressor()
457 frame = cctx.compress(b"foo" * 60)
466 frame = cctx.compress(b"foo" * 60)
458
467
459 dctx = zstd.ZstdDecompressor()
468 dctx = zstd.ZstdDecompressor()
460
469
461 with dctx.stream_reader(frame) as reader:
470 with dctx.stream_reader(frame) as reader:
462 while reader.read(16):
471 while reader.read(16):
463 pass
472 pass
464
473
465 self.assertTrue(reader.closed)
474 self.assertTrue(reader.closed)
466
475
467 with self.assertRaisesRegex(ValueError, "stream is closed"):
476 with self.assertRaisesRegex(ValueError, "stream is closed"):
468 reader.read(10)
477 reader.read(10)
469
478
470 def test_illegal_seeks(self):
479 def test_illegal_seeks(self):
471 cctx = zstd.ZstdCompressor()
480 cctx = zstd.ZstdCompressor()
472 frame = cctx.compress(b"foo" * 60)
481 frame = cctx.compress(b"foo" * 60)
473
482
474 dctx = zstd.ZstdDecompressor()
483 dctx = zstd.ZstdDecompressor()
475
484
476 with dctx.stream_reader(frame) as reader:
485 with dctx.stream_reader(frame) as reader:
477 with self.assertRaisesRegex(ValueError, "cannot seek to negative position"):
486 with self.assertRaisesRegex(
487 ValueError, "cannot seek to negative position"
488 ):
478 reader.seek(-1, os.SEEK_SET)
489 reader.seek(-1, os.SEEK_SET)
479
490
480 reader.read(1)
491 reader.read(1)
481
492
482 with self.assertRaisesRegex(
493 with self.assertRaisesRegex(
483 ValueError, "cannot seek zstd decompression stream backwards"
494 ValueError, "cannot seek zstd decompression stream backwards"
484 ):
495 ):
485 reader.seek(0, os.SEEK_SET)
496 reader.seek(0, os.SEEK_SET)
486
497
487 with self.assertRaisesRegex(
498 with self.assertRaisesRegex(
488 ValueError, "cannot seek zstd decompression stream backwards"
499 ValueError, "cannot seek zstd decompression stream backwards"
489 ):
500 ):
490 reader.seek(-1, os.SEEK_CUR)
501 reader.seek(-1, os.SEEK_CUR)
491
502
492 with self.assertRaisesRegex(
503 with self.assertRaisesRegex(
493 ValueError, "zstd decompression streams cannot be seeked with SEEK_END"
504 ValueError,
505 "zstd decompression streams cannot be seeked with SEEK_END",
494 ):
506 ):
495 reader.seek(0, os.SEEK_END)
507 reader.seek(0, os.SEEK_END)
496
508
497 reader.close()
509 reader.close()
498
510
499 with self.assertRaisesRegex(ValueError, "stream is closed"):
511 with self.assertRaisesRegex(ValueError, "stream is closed"):
500 reader.seek(4, os.SEEK_SET)
512 reader.seek(4, os.SEEK_SET)
501
513
502 with self.assertRaisesRegex(ValueError, "stream is closed"):
514 with self.assertRaisesRegex(ValueError, "stream is closed"):
503 reader.seek(0)
515 reader.seek(0)
504
516
505 def test_seek(self):
517 def test_seek(self):
506 source = b"foobar" * 60
518 source = b"foobar" * 60
507 cctx = zstd.ZstdCompressor()
519 cctx = zstd.ZstdCompressor()
508 frame = cctx.compress(source)
520 frame = cctx.compress(source)
509
521
510 dctx = zstd.ZstdDecompressor()
522 dctx = zstd.ZstdDecompressor()
511
523
512 with dctx.stream_reader(frame) as reader:
524 with dctx.stream_reader(frame) as reader:
513 reader.seek(3)
525 reader.seek(3)
514 self.assertEqual(reader.read(3), b"bar")
526 self.assertEqual(reader.read(3), b"bar")
515
527
516 reader.seek(4, os.SEEK_CUR)
528 reader.seek(4, os.SEEK_CUR)
517 self.assertEqual(reader.read(2), b"ar")
529 self.assertEqual(reader.read(2), b"ar")
518
530
519 def test_no_context_manager(self):
531 def test_no_context_manager(self):
520 source = b"foobar" * 60
532 source = b"foobar" * 60
521 cctx = zstd.ZstdCompressor()
533 cctx = zstd.ZstdCompressor()
522 frame = cctx.compress(source)
534 frame = cctx.compress(source)
523
535
524 dctx = zstd.ZstdDecompressor()
536 dctx = zstd.ZstdDecompressor()
525 reader = dctx.stream_reader(frame)
537 reader = dctx.stream_reader(frame)
526
538
527 self.assertEqual(reader.read(6), b"foobar")
539 self.assertEqual(reader.read(6), b"foobar")
528 self.assertEqual(reader.read(18), b"foobar" * 3)
540 self.assertEqual(reader.read(18), b"foobar" * 3)
529 self.assertFalse(reader.closed)
541 self.assertFalse(reader.closed)
530
542
531 # Calling close prevents subsequent use.
543 # Calling close prevents subsequent use.
532 reader.close()
544 reader.close()
533 self.assertTrue(reader.closed)
545 self.assertTrue(reader.closed)
534
546
535 with self.assertRaisesRegex(ValueError, "stream is closed"):
547 with self.assertRaisesRegex(ValueError, "stream is closed"):
536 reader.read(6)
548 reader.read(6)
537
549
538 def test_read_after_error(self):
550 def test_read_after_error(self):
539 source = io.BytesIO(b"")
551 source = io.BytesIO(b"")
540 dctx = zstd.ZstdDecompressor()
552 dctx = zstd.ZstdDecompressor()
541
553
542 reader = dctx.stream_reader(source)
554 reader = dctx.stream_reader(source)
543
555
544 with reader:
556 with reader:
545 reader.read(0)
557 reader.read(0)
546
558
547 with reader:
559 with reader:
548 with self.assertRaisesRegex(ValueError, "stream is closed"):
560 with self.assertRaisesRegex(ValueError, "stream is closed"):
549 reader.read(100)
561 reader.read(100)
550
562
551 def test_partial_read(self):
563 def test_partial_read(self):
552 # Inspired by https://github.com/indygreg/python-zstandard/issues/71.
564 # Inspired by https://github.com/indygreg/python-zstandard/issues/71.
553 buffer = io.BytesIO()
565 buffer = io.BytesIO()
554 cctx = zstd.ZstdCompressor()
566 cctx = zstd.ZstdCompressor()
555 writer = cctx.stream_writer(buffer)
567 writer = cctx.stream_writer(buffer)
556 writer.write(bytearray(os.urandom(1000000)))
568 writer.write(bytearray(os.urandom(1000000)))
557 writer.flush(zstd.FLUSH_FRAME)
569 writer.flush(zstd.FLUSH_FRAME)
558 buffer.seek(0)
570 buffer.seek(0)
559
571
560 dctx = zstd.ZstdDecompressor()
572 dctx = zstd.ZstdDecompressor()
561 reader = dctx.stream_reader(buffer)
573 reader = dctx.stream_reader(buffer)
562
574
563 while True:
575 while True:
564 chunk = reader.read(8192)
576 chunk = reader.read(8192)
565 if not chunk:
577 if not chunk:
566 break
578 break
567
579
568 def test_read_multiple_frames(self):
580 def test_read_multiple_frames(self):
569 cctx = zstd.ZstdCompressor()
581 cctx = zstd.ZstdCompressor()
570 source = io.BytesIO()
582 source = io.BytesIO()
571 writer = cctx.stream_writer(source)
583 writer = cctx.stream_writer(source)
572 writer.write(b"foo")
584 writer.write(b"foo")
573 writer.flush(zstd.FLUSH_FRAME)
585 writer.flush(zstd.FLUSH_FRAME)
574 writer.write(b"bar")
586 writer.write(b"bar")
575 writer.flush(zstd.FLUSH_FRAME)
587 writer.flush(zstd.FLUSH_FRAME)
576
588
577 dctx = zstd.ZstdDecompressor()
589 dctx = zstd.ZstdDecompressor()
578
590
579 reader = dctx.stream_reader(source.getvalue())
591 reader = dctx.stream_reader(source.getvalue())
580 self.assertEqual(reader.read(2), b"fo")
592 self.assertEqual(reader.read(2), b"fo")
581 self.assertEqual(reader.read(2), b"o")
593 self.assertEqual(reader.read(2), b"o")
582 self.assertEqual(reader.read(2), b"ba")
594 self.assertEqual(reader.read(2), b"ba")
583 self.assertEqual(reader.read(2), b"r")
595 self.assertEqual(reader.read(2), b"r")
584
596
585 source.seek(0)
597 source.seek(0)
586 reader = dctx.stream_reader(source)
598 reader = dctx.stream_reader(source)
587 self.assertEqual(reader.read(2), b"fo")
599 self.assertEqual(reader.read(2), b"fo")
588 self.assertEqual(reader.read(2), b"o")
600 self.assertEqual(reader.read(2), b"o")
589 self.assertEqual(reader.read(2), b"ba")
601 self.assertEqual(reader.read(2), b"ba")
590 self.assertEqual(reader.read(2), b"r")
602 self.assertEqual(reader.read(2), b"r")
591
603
592 reader = dctx.stream_reader(source.getvalue())
604 reader = dctx.stream_reader(source.getvalue())
593 self.assertEqual(reader.read(3), b"foo")
605 self.assertEqual(reader.read(3), b"foo")
594 self.assertEqual(reader.read(3), b"bar")
606 self.assertEqual(reader.read(3), b"bar")
595
607
596 source.seek(0)
608 source.seek(0)
597 reader = dctx.stream_reader(source)
609 reader = dctx.stream_reader(source)
598 self.assertEqual(reader.read(3), b"foo")
610 self.assertEqual(reader.read(3), b"foo")
599 self.assertEqual(reader.read(3), b"bar")
611 self.assertEqual(reader.read(3), b"bar")
600
612
601 reader = dctx.stream_reader(source.getvalue())
613 reader = dctx.stream_reader(source.getvalue())
602 self.assertEqual(reader.read(4), b"foo")
614 self.assertEqual(reader.read(4), b"foo")
603 self.assertEqual(reader.read(4), b"bar")
615 self.assertEqual(reader.read(4), b"bar")
604
616
605 source.seek(0)
617 source.seek(0)
606 reader = dctx.stream_reader(source)
618 reader = dctx.stream_reader(source)
607 self.assertEqual(reader.read(4), b"foo")
619 self.assertEqual(reader.read(4), b"foo")
608 self.assertEqual(reader.read(4), b"bar")
620 self.assertEqual(reader.read(4), b"bar")
609
621
610 reader = dctx.stream_reader(source.getvalue())
622 reader = dctx.stream_reader(source.getvalue())
611 self.assertEqual(reader.read(128), b"foo")
623 self.assertEqual(reader.read(128), b"foo")
612 self.assertEqual(reader.read(128), b"bar")
624 self.assertEqual(reader.read(128), b"bar")
613
625
614 source.seek(0)
626 source.seek(0)
615 reader = dctx.stream_reader(source)
627 reader = dctx.stream_reader(source)
616 self.assertEqual(reader.read(128), b"foo")
628 self.assertEqual(reader.read(128), b"foo")
617 self.assertEqual(reader.read(128), b"bar")
629 self.assertEqual(reader.read(128), b"bar")
618
630
619 # Now tests for reads spanning frames.
631 # Now tests for reads spanning frames.
620 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
632 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
621 self.assertEqual(reader.read(3), b"foo")
633 self.assertEqual(reader.read(3), b"foo")
622 self.assertEqual(reader.read(3), b"bar")
634 self.assertEqual(reader.read(3), b"bar")
623
635
624 source.seek(0)
636 source.seek(0)
625 reader = dctx.stream_reader(source, read_across_frames=True)
637 reader = dctx.stream_reader(source, read_across_frames=True)
626 self.assertEqual(reader.read(3), b"foo")
638 self.assertEqual(reader.read(3), b"foo")
627 self.assertEqual(reader.read(3), b"bar")
639 self.assertEqual(reader.read(3), b"bar")
628
640
629 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
641 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
630 self.assertEqual(reader.read(6), b"foobar")
642 self.assertEqual(reader.read(6), b"foobar")
631
643
632 source.seek(0)
644 source.seek(0)
633 reader = dctx.stream_reader(source, read_across_frames=True)
645 reader = dctx.stream_reader(source, read_across_frames=True)
634 self.assertEqual(reader.read(6), b"foobar")
646 self.assertEqual(reader.read(6), b"foobar")
635
647
636 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
648 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
637 self.assertEqual(reader.read(7), b"foobar")
649 self.assertEqual(reader.read(7), b"foobar")
638
650
639 source.seek(0)
651 source.seek(0)
640 reader = dctx.stream_reader(source, read_across_frames=True)
652 reader = dctx.stream_reader(source, read_across_frames=True)
641 self.assertEqual(reader.read(7), b"foobar")
653 self.assertEqual(reader.read(7), b"foobar")
642
654
643 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
655 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
644 self.assertEqual(reader.read(128), b"foobar")
656 self.assertEqual(reader.read(128), b"foobar")
645
657
646 source.seek(0)
658 source.seek(0)
647 reader = dctx.stream_reader(source, read_across_frames=True)
659 reader = dctx.stream_reader(source, read_across_frames=True)
648 self.assertEqual(reader.read(128), b"foobar")
660 self.assertEqual(reader.read(128), b"foobar")
649
661
650 def test_readinto(self):
662 def test_readinto(self):
651 cctx = zstd.ZstdCompressor()
663 cctx = zstd.ZstdCompressor()
652 foo = cctx.compress(b"foo")
664 foo = cctx.compress(b"foo")
653
665
654 dctx = zstd.ZstdDecompressor()
666 dctx = zstd.ZstdDecompressor()
655
667
656 # Attempting to readinto() a non-writable buffer fails.
668 # Attempting to readinto() a non-writable buffer fails.
657 # The exact exception varies based on the backend.
669 # The exact exception varies based on the backend.
658 reader = dctx.stream_reader(foo)
670 reader = dctx.stream_reader(foo)
659 with self.assertRaises(Exception):
671 with self.assertRaises(Exception):
660 reader.readinto(b"foobar")
672 reader.readinto(b"foobar")
661
673
662 # readinto() with sufficiently large destination.
674 # readinto() with sufficiently large destination.
663 b = bytearray(1024)
675 b = bytearray(1024)
664 reader = dctx.stream_reader(foo)
676 reader = dctx.stream_reader(foo)
665 self.assertEqual(reader.readinto(b), 3)
677 self.assertEqual(reader.readinto(b), 3)
666 self.assertEqual(b[0:3], b"foo")
678 self.assertEqual(b[0:3], b"foo")
667 self.assertEqual(reader.readinto(b), 0)
679 self.assertEqual(reader.readinto(b), 0)
668 self.assertEqual(b[0:3], b"foo")
680 self.assertEqual(b[0:3], b"foo")
669
681
670 # readinto() with small reads.
682 # readinto() with small reads.
671 b = bytearray(1024)
683 b = bytearray(1024)
672 reader = dctx.stream_reader(foo, read_size=1)
684 reader = dctx.stream_reader(foo, read_size=1)
673 self.assertEqual(reader.readinto(b), 3)
685 self.assertEqual(reader.readinto(b), 3)
674 self.assertEqual(b[0:3], b"foo")
686 self.assertEqual(b[0:3], b"foo")
675
687
676 # Too small destination buffer.
688 # Too small destination buffer.
677 b = bytearray(2)
689 b = bytearray(2)
678 reader = dctx.stream_reader(foo)
690 reader = dctx.stream_reader(foo)
679 self.assertEqual(reader.readinto(b), 2)
691 self.assertEqual(reader.readinto(b), 2)
680 self.assertEqual(b[:], b"fo")
692 self.assertEqual(b[:], b"fo")
681
693
682 def test_readinto1(self):
694 def test_readinto1(self):
683 cctx = zstd.ZstdCompressor()
695 cctx = zstd.ZstdCompressor()
684 foo = cctx.compress(b"foo")
696 foo = cctx.compress(b"foo")
685
697
686 dctx = zstd.ZstdDecompressor()
698 dctx = zstd.ZstdDecompressor()
687
699
688 reader = dctx.stream_reader(foo)
700 reader = dctx.stream_reader(foo)
689 with self.assertRaises(Exception):
701 with self.assertRaises(Exception):
690 reader.readinto1(b"foobar")
702 reader.readinto1(b"foobar")
691
703
692 # Sufficiently large destination.
704 # Sufficiently large destination.
693 b = bytearray(1024)
705 b = bytearray(1024)
694 reader = dctx.stream_reader(foo)
706 reader = dctx.stream_reader(foo)
695 self.assertEqual(reader.readinto1(b), 3)
707 self.assertEqual(reader.readinto1(b), 3)
696 self.assertEqual(b[0:3], b"foo")
708 self.assertEqual(b[0:3], b"foo")
697 self.assertEqual(reader.readinto1(b), 0)
709 self.assertEqual(reader.readinto1(b), 0)
698 self.assertEqual(b[0:3], b"foo")
710 self.assertEqual(b[0:3], b"foo")
699
711
700 # readinto() with small reads.
712 # readinto() with small reads.
701 b = bytearray(1024)
713 b = bytearray(1024)
702 reader = dctx.stream_reader(foo, read_size=1)
714 reader = dctx.stream_reader(foo, read_size=1)
703 self.assertEqual(reader.readinto1(b), 3)
715 self.assertEqual(reader.readinto1(b), 3)
704 self.assertEqual(b[0:3], b"foo")
716 self.assertEqual(b[0:3], b"foo")
705
717
706 # Too small destination buffer.
718 # Too small destination buffer.
707 b = bytearray(2)
719 b = bytearray(2)
708 reader = dctx.stream_reader(foo)
720 reader = dctx.stream_reader(foo)
709 self.assertEqual(reader.readinto1(b), 2)
721 self.assertEqual(reader.readinto1(b), 2)
710 self.assertEqual(b[:], b"fo")
722 self.assertEqual(b[:], b"fo")
711
723
712 def test_readall(self):
724 def test_readall(self):
713 cctx = zstd.ZstdCompressor()
725 cctx = zstd.ZstdCompressor()
714 foo = cctx.compress(b"foo")
726 foo = cctx.compress(b"foo")
715
727
716 dctx = zstd.ZstdDecompressor()
728 dctx = zstd.ZstdDecompressor()
717 reader = dctx.stream_reader(foo)
729 reader = dctx.stream_reader(foo)
718
730
719 self.assertEqual(reader.readall(), b"foo")
731 self.assertEqual(reader.readall(), b"foo")
720
732
721 def test_read1(self):
733 def test_read1(self):
722 cctx = zstd.ZstdCompressor()
734 cctx = zstd.ZstdCompressor()
723 foo = cctx.compress(b"foo")
735 foo = cctx.compress(b"foo")
724
736
725 dctx = zstd.ZstdDecompressor()
737 dctx = zstd.ZstdDecompressor()
726
738
727 b = OpCountingBytesIO(foo)
739 b = OpCountingBytesIO(foo)
728 reader = dctx.stream_reader(b)
740 reader = dctx.stream_reader(b)
729
741
730 self.assertEqual(reader.read1(), b"foo")
742 self.assertEqual(reader.read1(), b"foo")
731 self.assertEqual(b._read_count, 1)
743 self.assertEqual(b._read_count, 1)
732
744
733 b = OpCountingBytesIO(foo)
745 b = OpCountingBytesIO(foo)
734 reader = dctx.stream_reader(b)
746 reader = dctx.stream_reader(b)
735
747
736 self.assertEqual(reader.read1(0), b"")
748 self.assertEqual(reader.read1(0), b"")
737 self.assertEqual(reader.read1(2), b"fo")
749 self.assertEqual(reader.read1(2), b"fo")
738 self.assertEqual(b._read_count, 1)
750 self.assertEqual(b._read_count, 1)
739 self.assertEqual(reader.read1(1), b"o")
751 self.assertEqual(reader.read1(1), b"o")
740 self.assertEqual(b._read_count, 1)
752 self.assertEqual(b._read_count, 1)
741 self.assertEqual(reader.read1(1), b"")
753 self.assertEqual(reader.read1(1), b"")
742 self.assertEqual(b._read_count, 2)
754 self.assertEqual(b._read_count, 2)
743
755
744 def test_read_lines(self):
756 def test_read_lines(self):
745 cctx = zstd.ZstdCompressor()
757 cctx = zstd.ZstdCompressor()
746 source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024))
758 source = b"\n".join(
759 ("line %d" % i).encode("ascii") for i in range(1024)
760 )
747
761
748 frame = cctx.compress(source)
762 frame = cctx.compress(source)
749
763
750 dctx = zstd.ZstdDecompressor()
764 dctx = zstd.ZstdDecompressor()
751 reader = dctx.stream_reader(frame)
765 reader = dctx.stream_reader(frame)
752 tr = io.TextIOWrapper(reader, encoding="utf-8")
766 tr = io.TextIOWrapper(reader, encoding="utf-8")
753
767
754 lines = []
768 lines = []
755 for line in tr:
769 for line in tr:
756 lines.append(line.encode("utf-8"))
770 lines.append(line.encode("utf-8"))
757
771
758 self.assertEqual(len(lines), 1024)
772 self.assertEqual(len(lines), 1024)
759 self.assertEqual(b"".join(lines), source)
773 self.assertEqual(b"".join(lines), source)
760
774
761 reader = dctx.stream_reader(frame)
775 reader = dctx.stream_reader(frame)
762 tr = io.TextIOWrapper(reader, encoding="utf-8")
776 tr = io.TextIOWrapper(reader, encoding="utf-8")
763
777
764 lines = tr.readlines()
778 lines = tr.readlines()
765 self.assertEqual(len(lines), 1024)
779 self.assertEqual(len(lines), 1024)
766 self.assertEqual("".join(lines).encode("utf-8"), source)
780 self.assertEqual("".join(lines).encode("utf-8"), source)
767
781
768 reader = dctx.stream_reader(frame)
782 reader = dctx.stream_reader(frame)
769 tr = io.TextIOWrapper(reader, encoding="utf-8")
783 tr = io.TextIOWrapper(reader, encoding="utf-8")
770
784
771 lines = []
785 lines = []
772 while True:
786 while True:
773 line = tr.readline()
787 line = tr.readline()
774 if not line:
788 if not line:
775 break
789 break
776
790
777 lines.append(line.encode("utf-8"))
791 lines.append(line.encode("utf-8"))
778
792
779 self.assertEqual(len(lines), 1024)
793 self.assertEqual(len(lines), 1024)
780 self.assertEqual(b"".join(lines), source)
794 self.assertEqual(b"".join(lines), source)
781
795
782
796
783 @make_cffi
797 @make_cffi
784 class TestDecompressor_decompressobj(TestCase):
798 class TestDecompressor_decompressobj(TestCase):
785 def test_simple(self):
799 def test_simple(self):
786 data = zstd.ZstdCompressor(level=1).compress(b"foobar")
800 data = zstd.ZstdCompressor(level=1).compress(b"foobar")
787
801
788 dctx = zstd.ZstdDecompressor()
802 dctx = zstd.ZstdDecompressor()
789 dobj = dctx.decompressobj()
803 dobj = dctx.decompressobj()
790 self.assertEqual(dobj.decompress(data), b"foobar")
804 self.assertEqual(dobj.decompress(data), b"foobar")
791 self.assertIsNone(dobj.flush())
805 self.assertIsNone(dobj.flush())
792 self.assertIsNone(dobj.flush(10))
806 self.assertIsNone(dobj.flush(10))
793 self.assertIsNone(dobj.flush(length=100))
807 self.assertIsNone(dobj.flush(length=100))
794
808
795 def test_input_types(self):
809 def test_input_types(self):
796 compressed = zstd.ZstdCompressor(level=1).compress(b"foo")
810 compressed = zstd.ZstdCompressor(level=1).compress(b"foo")
797
811
798 dctx = zstd.ZstdDecompressor()
812 dctx = zstd.ZstdDecompressor()
799
813
800 mutable_array = bytearray(len(compressed))
814 mutable_array = bytearray(len(compressed))
801 mutable_array[:] = compressed
815 mutable_array[:] = compressed
802
816
803 sources = [
817 sources = [
804 memoryview(compressed),
818 memoryview(compressed),
805 bytearray(compressed),
819 bytearray(compressed),
806 mutable_array,
820 mutable_array,
807 ]
821 ]
808
822
809 for source in sources:
823 for source in sources:
810 dobj = dctx.decompressobj()
824 dobj = dctx.decompressobj()
811 self.assertIsNone(dobj.flush())
825 self.assertIsNone(dobj.flush())
812 self.assertIsNone(dobj.flush(10))
826 self.assertIsNone(dobj.flush(10))
813 self.assertIsNone(dobj.flush(length=100))
827 self.assertIsNone(dobj.flush(length=100))
814 self.assertEqual(dobj.decompress(source), b"foo")
828 self.assertEqual(dobj.decompress(source), b"foo")
815 self.assertIsNone(dobj.flush())
829 self.assertIsNone(dobj.flush())
816
830
817 def test_reuse(self):
831 def test_reuse(self):
818 data = zstd.ZstdCompressor(level=1).compress(b"foobar")
832 data = zstd.ZstdCompressor(level=1).compress(b"foobar")
819
833
820 dctx = zstd.ZstdDecompressor()
834 dctx = zstd.ZstdDecompressor()
821 dobj = dctx.decompressobj()
835 dobj = dctx.decompressobj()
822 dobj.decompress(data)
836 dobj.decompress(data)
823
837
824 with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"):
838 with self.assertRaisesRegex(
839 zstd.ZstdError, "cannot use a decompressobj"
840 ):
825 dobj.decompress(data)
841 dobj.decompress(data)
826 self.assertIsNone(dobj.flush())
842 self.assertIsNone(dobj.flush())
827
843
828 def test_bad_write_size(self):
844 def test_bad_write_size(self):
829 dctx = zstd.ZstdDecompressor()
845 dctx = zstd.ZstdDecompressor()
830
846
831 with self.assertRaisesRegex(ValueError, "write_size must be positive"):
847 with self.assertRaisesRegex(ValueError, "write_size must be positive"):
832 dctx.decompressobj(write_size=0)
848 dctx.decompressobj(write_size=0)
833
849
834 def test_write_size(self):
850 def test_write_size(self):
835 source = b"foo" * 64 + b"bar" * 128
851 source = b"foo" * 64 + b"bar" * 128
836 data = zstd.ZstdCompressor(level=1).compress(source)
852 data = zstd.ZstdCompressor(level=1).compress(source)
837
853
838 dctx = zstd.ZstdDecompressor()
854 dctx = zstd.ZstdDecompressor()
839
855
840 for i in range(128):
856 for i in range(128):
841 dobj = dctx.decompressobj(write_size=i + 1)
857 dobj = dctx.decompressobj(write_size=i + 1)
842 self.assertEqual(dobj.decompress(data), source)
858 self.assertEqual(dobj.decompress(data), source)
843
859
844
860
845 def decompress_via_writer(data):
861 def decompress_via_writer(data):
846 buffer = io.BytesIO()
862 buffer = io.BytesIO()
847 dctx = zstd.ZstdDecompressor()
863 dctx = zstd.ZstdDecompressor()
848 decompressor = dctx.stream_writer(buffer)
864 decompressor = dctx.stream_writer(buffer)
849 decompressor.write(data)
865 decompressor.write(data)
850
866
851 return buffer.getvalue()
867 return buffer.getvalue()
852
868
853
869
854 @make_cffi
870 @make_cffi
855 class TestDecompressor_stream_writer(TestCase):
871 class TestDecompressor_stream_writer(TestCase):
856 def test_io_api(self):
872 def test_io_api(self):
857 buffer = io.BytesIO()
873 buffer = io.BytesIO()
858 dctx = zstd.ZstdDecompressor()
874 dctx = zstd.ZstdDecompressor()
859 writer = dctx.stream_writer(buffer)
875 writer = dctx.stream_writer(buffer)
860
876
861 self.assertFalse(writer.closed)
877 self.assertFalse(writer.closed)
862 self.assertFalse(writer.isatty())
878 self.assertFalse(writer.isatty())
863 self.assertFalse(writer.readable())
879 self.assertFalse(writer.readable())
864
880
865 with self.assertRaises(io.UnsupportedOperation):
881 with self.assertRaises(io.UnsupportedOperation):
866 writer.readline()
882 writer.readline()
867
883
868 with self.assertRaises(io.UnsupportedOperation):
884 with self.assertRaises(io.UnsupportedOperation):
869 writer.readline(42)
885 writer.readline(42)
870
886
871 with self.assertRaises(io.UnsupportedOperation):
887 with self.assertRaises(io.UnsupportedOperation):
872 writer.readline(size=42)
888 writer.readline(size=42)
873
889
874 with self.assertRaises(io.UnsupportedOperation):
890 with self.assertRaises(io.UnsupportedOperation):
875 writer.readlines()
891 writer.readlines()
876
892
877 with self.assertRaises(io.UnsupportedOperation):
893 with self.assertRaises(io.UnsupportedOperation):
878 writer.readlines(42)
894 writer.readlines(42)
879
895
880 with self.assertRaises(io.UnsupportedOperation):
896 with self.assertRaises(io.UnsupportedOperation):
881 writer.readlines(hint=42)
897 writer.readlines(hint=42)
882
898
883 with self.assertRaises(io.UnsupportedOperation):
899 with self.assertRaises(io.UnsupportedOperation):
884 writer.seek(0)
900 writer.seek(0)
885
901
886 with self.assertRaises(io.UnsupportedOperation):
902 with self.assertRaises(io.UnsupportedOperation):
887 writer.seek(10, os.SEEK_SET)
903 writer.seek(10, os.SEEK_SET)
888
904
889 self.assertFalse(writer.seekable())
905 self.assertFalse(writer.seekable())
890
906
891 with self.assertRaises(io.UnsupportedOperation):
907 with self.assertRaises(io.UnsupportedOperation):
892 writer.tell()
908 writer.tell()
893
909
894 with self.assertRaises(io.UnsupportedOperation):
910 with self.assertRaises(io.UnsupportedOperation):
895 writer.truncate()
911 writer.truncate()
896
912
897 with self.assertRaises(io.UnsupportedOperation):
913 with self.assertRaises(io.UnsupportedOperation):
898 writer.truncate(42)
914 writer.truncate(42)
899
915
900 with self.assertRaises(io.UnsupportedOperation):
916 with self.assertRaises(io.UnsupportedOperation):
901 writer.truncate(size=42)
917 writer.truncate(size=42)
902
918
903 self.assertTrue(writer.writable())
919 self.assertTrue(writer.writable())
904
920
905 with self.assertRaises(io.UnsupportedOperation):
921 with self.assertRaises(io.UnsupportedOperation):
906 writer.writelines([])
922 writer.writelines([])
907
923
908 with self.assertRaises(io.UnsupportedOperation):
924 with self.assertRaises(io.UnsupportedOperation):
909 writer.read()
925 writer.read()
910
926
911 with self.assertRaises(io.UnsupportedOperation):
927 with self.assertRaises(io.UnsupportedOperation):
912 writer.read(42)
928 writer.read(42)
913
929
914 with self.assertRaises(io.UnsupportedOperation):
930 with self.assertRaises(io.UnsupportedOperation):
915 writer.read(size=42)
931 writer.read(size=42)
916
932
917 with self.assertRaises(io.UnsupportedOperation):
933 with self.assertRaises(io.UnsupportedOperation):
918 writer.readall()
934 writer.readall()
919
935
920 with self.assertRaises(io.UnsupportedOperation):
936 with self.assertRaises(io.UnsupportedOperation):
921 writer.readinto(None)
937 writer.readinto(None)
922
938
923 with self.assertRaises(io.UnsupportedOperation):
939 with self.assertRaises(io.UnsupportedOperation):
924 writer.fileno()
940 writer.fileno()
925
941
926 def test_fileno_file(self):
942 def test_fileno_file(self):
927 with tempfile.TemporaryFile("wb") as tf:
943 with tempfile.TemporaryFile("wb") as tf:
928 dctx = zstd.ZstdDecompressor()
944 dctx = zstd.ZstdDecompressor()
929 writer = dctx.stream_writer(tf)
945 writer = dctx.stream_writer(tf)
930
946
931 self.assertEqual(writer.fileno(), tf.fileno())
947 self.assertEqual(writer.fileno(), tf.fileno())
932
948
933 def test_close(self):
949 def test_close(self):
934 foo = zstd.ZstdCompressor().compress(b"foo")
950 foo = zstd.ZstdCompressor().compress(b"foo")
935
951
936 buffer = NonClosingBytesIO()
952 buffer = NonClosingBytesIO()
937 dctx = zstd.ZstdDecompressor()
953 dctx = zstd.ZstdDecompressor()
938 writer = dctx.stream_writer(buffer)
954 writer = dctx.stream_writer(buffer)
939
955
940 writer.write(foo)
956 writer.write(foo)
941 self.assertFalse(writer.closed)
957 self.assertFalse(writer.closed)
942 self.assertFalse(buffer.closed)
958 self.assertFalse(buffer.closed)
943 writer.close()
959 writer.close()
944 self.assertTrue(writer.closed)
960 self.assertTrue(writer.closed)
945 self.assertTrue(buffer.closed)
961 self.assertTrue(buffer.closed)
946
962
947 with self.assertRaisesRegex(ValueError, "stream is closed"):
963 with self.assertRaisesRegex(ValueError, "stream is closed"):
948 writer.write(b"")
964 writer.write(b"")
949
965
950 with self.assertRaisesRegex(ValueError, "stream is closed"):
966 with self.assertRaisesRegex(ValueError, "stream is closed"):
951 writer.flush()
967 writer.flush()
952
968
953 with self.assertRaisesRegex(ValueError, "stream is closed"):
969 with self.assertRaisesRegex(ValueError, "stream is closed"):
954 with writer:
970 with writer:
955 pass
971 pass
956
972
957 self.assertEqual(buffer.getvalue(), b"foo")
973 self.assertEqual(buffer.getvalue(), b"foo")
958
974
959 # Context manager exit should close stream.
975 # Context manager exit should close stream.
960 buffer = NonClosingBytesIO()
976 buffer = NonClosingBytesIO()
961 writer = dctx.stream_writer(buffer)
977 writer = dctx.stream_writer(buffer)
962
978
963 with writer:
979 with writer:
964 writer.write(foo)
980 writer.write(foo)
965
981
966 self.assertTrue(writer.closed)
982 self.assertTrue(writer.closed)
967 self.assertEqual(buffer.getvalue(), b"foo")
983 self.assertEqual(buffer.getvalue(), b"foo")
968
984
969 def test_flush(self):
985 def test_flush(self):
970 buffer = OpCountingBytesIO()
986 buffer = OpCountingBytesIO()
971 dctx = zstd.ZstdDecompressor()
987 dctx = zstd.ZstdDecompressor()
972 writer = dctx.stream_writer(buffer)
988 writer = dctx.stream_writer(buffer)
973
989
974 writer.flush()
990 writer.flush()
975 self.assertEqual(buffer._flush_count, 1)
991 self.assertEqual(buffer._flush_count, 1)
976 writer.flush()
992 writer.flush()
977 self.assertEqual(buffer._flush_count, 2)
993 self.assertEqual(buffer._flush_count, 2)
978
994
979 def test_empty_roundtrip(self):
995 def test_empty_roundtrip(self):
980 cctx = zstd.ZstdCompressor()
996 cctx = zstd.ZstdCompressor()
981 empty = cctx.compress(b"")
997 empty = cctx.compress(b"")
982 self.assertEqual(decompress_via_writer(empty), b"")
998 self.assertEqual(decompress_via_writer(empty), b"")
983
999
984 def test_input_types(self):
1000 def test_input_types(self):
985 cctx = zstd.ZstdCompressor(level=1)
1001 cctx = zstd.ZstdCompressor(level=1)
986 compressed = cctx.compress(b"foo")
1002 compressed = cctx.compress(b"foo")
987
1003
988 mutable_array = bytearray(len(compressed))
1004 mutable_array = bytearray(len(compressed))
989 mutable_array[:] = compressed
1005 mutable_array[:] = compressed
990
1006
991 sources = [
1007 sources = [
992 memoryview(compressed),
1008 memoryview(compressed),
993 bytearray(compressed),
1009 bytearray(compressed),
994 mutable_array,
1010 mutable_array,
995 ]
1011 ]
996
1012
997 dctx = zstd.ZstdDecompressor()
1013 dctx = zstd.ZstdDecompressor()
998 for source in sources:
1014 for source in sources:
999 buffer = io.BytesIO()
1015 buffer = io.BytesIO()
1000
1016
1001 decompressor = dctx.stream_writer(buffer)
1017 decompressor = dctx.stream_writer(buffer)
1002 decompressor.write(source)
1018 decompressor.write(source)
1003 self.assertEqual(buffer.getvalue(), b"foo")
1019 self.assertEqual(buffer.getvalue(), b"foo")
1004
1020
1005 buffer = NonClosingBytesIO()
1021 buffer = NonClosingBytesIO()
1006
1022
1007 with dctx.stream_writer(buffer) as decompressor:
1023 with dctx.stream_writer(buffer) as decompressor:
1008 self.assertEqual(decompressor.write(source), 3)
1024 self.assertEqual(decompressor.write(source), 3)
1009
1025
1010 self.assertEqual(buffer.getvalue(), b"foo")
1026 self.assertEqual(buffer.getvalue(), b"foo")
1011
1027
1012 buffer = io.BytesIO()
1028 buffer = io.BytesIO()
1013 writer = dctx.stream_writer(buffer, write_return_read=True)
1029 writer = dctx.stream_writer(buffer, write_return_read=True)
1014 self.assertEqual(writer.write(source), len(source))
1030 self.assertEqual(writer.write(source), len(source))
1015 self.assertEqual(buffer.getvalue(), b"foo")
1031 self.assertEqual(buffer.getvalue(), b"foo")
1016
1032
1017 def test_large_roundtrip(self):
1033 def test_large_roundtrip(self):
1018 chunks = []
1034 chunks = []
1019 for i in range(255):
1035 for i in range(255):
1020 chunks.append(struct.Struct(">B").pack(i) * 16384)
1036 chunks.append(struct.Struct(">B").pack(i) * 16384)
1021 orig = b"".join(chunks)
1037 orig = b"".join(chunks)
1022 cctx = zstd.ZstdCompressor()
1038 cctx = zstd.ZstdCompressor()
1023 compressed = cctx.compress(orig)
1039 compressed = cctx.compress(orig)
1024
1040
1025 self.assertEqual(decompress_via_writer(compressed), orig)
1041 self.assertEqual(decompress_via_writer(compressed), orig)
1026
1042
1027 def test_multiple_calls(self):
1043 def test_multiple_calls(self):
1028 chunks = []
1044 chunks = []
1029 for i in range(255):
1045 for i in range(255):
1030 for j in range(255):
1046 for j in range(255):
1031 chunks.append(struct.Struct(">B").pack(j) * i)
1047 chunks.append(struct.Struct(">B").pack(j) * i)
1032
1048
1033 orig = b"".join(chunks)
1049 orig = b"".join(chunks)
1034 cctx = zstd.ZstdCompressor()
1050 cctx = zstd.ZstdCompressor()
1035 compressed = cctx.compress(orig)
1051 compressed = cctx.compress(orig)
1036
1052
1037 buffer = NonClosingBytesIO()
1053 buffer = NonClosingBytesIO()
1038 dctx = zstd.ZstdDecompressor()
1054 dctx = zstd.ZstdDecompressor()
1039 with dctx.stream_writer(buffer) as decompressor:
1055 with dctx.stream_writer(buffer) as decompressor:
1040 pos = 0
1056 pos = 0
1041 while pos < len(compressed):
1057 while pos < len(compressed):
1042 pos2 = pos + 8192
1058 pos2 = pos + 8192
1043 decompressor.write(compressed[pos:pos2])
1059 decompressor.write(compressed[pos:pos2])
1044 pos += 8192
1060 pos += 8192
1045 self.assertEqual(buffer.getvalue(), orig)
1061 self.assertEqual(buffer.getvalue(), orig)
1046
1062
1047 # Again with write_return_read=True
1063 # Again with write_return_read=True
1048 buffer = io.BytesIO()
1064 buffer = io.BytesIO()
1049 writer = dctx.stream_writer(buffer, write_return_read=True)
1065 writer = dctx.stream_writer(buffer, write_return_read=True)
1050 pos = 0
1066 pos = 0
1051 while pos < len(compressed):
1067 while pos < len(compressed):
1052 pos2 = pos + 8192
1068 pos2 = pos + 8192
1053 chunk = compressed[pos:pos2]
1069 chunk = compressed[pos:pos2]
1054 self.assertEqual(writer.write(chunk), len(chunk))
1070 self.assertEqual(writer.write(chunk), len(chunk))
1055 pos += 8192
1071 pos += 8192
1056 self.assertEqual(buffer.getvalue(), orig)
1072 self.assertEqual(buffer.getvalue(), orig)
1057
1073
1058 def test_dictionary(self):
1074 def test_dictionary(self):
1059 samples = []
1075 samples = []
1060 for i in range(128):
1076 for i in range(128):
1061 samples.append(b"foo" * 64)
1077 samples.append(b"foo" * 64)
1062 samples.append(b"bar" * 64)
1078 samples.append(b"bar" * 64)
1063 samples.append(b"foobar" * 64)
1079 samples.append(b"foobar" * 64)
1064
1080
1065 d = zstd.train_dictionary(8192, samples)
1081 d = zstd.train_dictionary(8192, samples)
1066
1082
1067 orig = b"foobar" * 16384
1083 orig = b"foobar" * 16384
1068 buffer = NonClosingBytesIO()
1084 buffer = NonClosingBytesIO()
1069 cctx = zstd.ZstdCompressor(dict_data=d)
1085 cctx = zstd.ZstdCompressor(dict_data=d)
1070 with cctx.stream_writer(buffer) as compressor:
1086 with cctx.stream_writer(buffer) as compressor:
1071 self.assertEqual(compressor.write(orig), 0)
1087 self.assertEqual(compressor.write(orig), 0)
1072
1088
1073 compressed = buffer.getvalue()
1089 compressed = buffer.getvalue()
1074 buffer = io.BytesIO()
1090 buffer = io.BytesIO()
1075
1091
1076 dctx = zstd.ZstdDecompressor(dict_data=d)
1092 dctx = zstd.ZstdDecompressor(dict_data=d)
1077 decompressor = dctx.stream_writer(buffer)
1093 decompressor = dctx.stream_writer(buffer)
1078 self.assertEqual(decompressor.write(compressed), len(orig))
1094 self.assertEqual(decompressor.write(compressed), len(orig))
1079 self.assertEqual(buffer.getvalue(), orig)
1095 self.assertEqual(buffer.getvalue(), orig)
1080
1096
1081 buffer = NonClosingBytesIO()
1097 buffer = NonClosingBytesIO()
1082
1098
1083 with dctx.stream_writer(buffer) as decompressor:
1099 with dctx.stream_writer(buffer) as decompressor:
1084 self.assertEqual(decompressor.write(compressed), len(orig))
1100 self.assertEqual(decompressor.write(compressed), len(orig))
1085
1101
1086 self.assertEqual(buffer.getvalue(), orig)
1102 self.assertEqual(buffer.getvalue(), orig)
1087
1103
1088 def test_memory_size(self):
1104 def test_memory_size(self):
1089 dctx = zstd.ZstdDecompressor()
1105 dctx = zstd.ZstdDecompressor()
1090 buffer = io.BytesIO()
1106 buffer = io.BytesIO()
1091
1107
1092 decompressor = dctx.stream_writer(buffer)
1108 decompressor = dctx.stream_writer(buffer)
1093 size = decompressor.memory_size()
1109 size = decompressor.memory_size()
1094 self.assertGreater(size, 100000)
1110 self.assertGreater(size, 100000)
1095
1111
1096 with dctx.stream_writer(buffer) as decompressor:
1112 with dctx.stream_writer(buffer) as decompressor:
1097 size = decompressor.memory_size()
1113 size = decompressor.memory_size()
1098
1114
1099 self.assertGreater(size, 100000)
1115 self.assertGreater(size, 100000)
1100
1116
1101 def test_write_size(self):
1117 def test_write_size(self):
1102 source = zstd.ZstdCompressor().compress(b"foobarfoobar")
1118 source = zstd.ZstdCompressor().compress(b"foobarfoobar")
1103 dest = OpCountingBytesIO()
1119 dest = OpCountingBytesIO()
1104 dctx = zstd.ZstdDecompressor()
1120 dctx = zstd.ZstdDecompressor()
1105 with dctx.stream_writer(dest, write_size=1) as decompressor:
1121 with dctx.stream_writer(dest, write_size=1) as decompressor:
1106 s = struct.Struct(">B")
1122 s = struct.Struct(">B")
1107 for c in source:
1123 for c in source:
1108 if not isinstance(c, str):
1124 if not isinstance(c, str):
1109 c = s.pack(c)
1125 c = s.pack(c)
1110 decompressor.write(c)
1126 decompressor.write(c)
1111
1127
1112 self.assertEqual(dest.getvalue(), b"foobarfoobar")
1128 self.assertEqual(dest.getvalue(), b"foobarfoobar")
1113 self.assertEqual(dest._write_count, len(dest.getvalue()))
1129 self.assertEqual(dest._write_count, len(dest.getvalue()))
1114
1130
1115
1131
1116 @make_cffi
1132 @make_cffi
1117 class TestDecompressor_read_to_iter(TestCase):
1133 class TestDecompressor_read_to_iter(TestCase):
1118 def test_type_validation(self):
1134 def test_type_validation(self):
1119 dctx = zstd.ZstdDecompressor()
1135 dctx = zstd.ZstdDecompressor()
1120
1136
1121 # Object with read() works.
1137 # Object with read() works.
1122 dctx.read_to_iter(io.BytesIO())
1138 dctx.read_to_iter(io.BytesIO())
1123
1139
1124 # Buffer protocol works.
1140 # Buffer protocol works.
1125 dctx.read_to_iter(b"foobar")
1141 dctx.read_to_iter(b"foobar")
1126
1142
1127 with self.assertRaisesRegex(ValueError, "must pass an object with a read"):
1143 with self.assertRaisesRegex(
1144 ValueError, "must pass an object with a read"
1145 ):
1128 b"".join(dctx.read_to_iter(True))
1146 b"".join(dctx.read_to_iter(True))
1129
1147
1130 def test_empty_input(self):
1148 def test_empty_input(self):
1131 dctx = zstd.ZstdDecompressor()
1149 dctx = zstd.ZstdDecompressor()
1132
1150
1133 source = io.BytesIO()
1151 source = io.BytesIO()
1134 it = dctx.read_to_iter(source)
1152 it = dctx.read_to_iter(source)
1135 # TODO this is arguably wrong. Should get an error about missing frame foo.
1153 # TODO this is arguably wrong. Should get an error about missing frame foo.
1136 with self.assertRaises(StopIteration):
1154 with self.assertRaises(StopIteration):
1137 next(it)
1155 next(it)
1138
1156
1139 it = dctx.read_to_iter(b"")
1157 it = dctx.read_to_iter(b"")
1140 with self.assertRaises(StopIteration):
1158 with self.assertRaises(StopIteration):
1141 next(it)
1159 next(it)
1142
1160
1143 def test_invalid_input(self):
1161 def test_invalid_input(self):
1144 dctx = zstd.ZstdDecompressor()
1162 dctx = zstd.ZstdDecompressor()
1145
1163
1146 source = io.BytesIO(b"foobar")
1164 source = io.BytesIO(b"foobar")
1147 it = dctx.read_to_iter(source)
1165 it = dctx.read_to_iter(source)
1148 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
1166 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
1149 next(it)
1167 next(it)
1150
1168
1151 it = dctx.read_to_iter(b"foobar")
1169 it = dctx.read_to_iter(b"foobar")
1152 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
1170 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
1153 next(it)
1171 next(it)
1154
1172
1155 def test_empty_roundtrip(self):
1173 def test_empty_roundtrip(self):
1156 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1174 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1157 empty = cctx.compress(b"")
1175 empty = cctx.compress(b"")
1158
1176
1159 source = io.BytesIO(empty)
1177 source = io.BytesIO(empty)
1160 source.seek(0)
1178 source.seek(0)
1161
1179
1162 dctx = zstd.ZstdDecompressor()
1180 dctx = zstd.ZstdDecompressor()
1163 it = dctx.read_to_iter(source)
1181 it = dctx.read_to_iter(source)
1164
1182
1165 # No chunks should be emitted since there is no data.
1183 # No chunks should be emitted since there is no data.
1166 with self.assertRaises(StopIteration):
1184 with self.assertRaises(StopIteration):
1167 next(it)
1185 next(it)
1168
1186
1169 # Again for good measure.
1187 # Again for good measure.
1170 with self.assertRaises(StopIteration):
1188 with self.assertRaises(StopIteration):
1171 next(it)
1189 next(it)
1172
1190
1173 def test_skip_bytes_too_large(self):
1191 def test_skip_bytes_too_large(self):
1174 dctx = zstd.ZstdDecompressor()
1192 dctx = zstd.ZstdDecompressor()
1175
1193
1176 with self.assertRaisesRegex(
1194 with self.assertRaisesRegex(
1177 ValueError, "skip_bytes must be smaller than read_size"
1195 ValueError, "skip_bytes must be smaller than read_size"
1178 ):
1196 ):
1179 b"".join(dctx.read_to_iter(b"", skip_bytes=1, read_size=1))
1197 b"".join(dctx.read_to_iter(b"", skip_bytes=1, read_size=1))
1180
1198
1181 with self.assertRaisesRegex(
1199 with self.assertRaisesRegex(
1182 ValueError, "skip_bytes larger than first input chunk"
1200 ValueError, "skip_bytes larger than first input chunk"
1183 ):
1201 ):
1184 b"".join(dctx.read_to_iter(b"foobar", skip_bytes=10))
1202 b"".join(dctx.read_to_iter(b"foobar", skip_bytes=10))
1185
1203
1186 def test_skip_bytes(self):
1204 def test_skip_bytes(self):
1187 cctx = zstd.ZstdCompressor(write_content_size=False)
1205 cctx = zstd.ZstdCompressor(write_content_size=False)
1188 compressed = cctx.compress(b"foobar")
1206 compressed = cctx.compress(b"foobar")
1189
1207
1190 dctx = zstd.ZstdDecompressor()
1208 dctx = zstd.ZstdDecompressor()
1191 output = b"".join(dctx.read_to_iter(b"hdr" + compressed, skip_bytes=3))
1209 output = b"".join(dctx.read_to_iter(b"hdr" + compressed, skip_bytes=3))
1192 self.assertEqual(output, b"foobar")
1210 self.assertEqual(output, b"foobar")
1193
1211
1194 def test_large_output(self):
1212 def test_large_output(self):
1195 source = io.BytesIO()
1213 source = io.BytesIO()
1196 source.write(b"f" * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
1214 source.write(b"f" * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
1197 source.write(b"o")
1215 source.write(b"o")
1198 source.seek(0)
1216 source.seek(0)
1199
1217
1200 cctx = zstd.ZstdCompressor(level=1)
1218 cctx = zstd.ZstdCompressor(level=1)
1201 compressed = io.BytesIO(cctx.compress(source.getvalue()))
1219 compressed = io.BytesIO(cctx.compress(source.getvalue()))
1202 compressed.seek(0)
1220 compressed.seek(0)
1203
1221
1204 dctx = zstd.ZstdDecompressor()
1222 dctx = zstd.ZstdDecompressor()
1205 it = dctx.read_to_iter(compressed)
1223 it = dctx.read_to_iter(compressed)
1206
1224
1207 chunks = []
1225 chunks = []
1208 chunks.append(next(it))
1226 chunks.append(next(it))
1209 chunks.append(next(it))
1227 chunks.append(next(it))
1210
1228
1211 with self.assertRaises(StopIteration):
1229 with self.assertRaises(StopIteration):
1212 next(it)
1230 next(it)
1213
1231
1214 decompressed = b"".join(chunks)
1232 decompressed = b"".join(chunks)
1215 self.assertEqual(decompressed, source.getvalue())
1233 self.assertEqual(decompressed, source.getvalue())
1216
1234
1217 # And again with buffer protocol.
1235 # And again with buffer protocol.
1218 it = dctx.read_to_iter(compressed.getvalue())
1236 it = dctx.read_to_iter(compressed.getvalue())
1219 chunks = []
1237 chunks = []
1220 chunks.append(next(it))
1238 chunks.append(next(it))
1221 chunks.append(next(it))
1239 chunks.append(next(it))
1222
1240
1223 with self.assertRaises(StopIteration):
1241 with self.assertRaises(StopIteration):
1224 next(it)
1242 next(it)
1225
1243
1226 decompressed = b"".join(chunks)
1244 decompressed = b"".join(chunks)
1227 self.assertEqual(decompressed, source.getvalue())
1245 self.assertEqual(decompressed, source.getvalue())
1228
1246
1229 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
1247 @unittest.skipUnless(
1248 "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set"
1249 )
1230 def test_large_input(self):
1250 def test_large_input(self):
1231 bytes = list(struct.Struct(">B").pack(i) for i in range(256))
1251 bytes = list(struct.Struct(">B").pack(i) for i in range(256))
1232 compressed = NonClosingBytesIO()
1252 compressed = NonClosingBytesIO()
1233 input_size = 0
1253 input_size = 0
1234 cctx = zstd.ZstdCompressor(level=1)
1254 cctx = zstd.ZstdCompressor(level=1)
1235 with cctx.stream_writer(compressed) as compressor:
1255 with cctx.stream_writer(compressed) as compressor:
1236 while True:
1256 while True:
1237 compressor.write(random.choice(bytes))
1257 compressor.write(random.choice(bytes))
1238 input_size += 1
1258 input_size += 1
1239
1259
1240 have_compressed = (
1260 have_compressed = (
1241 len(compressed.getvalue())
1261 len(compressed.getvalue())
1242 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
1262 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
1243 )
1263 )
1244 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
1264 have_raw = (
1265 input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
1266 )
1245 if have_compressed and have_raw:
1267 if have_compressed and have_raw:
1246 break
1268 break
1247
1269
1248 compressed = io.BytesIO(compressed.getvalue())
1270 compressed = io.BytesIO(compressed.getvalue())
1249 self.assertGreater(
1271 self.assertGreater(
1250 len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
1272 len(compressed.getvalue()),
1273 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
1251 )
1274 )
1252
1275
1253 dctx = zstd.ZstdDecompressor()
1276 dctx = zstd.ZstdDecompressor()
1254 it = dctx.read_to_iter(compressed)
1277 it = dctx.read_to_iter(compressed)
1255
1278
1256 chunks = []
1279 chunks = []
1257 chunks.append(next(it))
1280 chunks.append(next(it))
1258 chunks.append(next(it))
1281 chunks.append(next(it))
1259 chunks.append(next(it))
1282 chunks.append(next(it))
1260
1283
1261 with self.assertRaises(StopIteration):
1284 with self.assertRaises(StopIteration):
1262 next(it)
1285 next(it)
1263
1286
1264 decompressed = b"".join(chunks)
1287 decompressed = b"".join(chunks)
1265 self.assertEqual(len(decompressed), input_size)
1288 self.assertEqual(len(decompressed), input_size)
1266
1289
1267 # And again with buffer protocol.
1290 # And again with buffer protocol.
1268 it = dctx.read_to_iter(compressed.getvalue())
1291 it = dctx.read_to_iter(compressed.getvalue())
1269
1292
1270 chunks = []
1293 chunks = []
1271 chunks.append(next(it))
1294 chunks.append(next(it))
1272 chunks.append(next(it))
1295 chunks.append(next(it))
1273 chunks.append(next(it))
1296 chunks.append(next(it))
1274
1297
1275 with self.assertRaises(StopIteration):
1298 with self.assertRaises(StopIteration):
1276 next(it)
1299 next(it)
1277
1300
1278 decompressed = b"".join(chunks)
1301 decompressed = b"".join(chunks)
1279 self.assertEqual(len(decompressed), input_size)
1302 self.assertEqual(len(decompressed), input_size)
1280
1303
1281 def test_interesting(self):
1304 def test_interesting(self):
1282 # Found this edge case via fuzzing.
1305 # Found this edge case via fuzzing.
1283 cctx = zstd.ZstdCompressor(level=1)
1306 cctx = zstd.ZstdCompressor(level=1)
1284
1307
1285 source = io.BytesIO()
1308 source = io.BytesIO()
1286
1309
1287 compressed = NonClosingBytesIO()
1310 compressed = NonClosingBytesIO()
1288 with cctx.stream_writer(compressed) as compressor:
1311 with cctx.stream_writer(compressed) as compressor:
1289 for i in range(256):
1312 for i in range(256):
1290 chunk = b"\0" * 1024
1313 chunk = b"\0" * 1024
1291 compressor.write(chunk)
1314 compressor.write(chunk)
1292 source.write(chunk)
1315 source.write(chunk)
1293
1316
1294 dctx = zstd.ZstdDecompressor()
1317 dctx = zstd.ZstdDecompressor()
1295
1318
1296 simple = dctx.decompress(
1319 simple = dctx.decompress(
1297 compressed.getvalue(), max_output_size=len(source.getvalue())
1320 compressed.getvalue(), max_output_size=len(source.getvalue())
1298 )
1321 )
1299 self.assertEqual(simple, source.getvalue())
1322 self.assertEqual(simple, source.getvalue())
1300
1323
1301 compressed = io.BytesIO(compressed.getvalue())
1324 compressed = io.BytesIO(compressed.getvalue())
1302 streamed = b"".join(dctx.read_to_iter(compressed))
1325 streamed = b"".join(dctx.read_to_iter(compressed))
1303 self.assertEqual(streamed, source.getvalue())
1326 self.assertEqual(streamed, source.getvalue())
1304
1327
1305 def test_read_write_size(self):
1328 def test_read_write_size(self):
1306 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar"))
1329 source = OpCountingBytesIO(
1330 zstd.ZstdCompressor().compress(b"foobarfoobar")
1331 )
1307 dctx = zstd.ZstdDecompressor()
1332 dctx = zstd.ZstdDecompressor()
1308 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1):
1333 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1):
1309 self.assertEqual(len(chunk), 1)
1334 self.assertEqual(len(chunk), 1)
1310
1335
1311 self.assertEqual(source._read_count, len(source.getvalue()))
1336 self.assertEqual(source._read_count, len(source.getvalue()))
1312
1337
1313 def test_magic_less(self):
1338 def test_magic_less(self):
1314 params = zstd.CompressionParameters.from_level(
1339 params = zstd.CompressionParameters.from_level(
1315 1, format=zstd.FORMAT_ZSTD1_MAGICLESS
1340 1, format=zstd.FORMAT_ZSTD1_MAGICLESS
1316 )
1341 )
1317 cctx = zstd.ZstdCompressor(compression_params=params)
1342 cctx = zstd.ZstdCompressor(compression_params=params)
1318 frame = cctx.compress(b"foobar")
1343 frame = cctx.compress(b"foobar")
1319
1344
1320 self.assertNotEqual(frame[0:4], b"\x28\xb5\x2f\xfd")
1345 self.assertNotEqual(frame[0:4], b"\x28\xb5\x2f\xfd")
1321
1346
1322 dctx = zstd.ZstdDecompressor()
1347 dctx = zstd.ZstdDecompressor()
1323 with self.assertRaisesRegex(
1348 with self.assertRaisesRegex(
1324 zstd.ZstdError, "error determining content size from frame header"
1349 zstd.ZstdError, "error determining content size from frame header"
1325 ):
1350 ):
1326 dctx.decompress(frame)
1351 dctx.decompress(frame)
1327
1352
1328 dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS)
1353 dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS)
1329 res = b"".join(dctx.read_to_iter(frame))
1354 res = b"".join(dctx.read_to_iter(frame))
1330 self.assertEqual(res, b"foobar")
1355 self.assertEqual(res, b"foobar")
1331
1356
1332
1357
1333 @make_cffi
1358 @make_cffi
1334 class TestDecompressor_content_dict_chain(TestCase):
1359 class TestDecompressor_content_dict_chain(TestCase):
1335 def test_bad_inputs_simple(self):
1360 def test_bad_inputs_simple(self):
1336 dctx = zstd.ZstdDecompressor()
1361 dctx = zstd.ZstdDecompressor()
1337
1362
1338 with self.assertRaises(TypeError):
1363 with self.assertRaises(TypeError):
1339 dctx.decompress_content_dict_chain(b"foo")
1364 dctx.decompress_content_dict_chain(b"foo")
1340
1365
1341 with self.assertRaises(TypeError):
1366 with self.assertRaises(TypeError):
1342 dctx.decompress_content_dict_chain((b"foo", b"bar"))
1367 dctx.decompress_content_dict_chain((b"foo", b"bar"))
1343
1368
1344 with self.assertRaisesRegex(ValueError, "empty input chain"):
1369 with self.assertRaisesRegex(ValueError, "empty input chain"):
1345 dctx.decompress_content_dict_chain([])
1370 dctx.decompress_content_dict_chain([])
1346
1371
1347 with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"):
1372 with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"):
1348 dctx.decompress_content_dict_chain([u"foo"])
1373 dctx.decompress_content_dict_chain([u"foo"])
1349
1374
1350 with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"):
1375 with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"):
1351 dctx.decompress_content_dict_chain([True])
1376 dctx.decompress_content_dict_chain([True])
1352
1377
1353 with self.assertRaisesRegex(
1378 with self.assertRaisesRegex(
1354 ValueError, "chunk 0 is too small to contain a zstd frame"
1379 ValueError, "chunk 0 is too small to contain a zstd frame"
1355 ):
1380 ):
1356 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
1381 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
1357
1382
1358 with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"):
1383 with self.assertRaisesRegex(
1384 ValueError, "chunk 0 is not a valid zstd frame"
1385 ):
1359 dctx.decompress_content_dict_chain([b"foo" * 8])
1386 dctx.decompress_content_dict_chain([b"foo" * 8])
1360
1387
1361 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64)
1388 no_size = zstd.ZstdCompressor(write_content_size=False).compress(
1389 b"foo" * 64
1390 )
1362
1391
1363 with self.assertRaisesRegex(
1392 with self.assertRaisesRegex(
1364 ValueError, "chunk 0 missing content size in frame"
1393 ValueError, "chunk 0 missing content size in frame"
1365 ):
1394 ):
1366 dctx.decompress_content_dict_chain([no_size])
1395 dctx.decompress_content_dict_chain([no_size])
1367
1396
1368 # Corrupt first frame.
1397 # Corrupt first frame.
1369 frame = zstd.ZstdCompressor().compress(b"foo" * 64)
1398 frame = zstd.ZstdCompressor().compress(b"foo" * 64)
1370 frame = frame[0:12] + frame[15:]
1399 frame = frame[0:12] + frame[15:]
1371 with self.assertRaisesRegex(
1400 with self.assertRaisesRegex(
1372 zstd.ZstdError, "chunk 0 did not decompress full frame"
1401 zstd.ZstdError, "chunk 0 did not decompress full frame"
1373 ):
1402 ):
1374 dctx.decompress_content_dict_chain([frame])
1403 dctx.decompress_content_dict_chain([frame])
1375
1404
1376 def test_bad_subsequent_input(self):
1405 def test_bad_subsequent_input(self):
1377 initial = zstd.ZstdCompressor().compress(b"foo" * 64)
1406 initial = zstd.ZstdCompressor().compress(b"foo" * 64)
1378
1407
1379 dctx = zstd.ZstdDecompressor()
1408 dctx = zstd.ZstdDecompressor()
1380
1409
1381 with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"):
1410 with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"):
1382 dctx.decompress_content_dict_chain([initial, u"foo"])
1411 dctx.decompress_content_dict_chain([initial, u"foo"])
1383
1412
1384 with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"):
1413 with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"):
1385 dctx.decompress_content_dict_chain([initial, None])
1414 dctx.decompress_content_dict_chain([initial, None])
1386
1415
1387 with self.assertRaisesRegex(
1416 with self.assertRaisesRegex(
1388 ValueError, "chunk 1 is too small to contain a zstd frame"
1417 ValueError, "chunk 1 is too small to contain a zstd frame"
1389 ):
1418 ):
1390 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
1419 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
1391
1420
1392 with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"):
1421 with self.assertRaisesRegex(
1422 ValueError, "chunk 1 is not a valid zstd frame"
1423 ):
1393 dctx.decompress_content_dict_chain([initial, b"foo" * 8])
1424 dctx.decompress_content_dict_chain([initial, b"foo" * 8])
1394
1425
1395 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64)
1426 no_size = zstd.ZstdCompressor(write_content_size=False).compress(
1427 b"foo" * 64
1428 )
1396
1429
1397 with self.assertRaisesRegex(
1430 with self.assertRaisesRegex(
1398 ValueError, "chunk 1 missing content size in frame"
1431 ValueError, "chunk 1 missing content size in frame"
1399 ):
1432 ):
1400 dctx.decompress_content_dict_chain([initial, no_size])
1433 dctx.decompress_content_dict_chain([initial, no_size])
1401
1434
1402 # Corrupt second frame.
1435 # Corrupt second frame.
1403 cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64))
1436 cctx = zstd.ZstdCompressor(
1437 dict_data=zstd.ZstdCompressionDict(b"foo" * 64)
1438 )
1404 frame = cctx.compress(b"bar" * 64)
1439 frame = cctx.compress(b"bar" * 64)
1405 frame = frame[0:12] + frame[15:]
1440 frame = frame[0:12] + frame[15:]
1406
1441
1407 with self.assertRaisesRegex(
1442 with self.assertRaisesRegex(
1408 zstd.ZstdError, "chunk 1 did not decompress full frame"
1443 zstd.ZstdError, "chunk 1 did not decompress full frame"
1409 ):
1444 ):
1410 dctx.decompress_content_dict_chain([initial, frame])
1445 dctx.decompress_content_dict_chain([initial, frame])
1411
1446
1412 def test_simple(self):
1447 def test_simple(self):
1413 original = [
1448 original = [
1414 b"foo" * 64,
1449 b"foo" * 64,
1415 b"foobar" * 64,
1450 b"foobar" * 64,
1416 b"baz" * 64,
1451 b"baz" * 64,
1417 b"foobaz" * 64,
1452 b"foobaz" * 64,
1418 b"foobarbaz" * 64,
1453 b"foobarbaz" * 64,
1419 ]
1454 ]
1420
1455
1421 chunks = []
1456 chunks = []
1422 chunks.append(zstd.ZstdCompressor().compress(original[0]))
1457 chunks.append(zstd.ZstdCompressor().compress(original[0]))
1423 for i, chunk in enumerate(original[1:]):
1458 for i, chunk in enumerate(original[1:]):
1424 d = zstd.ZstdCompressionDict(original[i])
1459 d = zstd.ZstdCompressionDict(original[i])
1425 cctx = zstd.ZstdCompressor(dict_data=d)
1460 cctx = zstd.ZstdCompressor(dict_data=d)
1426 chunks.append(cctx.compress(chunk))
1461 chunks.append(cctx.compress(chunk))
1427
1462
1428 for i in range(1, len(original)):
1463 for i in range(1, len(original)):
1429 chain = chunks[0:i]
1464 chain = chunks[0:i]
1430 expected = original[i - 1]
1465 expected = original[i - 1]
1431 dctx = zstd.ZstdDecompressor()
1466 dctx = zstd.ZstdDecompressor()
1432 decompressed = dctx.decompress_content_dict_chain(chain)
1467 decompressed = dctx.decompress_content_dict_chain(chain)
1433 self.assertEqual(decompressed, expected)
1468 self.assertEqual(decompressed, expected)
1434
1469
1435
1470
1436 # TODO enable for CFFI
1471 # TODO enable for CFFI
1437 class TestDecompressor_multi_decompress_to_buffer(TestCase):
1472 class TestDecompressor_multi_decompress_to_buffer(TestCase):
1438 def test_invalid_inputs(self):
1473 def test_invalid_inputs(self):
1439 dctx = zstd.ZstdDecompressor()
1474 dctx = zstd.ZstdDecompressor()
1440
1475
1441 if not hasattr(dctx, "multi_decompress_to_buffer"):
1476 if not hasattr(dctx, "multi_decompress_to_buffer"):
1442 self.skipTest("multi_decompress_to_buffer not available")
1477 self.skipTest("multi_decompress_to_buffer not available")
1443
1478
1444 with self.assertRaises(TypeError):
1479 with self.assertRaises(TypeError):
1445 dctx.multi_decompress_to_buffer(True)
1480 dctx.multi_decompress_to_buffer(True)
1446
1481
1447 with self.assertRaises(TypeError):
1482 with self.assertRaises(TypeError):
1448 dctx.multi_decompress_to_buffer((1, 2))
1483 dctx.multi_decompress_to_buffer((1, 2))
1449
1484
1450 with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"):
1485 with self.assertRaisesRegex(
1486 TypeError, "item 0 not a bytes like object"
1487 ):
1451 dctx.multi_decompress_to_buffer([u"foo"])
1488 dctx.multi_decompress_to_buffer([u"foo"])
1452
1489
1453 with self.assertRaisesRegex(
1490 with self.assertRaisesRegex(
1454 ValueError, "could not determine decompressed size of item 0"
1491 ValueError, "could not determine decompressed size of item 0"
1455 ):
1492 ):
1456 dctx.multi_decompress_to_buffer([b"foobarbaz"])
1493 dctx.multi_decompress_to_buffer([b"foobarbaz"])
1457
1494
1458 def test_list_input(self):
1495 def test_list_input(self):
1459 cctx = zstd.ZstdCompressor()
1496 cctx = zstd.ZstdCompressor()
1460
1497
1461 original = [b"foo" * 4, b"bar" * 6]
1498 original = [b"foo" * 4, b"bar" * 6]
1462 frames = [cctx.compress(d) for d in original]
1499 frames = [cctx.compress(d) for d in original]
1463
1500
1464 dctx = zstd.ZstdDecompressor()
1501 dctx = zstd.ZstdDecompressor()
1465
1502
1466 if not hasattr(dctx, "multi_decompress_to_buffer"):
1503 if not hasattr(dctx, "multi_decompress_to_buffer"):
1467 self.skipTest("multi_decompress_to_buffer not available")
1504 self.skipTest("multi_decompress_to_buffer not available")
1468
1505
1469 result = dctx.multi_decompress_to_buffer(frames)
1506 result = dctx.multi_decompress_to_buffer(frames)
1470
1507
1471 self.assertEqual(len(result), len(frames))
1508 self.assertEqual(len(result), len(frames))
1472 self.assertEqual(result.size(), sum(map(len, original)))
1509 self.assertEqual(result.size(), sum(map(len, original)))
1473
1510
1474 for i, data in enumerate(original):
1511 for i, data in enumerate(original):
1475 self.assertEqual(result[i].tobytes(), data)
1512 self.assertEqual(result[i].tobytes(), data)
1476
1513
1477 self.assertEqual(result[0].offset, 0)
1514 self.assertEqual(result[0].offset, 0)
1478 self.assertEqual(len(result[0]), 12)
1515 self.assertEqual(len(result[0]), 12)
1479 self.assertEqual(result[1].offset, 12)
1516 self.assertEqual(result[1].offset, 12)
1480 self.assertEqual(len(result[1]), 18)
1517 self.assertEqual(len(result[1]), 18)
1481
1518
1482 def test_list_input_frame_sizes(self):
1519 def test_list_input_frame_sizes(self):
1483 cctx = zstd.ZstdCompressor()
1520 cctx = zstd.ZstdCompressor()
1484
1521
1485 original = [b"foo" * 4, b"bar" * 6, b"baz" * 8]
1522 original = [b"foo" * 4, b"bar" * 6, b"baz" * 8]
1486 frames = [cctx.compress(d) for d in original]
1523 frames = [cctx.compress(d) for d in original]
1487 sizes = struct.pack("=" + "Q" * len(original), *map(len, original))
1524 sizes = struct.pack("=" + "Q" * len(original), *map(len, original))
1488
1525
1489 dctx = zstd.ZstdDecompressor()
1526 dctx = zstd.ZstdDecompressor()
1490
1527
1491 if not hasattr(dctx, "multi_decompress_to_buffer"):
1528 if not hasattr(dctx, "multi_decompress_to_buffer"):
1492 self.skipTest("multi_decompress_to_buffer not available")
1529 self.skipTest("multi_decompress_to_buffer not available")
1493
1530
1494 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
1531 result = dctx.multi_decompress_to_buffer(
1532 frames, decompressed_sizes=sizes
1533 )
1495
1534
1496 self.assertEqual(len(result), len(frames))
1535 self.assertEqual(len(result), len(frames))
1497 self.assertEqual(result.size(), sum(map(len, original)))
1536 self.assertEqual(result.size(), sum(map(len, original)))
1498
1537
1499 for i, data in enumerate(original):
1538 for i, data in enumerate(original):
1500 self.assertEqual(result[i].tobytes(), data)
1539 self.assertEqual(result[i].tobytes(), data)
1501
1540
1502 def test_buffer_with_segments_input(self):
1541 def test_buffer_with_segments_input(self):
1503 cctx = zstd.ZstdCompressor()
1542 cctx = zstd.ZstdCompressor()
1504
1543
1505 original = [b"foo" * 4, b"bar" * 6]
1544 original = [b"foo" * 4, b"bar" * 6]
1506 frames = [cctx.compress(d) for d in original]
1545 frames = [cctx.compress(d) for d in original]
1507
1546
1508 dctx = zstd.ZstdDecompressor()
1547 dctx = zstd.ZstdDecompressor()
1509
1548
1510 if not hasattr(dctx, "multi_decompress_to_buffer"):
1549 if not hasattr(dctx, "multi_decompress_to_buffer"):
1511 self.skipTest("multi_decompress_to_buffer not available")
1550 self.skipTest("multi_decompress_to_buffer not available")
1512
1551
1513 segments = struct.pack(
1552 segments = struct.pack(
1514 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])
1553 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])
1515 )
1554 )
1516 b = zstd.BufferWithSegments(b"".join(frames), segments)
1555 b = zstd.BufferWithSegments(b"".join(frames), segments)
1517
1556
1518 result = dctx.multi_decompress_to_buffer(b)
1557 result = dctx.multi_decompress_to_buffer(b)
1519
1558
1520 self.assertEqual(len(result), len(frames))
1559 self.assertEqual(len(result), len(frames))
1521 self.assertEqual(result[0].offset, 0)
1560 self.assertEqual(result[0].offset, 0)
1522 self.assertEqual(len(result[0]), 12)
1561 self.assertEqual(len(result[0]), 12)
1523 self.assertEqual(result[1].offset, 12)
1562 self.assertEqual(result[1].offset, 12)
1524 self.assertEqual(len(result[1]), 18)
1563 self.assertEqual(len(result[1]), 18)
1525
1564
1526 def test_buffer_with_segments_sizes(self):
1565 def test_buffer_with_segments_sizes(self):
1527 cctx = zstd.ZstdCompressor(write_content_size=False)
1566 cctx = zstd.ZstdCompressor(write_content_size=False)
1528 original = [b"foo" * 4, b"bar" * 6, b"baz" * 8]
1567 original = [b"foo" * 4, b"bar" * 6, b"baz" * 8]
1529 frames = [cctx.compress(d) for d in original]
1568 frames = [cctx.compress(d) for d in original]
1530 sizes = struct.pack("=" + "Q" * len(original), *map(len, original))
1569 sizes = struct.pack("=" + "Q" * len(original), *map(len, original))
1531
1570
1532 dctx = zstd.ZstdDecompressor()
1571 dctx = zstd.ZstdDecompressor()
1533
1572
1534 if not hasattr(dctx, "multi_decompress_to_buffer"):
1573 if not hasattr(dctx, "multi_decompress_to_buffer"):
1535 self.skipTest("multi_decompress_to_buffer not available")
1574 self.skipTest("multi_decompress_to_buffer not available")
1536
1575
1537 segments = struct.pack(
1576 segments = struct.pack(
1538 "=QQQQQQ",
1577 "=QQQQQQ",
1539 0,
1578 0,
1540 len(frames[0]),
1579 len(frames[0]),
1541 len(frames[0]),
1580 len(frames[0]),
1542 len(frames[1]),
1581 len(frames[1]),
1543 len(frames[0]) + len(frames[1]),
1582 len(frames[0]) + len(frames[1]),
1544 len(frames[2]),
1583 len(frames[2]),
1545 )
1584 )
1546 b = zstd.BufferWithSegments(b"".join(frames), segments)
1585 b = zstd.BufferWithSegments(b"".join(frames), segments)
1547
1586
1548 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
1587 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
1549
1588
1550 self.assertEqual(len(result), len(frames))
1589 self.assertEqual(len(result), len(frames))
1551 self.assertEqual(result.size(), sum(map(len, original)))
1590 self.assertEqual(result.size(), sum(map(len, original)))
1552
1591
1553 for i, data in enumerate(original):
1592 for i, data in enumerate(original):
1554 self.assertEqual(result[i].tobytes(), data)
1593 self.assertEqual(result[i].tobytes(), data)
1555
1594
1556 def test_buffer_with_segments_collection_input(self):
1595 def test_buffer_with_segments_collection_input(self):
1557 cctx = zstd.ZstdCompressor()
1596 cctx = zstd.ZstdCompressor()
1558
1597
1559 original = [
1598 original = [
1560 b"foo0" * 2,
1599 b"foo0" * 2,
1561 b"foo1" * 3,
1600 b"foo1" * 3,
1562 b"foo2" * 4,
1601 b"foo2" * 4,
1563 b"foo3" * 5,
1602 b"foo3" * 5,
1564 b"foo4" * 6,
1603 b"foo4" * 6,
1565 ]
1604 ]
1566
1605
1567 if not hasattr(cctx, "multi_compress_to_buffer"):
1606 if not hasattr(cctx, "multi_compress_to_buffer"):
1568 self.skipTest("multi_compress_to_buffer not available")
1607 self.skipTest("multi_compress_to_buffer not available")
1569
1608
1570 frames = cctx.multi_compress_to_buffer(original)
1609 frames = cctx.multi_compress_to_buffer(original)
1571
1610
1572 # Check round trip.
1611 # Check round trip.
1573 dctx = zstd.ZstdDecompressor()
1612 dctx = zstd.ZstdDecompressor()
1574
1613
1575 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
1614 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
1576
1615
1577 self.assertEqual(len(decompressed), len(original))
1616 self.assertEqual(len(decompressed), len(original))
1578
1617
1579 for i, data in enumerate(original):
1618 for i, data in enumerate(original):
1580 self.assertEqual(data, decompressed[i].tobytes())
1619 self.assertEqual(data, decompressed[i].tobytes())
1581
1620
1582 # And a manual mode.
1621 # And a manual mode.
1583 b = b"".join([frames[0].tobytes(), frames[1].tobytes()])
1622 b = b"".join([frames[0].tobytes(), frames[1].tobytes()])
1584 b1 = zstd.BufferWithSegments(
1623 b1 = zstd.BufferWithSegments(
1585 b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]))
1624 b,
1625 struct.pack(
1626 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])
1627 ),
1586 )
1628 )
1587
1629
1588 b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()])
1630 b = b"".join(
1631 [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]
1632 )
1589 b2 = zstd.BufferWithSegments(
1633 b2 = zstd.BufferWithSegments(
1590 b,
1634 b,
1591 struct.pack(
1635 struct.pack(
1592 "=QQQQQQ",
1636 "=QQQQQQ",
1593 0,
1637 0,
1594 len(frames[2]),
1638 len(frames[2]),
1595 len(frames[2]),
1639 len(frames[2]),
1596 len(frames[3]),
1640 len(frames[3]),
1597 len(frames[2]) + len(frames[3]),
1641 len(frames[2]) + len(frames[3]),
1598 len(frames[4]),
1642 len(frames[4]),
1599 ),
1643 ),
1600 )
1644 )
1601
1645
1602 c = zstd.BufferWithSegmentsCollection(b1, b2)
1646 c = zstd.BufferWithSegmentsCollection(b1, b2)
1603
1647
1604 dctx = zstd.ZstdDecompressor()
1648 dctx = zstd.ZstdDecompressor()
1605 decompressed = dctx.multi_decompress_to_buffer(c)
1649 decompressed = dctx.multi_decompress_to_buffer(c)
1606
1650
1607 self.assertEqual(len(decompressed), 5)
1651 self.assertEqual(len(decompressed), 5)
1608 for i in range(5):
1652 for i in range(5):
1609 self.assertEqual(decompressed[i].tobytes(), original[i])
1653 self.assertEqual(decompressed[i].tobytes(), original[i])
1610
1654
1611 def test_dict(self):
1655 def test_dict(self):
1612 d = zstd.train_dictionary(16384, generate_samples(), k=64, d=16)
1656 d = zstd.train_dictionary(16384, generate_samples(), k=64, d=16)
1613
1657
1614 cctx = zstd.ZstdCompressor(dict_data=d, level=1)
1658 cctx = zstd.ZstdCompressor(dict_data=d, level=1)
1615 frames = [cctx.compress(s) for s in generate_samples()]
1659 frames = [cctx.compress(s) for s in generate_samples()]
1616
1660
1617 dctx = zstd.ZstdDecompressor(dict_data=d)
1661 dctx = zstd.ZstdDecompressor(dict_data=d)
1618
1662
1619 if not hasattr(dctx, "multi_decompress_to_buffer"):
1663 if not hasattr(dctx, "multi_decompress_to_buffer"):
1620 self.skipTest("multi_decompress_to_buffer not available")
1664 self.skipTest("multi_decompress_to_buffer not available")
1621
1665
1622 result = dctx.multi_decompress_to_buffer(frames)
1666 result = dctx.multi_decompress_to_buffer(frames)
1623
1667
1624 self.assertEqual([o.tobytes() for o in result], generate_samples())
1668 self.assertEqual([o.tobytes() for o in result], generate_samples())
1625
1669
1626 def test_multiple_threads(self):
1670 def test_multiple_threads(self):
1627 cctx = zstd.ZstdCompressor()
1671 cctx = zstd.ZstdCompressor()
1628
1672
1629 frames = []
1673 frames = []
1630 frames.extend(cctx.compress(b"x" * 64) for i in range(256))
1674 frames.extend(cctx.compress(b"x" * 64) for i in range(256))
1631 frames.extend(cctx.compress(b"y" * 64) for i in range(256))
1675 frames.extend(cctx.compress(b"y" * 64) for i in range(256))
1632
1676
1633 dctx = zstd.ZstdDecompressor()
1677 dctx = zstd.ZstdDecompressor()
1634
1678
1635 if not hasattr(dctx, "multi_decompress_to_buffer"):
1679 if not hasattr(dctx, "multi_decompress_to_buffer"):
1636 self.skipTest("multi_decompress_to_buffer not available")
1680 self.skipTest("multi_decompress_to_buffer not available")
1637
1681
1638 result = dctx.multi_decompress_to_buffer(frames, threads=-1)
1682 result = dctx.multi_decompress_to_buffer(frames, threads=-1)
1639
1683
1640 self.assertEqual(len(result), len(frames))
1684 self.assertEqual(len(result), len(frames))
1641 self.assertEqual(result.size(), 2 * 64 * 256)
1685 self.assertEqual(result.size(), 2 * 64 * 256)
1642 self.assertEqual(result[0].tobytes(), b"x" * 64)
1686 self.assertEqual(result[0].tobytes(), b"x" * 64)
1643 self.assertEqual(result[256].tobytes(), b"y" * 64)
1687 self.assertEqual(result[256].tobytes(), b"y" * 64)
1644
1688
1645 def test_item_failure(self):
1689 def test_item_failure(self):
1646 cctx = zstd.ZstdCompressor()
1690 cctx = zstd.ZstdCompressor()
1647 frames = [cctx.compress(b"x" * 128), cctx.compress(b"y" * 128)]
1691 frames = [cctx.compress(b"x" * 128), cctx.compress(b"y" * 128)]
1648
1692
1649 frames[1] = frames[1][0:15] + b"extra" + frames[1][15:]
1693 frames[1] = frames[1][0:15] + b"extra" + frames[1][15:]
1650
1694
1651 dctx = zstd.ZstdDecompressor()
1695 dctx = zstd.ZstdDecompressor()
1652
1696
1653 if not hasattr(dctx, "multi_decompress_to_buffer"):
1697 if not hasattr(dctx, "multi_decompress_to_buffer"):
1654 self.skipTest("multi_decompress_to_buffer not available")
1698 self.skipTest("multi_decompress_to_buffer not available")
1655
1699
1656 with self.assertRaisesRegex(
1700 with self.assertRaisesRegex(
1657 zstd.ZstdError,
1701 zstd.ZstdError,
1658 "error decompressing item 1: ("
1702 "error decompressing item 1: ("
1659 "Corrupted block|"
1703 "Corrupted block|"
1660 "Destination buffer is too small)",
1704 "Destination buffer is too small)",
1661 ):
1705 ):
1662 dctx.multi_decompress_to_buffer(frames)
1706 dctx.multi_decompress_to_buffer(frames)
1663
1707
1664 with self.assertRaisesRegex(
1708 with self.assertRaisesRegex(
1665 zstd.ZstdError,
1709 zstd.ZstdError,
1666 "error decompressing item 1: ("
1710 "error decompressing item 1: ("
1667 "Corrupted block|"
1711 "Corrupted block|"
1668 "Destination buffer is too small)",
1712 "Destination buffer is too small)",
1669 ):
1713 ):
1670 dctx.multi_decompress_to_buffer(frames, threads=2)
1714 dctx.multi_decompress_to_buffer(frames, threads=2)
@@ -1,576 +1,593 b''
1 import io
1 import io
2 import os
2 import os
3 import unittest
3 import unittest
4
4
5 try:
5 try:
6 import hypothesis
6 import hypothesis
7 import hypothesis.strategies as strategies
7 import hypothesis.strategies as strategies
8 except ImportError:
8 except ImportError:
9 raise unittest.SkipTest("hypothesis not available")
9 raise unittest.SkipTest("hypothesis not available")
10
10
11 import zstandard as zstd
11 import zstandard as zstd
12
12
13 from .common import (
13 from .common import (
14 make_cffi,
14 make_cffi,
15 NonClosingBytesIO,
15 NonClosingBytesIO,
16 random_input_data,
16 random_input_data,
17 TestCase,
17 TestCase,
18 )
18 )
19
19
20
20
21 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
21 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
22 @make_cffi
22 @make_cffi
23 class TestDecompressor_stream_reader_fuzzing(TestCase):
23 class TestDecompressor_stream_reader_fuzzing(TestCase):
24 @hypothesis.settings(
24 @hypothesis.settings(
25 suppress_health_check=[
25 suppress_health_check=[
26 hypothesis.HealthCheck.large_base_example,
26 hypothesis.HealthCheck.large_base_example,
27 hypothesis.HealthCheck.too_slow,
27 hypothesis.HealthCheck.too_slow,
28 ]
28 ]
29 )
29 )
30 @hypothesis.given(
30 @hypothesis.given(
31 original=strategies.sampled_from(random_input_data()),
31 original=strategies.sampled_from(random_input_data()),
32 level=strategies.integers(min_value=1, max_value=5),
32 level=strategies.integers(min_value=1, max_value=5),
33 streaming=strategies.booleans(),
33 streaming=strategies.booleans(),
34 source_read_size=strategies.integers(1, 1048576),
34 source_read_size=strategies.integers(1, 1048576),
35 read_sizes=strategies.data(),
35 read_sizes=strategies.data(),
36 )
36 )
37 def test_stream_source_read_variance(
37 def test_stream_source_read_variance(
38 self, original, level, streaming, source_read_size, read_sizes
38 self, original, level, streaming, source_read_size, read_sizes
39 ):
39 ):
40 cctx = zstd.ZstdCompressor(level=level)
40 cctx = zstd.ZstdCompressor(level=level)
41
41
42 if streaming:
42 if streaming:
43 source = io.BytesIO()
43 source = io.BytesIO()
44 writer = cctx.stream_writer(source)
44 writer = cctx.stream_writer(source)
45 writer.write(original)
45 writer.write(original)
46 writer.flush(zstd.FLUSH_FRAME)
46 writer.flush(zstd.FLUSH_FRAME)
47 source.seek(0)
47 source.seek(0)
48 else:
48 else:
49 frame = cctx.compress(original)
49 frame = cctx.compress(original)
50 source = io.BytesIO(frame)
50 source = io.BytesIO(frame)
51
51
52 dctx = zstd.ZstdDecompressor()
52 dctx = zstd.ZstdDecompressor()
53
53
54 chunks = []
54 chunks = []
55 with dctx.stream_reader(source, read_size=source_read_size) as reader:
55 with dctx.stream_reader(source, read_size=source_read_size) as reader:
56 while True:
56 while True:
57 read_size = read_sizes.draw(strategies.integers(-1, 131072))
57 read_size = read_sizes.draw(strategies.integers(-1, 131072))
58 chunk = reader.read(read_size)
58 chunk = reader.read(read_size)
59 if not chunk and read_size:
59 if not chunk and read_size:
60 break
60 break
61
61
62 chunks.append(chunk)
62 chunks.append(chunk)
63
63
64 self.assertEqual(b"".join(chunks), original)
64 self.assertEqual(b"".join(chunks), original)
65
65
66 # Similar to above except we have a constant read() size.
66 # Similar to above except we have a constant read() size.
67 @hypothesis.settings(
67 @hypothesis.settings(
68 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
68 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
69 )
69 )
70 @hypothesis.given(
70 @hypothesis.given(
71 original=strategies.sampled_from(random_input_data()),
71 original=strategies.sampled_from(random_input_data()),
72 level=strategies.integers(min_value=1, max_value=5),
72 level=strategies.integers(min_value=1, max_value=5),
73 streaming=strategies.booleans(),
73 streaming=strategies.booleans(),
74 source_read_size=strategies.integers(1, 1048576),
74 source_read_size=strategies.integers(1, 1048576),
75 read_size=strategies.integers(-1, 131072),
75 read_size=strategies.integers(-1, 131072),
76 )
76 )
77 def test_stream_source_read_size(
77 def test_stream_source_read_size(
78 self, original, level, streaming, source_read_size, read_size
78 self, original, level, streaming, source_read_size, read_size
79 ):
79 ):
80 if read_size == 0:
80 if read_size == 0:
81 read_size = 1
81 read_size = 1
82
82
83 cctx = zstd.ZstdCompressor(level=level)
83 cctx = zstd.ZstdCompressor(level=level)
84
84
85 if streaming:
85 if streaming:
86 source = io.BytesIO()
86 source = io.BytesIO()
87 writer = cctx.stream_writer(source)
87 writer = cctx.stream_writer(source)
88 writer.write(original)
88 writer.write(original)
89 writer.flush(zstd.FLUSH_FRAME)
89 writer.flush(zstd.FLUSH_FRAME)
90 source.seek(0)
90 source.seek(0)
91 else:
91 else:
92 frame = cctx.compress(original)
92 frame = cctx.compress(original)
93 source = io.BytesIO(frame)
93 source = io.BytesIO(frame)
94
94
95 dctx = zstd.ZstdDecompressor()
95 dctx = zstd.ZstdDecompressor()
96
96
97 chunks = []
97 chunks = []
98 reader = dctx.stream_reader(source, read_size=source_read_size)
98 reader = dctx.stream_reader(source, read_size=source_read_size)
99 while True:
99 while True:
100 chunk = reader.read(read_size)
100 chunk = reader.read(read_size)
101 if not chunk and read_size:
101 if not chunk and read_size:
102 break
102 break
103
103
104 chunks.append(chunk)
104 chunks.append(chunk)
105
105
106 self.assertEqual(b"".join(chunks), original)
106 self.assertEqual(b"".join(chunks), original)
107
107
108 @hypothesis.settings(
108 @hypothesis.settings(
109 suppress_health_check=[
109 suppress_health_check=[
110 hypothesis.HealthCheck.large_base_example,
110 hypothesis.HealthCheck.large_base_example,
111 hypothesis.HealthCheck.too_slow,
111 hypothesis.HealthCheck.too_slow,
112 ]
112 ]
113 )
113 )
114 @hypothesis.given(
114 @hypothesis.given(
115 original=strategies.sampled_from(random_input_data()),
115 original=strategies.sampled_from(random_input_data()),
116 level=strategies.integers(min_value=1, max_value=5),
116 level=strategies.integers(min_value=1, max_value=5),
117 streaming=strategies.booleans(),
117 streaming=strategies.booleans(),
118 source_read_size=strategies.integers(1, 1048576),
118 source_read_size=strategies.integers(1, 1048576),
119 read_sizes=strategies.data(),
119 read_sizes=strategies.data(),
120 )
120 )
121 def test_buffer_source_read_variance(
121 def test_buffer_source_read_variance(
122 self, original, level, streaming, source_read_size, read_sizes
122 self, original, level, streaming, source_read_size, read_sizes
123 ):
123 ):
124 cctx = zstd.ZstdCompressor(level=level)
124 cctx = zstd.ZstdCompressor(level=level)
125
125
126 if streaming:
126 if streaming:
127 source = io.BytesIO()
127 source = io.BytesIO()
128 writer = cctx.stream_writer(source)
128 writer = cctx.stream_writer(source)
129 writer.write(original)
129 writer.write(original)
130 writer.flush(zstd.FLUSH_FRAME)
130 writer.flush(zstd.FLUSH_FRAME)
131 frame = source.getvalue()
131 frame = source.getvalue()
132 else:
132 else:
133 frame = cctx.compress(original)
133 frame = cctx.compress(original)
134
134
135 dctx = zstd.ZstdDecompressor()
135 dctx = zstd.ZstdDecompressor()
136 chunks = []
136 chunks = []
137
137
138 with dctx.stream_reader(frame, read_size=source_read_size) as reader:
138 with dctx.stream_reader(frame, read_size=source_read_size) as reader:
139 while True:
139 while True:
140 read_size = read_sizes.draw(strategies.integers(-1, 131072))
140 read_size = read_sizes.draw(strategies.integers(-1, 131072))
141 chunk = reader.read(read_size)
141 chunk = reader.read(read_size)
142 if not chunk and read_size:
142 if not chunk and read_size:
143 break
143 break
144
144
145 chunks.append(chunk)
145 chunks.append(chunk)
146
146
147 self.assertEqual(b"".join(chunks), original)
147 self.assertEqual(b"".join(chunks), original)
148
148
149 # Similar to above except we have a constant read() size.
149 # Similar to above except we have a constant read() size.
150 @hypothesis.settings(
150 @hypothesis.settings(
151 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
151 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
152 )
152 )
153 @hypothesis.given(
153 @hypothesis.given(
154 original=strategies.sampled_from(random_input_data()),
154 original=strategies.sampled_from(random_input_data()),
155 level=strategies.integers(min_value=1, max_value=5),
155 level=strategies.integers(min_value=1, max_value=5),
156 streaming=strategies.booleans(),
156 streaming=strategies.booleans(),
157 source_read_size=strategies.integers(1, 1048576),
157 source_read_size=strategies.integers(1, 1048576),
158 read_size=strategies.integers(-1, 131072),
158 read_size=strategies.integers(-1, 131072),
159 )
159 )
160 def test_buffer_source_constant_read_size(
160 def test_buffer_source_constant_read_size(
161 self, original, level, streaming, source_read_size, read_size
161 self, original, level, streaming, source_read_size, read_size
162 ):
162 ):
163 if read_size == 0:
163 if read_size == 0:
164 read_size = -1
164 read_size = -1
165
165
166 cctx = zstd.ZstdCompressor(level=level)
166 cctx = zstd.ZstdCompressor(level=level)
167
167
168 if streaming:
168 if streaming:
169 source = io.BytesIO()
169 source = io.BytesIO()
170 writer = cctx.stream_writer(source)
170 writer = cctx.stream_writer(source)
171 writer.write(original)
171 writer.write(original)
172 writer.flush(zstd.FLUSH_FRAME)
172 writer.flush(zstd.FLUSH_FRAME)
173 frame = source.getvalue()
173 frame = source.getvalue()
174 else:
174 else:
175 frame = cctx.compress(original)
175 frame = cctx.compress(original)
176
176
177 dctx = zstd.ZstdDecompressor()
177 dctx = zstd.ZstdDecompressor()
178 chunks = []
178 chunks = []
179
179
180 reader = dctx.stream_reader(frame, read_size=source_read_size)
180 reader = dctx.stream_reader(frame, read_size=source_read_size)
181 while True:
181 while True:
182 chunk = reader.read(read_size)
182 chunk = reader.read(read_size)
183 if not chunk and read_size:
183 if not chunk and read_size:
184 break
184 break
185
185
186 chunks.append(chunk)
186 chunks.append(chunk)
187
187
188 self.assertEqual(b"".join(chunks), original)
188 self.assertEqual(b"".join(chunks), original)
189
189
190 @hypothesis.settings(
190 @hypothesis.settings(
191 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
191 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
192 )
192 )
193 @hypothesis.given(
193 @hypothesis.given(
194 original=strategies.sampled_from(random_input_data()),
194 original=strategies.sampled_from(random_input_data()),
195 level=strategies.integers(min_value=1, max_value=5),
195 level=strategies.integers(min_value=1, max_value=5),
196 streaming=strategies.booleans(),
196 streaming=strategies.booleans(),
197 source_read_size=strategies.integers(1, 1048576),
197 source_read_size=strategies.integers(1, 1048576),
198 )
198 )
199 def test_stream_source_readall(self, original, level, streaming, source_read_size):
199 def test_stream_source_readall(
200 self, original, level, streaming, source_read_size
201 ):
200 cctx = zstd.ZstdCompressor(level=level)
202 cctx = zstd.ZstdCompressor(level=level)
201
203
202 if streaming:
204 if streaming:
203 source = io.BytesIO()
205 source = io.BytesIO()
204 writer = cctx.stream_writer(source)
206 writer = cctx.stream_writer(source)
205 writer.write(original)
207 writer.write(original)
206 writer.flush(zstd.FLUSH_FRAME)
208 writer.flush(zstd.FLUSH_FRAME)
207 source.seek(0)
209 source.seek(0)
208 else:
210 else:
209 frame = cctx.compress(original)
211 frame = cctx.compress(original)
210 source = io.BytesIO(frame)
212 source = io.BytesIO(frame)
211
213
212 dctx = zstd.ZstdDecompressor()
214 dctx = zstd.ZstdDecompressor()
213
215
214 data = dctx.stream_reader(source, read_size=source_read_size).readall()
216 data = dctx.stream_reader(source, read_size=source_read_size).readall()
215 self.assertEqual(data, original)
217 self.assertEqual(data, original)
216
218
217 @hypothesis.settings(
219 @hypothesis.settings(
218 suppress_health_check=[
220 suppress_health_check=[
219 hypothesis.HealthCheck.large_base_example,
221 hypothesis.HealthCheck.large_base_example,
220 hypothesis.HealthCheck.too_slow,
222 hypothesis.HealthCheck.too_slow,
221 ]
223 ]
222 )
224 )
223 @hypothesis.given(
225 @hypothesis.given(
224 original=strategies.sampled_from(random_input_data()),
226 original=strategies.sampled_from(random_input_data()),
225 level=strategies.integers(min_value=1, max_value=5),
227 level=strategies.integers(min_value=1, max_value=5),
226 streaming=strategies.booleans(),
228 streaming=strategies.booleans(),
227 source_read_size=strategies.integers(1, 1048576),
229 source_read_size=strategies.integers(1, 1048576),
228 read_sizes=strategies.data(),
230 read_sizes=strategies.data(),
229 )
231 )
230 def test_stream_source_read1_variance(
232 def test_stream_source_read1_variance(
231 self, original, level, streaming, source_read_size, read_sizes
233 self, original, level, streaming, source_read_size, read_sizes
232 ):
234 ):
233 cctx = zstd.ZstdCompressor(level=level)
235 cctx = zstd.ZstdCompressor(level=level)
234
236
235 if streaming:
237 if streaming:
236 source = io.BytesIO()
238 source = io.BytesIO()
237 writer = cctx.stream_writer(source)
239 writer = cctx.stream_writer(source)
238 writer.write(original)
240 writer.write(original)
239 writer.flush(zstd.FLUSH_FRAME)
241 writer.flush(zstd.FLUSH_FRAME)
240 source.seek(0)
242 source.seek(0)
241 else:
243 else:
242 frame = cctx.compress(original)
244 frame = cctx.compress(original)
243 source = io.BytesIO(frame)
245 source = io.BytesIO(frame)
244
246
245 dctx = zstd.ZstdDecompressor()
247 dctx = zstd.ZstdDecompressor()
246
248
247 chunks = []
249 chunks = []
248 with dctx.stream_reader(source, read_size=source_read_size) as reader:
250 with dctx.stream_reader(source, read_size=source_read_size) as reader:
249 while True:
251 while True:
250 read_size = read_sizes.draw(strategies.integers(-1, 131072))
252 read_size = read_sizes.draw(strategies.integers(-1, 131072))
251 chunk = reader.read1(read_size)
253 chunk = reader.read1(read_size)
252 if not chunk and read_size:
254 if not chunk and read_size:
253 break
255 break
254
256
255 chunks.append(chunk)
257 chunks.append(chunk)
256
258
257 self.assertEqual(b"".join(chunks), original)
259 self.assertEqual(b"".join(chunks), original)
258
260
259 @hypothesis.settings(
261 @hypothesis.settings(
260 suppress_health_check=[
262 suppress_health_check=[
261 hypothesis.HealthCheck.large_base_example,
263 hypothesis.HealthCheck.large_base_example,
262 hypothesis.HealthCheck.too_slow,
264 hypothesis.HealthCheck.too_slow,
263 ]
265 ]
264 )
266 )
265 @hypothesis.given(
267 @hypothesis.given(
266 original=strategies.sampled_from(random_input_data()),
268 original=strategies.sampled_from(random_input_data()),
267 level=strategies.integers(min_value=1, max_value=5),
269 level=strategies.integers(min_value=1, max_value=5),
268 streaming=strategies.booleans(),
270 streaming=strategies.booleans(),
269 source_read_size=strategies.integers(1, 1048576),
271 source_read_size=strategies.integers(1, 1048576),
270 read_sizes=strategies.data(),
272 read_sizes=strategies.data(),
271 )
273 )
272 def test_stream_source_readinto1_variance(
274 def test_stream_source_readinto1_variance(
273 self, original, level, streaming, source_read_size, read_sizes
275 self, original, level, streaming, source_read_size, read_sizes
274 ):
276 ):
275 cctx = zstd.ZstdCompressor(level=level)
277 cctx = zstd.ZstdCompressor(level=level)
276
278
277 if streaming:
279 if streaming:
278 source = io.BytesIO()
280 source = io.BytesIO()
279 writer = cctx.stream_writer(source)
281 writer = cctx.stream_writer(source)
280 writer.write(original)
282 writer.write(original)
281 writer.flush(zstd.FLUSH_FRAME)
283 writer.flush(zstd.FLUSH_FRAME)
282 source.seek(0)
284 source.seek(0)
283 else:
285 else:
284 frame = cctx.compress(original)
286 frame = cctx.compress(original)
285 source = io.BytesIO(frame)
287 source = io.BytesIO(frame)
286
288
287 dctx = zstd.ZstdDecompressor()
289 dctx = zstd.ZstdDecompressor()
288
290
289 chunks = []
291 chunks = []
290 with dctx.stream_reader(source, read_size=source_read_size) as reader:
292 with dctx.stream_reader(source, read_size=source_read_size) as reader:
291 while True:
293 while True:
292 read_size = read_sizes.draw(strategies.integers(1, 131072))
294 read_size = read_sizes.draw(strategies.integers(1, 131072))
293 b = bytearray(read_size)
295 b = bytearray(read_size)
294 count = reader.readinto1(b)
296 count = reader.readinto1(b)
295
297
296 if not count:
298 if not count:
297 break
299 break
298
300
299 chunks.append(bytes(b[0:count]))
301 chunks.append(bytes(b[0:count]))
300
302
301 self.assertEqual(b"".join(chunks), original)
303 self.assertEqual(b"".join(chunks), original)
302
304
303 @hypothesis.settings(
305 @hypothesis.settings(
304 suppress_health_check=[
306 suppress_health_check=[
305 hypothesis.HealthCheck.large_base_example,
307 hypothesis.HealthCheck.large_base_example,
306 hypothesis.HealthCheck.too_slow,
308 hypothesis.HealthCheck.too_slow,
307 ]
309 ]
308 )
310 )
309 @hypothesis.given(
311 @hypothesis.given(
310 original=strategies.sampled_from(random_input_data()),
312 original=strategies.sampled_from(random_input_data()),
311 level=strategies.integers(min_value=1, max_value=5),
313 level=strategies.integers(min_value=1, max_value=5),
312 source_read_size=strategies.integers(1, 1048576),
314 source_read_size=strategies.integers(1, 1048576),
313 seek_amounts=strategies.data(),
315 seek_amounts=strategies.data(),
314 read_sizes=strategies.data(),
316 read_sizes=strategies.data(),
315 )
317 )
316 def test_relative_seeks(
318 def test_relative_seeks(
317 self, original, level, source_read_size, seek_amounts, read_sizes
319 self, original, level, source_read_size, seek_amounts, read_sizes
318 ):
320 ):
319 cctx = zstd.ZstdCompressor(level=level)
321 cctx = zstd.ZstdCompressor(level=level)
320 frame = cctx.compress(original)
322 frame = cctx.compress(original)
321
323
322 dctx = zstd.ZstdDecompressor()
324 dctx = zstd.ZstdDecompressor()
323
325
324 with dctx.stream_reader(frame, read_size=source_read_size) as reader:
326 with dctx.stream_reader(frame, read_size=source_read_size) as reader:
325 while True:
327 while True:
326 amount = seek_amounts.draw(strategies.integers(0, 16384))
328 amount = seek_amounts.draw(strategies.integers(0, 16384))
327 reader.seek(amount, os.SEEK_CUR)
329 reader.seek(amount, os.SEEK_CUR)
328
330
329 offset = reader.tell()
331 offset = reader.tell()
330 read_amount = read_sizes.draw(strategies.integers(1, 16384))
332 read_amount = read_sizes.draw(strategies.integers(1, 16384))
331 chunk = reader.read(read_amount)
333 chunk = reader.read(read_amount)
332
334
333 if not chunk:
335 if not chunk:
334 break
336 break
335
337
336 self.assertEqual(original[offset : offset + len(chunk)], chunk)
338 self.assertEqual(original[offset : offset + len(chunk)], chunk)
337
339
338 @hypothesis.settings(
340 @hypothesis.settings(
339 suppress_health_check=[
341 suppress_health_check=[
340 hypothesis.HealthCheck.large_base_example,
342 hypothesis.HealthCheck.large_base_example,
341 hypothesis.HealthCheck.too_slow,
343 hypothesis.HealthCheck.too_slow,
342 ]
344 ]
343 )
345 )
344 @hypothesis.given(
346 @hypothesis.given(
345 originals=strategies.data(),
347 originals=strategies.data(),
346 frame_count=strategies.integers(min_value=2, max_value=10),
348 frame_count=strategies.integers(min_value=2, max_value=10),
347 level=strategies.integers(min_value=1, max_value=5),
349 level=strategies.integers(min_value=1, max_value=5),
348 source_read_size=strategies.integers(1, 1048576),
350 source_read_size=strategies.integers(1, 1048576),
349 read_sizes=strategies.data(),
351 read_sizes=strategies.data(),
350 )
352 )
351 def test_multiple_frames(
353 def test_multiple_frames(
352 self, originals, frame_count, level, source_read_size, read_sizes
354 self, originals, frame_count, level, source_read_size, read_sizes
353 ):
355 ):
354
356
355 cctx = zstd.ZstdCompressor(level=level)
357 cctx = zstd.ZstdCompressor(level=level)
356 source = io.BytesIO()
358 source = io.BytesIO()
357 buffer = io.BytesIO()
359 buffer = io.BytesIO()
358 writer = cctx.stream_writer(buffer)
360 writer = cctx.stream_writer(buffer)
359
361
360 for i in range(frame_count):
362 for i in range(frame_count):
361 data = originals.draw(strategies.sampled_from(random_input_data()))
363 data = originals.draw(strategies.sampled_from(random_input_data()))
362 source.write(data)
364 source.write(data)
363 writer.write(data)
365 writer.write(data)
364 writer.flush(zstd.FLUSH_FRAME)
366 writer.flush(zstd.FLUSH_FRAME)
365
367
366 dctx = zstd.ZstdDecompressor()
368 dctx = zstd.ZstdDecompressor()
367 buffer.seek(0)
369 buffer.seek(0)
368 reader = dctx.stream_reader(
370 reader = dctx.stream_reader(
369 buffer, read_size=source_read_size, read_across_frames=True
371 buffer, read_size=source_read_size, read_across_frames=True
370 )
372 )
371
373
372 chunks = []
374 chunks = []
373
375
374 while True:
376 while True:
375 read_amount = read_sizes.draw(strategies.integers(-1, 16384))
377 read_amount = read_sizes.draw(strategies.integers(-1, 16384))
376 chunk = reader.read(read_amount)
378 chunk = reader.read(read_amount)
377
379
378 if not chunk and read_amount:
380 if not chunk and read_amount:
379 break
381 break
380
382
381 chunks.append(chunk)
383 chunks.append(chunk)
382
384
383 self.assertEqual(source.getvalue(), b"".join(chunks))
385 self.assertEqual(source.getvalue(), b"".join(chunks))
384
386
385
387
386 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
388 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
387 @make_cffi
389 @make_cffi
388 class TestDecompressor_stream_writer_fuzzing(TestCase):
390 class TestDecompressor_stream_writer_fuzzing(TestCase):
389 @hypothesis.settings(
391 @hypothesis.settings(
390 suppress_health_check=[
392 suppress_health_check=[
391 hypothesis.HealthCheck.large_base_example,
393 hypothesis.HealthCheck.large_base_example,
392 hypothesis.HealthCheck.too_slow,
394 hypothesis.HealthCheck.too_slow,
393 ]
395 ]
394 )
396 )
395 @hypothesis.given(
397 @hypothesis.given(
396 original=strategies.sampled_from(random_input_data()),
398 original=strategies.sampled_from(random_input_data()),
397 level=strategies.integers(min_value=1, max_value=5),
399 level=strategies.integers(min_value=1, max_value=5),
398 write_size=strategies.integers(min_value=1, max_value=8192),
400 write_size=strategies.integers(min_value=1, max_value=8192),
399 input_sizes=strategies.data(),
401 input_sizes=strategies.data(),
400 )
402 )
401 def test_write_size_variance(self, original, level, write_size, input_sizes):
403 def test_write_size_variance(
404 self, original, level, write_size, input_sizes
405 ):
402 cctx = zstd.ZstdCompressor(level=level)
406 cctx = zstd.ZstdCompressor(level=level)
403 frame = cctx.compress(original)
407 frame = cctx.compress(original)
404
408
405 dctx = zstd.ZstdDecompressor()
409 dctx = zstd.ZstdDecompressor()
406 source = io.BytesIO(frame)
410 source = io.BytesIO(frame)
407 dest = NonClosingBytesIO()
411 dest = NonClosingBytesIO()
408
412
409 with dctx.stream_writer(dest, write_size=write_size) as decompressor:
413 with dctx.stream_writer(dest, write_size=write_size) as decompressor:
410 while True:
414 while True:
411 input_size = input_sizes.draw(strategies.integers(1, 4096))
415 input_size = input_sizes.draw(strategies.integers(1, 4096))
412 chunk = source.read(input_size)
416 chunk = source.read(input_size)
413 if not chunk:
417 if not chunk:
414 break
418 break
415
419
416 decompressor.write(chunk)
420 decompressor.write(chunk)
417
421
418 self.assertEqual(dest.getvalue(), original)
422 self.assertEqual(dest.getvalue(), original)
419
423
420
424
421 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
425 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
422 @make_cffi
426 @make_cffi
423 class TestDecompressor_copy_stream_fuzzing(TestCase):
427 class TestDecompressor_copy_stream_fuzzing(TestCase):
424 @hypothesis.settings(
428 @hypothesis.settings(
425 suppress_health_check=[
429 suppress_health_check=[
426 hypothesis.HealthCheck.large_base_example,
430 hypothesis.HealthCheck.large_base_example,
427 hypothesis.HealthCheck.too_slow,
431 hypothesis.HealthCheck.too_slow,
428 ]
432 ]
429 )
433 )
430 @hypothesis.given(
434 @hypothesis.given(
431 original=strategies.sampled_from(random_input_data()),
435 original=strategies.sampled_from(random_input_data()),
432 level=strategies.integers(min_value=1, max_value=5),
436 level=strategies.integers(min_value=1, max_value=5),
433 read_size=strategies.integers(min_value=1, max_value=8192),
437 read_size=strategies.integers(min_value=1, max_value=8192),
434 write_size=strategies.integers(min_value=1, max_value=8192),
438 write_size=strategies.integers(min_value=1, max_value=8192),
435 )
439 )
436 def test_read_write_size_variance(self, original, level, read_size, write_size):
440 def test_read_write_size_variance(
441 self, original, level, read_size, write_size
442 ):
437 cctx = zstd.ZstdCompressor(level=level)
443 cctx = zstd.ZstdCompressor(level=level)
438 frame = cctx.compress(original)
444 frame = cctx.compress(original)
439
445
440 source = io.BytesIO(frame)
446 source = io.BytesIO(frame)
441 dest = io.BytesIO()
447 dest = io.BytesIO()
442
448
443 dctx = zstd.ZstdDecompressor()
449 dctx = zstd.ZstdDecompressor()
444 dctx.copy_stream(source, dest, read_size=read_size, write_size=write_size)
450 dctx.copy_stream(
451 source, dest, read_size=read_size, write_size=write_size
452 )
445
453
446 self.assertEqual(dest.getvalue(), original)
454 self.assertEqual(dest.getvalue(), original)
447
455
448
456
449 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
457 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
450 @make_cffi
458 @make_cffi
451 class TestDecompressor_decompressobj_fuzzing(TestCase):
459 class TestDecompressor_decompressobj_fuzzing(TestCase):
452 @hypothesis.settings(
460 @hypothesis.settings(
453 suppress_health_check=[
461 suppress_health_check=[
454 hypothesis.HealthCheck.large_base_example,
462 hypothesis.HealthCheck.large_base_example,
455 hypothesis.HealthCheck.too_slow,
463 hypothesis.HealthCheck.too_slow,
456 ]
464 ]
457 )
465 )
458 @hypothesis.given(
466 @hypothesis.given(
459 original=strategies.sampled_from(random_input_data()),
467 original=strategies.sampled_from(random_input_data()),
460 level=strategies.integers(min_value=1, max_value=5),
468 level=strategies.integers(min_value=1, max_value=5),
461 chunk_sizes=strategies.data(),
469 chunk_sizes=strategies.data(),
462 )
470 )
463 def test_random_input_sizes(self, original, level, chunk_sizes):
471 def test_random_input_sizes(self, original, level, chunk_sizes):
464 cctx = zstd.ZstdCompressor(level=level)
472 cctx = zstd.ZstdCompressor(level=level)
465 frame = cctx.compress(original)
473 frame = cctx.compress(original)
466
474
467 source = io.BytesIO(frame)
475 source = io.BytesIO(frame)
468
476
469 dctx = zstd.ZstdDecompressor()
477 dctx = zstd.ZstdDecompressor()
470 dobj = dctx.decompressobj()
478 dobj = dctx.decompressobj()
471
479
472 chunks = []
480 chunks = []
473 while True:
481 while True:
474 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
482 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
475 chunk = source.read(chunk_size)
483 chunk = source.read(chunk_size)
476 if not chunk:
484 if not chunk:
477 break
485 break
478
486
479 chunks.append(dobj.decompress(chunk))
487 chunks.append(dobj.decompress(chunk))
480
488
481 self.assertEqual(b"".join(chunks), original)
489 self.assertEqual(b"".join(chunks), original)
482
490
483 @hypothesis.settings(
491 @hypothesis.settings(
484 suppress_health_check=[
492 suppress_health_check=[
485 hypothesis.HealthCheck.large_base_example,
493 hypothesis.HealthCheck.large_base_example,
486 hypothesis.HealthCheck.too_slow,
494 hypothesis.HealthCheck.too_slow,
487 ]
495 ]
488 )
496 )
489 @hypothesis.given(
497 @hypothesis.given(
490 original=strategies.sampled_from(random_input_data()),
498 original=strategies.sampled_from(random_input_data()),
491 level=strategies.integers(min_value=1, max_value=5),
499 level=strategies.integers(min_value=1, max_value=5),
492 write_size=strategies.integers(
500 write_size=strategies.integers(
493 min_value=1, max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE
501 min_value=1,
502 max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
494 ),
503 ),
495 chunk_sizes=strategies.data(),
504 chunk_sizes=strategies.data(),
496 )
505 )
497 def test_random_output_sizes(self, original, level, write_size, chunk_sizes):
506 def test_random_output_sizes(
507 self, original, level, write_size, chunk_sizes
508 ):
498 cctx = zstd.ZstdCompressor(level=level)
509 cctx = zstd.ZstdCompressor(level=level)
499 frame = cctx.compress(original)
510 frame = cctx.compress(original)
500
511
501 source = io.BytesIO(frame)
512 source = io.BytesIO(frame)
502
513
503 dctx = zstd.ZstdDecompressor()
514 dctx = zstd.ZstdDecompressor()
504 dobj = dctx.decompressobj(write_size=write_size)
515 dobj = dctx.decompressobj(write_size=write_size)
505
516
506 chunks = []
517 chunks = []
507 while True:
518 while True:
508 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
519 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
509 chunk = source.read(chunk_size)
520 chunk = source.read(chunk_size)
510 if not chunk:
521 if not chunk:
511 break
522 break
512
523
513 chunks.append(dobj.decompress(chunk))
524 chunks.append(dobj.decompress(chunk))
514
525
515 self.assertEqual(b"".join(chunks), original)
526 self.assertEqual(b"".join(chunks), original)
516
527
517
528
518 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
529 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
519 @make_cffi
530 @make_cffi
520 class TestDecompressor_read_to_iter_fuzzing(TestCase):
531 class TestDecompressor_read_to_iter_fuzzing(TestCase):
521 @hypothesis.given(
532 @hypothesis.given(
522 original=strategies.sampled_from(random_input_data()),
533 original=strategies.sampled_from(random_input_data()),
523 level=strategies.integers(min_value=1, max_value=5),
534 level=strategies.integers(min_value=1, max_value=5),
524 read_size=strategies.integers(min_value=1, max_value=4096),
535 read_size=strategies.integers(min_value=1, max_value=4096),
525 write_size=strategies.integers(min_value=1, max_value=4096),
536 write_size=strategies.integers(min_value=1, max_value=4096),
526 )
537 )
527 def test_read_write_size_variance(self, original, level, read_size, write_size):
538 def test_read_write_size_variance(
539 self, original, level, read_size, write_size
540 ):
528 cctx = zstd.ZstdCompressor(level=level)
541 cctx = zstd.ZstdCompressor(level=level)
529 frame = cctx.compress(original)
542 frame = cctx.compress(original)
530
543
531 source = io.BytesIO(frame)
544 source = io.BytesIO(frame)
532
545
533 dctx = zstd.ZstdDecompressor()
546 dctx = zstd.ZstdDecompressor()
534 chunks = list(
547 chunks = list(
535 dctx.read_to_iter(source, read_size=read_size, write_size=write_size)
548 dctx.read_to_iter(
549 source, read_size=read_size, write_size=write_size
550 )
536 )
551 )
537
552
538 self.assertEqual(b"".join(chunks), original)
553 self.assertEqual(b"".join(chunks), original)
539
554
540
555
541 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
556 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
542 class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase):
557 class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase):
543 @hypothesis.given(
558 @hypothesis.given(
544 original=strategies.lists(
559 original=strategies.lists(
545 strategies.sampled_from(random_input_data()), min_size=1, max_size=1024
560 strategies.sampled_from(random_input_data()),
561 min_size=1,
562 max_size=1024,
546 ),
563 ),
547 threads=strategies.integers(min_value=1, max_value=8),
564 threads=strategies.integers(min_value=1, max_value=8),
548 use_dict=strategies.booleans(),
565 use_dict=strategies.booleans(),
549 )
566 )
550 def test_data_equivalence(self, original, threads, use_dict):
567 def test_data_equivalence(self, original, threads, use_dict):
551 kwargs = {}
568 kwargs = {}
552 if use_dict:
569 if use_dict:
553 kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0])
570 kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0])
554
571
555 cctx = zstd.ZstdCompressor(
572 cctx = zstd.ZstdCompressor(
556 level=1, write_content_size=True, write_checksum=True, **kwargs
573 level=1, write_content_size=True, write_checksum=True, **kwargs
557 )
574 )
558
575
559 if not hasattr(cctx, "multi_compress_to_buffer"):
576 if not hasattr(cctx, "multi_compress_to_buffer"):
560 self.skipTest("multi_compress_to_buffer not available")
577 self.skipTest("multi_compress_to_buffer not available")
561
578
562 frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1)
579 frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1)
563
580
564 dctx = zstd.ZstdDecompressor(**kwargs)
581 dctx = zstd.ZstdDecompressor(**kwargs)
565 result = dctx.multi_decompress_to_buffer(frames_buffer)
582 result = dctx.multi_decompress_to_buffer(frames_buffer)
566
583
567 self.assertEqual(len(result), len(original))
584 self.assertEqual(len(result), len(original))
568 for i, frame in enumerate(result):
585 for i, frame in enumerate(result):
569 self.assertEqual(frame.tobytes(), original[i])
586 self.assertEqual(frame.tobytes(), original[i])
570
587
571 frames_list = [f.tobytes() for f in frames_buffer]
588 frames_list = [f.tobytes() for f in frames_buffer]
572 result = dctx.multi_decompress_to_buffer(frames_list)
589 result = dctx.multi_decompress_to_buffer(frames_list)
573
590
574 self.assertEqual(len(result), len(original))
591 self.assertEqual(len(result), len(original))
575 for i, frame in enumerate(result):
592 for i, frame in enumerate(result):
576 self.assertEqual(frame.tobytes(), original[i])
593 self.assertEqual(frame.tobytes(), original[i])
@@ -1,92 +1,102 b''
1 import struct
1 import struct
2 import sys
2 import sys
3 import unittest
3 import unittest
4
4
5 import zstandard as zstd
5 import zstandard as zstd
6
6
7 from .common import (
7 from .common import (
8 generate_samples,
8 generate_samples,
9 make_cffi,
9 make_cffi,
10 random_input_data,
10 random_input_data,
11 TestCase,
11 TestCase,
12 )
12 )
13
13
14 if sys.version_info[0] >= 3:
14 if sys.version_info[0] >= 3:
15 int_type = int
15 int_type = int
16 else:
16 else:
17 int_type = long
17 int_type = long
18
18
19
19
20 @make_cffi
20 @make_cffi
21 class TestTrainDictionary(TestCase):
21 class TestTrainDictionary(TestCase):
22 def test_no_args(self):
22 def test_no_args(self):
23 with self.assertRaises(TypeError):
23 with self.assertRaises(TypeError):
24 zstd.train_dictionary()
24 zstd.train_dictionary()
25
25
26 def test_bad_args(self):
26 def test_bad_args(self):
27 with self.assertRaises(TypeError):
27 with self.assertRaises(TypeError):
28 zstd.train_dictionary(8192, u"foo")
28 zstd.train_dictionary(8192, u"foo")
29
29
30 with self.assertRaises(ValueError):
30 with self.assertRaises(ValueError):
31 zstd.train_dictionary(8192, [u"foo"])
31 zstd.train_dictionary(8192, [u"foo"])
32
32
33 def test_no_params(self):
33 def test_no_params(self):
34 d = zstd.train_dictionary(8192, random_input_data())
34 d = zstd.train_dictionary(8192, random_input_data())
35 self.assertIsInstance(d.dict_id(), int_type)
35 self.assertIsInstance(d.dict_id(), int_type)
36
36
37 # The dictionary ID may be different across platforms.
37 # The dictionary ID may be different across platforms.
38 expected = b"\x37\xa4\x30\xec" + struct.pack("<I", d.dict_id())
38 expected = b"\x37\xa4\x30\xec" + struct.pack("<I", d.dict_id())
39
39
40 data = d.as_bytes()
40 data = d.as_bytes()
41 self.assertEqual(data[0:8], expected)
41 self.assertEqual(data[0:8], expected)
42
42
43 def test_basic(self):
43 def test_basic(self):
44 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
44 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
45 self.assertIsInstance(d.dict_id(), int_type)
45 self.assertIsInstance(d.dict_id(), int_type)
46
46
47 data = d.as_bytes()
47 data = d.as_bytes()
48 self.assertEqual(data[0:4], b"\x37\xa4\x30\xec")
48 self.assertEqual(data[0:4], b"\x37\xa4\x30\xec")
49
49
50 self.assertEqual(d.k, 64)
50 self.assertEqual(d.k, 64)
51 self.assertEqual(d.d, 16)
51 self.assertEqual(d.d, 16)
52
52
53 def test_set_dict_id(self):
53 def test_set_dict_id(self):
54 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16, dict_id=42)
54 d = zstd.train_dictionary(
55 8192, generate_samples(), k=64, d=16, dict_id=42
56 )
55 self.assertEqual(d.dict_id(), 42)
57 self.assertEqual(d.dict_id(), 42)
56
58
57 def test_optimize(self):
59 def test_optimize(self):
58 d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1, d=16)
60 d = zstd.train_dictionary(
61 8192, generate_samples(), threads=-1, steps=1, d=16
62 )
59
63
60 # This varies by platform.
64 # This varies by platform.
61 self.assertIn(d.k, (50, 2000))
65 self.assertIn(d.k, (50, 2000))
62 self.assertEqual(d.d, 16)
66 self.assertEqual(d.d, 16)
63
67
64
68
65 @make_cffi
69 @make_cffi
66 class TestCompressionDict(TestCase):
70 class TestCompressionDict(TestCase):
67 def test_bad_mode(self):
71 def test_bad_mode(self):
68 with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"):
72 with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"):
69 zstd.ZstdCompressionDict(b"foo", dict_type=42)
73 zstd.ZstdCompressionDict(b"foo", dict_type=42)
70
74
71 def test_bad_precompute_compress(self):
75 def test_bad_precompute_compress(self):
72 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
76 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
73
77
74 with self.assertRaisesRegex(ValueError, "must specify one of level or "):
78 with self.assertRaisesRegex(
79 ValueError, "must specify one of level or "
80 ):
75 d.precompute_compress()
81 d.precompute_compress()
76
82
77 with self.assertRaisesRegex(ValueError, "must only specify one of level or "):
83 with self.assertRaisesRegex(
84 ValueError, "must only specify one of level or "
85 ):
78 d.precompute_compress(
86 d.precompute_compress(
79 level=3, compression_params=zstd.CompressionParameters()
87 level=3, compression_params=zstd.CompressionParameters()
80 )
88 )
81
89
82 def test_precompute_compress_rawcontent(self):
90 def test_precompute_compress_rawcontent(self):
83 d = zstd.ZstdCompressionDict(
91 d = zstd.ZstdCompressionDict(
84 b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT
92 b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT
85 )
93 )
86 d.precompute_compress(level=1)
94 d.precompute_compress(level=1)
87
95
88 d = zstd.ZstdCompressionDict(
96 d = zstd.ZstdCompressionDict(
89 b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT
97 b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT
90 )
98 )
91 with self.assertRaisesRegex(zstd.ZstdError, "unable to precompute dictionary"):
99 with self.assertRaisesRegex(
100 zstd.ZstdError, "unable to precompute dictionary"
101 ):
92 d.precompute_compress(level=1)
102 d.precompute_compress(level=1)
@@ -1,2615 +1,2769 b''
1 # Copyright (c) 2016-present, Gregory Szorc
1 # Copyright (c) 2016-present, Gregory Szorc
2 # All rights reserved.
2 # All rights reserved.
3 #
3 #
4 # This software may be modified and distributed under the terms
4 # This software may be modified and distributed under the terms
5 # of the BSD license. See the LICENSE file for details.
5 # of the BSD license. See the LICENSE file for details.
6
6
7 """Python interface to the Zstandard (zstd) compression library."""
7 """Python interface to the Zstandard (zstd) compression library."""
8
8
9 from __future__ import absolute_import, unicode_literals
9 from __future__ import absolute_import, unicode_literals
10
10
11 # This should match what the C extension exports.
11 # This should match what the C extension exports.
12 __all__ = [
12 __all__ = [
13 #'BufferSegment',
13 #'BufferSegment',
14 #'BufferSegments',
14 #'BufferSegments',
15 #'BufferWithSegments',
15 #'BufferWithSegments',
16 #'BufferWithSegmentsCollection',
16 #'BufferWithSegmentsCollection',
17 "CompressionParameters",
17 "CompressionParameters",
18 "ZstdCompressionDict",
18 "ZstdCompressionDict",
19 "ZstdCompressionParameters",
19 "ZstdCompressionParameters",
20 "ZstdCompressor",
20 "ZstdCompressor",
21 "ZstdError",
21 "ZstdError",
22 "ZstdDecompressor",
22 "ZstdDecompressor",
23 "FrameParameters",
23 "FrameParameters",
24 "estimate_decompression_context_size",
24 "estimate_decompression_context_size",
25 "frame_content_size",
25 "frame_content_size",
26 "frame_header_size",
26 "frame_header_size",
27 "get_frame_parameters",
27 "get_frame_parameters",
28 "train_dictionary",
28 "train_dictionary",
29 # Constants.
29 # Constants.
30 "FLUSH_BLOCK",
30 "FLUSH_BLOCK",
31 "FLUSH_FRAME",
31 "FLUSH_FRAME",
32 "COMPRESSOBJ_FLUSH_FINISH",
32 "COMPRESSOBJ_FLUSH_FINISH",
33 "COMPRESSOBJ_FLUSH_BLOCK",
33 "COMPRESSOBJ_FLUSH_BLOCK",
34 "ZSTD_VERSION",
34 "ZSTD_VERSION",
35 "FRAME_HEADER",
35 "FRAME_HEADER",
36 "CONTENTSIZE_UNKNOWN",
36 "CONTENTSIZE_UNKNOWN",
37 "CONTENTSIZE_ERROR",
37 "CONTENTSIZE_ERROR",
38 "MAX_COMPRESSION_LEVEL",
38 "MAX_COMPRESSION_LEVEL",
39 "COMPRESSION_RECOMMENDED_INPUT_SIZE",
39 "COMPRESSION_RECOMMENDED_INPUT_SIZE",
40 "COMPRESSION_RECOMMENDED_OUTPUT_SIZE",
40 "COMPRESSION_RECOMMENDED_OUTPUT_SIZE",
41 "DECOMPRESSION_RECOMMENDED_INPUT_SIZE",
41 "DECOMPRESSION_RECOMMENDED_INPUT_SIZE",
42 "DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE",
42 "DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE",
43 "MAGIC_NUMBER",
43 "MAGIC_NUMBER",
44 "BLOCKSIZELOG_MAX",
44 "BLOCKSIZELOG_MAX",
45 "BLOCKSIZE_MAX",
45 "BLOCKSIZE_MAX",
46 "WINDOWLOG_MIN",
46 "WINDOWLOG_MIN",
47 "WINDOWLOG_MAX",
47 "WINDOWLOG_MAX",
48 "CHAINLOG_MIN",
48 "CHAINLOG_MIN",
49 "CHAINLOG_MAX",
49 "CHAINLOG_MAX",
50 "HASHLOG_MIN",
50 "HASHLOG_MIN",
51 "HASHLOG_MAX",
51 "HASHLOG_MAX",
52 "HASHLOG3_MAX",
52 "HASHLOG3_MAX",
53 "MINMATCH_MIN",
53 "MINMATCH_MIN",
54 "MINMATCH_MAX",
54 "MINMATCH_MAX",
55 "SEARCHLOG_MIN",
55 "SEARCHLOG_MIN",
56 "SEARCHLOG_MAX",
56 "SEARCHLOG_MAX",
57 "SEARCHLENGTH_MIN",
57 "SEARCHLENGTH_MIN",
58 "SEARCHLENGTH_MAX",
58 "SEARCHLENGTH_MAX",
59 "TARGETLENGTH_MIN",
59 "TARGETLENGTH_MIN",
60 "TARGETLENGTH_MAX",
60 "TARGETLENGTH_MAX",
61 "LDM_MINMATCH_MIN",
61 "LDM_MINMATCH_MIN",
62 "LDM_MINMATCH_MAX",
62 "LDM_MINMATCH_MAX",
63 "LDM_BUCKETSIZELOG_MAX",
63 "LDM_BUCKETSIZELOG_MAX",
64 "STRATEGY_FAST",
64 "STRATEGY_FAST",
65 "STRATEGY_DFAST",
65 "STRATEGY_DFAST",
66 "STRATEGY_GREEDY",
66 "STRATEGY_GREEDY",
67 "STRATEGY_LAZY",
67 "STRATEGY_LAZY",
68 "STRATEGY_LAZY2",
68 "STRATEGY_LAZY2",
69 "STRATEGY_BTLAZY2",
69 "STRATEGY_BTLAZY2",
70 "STRATEGY_BTOPT",
70 "STRATEGY_BTOPT",
71 "STRATEGY_BTULTRA",
71 "STRATEGY_BTULTRA",
72 "STRATEGY_BTULTRA2",
72 "STRATEGY_BTULTRA2",
73 "DICT_TYPE_AUTO",
73 "DICT_TYPE_AUTO",
74 "DICT_TYPE_RAWCONTENT",
74 "DICT_TYPE_RAWCONTENT",
75 "DICT_TYPE_FULLDICT",
75 "DICT_TYPE_FULLDICT",
76 "FORMAT_ZSTD1",
76 "FORMAT_ZSTD1",
77 "FORMAT_ZSTD1_MAGICLESS",
77 "FORMAT_ZSTD1_MAGICLESS",
78 ]
78 ]
79
79
80 import io
80 import io
81 import os
81 import os
82 import sys
82 import sys
83
83
84 from _zstd_cffi import (
84 from _zstd_cffi import (
85 ffi,
85 ffi,
86 lib,
86 lib,
87 )
87 )
88
88
89 if sys.version_info[0] == 2:
89 if sys.version_info[0] == 2:
90 bytes_type = str
90 bytes_type = str
91 int_type = long
91 int_type = long
92 else:
92 else:
93 bytes_type = bytes
93 bytes_type = bytes
94 int_type = int
94 int_type = int
95
95
96
96
97 COMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_CStreamInSize()
97 COMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_CStreamInSize()
98 COMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_CStreamOutSize()
98 COMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_CStreamOutSize()
99 DECOMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_DStreamInSize()
99 DECOMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_DStreamInSize()
100 DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_DStreamOutSize()
100 DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_DStreamOutSize()
101
101
102 new_nonzero = ffi.new_allocator(should_clear_after_alloc=False)
102 new_nonzero = ffi.new_allocator(should_clear_after_alloc=False)
103
103
104
104
105 MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel()
105 MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel()
106 MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER
106 MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER
107 FRAME_HEADER = b"\x28\xb5\x2f\xfd"
107 FRAME_HEADER = b"\x28\xb5\x2f\xfd"
108 CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN
108 CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN
109 CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR
109 CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR
110 ZSTD_VERSION = (
110 ZSTD_VERSION = (
111 lib.ZSTD_VERSION_MAJOR,
111 lib.ZSTD_VERSION_MAJOR,
112 lib.ZSTD_VERSION_MINOR,
112 lib.ZSTD_VERSION_MINOR,
113 lib.ZSTD_VERSION_RELEASE,
113 lib.ZSTD_VERSION_RELEASE,
114 )
114 )
115
115
116 BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX
116 BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX
117 BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX
117 BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX
118 WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN
118 WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN
119 WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX
119 WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX
120 CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN
120 CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN
121 CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX
121 CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX
122 HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN
122 HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN
123 HASHLOG_MAX = lib.ZSTD_HASHLOG_MAX
123 HASHLOG_MAX = lib.ZSTD_HASHLOG_MAX
124 HASHLOG3_MAX = lib.ZSTD_HASHLOG3_MAX
124 HASHLOG3_MAX = lib.ZSTD_HASHLOG3_MAX
125 MINMATCH_MIN = lib.ZSTD_MINMATCH_MIN
125 MINMATCH_MIN = lib.ZSTD_MINMATCH_MIN
126 MINMATCH_MAX = lib.ZSTD_MINMATCH_MAX
126 MINMATCH_MAX = lib.ZSTD_MINMATCH_MAX
127 SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN
127 SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN
128 SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX
128 SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX
129 SEARCHLENGTH_MIN = lib.ZSTD_MINMATCH_MIN
129 SEARCHLENGTH_MIN = lib.ZSTD_MINMATCH_MIN
130 SEARCHLENGTH_MAX = lib.ZSTD_MINMATCH_MAX
130 SEARCHLENGTH_MAX = lib.ZSTD_MINMATCH_MAX
131 TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN
131 TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN
132 TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX
132 TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX
133 LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN
133 LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN
134 LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX
134 LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX
135 LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX
135 LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX
136
136
137 STRATEGY_FAST = lib.ZSTD_fast
137 STRATEGY_FAST = lib.ZSTD_fast
138 STRATEGY_DFAST = lib.ZSTD_dfast
138 STRATEGY_DFAST = lib.ZSTD_dfast
139 STRATEGY_GREEDY = lib.ZSTD_greedy
139 STRATEGY_GREEDY = lib.ZSTD_greedy
140 STRATEGY_LAZY = lib.ZSTD_lazy
140 STRATEGY_LAZY = lib.ZSTD_lazy
141 STRATEGY_LAZY2 = lib.ZSTD_lazy2
141 STRATEGY_LAZY2 = lib.ZSTD_lazy2
142 STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2
142 STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2
143 STRATEGY_BTOPT = lib.ZSTD_btopt
143 STRATEGY_BTOPT = lib.ZSTD_btopt
144 STRATEGY_BTULTRA = lib.ZSTD_btultra
144 STRATEGY_BTULTRA = lib.ZSTD_btultra
145 STRATEGY_BTULTRA2 = lib.ZSTD_btultra2
145 STRATEGY_BTULTRA2 = lib.ZSTD_btultra2
146
146
147 DICT_TYPE_AUTO = lib.ZSTD_dct_auto
147 DICT_TYPE_AUTO = lib.ZSTD_dct_auto
148 DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent
148 DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent
149 DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict
149 DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict
150
150
151 FORMAT_ZSTD1 = lib.ZSTD_f_zstd1
151 FORMAT_ZSTD1 = lib.ZSTD_f_zstd1
152 FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless
152 FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless
153
153
154 FLUSH_BLOCK = 0
154 FLUSH_BLOCK = 0
155 FLUSH_FRAME = 1
155 FLUSH_FRAME = 1
156
156
157 COMPRESSOBJ_FLUSH_FINISH = 0
157 COMPRESSOBJ_FLUSH_FINISH = 0
158 COMPRESSOBJ_FLUSH_BLOCK = 1
158 COMPRESSOBJ_FLUSH_BLOCK = 1
159
159
160
160
161 def _cpu_count():
161 def _cpu_count():
162 # os.cpu_count() was introducd in Python 3.4.
162 # os.cpu_count() was introducd in Python 3.4.
163 try:
163 try:
164 return os.cpu_count() or 0
164 return os.cpu_count() or 0
165 except AttributeError:
165 except AttributeError:
166 pass
166 pass
167
167
168 # Linux.
168 # Linux.
169 try:
169 try:
170 if sys.version_info[0] == 2:
170 if sys.version_info[0] == 2:
171 return os.sysconf(b"SC_NPROCESSORS_ONLN")
171 return os.sysconf(b"SC_NPROCESSORS_ONLN")
172 else:
172 else:
173 return os.sysconf("SC_NPROCESSORS_ONLN")
173 return os.sysconf("SC_NPROCESSORS_ONLN")
174 except (AttributeError, ValueError):
174 except (AttributeError, ValueError):
175 pass
175 pass
176
176
177 # TODO implement on other platforms.
177 # TODO implement on other platforms.
178 return 0
178 return 0
179
179
180
180
181 class ZstdError(Exception):
181 class ZstdError(Exception):
182 pass
182 pass
183
183
184
184
185 def _zstd_error(zresult):
185 def _zstd_error(zresult):
186 # Resolves to bytes on Python 2 and 3. We use the string for formatting
186 # Resolves to bytes on Python 2 and 3. We use the string for formatting
187 # into error messages, which will be literal unicode. So convert it to
187 # into error messages, which will be literal unicode. So convert it to
188 # unicode.
188 # unicode.
189 return ffi.string(lib.ZSTD_getErrorName(zresult)).decode("utf-8")
189 return ffi.string(lib.ZSTD_getErrorName(zresult)).decode("utf-8")
190
190
191
191
192 def _make_cctx_params(params):
192 def _make_cctx_params(params):
193 res = lib.ZSTD_createCCtxParams()
193 res = lib.ZSTD_createCCtxParams()
194 if res == ffi.NULL:
194 if res == ffi.NULL:
195 raise MemoryError()
195 raise MemoryError()
196
196
197 res = ffi.gc(res, lib.ZSTD_freeCCtxParams)
197 res = ffi.gc(res, lib.ZSTD_freeCCtxParams)
198
198
199 attrs = [
199 attrs = [
200 (lib.ZSTD_c_format, params.format),
200 (lib.ZSTD_c_format, params.format),
201 (lib.ZSTD_c_compressionLevel, params.compression_level),
201 (lib.ZSTD_c_compressionLevel, params.compression_level),
202 (lib.ZSTD_c_windowLog, params.window_log),
202 (lib.ZSTD_c_windowLog, params.window_log),
203 (lib.ZSTD_c_hashLog, params.hash_log),
203 (lib.ZSTD_c_hashLog, params.hash_log),
204 (lib.ZSTD_c_chainLog, params.chain_log),
204 (lib.ZSTD_c_chainLog, params.chain_log),
205 (lib.ZSTD_c_searchLog, params.search_log),
205 (lib.ZSTD_c_searchLog, params.search_log),
206 (lib.ZSTD_c_minMatch, params.min_match),
206 (lib.ZSTD_c_minMatch, params.min_match),
207 (lib.ZSTD_c_targetLength, params.target_length),
207 (lib.ZSTD_c_targetLength, params.target_length),
208 (lib.ZSTD_c_strategy, params.compression_strategy),
208 (lib.ZSTD_c_strategy, params.compression_strategy),
209 (lib.ZSTD_c_contentSizeFlag, params.write_content_size),
209 (lib.ZSTD_c_contentSizeFlag, params.write_content_size),
210 (lib.ZSTD_c_checksumFlag, params.write_checksum),
210 (lib.ZSTD_c_checksumFlag, params.write_checksum),
211 (lib.ZSTD_c_dictIDFlag, params.write_dict_id),
211 (lib.ZSTD_c_dictIDFlag, params.write_dict_id),
212 (lib.ZSTD_c_nbWorkers, params.threads),
212 (lib.ZSTD_c_nbWorkers, params.threads),
213 (lib.ZSTD_c_jobSize, params.job_size),
213 (lib.ZSTD_c_jobSize, params.job_size),
214 (lib.ZSTD_c_overlapLog, params.overlap_log),
214 (lib.ZSTD_c_overlapLog, params.overlap_log),
215 (lib.ZSTD_c_forceMaxWindow, params.force_max_window),
215 (lib.ZSTD_c_forceMaxWindow, params.force_max_window),
216 (lib.ZSTD_c_enableLongDistanceMatching, params.enable_ldm),
216 (lib.ZSTD_c_enableLongDistanceMatching, params.enable_ldm),
217 (lib.ZSTD_c_ldmHashLog, params.ldm_hash_log),
217 (lib.ZSTD_c_ldmHashLog, params.ldm_hash_log),
218 (lib.ZSTD_c_ldmMinMatch, params.ldm_min_match),
218 (lib.ZSTD_c_ldmMinMatch, params.ldm_min_match),
219 (lib.ZSTD_c_ldmBucketSizeLog, params.ldm_bucket_size_log),
219 (lib.ZSTD_c_ldmBucketSizeLog, params.ldm_bucket_size_log),
220 (lib.ZSTD_c_ldmHashRateLog, params.ldm_hash_rate_log),
220 (lib.ZSTD_c_ldmHashRateLog, params.ldm_hash_rate_log),
221 ]
221 ]
222
222
223 for param, value in attrs:
223 for param, value in attrs:
224 _set_compression_parameter(res, param, value)
224 _set_compression_parameter(res, param, value)
225
225
226 return res
226 return res
227
227
228
228
229 class ZstdCompressionParameters(object):
229 class ZstdCompressionParameters(object):
230 @staticmethod
230 @staticmethod
231 def from_level(level, source_size=0, dict_size=0, **kwargs):
231 def from_level(level, source_size=0, dict_size=0, **kwargs):
232 params = lib.ZSTD_getCParams(level, source_size, dict_size)
232 params = lib.ZSTD_getCParams(level, source_size, dict_size)
233
233
234 args = {
234 args = {
235 "window_log": "windowLog",
235 "window_log": "windowLog",
236 "chain_log": "chainLog",
236 "chain_log": "chainLog",
237 "hash_log": "hashLog",
237 "hash_log": "hashLog",
238 "search_log": "searchLog",
238 "search_log": "searchLog",
239 "min_match": "minMatch",
239 "min_match": "minMatch",
240 "target_length": "targetLength",
240 "target_length": "targetLength",
241 "compression_strategy": "strategy",
241 "compression_strategy": "strategy",
242 }
242 }
243
243
244 for arg, attr in args.items():
244 for arg, attr in args.items():
245 if arg not in kwargs:
245 if arg not in kwargs:
246 kwargs[arg] = getattr(params, attr)
246 kwargs[arg] = getattr(params, attr)
247
247
248 return ZstdCompressionParameters(**kwargs)
248 return ZstdCompressionParameters(**kwargs)
249
249
250 def __init__(
250 def __init__(
251 self,
251 self,
252 format=0,
252 format=0,
253 compression_level=0,
253 compression_level=0,
254 window_log=0,
254 window_log=0,
255 hash_log=0,
255 hash_log=0,
256 chain_log=0,
256 chain_log=0,
257 search_log=0,
257 search_log=0,
258 min_match=0,
258 min_match=0,
259 target_length=0,
259 target_length=0,
260 strategy=-1,
260 strategy=-1,
261 compression_strategy=-1,
261 compression_strategy=-1,
262 write_content_size=1,
262 write_content_size=1,
263 write_checksum=0,
263 write_checksum=0,
264 write_dict_id=0,
264 write_dict_id=0,
265 job_size=0,
265 job_size=0,
266 overlap_log=-1,
266 overlap_log=-1,
267 overlap_size_log=-1,
267 overlap_size_log=-1,
268 force_max_window=0,
268 force_max_window=0,
269 enable_ldm=0,
269 enable_ldm=0,
270 ldm_hash_log=0,
270 ldm_hash_log=0,
271 ldm_min_match=0,
271 ldm_min_match=0,
272 ldm_bucket_size_log=0,
272 ldm_bucket_size_log=0,
273 ldm_hash_rate_log=-1,
273 ldm_hash_rate_log=-1,
274 ldm_hash_every_log=-1,
274 ldm_hash_every_log=-1,
275 threads=0,
275 threads=0,
276 ):
276 ):
277
277
278 params = lib.ZSTD_createCCtxParams()
278 params = lib.ZSTD_createCCtxParams()
279 if params == ffi.NULL:
279 if params == ffi.NULL:
280 raise MemoryError()
280 raise MemoryError()
281
281
282 params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
282 params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
283
283
284 self._params = params
284 self._params = params
285
285
286 if threads < 0:
286 if threads < 0:
287 threads = _cpu_count()
287 threads = _cpu_count()
288
288
289 # We need to set ZSTD_c_nbWorkers before ZSTD_c_jobSize and ZSTD_c_overlapLog
289 # We need to set ZSTD_c_nbWorkers before ZSTD_c_jobSize and ZSTD_c_overlapLog
290 # because setting ZSTD_c_nbWorkers resets the other parameters.
290 # because setting ZSTD_c_nbWorkers resets the other parameters.
291 _set_compression_parameter(params, lib.ZSTD_c_nbWorkers, threads)
291 _set_compression_parameter(params, lib.ZSTD_c_nbWorkers, threads)
292
292
293 _set_compression_parameter(params, lib.ZSTD_c_format, format)
293 _set_compression_parameter(params, lib.ZSTD_c_format, format)
294 _set_compression_parameter(
294 _set_compression_parameter(
295 params, lib.ZSTD_c_compressionLevel, compression_level
295 params, lib.ZSTD_c_compressionLevel, compression_level
296 )
296 )
297 _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log)
297 _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log)
298 _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log)
298 _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log)
299 _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log)
299 _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log)
300 _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log)
300 _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log)
301 _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match)
301 _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match)
302 _set_compression_parameter(params, lib.ZSTD_c_targetLength, target_length)
302 _set_compression_parameter(
303 params, lib.ZSTD_c_targetLength, target_length
304 )
303
305
304 if strategy != -1 and compression_strategy != -1:
306 if strategy != -1 and compression_strategy != -1:
305 raise ValueError("cannot specify both compression_strategy and strategy")
307 raise ValueError(
308 "cannot specify both compression_strategy and strategy"
309 )
306
310
307 if compression_strategy != -1:
311 if compression_strategy != -1:
308 strategy = compression_strategy
312 strategy = compression_strategy
309 elif strategy == -1:
313 elif strategy == -1:
310 strategy = 0
314 strategy = 0
311
315
312 _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy)
316 _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy)
313 _set_compression_parameter(
317 _set_compression_parameter(
314 params, lib.ZSTD_c_contentSizeFlag, write_content_size
318 params, lib.ZSTD_c_contentSizeFlag, write_content_size
315 )
319 )
316 _set_compression_parameter(params, lib.ZSTD_c_checksumFlag, write_checksum)
320 _set_compression_parameter(
321 params, lib.ZSTD_c_checksumFlag, write_checksum
322 )
317 _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id)
323 _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id)
318 _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size)
324 _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size)
319
325
320 if overlap_log != -1 and overlap_size_log != -1:
326 if overlap_log != -1 and overlap_size_log != -1:
321 raise ValueError("cannot specify both overlap_log and overlap_size_log")
327 raise ValueError(
328 "cannot specify both overlap_log and overlap_size_log"
329 )
322
330
323 if overlap_size_log != -1:
331 if overlap_size_log != -1:
324 overlap_log = overlap_size_log
332 overlap_log = overlap_size_log
325 elif overlap_log == -1:
333 elif overlap_log == -1:
326 overlap_log = 0
334 overlap_log = 0
327
335
328 _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log)
336 _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log)
329 _set_compression_parameter(params, lib.ZSTD_c_forceMaxWindow, force_max_window)
337 _set_compression_parameter(
338 params, lib.ZSTD_c_forceMaxWindow, force_max_window
339 )
330 _set_compression_parameter(
340 _set_compression_parameter(
331 params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm
341 params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm
332 )
342 )
333 _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log)
343 _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log)
334 _set_compression_parameter(params, lib.ZSTD_c_ldmMinMatch, ldm_min_match)
344 _set_compression_parameter(
345 params, lib.ZSTD_c_ldmMinMatch, ldm_min_match
346 )
335 _set_compression_parameter(
347 _set_compression_parameter(
336 params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log
348 params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log
337 )
349 )
338
350
339 if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1:
351 if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1:
340 raise ValueError(
352 raise ValueError(
341 "cannot specify both ldm_hash_rate_log and ldm_hash_every_log"
353 "cannot specify both ldm_hash_rate_log and ldm_hash_every_log"
342 )
354 )
343
355
344 if ldm_hash_every_log != -1:
356 if ldm_hash_every_log != -1:
345 ldm_hash_rate_log = ldm_hash_every_log
357 ldm_hash_rate_log = ldm_hash_every_log
346 elif ldm_hash_rate_log == -1:
358 elif ldm_hash_rate_log == -1:
347 ldm_hash_rate_log = 0
359 ldm_hash_rate_log = 0
348
360
349 _set_compression_parameter(params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log)
361 _set_compression_parameter(
362 params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log
363 )
350
364
351 @property
365 @property
352 def format(self):
366 def format(self):
353 return _get_compression_parameter(self._params, lib.ZSTD_c_format)
367 return _get_compression_parameter(self._params, lib.ZSTD_c_format)
354
368
355 @property
369 @property
356 def compression_level(self):
370 def compression_level(self):
357 return _get_compression_parameter(self._params, lib.ZSTD_c_compressionLevel)
371 return _get_compression_parameter(
372 self._params, lib.ZSTD_c_compressionLevel
373 )
358
374
359 @property
375 @property
360 def window_log(self):
376 def window_log(self):
361 return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog)
377 return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog)
362
378
363 @property
379 @property
364 def hash_log(self):
380 def hash_log(self):
365 return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog)
381 return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog)
366
382
367 @property
383 @property
368 def chain_log(self):
384 def chain_log(self):
369 return _get_compression_parameter(self._params, lib.ZSTD_c_chainLog)
385 return _get_compression_parameter(self._params, lib.ZSTD_c_chainLog)
370
386
371 @property
387 @property
372 def search_log(self):
388 def search_log(self):
373 return _get_compression_parameter(self._params, lib.ZSTD_c_searchLog)
389 return _get_compression_parameter(self._params, lib.ZSTD_c_searchLog)
374
390
375 @property
391 @property
376 def min_match(self):
392 def min_match(self):
377 return _get_compression_parameter(self._params, lib.ZSTD_c_minMatch)
393 return _get_compression_parameter(self._params, lib.ZSTD_c_minMatch)
378
394
379 @property
395 @property
380 def target_length(self):
396 def target_length(self):
381 return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength)
397 return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength)
382
398
383 @property
399 @property
384 def compression_strategy(self):
400 def compression_strategy(self):
385 return _get_compression_parameter(self._params, lib.ZSTD_c_strategy)
401 return _get_compression_parameter(self._params, lib.ZSTD_c_strategy)
386
402
387 @property
403 @property
388 def write_content_size(self):
404 def write_content_size(self):
389 return _get_compression_parameter(self._params, lib.ZSTD_c_contentSizeFlag)
405 return _get_compression_parameter(
406 self._params, lib.ZSTD_c_contentSizeFlag
407 )
390
408
391 @property
409 @property
392 def write_checksum(self):
410 def write_checksum(self):
393 return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag)
411 return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag)
394
412
395 @property
413 @property
396 def write_dict_id(self):
414 def write_dict_id(self):
397 return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag)
415 return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag)
398
416
399 @property
417 @property
400 def job_size(self):
418 def job_size(self):
401 return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize)
419 return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize)
402
420
403 @property
421 @property
404 def overlap_log(self):
422 def overlap_log(self):
405 return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog)
423 return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog)
406
424
407 @property
425 @property
408 def overlap_size_log(self):
426 def overlap_size_log(self):
409 return self.overlap_log
427 return self.overlap_log
410
428
411 @property
429 @property
412 def force_max_window(self):
430 def force_max_window(self):
413 return _get_compression_parameter(self._params, lib.ZSTD_c_forceMaxWindow)
431 return _get_compression_parameter(
432 self._params, lib.ZSTD_c_forceMaxWindow
433 )
414
434
415 @property
435 @property
416 def enable_ldm(self):
436 def enable_ldm(self):
417 return _get_compression_parameter(
437 return _get_compression_parameter(
418 self._params, lib.ZSTD_c_enableLongDistanceMatching
438 self._params, lib.ZSTD_c_enableLongDistanceMatching
419 )
439 )
420
440
421 @property
441 @property
422 def ldm_hash_log(self):
442 def ldm_hash_log(self):
423 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog)
443 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog)
424
444
425 @property
445 @property
426 def ldm_min_match(self):
446 def ldm_min_match(self):
427 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch)
447 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch)
428
448
429 @property
449 @property
430 def ldm_bucket_size_log(self):
450 def ldm_bucket_size_log(self):
431 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmBucketSizeLog)
451 return _get_compression_parameter(
452 self._params, lib.ZSTD_c_ldmBucketSizeLog
453 )
432
454
433 @property
455 @property
434 def ldm_hash_rate_log(self):
456 def ldm_hash_rate_log(self):
435 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashRateLog)
457 return _get_compression_parameter(
458 self._params, lib.ZSTD_c_ldmHashRateLog
459 )
436
460
437 @property
461 @property
438 def ldm_hash_every_log(self):
462 def ldm_hash_every_log(self):
439 return self.ldm_hash_rate_log
463 return self.ldm_hash_rate_log
440
464
441 @property
465 @property
442 def threads(self):
466 def threads(self):
443 return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers)
467 return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers)
444
468
445 def estimated_compression_context_size(self):
469 def estimated_compression_context_size(self):
446 return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params)
470 return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params)
447
471
448
472
449 CompressionParameters = ZstdCompressionParameters
473 CompressionParameters = ZstdCompressionParameters
450
474
451
475
452 def estimate_decompression_context_size():
476 def estimate_decompression_context_size():
453 return lib.ZSTD_estimateDCtxSize()
477 return lib.ZSTD_estimateDCtxSize()
454
478
455
479
456 def _set_compression_parameter(params, param, value):
480 def _set_compression_parameter(params, param, value):
457 zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value)
481 zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value)
458 if lib.ZSTD_isError(zresult):
482 if lib.ZSTD_isError(zresult):
459 raise ZstdError(
483 raise ZstdError(
460 "unable to set compression context parameter: %s" % _zstd_error(zresult)
484 "unable to set compression context parameter: %s"
485 % _zstd_error(zresult)
461 )
486 )
462
487
463
488
464 def _get_compression_parameter(params, param):
489 def _get_compression_parameter(params, param):
465 result = ffi.new("int *")
490 result = ffi.new("int *")
466
491
467 zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result)
492 zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result)
468 if lib.ZSTD_isError(zresult):
493 if lib.ZSTD_isError(zresult):
469 raise ZstdError(
494 raise ZstdError(
470 "unable to get compression context parameter: %s" % _zstd_error(zresult)
495 "unable to get compression context parameter: %s"
496 % _zstd_error(zresult)
471 )
497 )
472
498
473 return result[0]
499 return result[0]
474
500
475
501
476 class ZstdCompressionWriter(object):
502 class ZstdCompressionWriter(object):
477 def __init__(self, compressor, writer, source_size, write_size, write_return_read):
503 def __init__(
504 self, compressor, writer, source_size, write_size, write_return_read
505 ):
478 self._compressor = compressor
506 self._compressor = compressor
479 self._writer = writer
507 self._writer = writer
480 self._write_size = write_size
508 self._write_size = write_size
481 self._write_return_read = bool(write_return_read)
509 self._write_return_read = bool(write_return_read)
482 self._entered = False
510 self._entered = False
483 self._closed = False
511 self._closed = False
484 self._bytes_compressed = 0
512 self._bytes_compressed = 0
485
513
486 self._dst_buffer = ffi.new("char[]", write_size)
514 self._dst_buffer = ffi.new("char[]", write_size)
487 self._out_buffer = ffi.new("ZSTD_outBuffer *")
515 self._out_buffer = ffi.new("ZSTD_outBuffer *")
488 self._out_buffer.dst = self._dst_buffer
516 self._out_buffer.dst = self._dst_buffer
489 self._out_buffer.size = len(self._dst_buffer)
517 self._out_buffer.size = len(self._dst_buffer)
490 self._out_buffer.pos = 0
518 self._out_buffer.pos = 0
491
519
492 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size)
520 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size)
493 if lib.ZSTD_isError(zresult):
521 if lib.ZSTD_isError(zresult):
494 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
522 raise ZstdError(
523 "error setting source size: %s" % _zstd_error(zresult)
524 )
495
525
496 def __enter__(self):
526 def __enter__(self):
497 if self._closed:
527 if self._closed:
498 raise ValueError("stream is closed")
528 raise ValueError("stream is closed")
499
529
500 if self._entered:
530 if self._entered:
501 raise ZstdError("cannot __enter__ multiple times")
531 raise ZstdError("cannot __enter__ multiple times")
502
532
503 self._entered = True
533 self._entered = True
504 return self
534 return self
505
535
506 def __exit__(self, exc_type, exc_value, exc_tb):
536 def __exit__(self, exc_type, exc_value, exc_tb):
507 self._entered = False
537 self._entered = False
508
538
509 if not exc_type and not exc_value and not exc_tb:
539 if not exc_type and not exc_value and not exc_tb:
510 self.close()
540 self.close()
511
541
512 self._compressor = None
542 self._compressor = None
513
543
514 return False
544 return False
515
545
516 def memory_size(self):
546 def memory_size(self):
517 return lib.ZSTD_sizeof_CCtx(self._compressor._cctx)
547 return lib.ZSTD_sizeof_CCtx(self._compressor._cctx)
518
548
519 def fileno(self):
549 def fileno(self):
520 f = getattr(self._writer, "fileno", None)
550 f = getattr(self._writer, "fileno", None)
521 if f:
551 if f:
522 return f()
552 return f()
523 else:
553 else:
524 raise OSError("fileno not available on underlying writer")
554 raise OSError("fileno not available on underlying writer")
525
555
526 def close(self):
556 def close(self):
527 if self._closed:
557 if self._closed:
528 return
558 return
529
559
530 try:
560 try:
531 self.flush(FLUSH_FRAME)
561 self.flush(FLUSH_FRAME)
532 finally:
562 finally:
533 self._closed = True
563 self._closed = True
534
564
535 # Call close() on underlying stream as well.
565 # Call close() on underlying stream as well.
536 f = getattr(self._writer, "close", None)
566 f = getattr(self._writer, "close", None)
537 if f:
567 if f:
538 f()
568 f()
539
569
540 @property
570 @property
541 def closed(self):
571 def closed(self):
542 return self._closed
572 return self._closed
543
573
544 def isatty(self):
574 def isatty(self):
545 return False
575 return False
546
576
547 def readable(self):
577 def readable(self):
548 return False
578 return False
549
579
550 def readline(self, size=-1):
580 def readline(self, size=-1):
551 raise io.UnsupportedOperation()
581 raise io.UnsupportedOperation()
552
582
553 def readlines(self, hint=-1):
583 def readlines(self, hint=-1):
554 raise io.UnsupportedOperation()
584 raise io.UnsupportedOperation()
555
585
556 def seek(self, offset, whence=None):
586 def seek(self, offset, whence=None):
557 raise io.UnsupportedOperation()
587 raise io.UnsupportedOperation()
558
588
559 def seekable(self):
589 def seekable(self):
560 return False
590 return False
561
591
562 def truncate(self, size=None):
592 def truncate(self, size=None):
563 raise io.UnsupportedOperation()
593 raise io.UnsupportedOperation()
564
594
565 def writable(self):
595 def writable(self):
566 return True
596 return True
567
597
568 def writelines(self, lines):
598 def writelines(self, lines):
569 raise NotImplementedError("writelines() is not yet implemented")
599 raise NotImplementedError("writelines() is not yet implemented")
570
600
571 def read(self, size=-1):
601 def read(self, size=-1):
572 raise io.UnsupportedOperation()
602 raise io.UnsupportedOperation()
573
603
574 def readall(self):
604 def readall(self):
575 raise io.UnsupportedOperation()
605 raise io.UnsupportedOperation()
576
606
577 def readinto(self, b):
607 def readinto(self, b):
578 raise io.UnsupportedOperation()
608 raise io.UnsupportedOperation()
579
609
580 def write(self, data):
610 def write(self, data):
581 if self._closed:
611 if self._closed:
582 raise ValueError("stream is closed")
612 raise ValueError("stream is closed")
583
613
584 total_write = 0
614 total_write = 0
585
615
586 data_buffer = ffi.from_buffer(data)
616 data_buffer = ffi.from_buffer(data)
587
617
588 in_buffer = ffi.new("ZSTD_inBuffer *")
618 in_buffer = ffi.new("ZSTD_inBuffer *")
589 in_buffer.src = data_buffer
619 in_buffer.src = data_buffer
590 in_buffer.size = len(data_buffer)
620 in_buffer.size = len(data_buffer)
591 in_buffer.pos = 0
621 in_buffer.pos = 0
592
622
593 out_buffer = self._out_buffer
623 out_buffer = self._out_buffer
594 out_buffer.pos = 0
624 out_buffer.pos = 0
595
625
596 while in_buffer.pos < in_buffer.size:
626 while in_buffer.pos < in_buffer.size:
597 zresult = lib.ZSTD_compressStream2(
627 zresult = lib.ZSTD_compressStream2(
598 self._compressor._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
628 self._compressor._cctx,
629 out_buffer,
630 in_buffer,
631 lib.ZSTD_e_continue,
599 )
632 )
600 if lib.ZSTD_isError(zresult):
633 if lib.ZSTD_isError(zresult):
601 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
634 raise ZstdError(
635 "zstd compress error: %s" % _zstd_error(zresult)
636 )
602
637
603 if out_buffer.pos:
638 if out_buffer.pos:
604 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
639 self._writer.write(
640 ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
641 )
605 total_write += out_buffer.pos
642 total_write += out_buffer.pos
606 self._bytes_compressed += out_buffer.pos
643 self._bytes_compressed += out_buffer.pos
607 out_buffer.pos = 0
644 out_buffer.pos = 0
608
645
609 if self._write_return_read:
646 if self._write_return_read:
610 return in_buffer.pos
647 return in_buffer.pos
611 else:
648 else:
612 return total_write
649 return total_write
613
650
614 def flush(self, flush_mode=FLUSH_BLOCK):
651 def flush(self, flush_mode=FLUSH_BLOCK):
615 if flush_mode == FLUSH_BLOCK:
652 if flush_mode == FLUSH_BLOCK:
616 flush = lib.ZSTD_e_flush
653 flush = lib.ZSTD_e_flush
617 elif flush_mode == FLUSH_FRAME:
654 elif flush_mode == FLUSH_FRAME:
618 flush = lib.ZSTD_e_end
655 flush = lib.ZSTD_e_end
619 else:
656 else:
620 raise ValueError("unknown flush_mode: %r" % flush_mode)
657 raise ValueError("unknown flush_mode: %r" % flush_mode)
621
658
622 if self._closed:
659 if self._closed:
623 raise ValueError("stream is closed")
660 raise ValueError("stream is closed")
624
661
625 total_write = 0
662 total_write = 0
626
663
627 out_buffer = self._out_buffer
664 out_buffer = self._out_buffer
628 out_buffer.pos = 0
665 out_buffer.pos = 0
629
666
630 in_buffer = ffi.new("ZSTD_inBuffer *")
667 in_buffer = ffi.new("ZSTD_inBuffer *")
631 in_buffer.src = ffi.NULL
668 in_buffer.src = ffi.NULL
632 in_buffer.size = 0
669 in_buffer.size = 0
633 in_buffer.pos = 0
670 in_buffer.pos = 0
634
671
635 while True:
672 while True:
636 zresult = lib.ZSTD_compressStream2(
673 zresult = lib.ZSTD_compressStream2(
637 self._compressor._cctx, out_buffer, in_buffer, flush
674 self._compressor._cctx, out_buffer, in_buffer, flush
638 )
675 )
639 if lib.ZSTD_isError(zresult):
676 if lib.ZSTD_isError(zresult):
640 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
677 raise ZstdError(
678 "zstd compress error: %s" % _zstd_error(zresult)
679 )
641
680
642 if out_buffer.pos:
681 if out_buffer.pos:
643 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
682 self._writer.write(
683 ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
684 )
644 total_write += out_buffer.pos
685 total_write += out_buffer.pos
645 self._bytes_compressed += out_buffer.pos
686 self._bytes_compressed += out_buffer.pos
646 out_buffer.pos = 0
687 out_buffer.pos = 0
647
688
648 if not zresult:
689 if not zresult:
649 break
690 break
650
691
651 return total_write
692 return total_write
652
693
653 def tell(self):
694 def tell(self):
654 return self._bytes_compressed
695 return self._bytes_compressed
655
696
656
697
657 class ZstdCompressionObj(object):
698 class ZstdCompressionObj(object):
658 def compress(self, data):
699 def compress(self, data):
659 if self._finished:
700 if self._finished:
660 raise ZstdError("cannot call compress() after compressor finished")
701 raise ZstdError("cannot call compress() after compressor finished")
661
702
662 data_buffer = ffi.from_buffer(data)
703 data_buffer = ffi.from_buffer(data)
663 source = ffi.new("ZSTD_inBuffer *")
704 source = ffi.new("ZSTD_inBuffer *")
664 source.src = data_buffer
705 source.src = data_buffer
665 source.size = len(data_buffer)
706 source.size = len(data_buffer)
666 source.pos = 0
707 source.pos = 0
667
708
668 chunks = []
709 chunks = []
669
710
670 while source.pos < len(data):
711 while source.pos < len(data):
671 zresult = lib.ZSTD_compressStream2(
712 zresult = lib.ZSTD_compressStream2(
672 self._compressor._cctx, self._out, source, lib.ZSTD_e_continue
713 self._compressor._cctx, self._out, source, lib.ZSTD_e_continue
673 )
714 )
674 if lib.ZSTD_isError(zresult):
715 if lib.ZSTD_isError(zresult):
675 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
716 raise ZstdError(
717 "zstd compress error: %s" % _zstd_error(zresult)
718 )
676
719
677 if self._out.pos:
720 if self._out.pos:
678 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
721 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
679 self._out.pos = 0
722 self._out.pos = 0
680
723
681 return b"".join(chunks)
724 return b"".join(chunks)
682
725
683 def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH):
726 def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH):
684 if flush_mode not in (COMPRESSOBJ_FLUSH_FINISH, COMPRESSOBJ_FLUSH_BLOCK):
727 if flush_mode not in (
728 COMPRESSOBJ_FLUSH_FINISH,
729 COMPRESSOBJ_FLUSH_BLOCK,
730 ):
685 raise ValueError("flush mode not recognized")
731 raise ValueError("flush mode not recognized")
686
732
687 if self._finished:
733 if self._finished:
688 raise ZstdError("compressor object already finished")
734 raise ZstdError("compressor object already finished")
689
735
690 if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
736 if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
691 z_flush_mode = lib.ZSTD_e_flush
737 z_flush_mode = lib.ZSTD_e_flush
692 elif flush_mode == COMPRESSOBJ_FLUSH_FINISH:
738 elif flush_mode == COMPRESSOBJ_FLUSH_FINISH:
693 z_flush_mode = lib.ZSTD_e_end
739 z_flush_mode = lib.ZSTD_e_end
694 self._finished = True
740 self._finished = True
695 else:
741 else:
696 raise ZstdError("unhandled flush mode")
742 raise ZstdError("unhandled flush mode")
697
743
698 assert self._out.pos == 0
744 assert self._out.pos == 0
699
745
700 in_buffer = ffi.new("ZSTD_inBuffer *")
746 in_buffer = ffi.new("ZSTD_inBuffer *")
701 in_buffer.src = ffi.NULL
747 in_buffer.src = ffi.NULL
702 in_buffer.size = 0
748 in_buffer.size = 0
703 in_buffer.pos = 0
749 in_buffer.pos = 0
704
750
705 chunks = []
751 chunks = []
706
752
707 while True:
753 while True:
708 zresult = lib.ZSTD_compressStream2(
754 zresult = lib.ZSTD_compressStream2(
709 self._compressor._cctx, self._out, in_buffer, z_flush_mode
755 self._compressor._cctx, self._out, in_buffer, z_flush_mode
710 )
756 )
711 if lib.ZSTD_isError(zresult):
757 if lib.ZSTD_isError(zresult):
712 raise ZstdError(
758 raise ZstdError(
713 "error ending compression stream: %s" % _zstd_error(zresult)
759 "error ending compression stream: %s" % _zstd_error(zresult)
714 )
760 )
715
761
716 if self._out.pos:
762 if self._out.pos:
717 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
763 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
718 self._out.pos = 0
764 self._out.pos = 0
719
765
720 if not zresult:
766 if not zresult:
721 break
767 break
722
768
723 return b"".join(chunks)
769 return b"".join(chunks)
724
770
725
771
726 class ZstdCompressionChunker(object):
772 class ZstdCompressionChunker(object):
727 def __init__(self, compressor, chunk_size):
773 def __init__(self, compressor, chunk_size):
728 self._compressor = compressor
774 self._compressor = compressor
729 self._out = ffi.new("ZSTD_outBuffer *")
775 self._out = ffi.new("ZSTD_outBuffer *")
730 self._dst_buffer = ffi.new("char[]", chunk_size)
776 self._dst_buffer = ffi.new("char[]", chunk_size)
731 self._out.dst = self._dst_buffer
777 self._out.dst = self._dst_buffer
732 self._out.size = chunk_size
778 self._out.size = chunk_size
733 self._out.pos = 0
779 self._out.pos = 0
734
780
735 self._in = ffi.new("ZSTD_inBuffer *")
781 self._in = ffi.new("ZSTD_inBuffer *")
736 self._in.src = ffi.NULL
782 self._in.src = ffi.NULL
737 self._in.size = 0
783 self._in.size = 0
738 self._in.pos = 0
784 self._in.pos = 0
739 self._finished = False
785 self._finished = False
740
786
741 def compress(self, data):
787 def compress(self, data):
742 if self._finished:
788 if self._finished:
743 raise ZstdError("cannot call compress() after compression finished")
789 raise ZstdError("cannot call compress() after compression finished")
744
790
745 if self._in.src != ffi.NULL:
791 if self._in.src != ffi.NULL:
746 raise ZstdError(
792 raise ZstdError(
747 "cannot perform operation before consuming output "
793 "cannot perform operation before consuming output "
748 "from previous operation"
794 "from previous operation"
749 )
795 )
750
796
751 data_buffer = ffi.from_buffer(data)
797 data_buffer = ffi.from_buffer(data)
752
798
753 if not len(data_buffer):
799 if not len(data_buffer):
754 return
800 return
755
801
756 self._in.src = data_buffer
802 self._in.src = data_buffer
757 self._in.size = len(data_buffer)
803 self._in.size = len(data_buffer)
758 self._in.pos = 0
804 self._in.pos = 0
759
805
760 while self._in.pos < self._in.size:
806 while self._in.pos < self._in.size:
761 zresult = lib.ZSTD_compressStream2(
807 zresult = lib.ZSTD_compressStream2(
762 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_continue
808 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_continue
763 )
809 )
764
810
765 if self._in.pos == self._in.size:
811 if self._in.pos == self._in.size:
766 self._in.src = ffi.NULL
812 self._in.src = ffi.NULL
767 self._in.size = 0
813 self._in.size = 0
768 self._in.pos = 0
814 self._in.pos = 0
769
815
770 if lib.ZSTD_isError(zresult):
816 if lib.ZSTD_isError(zresult):
771 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
817 raise ZstdError(
818 "zstd compress error: %s" % _zstd_error(zresult)
819 )
772
820
773 if self._out.pos == self._out.size:
821 if self._out.pos == self._out.size:
774 yield ffi.buffer(self._out.dst, self._out.pos)[:]
822 yield ffi.buffer(self._out.dst, self._out.pos)[:]
775 self._out.pos = 0
823 self._out.pos = 0
776
824
777 def flush(self):
825 def flush(self):
778 if self._finished:
826 if self._finished:
779 raise ZstdError("cannot call flush() after compression finished")
827 raise ZstdError("cannot call flush() after compression finished")
780
828
781 if self._in.src != ffi.NULL:
829 if self._in.src != ffi.NULL:
782 raise ZstdError(
830 raise ZstdError(
783 "cannot call flush() before consuming output from " "previous operation"
831 "cannot call flush() before consuming output from "
832 "previous operation"
784 )
833 )
785
834
786 while True:
835 while True:
787 zresult = lib.ZSTD_compressStream2(
836 zresult = lib.ZSTD_compressStream2(
788 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush
837 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush
789 )
838 )
790 if lib.ZSTD_isError(zresult):
839 if lib.ZSTD_isError(zresult):
791 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
840 raise ZstdError(
841 "zstd compress error: %s" % _zstd_error(zresult)
842 )
792
843
793 if self._out.pos:
844 if self._out.pos:
794 yield ffi.buffer(self._out.dst, self._out.pos)[:]
845 yield ffi.buffer(self._out.dst, self._out.pos)[:]
795 self._out.pos = 0
846 self._out.pos = 0
796
847
797 if not zresult:
848 if not zresult:
798 return
849 return
799
850
800 def finish(self):
851 def finish(self):
801 if self._finished:
852 if self._finished:
802 raise ZstdError("cannot call finish() after compression finished")
853 raise ZstdError("cannot call finish() after compression finished")
803
854
804 if self._in.src != ffi.NULL:
855 if self._in.src != ffi.NULL:
805 raise ZstdError(
856 raise ZstdError(
806 "cannot call finish() before consuming output from "
857 "cannot call finish() before consuming output from "
807 "previous operation"
858 "previous operation"
808 )
859 )
809
860
810 while True:
861 while True:
811 zresult = lib.ZSTD_compressStream2(
862 zresult = lib.ZSTD_compressStream2(
812 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end
863 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end
813 )
864 )
814 if lib.ZSTD_isError(zresult):
865 if lib.ZSTD_isError(zresult):
815 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
866 raise ZstdError(
867 "zstd compress error: %s" % _zstd_error(zresult)
868 )
816
869
817 if self._out.pos:
870 if self._out.pos:
818 yield ffi.buffer(self._out.dst, self._out.pos)[:]
871 yield ffi.buffer(self._out.dst, self._out.pos)[:]
819 self._out.pos = 0
872 self._out.pos = 0
820
873
821 if not zresult:
874 if not zresult:
822 self._finished = True
875 self._finished = True
823 return
876 return
824
877
825
878
826 class ZstdCompressionReader(object):
879 class ZstdCompressionReader(object):
827 def __init__(self, compressor, source, read_size):
880 def __init__(self, compressor, source, read_size):
828 self._compressor = compressor
881 self._compressor = compressor
829 self._source = source
882 self._source = source
830 self._read_size = read_size
883 self._read_size = read_size
831 self._entered = False
884 self._entered = False
832 self._closed = False
885 self._closed = False
833 self._bytes_compressed = 0
886 self._bytes_compressed = 0
834 self._finished_input = False
887 self._finished_input = False
835 self._finished_output = False
888 self._finished_output = False
836
889
837 self._in_buffer = ffi.new("ZSTD_inBuffer *")
890 self._in_buffer = ffi.new("ZSTD_inBuffer *")
838 # Holds a ref so backing bytes in self._in_buffer stay alive.
891 # Holds a ref so backing bytes in self._in_buffer stay alive.
839 self._source_buffer = None
892 self._source_buffer = None
840
893
841 def __enter__(self):
894 def __enter__(self):
842 if self._entered:
895 if self._entered:
843 raise ValueError("cannot __enter__ multiple times")
896 raise ValueError("cannot __enter__ multiple times")
844
897
845 self._entered = True
898 self._entered = True
846 return self
899 return self
847
900
848 def __exit__(self, exc_type, exc_value, exc_tb):
901 def __exit__(self, exc_type, exc_value, exc_tb):
849 self._entered = False
902 self._entered = False
850 self._closed = True
903 self._closed = True
851 self._source = None
904 self._source = None
852 self._compressor = None
905 self._compressor = None
853
906
854 return False
907 return False
855
908
856 def readable(self):
909 def readable(self):
857 return True
910 return True
858
911
859 def writable(self):
912 def writable(self):
860 return False
913 return False
861
914
862 def seekable(self):
915 def seekable(self):
863 return False
916 return False
864
917
865 def readline(self):
918 def readline(self):
866 raise io.UnsupportedOperation()
919 raise io.UnsupportedOperation()
867
920
868 def readlines(self):
921 def readlines(self):
869 raise io.UnsupportedOperation()
922 raise io.UnsupportedOperation()
870
923
871 def write(self, data):
924 def write(self, data):
872 raise OSError("stream is not writable")
925 raise OSError("stream is not writable")
873
926
874 def writelines(self, ignored):
927 def writelines(self, ignored):
875 raise OSError("stream is not writable")
928 raise OSError("stream is not writable")
876
929
877 def isatty(self):
930 def isatty(self):
878 return False
931 return False
879
932
880 def flush(self):
933 def flush(self):
881 return None
934 return None
882
935
883 def close(self):
936 def close(self):
884 self._closed = True
937 self._closed = True
885 return None
938 return None
886
939
887 @property
940 @property
888 def closed(self):
941 def closed(self):
889 return self._closed
942 return self._closed
890
943
891 def tell(self):
944 def tell(self):
892 return self._bytes_compressed
945 return self._bytes_compressed
893
946
894 def readall(self):
947 def readall(self):
895 chunks = []
948 chunks = []
896
949
897 while True:
950 while True:
898 chunk = self.read(1048576)
951 chunk = self.read(1048576)
899 if not chunk:
952 if not chunk:
900 break
953 break
901
954
902 chunks.append(chunk)
955 chunks.append(chunk)
903
956
904 return b"".join(chunks)
957 return b"".join(chunks)
905
958
906 def __iter__(self):
959 def __iter__(self):
907 raise io.UnsupportedOperation()
960 raise io.UnsupportedOperation()
908
961
909 def __next__(self):
962 def __next__(self):
910 raise io.UnsupportedOperation()
963 raise io.UnsupportedOperation()
911
964
912 next = __next__
965 next = __next__
913
966
914 def _read_input(self):
967 def _read_input(self):
915 if self._finished_input:
968 if self._finished_input:
916 return
969 return
917
970
918 if hasattr(self._source, "read"):
971 if hasattr(self._source, "read"):
919 data = self._source.read(self._read_size)
972 data = self._source.read(self._read_size)
920
973
921 if not data:
974 if not data:
922 self._finished_input = True
975 self._finished_input = True
923 return
976 return
924
977
925 self._source_buffer = ffi.from_buffer(data)
978 self._source_buffer = ffi.from_buffer(data)
926 self._in_buffer.src = self._source_buffer
979 self._in_buffer.src = self._source_buffer
927 self._in_buffer.size = len(self._source_buffer)
980 self._in_buffer.size = len(self._source_buffer)
928 self._in_buffer.pos = 0
981 self._in_buffer.pos = 0
929 else:
982 else:
930 self._source_buffer = ffi.from_buffer(self._source)
983 self._source_buffer = ffi.from_buffer(self._source)
931 self._in_buffer.src = self._source_buffer
984 self._in_buffer.src = self._source_buffer
932 self._in_buffer.size = len(self._source_buffer)
985 self._in_buffer.size = len(self._source_buffer)
933 self._in_buffer.pos = 0
986 self._in_buffer.pos = 0
934
987
935 def _compress_into_buffer(self, out_buffer):
988 def _compress_into_buffer(self, out_buffer):
936 if self._in_buffer.pos >= self._in_buffer.size:
989 if self._in_buffer.pos >= self._in_buffer.size:
937 return
990 return
938
991
939 old_pos = out_buffer.pos
992 old_pos = out_buffer.pos
940
993
941 zresult = lib.ZSTD_compressStream2(
994 zresult = lib.ZSTD_compressStream2(
942 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_continue
995 self._compressor._cctx,
996 out_buffer,
997 self._in_buffer,
998 lib.ZSTD_e_continue,
943 )
999 )
944
1000
945 self._bytes_compressed += out_buffer.pos - old_pos
1001 self._bytes_compressed += out_buffer.pos - old_pos
946
1002
947 if self._in_buffer.pos == self._in_buffer.size:
1003 if self._in_buffer.pos == self._in_buffer.size:
948 self._in_buffer.src = ffi.NULL
1004 self._in_buffer.src = ffi.NULL
949 self._in_buffer.pos = 0
1005 self._in_buffer.pos = 0
950 self._in_buffer.size = 0
1006 self._in_buffer.size = 0
951 self._source_buffer = None
1007 self._source_buffer = None
952
1008
953 if not hasattr(self._source, "read"):
1009 if not hasattr(self._source, "read"):
954 self._finished_input = True
1010 self._finished_input = True
955
1011
956 if lib.ZSTD_isError(zresult):
1012 if lib.ZSTD_isError(zresult):
957 raise ZstdError("zstd compress error: %s", _zstd_error(zresult))
1013 raise ZstdError("zstd compress error: %s", _zstd_error(zresult))
958
1014
959 return out_buffer.pos and out_buffer.pos == out_buffer.size
1015 return out_buffer.pos and out_buffer.pos == out_buffer.size
960
1016
961 def read(self, size=-1):
1017 def read(self, size=-1):
962 if self._closed:
1018 if self._closed:
963 raise ValueError("stream is closed")
1019 raise ValueError("stream is closed")
964
1020
965 if size < -1:
1021 if size < -1:
966 raise ValueError("cannot read negative amounts less than -1")
1022 raise ValueError("cannot read negative amounts less than -1")
967
1023
968 if size == -1:
1024 if size == -1:
969 return self.readall()
1025 return self.readall()
970
1026
971 if self._finished_output or size == 0:
1027 if self._finished_output or size == 0:
972 return b""
1028 return b""
973
1029
974 # Need a dedicated ref to dest buffer otherwise it gets collected.
1030 # Need a dedicated ref to dest buffer otherwise it gets collected.
975 dst_buffer = ffi.new("char[]", size)
1031 dst_buffer = ffi.new("char[]", size)
976 out_buffer = ffi.new("ZSTD_outBuffer *")
1032 out_buffer = ffi.new("ZSTD_outBuffer *")
977 out_buffer.dst = dst_buffer
1033 out_buffer.dst = dst_buffer
978 out_buffer.size = size
1034 out_buffer.size = size
979 out_buffer.pos = 0
1035 out_buffer.pos = 0
980
1036
981 if self._compress_into_buffer(out_buffer):
1037 if self._compress_into_buffer(out_buffer):
982 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1038 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
983
1039
984 while not self._finished_input:
1040 while not self._finished_input:
985 self._read_input()
1041 self._read_input()
986
1042
987 if self._compress_into_buffer(out_buffer):
1043 if self._compress_into_buffer(out_buffer):
988 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1044 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
989
1045
990 # EOF
1046 # EOF
991 old_pos = out_buffer.pos
1047 old_pos = out_buffer.pos
992
1048
993 zresult = lib.ZSTD_compressStream2(
1049 zresult = lib.ZSTD_compressStream2(
994 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1050 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
995 )
1051 )
996
1052
997 self._bytes_compressed += out_buffer.pos - old_pos
1053 self._bytes_compressed += out_buffer.pos - old_pos
998
1054
999 if lib.ZSTD_isError(zresult):
1055 if lib.ZSTD_isError(zresult):
1000 raise ZstdError("error ending compression stream: %s", _zstd_error(zresult))
1056 raise ZstdError(
1057 "error ending compression stream: %s", _zstd_error(zresult)
1058 )
1001
1059
1002 if zresult == 0:
1060 if zresult == 0:
1003 self._finished_output = True
1061 self._finished_output = True
1004
1062
1005 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1063 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1006
1064
1007 def read1(self, size=-1):
1065 def read1(self, size=-1):
1008 if self._closed:
1066 if self._closed:
1009 raise ValueError("stream is closed")
1067 raise ValueError("stream is closed")
1010
1068
1011 if size < -1:
1069 if size < -1:
1012 raise ValueError("cannot read negative amounts less than -1")
1070 raise ValueError("cannot read negative amounts less than -1")
1013
1071
1014 if self._finished_output or size == 0:
1072 if self._finished_output or size == 0:
1015 return b""
1073 return b""
1016
1074
1017 # -1 returns arbitrary number of bytes.
1075 # -1 returns arbitrary number of bytes.
1018 if size == -1:
1076 if size == -1:
1019 size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1077 size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1020
1078
1021 dst_buffer = ffi.new("char[]", size)
1079 dst_buffer = ffi.new("char[]", size)
1022 out_buffer = ffi.new("ZSTD_outBuffer *")
1080 out_buffer = ffi.new("ZSTD_outBuffer *")
1023 out_buffer.dst = dst_buffer
1081 out_buffer.dst = dst_buffer
1024 out_buffer.size = size
1082 out_buffer.size = size
1025 out_buffer.pos = 0
1083 out_buffer.pos = 0
1026
1084
1027 # read1() dictates that we can perform at most 1 call to the
1085 # read1() dictates that we can perform at most 1 call to the
1028 # underlying stream to get input. However, we can't satisfy this
1086 # underlying stream to get input. However, we can't satisfy this
1029 # restriction with compression because not all input generates output.
1087 # restriction with compression because not all input generates output.
1030 # It is possible to perform a block flush in order to ensure output.
1088 # It is possible to perform a block flush in order to ensure output.
1031 # But this may not be desirable behavior. So we allow multiple read()
1089 # But this may not be desirable behavior. So we allow multiple read()
1032 # to the underlying stream. But unlike read(), we stop once we have
1090 # to the underlying stream. But unlike read(), we stop once we have
1033 # any output.
1091 # any output.
1034
1092
1035 self._compress_into_buffer(out_buffer)
1093 self._compress_into_buffer(out_buffer)
1036 if out_buffer.pos:
1094 if out_buffer.pos:
1037 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1095 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1038
1096
1039 while not self._finished_input:
1097 while not self._finished_input:
1040 self._read_input()
1098 self._read_input()
1041
1099
1042 # If we've filled the output buffer, return immediately.
1100 # If we've filled the output buffer, return immediately.
1043 if self._compress_into_buffer(out_buffer):
1101 if self._compress_into_buffer(out_buffer):
1044 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1102 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1045
1103
1046 # If we've populated the output buffer and we're not at EOF,
1104 # If we've populated the output buffer and we're not at EOF,
1047 # also return, as we've satisfied the read1() limits.
1105 # also return, as we've satisfied the read1() limits.
1048 if out_buffer.pos and not self._finished_input:
1106 if out_buffer.pos and not self._finished_input:
1049 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1107 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1050
1108
1051 # Else if we're at EOS and we have room left in the buffer,
1109 # Else if we're at EOS and we have room left in the buffer,
1052 # fall through to below and try to add more data to the output.
1110 # fall through to below and try to add more data to the output.
1053
1111
1054 # EOF.
1112 # EOF.
1055 old_pos = out_buffer.pos
1113 old_pos = out_buffer.pos
1056
1114
1057 zresult = lib.ZSTD_compressStream2(
1115 zresult = lib.ZSTD_compressStream2(
1058 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1116 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1059 )
1117 )
1060
1118
1061 self._bytes_compressed += out_buffer.pos - old_pos
1119 self._bytes_compressed += out_buffer.pos - old_pos
1062
1120
1063 if lib.ZSTD_isError(zresult):
1121 if lib.ZSTD_isError(zresult):
1064 raise ZstdError(
1122 raise ZstdError(
1065 "error ending compression stream: %s" % _zstd_error(zresult)
1123 "error ending compression stream: %s" % _zstd_error(zresult)
1066 )
1124 )
1067
1125
1068 if zresult == 0:
1126 if zresult == 0:
1069 self._finished_output = True
1127 self._finished_output = True
1070
1128
1071 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1129 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1072
1130
1073 def readinto(self, b):
1131 def readinto(self, b):
1074 if self._closed:
1132 if self._closed:
1075 raise ValueError("stream is closed")
1133 raise ValueError("stream is closed")
1076
1134
1077 if self._finished_output:
1135 if self._finished_output:
1078 return 0
1136 return 0
1079
1137
1080 # TODO use writable=True once we require CFFI >= 1.12.
1138 # TODO use writable=True once we require CFFI >= 1.12.
1081 dest_buffer = ffi.from_buffer(b)
1139 dest_buffer = ffi.from_buffer(b)
1082 ffi.memmove(b, b"", 0)
1140 ffi.memmove(b, b"", 0)
1083 out_buffer = ffi.new("ZSTD_outBuffer *")
1141 out_buffer = ffi.new("ZSTD_outBuffer *")
1084 out_buffer.dst = dest_buffer
1142 out_buffer.dst = dest_buffer
1085 out_buffer.size = len(dest_buffer)
1143 out_buffer.size = len(dest_buffer)
1086 out_buffer.pos = 0
1144 out_buffer.pos = 0
1087
1145
1088 if self._compress_into_buffer(out_buffer):
1146 if self._compress_into_buffer(out_buffer):
1089 return out_buffer.pos
1147 return out_buffer.pos
1090
1148
1091 while not self._finished_input:
1149 while not self._finished_input:
1092 self._read_input()
1150 self._read_input()
1093 if self._compress_into_buffer(out_buffer):
1151 if self._compress_into_buffer(out_buffer):
1094 return out_buffer.pos
1152 return out_buffer.pos
1095
1153
1096 # EOF.
1154 # EOF.
1097 old_pos = out_buffer.pos
1155 old_pos = out_buffer.pos
1098 zresult = lib.ZSTD_compressStream2(
1156 zresult = lib.ZSTD_compressStream2(
1099 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1157 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1100 )
1158 )
1101
1159
1102 self._bytes_compressed += out_buffer.pos - old_pos
1160 self._bytes_compressed += out_buffer.pos - old_pos
1103
1161
1104 if lib.ZSTD_isError(zresult):
1162 if lib.ZSTD_isError(zresult):
1105 raise ZstdError("error ending compression stream: %s", _zstd_error(zresult))
1163 raise ZstdError(
1164 "error ending compression stream: %s", _zstd_error(zresult)
1165 )
1106
1166
1107 if zresult == 0:
1167 if zresult == 0:
1108 self._finished_output = True
1168 self._finished_output = True
1109
1169
1110 return out_buffer.pos
1170 return out_buffer.pos
1111
1171
1112 def readinto1(self, b):
1172 def readinto1(self, b):
1113 if self._closed:
1173 if self._closed:
1114 raise ValueError("stream is closed")
1174 raise ValueError("stream is closed")
1115
1175
1116 if self._finished_output:
1176 if self._finished_output:
1117 return 0
1177 return 0
1118
1178
1119 # TODO use writable=True once we require CFFI >= 1.12.
1179 # TODO use writable=True once we require CFFI >= 1.12.
1120 dest_buffer = ffi.from_buffer(b)
1180 dest_buffer = ffi.from_buffer(b)
1121 ffi.memmove(b, b"", 0)
1181 ffi.memmove(b, b"", 0)
1122
1182
1123 out_buffer = ffi.new("ZSTD_outBuffer *")
1183 out_buffer = ffi.new("ZSTD_outBuffer *")
1124 out_buffer.dst = dest_buffer
1184 out_buffer.dst = dest_buffer
1125 out_buffer.size = len(dest_buffer)
1185 out_buffer.size = len(dest_buffer)
1126 out_buffer.pos = 0
1186 out_buffer.pos = 0
1127
1187
1128 self._compress_into_buffer(out_buffer)
1188 self._compress_into_buffer(out_buffer)
1129 if out_buffer.pos:
1189 if out_buffer.pos:
1130 return out_buffer.pos
1190 return out_buffer.pos
1131
1191
1132 while not self._finished_input:
1192 while not self._finished_input:
1133 self._read_input()
1193 self._read_input()
1134
1194
1135 if self._compress_into_buffer(out_buffer):
1195 if self._compress_into_buffer(out_buffer):
1136 return out_buffer.pos
1196 return out_buffer.pos
1137
1197
1138 if out_buffer.pos and not self._finished_input:
1198 if out_buffer.pos and not self._finished_input:
1139 return out_buffer.pos
1199 return out_buffer.pos
1140
1200
1141 # EOF.
1201 # EOF.
1142 old_pos = out_buffer.pos
1202 old_pos = out_buffer.pos
1143
1203
1144 zresult = lib.ZSTD_compressStream2(
1204 zresult = lib.ZSTD_compressStream2(
1145 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1205 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1146 )
1206 )
1147
1207
1148 self._bytes_compressed += out_buffer.pos - old_pos
1208 self._bytes_compressed += out_buffer.pos - old_pos
1149
1209
1150 if lib.ZSTD_isError(zresult):
1210 if lib.ZSTD_isError(zresult):
1151 raise ZstdError(
1211 raise ZstdError(
1152 "error ending compression stream: %s" % _zstd_error(zresult)
1212 "error ending compression stream: %s" % _zstd_error(zresult)
1153 )
1213 )
1154
1214
1155 if zresult == 0:
1215 if zresult == 0:
1156 self._finished_output = True
1216 self._finished_output = True
1157
1217
1158 return out_buffer.pos
1218 return out_buffer.pos
1159
1219
1160
1220
1161 class ZstdCompressor(object):
1221 class ZstdCompressor(object):
1162 def __init__(
1222 def __init__(
1163 self,
1223 self,
1164 level=3,
1224 level=3,
1165 dict_data=None,
1225 dict_data=None,
1166 compression_params=None,
1226 compression_params=None,
1167 write_checksum=None,
1227 write_checksum=None,
1168 write_content_size=None,
1228 write_content_size=None,
1169 write_dict_id=None,
1229 write_dict_id=None,
1170 threads=0,
1230 threads=0,
1171 ):
1231 ):
1172 if level > lib.ZSTD_maxCLevel():
1232 if level > lib.ZSTD_maxCLevel():
1173 raise ValueError("level must be less than %d" % lib.ZSTD_maxCLevel())
1233 raise ValueError(
1234 "level must be less than %d" % lib.ZSTD_maxCLevel()
1235 )
1174
1236
1175 if threads < 0:
1237 if threads < 0:
1176 threads = _cpu_count()
1238 threads = _cpu_count()
1177
1239
1178 if compression_params and write_checksum is not None:
1240 if compression_params and write_checksum is not None:
1179 raise ValueError("cannot define compression_params and " "write_checksum")
1241 raise ValueError(
1242 "cannot define compression_params and " "write_checksum"
1243 )
1180
1244
1181 if compression_params and write_content_size is not None:
1245 if compression_params and write_content_size is not None:
1182 raise ValueError(
1246 raise ValueError(
1183 "cannot define compression_params and " "write_content_size"
1247 "cannot define compression_params and " "write_content_size"
1184 )
1248 )
1185
1249
1186 if compression_params and write_dict_id is not None:
1250 if compression_params and write_dict_id is not None:
1187 raise ValueError("cannot define compression_params and " "write_dict_id")
1251 raise ValueError(
1252 "cannot define compression_params and " "write_dict_id"
1253 )
1188
1254
1189 if compression_params and threads:
1255 if compression_params and threads:
1190 raise ValueError("cannot define compression_params and threads")
1256 raise ValueError("cannot define compression_params and threads")
1191
1257
1192 if compression_params:
1258 if compression_params:
1193 self._params = _make_cctx_params(compression_params)
1259 self._params = _make_cctx_params(compression_params)
1194 else:
1260 else:
1195 if write_dict_id is None:
1261 if write_dict_id is None:
1196 write_dict_id = True
1262 write_dict_id = True
1197
1263
1198 params = lib.ZSTD_createCCtxParams()
1264 params = lib.ZSTD_createCCtxParams()
1199 if params == ffi.NULL:
1265 if params == ffi.NULL:
1200 raise MemoryError()
1266 raise MemoryError()
1201
1267
1202 self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
1268 self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
1203
1269
1204 _set_compression_parameter(self._params, lib.ZSTD_c_compressionLevel, level)
1270 _set_compression_parameter(
1271 self._params, lib.ZSTD_c_compressionLevel, level
1272 )
1205
1273
1206 _set_compression_parameter(
1274 _set_compression_parameter(
1207 self._params,
1275 self._params,
1208 lib.ZSTD_c_contentSizeFlag,
1276 lib.ZSTD_c_contentSizeFlag,
1209 write_content_size if write_content_size is not None else 1,
1277 write_content_size if write_content_size is not None else 1,
1210 )
1278 )
1211
1279
1212 _set_compression_parameter(
1280 _set_compression_parameter(
1213 self._params, lib.ZSTD_c_checksumFlag, 1 if write_checksum else 0
1281 self._params,
1282 lib.ZSTD_c_checksumFlag,
1283 1 if write_checksum else 0,
1214 )
1284 )
1215
1285
1216 _set_compression_parameter(
1286 _set_compression_parameter(
1217 self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0
1287 self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0
1218 )
1288 )
1219
1289
1220 if threads:
1290 if threads:
1221 _set_compression_parameter(self._params, lib.ZSTD_c_nbWorkers, threads)
1291 _set_compression_parameter(
1292 self._params, lib.ZSTD_c_nbWorkers, threads
1293 )
1222
1294
1223 cctx = lib.ZSTD_createCCtx()
1295 cctx = lib.ZSTD_createCCtx()
1224 if cctx == ffi.NULL:
1296 if cctx == ffi.NULL:
1225 raise MemoryError()
1297 raise MemoryError()
1226
1298
1227 self._cctx = cctx
1299 self._cctx = cctx
1228 self._dict_data = dict_data
1300 self._dict_data = dict_data
1229
1301
1230 # We defer setting up garbage collection until after calling
1302 # We defer setting up garbage collection until after calling
1231 # _setup_cctx() to ensure the memory size estimate is more accurate.
1303 # _setup_cctx() to ensure the memory size estimate is more accurate.
1232 try:
1304 try:
1233 self._setup_cctx()
1305 self._setup_cctx()
1234 finally:
1306 finally:
1235 self._cctx = ffi.gc(
1307 self._cctx = ffi.gc(
1236 cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx)
1308 cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx)
1237 )
1309 )
1238
1310
1239 def _setup_cctx(self):
1311 def _setup_cctx(self):
1240 zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx, self._params)
1312 zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(
1313 self._cctx, self._params
1314 )
1241 if lib.ZSTD_isError(zresult):
1315 if lib.ZSTD_isError(zresult):
1242 raise ZstdError(
1316 raise ZstdError(
1243 "could not set compression parameters: %s" % _zstd_error(zresult)
1317 "could not set compression parameters: %s"
1318 % _zstd_error(zresult)
1244 )
1319 )
1245
1320
1246 dict_data = self._dict_data
1321 dict_data = self._dict_data
1247
1322
1248 if dict_data:
1323 if dict_data:
1249 if dict_data._cdict:
1324 if dict_data._cdict:
1250 zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict)
1325 zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict)
1251 else:
1326 else:
1252 zresult = lib.ZSTD_CCtx_loadDictionary_advanced(
1327 zresult = lib.ZSTD_CCtx_loadDictionary_advanced(
1253 self._cctx,
1328 self._cctx,
1254 dict_data.as_bytes(),
1329 dict_data.as_bytes(),
1255 len(dict_data),
1330 len(dict_data),
1256 lib.ZSTD_dlm_byRef,
1331 lib.ZSTD_dlm_byRef,
1257 dict_data._dict_type,
1332 dict_data._dict_type,
1258 )
1333 )
1259
1334
1260 if lib.ZSTD_isError(zresult):
1335 if lib.ZSTD_isError(zresult):
1261 raise ZstdError(
1336 raise ZstdError(
1262 "could not load compression dictionary: %s" % _zstd_error(zresult)
1337 "could not load compression dictionary: %s"
1338 % _zstd_error(zresult)
1263 )
1339 )
1264
1340
1265 def memory_size(self):
1341 def memory_size(self):
1266 return lib.ZSTD_sizeof_CCtx(self._cctx)
1342 return lib.ZSTD_sizeof_CCtx(self._cctx)
1267
1343
1268 def compress(self, data):
1344 def compress(self, data):
1269 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1345 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1270
1346
1271 data_buffer = ffi.from_buffer(data)
1347 data_buffer = ffi.from_buffer(data)
1272
1348
1273 dest_size = lib.ZSTD_compressBound(len(data_buffer))
1349 dest_size = lib.ZSTD_compressBound(len(data_buffer))
1274 out = new_nonzero("char[]", dest_size)
1350 out = new_nonzero("char[]", dest_size)
1275
1351
1276 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer))
1352 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer))
1277 if lib.ZSTD_isError(zresult):
1353 if lib.ZSTD_isError(zresult):
1278 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1354 raise ZstdError(
1355 "error setting source size: %s" % _zstd_error(zresult)
1356 )
1279
1357
1280 out_buffer = ffi.new("ZSTD_outBuffer *")
1358 out_buffer = ffi.new("ZSTD_outBuffer *")
1281 in_buffer = ffi.new("ZSTD_inBuffer *")
1359 in_buffer = ffi.new("ZSTD_inBuffer *")
1282
1360
1283 out_buffer.dst = out
1361 out_buffer.dst = out
1284 out_buffer.size = dest_size
1362 out_buffer.size = dest_size
1285 out_buffer.pos = 0
1363 out_buffer.pos = 0
1286
1364
1287 in_buffer.src = data_buffer
1365 in_buffer.src = data_buffer
1288 in_buffer.size = len(data_buffer)
1366 in_buffer.size = len(data_buffer)
1289 in_buffer.pos = 0
1367 in_buffer.pos = 0
1290
1368
1291 zresult = lib.ZSTD_compressStream2(
1369 zresult = lib.ZSTD_compressStream2(
1292 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1370 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1293 )
1371 )
1294
1372
1295 if lib.ZSTD_isError(zresult):
1373 if lib.ZSTD_isError(zresult):
1296 raise ZstdError("cannot compress: %s" % _zstd_error(zresult))
1374 raise ZstdError("cannot compress: %s" % _zstd_error(zresult))
1297 elif zresult:
1375 elif zresult:
1298 raise ZstdError("unexpected partial frame flush")
1376 raise ZstdError("unexpected partial frame flush")
1299
1377
1300 return ffi.buffer(out, out_buffer.pos)[:]
1378 return ffi.buffer(out, out_buffer.pos)[:]
1301
1379
1302 def compressobj(self, size=-1):
1380 def compressobj(self, size=-1):
1303 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1381 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1304
1382
1305 if size < 0:
1383 if size < 0:
1306 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1384 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1307
1385
1308 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1386 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1309 if lib.ZSTD_isError(zresult):
1387 if lib.ZSTD_isError(zresult):
1310 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1388 raise ZstdError(
1389 "error setting source size: %s" % _zstd_error(zresult)
1390 )
1311
1391
1312 cobj = ZstdCompressionObj()
1392 cobj = ZstdCompressionObj()
1313 cobj._out = ffi.new("ZSTD_outBuffer *")
1393 cobj._out = ffi.new("ZSTD_outBuffer *")
1314 cobj._dst_buffer = ffi.new("char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE)
1394 cobj._dst_buffer = ffi.new(
1395 "char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1396 )
1315 cobj._out.dst = cobj._dst_buffer
1397 cobj._out.dst = cobj._dst_buffer
1316 cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1398 cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1317 cobj._out.pos = 0
1399 cobj._out.pos = 0
1318 cobj._compressor = self
1400 cobj._compressor = self
1319 cobj._finished = False
1401 cobj._finished = False
1320
1402
1321 return cobj
1403 return cobj
1322
1404
1323 def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
1405 def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
1324 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1406 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1325
1407
1326 if size < 0:
1408 if size < 0:
1327 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1409 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1328
1410
1329 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1411 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1330 if lib.ZSTD_isError(zresult):
1412 if lib.ZSTD_isError(zresult):
1331 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1413 raise ZstdError(
1414 "error setting source size: %s" % _zstd_error(zresult)
1415 )
1332
1416
1333 return ZstdCompressionChunker(self, chunk_size=chunk_size)
1417 return ZstdCompressionChunker(self, chunk_size=chunk_size)
1334
1418
1335 def copy_stream(
1419 def copy_stream(
1336 self,
1420 self,
1337 ifh,
1421 ifh,
1338 ofh,
1422 ofh,
1339 size=-1,
1423 size=-1,
1340 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
1424 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
1341 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1425 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1342 ):
1426 ):
1343
1427
1344 if not hasattr(ifh, "read"):
1428 if not hasattr(ifh, "read"):
1345 raise ValueError("first argument must have a read() method")
1429 raise ValueError("first argument must have a read() method")
1346 if not hasattr(ofh, "write"):
1430 if not hasattr(ofh, "write"):
1347 raise ValueError("second argument must have a write() method")
1431 raise ValueError("second argument must have a write() method")
1348
1432
1349 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1433 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1350
1434
1351 if size < 0:
1435 if size < 0:
1352 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1436 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1353
1437
1354 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1438 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1355 if lib.ZSTD_isError(zresult):
1439 if lib.ZSTD_isError(zresult):
1356 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1440 raise ZstdError(
1441 "error setting source size: %s" % _zstd_error(zresult)
1442 )
1357
1443
1358 in_buffer = ffi.new("ZSTD_inBuffer *")
1444 in_buffer = ffi.new("ZSTD_inBuffer *")
1359 out_buffer = ffi.new("ZSTD_outBuffer *")
1445 out_buffer = ffi.new("ZSTD_outBuffer *")
1360
1446
1361 dst_buffer = ffi.new("char[]", write_size)
1447 dst_buffer = ffi.new("char[]", write_size)
1362 out_buffer.dst = dst_buffer
1448 out_buffer.dst = dst_buffer
1363 out_buffer.size = write_size
1449 out_buffer.size = write_size
1364 out_buffer.pos = 0
1450 out_buffer.pos = 0
1365
1451
1366 total_read, total_write = 0, 0
1452 total_read, total_write = 0, 0
1367
1453
1368 while True:
1454 while True:
1369 data = ifh.read(read_size)
1455 data = ifh.read(read_size)
1370 if not data:
1456 if not data:
1371 break
1457 break
1372
1458
1373 data_buffer = ffi.from_buffer(data)
1459 data_buffer = ffi.from_buffer(data)
1374 total_read += len(data_buffer)
1460 total_read += len(data_buffer)
1375 in_buffer.src = data_buffer
1461 in_buffer.src = data_buffer
1376 in_buffer.size = len(data_buffer)
1462 in_buffer.size = len(data_buffer)
1377 in_buffer.pos = 0
1463 in_buffer.pos = 0
1378
1464
1379 while in_buffer.pos < in_buffer.size:
1465 while in_buffer.pos < in_buffer.size:
1380 zresult = lib.ZSTD_compressStream2(
1466 zresult = lib.ZSTD_compressStream2(
1381 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
1467 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
1382 )
1468 )
1383 if lib.ZSTD_isError(zresult):
1469 if lib.ZSTD_isError(zresult):
1384 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
1470 raise ZstdError(
1471 "zstd compress error: %s" % _zstd_error(zresult)
1472 )
1385
1473
1386 if out_buffer.pos:
1474 if out_buffer.pos:
1387 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
1475 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
1388 total_write += out_buffer.pos
1476 total_write += out_buffer.pos
1389 out_buffer.pos = 0
1477 out_buffer.pos = 0
1390
1478
1391 # We've finished reading. Flush the compressor.
1479 # We've finished reading. Flush the compressor.
1392 while True:
1480 while True:
1393 zresult = lib.ZSTD_compressStream2(
1481 zresult = lib.ZSTD_compressStream2(
1394 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1482 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1395 )
1483 )
1396 if lib.ZSTD_isError(zresult):
1484 if lib.ZSTD_isError(zresult):
1397 raise ZstdError(
1485 raise ZstdError(
1398 "error ending compression stream: %s" % _zstd_error(zresult)
1486 "error ending compression stream: %s" % _zstd_error(zresult)
1399 )
1487 )
1400
1488
1401 if out_buffer.pos:
1489 if out_buffer.pos:
1402 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
1490 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
1403 total_write += out_buffer.pos
1491 total_write += out_buffer.pos
1404 out_buffer.pos = 0
1492 out_buffer.pos = 0
1405
1493
1406 if zresult == 0:
1494 if zresult == 0:
1407 break
1495 break
1408
1496
1409 return total_read, total_write
1497 return total_read, total_write
1410
1498
1411 def stream_reader(
1499 def stream_reader(
1412 self, source, size=-1, read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE
1500 self, source, size=-1, read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE
1413 ):
1501 ):
1414 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1502 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1415
1503
1416 try:
1504 try:
1417 size = len(source)
1505 size = len(source)
1418 except Exception:
1506 except Exception:
1419 pass
1507 pass
1420
1508
1421 if size < 0:
1509 if size < 0:
1422 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1510 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1423
1511
1424 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1512 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1425 if lib.ZSTD_isError(zresult):
1513 if lib.ZSTD_isError(zresult):
1426 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1514 raise ZstdError(
1515 "error setting source size: %s" % _zstd_error(zresult)
1516 )
1427
1517
1428 return ZstdCompressionReader(self, source, read_size)
1518 return ZstdCompressionReader(self, source, read_size)
1429
1519
1430 def stream_writer(
1520 def stream_writer(
1431 self,
1521 self,
1432 writer,
1522 writer,
1433 size=-1,
1523 size=-1,
1434 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1524 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1435 write_return_read=False,
1525 write_return_read=False,
1436 ):
1526 ):
1437
1527
1438 if not hasattr(writer, "write"):
1528 if not hasattr(writer, "write"):
1439 raise ValueError("must pass an object with a write() method")
1529 raise ValueError("must pass an object with a write() method")
1440
1530
1441 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1531 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1442
1532
1443 if size < 0:
1533 if size < 0:
1444 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1534 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1445
1535
1446 return ZstdCompressionWriter(self, writer, size, write_size, write_return_read)
1536 return ZstdCompressionWriter(
1537 self, writer, size, write_size, write_return_read
1538 )
1447
1539
1448 write_to = stream_writer
1540 write_to = stream_writer
1449
1541
1450 def read_to_iter(
1542 def read_to_iter(
1451 self,
1543 self,
1452 reader,
1544 reader,
1453 size=-1,
1545 size=-1,
1454 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
1546 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
1455 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1547 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1456 ):
1548 ):
1457 if hasattr(reader, "read"):
1549 if hasattr(reader, "read"):
1458 have_read = True
1550 have_read = True
1459 elif hasattr(reader, "__getitem__"):
1551 elif hasattr(reader, "__getitem__"):
1460 have_read = False
1552 have_read = False
1461 buffer_offset = 0
1553 buffer_offset = 0
1462 size = len(reader)
1554 size = len(reader)
1463 else:
1555 else:
1464 raise ValueError(
1556 raise ValueError(
1465 "must pass an object with a read() method or "
1557 "must pass an object with a read() method or "
1466 "conforms to buffer protocol"
1558 "conforms to buffer protocol"
1467 )
1559 )
1468
1560
1469 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1561 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1470
1562
1471 if size < 0:
1563 if size < 0:
1472 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1564 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1473
1565
1474 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1566 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1475 if lib.ZSTD_isError(zresult):
1567 if lib.ZSTD_isError(zresult):
1476 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1568 raise ZstdError(
1569 "error setting source size: %s" % _zstd_error(zresult)
1570 )
1477
1571
1478 in_buffer = ffi.new("ZSTD_inBuffer *")
1572 in_buffer = ffi.new("ZSTD_inBuffer *")
1479 out_buffer = ffi.new("ZSTD_outBuffer *")
1573 out_buffer = ffi.new("ZSTD_outBuffer *")
1480
1574
1481 in_buffer.src = ffi.NULL
1575 in_buffer.src = ffi.NULL
1482 in_buffer.size = 0
1576 in_buffer.size = 0
1483 in_buffer.pos = 0
1577 in_buffer.pos = 0
1484
1578
1485 dst_buffer = ffi.new("char[]", write_size)
1579 dst_buffer = ffi.new("char[]", write_size)
1486 out_buffer.dst = dst_buffer
1580 out_buffer.dst = dst_buffer
1487 out_buffer.size = write_size
1581 out_buffer.size = write_size
1488 out_buffer.pos = 0
1582 out_buffer.pos = 0
1489
1583
1490 while True:
1584 while True:
1491 # We should never have output data sitting around after a previous
1585 # We should never have output data sitting around after a previous
1492 # iteration.
1586 # iteration.
1493 assert out_buffer.pos == 0
1587 assert out_buffer.pos == 0
1494
1588
1495 # Collect input data.
1589 # Collect input data.
1496 if have_read:
1590 if have_read:
1497 read_result = reader.read(read_size)
1591 read_result = reader.read(read_size)
1498 else:
1592 else:
1499 remaining = len(reader) - buffer_offset
1593 remaining = len(reader) - buffer_offset
1500 slice_size = min(remaining, read_size)
1594 slice_size = min(remaining, read_size)
1501 read_result = reader[buffer_offset : buffer_offset + slice_size]
1595 read_result = reader[buffer_offset : buffer_offset + slice_size]
1502 buffer_offset += slice_size
1596 buffer_offset += slice_size
1503
1597
1504 # No new input data. Break out of the read loop.
1598 # No new input data. Break out of the read loop.
1505 if not read_result:
1599 if not read_result:
1506 break
1600 break
1507
1601
1508 # Feed all read data into the compressor and emit output until
1602 # Feed all read data into the compressor and emit output until
1509 # exhausted.
1603 # exhausted.
1510 read_buffer = ffi.from_buffer(read_result)
1604 read_buffer = ffi.from_buffer(read_result)
1511 in_buffer.src = read_buffer
1605 in_buffer.src = read_buffer
1512 in_buffer.size = len(read_buffer)
1606 in_buffer.size = len(read_buffer)
1513 in_buffer.pos = 0
1607 in_buffer.pos = 0
1514
1608
1515 while in_buffer.pos < in_buffer.size:
1609 while in_buffer.pos < in_buffer.size:
1516 zresult = lib.ZSTD_compressStream2(
1610 zresult = lib.ZSTD_compressStream2(
1517 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
1611 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
1518 )
1612 )
1519 if lib.ZSTD_isError(zresult):
1613 if lib.ZSTD_isError(zresult):
1520 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
1614 raise ZstdError(
1615 "zstd compress error: %s" % _zstd_error(zresult)
1616 )
1521
1617
1522 if out_buffer.pos:
1618 if out_buffer.pos:
1523 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1619 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1524 out_buffer.pos = 0
1620 out_buffer.pos = 0
1525 yield data
1621 yield data
1526
1622
1527 assert out_buffer.pos == 0
1623 assert out_buffer.pos == 0
1528
1624
1529 # And repeat the loop to collect more data.
1625 # And repeat the loop to collect more data.
1530 continue
1626 continue
1531
1627
1532 # If we get here, input is exhausted. End the stream and emit what
1628 # If we get here, input is exhausted. End the stream and emit what
1533 # remains.
1629 # remains.
1534 while True:
1630 while True:
1535 assert out_buffer.pos == 0
1631 assert out_buffer.pos == 0
1536 zresult = lib.ZSTD_compressStream2(
1632 zresult = lib.ZSTD_compressStream2(
1537 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1633 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1538 )
1634 )
1539 if lib.ZSTD_isError(zresult):
1635 if lib.ZSTD_isError(zresult):
1540 raise ZstdError(
1636 raise ZstdError(
1541 "error ending compression stream: %s" % _zstd_error(zresult)
1637 "error ending compression stream: %s" % _zstd_error(zresult)
1542 )
1638 )
1543
1639
1544 if out_buffer.pos:
1640 if out_buffer.pos:
1545 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1641 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1546 out_buffer.pos = 0
1642 out_buffer.pos = 0
1547 yield data
1643 yield data
1548
1644
1549 if zresult == 0:
1645 if zresult == 0:
1550 break
1646 break
1551
1647
1552 read_from = read_to_iter
1648 read_from = read_to_iter
1553
1649
1554 def frame_progression(self):
1650 def frame_progression(self):
1555 progression = lib.ZSTD_getFrameProgression(self._cctx)
1651 progression = lib.ZSTD_getFrameProgression(self._cctx)
1556
1652
1557 return progression.ingested, progression.consumed, progression.produced
1653 return progression.ingested, progression.consumed, progression.produced
1558
1654
1559
1655
1560 class FrameParameters(object):
1656 class FrameParameters(object):
1561 def __init__(self, fparams):
1657 def __init__(self, fparams):
1562 self.content_size = fparams.frameContentSize
1658 self.content_size = fparams.frameContentSize
1563 self.window_size = fparams.windowSize
1659 self.window_size = fparams.windowSize
1564 self.dict_id = fparams.dictID
1660 self.dict_id = fparams.dictID
1565 self.has_checksum = bool(fparams.checksumFlag)
1661 self.has_checksum = bool(fparams.checksumFlag)
1566
1662
1567
1663
1568 def frame_content_size(data):
1664 def frame_content_size(data):
1569 data_buffer = ffi.from_buffer(data)
1665 data_buffer = ffi.from_buffer(data)
1570
1666
1571 size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
1667 size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
1572
1668
1573 if size == lib.ZSTD_CONTENTSIZE_ERROR:
1669 if size == lib.ZSTD_CONTENTSIZE_ERROR:
1574 raise ZstdError("error when determining content size")
1670 raise ZstdError("error when determining content size")
1575 elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
1671 elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
1576 return -1
1672 return -1
1577 else:
1673 else:
1578 return size
1674 return size
1579
1675
1580
1676
1581 def frame_header_size(data):
1677 def frame_header_size(data):
1582 data_buffer = ffi.from_buffer(data)
1678 data_buffer = ffi.from_buffer(data)
1583
1679
1584 zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer))
1680 zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer))
1585 if lib.ZSTD_isError(zresult):
1681 if lib.ZSTD_isError(zresult):
1586 raise ZstdError(
1682 raise ZstdError(
1587 "could not determine frame header size: %s" % _zstd_error(zresult)
1683 "could not determine frame header size: %s" % _zstd_error(zresult)
1588 )
1684 )
1589
1685
1590 return zresult
1686 return zresult
1591
1687
1592
1688
1593 def get_frame_parameters(data):
1689 def get_frame_parameters(data):
1594 params = ffi.new("ZSTD_frameHeader *")
1690 params = ffi.new("ZSTD_frameHeader *")
1595
1691
1596 data_buffer = ffi.from_buffer(data)
1692 data_buffer = ffi.from_buffer(data)
1597 zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer))
1693 zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer))
1598 if lib.ZSTD_isError(zresult):
1694 if lib.ZSTD_isError(zresult):
1599 raise ZstdError("cannot get frame parameters: %s" % _zstd_error(zresult))
1695 raise ZstdError(
1696 "cannot get frame parameters: %s" % _zstd_error(zresult)
1697 )
1600
1698
1601 if zresult:
1699 if zresult:
1602 raise ZstdError("not enough data for frame parameters; need %d bytes" % zresult)
1700 raise ZstdError(
1701 "not enough data for frame parameters; need %d bytes" % zresult
1702 )
1603
1703
1604 return FrameParameters(params[0])
1704 return FrameParameters(params[0])
1605
1705
1606
1706
1607 class ZstdCompressionDict(object):
1707 class ZstdCompressionDict(object):
1608 def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0):
1708 def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0):
1609 assert isinstance(data, bytes_type)
1709 assert isinstance(data, bytes_type)
1610 self._data = data
1710 self._data = data
1611 self.k = k
1711 self.k = k
1612 self.d = d
1712 self.d = d
1613
1713
1614 if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT, DICT_TYPE_FULLDICT):
1714 if dict_type not in (
1715 DICT_TYPE_AUTO,
1716 DICT_TYPE_RAWCONTENT,
1717 DICT_TYPE_FULLDICT,
1718 ):
1615 raise ValueError(
1719 raise ValueError(
1616 "invalid dictionary load mode: %d; must use " "DICT_TYPE_* constants"
1720 "invalid dictionary load mode: %d; must use "
1721 "DICT_TYPE_* constants"
1617 )
1722 )
1618
1723
1619 self._dict_type = dict_type
1724 self._dict_type = dict_type
1620 self._cdict = None
1725 self._cdict = None
1621
1726
1622 def __len__(self):
1727 def __len__(self):
1623 return len(self._data)
1728 return len(self._data)
1624
1729
1625 def dict_id(self):
1730 def dict_id(self):
1626 return int_type(lib.ZDICT_getDictID(self._data, len(self._data)))
1731 return int_type(lib.ZDICT_getDictID(self._data, len(self._data)))
1627
1732
1628 def as_bytes(self):
1733 def as_bytes(self):
1629 return self._data
1734 return self._data
1630
1735
1631 def precompute_compress(self, level=0, compression_params=None):
1736 def precompute_compress(self, level=0, compression_params=None):
1632 if level and compression_params:
1737 if level and compression_params:
1633 raise ValueError("must only specify one of level or " "compression_params")
1738 raise ValueError(
1739 "must only specify one of level or " "compression_params"
1740 )
1634
1741
1635 if not level and not compression_params:
1742 if not level and not compression_params:
1636 raise ValueError("must specify one of level or compression_params")
1743 raise ValueError("must specify one of level or compression_params")
1637
1744
1638 if level:
1745 if level:
1639 cparams = lib.ZSTD_getCParams(level, 0, len(self._data))
1746 cparams = lib.ZSTD_getCParams(level, 0, len(self._data))
1640 else:
1747 else:
1641 cparams = ffi.new("ZSTD_compressionParameters")
1748 cparams = ffi.new("ZSTD_compressionParameters")
1642 cparams.chainLog = compression_params.chain_log
1749 cparams.chainLog = compression_params.chain_log
1643 cparams.hashLog = compression_params.hash_log
1750 cparams.hashLog = compression_params.hash_log
1644 cparams.minMatch = compression_params.min_match
1751 cparams.minMatch = compression_params.min_match
1645 cparams.searchLog = compression_params.search_log
1752 cparams.searchLog = compression_params.search_log
1646 cparams.strategy = compression_params.compression_strategy
1753 cparams.strategy = compression_params.compression_strategy
1647 cparams.targetLength = compression_params.target_length
1754 cparams.targetLength = compression_params.target_length
1648 cparams.windowLog = compression_params.window_log
1755 cparams.windowLog = compression_params.window_log
1649
1756
1650 cdict = lib.ZSTD_createCDict_advanced(
1757 cdict = lib.ZSTD_createCDict_advanced(
1651 self._data,
1758 self._data,
1652 len(self._data),
1759 len(self._data),
1653 lib.ZSTD_dlm_byRef,
1760 lib.ZSTD_dlm_byRef,
1654 self._dict_type,
1761 self._dict_type,
1655 cparams,
1762 cparams,
1656 lib.ZSTD_defaultCMem,
1763 lib.ZSTD_defaultCMem,
1657 )
1764 )
1658 if cdict == ffi.NULL:
1765 if cdict == ffi.NULL:
1659 raise ZstdError("unable to precompute dictionary")
1766 raise ZstdError("unable to precompute dictionary")
1660
1767
1661 self._cdict = ffi.gc(
1768 self._cdict = ffi.gc(
1662 cdict, lib.ZSTD_freeCDict, size=lib.ZSTD_sizeof_CDict(cdict)
1769 cdict, lib.ZSTD_freeCDict, size=lib.ZSTD_sizeof_CDict(cdict)
1663 )
1770 )
1664
1771
1665 @property
1772 @property
1666 def _ddict(self):
1773 def _ddict(self):
1667 ddict = lib.ZSTD_createDDict_advanced(
1774 ddict = lib.ZSTD_createDDict_advanced(
1668 self._data,
1775 self._data,
1669 len(self._data),
1776 len(self._data),
1670 lib.ZSTD_dlm_byRef,
1777 lib.ZSTD_dlm_byRef,
1671 self._dict_type,
1778 self._dict_type,
1672 lib.ZSTD_defaultCMem,
1779 lib.ZSTD_defaultCMem,
1673 )
1780 )
1674
1781
1675 if ddict == ffi.NULL:
1782 if ddict == ffi.NULL:
1676 raise ZstdError("could not create decompression dict")
1783 raise ZstdError("could not create decompression dict")
1677
1784
1678 ddict = ffi.gc(ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict))
1785 ddict = ffi.gc(
1786 ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict)
1787 )
1679 self.__dict__["_ddict"] = ddict
1788 self.__dict__["_ddict"] = ddict
1680
1789
1681 return ddict
1790 return ddict
1682
1791
1683
1792
1684 def train_dictionary(
1793 def train_dictionary(
1685 dict_size,
1794 dict_size,
1686 samples,
1795 samples,
1687 k=0,
1796 k=0,
1688 d=0,
1797 d=0,
1689 notifications=0,
1798 notifications=0,
1690 dict_id=0,
1799 dict_id=0,
1691 level=0,
1800 level=0,
1692 steps=0,
1801 steps=0,
1693 threads=0,
1802 threads=0,
1694 ):
1803 ):
1695 if not isinstance(samples, list):
1804 if not isinstance(samples, list):
1696 raise TypeError("samples must be a list")
1805 raise TypeError("samples must be a list")
1697
1806
1698 if threads < 0:
1807 if threads < 0:
1699 threads = _cpu_count()
1808 threads = _cpu_count()
1700
1809
1701 total_size = sum(map(len, samples))
1810 total_size = sum(map(len, samples))
1702
1811
1703 samples_buffer = new_nonzero("char[]", total_size)
1812 samples_buffer = new_nonzero("char[]", total_size)
1704 sample_sizes = new_nonzero("size_t[]", len(samples))
1813 sample_sizes = new_nonzero("size_t[]", len(samples))
1705
1814
1706 offset = 0
1815 offset = 0
1707 for i, sample in enumerate(samples):
1816 for i, sample in enumerate(samples):
1708 if not isinstance(sample, bytes_type):
1817 if not isinstance(sample, bytes_type):
1709 raise ValueError("samples must be bytes")
1818 raise ValueError("samples must be bytes")
1710
1819
1711 l = len(sample)
1820 l = len(sample)
1712 ffi.memmove(samples_buffer + offset, sample, l)
1821 ffi.memmove(samples_buffer + offset, sample, l)
1713 offset += l
1822 offset += l
1714 sample_sizes[i] = l
1823 sample_sizes[i] = l
1715
1824
1716 dict_data = new_nonzero("char[]", dict_size)
1825 dict_data = new_nonzero("char[]", dict_size)
1717
1826
1718 dparams = ffi.new("ZDICT_cover_params_t *")[0]
1827 dparams = ffi.new("ZDICT_cover_params_t *")[0]
1719 dparams.k = k
1828 dparams.k = k
1720 dparams.d = d
1829 dparams.d = d
1721 dparams.steps = steps
1830 dparams.steps = steps
1722 dparams.nbThreads = threads
1831 dparams.nbThreads = threads
1723 dparams.zParams.notificationLevel = notifications
1832 dparams.zParams.notificationLevel = notifications
1724 dparams.zParams.dictID = dict_id
1833 dparams.zParams.dictID = dict_id
1725 dparams.zParams.compressionLevel = level
1834 dparams.zParams.compressionLevel = level
1726
1835
1727 if (
1836 if (
1728 not dparams.k
1837 not dparams.k
1729 and not dparams.d
1838 and not dparams.d
1730 and not dparams.steps
1839 and not dparams.steps
1731 and not dparams.nbThreads
1840 and not dparams.nbThreads
1732 and not dparams.zParams.notificationLevel
1841 and not dparams.zParams.notificationLevel
1733 and not dparams.zParams.dictID
1842 and not dparams.zParams.dictID
1734 and not dparams.zParams.compressionLevel
1843 and not dparams.zParams.compressionLevel
1735 ):
1844 ):
1736 zresult = lib.ZDICT_trainFromBuffer(
1845 zresult = lib.ZDICT_trainFromBuffer(
1737 ffi.addressof(dict_data),
1846 ffi.addressof(dict_data),
1738 dict_size,
1847 dict_size,
1739 ffi.addressof(samples_buffer),
1848 ffi.addressof(samples_buffer),
1740 ffi.addressof(sample_sizes, 0),
1849 ffi.addressof(sample_sizes, 0),
1741 len(samples),
1850 len(samples),
1742 )
1851 )
1743 elif dparams.steps or dparams.nbThreads:
1852 elif dparams.steps or dparams.nbThreads:
1744 zresult = lib.ZDICT_optimizeTrainFromBuffer_cover(
1853 zresult = lib.ZDICT_optimizeTrainFromBuffer_cover(
1745 ffi.addressof(dict_data),
1854 ffi.addressof(dict_data),
1746 dict_size,
1855 dict_size,
1747 ffi.addressof(samples_buffer),
1856 ffi.addressof(samples_buffer),
1748 ffi.addressof(sample_sizes, 0),
1857 ffi.addressof(sample_sizes, 0),
1749 len(samples),
1858 len(samples),
1750 ffi.addressof(dparams),
1859 ffi.addressof(dparams),
1751 )
1860 )
1752 else:
1861 else:
1753 zresult = lib.ZDICT_trainFromBuffer_cover(
1862 zresult = lib.ZDICT_trainFromBuffer_cover(
1754 ffi.addressof(dict_data),
1863 ffi.addressof(dict_data),
1755 dict_size,
1864 dict_size,
1756 ffi.addressof(samples_buffer),
1865 ffi.addressof(samples_buffer),
1757 ffi.addressof(sample_sizes, 0),
1866 ffi.addressof(sample_sizes, 0),
1758 len(samples),
1867 len(samples),
1759 dparams,
1868 dparams,
1760 )
1869 )
1761
1870
1762 if lib.ZDICT_isError(zresult):
1871 if lib.ZDICT_isError(zresult):
1763 msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode("utf-8")
1872 msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode("utf-8")
1764 raise ZstdError("cannot train dict: %s" % msg)
1873 raise ZstdError("cannot train dict: %s" % msg)
1765
1874
1766 return ZstdCompressionDict(
1875 return ZstdCompressionDict(
1767 ffi.buffer(dict_data, zresult)[:],
1876 ffi.buffer(dict_data, zresult)[:],
1768 dict_type=DICT_TYPE_FULLDICT,
1877 dict_type=DICT_TYPE_FULLDICT,
1769 k=dparams.k,
1878 k=dparams.k,
1770 d=dparams.d,
1879 d=dparams.d,
1771 )
1880 )
1772
1881
1773
1882
1774 class ZstdDecompressionObj(object):
1883 class ZstdDecompressionObj(object):
1775 def __init__(self, decompressor, write_size):
1884 def __init__(self, decompressor, write_size):
1776 self._decompressor = decompressor
1885 self._decompressor = decompressor
1777 self._write_size = write_size
1886 self._write_size = write_size
1778 self._finished = False
1887 self._finished = False
1779
1888
1780 def decompress(self, data):
1889 def decompress(self, data):
1781 if self._finished:
1890 if self._finished:
1782 raise ZstdError("cannot use a decompressobj multiple times")
1891 raise ZstdError("cannot use a decompressobj multiple times")
1783
1892
1784 in_buffer = ffi.new("ZSTD_inBuffer *")
1893 in_buffer = ffi.new("ZSTD_inBuffer *")
1785 out_buffer = ffi.new("ZSTD_outBuffer *")
1894 out_buffer = ffi.new("ZSTD_outBuffer *")
1786
1895
1787 data_buffer = ffi.from_buffer(data)
1896 data_buffer = ffi.from_buffer(data)
1788
1897
1789 if len(data_buffer) == 0:
1898 if len(data_buffer) == 0:
1790 return b""
1899 return b""
1791
1900
1792 in_buffer.src = data_buffer
1901 in_buffer.src = data_buffer
1793 in_buffer.size = len(data_buffer)
1902 in_buffer.size = len(data_buffer)
1794 in_buffer.pos = 0
1903 in_buffer.pos = 0
1795
1904
1796 dst_buffer = ffi.new("char[]", self._write_size)
1905 dst_buffer = ffi.new("char[]", self._write_size)
1797 out_buffer.dst = dst_buffer
1906 out_buffer.dst = dst_buffer
1798 out_buffer.size = len(dst_buffer)
1907 out_buffer.size = len(dst_buffer)
1799 out_buffer.pos = 0
1908 out_buffer.pos = 0
1800
1909
1801 chunks = []
1910 chunks = []
1802
1911
1803 while True:
1912 while True:
1804 zresult = lib.ZSTD_decompressStream(
1913 zresult = lib.ZSTD_decompressStream(
1805 self._decompressor._dctx, out_buffer, in_buffer
1914 self._decompressor._dctx, out_buffer, in_buffer
1806 )
1915 )
1807 if lib.ZSTD_isError(zresult):
1916 if lib.ZSTD_isError(zresult):
1808 raise ZstdError("zstd decompressor error: %s" % _zstd_error(zresult))
1917 raise ZstdError(
1918 "zstd decompressor error: %s" % _zstd_error(zresult)
1919 )
1809
1920
1810 if zresult == 0:
1921 if zresult == 0:
1811 self._finished = True
1922 self._finished = True
1812 self._decompressor = None
1923 self._decompressor = None
1813
1924
1814 if out_buffer.pos:
1925 if out_buffer.pos:
1815 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
1926 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
1816
1927
1817 if zresult == 0 or (
1928 if zresult == 0 or (
1818 in_buffer.pos == in_buffer.size and out_buffer.pos == 0
1929 in_buffer.pos == in_buffer.size and out_buffer.pos == 0
1819 ):
1930 ):
1820 break
1931 break
1821
1932
1822 out_buffer.pos = 0
1933 out_buffer.pos = 0
1823
1934
1824 return b"".join(chunks)
1935 return b"".join(chunks)
1825
1936
1826 def flush(self, length=0):
1937 def flush(self, length=0):
1827 pass
1938 pass
1828
1939
1829
1940
1830 class ZstdDecompressionReader(object):
1941 class ZstdDecompressionReader(object):
1831 def __init__(self, decompressor, source, read_size, read_across_frames):
1942 def __init__(self, decompressor, source, read_size, read_across_frames):
1832 self._decompressor = decompressor
1943 self._decompressor = decompressor
1833 self._source = source
1944 self._source = source
1834 self._read_size = read_size
1945 self._read_size = read_size
1835 self._read_across_frames = bool(read_across_frames)
1946 self._read_across_frames = bool(read_across_frames)
1836 self._entered = False
1947 self._entered = False
1837 self._closed = False
1948 self._closed = False
1838 self._bytes_decompressed = 0
1949 self._bytes_decompressed = 0
1839 self._finished_input = False
1950 self._finished_input = False
1840 self._finished_output = False
1951 self._finished_output = False
1841 self._in_buffer = ffi.new("ZSTD_inBuffer *")
1952 self._in_buffer = ffi.new("ZSTD_inBuffer *")
1842 # Holds a ref to self._in_buffer.src.
1953 # Holds a ref to self._in_buffer.src.
1843 self._source_buffer = None
1954 self._source_buffer = None
1844
1955
1845 def __enter__(self):
1956 def __enter__(self):
1846 if self._entered:
1957 if self._entered:
1847 raise ValueError("cannot __enter__ multiple times")
1958 raise ValueError("cannot __enter__ multiple times")
1848
1959
1849 self._entered = True
1960 self._entered = True
1850 return self
1961 return self
1851
1962
1852 def __exit__(self, exc_type, exc_value, exc_tb):
1963 def __exit__(self, exc_type, exc_value, exc_tb):
1853 self._entered = False
1964 self._entered = False
1854 self._closed = True
1965 self._closed = True
1855 self._source = None
1966 self._source = None
1856 self._decompressor = None
1967 self._decompressor = None
1857
1968
1858 return False
1969 return False
1859
1970
1860 def readable(self):
1971 def readable(self):
1861 return True
1972 return True
1862
1973
1863 def writable(self):
1974 def writable(self):
1864 return False
1975 return False
1865
1976
1866 def seekable(self):
1977 def seekable(self):
1867 return True
1978 return True
1868
1979
1869 def readline(self):
1980 def readline(self):
1870 raise io.UnsupportedOperation()
1981 raise io.UnsupportedOperation()
1871
1982
1872 def readlines(self):
1983 def readlines(self):
1873 raise io.UnsupportedOperation()
1984 raise io.UnsupportedOperation()
1874
1985
1875 def write(self, data):
1986 def write(self, data):
1876 raise io.UnsupportedOperation()
1987 raise io.UnsupportedOperation()
1877
1988
1878 def writelines(self, lines):
1989 def writelines(self, lines):
1879 raise io.UnsupportedOperation()
1990 raise io.UnsupportedOperation()
1880
1991
1881 def isatty(self):
1992 def isatty(self):
1882 return False
1993 return False
1883
1994
1884 def flush(self):
1995 def flush(self):
1885 return None
1996 return None
1886
1997
1887 def close(self):
1998 def close(self):
1888 self._closed = True
1999 self._closed = True
1889 return None
2000 return None
1890
2001
1891 @property
2002 @property
1892 def closed(self):
2003 def closed(self):
1893 return self._closed
2004 return self._closed
1894
2005
1895 def tell(self):
2006 def tell(self):
1896 return self._bytes_decompressed
2007 return self._bytes_decompressed
1897
2008
1898 def readall(self):
2009 def readall(self):
1899 chunks = []
2010 chunks = []
1900
2011
1901 while True:
2012 while True:
1902 chunk = self.read(1048576)
2013 chunk = self.read(1048576)
1903 if not chunk:
2014 if not chunk:
1904 break
2015 break
1905
2016
1906 chunks.append(chunk)
2017 chunks.append(chunk)
1907
2018
1908 return b"".join(chunks)
2019 return b"".join(chunks)
1909
2020
1910 def __iter__(self):
2021 def __iter__(self):
1911 raise io.UnsupportedOperation()
2022 raise io.UnsupportedOperation()
1912
2023
1913 def __next__(self):
2024 def __next__(self):
1914 raise io.UnsupportedOperation()
2025 raise io.UnsupportedOperation()
1915
2026
1916 next = __next__
2027 next = __next__
1917
2028
1918 def _read_input(self):
2029 def _read_input(self):
1919 # We have data left over in the input buffer. Use it.
2030 # We have data left over in the input buffer. Use it.
1920 if self._in_buffer.pos < self._in_buffer.size:
2031 if self._in_buffer.pos < self._in_buffer.size:
1921 return
2032 return
1922
2033
1923 # All input data exhausted. Nothing to do.
2034 # All input data exhausted. Nothing to do.
1924 if self._finished_input:
2035 if self._finished_input:
1925 return
2036 return
1926
2037
1927 # Else populate the input buffer from our source.
2038 # Else populate the input buffer from our source.
1928 if hasattr(self._source, "read"):
2039 if hasattr(self._source, "read"):
1929 data = self._source.read(self._read_size)
2040 data = self._source.read(self._read_size)
1930
2041
1931 if not data:
2042 if not data:
1932 self._finished_input = True
2043 self._finished_input = True
1933 return
2044 return
1934
2045
1935 self._source_buffer = ffi.from_buffer(data)
2046 self._source_buffer = ffi.from_buffer(data)
1936 self._in_buffer.src = self._source_buffer
2047 self._in_buffer.src = self._source_buffer
1937 self._in_buffer.size = len(self._source_buffer)
2048 self._in_buffer.size = len(self._source_buffer)
1938 self._in_buffer.pos = 0
2049 self._in_buffer.pos = 0
1939 else:
2050 else:
1940 self._source_buffer = ffi.from_buffer(self._source)
2051 self._source_buffer = ffi.from_buffer(self._source)
1941 self._in_buffer.src = self._source_buffer
2052 self._in_buffer.src = self._source_buffer
1942 self._in_buffer.size = len(self._source_buffer)
2053 self._in_buffer.size = len(self._source_buffer)
1943 self._in_buffer.pos = 0
2054 self._in_buffer.pos = 0
1944
2055
1945 def _decompress_into_buffer(self, out_buffer):
2056 def _decompress_into_buffer(self, out_buffer):
1946 """Decompress available input into an output buffer.
2057 """Decompress available input into an output buffer.
1947
2058
1948 Returns True if data in output buffer should be emitted.
2059 Returns True if data in output buffer should be emitted.
1949 """
2060 """
1950 zresult = lib.ZSTD_decompressStream(
2061 zresult = lib.ZSTD_decompressStream(
1951 self._decompressor._dctx, out_buffer, self._in_buffer
2062 self._decompressor._dctx, out_buffer, self._in_buffer
1952 )
2063 )
1953
2064
1954 if self._in_buffer.pos == self._in_buffer.size:
2065 if self._in_buffer.pos == self._in_buffer.size:
1955 self._in_buffer.src = ffi.NULL
2066 self._in_buffer.src = ffi.NULL
1956 self._in_buffer.pos = 0
2067 self._in_buffer.pos = 0
1957 self._in_buffer.size = 0
2068 self._in_buffer.size = 0
1958 self._source_buffer = None
2069 self._source_buffer = None
1959
2070
1960 if not hasattr(self._source, "read"):
2071 if not hasattr(self._source, "read"):
1961 self._finished_input = True
2072 self._finished_input = True
1962
2073
1963 if lib.ZSTD_isError(zresult):
2074 if lib.ZSTD_isError(zresult):
1964 raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult))
2075 raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult))
1965
2076
1966 # Emit data if there is data AND either:
2077 # Emit data if there is data AND either:
1967 # a) output buffer is full (read amount is satisfied)
2078 # a) output buffer is full (read amount is satisfied)
1968 # b) we're at end of a frame and not in frame spanning mode
2079 # b) we're at end of a frame and not in frame spanning mode
1969 return out_buffer.pos and (
2080 return out_buffer.pos and (
1970 out_buffer.pos == out_buffer.size
2081 out_buffer.pos == out_buffer.size
1971 or zresult == 0
2082 or zresult == 0
1972 and not self._read_across_frames
2083 and not self._read_across_frames
1973 )
2084 )
1974
2085
1975 def read(self, size=-1):
2086 def read(self, size=-1):
1976 if self._closed:
2087 if self._closed:
1977 raise ValueError("stream is closed")
2088 raise ValueError("stream is closed")
1978
2089
1979 if size < -1:
2090 if size < -1:
1980 raise ValueError("cannot read negative amounts less than -1")
2091 raise ValueError("cannot read negative amounts less than -1")
1981
2092
1982 if size == -1:
2093 if size == -1:
1983 # This is recursive. But it gets the job done.
2094 # This is recursive. But it gets the job done.
1984 return self.readall()
2095 return self.readall()
1985
2096
1986 if self._finished_output or size == 0:
2097 if self._finished_output or size == 0:
1987 return b""
2098 return b""
1988
2099
1989 # We /could/ call into readinto() here. But that introduces more
2100 # We /could/ call into readinto() here. But that introduces more
1990 # overhead.
2101 # overhead.
1991 dst_buffer = ffi.new("char[]", size)
2102 dst_buffer = ffi.new("char[]", size)
1992 out_buffer = ffi.new("ZSTD_outBuffer *")
2103 out_buffer = ffi.new("ZSTD_outBuffer *")
1993 out_buffer.dst = dst_buffer
2104 out_buffer.dst = dst_buffer
1994 out_buffer.size = size
2105 out_buffer.size = size
1995 out_buffer.pos = 0
2106 out_buffer.pos = 0
1996
2107
1997 self._read_input()
2108 self._read_input()
1998 if self._decompress_into_buffer(out_buffer):
2109 if self._decompress_into_buffer(out_buffer):
1999 self._bytes_decompressed += out_buffer.pos
2110 self._bytes_decompressed += out_buffer.pos
2000 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2111 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2001
2112
2002 while not self._finished_input:
2113 while not self._finished_input:
2003 self._read_input()
2114 self._read_input()
2004 if self._decompress_into_buffer(out_buffer):
2115 if self._decompress_into_buffer(out_buffer):
2005 self._bytes_decompressed += out_buffer.pos
2116 self._bytes_decompressed += out_buffer.pos
2006 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2117 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2007
2118
2008 self._bytes_decompressed += out_buffer.pos
2119 self._bytes_decompressed += out_buffer.pos
2009 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2120 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2010
2121
2011 def readinto(self, b):
2122 def readinto(self, b):
2012 if self._closed:
2123 if self._closed:
2013 raise ValueError("stream is closed")
2124 raise ValueError("stream is closed")
2014
2125
2015 if self._finished_output:
2126 if self._finished_output:
2016 return 0
2127 return 0
2017
2128
2018 # TODO use writable=True once we require CFFI >= 1.12.
2129 # TODO use writable=True once we require CFFI >= 1.12.
2019 dest_buffer = ffi.from_buffer(b)
2130 dest_buffer = ffi.from_buffer(b)
2020 ffi.memmove(b, b"", 0)
2131 ffi.memmove(b, b"", 0)
2021 out_buffer = ffi.new("ZSTD_outBuffer *")
2132 out_buffer = ffi.new("ZSTD_outBuffer *")
2022 out_buffer.dst = dest_buffer
2133 out_buffer.dst = dest_buffer
2023 out_buffer.size = len(dest_buffer)
2134 out_buffer.size = len(dest_buffer)
2024 out_buffer.pos = 0
2135 out_buffer.pos = 0
2025
2136
2026 self._read_input()
2137 self._read_input()
2027 if self._decompress_into_buffer(out_buffer):
2138 if self._decompress_into_buffer(out_buffer):
2028 self._bytes_decompressed += out_buffer.pos
2139 self._bytes_decompressed += out_buffer.pos
2029 return out_buffer.pos
2140 return out_buffer.pos
2030
2141
2031 while not self._finished_input:
2142 while not self._finished_input:
2032 self._read_input()
2143 self._read_input()
2033 if self._decompress_into_buffer(out_buffer):
2144 if self._decompress_into_buffer(out_buffer):
2034 self._bytes_decompressed += out_buffer.pos
2145 self._bytes_decompressed += out_buffer.pos
2035 return out_buffer.pos
2146 return out_buffer.pos
2036
2147
2037 self._bytes_decompressed += out_buffer.pos
2148 self._bytes_decompressed += out_buffer.pos
2038 return out_buffer.pos
2149 return out_buffer.pos
2039
2150
2040 def read1(self, size=-1):
2151 def read1(self, size=-1):
2041 if self._closed:
2152 if self._closed:
2042 raise ValueError("stream is closed")
2153 raise ValueError("stream is closed")
2043
2154
2044 if size < -1:
2155 if size < -1:
2045 raise ValueError("cannot read negative amounts less than -1")
2156 raise ValueError("cannot read negative amounts less than -1")
2046
2157
2047 if self._finished_output or size == 0:
2158 if self._finished_output or size == 0:
2048 return b""
2159 return b""
2049
2160
2050 # -1 returns arbitrary number of bytes.
2161 # -1 returns arbitrary number of bytes.
2051 if size == -1:
2162 if size == -1:
2052 size = DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE
2163 size = DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE
2053
2164
2054 dst_buffer = ffi.new("char[]", size)
2165 dst_buffer = ffi.new("char[]", size)
2055 out_buffer = ffi.new("ZSTD_outBuffer *")
2166 out_buffer = ffi.new("ZSTD_outBuffer *")
2056 out_buffer.dst = dst_buffer
2167 out_buffer.dst = dst_buffer
2057 out_buffer.size = size
2168 out_buffer.size = size
2058 out_buffer.pos = 0
2169 out_buffer.pos = 0
2059
2170
2060 # read1() dictates that we can perform at most 1 call to underlying
2171 # read1() dictates that we can perform at most 1 call to underlying
2061 # stream to get input. However, we can't satisfy this restriction with
2172 # stream to get input. However, we can't satisfy this restriction with
2062 # decompression because not all input generates output. So we allow
2173 # decompression because not all input generates output. So we allow
2063 # multiple read(). But unlike read(), we stop once we have any output.
2174 # multiple read(). But unlike read(), we stop once we have any output.
2064 while not self._finished_input:
2175 while not self._finished_input:
2065 self._read_input()
2176 self._read_input()
2066 self._decompress_into_buffer(out_buffer)
2177 self._decompress_into_buffer(out_buffer)
2067
2178
2068 if out_buffer.pos:
2179 if out_buffer.pos:
2069 break
2180 break
2070
2181
2071 self._bytes_decompressed += out_buffer.pos
2182 self._bytes_decompressed += out_buffer.pos
2072 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2183 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2073
2184
2074 def readinto1(self, b):
2185 def readinto1(self, b):
2075 if self._closed:
2186 if self._closed:
2076 raise ValueError("stream is closed")
2187 raise ValueError("stream is closed")
2077
2188
2078 if self._finished_output:
2189 if self._finished_output:
2079 return 0
2190 return 0
2080
2191
2081 # TODO use writable=True once we require CFFI >= 1.12.
2192 # TODO use writable=True once we require CFFI >= 1.12.
2082 dest_buffer = ffi.from_buffer(b)
2193 dest_buffer = ffi.from_buffer(b)
2083 ffi.memmove(b, b"", 0)
2194 ffi.memmove(b, b"", 0)
2084
2195
2085 out_buffer = ffi.new("ZSTD_outBuffer *")
2196 out_buffer = ffi.new("ZSTD_outBuffer *")
2086 out_buffer.dst = dest_buffer
2197 out_buffer.dst = dest_buffer
2087 out_buffer.size = len(dest_buffer)
2198 out_buffer.size = len(dest_buffer)
2088 out_buffer.pos = 0
2199 out_buffer.pos = 0
2089
2200
2090 while not self._finished_input and not self._finished_output:
2201 while not self._finished_input and not self._finished_output:
2091 self._read_input()
2202 self._read_input()
2092 self._decompress_into_buffer(out_buffer)
2203 self._decompress_into_buffer(out_buffer)
2093
2204
2094 if out_buffer.pos:
2205 if out_buffer.pos:
2095 break
2206 break
2096
2207
2097 self._bytes_decompressed += out_buffer.pos
2208 self._bytes_decompressed += out_buffer.pos
2098 return out_buffer.pos
2209 return out_buffer.pos
2099
2210
2100 def seek(self, pos, whence=os.SEEK_SET):
2211 def seek(self, pos, whence=os.SEEK_SET):
2101 if self._closed:
2212 if self._closed:
2102 raise ValueError("stream is closed")
2213 raise ValueError("stream is closed")
2103
2214
2104 read_amount = 0
2215 read_amount = 0
2105
2216
2106 if whence == os.SEEK_SET:
2217 if whence == os.SEEK_SET:
2107 if pos < 0:
2218 if pos < 0:
2108 raise ValueError("cannot seek to negative position with SEEK_SET")
2219 raise ValueError(
2220 "cannot seek to negative position with SEEK_SET"
2221 )
2109
2222
2110 if pos < self._bytes_decompressed:
2223 if pos < self._bytes_decompressed:
2111 raise ValueError("cannot seek zstd decompression stream " "backwards")
2224 raise ValueError(
2225 "cannot seek zstd decompression stream " "backwards"
2226 )
2112
2227
2113 read_amount = pos - self._bytes_decompressed
2228 read_amount = pos - self._bytes_decompressed
2114
2229
2115 elif whence == os.SEEK_CUR:
2230 elif whence == os.SEEK_CUR:
2116 if pos < 0:
2231 if pos < 0:
2117 raise ValueError("cannot seek zstd decompression stream " "backwards")
2232 raise ValueError(
2233 "cannot seek zstd decompression stream " "backwards"
2234 )
2118
2235
2119 read_amount = pos
2236 read_amount = pos
2120 elif whence == os.SEEK_END:
2237 elif whence == os.SEEK_END:
2121 raise ValueError(
2238 raise ValueError(
2122 "zstd decompression streams cannot be seeked " "with SEEK_END"
2239 "zstd decompression streams cannot be seeked " "with SEEK_END"
2123 )
2240 )
2124
2241
2125 while read_amount:
2242 while read_amount:
2126 result = self.read(min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE))
2243 result = self.read(
2244 min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
2245 )
2127
2246
2128 if not result:
2247 if not result:
2129 break
2248 break
2130
2249
2131 read_amount -= len(result)
2250 read_amount -= len(result)
2132
2251
2133 return self._bytes_decompressed
2252 return self._bytes_decompressed
2134
2253
2135
2254
2136 class ZstdDecompressionWriter(object):
2255 class ZstdDecompressionWriter(object):
2137 def __init__(self, decompressor, writer, write_size, write_return_read):
2256 def __init__(self, decompressor, writer, write_size, write_return_read):
2138 decompressor._ensure_dctx()
2257 decompressor._ensure_dctx()
2139
2258
2140 self._decompressor = decompressor
2259 self._decompressor = decompressor
2141 self._writer = writer
2260 self._writer = writer
2142 self._write_size = write_size
2261 self._write_size = write_size
2143 self._write_return_read = bool(write_return_read)
2262 self._write_return_read = bool(write_return_read)
2144 self._entered = False
2263 self._entered = False
2145 self._closed = False
2264 self._closed = False
2146
2265
2147 def __enter__(self):
2266 def __enter__(self):
2148 if self._closed:
2267 if self._closed:
2149 raise ValueError("stream is closed")
2268 raise ValueError("stream is closed")
2150
2269
2151 if self._entered:
2270 if self._entered:
2152 raise ZstdError("cannot __enter__ multiple times")
2271 raise ZstdError("cannot __enter__ multiple times")
2153
2272
2154 self._entered = True
2273 self._entered = True
2155
2274
2156 return self
2275 return self
2157
2276
2158 def __exit__(self, exc_type, exc_value, exc_tb):
2277 def __exit__(self, exc_type, exc_value, exc_tb):
2159 self._entered = False
2278 self._entered = False
2160 self.close()
2279 self.close()
2161
2280
2162 def memory_size(self):
2281 def memory_size(self):
2163 return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx)
2282 return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx)
2164
2283
2165 def close(self):
2284 def close(self):
2166 if self._closed:
2285 if self._closed:
2167 return
2286 return
2168
2287
2169 try:
2288 try:
2170 self.flush()
2289 self.flush()
2171 finally:
2290 finally:
2172 self._closed = True
2291 self._closed = True
2173
2292
2174 f = getattr(self._writer, "close", None)
2293 f = getattr(self._writer, "close", None)
2175 if f:
2294 if f:
2176 f()
2295 f()
2177
2296
2178 @property
2297 @property
2179 def closed(self):
2298 def closed(self):
2180 return self._closed
2299 return self._closed
2181
2300
2182 def fileno(self):
2301 def fileno(self):
2183 f = getattr(self._writer, "fileno", None)
2302 f = getattr(self._writer, "fileno", None)
2184 if f:
2303 if f:
2185 return f()
2304 return f()
2186 else:
2305 else:
2187 raise OSError("fileno not available on underlying writer")
2306 raise OSError("fileno not available on underlying writer")
2188
2307
2189 def flush(self):
2308 def flush(self):
2190 if self._closed:
2309 if self._closed:
2191 raise ValueError("stream is closed")
2310 raise ValueError("stream is closed")
2192
2311
2193 f = getattr(self._writer, "flush", None)
2312 f = getattr(self._writer, "flush", None)
2194 if f:
2313 if f:
2195 return f()
2314 return f()
2196
2315
2197 def isatty(self):
2316 def isatty(self):
2198 return False
2317 return False
2199
2318
2200 def readable(self):
2319 def readable(self):
2201 return False
2320 return False
2202
2321
2203 def readline(self, size=-1):
2322 def readline(self, size=-1):
2204 raise io.UnsupportedOperation()
2323 raise io.UnsupportedOperation()
2205
2324
2206 def readlines(self, hint=-1):
2325 def readlines(self, hint=-1):
2207 raise io.UnsupportedOperation()
2326 raise io.UnsupportedOperation()
2208
2327
2209 def seek(self, offset, whence=None):
2328 def seek(self, offset, whence=None):
2210 raise io.UnsupportedOperation()
2329 raise io.UnsupportedOperation()
2211
2330
2212 def seekable(self):
2331 def seekable(self):
2213 return False
2332 return False
2214
2333
2215 def tell(self):
2334 def tell(self):
2216 raise io.UnsupportedOperation()
2335 raise io.UnsupportedOperation()
2217
2336
2218 def truncate(self, size=None):
2337 def truncate(self, size=None):
2219 raise io.UnsupportedOperation()
2338 raise io.UnsupportedOperation()
2220
2339
2221 def writable(self):
2340 def writable(self):
2222 return True
2341 return True
2223
2342
2224 def writelines(self, lines):
2343 def writelines(self, lines):
2225 raise io.UnsupportedOperation()
2344 raise io.UnsupportedOperation()
2226
2345
2227 def read(self, size=-1):
2346 def read(self, size=-1):
2228 raise io.UnsupportedOperation()
2347 raise io.UnsupportedOperation()
2229
2348
2230 def readall(self):
2349 def readall(self):
2231 raise io.UnsupportedOperation()
2350 raise io.UnsupportedOperation()
2232
2351
2233 def readinto(self, b):
2352 def readinto(self, b):
2234 raise io.UnsupportedOperation()
2353 raise io.UnsupportedOperation()
2235
2354
2236 def write(self, data):
2355 def write(self, data):
2237 if self._closed:
2356 if self._closed:
2238 raise ValueError("stream is closed")
2357 raise ValueError("stream is closed")
2239
2358
2240 total_write = 0
2359 total_write = 0
2241
2360
2242 in_buffer = ffi.new("ZSTD_inBuffer *")
2361 in_buffer = ffi.new("ZSTD_inBuffer *")
2243 out_buffer = ffi.new("ZSTD_outBuffer *")
2362 out_buffer = ffi.new("ZSTD_outBuffer *")
2244
2363
2245 data_buffer = ffi.from_buffer(data)
2364 data_buffer = ffi.from_buffer(data)
2246 in_buffer.src = data_buffer
2365 in_buffer.src = data_buffer
2247 in_buffer.size = len(data_buffer)
2366 in_buffer.size = len(data_buffer)
2248 in_buffer.pos = 0
2367 in_buffer.pos = 0
2249
2368
2250 dst_buffer = ffi.new("char[]", self._write_size)
2369 dst_buffer = ffi.new("char[]", self._write_size)
2251 out_buffer.dst = dst_buffer
2370 out_buffer.dst = dst_buffer
2252 out_buffer.size = len(dst_buffer)
2371 out_buffer.size = len(dst_buffer)
2253 out_buffer.pos = 0
2372 out_buffer.pos = 0
2254
2373
2255 dctx = self._decompressor._dctx
2374 dctx = self._decompressor._dctx
2256
2375
2257 while in_buffer.pos < in_buffer.size:
2376 while in_buffer.pos < in_buffer.size:
2258 zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer)
2377 zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer)
2259 if lib.ZSTD_isError(zresult):
2378 if lib.ZSTD_isError(zresult):
2260 raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult))
2379 raise ZstdError(
2380 "zstd decompress error: %s" % _zstd_error(zresult)
2381 )
2261
2382
2262 if out_buffer.pos:
2383 if out_buffer.pos:
2263 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
2384 self._writer.write(
2385 ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2386 )
2264 total_write += out_buffer.pos
2387 total_write += out_buffer.pos
2265 out_buffer.pos = 0
2388 out_buffer.pos = 0
2266
2389
2267 if self._write_return_read:
2390 if self._write_return_read:
2268 return in_buffer.pos
2391 return in_buffer.pos
2269 else:
2392 else:
2270 return total_write
2393 return total_write
2271
2394
2272
2395
2273 class ZstdDecompressor(object):
2396 class ZstdDecompressor(object):
2274 def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1):
2397 def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1):
2275 self._dict_data = dict_data
2398 self._dict_data = dict_data
2276 self._max_window_size = max_window_size
2399 self._max_window_size = max_window_size
2277 self._format = format
2400 self._format = format
2278
2401
2279 dctx = lib.ZSTD_createDCtx()
2402 dctx = lib.ZSTD_createDCtx()
2280 if dctx == ffi.NULL:
2403 if dctx == ffi.NULL:
2281 raise MemoryError()
2404 raise MemoryError()
2282
2405
2283 self._dctx = dctx
2406 self._dctx = dctx
2284
2407
2285 # Defer setting up garbage collection until full state is loaded so
2408 # Defer setting up garbage collection until full state is loaded so
2286 # the memory size is more accurate.
2409 # the memory size is more accurate.
2287 try:
2410 try:
2288 self._ensure_dctx()
2411 self._ensure_dctx()
2289 finally:
2412 finally:
2290 self._dctx = ffi.gc(
2413 self._dctx = ffi.gc(
2291 dctx, lib.ZSTD_freeDCtx, size=lib.ZSTD_sizeof_DCtx(dctx)
2414 dctx, lib.ZSTD_freeDCtx, size=lib.ZSTD_sizeof_DCtx(dctx)
2292 )
2415 )
2293
2416
2294 def memory_size(self):
2417 def memory_size(self):
2295 return lib.ZSTD_sizeof_DCtx(self._dctx)
2418 return lib.ZSTD_sizeof_DCtx(self._dctx)
2296
2419
2297 def decompress(self, data, max_output_size=0):
2420 def decompress(self, data, max_output_size=0):
2298 self._ensure_dctx()
2421 self._ensure_dctx()
2299
2422
2300 data_buffer = ffi.from_buffer(data)
2423 data_buffer = ffi.from_buffer(data)
2301
2424
2302 output_size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
2425 output_size = lib.ZSTD_getFrameContentSize(
2426 data_buffer, len(data_buffer)
2427 )
2303
2428
2304 if output_size == lib.ZSTD_CONTENTSIZE_ERROR:
2429 if output_size == lib.ZSTD_CONTENTSIZE_ERROR:
2305 raise ZstdError("error determining content size from frame header")
2430 raise ZstdError("error determining content size from frame header")
2306 elif output_size == 0:
2431 elif output_size == 0:
2307 return b""
2432 return b""
2308 elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2433 elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2309 if not max_output_size:
2434 if not max_output_size:
2310 raise ZstdError("could not determine content size in frame header")
2435 raise ZstdError(
2436 "could not determine content size in frame header"
2437 )
2311
2438
2312 result_buffer = ffi.new("char[]", max_output_size)
2439 result_buffer = ffi.new("char[]", max_output_size)
2313 result_size = max_output_size
2440 result_size = max_output_size
2314 output_size = 0
2441 output_size = 0
2315 else:
2442 else:
2316 result_buffer = ffi.new("char[]", output_size)
2443 result_buffer = ffi.new("char[]", output_size)
2317 result_size = output_size
2444 result_size = output_size
2318
2445
2319 out_buffer = ffi.new("ZSTD_outBuffer *")
2446 out_buffer = ffi.new("ZSTD_outBuffer *")
2320 out_buffer.dst = result_buffer
2447 out_buffer.dst = result_buffer
2321 out_buffer.size = result_size
2448 out_buffer.size = result_size
2322 out_buffer.pos = 0
2449 out_buffer.pos = 0
2323
2450
2324 in_buffer = ffi.new("ZSTD_inBuffer *")
2451 in_buffer = ffi.new("ZSTD_inBuffer *")
2325 in_buffer.src = data_buffer
2452 in_buffer.src = data_buffer
2326 in_buffer.size = len(data_buffer)
2453 in_buffer.size = len(data_buffer)
2327 in_buffer.pos = 0
2454 in_buffer.pos = 0
2328
2455
2329 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2456 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2330 if lib.ZSTD_isError(zresult):
2457 if lib.ZSTD_isError(zresult):
2331 raise ZstdError("decompression error: %s" % _zstd_error(zresult))
2458 raise ZstdError("decompression error: %s" % _zstd_error(zresult))
2332 elif zresult:
2459 elif zresult:
2333 raise ZstdError("decompression error: did not decompress full frame")
2460 raise ZstdError(
2461 "decompression error: did not decompress full frame"
2462 )
2334 elif output_size and out_buffer.pos != output_size:
2463 elif output_size and out_buffer.pos != output_size:
2335 raise ZstdError(
2464 raise ZstdError(
2336 "decompression error: decompressed %d bytes; expected %d"
2465 "decompression error: decompressed %d bytes; expected %d"
2337 % (zresult, output_size)
2466 % (zresult, output_size)
2338 )
2467 )
2339
2468
2340 return ffi.buffer(result_buffer, out_buffer.pos)[:]
2469 return ffi.buffer(result_buffer, out_buffer.pos)[:]
2341
2470
2342 def stream_reader(
2471 def stream_reader(
2343 self,
2472 self,
2344 source,
2473 source,
2345 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2474 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2346 read_across_frames=False,
2475 read_across_frames=False,
2347 ):
2476 ):
2348 self._ensure_dctx()
2477 self._ensure_dctx()
2349 return ZstdDecompressionReader(self, source, read_size, read_across_frames)
2478 return ZstdDecompressionReader(
2479 self, source, read_size, read_across_frames
2480 )
2350
2481
2351 def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
2482 def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
2352 if write_size < 1:
2483 if write_size < 1:
2353 raise ValueError("write_size must be positive")
2484 raise ValueError("write_size must be positive")
2354
2485
2355 self._ensure_dctx()
2486 self._ensure_dctx()
2356 return ZstdDecompressionObj(self, write_size=write_size)
2487 return ZstdDecompressionObj(self, write_size=write_size)
2357
2488
2358 def read_to_iter(
2489 def read_to_iter(
2359 self,
2490 self,
2360 reader,
2491 reader,
2361 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2492 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2362 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2493 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2363 skip_bytes=0,
2494 skip_bytes=0,
2364 ):
2495 ):
2365 if skip_bytes >= read_size:
2496 if skip_bytes >= read_size:
2366 raise ValueError("skip_bytes must be smaller than read_size")
2497 raise ValueError("skip_bytes must be smaller than read_size")
2367
2498
2368 if hasattr(reader, "read"):
2499 if hasattr(reader, "read"):
2369 have_read = True
2500 have_read = True
2370 elif hasattr(reader, "__getitem__"):
2501 elif hasattr(reader, "__getitem__"):
2371 have_read = False
2502 have_read = False
2372 buffer_offset = 0
2503 buffer_offset = 0
2373 size = len(reader)
2504 size = len(reader)
2374 else:
2505 else:
2375 raise ValueError(
2506 raise ValueError(
2376 "must pass an object with a read() method or "
2507 "must pass an object with a read() method or "
2377 "conforms to buffer protocol"
2508 "conforms to buffer protocol"
2378 )
2509 )
2379
2510
2380 if skip_bytes:
2511 if skip_bytes:
2381 if have_read:
2512 if have_read:
2382 reader.read(skip_bytes)
2513 reader.read(skip_bytes)
2383 else:
2514 else:
2384 if skip_bytes > size:
2515 if skip_bytes > size:
2385 raise ValueError("skip_bytes larger than first input chunk")
2516 raise ValueError("skip_bytes larger than first input chunk")
2386
2517
2387 buffer_offset = skip_bytes
2518 buffer_offset = skip_bytes
2388
2519
2389 self._ensure_dctx()
2520 self._ensure_dctx()
2390
2521
2391 in_buffer = ffi.new("ZSTD_inBuffer *")
2522 in_buffer = ffi.new("ZSTD_inBuffer *")
2392 out_buffer = ffi.new("ZSTD_outBuffer *")
2523 out_buffer = ffi.new("ZSTD_outBuffer *")
2393
2524
2394 dst_buffer = ffi.new("char[]", write_size)
2525 dst_buffer = ffi.new("char[]", write_size)
2395 out_buffer.dst = dst_buffer
2526 out_buffer.dst = dst_buffer
2396 out_buffer.size = len(dst_buffer)
2527 out_buffer.size = len(dst_buffer)
2397 out_buffer.pos = 0
2528 out_buffer.pos = 0
2398
2529
2399 while True:
2530 while True:
2400 assert out_buffer.pos == 0
2531 assert out_buffer.pos == 0
2401
2532
2402 if have_read:
2533 if have_read:
2403 read_result = reader.read(read_size)
2534 read_result = reader.read(read_size)
2404 else:
2535 else:
2405 remaining = size - buffer_offset
2536 remaining = size - buffer_offset
2406 slice_size = min(remaining, read_size)
2537 slice_size = min(remaining, read_size)
2407 read_result = reader[buffer_offset : buffer_offset + slice_size]
2538 read_result = reader[buffer_offset : buffer_offset + slice_size]
2408 buffer_offset += slice_size
2539 buffer_offset += slice_size
2409
2540
2410 # No new input. Break out of read loop.
2541 # No new input. Break out of read loop.
2411 if not read_result:
2542 if not read_result:
2412 break
2543 break
2413
2544
2414 # Feed all read data into decompressor and emit output until
2545 # Feed all read data into decompressor and emit output until
2415 # exhausted.
2546 # exhausted.
2416 read_buffer = ffi.from_buffer(read_result)
2547 read_buffer = ffi.from_buffer(read_result)
2417 in_buffer.src = read_buffer
2548 in_buffer.src = read_buffer
2418 in_buffer.size = len(read_buffer)
2549 in_buffer.size = len(read_buffer)
2419 in_buffer.pos = 0
2550 in_buffer.pos = 0
2420
2551
2421 while in_buffer.pos < in_buffer.size:
2552 while in_buffer.pos < in_buffer.size:
2422 assert out_buffer.pos == 0
2553 assert out_buffer.pos == 0
2423
2554
2424 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2555 zresult = lib.ZSTD_decompressStream(
2556 self._dctx, out_buffer, in_buffer
2557 )
2425 if lib.ZSTD_isError(zresult):
2558 if lib.ZSTD_isError(zresult):
2426 raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult))
2559 raise ZstdError(
2560 "zstd decompress error: %s" % _zstd_error(zresult)
2561 )
2427
2562
2428 if out_buffer.pos:
2563 if out_buffer.pos:
2429 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2564 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2430 out_buffer.pos = 0
2565 out_buffer.pos = 0
2431 yield data
2566 yield data
2432
2567
2433 if zresult == 0:
2568 if zresult == 0:
2434 return
2569 return
2435
2570
2436 # Repeat loop to collect more input data.
2571 # Repeat loop to collect more input data.
2437 continue
2572 continue
2438
2573
2439 # If we get here, input is exhausted.
2574 # If we get here, input is exhausted.
2440
2575
2441 read_from = read_to_iter
2576 read_from = read_to_iter
2442
2577
2443 def stream_writer(
2578 def stream_writer(
2444 self,
2579 self,
2445 writer,
2580 writer,
2446 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2581 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2447 write_return_read=False,
2582 write_return_read=False,
2448 ):
2583 ):
2449 if not hasattr(writer, "write"):
2584 if not hasattr(writer, "write"):
2450 raise ValueError("must pass an object with a write() method")
2585 raise ValueError("must pass an object with a write() method")
2451
2586
2452 return ZstdDecompressionWriter(self, writer, write_size, write_return_read)
2587 return ZstdDecompressionWriter(
2588 self, writer, write_size, write_return_read
2589 )
2453
2590
2454 write_to = stream_writer
2591 write_to = stream_writer
2455
2592
2456 def copy_stream(
2593 def copy_stream(
2457 self,
2594 self,
2458 ifh,
2595 ifh,
2459 ofh,
2596 ofh,
2460 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2597 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2461 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2598 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2462 ):
2599 ):
2463 if not hasattr(ifh, "read"):
2600 if not hasattr(ifh, "read"):
2464 raise ValueError("first argument must have a read() method")
2601 raise ValueError("first argument must have a read() method")
2465 if not hasattr(ofh, "write"):
2602 if not hasattr(ofh, "write"):
2466 raise ValueError("second argument must have a write() method")
2603 raise ValueError("second argument must have a write() method")
2467
2604
2468 self._ensure_dctx()
2605 self._ensure_dctx()
2469
2606
2470 in_buffer = ffi.new("ZSTD_inBuffer *")
2607 in_buffer = ffi.new("ZSTD_inBuffer *")
2471 out_buffer = ffi.new("ZSTD_outBuffer *")
2608 out_buffer = ffi.new("ZSTD_outBuffer *")
2472
2609
2473 dst_buffer = ffi.new("char[]", write_size)
2610 dst_buffer = ffi.new("char[]", write_size)
2474 out_buffer.dst = dst_buffer
2611 out_buffer.dst = dst_buffer
2475 out_buffer.size = write_size
2612 out_buffer.size = write_size
2476 out_buffer.pos = 0
2613 out_buffer.pos = 0
2477
2614
2478 total_read, total_write = 0, 0
2615 total_read, total_write = 0, 0
2479
2616
2480 # Read all available input.
2617 # Read all available input.
2481 while True:
2618 while True:
2482 data = ifh.read(read_size)
2619 data = ifh.read(read_size)
2483 if not data:
2620 if not data:
2484 break
2621 break
2485
2622
2486 data_buffer = ffi.from_buffer(data)
2623 data_buffer = ffi.from_buffer(data)
2487 total_read += len(data_buffer)
2624 total_read += len(data_buffer)
2488 in_buffer.src = data_buffer
2625 in_buffer.src = data_buffer
2489 in_buffer.size = len(data_buffer)
2626 in_buffer.size = len(data_buffer)
2490 in_buffer.pos = 0
2627 in_buffer.pos = 0
2491
2628
2492 # Flush all read data to output.
2629 # Flush all read data to output.
2493 while in_buffer.pos < in_buffer.size:
2630 while in_buffer.pos < in_buffer.size:
2494 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2631 zresult = lib.ZSTD_decompressStream(
2632 self._dctx, out_buffer, in_buffer
2633 )
2495 if lib.ZSTD_isError(zresult):
2634 if lib.ZSTD_isError(zresult):
2496 raise ZstdError(
2635 raise ZstdError(
2497 "zstd decompressor error: %s" % _zstd_error(zresult)
2636 "zstd decompressor error: %s" % _zstd_error(zresult)
2498 )
2637 )
2499
2638
2500 if out_buffer.pos:
2639 if out_buffer.pos:
2501 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
2640 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
2502 total_write += out_buffer.pos
2641 total_write += out_buffer.pos
2503 out_buffer.pos = 0
2642 out_buffer.pos = 0
2504
2643
2505 # Continue loop to keep reading.
2644 # Continue loop to keep reading.
2506
2645
2507 return total_read, total_write
2646 return total_read, total_write
2508
2647
2509 def decompress_content_dict_chain(self, frames):
2648 def decompress_content_dict_chain(self, frames):
2510 if not isinstance(frames, list):
2649 if not isinstance(frames, list):
2511 raise TypeError("argument must be a list")
2650 raise TypeError("argument must be a list")
2512
2651
2513 if not frames:
2652 if not frames:
2514 raise ValueError("empty input chain")
2653 raise ValueError("empty input chain")
2515
2654
2516 # First chunk should not be using a dictionary. We handle it specially.
2655 # First chunk should not be using a dictionary. We handle it specially.
2517 chunk = frames[0]
2656 chunk = frames[0]
2518 if not isinstance(chunk, bytes_type):
2657 if not isinstance(chunk, bytes_type):
2519 raise ValueError("chunk 0 must be bytes")
2658 raise ValueError("chunk 0 must be bytes")
2520
2659
2521 # All chunks should be zstd frames and should have content size set.
2660 # All chunks should be zstd frames and should have content size set.
2522 chunk_buffer = ffi.from_buffer(chunk)
2661 chunk_buffer = ffi.from_buffer(chunk)
2523 params = ffi.new("ZSTD_frameHeader *")
2662 params = ffi.new("ZSTD_frameHeader *")
2524 zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
2663 zresult = lib.ZSTD_getFrameHeader(
2664 params, chunk_buffer, len(chunk_buffer)
2665 )
2525 if lib.ZSTD_isError(zresult):
2666 if lib.ZSTD_isError(zresult):
2526 raise ValueError("chunk 0 is not a valid zstd frame")
2667 raise ValueError("chunk 0 is not a valid zstd frame")
2527 elif zresult:
2668 elif zresult:
2528 raise ValueError("chunk 0 is too small to contain a zstd frame")
2669 raise ValueError("chunk 0 is too small to contain a zstd frame")
2529
2670
2530 if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2671 if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2531 raise ValueError("chunk 0 missing content size in frame")
2672 raise ValueError("chunk 0 missing content size in frame")
2532
2673
2533 self._ensure_dctx(load_dict=False)
2674 self._ensure_dctx(load_dict=False)
2534
2675
2535 last_buffer = ffi.new("char[]", params.frameContentSize)
2676 last_buffer = ffi.new("char[]", params.frameContentSize)
2536
2677
2537 out_buffer = ffi.new("ZSTD_outBuffer *")
2678 out_buffer = ffi.new("ZSTD_outBuffer *")
2538 out_buffer.dst = last_buffer
2679 out_buffer.dst = last_buffer
2539 out_buffer.size = len(last_buffer)
2680 out_buffer.size = len(last_buffer)
2540 out_buffer.pos = 0
2681 out_buffer.pos = 0
2541
2682
2542 in_buffer = ffi.new("ZSTD_inBuffer *")
2683 in_buffer = ffi.new("ZSTD_inBuffer *")
2543 in_buffer.src = chunk_buffer
2684 in_buffer.src = chunk_buffer
2544 in_buffer.size = len(chunk_buffer)
2685 in_buffer.size = len(chunk_buffer)
2545 in_buffer.pos = 0
2686 in_buffer.pos = 0
2546
2687
2547 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2688 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2548 if lib.ZSTD_isError(zresult):
2689 if lib.ZSTD_isError(zresult):
2549 raise ZstdError("could not decompress chunk 0: %s" % _zstd_error(zresult))
2690 raise ZstdError(
2691 "could not decompress chunk 0: %s" % _zstd_error(zresult)
2692 )
2550 elif zresult:
2693 elif zresult:
2551 raise ZstdError("chunk 0 did not decompress full frame")
2694 raise ZstdError("chunk 0 did not decompress full frame")
2552
2695
2553 # Special case of chain length of 1
2696 # Special case of chain length of 1
2554 if len(frames) == 1:
2697 if len(frames) == 1:
2555 return ffi.buffer(last_buffer, len(last_buffer))[:]
2698 return ffi.buffer(last_buffer, len(last_buffer))[:]
2556
2699
2557 i = 1
2700 i = 1
2558 while i < len(frames):
2701 while i < len(frames):
2559 chunk = frames[i]
2702 chunk = frames[i]
2560 if not isinstance(chunk, bytes_type):
2703 if not isinstance(chunk, bytes_type):
2561 raise ValueError("chunk %d must be bytes" % i)
2704 raise ValueError("chunk %d must be bytes" % i)
2562
2705
2563 chunk_buffer = ffi.from_buffer(chunk)
2706 chunk_buffer = ffi.from_buffer(chunk)
2564 zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
2707 zresult = lib.ZSTD_getFrameHeader(
2708 params, chunk_buffer, len(chunk_buffer)
2709 )
2565 if lib.ZSTD_isError(zresult):
2710 if lib.ZSTD_isError(zresult):
2566 raise ValueError("chunk %d is not a valid zstd frame" % i)
2711 raise ValueError("chunk %d is not a valid zstd frame" % i)
2567 elif zresult:
2712 elif zresult:
2568 raise ValueError("chunk %d is too small to contain a zstd frame" % i)
2713 raise ValueError(
2714 "chunk %d is too small to contain a zstd frame" % i
2715 )
2569
2716
2570 if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2717 if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2571 raise ValueError("chunk %d missing content size in frame" % i)
2718 raise ValueError("chunk %d missing content size in frame" % i)
2572
2719
2573 dest_buffer = ffi.new("char[]", params.frameContentSize)
2720 dest_buffer = ffi.new("char[]", params.frameContentSize)
2574
2721
2575 out_buffer.dst = dest_buffer
2722 out_buffer.dst = dest_buffer
2576 out_buffer.size = len(dest_buffer)
2723 out_buffer.size = len(dest_buffer)
2577 out_buffer.pos = 0
2724 out_buffer.pos = 0
2578
2725
2579 in_buffer.src = chunk_buffer
2726 in_buffer.src = chunk_buffer
2580 in_buffer.size = len(chunk_buffer)
2727 in_buffer.size = len(chunk_buffer)
2581 in_buffer.pos = 0
2728 in_buffer.pos = 0
2582
2729
2583 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2730 zresult = lib.ZSTD_decompressStream(
2731 self._dctx, out_buffer, in_buffer
2732 )
2584 if lib.ZSTD_isError(zresult):
2733 if lib.ZSTD_isError(zresult):
2585 raise ZstdError(
2734 raise ZstdError(
2586 "could not decompress chunk %d: %s" % _zstd_error(zresult)
2735 "could not decompress chunk %d: %s" % _zstd_error(zresult)
2587 )
2736 )
2588 elif zresult:
2737 elif zresult:
2589 raise ZstdError("chunk %d did not decompress full frame" % i)
2738 raise ZstdError("chunk %d did not decompress full frame" % i)
2590
2739
2591 last_buffer = dest_buffer
2740 last_buffer = dest_buffer
2592 i += 1
2741 i += 1
2593
2742
2594 return ffi.buffer(last_buffer, len(last_buffer))[:]
2743 return ffi.buffer(last_buffer, len(last_buffer))[:]
2595
2744
2596 def _ensure_dctx(self, load_dict=True):
2745 def _ensure_dctx(self, load_dict=True):
2597 lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only)
2746 lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only)
2598
2747
2599 if self._max_window_size:
2748 if self._max_window_size:
2600 zresult = lib.ZSTD_DCtx_setMaxWindowSize(self._dctx, self._max_window_size)
2749 zresult = lib.ZSTD_DCtx_setMaxWindowSize(
2750 self._dctx, self._max_window_size
2751 )
2601 if lib.ZSTD_isError(zresult):
2752 if lib.ZSTD_isError(zresult):
2602 raise ZstdError(
2753 raise ZstdError(
2603 "unable to set max window size: %s" % _zstd_error(zresult)
2754 "unable to set max window size: %s" % _zstd_error(zresult)
2604 )
2755 )
2605
2756
2606 zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format)
2757 zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format)
2607 if lib.ZSTD_isError(zresult):
2758 if lib.ZSTD_isError(zresult):
2608 raise ZstdError("unable to set decoding format: %s" % _zstd_error(zresult))
2759 raise ZstdError(
2760 "unable to set decoding format: %s" % _zstd_error(zresult)
2761 )
2609
2762
2610 if self._dict_data and load_dict:
2763 if self._dict_data and load_dict:
2611 zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict)
2764 zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict)
2612 if lib.ZSTD_isError(zresult):
2765 if lib.ZSTD_isError(zresult):
2613 raise ZstdError(
2766 raise ZstdError(
2614 "unable to reference prepared dictionary: %s" % _zstd_error(zresult)
2767 "unable to reference prepared dictionary: %s"
2768 % _zstd_error(zresult)
2615 )
2769 )
@@ -1,5 +1,5 b''
1 #require black
1 #require black
2
2
3 $ cd $RUNTESTDIR/..
3 $ cd $RUNTESTDIR/..
4 $ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/** - "contrib/python-zstandard/**"'`
4 $ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/**'`
5
5
General Comments 0
You need to be logged in to leave comments. Login now