Show More
@@ -1,15 +1,14 b'' | |||||
1 | [tool.black] |
|
1 | [tool.black] | |
2 | line-length = 80 |
|
2 | line-length = 80 | |
3 | exclude = ''' |
|
3 | exclude = ''' | |
4 | build/ |
|
4 | build/ | |
5 | | wheelhouse/ |
|
5 | | wheelhouse/ | |
6 | | dist/ |
|
6 | | dist/ | |
7 | | packages/ |
|
7 | | packages/ | |
8 | | \.hg/ |
|
8 | | \.hg/ | |
9 | | \.mypy_cache/ |
|
9 | | \.mypy_cache/ | |
10 | | \.venv/ |
|
10 | | \.venv/ | |
11 | | mercurial/thirdparty/ |
|
11 | | mercurial/thirdparty/ | |
12 | | contrib/python-zstandard/ |
|
|||
13 | ''' |
|
12 | ''' | |
14 | skip-string-normalization = true |
|
13 | skip-string-normalization = true | |
15 | quiet = true |
|
14 | quiet = true |
@@ -1,14 +1,14 b'' | |||||
1 | [fix] |
|
1 | [fix] | |
2 | clang-format:command = clang-format --style file |
|
2 | clang-format:command = clang-format --style file | |
3 | clang-format:pattern = set:(**.c or **.cc or **.h) and not "include:contrib/clang-format-ignorelist" |
|
3 | clang-format:pattern = set:(**.c or **.cc or **.h) and not "include:contrib/clang-format-ignorelist" | |
4 |
|
4 | |||
5 | rustfmt:command = rustfmt +nightly |
|
5 | rustfmt:command = rustfmt +nightly | |
6 | rustfmt:pattern = set:**.rs |
|
6 | rustfmt:pattern = set:**.rs | |
7 |
|
7 | |||
8 | black:command = black --config=black.toml - |
|
8 | black:command = black --config=black.toml - | |
9 |
black:pattern = set:**.py - mercurial/thirdparty/** |
|
9 | black:pattern = set:**.py - mercurial/thirdparty/** | |
10 |
|
10 | |||
11 | # Mercurial doesn't have any Go code, but if we did this is how we |
|
11 | # Mercurial doesn't have any Go code, but if we did this is how we | |
12 | # would configure `hg fix` for Go: |
|
12 | # would configure `hg fix` for Go: | |
13 | go:command = gofmt |
|
13 | go:command = gofmt | |
14 | go:pattern = set:**.go |
|
14 | go:pattern = set:**.go |
@@ -1,225 +1,228 b'' | |||||
1 | # Copyright (c) 2016-present, Gregory Szorc |
|
1 | # Copyright (c) 2016-present, Gregory Szorc | |
2 | # All rights reserved. |
|
2 | # All rights reserved. | |
3 | # |
|
3 | # | |
4 | # This software may be modified and distributed under the terms |
|
4 | # This software may be modified and distributed under the terms | |
5 | # of the BSD license. See the LICENSE file for details. |
|
5 | # of the BSD license. See the LICENSE file for details. | |
6 |
|
6 | |||
7 | from __future__ import absolute_import |
|
7 | from __future__ import absolute_import | |
8 |
|
8 | |||
9 | import cffi |
|
9 | import cffi | |
10 | import distutils.ccompiler |
|
10 | import distutils.ccompiler | |
11 | import os |
|
11 | import os | |
12 | import re |
|
12 | import re | |
13 | import subprocess |
|
13 | import subprocess | |
14 | import tempfile |
|
14 | import tempfile | |
15 |
|
15 | |||
16 |
|
16 | |||
17 | HERE = os.path.abspath(os.path.dirname(__file__)) |
|
17 | HERE = os.path.abspath(os.path.dirname(__file__)) | |
18 |
|
18 | |||
19 | SOURCES = [ |
|
19 | SOURCES = [ | |
20 | "zstd/%s" % p |
|
20 | "zstd/%s" % p | |
21 | for p in ( |
|
21 | for p in ( | |
22 | "common/debug.c", |
|
22 | "common/debug.c", | |
23 | "common/entropy_common.c", |
|
23 | "common/entropy_common.c", | |
24 | "common/error_private.c", |
|
24 | "common/error_private.c", | |
25 | "common/fse_decompress.c", |
|
25 | "common/fse_decompress.c", | |
26 | "common/pool.c", |
|
26 | "common/pool.c", | |
27 | "common/threading.c", |
|
27 | "common/threading.c", | |
28 | "common/xxhash.c", |
|
28 | "common/xxhash.c", | |
29 | "common/zstd_common.c", |
|
29 | "common/zstd_common.c", | |
30 | "compress/fse_compress.c", |
|
30 | "compress/fse_compress.c", | |
31 | "compress/hist.c", |
|
31 | "compress/hist.c", | |
32 | "compress/huf_compress.c", |
|
32 | "compress/huf_compress.c", | |
33 | "compress/zstd_compress.c", |
|
33 | "compress/zstd_compress.c", | |
34 | "compress/zstd_compress_literals.c", |
|
34 | "compress/zstd_compress_literals.c", | |
35 | "compress/zstd_compress_sequences.c", |
|
35 | "compress/zstd_compress_sequences.c", | |
36 | "compress/zstd_double_fast.c", |
|
36 | "compress/zstd_double_fast.c", | |
37 | "compress/zstd_fast.c", |
|
37 | "compress/zstd_fast.c", | |
38 | "compress/zstd_lazy.c", |
|
38 | "compress/zstd_lazy.c", | |
39 | "compress/zstd_ldm.c", |
|
39 | "compress/zstd_ldm.c", | |
40 | "compress/zstd_opt.c", |
|
40 | "compress/zstd_opt.c", | |
41 | "compress/zstdmt_compress.c", |
|
41 | "compress/zstdmt_compress.c", | |
42 | "decompress/huf_decompress.c", |
|
42 | "decompress/huf_decompress.c", | |
43 | "decompress/zstd_ddict.c", |
|
43 | "decompress/zstd_ddict.c", | |
44 | "decompress/zstd_decompress.c", |
|
44 | "decompress/zstd_decompress.c", | |
45 | "decompress/zstd_decompress_block.c", |
|
45 | "decompress/zstd_decompress_block.c", | |
46 | "dictBuilder/cover.c", |
|
46 | "dictBuilder/cover.c", | |
47 | "dictBuilder/fastcover.c", |
|
47 | "dictBuilder/fastcover.c", | |
48 | "dictBuilder/divsufsort.c", |
|
48 | "dictBuilder/divsufsort.c", | |
49 | "dictBuilder/zdict.c", |
|
49 | "dictBuilder/zdict.c", | |
50 | ) |
|
50 | ) | |
51 | ] |
|
51 | ] | |
52 |
|
52 | |||
53 | # Headers whose preprocessed output will be fed into cdef(). |
|
53 | # Headers whose preprocessed output will be fed into cdef(). | |
54 | HEADERS = [ |
|
54 | HEADERS = [ | |
55 | os.path.join(HERE, "zstd", *p) for p in (("zstd.h",), ("dictBuilder", "zdict.h"),) |
|
55 | os.path.join(HERE, "zstd", *p) | |
|
56 | for p in (("zstd.h",), ("dictBuilder", "zdict.h"),) | |||
56 | ] |
|
57 | ] | |
57 |
|
58 | |||
58 | INCLUDE_DIRS = [ |
|
59 | INCLUDE_DIRS = [ | |
59 | os.path.join(HERE, d) |
|
60 | os.path.join(HERE, d) | |
60 | for d in ( |
|
61 | for d in ( | |
61 | "zstd", |
|
62 | "zstd", | |
62 | "zstd/common", |
|
63 | "zstd/common", | |
63 | "zstd/compress", |
|
64 | "zstd/compress", | |
64 | "zstd/decompress", |
|
65 | "zstd/decompress", | |
65 | "zstd/dictBuilder", |
|
66 | "zstd/dictBuilder", | |
66 | ) |
|
67 | ) | |
67 | ] |
|
68 | ] | |
68 |
|
69 | |||
69 | # cffi can't parse some of the primitives in zstd.h. So we invoke the |
|
70 | # cffi can't parse some of the primitives in zstd.h. So we invoke the | |
70 | # preprocessor and feed its output into cffi. |
|
71 | # preprocessor and feed its output into cffi. | |
71 | compiler = distutils.ccompiler.new_compiler() |
|
72 | compiler = distutils.ccompiler.new_compiler() | |
72 |
|
73 | |||
73 | # Needed for MSVC. |
|
74 | # Needed for MSVC. | |
74 | if hasattr(compiler, "initialize"): |
|
75 | if hasattr(compiler, "initialize"): | |
75 | compiler.initialize() |
|
76 | compiler.initialize() | |
76 |
|
77 | |||
77 | # Distutils doesn't set compiler.preprocessor, so invoke the preprocessor |
|
78 | # Distutils doesn't set compiler.preprocessor, so invoke the preprocessor | |
78 | # manually. |
|
79 | # manually. | |
79 | if compiler.compiler_type == "unix": |
|
80 | if compiler.compiler_type == "unix": | |
80 | args = list(compiler.executables["compiler"]) |
|
81 | args = list(compiler.executables["compiler"]) | |
81 | args.extend( |
|
82 | args.extend( | |
82 | ["-E", "-DZSTD_STATIC_LINKING_ONLY", "-DZDICT_STATIC_LINKING_ONLY",] |
|
83 | ["-E", "-DZSTD_STATIC_LINKING_ONLY", "-DZDICT_STATIC_LINKING_ONLY",] | |
83 | ) |
|
84 | ) | |
84 | elif compiler.compiler_type == "msvc": |
|
85 | elif compiler.compiler_type == "msvc": | |
85 | args = [compiler.cc] |
|
86 | args = [compiler.cc] | |
86 | args.extend( |
|
87 | args.extend( | |
87 | ["/EP", "/DZSTD_STATIC_LINKING_ONLY", "/DZDICT_STATIC_LINKING_ONLY",] |
|
88 | ["/EP", "/DZSTD_STATIC_LINKING_ONLY", "/DZDICT_STATIC_LINKING_ONLY",] | |
88 | ) |
|
89 | ) | |
89 | else: |
|
90 | else: | |
90 | raise Exception("unsupported compiler type: %s" % compiler.compiler_type) |
|
91 | raise Exception("unsupported compiler type: %s" % compiler.compiler_type) | |
91 |
|
92 | |||
92 |
|
93 | |||
93 | def preprocess(path): |
|
94 | def preprocess(path): | |
94 | with open(path, "rb") as fh: |
|
95 | with open(path, "rb") as fh: | |
95 | lines = [] |
|
96 | lines = [] | |
96 | it = iter(fh) |
|
97 | it = iter(fh) | |
97 |
|
98 | |||
98 | for l in it: |
|
99 | for l in it: | |
99 | # zstd.h includes <stddef.h>, which is also included by cffi's |
|
100 | # zstd.h includes <stddef.h>, which is also included by cffi's | |
100 | # boilerplate. This can lead to duplicate declarations. So we strip |
|
101 | # boilerplate. This can lead to duplicate declarations. So we strip | |
101 | # this include from the preprocessor invocation. |
|
102 | # this include from the preprocessor invocation. | |
102 | # |
|
103 | # | |
103 | # The same things happens for including zstd.h, so give it the same |
|
104 | # The same things happens for including zstd.h, so give it the same | |
104 | # treatment. |
|
105 | # treatment. | |
105 | # |
|
106 | # | |
106 | # We define ZSTD_STATIC_LINKING_ONLY, which is redundant with the inline |
|
107 | # We define ZSTD_STATIC_LINKING_ONLY, which is redundant with the inline | |
107 | # #define in zstdmt_compress.h and results in a compiler warning. So drop |
|
108 | # #define in zstdmt_compress.h and results in a compiler warning. So drop | |
108 | # the inline #define. |
|
109 | # the inline #define. | |
109 | if l.startswith( |
|
110 | if l.startswith( | |
110 | ( |
|
111 | ( | |
111 | b"#include <stddef.h>", |
|
112 | b"#include <stddef.h>", | |
112 | b'#include "zstd.h"', |
|
113 | b'#include "zstd.h"', | |
113 | b"#define ZSTD_STATIC_LINKING_ONLY", |
|
114 | b"#define ZSTD_STATIC_LINKING_ONLY", | |
114 | ) |
|
115 | ) | |
115 | ): |
|
116 | ): | |
116 | continue |
|
117 | continue | |
117 |
|
118 | |||
118 | # The preprocessor environment on Windows doesn't define include |
|
119 | # The preprocessor environment on Windows doesn't define include | |
119 | # paths, so the #include of limits.h fails. We work around this |
|
120 | # paths, so the #include of limits.h fails. We work around this | |
120 | # by removing that import and defining INT_MAX ourselves. This is |
|
121 | # by removing that import and defining INT_MAX ourselves. This is | |
121 | # a bit hacky. But it gets the job done. |
|
122 | # a bit hacky. But it gets the job done. | |
122 | # TODO make limits.h work on Windows so we ensure INT_MAX is |
|
123 | # TODO make limits.h work on Windows so we ensure INT_MAX is | |
123 | # correct. |
|
124 | # correct. | |
124 | if l.startswith(b"#include <limits.h>"): |
|
125 | if l.startswith(b"#include <limits.h>"): | |
125 | l = b"#define INT_MAX 2147483647\n" |
|
126 | l = b"#define INT_MAX 2147483647\n" | |
126 |
|
127 | |||
127 | # ZSTDLIB_API may not be defined if we dropped zstd.h. It isn't |
|
128 | # ZSTDLIB_API may not be defined if we dropped zstd.h. It isn't | |
128 | # important so just filter it out. |
|
129 | # important so just filter it out. | |
129 | if l.startswith(b"ZSTDLIB_API"): |
|
130 | if l.startswith(b"ZSTDLIB_API"): | |
130 | l = l[len(b"ZSTDLIB_API ") :] |
|
131 | l = l[len(b"ZSTDLIB_API ") :] | |
131 |
|
132 | |||
132 | lines.append(l) |
|
133 | lines.append(l) | |
133 |
|
134 | |||
134 | fd, input_file = tempfile.mkstemp(suffix=".h") |
|
135 | fd, input_file = tempfile.mkstemp(suffix=".h") | |
135 | os.write(fd, b"".join(lines)) |
|
136 | os.write(fd, b"".join(lines)) | |
136 | os.close(fd) |
|
137 | os.close(fd) | |
137 |
|
138 | |||
138 | try: |
|
139 | try: | |
139 | env = dict(os.environ) |
|
140 | env = dict(os.environ) | |
140 | if getattr(compiler, "_paths", None): |
|
141 | if getattr(compiler, "_paths", None): | |
141 | env["PATH"] = compiler._paths |
|
142 | env["PATH"] = compiler._paths | |
142 | process = subprocess.Popen(args + [input_file], stdout=subprocess.PIPE, env=env) |
|
143 | process = subprocess.Popen( | |
|
144 | args + [input_file], stdout=subprocess.PIPE, env=env | |||
|
145 | ) | |||
143 | output = process.communicate()[0] |
|
146 | output = process.communicate()[0] | |
144 | ret = process.poll() |
|
147 | ret = process.poll() | |
145 | if ret: |
|
148 | if ret: | |
146 | raise Exception("preprocessor exited with error") |
|
149 | raise Exception("preprocessor exited with error") | |
147 |
|
150 | |||
148 | return output |
|
151 | return output | |
149 | finally: |
|
152 | finally: | |
150 | os.unlink(input_file) |
|
153 | os.unlink(input_file) | |
151 |
|
154 | |||
152 |
|
155 | |||
153 | def normalize_output(output): |
|
156 | def normalize_output(output): | |
154 | lines = [] |
|
157 | lines = [] | |
155 | for line in output.splitlines(): |
|
158 | for line in output.splitlines(): | |
156 | # CFFI's parser doesn't like __attribute__ on UNIX compilers. |
|
159 | # CFFI's parser doesn't like __attribute__ on UNIX compilers. | |
157 | if line.startswith(b'__attribute__ ((visibility ("default"))) '): |
|
160 | if line.startswith(b'__attribute__ ((visibility ("default"))) '): | |
158 | line = line[len(b'__attribute__ ((visibility ("default"))) ') :] |
|
161 | line = line[len(b'__attribute__ ((visibility ("default"))) ') :] | |
159 |
|
162 | |||
160 | if line.startswith(b"__attribute__((deprecated("): |
|
163 | if line.startswith(b"__attribute__((deprecated("): | |
161 | continue |
|
164 | continue | |
162 | elif b"__declspec(deprecated(" in line: |
|
165 | elif b"__declspec(deprecated(" in line: | |
163 | continue |
|
166 | continue | |
164 |
|
167 | |||
165 | lines.append(line) |
|
168 | lines.append(line) | |
166 |
|
169 | |||
167 | return b"\n".join(lines) |
|
170 | return b"\n".join(lines) | |
168 |
|
171 | |||
169 |
|
172 | |||
170 | ffi = cffi.FFI() |
|
173 | ffi = cffi.FFI() | |
171 | # zstd.h uses a possible undefined MIN(). Define it until |
|
174 | # zstd.h uses a possible undefined MIN(). Define it until | |
172 | # https://github.com/facebook/zstd/issues/976 is fixed. |
|
175 | # https://github.com/facebook/zstd/issues/976 is fixed. | |
173 | # *_DISABLE_DEPRECATE_WARNINGS prevents the compiler from emitting a warning |
|
176 | # *_DISABLE_DEPRECATE_WARNINGS prevents the compiler from emitting a warning | |
174 | # when cffi uses the function. Since we statically link against zstd, even |
|
177 | # when cffi uses the function. Since we statically link against zstd, even | |
175 | # if we use the deprecated functions it shouldn't be a huge problem. |
|
178 | # if we use the deprecated functions it shouldn't be a huge problem. | |
176 | ffi.set_source( |
|
179 | ffi.set_source( | |
177 | "_zstd_cffi", |
|
180 | "_zstd_cffi", | |
178 | """ |
|
181 | """ | |
179 | #define MIN(a,b) ((a)<(b) ? (a) : (b)) |
|
182 | #define MIN(a,b) ((a)<(b) ? (a) : (b)) | |
180 | #define ZSTD_STATIC_LINKING_ONLY |
|
183 | #define ZSTD_STATIC_LINKING_ONLY | |
181 | #include <zstd.h> |
|
184 | #include <zstd.h> | |
182 | #define ZDICT_STATIC_LINKING_ONLY |
|
185 | #define ZDICT_STATIC_LINKING_ONLY | |
183 | #define ZDICT_DISABLE_DEPRECATE_WARNINGS |
|
186 | #define ZDICT_DISABLE_DEPRECATE_WARNINGS | |
184 | #include <zdict.h> |
|
187 | #include <zdict.h> | |
185 | """, |
|
188 | """, | |
186 | sources=SOURCES, |
|
189 | sources=SOURCES, | |
187 | include_dirs=INCLUDE_DIRS, |
|
190 | include_dirs=INCLUDE_DIRS, | |
188 | extra_compile_args=["-DZSTD_MULTITHREAD"], |
|
191 | extra_compile_args=["-DZSTD_MULTITHREAD"], | |
189 | ) |
|
192 | ) | |
190 |
|
193 | |||
191 | DEFINE = re.compile(b"^\\#define ([a-zA-Z0-9_]+) ") |
|
194 | DEFINE = re.compile(b"^\\#define ([a-zA-Z0-9_]+) ") | |
192 |
|
195 | |||
193 | sources = [] |
|
196 | sources = [] | |
194 |
|
197 | |||
195 | # Feed normalized preprocessor output for headers into the cdef parser. |
|
198 | # Feed normalized preprocessor output for headers into the cdef parser. | |
196 | for header in HEADERS: |
|
199 | for header in HEADERS: | |
197 | preprocessed = preprocess(header) |
|
200 | preprocessed = preprocess(header) | |
198 | sources.append(normalize_output(preprocessed)) |
|
201 | sources.append(normalize_output(preprocessed)) | |
199 |
|
202 | |||
200 | # #define's are effectively erased as part of going through preprocessor. |
|
203 | # #define's are effectively erased as part of going through preprocessor. | |
201 | # So perform a manual pass to re-add those to the cdef source. |
|
204 | # So perform a manual pass to re-add those to the cdef source. | |
202 | with open(header, "rb") as fh: |
|
205 | with open(header, "rb") as fh: | |
203 | for line in fh: |
|
206 | for line in fh: | |
204 | line = line.strip() |
|
207 | line = line.strip() | |
205 | m = DEFINE.match(line) |
|
208 | m = DEFINE.match(line) | |
206 | if not m: |
|
209 | if not m: | |
207 | continue |
|
210 | continue | |
208 |
|
211 | |||
209 | if m.group(1) == b"ZSTD_STATIC_LINKING_ONLY": |
|
212 | if m.group(1) == b"ZSTD_STATIC_LINKING_ONLY": | |
210 | continue |
|
213 | continue | |
211 |
|
214 | |||
212 | # The parser doesn't like some constants with complex values. |
|
215 | # The parser doesn't like some constants with complex values. | |
213 | if m.group(1) in (b"ZSTD_LIB_VERSION", b"ZSTD_VERSION_STRING"): |
|
216 | if m.group(1) in (b"ZSTD_LIB_VERSION", b"ZSTD_VERSION_STRING"): | |
214 | continue |
|
217 | continue | |
215 |
|
218 | |||
216 | # The ... is magic syntax by the cdef parser to resolve the |
|
219 | # The ... is magic syntax by the cdef parser to resolve the | |
217 | # value at compile time. |
|
220 | # value at compile time. | |
218 | sources.append(m.group(0) + b" ...") |
|
221 | sources.append(m.group(0) + b" ...") | |
219 |
|
222 | |||
220 | cdeflines = b"\n".join(sources).splitlines() |
|
223 | cdeflines = b"\n".join(sources).splitlines() | |
221 | cdeflines = [l for l in cdeflines if l.strip()] |
|
224 | cdeflines = [l for l in cdeflines if l.strip()] | |
222 | ffi.cdef(b"\n".join(cdeflines).decode("latin1")) |
|
225 | ffi.cdef(b"\n".join(cdeflines).decode("latin1")) | |
223 |
|
226 | |||
224 | if __name__ == "__main__": |
|
227 | if __name__ == "__main__": | |
225 | ffi.compile() |
|
228 | ffi.compile() |
@@ -1,118 +1,120 b'' | |||||
1 | #!/usr/bin/env python |
|
1 | #!/usr/bin/env python | |
2 | # Copyright (c) 2016-present, Gregory Szorc |
|
2 | # Copyright (c) 2016-present, Gregory Szorc | |
3 | # All rights reserved. |
|
3 | # All rights reserved. | |
4 | # |
|
4 | # | |
5 | # This software may be modified and distributed under the terms |
|
5 | # This software may be modified and distributed under the terms | |
6 | # of the BSD license. See the LICENSE file for details. |
|
6 | # of the BSD license. See the LICENSE file for details. | |
7 |
|
7 | |||
8 | from __future__ import print_function |
|
8 | from __future__ import print_function | |
9 |
|
9 | |||
10 | from distutils.version import LooseVersion |
|
10 | from distutils.version import LooseVersion | |
11 | import os |
|
11 | import os | |
12 | import sys |
|
12 | import sys | |
13 | from setuptools import setup |
|
13 | from setuptools import setup | |
14 |
|
14 | |||
15 | # Need change in 1.10 for ffi.from_buffer() to handle all buffer types |
|
15 | # Need change in 1.10 for ffi.from_buffer() to handle all buffer types | |
16 | # (like memoryview). |
|
16 | # (like memoryview). | |
17 | # Need feature in 1.11 for ffi.gc() to declare size of objects so we avoid |
|
17 | # Need feature in 1.11 for ffi.gc() to declare size of objects so we avoid | |
18 | # garbage collection pitfalls. |
|
18 | # garbage collection pitfalls. | |
19 | MINIMUM_CFFI_VERSION = "1.11" |
|
19 | MINIMUM_CFFI_VERSION = "1.11" | |
20 |
|
20 | |||
21 | try: |
|
21 | try: | |
22 | import cffi |
|
22 | import cffi | |
23 |
|
23 | |||
24 | # PyPy (and possibly other distros) have CFFI distributed as part of |
|
24 | # PyPy (and possibly other distros) have CFFI distributed as part of | |
25 | # them. The install_requires for CFFI below won't work. We need to sniff |
|
25 | # them. The install_requires for CFFI below won't work. We need to sniff | |
26 | # out the CFFI version here and reject CFFI if it is too old. |
|
26 | # out the CFFI version here and reject CFFI if it is too old. | |
27 | cffi_version = LooseVersion(cffi.__version__) |
|
27 | cffi_version = LooseVersion(cffi.__version__) | |
28 | if cffi_version < LooseVersion(MINIMUM_CFFI_VERSION): |
|
28 | if cffi_version < LooseVersion(MINIMUM_CFFI_VERSION): | |
29 | print( |
|
29 | print( | |
30 | "CFFI 1.11 or newer required (%s found); " |
|
30 | "CFFI 1.11 or newer required (%s found); " | |
31 | "not building CFFI backend" % cffi_version, |
|
31 | "not building CFFI backend" % cffi_version, | |
32 | file=sys.stderr, |
|
32 | file=sys.stderr, | |
33 | ) |
|
33 | ) | |
34 | cffi = None |
|
34 | cffi = None | |
35 |
|
35 | |||
36 | except ImportError: |
|
36 | except ImportError: | |
37 | cffi = None |
|
37 | cffi = None | |
38 |
|
38 | |||
39 | import setup_zstd |
|
39 | import setup_zstd | |
40 |
|
40 | |||
41 | SUPPORT_LEGACY = False |
|
41 | SUPPORT_LEGACY = False | |
42 | SYSTEM_ZSTD = False |
|
42 | SYSTEM_ZSTD = False | |
43 | WARNINGS_AS_ERRORS = False |
|
43 | WARNINGS_AS_ERRORS = False | |
44 |
|
44 | |||
45 | if os.environ.get("ZSTD_WARNINGS_AS_ERRORS", ""): |
|
45 | if os.environ.get("ZSTD_WARNINGS_AS_ERRORS", ""): | |
46 | WARNINGS_AS_ERRORS = True |
|
46 | WARNINGS_AS_ERRORS = True | |
47 |
|
47 | |||
48 | if "--legacy" in sys.argv: |
|
48 | if "--legacy" in sys.argv: | |
49 | SUPPORT_LEGACY = True |
|
49 | SUPPORT_LEGACY = True | |
50 | sys.argv.remove("--legacy") |
|
50 | sys.argv.remove("--legacy") | |
51 |
|
51 | |||
52 | if "--system-zstd" in sys.argv: |
|
52 | if "--system-zstd" in sys.argv: | |
53 | SYSTEM_ZSTD = True |
|
53 | SYSTEM_ZSTD = True | |
54 | sys.argv.remove("--system-zstd") |
|
54 | sys.argv.remove("--system-zstd") | |
55 |
|
55 | |||
56 | if "--warnings-as-errors" in sys.argv: |
|
56 | if "--warnings-as-errors" in sys.argv: | |
57 | WARNINGS_AS_ERRORS = True |
|
57 | WARNINGS_AS_ERRORS = True | |
58 | sys.argv.remove("--warning-as-errors") |
|
58 | sys.argv.remove("--warning-as-errors") | |
59 |
|
59 | |||
60 | # Code for obtaining the Extension instance is in its own module to |
|
60 | # Code for obtaining the Extension instance is in its own module to | |
61 | # facilitate reuse in other projects. |
|
61 | # facilitate reuse in other projects. | |
62 | extensions = [ |
|
62 | extensions = [ | |
63 | setup_zstd.get_c_extension( |
|
63 | setup_zstd.get_c_extension( | |
64 | name="zstd", |
|
64 | name="zstd", | |
65 | support_legacy=SUPPORT_LEGACY, |
|
65 | support_legacy=SUPPORT_LEGACY, | |
66 | system_zstd=SYSTEM_ZSTD, |
|
66 | system_zstd=SYSTEM_ZSTD, | |
67 | warnings_as_errors=WARNINGS_AS_ERRORS, |
|
67 | warnings_as_errors=WARNINGS_AS_ERRORS, | |
68 | ), |
|
68 | ), | |
69 | ] |
|
69 | ] | |
70 |
|
70 | |||
71 | install_requires = [] |
|
71 | install_requires = [] | |
72 |
|
72 | |||
73 | if cffi: |
|
73 | if cffi: | |
74 | import make_cffi |
|
74 | import make_cffi | |
75 |
|
75 | |||
76 | extensions.append(make_cffi.ffi.distutils_extension()) |
|
76 | extensions.append(make_cffi.ffi.distutils_extension()) | |
77 | install_requires.append("cffi>=%s" % MINIMUM_CFFI_VERSION) |
|
77 | install_requires.append("cffi>=%s" % MINIMUM_CFFI_VERSION) | |
78 |
|
78 | |||
79 | version = None |
|
79 | version = None | |
80 |
|
80 | |||
81 | with open("c-ext/python-zstandard.h", "r") as fh: |
|
81 | with open("c-ext/python-zstandard.h", "r") as fh: | |
82 | for line in fh: |
|
82 | for line in fh: | |
83 | if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"): |
|
83 | if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"): | |
84 | continue |
|
84 | continue | |
85 |
|
85 | |||
86 | version = line.split()[2][1:-1] |
|
86 | version = line.split()[2][1:-1] | |
87 | break |
|
87 | break | |
88 |
|
88 | |||
89 | if not version: |
|
89 | if not version: | |
90 | raise Exception("could not resolve package version; " "this should never happen") |
|
90 | raise Exception( | |
|
91 | "could not resolve package version; " "this should never happen" | |||
|
92 | ) | |||
91 |
|
93 | |||
92 | setup( |
|
94 | setup( | |
93 | name="zstandard", |
|
95 | name="zstandard", | |
94 | version=version, |
|
96 | version=version, | |
95 | description="Zstandard bindings for Python", |
|
97 | description="Zstandard bindings for Python", | |
96 | long_description=open("README.rst", "r").read(), |
|
98 | long_description=open("README.rst", "r").read(), | |
97 | url="https://github.com/indygreg/python-zstandard", |
|
99 | url="https://github.com/indygreg/python-zstandard", | |
98 | author="Gregory Szorc", |
|
100 | author="Gregory Szorc", | |
99 | author_email="gregory.szorc@gmail.com", |
|
101 | author_email="gregory.szorc@gmail.com", | |
100 | license="BSD", |
|
102 | license="BSD", | |
101 | classifiers=[ |
|
103 | classifiers=[ | |
102 | "Development Status :: 4 - Beta", |
|
104 | "Development Status :: 4 - Beta", | |
103 | "Intended Audience :: Developers", |
|
105 | "Intended Audience :: Developers", | |
104 | "License :: OSI Approved :: BSD License", |
|
106 | "License :: OSI Approved :: BSD License", | |
105 | "Programming Language :: C", |
|
107 | "Programming Language :: C", | |
106 | "Programming Language :: Python :: 2.7", |
|
108 | "Programming Language :: Python :: 2.7", | |
107 | "Programming Language :: Python :: 3.5", |
|
109 | "Programming Language :: Python :: 3.5", | |
108 | "Programming Language :: Python :: 3.6", |
|
110 | "Programming Language :: Python :: 3.6", | |
109 | "Programming Language :: Python :: 3.7", |
|
111 | "Programming Language :: Python :: 3.7", | |
110 | "Programming Language :: Python :: 3.8", |
|
112 | "Programming Language :: Python :: 3.8", | |
111 | ], |
|
113 | ], | |
112 | keywords="zstandard zstd compression", |
|
114 | keywords="zstandard zstd compression", | |
113 | packages=["zstandard"], |
|
115 | packages=["zstandard"], | |
114 | ext_modules=extensions, |
|
116 | ext_modules=extensions, | |
115 | test_suite="tests", |
|
117 | test_suite="tests", | |
116 | install_requires=install_requires, |
|
118 | install_requires=install_requires, | |
117 | tests_require=["hypothesis"], |
|
119 | tests_require=["hypothesis"], | |
118 | ) |
|
120 | ) |
@@ -1,206 +1,210 b'' | |||||
1 | # Copyright (c) 2016-present, Gregory Szorc |
|
1 | # Copyright (c) 2016-present, Gregory Szorc | |
2 | # All rights reserved. |
|
2 | # All rights reserved. | |
3 | # |
|
3 | # | |
4 | # This software may be modified and distributed under the terms |
|
4 | # This software may be modified and distributed under the terms | |
5 | # of the BSD license. See the LICENSE file for details. |
|
5 | # of the BSD license. See the LICENSE file for details. | |
6 |
|
6 | |||
7 | import distutils.ccompiler |
|
7 | import distutils.ccompiler | |
8 | import os |
|
8 | import os | |
9 |
|
9 | |||
10 | from distutils.extension import Extension |
|
10 | from distutils.extension import Extension | |
11 |
|
11 | |||
12 |
|
12 | |||
13 | zstd_sources = [ |
|
13 | zstd_sources = [ | |
14 | "zstd/%s" % p |
|
14 | "zstd/%s" % p | |
15 | for p in ( |
|
15 | for p in ( | |
16 | "common/debug.c", |
|
16 | "common/debug.c", | |
17 | "common/entropy_common.c", |
|
17 | "common/entropy_common.c", | |
18 | "common/error_private.c", |
|
18 | "common/error_private.c", | |
19 | "common/fse_decompress.c", |
|
19 | "common/fse_decompress.c", | |
20 | "common/pool.c", |
|
20 | "common/pool.c", | |
21 | "common/threading.c", |
|
21 | "common/threading.c", | |
22 | "common/xxhash.c", |
|
22 | "common/xxhash.c", | |
23 | "common/zstd_common.c", |
|
23 | "common/zstd_common.c", | |
24 | "compress/fse_compress.c", |
|
24 | "compress/fse_compress.c", | |
25 | "compress/hist.c", |
|
25 | "compress/hist.c", | |
26 | "compress/huf_compress.c", |
|
26 | "compress/huf_compress.c", | |
27 | "compress/zstd_compress_literals.c", |
|
27 | "compress/zstd_compress_literals.c", | |
28 | "compress/zstd_compress_sequences.c", |
|
28 | "compress/zstd_compress_sequences.c", | |
29 | "compress/zstd_compress.c", |
|
29 | "compress/zstd_compress.c", | |
30 | "compress/zstd_double_fast.c", |
|
30 | "compress/zstd_double_fast.c", | |
31 | "compress/zstd_fast.c", |
|
31 | "compress/zstd_fast.c", | |
32 | "compress/zstd_lazy.c", |
|
32 | "compress/zstd_lazy.c", | |
33 | "compress/zstd_ldm.c", |
|
33 | "compress/zstd_ldm.c", | |
34 | "compress/zstd_opt.c", |
|
34 | "compress/zstd_opt.c", | |
35 | "compress/zstdmt_compress.c", |
|
35 | "compress/zstdmt_compress.c", | |
36 | "decompress/huf_decompress.c", |
|
36 | "decompress/huf_decompress.c", | |
37 | "decompress/zstd_ddict.c", |
|
37 | "decompress/zstd_ddict.c", | |
38 | "decompress/zstd_decompress.c", |
|
38 | "decompress/zstd_decompress.c", | |
39 | "decompress/zstd_decompress_block.c", |
|
39 | "decompress/zstd_decompress_block.c", | |
40 | "dictBuilder/cover.c", |
|
40 | "dictBuilder/cover.c", | |
41 | "dictBuilder/divsufsort.c", |
|
41 | "dictBuilder/divsufsort.c", | |
42 | "dictBuilder/fastcover.c", |
|
42 | "dictBuilder/fastcover.c", | |
43 | "dictBuilder/zdict.c", |
|
43 | "dictBuilder/zdict.c", | |
44 | ) |
|
44 | ) | |
45 | ] |
|
45 | ] | |
46 |
|
46 | |||
47 | zstd_sources_legacy = [ |
|
47 | zstd_sources_legacy = [ | |
48 | "zstd/%s" % p |
|
48 | "zstd/%s" % p | |
49 | for p in ( |
|
49 | for p in ( | |
50 | "deprecated/zbuff_common.c", |
|
50 | "deprecated/zbuff_common.c", | |
51 | "deprecated/zbuff_compress.c", |
|
51 | "deprecated/zbuff_compress.c", | |
52 | "deprecated/zbuff_decompress.c", |
|
52 | "deprecated/zbuff_decompress.c", | |
53 | "legacy/zstd_v01.c", |
|
53 | "legacy/zstd_v01.c", | |
54 | "legacy/zstd_v02.c", |
|
54 | "legacy/zstd_v02.c", | |
55 | "legacy/zstd_v03.c", |
|
55 | "legacy/zstd_v03.c", | |
56 | "legacy/zstd_v04.c", |
|
56 | "legacy/zstd_v04.c", | |
57 | "legacy/zstd_v05.c", |
|
57 | "legacy/zstd_v05.c", | |
58 | "legacy/zstd_v06.c", |
|
58 | "legacy/zstd_v06.c", | |
59 | "legacy/zstd_v07.c", |
|
59 | "legacy/zstd_v07.c", | |
60 | ) |
|
60 | ) | |
61 | ] |
|
61 | ] | |
62 |
|
62 | |||
63 | zstd_includes = [ |
|
63 | zstd_includes = [ | |
64 | "zstd", |
|
64 | "zstd", | |
65 | "zstd/common", |
|
65 | "zstd/common", | |
66 | "zstd/compress", |
|
66 | "zstd/compress", | |
67 | "zstd/decompress", |
|
67 | "zstd/decompress", | |
68 | "zstd/dictBuilder", |
|
68 | "zstd/dictBuilder", | |
69 | ] |
|
69 | ] | |
70 |
|
70 | |||
71 | zstd_includes_legacy = [ |
|
71 | zstd_includes_legacy = [ | |
72 | "zstd/deprecated", |
|
72 | "zstd/deprecated", | |
73 | "zstd/legacy", |
|
73 | "zstd/legacy", | |
74 | ] |
|
74 | ] | |
75 |
|
75 | |||
76 | ext_includes = [ |
|
76 | ext_includes = [ | |
77 | "c-ext", |
|
77 | "c-ext", | |
78 | "zstd/common", |
|
78 | "zstd/common", | |
79 | ] |
|
79 | ] | |
80 |
|
80 | |||
81 | ext_sources = [ |
|
81 | ext_sources = [ | |
82 | "zstd/common/error_private.c", |
|
82 | "zstd/common/error_private.c", | |
83 | "zstd/common/pool.c", |
|
83 | "zstd/common/pool.c", | |
84 | "zstd/common/threading.c", |
|
84 | "zstd/common/threading.c", | |
85 | "zstd/common/zstd_common.c", |
|
85 | "zstd/common/zstd_common.c", | |
86 | "zstd.c", |
|
86 | "zstd.c", | |
87 | "c-ext/bufferutil.c", |
|
87 | "c-ext/bufferutil.c", | |
88 | "c-ext/compressiondict.c", |
|
88 | "c-ext/compressiondict.c", | |
89 | "c-ext/compressobj.c", |
|
89 | "c-ext/compressobj.c", | |
90 | "c-ext/compressor.c", |
|
90 | "c-ext/compressor.c", | |
91 | "c-ext/compressoriterator.c", |
|
91 | "c-ext/compressoriterator.c", | |
92 | "c-ext/compressionchunker.c", |
|
92 | "c-ext/compressionchunker.c", | |
93 | "c-ext/compressionparams.c", |
|
93 | "c-ext/compressionparams.c", | |
94 | "c-ext/compressionreader.c", |
|
94 | "c-ext/compressionreader.c", | |
95 | "c-ext/compressionwriter.c", |
|
95 | "c-ext/compressionwriter.c", | |
96 | "c-ext/constants.c", |
|
96 | "c-ext/constants.c", | |
97 | "c-ext/decompressobj.c", |
|
97 | "c-ext/decompressobj.c", | |
98 | "c-ext/decompressor.c", |
|
98 | "c-ext/decompressor.c", | |
99 | "c-ext/decompressoriterator.c", |
|
99 | "c-ext/decompressoriterator.c", | |
100 | "c-ext/decompressionreader.c", |
|
100 | "c-ext/decompressionreader.c", | |
101 | "c-ext/decompressionwriter.c", |
|
101 | "c-ext/decompressionwriter.c", | |
102 | "c-ext/frameparams.c", |
|
102 | "c-ext/frameparams.c", | |
103 | ] |
|
103 | ] | |
104 |
|
104 | |||
105 | zstd_depends = [ |
|
105 | zstd_depends = [ | |
106 | "c-ext/python-zstandard.h", |
|
106 | "c-ext/python-zstandard.h", | |
107 | ] |
|
107 | ] | |
108 |
|
108 | |||
109 |
|
109 | |||
110 | def get_c_extension( |
|
110 | def get_c_extension( | |
111 | support_legacy=False, |
|
111 | support_legacy=False, | |
112 | system_zstd=False, |
|
112 | system_zstd=False, | |
113 | name="zstd", |
|
113 | name="zstd", | |
114 | warnings_as_errors=False, |
|
114 | warnings_as_errors=False, | |
115 | root=None, |
|
115 | root=None, | |
116 | ): |
|
116 | ): | |
117 | """Obtain a distutils.extension.Extension for the C extension. |
|
117 | """Obtain a distutils.extension.Extension for the C extension. | |
118 |
|
118 | |||
119 | ``support_legacy`` controls whether to compile in legacy zstd format support. |
|
119 | ``support_legacy`` controls whether to compile in legacy zstd format support. | |
120 |
|
120 | |||
121 | ``system_zstd`` controls whether to compile against the system zstd library. |
|
121 | ``system_zstd`` controls whether to compile against the system zstd library. | |
122 | For this to work, the system zstd library and headers must match what |
|
122 | For this to work, the system zstd library and headers must match what | |
123 | python-zstandard is coded against exactly. |
|
123 | python-zstandard is coded against exactly. | |
124 |
|
124 | |||
125 | ``name`` is the module name of the C extension to produce. |
|
125 | ``name`` is the module name of the C extension to produce. | |
126 |
|
126 | |||
127 | ``warnings_as_errors`` controls whether compiler warnings are turned into |
|
127 | ``warnings_as_errors`` controls whether compiler warnings are turned into | |
128 | compiler errors. |
|
128 | compiler errors. | |
129 |
|
129 | |||
130 | ``root`` defines a root path that source should be computed as relative |
|
130 | ``root`` defines a root path that source should be computed as relative | |
131 | to. This should be the directory with the main ``setup.py`` that is |
|
131 | to. This should be the directory with the main ``setup.py`` that is | |
132 | being invoked. If not defined, paths will be relative to this file. |
|
132 | being invoked. If not defined, paths will be relative to this file. | |
133 | """ |
|
133 | """ | |
134 | actual_root = os.path.abspath(os.path.dirname(__file__)) |
|
134 | actual_root = os.path.abspath(os.path.dirname(__file__)) | |
135 | root = root or actual_root |
|
135 | root = root or actual_root | |
136 |
|
136 | |||
137 | sources = set([os.path.join(actual_root, p) for p in ext_sources]) |
|
137 | sources = set([os.path.join(actual_root, p) for p in ext_sources]) | |
138 | if not system_zstd: |
|
138 | if not system_zstd: | |
139 | sources.update([os.path.join(actual_root, p) for p in zstd_sources]) |
|
139 | sources.update([os.path.join(actual_root, p) for p in zstd_sources]) | |
140 | if support_legacy: |
|
140 | if support_legacy: | |
141 | sources.update([os.path.join(actual_root, p) for p in zstd_sources_legacy]) |
|
141 | sources.update( | |
|
142 | [os.path.join(actual_root, p) for p in zstd_sources_legacy] | |||
|
143 | ) | |||
142 | sources = list(sources) |
|
144 | sources = list(sources) | |
143 |
|
145 | |||
144 | include_dirs = set([os.path.join(actual_root, d) for d in ext_includes]) |
|
146 | include_dirs = set([os.path.join(actual_root, d) for d in ext_includes]) | |
145 | if not system_zstd: |
|
147 | if not system_zstd: | |
146 | include_dirs.update([os.path.join(actual_root, d) for d in zstd_includes]) |
|
148 | include_dirs.update( | |
|
149 | [os.path.join(actual_root, d) for d in zstd_includes] | |||
|
150 | ) | |||
147 | if support_legacy: |
|
151 | if support_legacy: | |
148 | include_dirs.update( |
|
152 | include_dirs.update( | |
149 | [os.path.join(actual_root, d) for d in zstd_includes_legacy] |
|
153 | [os.path.join(actual_root, d) for d in zstd_includes_legacy] | |
150 | ) |
|
154 | ) | |
151 | include_dirs = list(include_dirs) |
|
155 | include_dirs = list(include_dirs) | |
152 |
|
156 | |||
153 | depends = [os.path.join(actual_root, p) for p in zstd_depends] |
|
157 | depends = [os.path.join(actual_root, p) for p in zstd_depends] | |
154 |
|
158 | |||
155 | compiler = distutils.ccompiler.new_compiler() |
|
159 | compiler = distutils.ccompiler.new_compiler() | |
156 |
|
160 | |||
157 | # Needed for MSVC. |
|
161 | # Needed for MSVC. | |
158 | if hasattr(compiler, "initialize"): |
|
162 | if hasattr(compiler, "initialize"): | |
159 | compiler.initialize() |
|
163 | compiler.initialize() | |
160 |
|
164 | |||
161 | if compiler.compiler_type == "unix": |
|
165 | if compiler.compiler_type == "unix": | |
162 | compiler_type = "unix" |
|
166 | compiler_type = "unix" | |
163 | elif compiler.compiler_type == "msvc": |
|
167 | elif compiler.compiler_type == "msvc": | |
164 | compiler_type = "msvc" |
|
168 | compiler_type = "msvc" | |
165 | elif compiler.compiler_type == "mingw32": |
|
169 | elif compiler.compiler_type == "mingw32": | |
166 | compiler_type = "mingw32" |
|
170 | compiler_type = "mingw32" | |
167 | else: |
|
171 | else: | |
168 | raise Exception("unhandled compiler type: %s" % compiler.compiler_type) |
|
172 | raise Exception("unhandled compiler type: %s" % compiler.compiler_type) | |
169 |
|
173 | |||
170 | extra_args = ["-DZSTD_MULTITHREAD"] |
|
174 | extra_args = ["-DZSTD_MULTITHREAD"] | |
171 |
|
175 | |||
172 | if not system_zstd: |
|
176 | if not system_zstd: | |
173 | extra_args.append("-DZSTDLIB_VISIBILITY=") |
|
177 | extra_args.append("-DZSTDLIB_VISIBILITY=") | |
174 | extra_args.append("-DZDICTLIB_VISIBILITY=") |
|
178 | extra_args.append("-DZDICTLIB_VISIBILITY=") | |
175 | extra_args.append("-DZSTDERRORLIB_VISIBILITY=") |
|
179 | extra_args.append("-DZSTDERRORLIB_VISIBILITY=") | |
176 |
|
180 | |||
177 | if compiler_type == "unix": |
|
181 | if compiler_type == "unix": | |
178 | extra_args.append("-fvisibility=hidden") |
|
182 | extra_args.append("-fvisibility=hidden") | |
179 |
|
183 | |||
180 | if not system_zstd and support_legacy: |
|
184 | if not system_zstd and support_legacy: | |
181 | extra_args.append("-DZSTD_LEGACY_SUPPORT=1") |
|
185 | extra_args.append("-DZSTD_LEGACY_SUPPORT=1") | |
182 |
|
186 | |||
183 | if warnings_as_errors: |
|
187 | if warnings_as_errors: | |
184 | if compiler_type in ("unix", "mingw32"): |
|
188 | if compiler_type in ("unix", "mingw32"): | |
185 | extra_args.append("-Werror") |
|
189 | extra_args.append("-Werror") | |
186 | elif compiler_type == "msvc": |
|
190 | elif compiler_type == "msvc": | |
187 | extra_args.append("/WX") |
|
191 | extra_args.append("/WX") | |
188 | else: |
|
192 | else: | |
189 | assert False |
|
193 | assert False | |
190 |
|
194 | |||
191 | libraries = ["zstd"] if system_zstd else [] |
|
195 | libraries = ["zstd"] if system_zstd else [] | |
192 |
|
196 | |||
193 | # Python 3.7 doesn't like absolute paths. So normalize to relative. |
|
197 | # Python 3.7 doesn't like absolute paths. So normalize to relative. | |
194 | sources = [os.path.relpath(p, root) for p in sources] |
|
198 | sources = [os.path.relpath(p, root) for p in sources] | |
195 | include_dirs = [os.path.relpath(p, root) for p in include_dirs] |
|
199 | include_dirs = [os.path.relpath(p, root) for p in include_dirs] | |
196 | depends = [os.path.relpath(p, root) for p in depends] |
|
200 | depends = [os.path.relpath(p, root) for p in depends] | |
197 |
|
201 | |||
198 | # TODO compile with optimizations. |
|
202 | # TODO compile with optimizations. | |
199 | return Extension( |
|
203 | return Extension( | |
200 | name, |
|
204 | name, | |
201 | sources, |
|
205 | sources, | |
202 | include_dirs=include_dirs, |
|
206 | include_dirs=include_dirs, | |
203 | depends=depends, |
|
207 | depends=depends, | |
204 | extra_compile_args=extra_args, |
|
208 | extra_compile_args=extra_args, | |
205 | libraries=libraries, |
|
209 | libraries=libraries, | |
206 | ) |
|
210 | ) |
@@ -1,197 +1,203 b'' | |||||
1 | import imp |
|
1 | import imp | |
2 | import inspect |
|
2 | import inspect | |
3 | import io |
|
3 | import io | |
4 | import os |
|
4 | import os | |
5 | import types |
|
5 | import types | |
6 | import unittest |
|
6 | import unittest | |
7 |
|
7 | |||
8 | try: |
|
8 | try: | |
9 | import hypothesis |
|
9 | import hypothesis | |
10 | except ImportError: |
|
10 | except ImportError: | |
11 | hypothesis = None |
|
11 | hypothesis = None | |
12 |
|
12 | |||
13 |
|
13 | |||
14 | class TestCase(unittest.TestCase): |
|
14 | class TestCase(unittest.TestCase): | |
15 | if not getattr(unittest.TestCase, "assertRaisesRegex", False): |
|
15 | if not getattr(unittest.TestCase, "assertRaisesRegex", False): | |
16 | assertRaisesRegex = unittest.TestCase.assertRaisesRegexp |
|
16 | assertRaisesRegex = unittest.TestCase.assertRaisesRegexp | |
17 |
|
17 | |||
18 |
|
18 | |||
19 | def make_cffi(cls): |
|
19 | def make_cffi(cls): | |
20 | """Decorator to add CFFI versions of each test method.""" |
|
20 | """Decorator to add CFFI versions of each test method.""" | |
21 |
|
21 | |||
22 | # The module containing this class definition should |
|
22 | # The module containing this class definition should | |
23 | # `import zstandard as zstd`. Otherwise things may blow up. |
|
23 | # `import zstandard as zstd`. Otherwise things may blow up. | |
24 | mod = inspect.getmodule(cls) |
|
24 | mod = inspect.getmodule(cls) | |
25 | if not hasattr(mod, "zstd"): |
|
25 | if not hasattr(mod, "zstd"): | |
26 | raise Exception('test module does not contain "zstd" symbol') |
|
26 | raise Exception('test module does not contain "zstd" symbol') | |
27 |
|
27 | |||
28 | if not hasattr(mod.zstd, "backend"): |
|
28 | if not hasattr(mod.zstd, "backend"): | |
29 | raise Exception( |
|
29 | raise Exception( | |
30 | 'zstd symbol does not have "backend" attribute; did ' |
|
30 | 'zstd symbol does not have "backend" attribute; did ' | |
31 | "you `import zstandard as zstd`?" |
|
31 | "you `import zstandard as zstd`?" | |
32 | ) |
|
32 | ) | |
33 |
|
33 | |||
34 | # If `import zstandard` already chose the cffi backend, there is nothing |
|
34 | # If `import zstandard` already chose the cffi backend, there is nothing | |
35 | # for us to do: we only add the cffi variation if the default backend |
|
35 | # for us to do: we only add the cffi variation if the default backend | |
36 | # is the C extension. |
|
36 | # is the C extension. | |
37 | if mod.zstd.backend == "cffi": |
|
37 | if mod.zstd.backend == "cffi": | |
38 | return cls |
|
38 | return cls | |
39 |
|
39 | |||
40 | old_env = dict(os.environ) |
|
40 | old_env = dict(os.environ) | |
41 | os.environ["PYTHON_ZSTANDARD_IMPORT_POLICY"] = "cffi" |
|
41 | os.environ["PYTHON_ZSTANDARD_IMPORT_POLICY"] = "cffi" | |
42 | try: |
|
42 | try: | |
43 | try: |
|
43 | try: | |
44 | mod_info = imp.find_module("zstandard") |
|
44 | mod_info = imp.find_module("zstandard") | |
45 | mod = imp.load_module("zstandard_cffi", *mod_info) |
|
45 | mod = imp.load_module("zstandard_cffi", *mod_info) | |
46 | except ImportError: |
|
46 | except ImportError: | |
47 | return cls |
|
47 | return cls | |
48 | finally: |
|
48 | finally: | |
49 | os.environ.clear() |
|
49 | os.environ.clear() | |
50 | os.environ.update(old_env) |
|
50 | os.environ.update(old_env) | |
51 |
|
51 | |||
52 | if mod.backend != "cffi": |
|
52 | if mod.backend != "cffi": | |
53 | raise Exception("got the zstandard %s backend instead of cffi" % mod.backend) |
|
53 | raise Exception( | |
|
54 | "got the zstandard %s backend instead of cffi" % mod.backend | |||
|
55 | ) | |||
54 |
|
56 | |||
55 | # If CFFI version is available, dynamically construct test methods |
|
57 | # If CFFI version is available, dynamically construct test methods | |
56 | # that use it. |
|
58 | # that use it. | |
57 |
|
59 | |||
58 | for attr in dir(cls): |
|
60 | for attr in dir(cls): | |
59 | fn = getattr(cls, attr) |
|
61 | fn = getattr(cls, attr) | |
60 | if not inspect.ismethod(fn) and not inspect.isfunction(fn): |
|
62 | if not inspect.ismethod(fn) and not inspect.isfunction(fn): | |
61 | continue |
|
63 | continue | |
62 |
|
64 | |||
63 | if not fn.__name__.startswith("test_"): |
|
65 | if not fn.__name__.startswith("test_"): | |
64 | continue |
|
66 | continue | |
65 |
|
67 | |||
66 | name = "%s_cffi" % fn.__name__ |
|
68 | name = "%s_cffi" % fn.__name__ | |
67 |
|
69 | |||
68 | # Replace the "zstd" symbol with the CFFI module instance. Then copy |
|
70 | # Replace the "zstd" symbol with the CFFI module instance. Then copy | |
69 | # the function object and install it in a new attribute. |
|
71 | # the function object and install it in a new attribute. | |
70 | if isinstance(fn, types.FunctionType): |
|
72 | if isinstance(fn, types.FunctionType): | |
71 | globs = dict(fn.__globals__) |
|
73 | globs = dict(fn.__globals__) | |
72 | globs["zstd"] = mod |
|
74 | globs["zstd"] = mod | |
73 | new_fn = types.FunctionType( |
|
75 | new_fn = types.FunctionType( | |
74 | fn.__code__, globs, name, fn.__defaults__, fn.__closure__ |
|
76 | fn.__code__, globs, name, fn.__defaults__, fn.__closure__ | |
75 | ) |
|
77 | ) | |
76 | new_method = new_fn |
|
78 | new_method = new_fn | |
77 | else: |
|
79 | else: | |
78 | globs = dict(fn.__func__.func_globals) |
|
80 | globs = dict(fn.__func__.func_globals) | |
79 | globs["zstd"] = mod |
|
81 | globs["zstd"] = mod | |
80 | new_fn = types.FunctionType( |
|
82 | new_fn = types.FunctionType( | |
81 | fn.__func__.func_code, |
|
83 | fn.__func__.func_code, | |
82 | globs, |
|
84 | globs, | |
83 | name, |
|
85 | name, | |
84 | fn.__func__.func_defaults, |
|
86 | fn.__func__.func_defaults, | |
85 | fn.__func__.func_closure, |
|
87 | fn.__func__.func_closure, | |
86 | ) |
|
88 | ) | |
87 |
new_method = types.UnboundMethodType( |
|
89 | new_method = types.UnboundMethodType( | |
|
90 | new_fn, fn.im_self, fn.im_class | |||
|
91 | ) | |||
88 |
|
92 | |||
89 | setattr(cls, name, new_method) |
|
93 | setattr(cls, name, new_method) | |
90 |
|
94 | |||
91 | return cls |
|
95 | return cls | |
92 |
|
96 | |||
93 |
|
97 | |||
94 | class NonClosingBytesIO(io.BytesIO): |
|
98 | class NonClosingBytesIO(io.BytesIO): | |
95 | """BytesIO that saves the underlying buffer on close(). |
|
99 | """BytesIO that saves the underlying buffer on close(). | |
96 |
|
100 | |||
97 | This allows us to access written data after close(). |
|
101 | This allows us to access written data after close(). | |
98 | """ |
|
102 | """ | |
99 |
|
103 | |||
100 | def __init__(self, *args, **kwargs): |
|
104 | def __init__(self, *args, **kwargs): | |
101 | super(NonClosingBytesIO, self).__init__(*args, **kwargs) |
|
105 | super(NonClosingBytesIO, self).__init__(*args, **kwargs) | |
102 | self._saved_buffer = None |
|
106 | self._saved_buffer = None | |
103 |
|
107 | |||
104 | def close(self): |
|
108 | def close(self): | |
105 | self._saved_buffer = self.getvalue() |
|
109 | self._saved_buffer = self.getvalue() | |
106 | return super(NonClosingBytesIO, self).close() |
|
110 | return super(NonClosingBytesIO, self).close() | |
107 |
|
111 | |||
108 | def getvalue(self): |
|
112 | def getvalue(self): | |
109 | if self.closed: |
|
113 | if self.closed: | |
110 | return self._saved_buffer |
|
114 | return self._saved_buffer | |
111 | else: |
|
115 | else: | |
112 | return super(NonClosingBytesIO, self).getvalue() |
|
116 | return super(NonClosingBytesIO, self).getvalue() | |
113 |
|
117 | |||
114 |
|
118 | |||
115 | class OpCountingBytesIO(NonClosingBytesIO): |
|
119 | class OpCountingBytesIO(NonClosingBytesIO): | |
116 | def __init__(self, *args, **kwargs): |
|
120 | def __init__(self, *args, **kwargs): | |
117 | self._flush_count = 0 |
|
121 | self._flush_count = 0 | |
118 | self._read_count = 0 |
|
122 | self._read_count = 0 | |
119 | self._write_count = 0 |
|
123 | self._write_count = 0 | |
120 | return super(OpCountingBytesIO, self).__init__(*args, **kwargs) |
|
124 | return super(OpCountingBytesIO, self).__init__(*args, **kwargs) | |
121 |
|
125 | |||
122 | def flush(self): |
|
126 | def flush(self): | |
123 | self._flush_count += 1 |
|
127 | self._flush_count += 1 | |
124 | return super(OpCountingBytesIO, self).flush() |
|
128 | return super(OpCountingBytesIO, self).flush() | |
125 |
|
129 | |||
126 | def read(self, *args): |
|
130 | def read(self, *args): | |
127 | self._read_count += 1 |
|
131 | self._read_count += 1 | |
128 | return super(OpCountingBytesIO, self).read(*args) |
|
132 | return super(OpCountingBytesIO, self).read(*args) | |
129 |
|
133 | |||
130 | def write(self, data): |
|
134 | def write(self, data): | |
131 | self._write_count += 1 |
|
135 | self._write_count += 1 | |
132 | return super(OpCountingBytesIO, self).write(data) |
|
136 | return super(OpCountingBytesIO, self).write(data) | |
133 |
|
137 | |||
134 |
|
138 | |||
135 | _source_files = [] |
|
139 | _source_files = [] | |
136 |
|
140 | |||
137 |
|
141 | |||
138 | def random_input_data(): |
|
142 | def random_input_data(): | |
139 | """Obtain the raw content of source files. |
|
143 | """Obtain the raw content of source files. | |
140 |
|
144 | |||
141 | This is used for generating "random" data to feed into fuzzing, since it is |
|
145 | This is used for generating "random" data to feed into fuzzing, since it is | |
142 | faster than random content generation. |
|
146 | faster than random content generation. | |
143 | """ |
|
147 | """ | |
144 | if _source_files: |
|
148 | if _source_files: | |
145 | return _source_files |
|
149 | return _source_files | |
146 |
|
150 | |||
147 | for root, dirs, files in os.walk(os.path.dirname(__file__)): |
|
151 | for root, dirs, files in os.walk(os.path.dirname(__file__)): | |
148 | dirs[:] = list(sorted(dirs)) |
|
152 | dirs[:] = list(sorted(dirs)) | |
149 | for f in sorted(files): |
|
153 | for f in sorted(files): | |
150 | try: |
|
154 | try: | |
151 | with open(os.path.join(root, f), "rb") as fh: |
|
155 | with open(os.path.join(root, f), "rb") as fh: | |
152 | data = fh.read() |
|
156 | data = fh.read() | |
153 | if data: |
|
157 | if data: | |
154 | _source_files.append(data) |
|
158 | _source_files.append(data) | |
155 | except OSError: |
|
159 | except OSError: | |
156 | pass |
|
160 | pass | |
157 |
|
161 | |||
158 | # Also add some actual random data. |
|
162 | # Also add some actual random data. | |
159 | _source_files.append(os.urandom(100)) |
|
163 | _source_files.append(os.urandom(100)) | |
160 | _source_files.append(os.urandom(1000)) |
|
164 | _source_files.append(os.urandom(1000)) | |
161 | _source_files.append(os.urandom(10000)) |
|
165 | _source_files.append(os.urandom(10000)) | |
162 | _source_files.append(os.urandom(100000)) |
|
166 | _source_files.append(os.urandom(100000)) | |
163 | _source_files.append(os.urandom(1000000)) |
|
167 | _source_files.append(os.urandom(1000000)) | |
164 |
|
168 | |||
165 | return _source_files |
|
169 | return _source_files | |
166 |
|
170 | |||
167 |
|
171 | |||
168 | def generate_samples(): |
|
172 | def generate_samples(): | |
169 | inputs = [ |
|
173 | inputs = [ | |
170 | b"foo", |
|
174 | b"foo", | |
171 | b"bar", |
|
175 | b"bar", | |
172 | b"abcdef", |
|
176 | b"abcdef", | |
173 | b"sometext", |
|
177 | b"sometext", | |
174 | b"baz", |
|
178 | b"baz", | |
175 | ] |
|
179 | ] | |
176 |
|
180 | |||
177 | samples = [] |
|
181 | samples = [] | |
178 |
|
182 | |||
179 | for i in range(128): |
|
183 | for i in range(128): | |
180 | samples.append(inputs[i % 5]) |
|
184 | samples.append(inputs[i % 5]) | |
181 | samples.append(inputs[i % 5] * (i + 3)) |
|
185 | samples.append(inputs[i % 5] * (i + 3)) | |
182 | samples.append(inputs[-(i % 5)] * (i + 2)) |
|
186 | samples.append(inputs[-(i % 5)] * (i + 2)) | |
183 |
|
187 | |||
184 | return samples |
|
188 | return samples | |
185 |
|
189 | |||
186 |
|
190 | |||
187 | if hypothesis: |
|
191 | if hypothesis: | |
188 | default_settings = hypothesis.settings(deadline=10000) |
|
192 | default_settings = hypothesis.settings(deadline=10000) | |
189 | hypothesis.settings.register_profile("default", default_settings) |
|
193 | hypothesis.settings.register_profile("default", default_settings) | |
190 |
|
194 | |||
191 | ci_settings = hypothesis.settings(deadline=20000, max_examples=1000) |
|
195 | ci_settings = hypothesis.settings(deadline=20000, max_examples=1000) | |
192 | hypothesis.settings.register_profile("ci", ci_settings) |
|
196 | hypothesis.settings.register_profile("ci", ci_settings) | |
193 |
|
197 | |||
194 | expensive_settings = hypothesis.settings(deadline=None, max_examples=10000) |
|
198 | expensive_settings = hypothesis.settings(deadline=None, max_examples=10000) | |
195 | hypothesis.settings.register_profile("expensive", expensive_settings) |
|
199 | hypothesis.settings.register_profile("expensive", expensive_settings) | |
196 |
|
200 | |||
197 |
hypothesis.settings.load_profile( |
|
201 | hypothesis.settings.load_profile( | |
|
202 | os.environ.get("HYPOTHESIS_PROFILE", "default") | |||
|
203 | ) |
@@ -1,146 +1,153 b'' | |||||
1 | import struct |
|
1 | import struct | |
2 | import unittest |
|
2 | import unittest | |
3 |
|
3 | |||
4 | import zstandard as zstd |
|
4 | import zstandard as zstd | |
5 |
|
5 | |||
6 | from .common import TestCase |
|
6 | from .common import TestCase | |
7 |
|
7 | |||
8 | ss = struct.Struct("=QQ") |
|
8 | ss = struct.Struct("=QQ") | |
9 |
|
9 | |||
10 |
|
10 | |||
11 | class TestBufferWithSegments(TestCase): |
|
11 | class TestBufferWithSegments(TestCase): | |
12 | def test_arguments(self): |
|
12 | def test_arguments(self): | |
13 | if not hasattr(zstd, "BufferWithSegments"): |
|
13 | if not hasattr(zstd, "BufferWithSegments"): | |
14 | self.skipTest("BufferWithSegments not available") |
|
14 | self.skipTest("BufferWithSegments not available") | |
15 |
|
15 | |||
16 | with self.assertRaises(TypeError): |
|
16 | with self.assertRaises(TypeError): | |
17 | zstd.BufferWithSegments() |
|
17 | zstd.BufferWithSegments() | |
18 |
|
18 | |||
19 | with self.assertRaises(TypeError): |
|
19 | with self.assertRaises(TypeError): | |
20 | zstd.BufferWithSegments(b"foo") |
|
20 | zstd.BufferWithSegments(b"foo") | |
21 |
|
21 | |||
22 | # Segments data should be a multiple of 16. |
|
22 | # Segments data should be a multiple of 16. | |
23 | with self.assertRaisesRegex( |
|
23 | with self.assertRaisesRegex( | |
24 | ValueError, "segments array size is not a multiple of 16" |
|
24 | ValueError, "segments array size is not a multiple of 16" | |
25 | ): |
|
25 | ): | |
26 | zstd.BufferWithSegments(b"foo", b"\x00\x00") |
|
26 | zstd.BufferWithSegments(b"foo", b"\x00\x00") | |
27 |
|
27 | |||
28 | def test_invalid_offset(self): |
|
28 | def test_invalid_offset(self): | |
29 | if not hasattr(zstd, "BufferWithSegments"): |
|
29 | if not hasattr(zstd, "BufferWithSegments"): | |
30 | self.skipTest("BufferWithSegments not available") |
|
30 | self.skipTest("BufferWithSegments not available") | |
31 |
|
31 | |||
32 | with self.assertRaisesRegex( |
|
32 | with self.assertRaisesRegex( | |
33 | ValueError, "offset within segments array references memory" |
|
33 | ValueError, "offset within segments array references memory" | |
34 | ): |
|
34 | ): | |
35 | zstd.BufferWithSegments(b"foo", ss.pack(0, 4)) |
|
35 | zstd.BufferWithSegments(b"foo", ss.pack(0, 4)) | |
36 |
|
36 | |||
37 | def test_invalid_getitem(self): |
|
37 | def test_invalid_getitem(self): | |
38 | if not hasattr(zstd, "BufferWithSegments"): |
|
38 | if not hasattr(zstd, "BufferWithSegments"): | |
39 | self.skipTest("BufferWithSegments not available") |
|
39 | self.skipTest("BufferWithSegments not available") | |
40 |
|
40 | |||
41 | b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) |
|
41 | b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) | |
42 |
|
42 | |||
43 | with self.assertRaisesRegex(IndexError, "offset must be non-negative"): |
|
43 | with self.assertRaisesRegex(IndexError, "offset must be non-negative"): | |
44 | test = b[-10] |
|
44 | test = b[-10] | |
45 |
|
45 | |||
46 | with self.assertRaisesRegex(IndexError, "offset must be less than 1"): |
|
46 | with self.assertRaisesRegex(IndexError, "offset must be less than 1"): | |
47 | test = b[1] |
|
47 | test = b[1] | |
48 |
|
48 | |||
49 | with self.assertRaisesRegex(IndexError, "offset must be less than 1"): |
|
49 | with self.assertRaisesRegex(IndexError, "offset must be less than 1"): | |
50 | test = b[2] |
|
50 | test = b[2] | |
51 |
|
51 | |||
52 | def test_single(self): |
|
52 | def test_single(self): | |
53 | if not hasattr(zstd, "BufferWithSegments"): |
|
53 | if not hasattr(zstd, "BufferWithSegments"): | |
54 | self.skipTest("BufferWithSegments not available") |
|
54 | self.skipTest("BufferWithSegments not available") | |
55 |
|
55 | |||
56 | b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) |
|
56 | b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) | |
57 | self.assertEqual(len(b), 1) |
|
57 | self.assertEqual(len(b), 1) | |
58 | self.assertEqual(b.size, 3) |
|
58 | self.assertEqual(b.size, 3) | |
59 | self.assertEqual(b.tobytes(), b"foo") |
|
59 | self.assertEqual(b.tobytes(), b"foo") | |
60 |
|
60 | |||
61 | self.assertEqual(len(b[0]), 3) |
|
61 | self.assertEqual(len(b[0]), 3) | |
62 | self.assertEqual(b[0].offset, 0) |
|
62 | self.assertEqual(b[0].offset, 0) | |
63 | self.assertEqual(b[0].tobytes(), b"foo") |
|
63 | self.assertEqual(b[0].tobytes(), b"foo") | |
64 |
|
64 | |||
65 | def test_multiple(self): |
|
65 | def test_multiple(self): | |
66 | if not hasattr(zstd, "BufferWithSegments"): |
|
66 | if not hasattr(zstd, "BufferWithSegments"): | |
67 | self.skipTest("BufferWithSegments not available") |
|
67 | self.skipTest("BufferWithSegments not available") | |
68 |
|
68 | |||
69 | b = zstd.BufferWithSegments( |
|
69 | b = zstd.BufferWithSegments( | |
70 | b"foofooxfooxy", b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]) |
|
70 | b"foofooxfooxy", | |
|
71 | b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]), | |||
71 | ) |
|
72 | ) | |
72 | self.assertEqual(len(b), 3) |
|
73 | self.assertEqual(len(b), 3) | |
73 | self.assertEqual(b.size, 12) |
|
74 | self.assertEqual(b.size, 12) | |
74 | self.assertEqual(b.tobytes(), b"foofooxfooxy") |
|
75 | self.assertEqual(b.tobytes(), b"foofooxfooxy") | |
75 |
|
76 | |||
76 | self.assertEqual(b[0].tobytes(), b"foo") |
|
77 | self.assertEqual(b[0].tobytes(), b"foo") | |
77 | self.assertEqual(b[1].tobytes(), b"foox") |
|
78 | self.assertEqual(b[1].tobytes(), b"foox") | |
78 | self.assertEqual(b[2].tobytes(), b"fooxy") |
|
79 | self.assertEqual(b[2].tobytes(), b"fooxy") | |
79 |
|
80 | |||
80 |
|
81 | |||
81 | class TestBufferWithSegmentsCollection(TestCase): |
|
82 | class TestBufferWithSegmentsCollection(TestCase): | |
82 | def test_empty_constructor(self): |
|
83 | def test_empty_constructor(self): | |
83 | if not hasattr(zstd, "BufferWithSegmentsCollection"): |
|
84 | if not hasattr(zstd, "BufferWithSegmentsCollection"): | |
84 | self.skipTest("BufferWithSegmentsCollection not available") |
|
85 | self.skipTest("BufferWithSegmentsCollection not available") | |
85 |
|
86 | |||
86 |
with self.assertRaisesRegex( |
|
87 | with self.assertRaisesRegex( | |
|
88 | ValueError, "must pass at least 1 argument" | |||
|
89 | ): | |||
87 | zstd.BufferWithSegmentsCollection() |
|
90 | zstd.BufferWithSegmentsCollection() | |
88 |
|
91 | |||
89 | def test_argument_validation(self): |
|
92 | def test_argument_validation(self): | |
90 | if not hasattr(zstd, "BufferWithSegmentsCollection"): |
|
93 | if not hasattr(zstd, "BufferWithSegmentsCollection"): | |
91 | self.skipTest("BufferWithSegmentsCollection not available") |
|
94 | self.skipTest("BufferWithSegmentsCollection not available") | |
92 |
|
95 | |||
93 | with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"): |
|
96 | with self.assertRaisesRegex( | |
|
97 | TypeError, "arguments must be BufferWithSegments" | |||
|
98 | ): | |||
94 | zstd.BufferWithSegmentsCollection(None) |
|
99 | zstd.BufferWithSegmentsCollection(None) | |
95 |
|
100 | |||
96 | with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"): |
|
101 | with self.assertRaisesRegex( | |
|
102 | TypeError, "arguments must be BufferWithSegments" | |||
|
103 | ): | |||
97 | zstd.BufferWithSegmentsCollection( |
|
104 | zstd.BufferWithSegmentsCollection( | |
98 | zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None |
|
105 | zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None | |
99 | ) |
|
106 | ) | |
100 |
|
107 | |||
101 | with self.assertRaisesRegex( |
|
108 | with self.assertRaisesRegex( | |
102 | ValueError, "ZstdBufferWithSegments cannot be empty" |
|
109 | ValueError, "ZstdBufferWithSegments cannot be empty" | |
103 | ): |
|
110 | ): | |
104 | zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b"")) |
|
111 | zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b"")) | |
105 |
|
112 | |||
106 | def test_length(self): |
|
113 | def test_length(self): | |
107 | if not hasattr(zstd, "BufferWithSegmentsCollection"): |
|
114 | if not hasattr(zstd, "BufferWithSegmentsCollection"): | |
108 | self.skipTest("BufferWithSegmentsCollection not available") |
|
115 | self.skipTest("BufferWithSegmentsCollection not available") | |
109 |
|
116 | |||
110 | b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) |
|
117 | b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) | |
111 | b2 = zstd.BufferWithSegments( |
|
118 | b2 = zstd.BufferWithSegments( | |
112 | b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)]) |
|
119 | b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)]) | |
113 | ) |
|
120 | ) | |
114 |
|
121 | |||
115 | c = zstd.BufferWithSegmentsCollection(b1) |
|
122 | c = zstd.BufferWithSegmentsCollection(b1) | |
116 | self.assertEqual(len(c), 1) |
|
123 | self.assertEqual(len(c), 1) | |
117 | self.assertEqual(c.size(), 3) |
|
124 | self.assertEqual(c.size(), 3) | |
118 |
|
125 | |||
119 | c = zstd.BufferWithSegmentsCollection(b2) |
|
126 | c = zstd.BufferWithSegmentsCollection(b2) | |
120 | self.assertEqual(len(c), 2) |
|
127 | self.assertEqual(len(c), 2) | |
121 | self.assertEqual(c.size(), 6) |
|
128 | self.assertEqual(c.size(), 6) | |
122 |
|
129 | |||
123 | c = zstd.BufferWithSegmentsCollection(b1, b2) |
|
130 | c = zstd.BufferWithSegmentsCollection(b1, b2) | |
124 | self.assertEqual(len(c), 3) |
|
131 | self.assertEqual(len(c), 3) | |
125 | self.assertEqual(c.size(), 9) |
|
132 | self.assertEqual(c.size(), 9) | |
126 |
|
133 | |||
127 | def test_getitem(self): |
|
134 | def test_getitem(self): | |
128 | if not hasattr(zstd, "BufferWithSegmentsCollection"): |
|
135 | if not hasattr(zstd, "BufferWithSegmentsCollection"): | |
129 | self.skipTest("BufferWithSegmentsCollection not available") |
|
136 | self.skipTest("BufferWithSegmentsCollection not available") | |
130 |
|
137 | |||
131 | b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) |
|
138 | b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3)) | |
132 | b2 = zstd.BufferWithSegments( |
|
139 | b2 = zstd.BufferWithSegments( | |
133 | b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)]) |
|
140 | b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)]) | |
134 | ) |
|
141 | ) | |
135 |
|
142 | |||
136 | c = zstd.BufferWithSegmentsCollection(b1, b2) |
|
143 | c = zstd.BufferWithSegmentsCollection(b1, b2) | |
137 |
|
144 | |||
138 | with self.assertRaisesRegex(IndexError, "offset must be less than 3"): |
|
145 | with self.assertRaisesRegex(IndexError, "offset must be less than 3"): | |
139 | c[3] |
|
146 | c[3] | |
140 |
|
147 | |||
141 | with self.assertRaisesRegex(IndexError, "offset must be less than 3"): |
|
148 | with self.assertRaisesRegex(IndexError, "offset must be less than 3"): | |
142 | c[4] |
|
149 | c[4] | |
143 |
|
150 | |||
144 | self.assertEqual(c[0].tobytes(), b"foo") |
|
151 | self.assertEqual(c[0].tobytes(), b"foo") | |
145 | self.assertEqual(c[1].tobytes(), b"bar") |
|
152 | self.assertEqual(c[1].tobytes(), b"bar") | |
146 | self.assertEqual(c[2].tobytes(), b"baz") |
|
153 | self.assertEqual(c[2].tobytes(), b"baz") |
@@ -1,1770 +1,1803 b'' | |||||
1 | import hashlib |
|
1 | import hashlib | |
2 | import io |
|
2 | import io | |
3 | import os |
|
3 | import os | |
4 | import struct |
|
4 | import struct | |
5 | import sys |
|
5 | import sys | |
6 | import tarfile |
|
6 | import tarfile | |
7 | import tempfile |
|
7 | import tempfile | |
8 | import unittest |
|
8 | import unittest | |
9 |
|
9 | |||
10 | import zstandard as zstd |
|
10 | import zstandard as zstd | |
11 |
|
11 | |||
12 | from .common import ( |
|
12 | from .common import ( | |
13 | make_cffi, |
|
13 | make_cffi, | |
14 | NonClosingBytesIO, |
|
14 | NonClosingBytesIO, | |
15 | OpCountingBytesIO, |
|
15 | OpCountingBytesIO, | |
16 | TestCase, |
|
16 | TestCase, | |
17 | ) |
|
17 | ) | |
18 |
|
18 | |||
19 |
|
19 | |||
20 | if sys.version_info[0] >= 3: |
|
20 | if sys.version_info[0] >= 3: | |
21 | next = lambda it: it.__next__() |
|
21 | next = lambda it: it.__next__() | |
22 | else: |
|
22 | else: | |
23 | next = lambda it: it.next() |
|
23 | next = lambda it: it.next() | |
24 |
|
24 | |||
25 |
|
25 | |||
26 | def multithreaded_chunk_size(level, source_size=0): |
|
26 | def multithreaded_chunk_size(level, source_size=0): | |
27 |
params = zstd.ZstdCompressionParameters.from_level( |
|
27 | params = zstd.ZstdCompressionParameters.from_level( | |
|
28 | level, source_size=source_size | |||
|
29 | ) | |||
28 |
|
30 | |||
29 | return 1 << (params.window_log + 2) |
|
31 | return 1 << (params.window_log + 2) | |
30 |
|
32 | |||
31 |
|
33 | |||
32 | @make_cffi |
|
34 | @make_cffi | |
33 | class TestCompressor(TestCase): |
|
35 | class TestCompressor(TestCase): | |
34 | def test_level_bounds(self): |
|
36 | def test_level_bounds(self): | |
35 | with self.assertRaises(ValueError): |
|
37 | with self.assertRaises(ValueError): | |
36 | zstd.ZstdCompressor(level=23) |
|
38 | zstd.ZstdCompressor(level=23) | |
37 |
|
39 | |||
38 | def test_memory_size(self): |
|
40 | def test_memory_size(self): | |
39 | cctx = zstd.ZstdCompressor(level=1) |
|
41 | cctx = zstd.ZstdCompressor(level=1) | |
40 | self.assertGreater(cctx.memory_size(), 100) |
|
42 | self.assertGreater(cctx.memory_size(), 100) | |
41 |
|
43 | |||
42 |
|
44 | |||
43 | @make_cffi |
|
45 | @make_cffi | |
44 | class TestCompressor_compress(TestCase): |
|
46 | class TestCompressor_compress(TestCase): | |
45 | def test_compress_empty(self): |
|
47 | def test_compress_empty(self): | |
46 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
48 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
47 | result = cctx.compress(b"") |
|
49 | result = cctx.compress(b"") | |
48 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") |
|
50 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") | |
49 | params = zstd.get_frame_parameters(result) |
|
51 | params = zstd.get_frame_parameters(result) | |
50 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
52 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
51 | self.assertEqual(params.window_size, 524288) |
|
53 | self.assertEqual(params.window_size, 524288) | |
52 | self.assertEqual(params.dict_id, 0) |
|
54 | self.assertEqual(params.dict_id, 0) | |
53 | self.assertFalse(params.has_checksum, 0) |
|
55 | self.assertFalse(params.has_checksum, 0) | |
54 |
|
56 | |||
55 | cctx = zstd.ZstdCompressor() |
|
57 | cctx = zstd.ZstdCompressor() | |
56 | result = cctx.compress(b"") |
|
58 | result = cctx.compress(b"") | |
57 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00") |
|
59 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00") | |
58 | params = zstd.get_frame_parameters(result) |
|
60 | params = zstd.get_frame_parameters(result) | |
59 | self.assertEqual(params.content_size, 0) |
|
61 | self.assertEqual(params.content_size, 0) | |
60 |
|
62 | |||
61 | def test_input_types(self): |
|
63 | def test_input_types(self): | |
62 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
64 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
63 | expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f" |
|
65 | expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f" | |
64 |
|
66 | |||
65 | mutable_array = bytearray(3) |
|
67 | mutable_array = bytearray(3) | |
66 | mutable_array[:] = b"foo" |
|
68 | mutable_array[:] = b"foo" | |
67 |
|
69 | |||
68 | sources = [ |
|
70 | sources = [ | |
69 | memoryview(b"foo"), |
|
71 | memoryview(b"foo"), | |
70 | bytearray(b"foo"), |
|
72 | bytearray(b"foo"), | |
71 | mutable_array, |
|
73 | mutable_array, | |
72 | ] |
|
74 | ] | |
73 |
|
75 | |||
74 | for source in sources: |
|
76 | for source in sources: | |
75 | self.assertEqual(cctx.compress(source), expected) |
|
77 | self.assertEqual(cctx.compress(source), expected) | |
76 |
|
78 | |||
77 | def test_compress_large(self): |
|
79 | def test_compress_large(self): | |
78 | chunks = [] |
|
80 | chunks = [] | |
79 | for i in range(255): |
|
81 | for i in range(255): | |
80 | chunks.append(struct.Struct(">B").pack(i) * 16384) |
|
82 | chunks.append(struct.Struct(">B").pack(i) * 16384) | |
81 |
|
83 | |||
82 | cctx = zstd.ZstdCompressor(level=3, write_content_size=False) |
|
84 | cctx = zstd.ZstdCompressor(level=3, write_content_size=False) | |
83 | result = cctx.compress(b"".join(chunks)) |
|
85 | result = cctx.compress(b"".join(chunks)) | |
84 | self.assertEqual(len(result), 999) |
|
86 | self.assertEqual(len(result), 999) | |
85 | self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") |
|
87 | self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") | |
86 |
|
88 | |||
87 | # This matches the test for read_to_iter() below. |
|
89 | # This matches the test for read_to_iter() below. | |
88 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
90 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
89 | result = cctx.compress(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o") |
|
91 | result = cctx.compress( | |
|
92 | b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o" | |||
|
93 | ) | |||
90 | self.assertEqual( |
|
94 | self.assertEqual( | |
91 | result, |
|
95 | result, | |
92 | b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" |
|
96 | b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" | |
93 | b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" |
|
97 | b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" | |
94 | b"\x02\x09\x00\x00\x6f", |
|
98 | b"\x02\x09\x00\x00\x6f", | |
95 | ) |
|
99 | ) | |
96 |
|
100 | |||
97 | def test_negative_level(self): |
|
101 | def test_negative_level(self): | |
98 | cctx = zstd.ZstdCompressor(level=-4) |
|
102 | cctx = zstd.ZstdCompressor(level=-4) | |
99 | result = cctx.compress(b"foo" * 256) |
|
103 | result = cctx.compress(b"foo" * 256) | |
100 |
|
104 | |||
101 | def test_no_magic(self): |
|
105 | def test_no_magic(self): | |
102 |
params = zstd.ZstdCompressionParameters.from_level( |
|
106 | params = zstd.ZstdCompressionParameters.from_level( | |
|
107 | 1, format=zstd.FORMAT_ZSTD1 | |||
|
108 | ) | |||
103 | cctx = zstd.ZstdCompressor(compression_params=params) |
|
109 | cctx = zstd.ZstdCompressor(compression_params=params) | |
104 | magic = cctx.compress(b"foobar") |
|
110 | magic = cctx.compress(b"foobar") | |
105 |
|
111 | |||
106 | params = zstd.ZstdCompressionParameters.from_level( |
|
112 | params = zstd.ZstdCompressionParameters.from_level( | |
107 | 1, format=zstd.FORMAT_ZSTD1_MAGICLESS |
|
113 | 1, format=zstd.FORMAT_ZSTD1_MAGICLESS | |
108 | ) |
|
114 | ) | |
109 | cctx = zstd.ZstdCompressor(compression_params=params) |
|
115 | cctx = zstd.ZstdCompressor(compression_params=params) | |
110 | no_magic = cctx.compress(b"foobar") |
|
116 | no_magic = cctx.compress(b"foobar") | |
111 |
|
117 | |||
112 | self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd") |
|
118 | self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd") | |
113 | self.assertEqual(magic[4:], no_magic) |
|
119 | self.assertEqual(magic[4:], no_magic) | |
114 |
|
120 | |||
115 | def test_write_checksum(self): |
|
121 | def test_write_checksum(self): | |
116 | cctx = zstd.ZstdCompressor(level=1) |
|
122 | cctx = zstd.ZstdCompressor(level=1) | |
117 | no_checksum = cctx.compress(b"foobar") |
|
123 | no_checksum = cctx.compress(b"foobar") | |
118 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
124 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | |
119 | with_checksum = cctx.compress(b"foobar") |
|
125 | with_checksum = cctx.compress(b"foobar") | |
120 |
|
126 | |||
121 | self.assertEqual(len(with_checksum), len(no_checksum) + 4) |
|
127 | self.assertEqual(len(with_checksum), len(no_checksum) + 4) | |
122 |
|
128 | |||
123 | no_params = zstd.get_frame_parameters(no_checksum) |
|
129 | no_params = zstd.get_frame_parameters(no_checksum) | |
124 | with_params = zstd.get_frame_parameters(with_checksum) |
|
130 | with_params = zstd.get_frame_parameters(with_checksum) | |
125 |
|
131 | |||
126 | self.assertFalse(no_params.has_checksum) |
|
132 | self.assertFalse(no_params.has_checksum) | |
127 | self.assertTrue(with_params.has_checksum) |
|
133 | self.assertTrue(with_params.has_checksum) | |
128 |
|
134 | |||
129 | def test_write_content_size(self): |
|
135 | def test_write_content_size(self): | |
130 | cctx = zstd.ZstdCompressor(level=1) |
|
136 | cctx = zstd.ZstdCompressor(level=1) | |
131 | with_size = cctx.compress(b"foobar" * 256) |
|
137 | with_size = cctx.compress(b"foobar" * 256) | |
132 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
138 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
133 | no_size = cctx.compress(b"foobar" * 256) |
|
139 | no_size = cctx.compress(b"foobar" * 256) | |
134 |
|
140 | |||
135 | self.assertEqual(len(with_size), len(no_size) + 1) |
|
141 | self.assertEqual(len(with_size), len(no_size) + 1) | |
136 |
|
142 | |||
137 | no_params = zstd.get_frame_parameters(no_size) |
|
143 | no_params = zstd.get_frame_parameters(no_size) | |
138 | with_params = zstd.get_frame_parameters(with_size) |
|
144 | with_params = zstd.get_frame_parameters(with_size) | |
139 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
145 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
140 | self.assertEqual(with_params.content_size, 1536) |
|
146 | self.assertEqual(with_params.content_size, 1536) | |
141 |
|
147 | |||
142 | def test_no_dict_id(self): |
|
148 | def test_no_dict_id(self): | |
143 | samples = [] |
|
149 | samples = [] | |
144 | for i in range(128): |
|
150 | for i in range(128): | |
145 | samples.append(b"foo" * 64) |
|
151 | samples.append(b"foo" * 64) | |
146 | samples.append(b"bar" * 64) |
|
152 | samples.append(b"bar" * 64) | |
147 | samples.append(b"foobar" * 64) |
|
153 | samples.append(b"foobar" * 64) | |
148 |
|
154 | |||
149 | d = zstd.train_dictionary(1024, samples) |
|
155 | d = zstd.train_dictionary(1024, samples) | |
150 |
|
156 | |||
151 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
157 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) | |
152 | with_dict_id = cctx.compress(b"foobarfoobar") |
|
158 | with_dict_id = cctx.compress(b"foobarfoobar") | |
153 |
|
159 | |||
154 | cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) |
|
160 | cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) | |
155 | no_dict_id = cctx.compress(b"foobarfoobar") |
|
161 | no_dict_id = cctx.compress(b"foobarfoobar") | |
156 |
|
162 | |||
157 | self.assertEqual(len(with_dict_id), len(no_dict_id) + 4) |
|
163 | self.assertEqual(len(with_dict_id), len(no_dict_id) + 4) | |
158 |
|
164 | |||
159 | no_params = zstd.get_frame_parameters(no_dict_id) |
|
165 | no_params = zstd.get_frame_parameters(no_dict_id) | |
160 | with_params = zstd.get_frame_parameters(with_dict_id) |
|
166 | with_params = zstd.get_frame_parameters(with_dict_id) | |
161 | self.assertEqual(no_params.dict_id, 0) |
|
167 | self.assertEqual(no_params.dict_id, 0) | |
162 | self.assertEqual(with_params.dict_id, 1880053135) |
|
168 | self.assertEqual(with_params.dict_id, 1880053135) | |
163 |
|
169 | |||
164 | def test_compress_dict_multiple(self): |
|
170 | def test_compress_dict_multiple(self): | |
165 | samples = [] |
|
171 | samples = [] | |
166 | for i in range(128): |
|
172 | for i in range(128): | |
167 | samples.append(b"foo" * 64) |
|
173 | samples.append(b"foo" * 64) | |
168 | samples.append(b"bar" * 64) |
|
174 | samples.append(b"bar" * 64) | |
169 | samples.append(b"foobar" * 64) |
|
175 | samples.append(b"foobar" * 64) | |
170 |
|
176 | |||
171 | d = zstd.train_dictionary(8192, samples) |
|
177 | d = zstd.train_dictionary(8192, samples) | |
172 |
|
178 | |||
173 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
179 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) | |
174 |
|
180 | |||
175 | for i in range(32): |
|
181 | for i in range(32): | |
176 | cctx.compress(b"foo bar foobar foo bar foobar") |
|
182 | cctx.compress(b"foo bar foobar foo bar foobar") | |
177 |
|
183 | |||
178 | def test_dict_precompute(self): |
|
184 | def test_dict_precompute(self): | |
179 | samples = [] |
|
185 | samples = [] | |
180 | for i in range(128): |
|
186 | for i in range(128): | |
181 | samples.append(b"foo" * 64) |
|
187 | samples.append(b"foo" * 64) | |
182 | samples.append(b"bar" * 64) |
|
188 | samples.append(b"bar" * 64) | |
183 | samples.append(b"foobar" * 64) |
|
189 | samples.append(b"foobar" * 64) | |
184 |
|
190 | |||
185 | d = zstd.train_dictionary(8192, samples) |
|
191 | d = zstd.train_dictionary(8192, samples) | |
186 | d.precompute_compress(level=1) |
|
192 | d.precompute_compress(level=1) | |
187 |
|
193 | |||
188 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
194 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) | |
189 |
|
195 | |||
190 | for i in range(32): |
|
196 | for i in range(32): | |
191 | cctx.compress(b"foo bar foobar foo bar foobar") |
|
197 | cctx.compress(b"foo bar foobar foo bar foobar") | |
192 |
|
198 | |||
193 | def test_multithreaded(self): |
|
199 | def test_multithreaded(self): | |
194 | chunk_size = multithreaded_chunk_size(1) |
|
200 | chunk_size = multithreaded_chunk_size(1) | |
195 | source = b"".join([b"x" * chunk_size, b"y" * chunk_size]) |
|
201 | source = b"".join([b"x" * chunk_size, b"y" * chunk_size]) | |
196 |
|
202 | |||
197 | cctx = zstd.ZstdCompressor(level=1, threads=2) |
|
203 | cctx = zstd.ZstdCompressor(level=1, threads=2) | |
198 | compressed = cctx.compress(source) |
|
204 | compressed = cctx.compress(source) | |
199 |
|
205 | |||
200 | params = zstd.get_frame_parameters(compressed) |
|
206 | params = zstd.get_frame_parameters(compressed) | |
201 | self.assertEqual(params.content_size, chunk_size * 2) |
|
207 | self.assertEqual(params.content_size, chunk_size * 2) | |
202 | self.assertEqual(params.dict_id, 0) |
|
208 | self.assertEqual(params.dict_id, 0) | |
203 | self.assertFalse(params.has_checksum) |
|
209 | self.assertFalse(params.has_checksum) | |
204 |
|
210 | |||
205 | dctx = zstd.ZstdDecompressor() |
|
211 | dctx = zstd.ZstdDecompressor() | |
206 | self.assertEqual(dctx.decompress(compressed), source) |
|
212 | self.assertEqual(dctx.decompress(compressed), source) | |
207 |
|
213 | |||
208 | def test_multithreaded_dict(self): |
|
214 | def test_multithreaded_dict(self): | |
209 | samples = [] |
|
215 | samples = [] | |
210 | for i in range(128): |
|
216 | for i in range(128): | |
211 | samples.append(b"foo" * 64) |
|
217 | samples.append(b"foo" * 64) | |
212 | samples.append(b"bar" * 64) |
|
218 | samples.append(b"bar" * 64) | |
213 | samples.append(b"foobar" * 64) |
|
219 | samples.append(b"foobar" * 64) | |
214 |
|
220 | |||
215 | d = zstd.train_dictionary(1024, samples) |
|
221 | d = zstd.train_dictionary(1024, samples) | |
216 |
|
222 | |||
217 | cctx = zstd.ZstdCompressor(dict_data=d, threads=2) |
|
223 | cctx = zstd.ZstdCompressor(dict_data=d, threads=2) | |
218 |
|
224 | |||
219 | result = cctx.compress(b"foo") |
|
225 | result = cctx.compress(b"foo") | |
220 | params = zstd.get_frame_parameters(result) |
|
226 | params = zstd.get_frame_parameters(result) | |
221 | self.assertEqual(params.content_size, 3) |
|
227 | self.assertEqual(params.content_size, 3) | |
222 | self.assertEqual(params.dict_id, d.dict_id()) |
|
228 | self.assertEqual(params.dict_id, d.dict_id()) | |
223 |
|
229 | |||
224 | self.assertEqual( |
|
230 | self.assertEqual( | |
225 | result, |
|
231 | result, | |
226 |
b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" |
|
232 | b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" | |
|
233 | b"\x66\x6f\x6f", | |||
227 | ) |
|
234 | ) | |
228 |
|
235 | |||
229 | def test_multithreaded_compression_params(self): |
|
236 | def test_multithreaded_compression_params(self): | |
230 | params = zstd.ZstdCompressionParameters.from_level(0, threads=2) |
|
237 | params = zstd.ZstdCompressionParameters.from_level(0, threads=2) | |
231 | cctx = zstd.ZstdCompressor(compression_params=params) |
|
238 | cctx = zstd.ZstdCompressor(compression_params=params) | |
232 |
|
239 | |||
233 | result = cctx.compress(b"foo") |
|
240 | result = cctx.compress(b"foo") | |
234 | params = zstd.get_frame_parameters(result) |
|
241 | params = zstd.get_frame_parameters(result) | |
235 | self.assertEqual(params.content_size, 3) |
|
242 | self.assertEqual(params.content_size, 3) | |
236 |
|
243 | |||
237 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f") |
|
244 | self.assertEqual( | |
|
245 | result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f" | |||
|
246 | ) | |||
238 |
|
247 | |||
239 |
|
248 | |||
240 | @make_cffi |
|
249 | @make_cffi | |
241 | class TestCompressor_compressobj(TestCase): |
|
250 | class TestCompressor_compressobj(TestCase): | |
242 | def test_compressobj_empty(self): |
|
251 | def test_compressobj_empty(self): | |
243 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
252 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
244 | cobj = cctx.compressobj() |
|
253 | cobj = cctx.compressobj() | |
245 | self.assertEqual(cobj.compress(b""), b"") |
|
254 | self.assertEqual(cobj.compress(b""), b"") | |
246 | self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") |
|
255 | self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") | |
247 |
|
256 | |||
248 | def test_input_types(self): |
|
257 | def test_input_types(self): | |
249 | expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" |
|
258 | expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" | |
250 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
259 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
251 |
|
260 | |||
252 | mutable_array = bytearray(3) |
|
261 | mutable_array = bytearray(3) | |
253 | mutable_array[:] = b"foo" |
|
262 | mutable_array[:] = b"foo" | |
254 |
|
263 | |||
255 | sources = [ |
|
264 | sources = [ | |
256 | memoryview(b"foo"), |
|
265 | memoryview(b"foo"), | |
257 | bytearray(b"foo"), |
|
266 | bytearray(b"foo"), | |
258 | mutable_array, |
|
267 | mutable_array, | |
259 | ] |
|
268 | ] | |
260 |
|
269 | |||
261 | for source in sources: |
|
270 | for source in sources: | |
262 | cobj = cctx.compressobj() |
|
271 | cobj = cctx.compressobj() | |
263 | self.assertEqual(cobj.compress(source), b"") |
|
272 | self.assertEqual(cobj.compress(source), b"") | |
264 | self.assertEqual(cobj.flush(), expected) |
|
273 | self.assertEqual(cobj.flush(), expected) | |
265 |
|
274 | |||
266 | def test_compressobj_large(self): |
|
275 | def test_compressobj_large(self): | |
267 | chunks = [] |
|
276 | chunks = [] | |
268 | for i in range(255): |
|
277 | for i in range(255): | |
269 | chunks.append(struct.Struct(">B").pack(i) * 16384) |
|
278 | chunks.append(struct.Struct(">B").pack(i) * 16384) | |
270 |
|
279 | |||
271 | cctx = zstd.ZstdCompressor(level=3) |
|
280 | cctx = zstd.ZstdCompressor(level=3) | |
272 | cobj = cctx.compressobj() |
|
281 | cobj = cctx.compressobj() | |
273 |
|
282 | |||
274 | result = cobj.compress(b"".join(chunks)) + cobj.flush() |
|
283 | result = cobj.compress(b"".join(chunks)) + cobj.flush() | |
275 | self.assertEqual(len(result), 999) |
|
284 | self.assertEqual(len(result), 999) | |
276 | self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") |
|
285 | self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") | |
277 |
|
286 | |||
278 | params = zstd.get_frame_parameters(result) |
|
287 | params = zstd.get_frame_parameters(result) | |
279 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
288 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
280 | self.assertEqual(params.window_size, 2097152) |
|
289 | self.assertEqual(params.window_size, 2097152) | |
281 | self.assertEqual(params.dict_id, 0) |
|
290 | self.assertEqual(params.dict_id, 0) | |
282 | self.assertFalse(params.has_checksum) |
|
291 | self.assertFalse(params.has_checksum) | |
283 |
|
292 | |||
284 | def test_write_checksum(self): |
|
293 | def test_write_checksum(self): | |
285 | cctx = zstd.ZstdCompressor(level=1) |
|
294 | cctx = zstd.ZstdCompressor(level=1) | |
286 | cobj = cctx.compressobj() |
|
295 | cobj = cctx.compressobj() | |
287 | no_checksum = cobj.compress(b"foobar") + cobj.flush() |
|
296 | no_checksum = cobj.compress(b"foobar") + cobj.flush() | |
288 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
297 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | |
289 | cobj = cctx.compressobj() |
|
298 | cobj = cctx.compressobj() | |
290 | with_checksum = cobj.compress(b"foobar") + cobj.flush() |
|
299 | with_checksum = cobj.compress(b"foobar") + cobj.flush() | |
291 |
|
300 | |||
292 | no_params = zstd.get_frame_parameters(no_checksum) |
|
301 | no_params = zstd.get_frame_parameters(no_checksum) | |
293 | with_params = zstd.get_frame_parameters(with_checksum) |
|
302 | with_params = zstd.get_frame_parameters(with_checksum) | |
294 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
303 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
295 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
304 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
296 | self.assertEqual(no_params.dict_id, 0) |
|
305 | self.assertEqual(no_params.dict_id, 0) | |
297 | self.assertEqual(with_params.dict_id, 0) |
|
306 | self.assertEqual(with_params.dict_id, 0) | |
298 | self.assertFalse(no_params.has_checksum) |
|
307 | self.assertFalse(no_params.has_checksum) | |
299 | self.assertTrue(with_params.has_checksum) |
|
308 | self.assertTrue(with_params.has_checksum) | |
300 |
|
309 | |||
301 | self.assertEqual(len(with_checksum), len(no_checksum) + 4) |
|
310 | self.assertEqual(len(with_checksum), len(no_checksum) + 4) | |
302 |
|
311 | |||
303 | def test_write_content_size(self): |
|
312 | def test_write_content_size(self): | |
304 | cctx = zstd.ZstdCompressor(level=1) |
|
313 | cctx = zstd.ZstdCompressor(level=1) | |
305 | cobj = cctx.compressobj(size=len(b"foobar" * 256)) |
|
314 | cobj = cctx.compressobj(size=len(b"foobar" * 256)) | |
306 | with_size = cobj.compress(b"foobar" * 256) + cobj.flush() |
|
315 | with_size = cobj.compress(b"foobar" * 256) + cobj.flush() | |
307 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
316 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
308 | cobj = cctx.compressobj(size=len(b"foobar" * 256)) |
|
317 | cobj = cctx.compressobj(size=len(b"foobar" * 256)) | |
309 | no_size = cobj.compress(b"foobar" * 256) + cobj.flush() |
|
318 | no_size = cobj.compress(b"foobar" * 256) + cobj.flush() | |
310 |
|
319 | |||
311 | no_params = zstd.get_frame_parameters(no_size) |
|
320 | no_params = zstd.get_frame_parameters(no_size) | |
312 | with_params = zstd.get_frame_parameters(with_size) |
|
321 | with_params = zstd.get_frame_parameters(with_size) | |
313 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
322 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
314 | self.assertEqual(with_params.content_size, 1536) |
|
323 | self.assertEqual(with_params.content_size, 1536) | |
315 | self.assertEqual(no_params.dict_id, 0) |
|
324 | self.assertEqual(no_params.dict_id, 0) | |
316 | self.assertEqual(with_params.dict_id, 0) |
|
325 | self.assertEqual(with_params.dict_id, 0) | |
317 | self.assertFalse(no_params.has_checksum) |
|
326 | self.assertFalse(no_params.has_checksum) | |
318 | self.assertFalse(with_params.has_checksum) |
|
327 | self.assertFalse(with_params.has_checksum) | |
319 |
|
328 | |||
320 | self.assertEqual(len(with_size), len(no_size) + 1) |
|
329 | self.assertEqual(len(with_size), len(no_size) + 1) | |
321 |
|
330 | |||
322 | def test_compress_after_finished(self): |
|
331 | def test_compress_after_finished(self): | |
323 | cctx = zstd.ZstdCompressor() |
|
332 | cctx = zstd.ZstdCompressor() | |
324 | cobj = cctx.compressobj() |
|
333 | cobj = cctx.compressobj() | |
325 |
|
334 | |||
326 | cobj.compress(b"foo") |
|
335 | cobj.compress(b"foo") | |
327 | cobj.flush() |
|
336 | cobj.flush() | |
328 |
|
337 | |||
329 | with self.assertRaisesRegex( |
|
338 | with self.assertRaisesRegex( | |
330 | zstd.ZstdError, r"cannot call compress\(\) after compressor" |
|
339 | zstd.ZstdError, r"cannot call compress\(\) after compressor" | |
331 | ): |
|
340 | ): | |
332 | cobj.compress(b"foo") |
|
341 | cobj.compress(b"foo") | |
333 |
|
342 | |||
334 | with self.assertRaisesRegex( |
|
343 | with self.assertRaisesRegex( | |
335 | zstd.ZstdError, "compressor object already finished" |
|
344 | zstd.ZstdError, "compressor object already finished" | |
336 | ): |
|
345 | ): | |
337 | cobj.flush() |
|
346 | cobj.flush() | |
338 |
|
347 | |||
339 | def test_flush_block_repeated(self): |
|
348 | def test_flush_block_repeated(self): | |
340 | cctx = zstd.ZstdCompressor(level=1) |
|
349 | cctx = zstd.ZstdCompressor(level=1) | |
341 | cobj = cctx.compressobj() |
|
350 | cobj = cctx.compressobj() | |
342 |
|
351 | |||
343 | self.assertEqual(cobj.compress(b"foo"), b"") |
|
352 | self.assertEqual(cobj.compress(b"foo"), b"") | |
344 | self.assertEqual( |
|
353 | self.assertEqual( | |
345 | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), |
|
354 | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), | |
346 | b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", |
|
355 | b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", | |
347 | ) |
|
356 | ) | |
348 | self.assertEqual(cobj.compress(b"bar"), b"") |
|
357 | self.assertEqual(cobj.compress(b"bar"), b"") | |
349 | # 3 byte header plus content. |
|
358 | # 3 byte header plus content. | |
350 | self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar") |
|
359 | self.assertEqual( | |
|
360 | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar" | |||
|
361 | ) | |||
351 | self.assertEqual(cobj.flush(), b"\x01\x00\x00") |
|
362 | self.assertEqual(cobj.flush(), b"\x01\x00\x00") | |
352 |
|
363 | |||
353 | def test_flush_empty_block(self): |
|
364 | def test_flush_empty_block(self): | |
354 | cctx = zstd.ZstdCompressor(write_checksum=True) |
|
365 | cctx = zstd.ZstdCompressor(write_checksum=True) | |
355 | cobj = cctx.compressobj() |
|
366 | cobj = cctx.compressobj() | |
356 |
|
367 | |||
357 | cobj.compress(b"foobar") |
|
368 | cobj.compress(b"foobar") | |
358 | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) |
|
369 | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) | |
359 | # No-op if no block is active (this is internal to zstd). |
|
370 | # No-op if no block is active (this is internal to zstd). | |
360 | self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"") |
|
371 | self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"") | |
361 |
|
372 | |||
362 | trailing = cobj.flush() |
|
373 | trailing = cobj.flush() | |
363 | # 3 bytes block header + 4 bytes frame checksum |
|
374 | # 3 bytes block header + 4 bytes frame checksum | |
364 | self.assertEqual(len(trailing), 7) |
|
375 | self.assertEqual(len(trailing), 7) | |
365 | header = trailing[0:3] |
|
376 | header = trailing[0:3] | |
366 | self.assertEqual(header, b"\x01\x00\x00") |
|
377 | self.assertEqual(header, b"\x01\x00\x00") | |
367 |
|
378 | |||
368 | def test_multithreaded(self): |
|
379 | def test_multithreaded(self): | |
369 | source = io.BytesIO() |
|
380 | source = io.BytesIO() | |
370 | source.write(b"a" * 1048576) |
|
381 | source.write(b"a" * 1048576) | |
371 | source.write(b"b" * 1048576) |
|
382 | source.write(b"b" * 1048576) | |
372 | source.write(b"c" * 1048576) |
|
383 | source.write(b"c" * 1048576) | |
373 | source.seek(0) |
|
384 | source.seek(0) | |
374 |
|
385 | |||
375 | cctx = zstd.ZstdCompressor(level=1, threads=2) |
|
386 | cctx = zstd.ZstdCompressor(level=1, threads=2) | |
376 | cobj = cctx.compressobj() |
|
387 | cobj = cctx.compressobj() | |
377 |
|
388 | |||
378 | chunks = [] |
|
389 | chunks = [] | |
379 | while True: |
|
390 | while True: | |
380 | d = source.read(8192) |
|
391 | d = source.read(8192) | |
381 | if not d: |
|
392 | if not d: | |
382 | break |
|
393 | break | |
383 |
|
394 | |||
384 | chunks.append(cobj.compress(d)) |
|
395 | chunks.append(cobj.compress(d)) | |
385 |
|
396 | |||
386 | chunks.append(cobj.flush()) |
|
397 | chunks.append(cobj.flush()) | |
387 |
|
398 | |||
388 | compressed = b"".join(chunks) |
|
399 | compressed = b"".join(chunks) | |
389 |
|
400 | |||
390 | self.assertEqual(len(compressed), 119) |
|
401 | self.assertEqual(len(compressed), 119) | |
391 |
|
402 | |||
392 | def test_frame_progression(self): |
|
403 | def test_frame_progression(self): | |
393 | cctx = zstd.ZstdCompressor() |
|
404 | cctx = zstd.ZstdCompressor() | |
394 |
|
405 | |||
395 | self.assertEqual(cctx.frame_progression(), (0, 0, 0)) |
|
406 | self.assertEqual(cctx.frame_progression(), (0, 0, 0)) | |
396 |
|
407 | |||
397 | cobj = cctx.compressobj() |
|
408 | cobj = cctx.compressobj() | |
398 |
|
409 | |||
399 | cobj.compress(b"foobar") |
|
410 | cobj.compress(b"foobar") | |
400 | self.assertEqual(cctx.frame_progression(), (6, 0, 0)) |
|
411 | self.assertEqual(cctx.frame_progression(), (6, 0, 0)) | |
401 |
|
412 | |||
402 | cobj.flush() |
|
413 | cobj.flush() | |
403 | self.assertEqual(cctx.frame_progression(), (6, 6, 15)) |
|
414 | self.assertEqual(cctx.frame_progression(), (6, 6, 15)) | |
404 |
|
415 | |||
405 | def test_bad_size(self): |
|
416 | def test_bad_size(self): | |
406 | cctx = zstd.ZstdCompressor() |
|
417 | cctx = zstd.ZstdCompressor() | |
407 |
|
418 | |||
408 | cobj = cctx.compressobj(size=2) |
|
419 | cobj = cctx.compressobj(size=2) | |
409 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): |
|
420 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): | |
410 | cobj.compress(b"foo") |
|
421 | cobj.compress(b"foo") | |
411 |
|
422 | |||
412 | # Try another operation on this instance. |
|
423 | # Try another operation on this instance. | |
413 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): |
|
424 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): | |
414 | cobj.compress(b"aa") |
|
425 | cobj.compress(b"aa") | |
415 |
|
426 | |||
416 | # Try another operation on the compressor. |
|
427 | # Try another operation on the compressor. | |
417 | cctx.compressobj(size=4) |
|
428 | cctx.compressobj(size=4) | |
418 | cctx.compress(b"foobar") |
|
429 | cctx.compress(b"foobar") | |
419 |
|
430 | |||
420 |
|
431 | |||
421 | @make_cffi |
|
432 | @make_cffi | |
422 | class TestCompressor_copy_stream(TestCase): |
|
433 | class TestCompressor_copy_stream(TestCase): | |
423 | def test_no_read(self): |
|
434 | def test_no_read(self): | |
424 | source = object() |
|
435 | source = object() | |
425 | dest = io.BytesIO() |
|
436 | dest = io.BytesIO() | |
426 |
|
437 | |||
427 | cctx = zstd.ZstdCompressor() |
|
438 | cctx = zstd.ZstdCompressor() | |
428 | with self.assertRaises(ValueError): |
|
439 | with self.assertRaises(ValueError): | |
429 | cctx.copy_stream(source, dest) |
|
440 | cctx.copy_stream(source, dest) | |
430 |
|
441 | |||
431 | def test_no_write(self): |
|
442 | def test_no_write(self): | |
432 | source = io.BytesIO() |
|
443 | source = io.BytesIO() | |
433 | dest = object() |
|
444 | dest = object() | |
434 |
|
445 | |||
435 | cctx = zstd.ZstdCompressor() |
|
446 | cctx = zstd.ZstdCompressor() | |
436 | with self.assertRaises(ValueError): |
|
447 | with self.assertRaises(ValueError): | |
437 | cctx.copy_stream(source, dest) |
|
448 | cctx.copy_stream(source, dest) | |
438 |
|
449 | |||
439 | def test_empty(self): |
|
450 | def test_empty(self): | |
440 | source = io.BytesIO() |
|
451 | source = io.BytesIO() | |
441 | dest = io.BytesIO() |
|
452 | dest = io.BytesIO() | |
442 |
|
453 | |||
443 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
454 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
444 | r, w = cctx.copy_stream(source, dest) |
|
455 | r, w = cctx.copy_stream(source, dest) | |
445 | self.assertEqual(int(r), 0) |
|
456 | self.assertEqual(int(r), 0) | |
446 | self.assertEqual(w, 9) |
|
457 | self.assertEqual(w, 9) | |
447 |
|
458 | |||
448 | self.assertEqual(dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") |
|
459 | self.assertEqual( | |
|
460 | dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00" | |||
|
461 | ) | |||
449 |
|
462 | |||
450 | def test_large_data(self): |
|
463 | def test_large_data(self): | |
451 | source = io.BytesIO() |
|
464 | source = io.BytesIO() | |
452 | for i in range(255): |
|
465 | for i in range(255): | |
453 | source.write(struct.Struct(">B").pack(i) * 16384) |
|
466 | source.write(struct.Struct(">B").pack(i) * 16384) | |
454 | source.seek(0) |
|
467 | source.seek(0) | |
455 |
|
468 | |||
456 | dest = io.BytesIO() |
|
469 | dest = io.BytesIO() | |
457 | cctx = zstd.ZstdCompressor() |
|
470 | cctx = zstd.ZstdCompressor() | |
458 | r, w = cctx.copy_stream(source, dest) |
|
471 | r, w = cctx.copy_stream(source, dest) | |
459 |
|
472 | |||
460 | self.assertEqual(r, 255 * 16384) |
|
473 | self.assertEqual(r, 255 * 16384) | |
461 | self.assertEqual(w, 999) |
|
474 | self.assertEqual(w, 999) | |
462 |
|
475 | |||
463 | params = zstd.get_frame_parameters(dest.getvalue()) |
|
476 | params = zstd.get_frame_parameters(dest.getvalue()) | |
464 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
477 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
465 | self.assertEqual(params.window_size, 2097152) |
|
478 | self.assertEqual(params.window_size, 2097152) | |
466 | self.assertEqual(params.dict_id, 0) |
|
479 | self.assertEqual(params.dict_id, 0) | |
467 | self.assertFalse(params.has_checksum) |
|
480 | self.assertFalse(params.has_checksum) | |
468 |
|
481 | |||
469 | def test_write_checksum(self): |
|
482 | def test_write_checksum(self): | |
470 | source = io.BytesIO(b"foobar") |
|
483 | source = io.BytesIO(b"foobar") | |
471 | no_checksum = io.BytesIO() |
|
484 | no_checksum = io.BytesIO() | |
472 |
|
485 | |||
473 | cctx = zstd.ZstdCompressor(level=1) |
|
486 | cctx = zstd.ZstdCompressor(level=1) | |
474 | cctx.copy_stream(source, no_checksum) |
|
487 | cctx.copy_stream(source, no_checksum) | |
475 |
|
488 | |||
476 | source.seek(0) |
|
489 | source.seek(0) | |
477 | with_checksum = io.BytesIO() |
|
490 | with_checksum = io.BytesIO() | |
478 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
491 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | |
479 | cctx.copy_stream(source, with_checksum) |
|
492 | cctx.copy_stream(source, with_checksum) | |
480 |
|
493 | |||
481 | self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4) |
|
494 | self.assertEqual( | |
|
495 | len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 | |||
|
496 | ) | |||
482 |
|
497 | |||
483 | no_params = zstd.get_frame_parameters(no_checksum.getvalue()) |
|
498 | no_params = zstd.get_frame_parameters(no_checksum.getvalue()) | |
484 | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) |
|
499 | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | |
485 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
500 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
486 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
501 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
487 | self.assertEqual(no_params.dict_id, 0) |
|
502 | self.assertEqual(no_params.dict_id, 0) | |
488 | self.assertEqual(with_params.dict_id, 0) |
|
503 | self.assertEqual(with_params.dict_id, 0) | |
489 | self.assertFalse(no_params.has_checksum) |
|
504 | self.assertFalse(no_params.has_checksum) | |
490 | self.assertTrue(with_params.has_checksum) |
|
505 | self.assertTrue(with_params.has_checksum) | |
491 |
|
506 | |||
492 | def test_write_content_size(self): |
|
507 | def test_write_content_size(self): | |
493 | source = io.BytesIO(b"foobar" * 256) |
|
508 | source = io.BytesIO(b"foobar" * 256) | |
494 | no_size = io.BytesIO() |
|
509 | no_size = io.BytesIO() | |
495 |
|
510 | |||
496 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
511 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
497 | cctx.copy_stream(source, no_size) |
|
512 | cctx.copy_stream(source, no_size) | |
498 |
|
513 | |||
499 | source.seek(0) |
|
514 | source.seek(0) | |
500 | with_size = io.BytesIO() |
|
515 | with_size = io.BytesIO() | |
501 | cctx = zstd.ZstdCompressor(level=1) |
|
516 | cctx = zstd.ZstdCompressor(level=1) | |
502 | cctx.copy_stream(source, with_size) |
|
517 | cctx.copy_stream(source, with_size) | |
503 |
|
518 | |||
504 | # Source content size is unknown, so no content size written. |
|
519 | # Source content size is unknown, so no content size written. | |
505 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) |
|
520 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) | |
506 |
|
521 | |||
507 | source.seek(0) |
|
522 | source.seek(0) | |
508 | with_size = io.BytesIO() |
|
523 | with_size = io.BytesIO() | |
509 | cctx.copy_stream(source, with_size, size=len(source.getvalue())) |
|
524 | cctx.copy_stream(source, with_size, size=len(source.getvalue())) | |
510 |
|
525 | |||
511 | # We specified source size, so content size header is present. |
|
526 | # We specified source size, so content size header is present. | |
512 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) |
|
527 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) | |
513 |
|
528 | |||
514 | no_params = zstd.get_frame_parameters(no_size.getvalue()) |
|
529 | no_params = zstd.get_frame_parameters(no_size.getvalue()) | |
515 | with_params = zstd.get_frame_parameters(with_size.getvalue()) |
|
530 | with_params = zstd.get_frame_parameters(with_size.getvalue()) | |
516 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
531 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
517 | self.assertEqual(with_params.content_size, 1536) |
|
532 | self.assertEqual(with_params.content_size, 1536) | |
518 | self.assertEqual(no_params.dict_id, 0) |
|
533 | self.assertEqual(no_params.dict_id, 0) | |
519 | self.assertEqual(with_params.dict_id, 0) |
|
534 | self.assertEqual(with_params.dict_id, 0) | |
520 | self.assertFalse(no_params.has_checksum) |
|
535 | self.assertFalse(no_params.has_checksum) | |
521 | self.assertFalse(with_params.has_checksum) |
|
536 | self.assertFalse(with_params.has_checksum) | |
522 |
|
537 | |||
523 | def test_read_write_size(self): |
|
538 | def test_read_write_size(self): | |
524 | source = OpCountingBytesIO(b"foobarfoobar") |
|
539 | source = OpCountingBytesIO(b"foobarfoobar") | |
525 | dest = OpCountingBytesIO() |
|
540 | dest = OpCountingBytesIO() | |
526 | cctx = zstd.ZstdCompressor() |
|
541 | cctx = zstd.ZstdCompressor() | |
527 | r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1) |
|
542 | r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1) | |
528 |
|
543 | |||
529 | self.assertEqual(r, len(source.getvalue())) |
|
544 | self.assertEqual(r, len(source.getvalue())) | |
530 | self.assertEqual(w, 21) |
|
545 | self.assertEqual(w, 21) | |
531 | self.assertEqual(source._read_count, len(source.getvalue()) + 1) |
|
546 | self.assertEqual(source._read_count, len(source.getvalue()) + 1) | |
532 | self.assertEqual(dest._write_count, len(dest.getvalue())) |
|
547 | self.assertEqual(dest._write_count, len(dest.getvalue())) | |
533 |
|
548 | |||
534 | def test_multithreaded(self): |
|
549 | def test_multithreaded(self): | |
535 | source = io.BytesIO() |
|
550 | source = io.BytesIO() | |
536 | source.write(b"a" * 1048576) |
|
551 | source.write(b"a" * 1048576) | |
537 | source.write(b"b" * 1048576) |
|
552 | source.write(b"b" * 1048576) | |
538 | source.write(b"c" * 1048576) |
|
553 | source.write(b"c" * 1048576) | |
539 | source.seek(0) |
|
554 | source.seek(0) | |
540 |
|
555 | |||
541 | dest = io.BytesIO() |
|
556 | dest = io.BytesIO() | |
542 | cctx = zstd.ZstdCompressor(threads=2, write_content_size=False) |
|
557 | cctx = zstd.ZstdCompressor(threads=2, write_content_size=False) | |
543 | r, w = cctx.copy_stream(source, dest) |
|
558 | r, w = cctx.copy_stream(source, dest) | |
544 | self.assertEqual(r, 3145728) |
|
559 | self.assertEqual(r, 3145728) | |
545 | self.assertEqual(w, 111) |
|
560 | self.assertEqual(w, 111) | |
546 |
|
561 | |||
547 | params = zstd.get_frame_parameters(dest.getvalue()) |
|
562 | params = zstd.get_frame_parameters(dest.getvalue()) | |
548 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
563 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
549 | self.assertEqual(params.dict_id, 0) |
|
564 | self.assertEqual(params.dict_id, 0) | |
550 | self.assertFalse(params.has_checksum) |
|
565 | self.assertFalse(params.has_checksum) | |
551 |
|
566 | |||
552 | # Writing content size and checksum works. |
|
567 | # Writing content size and checksum works. | |
553 | cctx = zstd.ZstdCompressor(threads=2, write_checksum=True) |
|
568 | cctx = zstd.ZstdCompressor(threads=2, write_checksum=True) | |
554 | dest = io.BytesIO() |
|
569 | dest = io.BytesIO() | |
555 | source.seek(0) |
|
570 | source.seek(0) | |
556 | cctx.copy_stream(source, dest, size=len(source.getvalue())) |
|
571 | cctx.copy_stream(source, dest, size=len(source.getvalue())) | |
557 |
|
572 | |||
558 | params = zstd.get_frame_parameters(dest.getvalue()) |
|
573 | params = zstd.get_frame_parameters(dest.getvalue()) | |
559 | self.assertEqual(params.content_size, 3145728) |
|
574 | self.assertEqual(params.content_size, 3145728) | |
560 | self.assertEqual(params.dict_id, 0) |
|
575 | self.assertEqual(params.dict_id, 0) | |
561 | self.assertTrue(params.has_checksum) |
|
576 | self.assertTrue(params.has_checksum) | |
562 |
|
577 | |||
563 | def test_bad_size(self): |
|
578 | def test_bad_size(self): | |
564 | source = io.BytesIO() |
|
579 | source = io.BytesIO() | |
565 | source.write(b"a" * 32768) |
|
580 | source.write(b"a" * 32768) | |
566 | source.write(b"b" * 32768) |
|
581 | source.write(b"b" * 32768) | |
567 | source.seek(0) |
|
582 | source.seek(0) | |
568 |
|
583 | |||
569 | dest = io.BytesIO() |
|
584 | dest = io.BytesIO() | |
570 |
|
585 | |||
571 | cctx = zstd.ZstdCompressor() |
|
586 | cctx = zstd.ZstdCompressor() | |
572 |
|
587 | |||
573 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): |
|
588 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): | |
574 | cctx.copy_stream(source, dest, size=42) |
|
589 | cctx.copy_stream(source, dest, size=42) | |
575 |
|
590 | |||
576 | # Try another operation on this compressor. |
|
591 | # Try another operation on this compressor. | |
577 | source.seek(0) |
|
592 | source.seek(0) | |
578 | dest = io.BytesIO() |
|
593 | dest = io.BytesIO() | |
579 | cctx.copy_stream(source, dest) |
|
594 | cctx.copy_stream(source, dest) | |
580 |
|
595 | |||
581 |
|
596 | |||
582 | @make_cffi |
|
597 | @make_cffi | |
583 | class TestCompressor_stream_reader(TestCase): |
|
598 | class TestCompressor_stream_reader(TestCase): | |
584 | def test_context_manager(self): |
|
599 | def test_context_manager(self): | |
585 | cctx = zstd.ZstdCompressor() |
|
600 | cctx = zstd.ZstdCompressor() | |
586 |
|
601 | |||
587 | with cctx.stream_reader(b"foo") as reader: |
|
602 | with cctx.stream_reader(b"foo") as reader: | |
588 |
with self.assertRaisesRegex( |
|
603 | with self.assertRaisesRegex( | |
|
604 | ValueError, "cannot __enter__ multiple times" | |||
|
605 | ): | |||
589 | with reader as reader2: |
|
606 | with reader as reader2: | |
590 | pass |
|
607 | pass | |
591 |
|
608 | |||
592 | def test_no_context_manager(self): |
|
609 | def test_no_context_manager(self): | |
593 | cctx = zstd.ZstdCompressor() |
|
610 | cctx = zstd.ZstdCompressor() | |
594 |
|
611 | |||
595 | reader = cctx.stream_reader(b"foo") |
|
612 | reader = cctx.stream_reader(b"foo") | |
596 | reader.read(4) |
|
613 | reader.read(4) | |
597 | self.assertFalse(reader.closed) |
|
614 | self.assertFalse(reader.closed) | |
598 |
|
615 | |||
599 | reader.close() |
|
616 | reader.close() | |
600 | self.assertTrue(reader.closed) |
|
617 | self.assertTrue(reader.closed) | |
601 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
618 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
602 | reader.read(1) |
|
619 | reader.read(1) | |
603 |
|
620 | |||
604 | def test_not_implemented(self): |
|
621 | def test_not_implemented(self): | |
605 | cctx = zstd.ZstdCompressor() |
|
622 | cctx = zstd.ZstdCompressor() | |
606 |
|
623 | |||
607 | with cctx.stream_reader(b"foo" * 60) as reader: |
|
624 | with cctx.stream_reader(b"foo" * 60) as reader: | |
608 | with self.assertRaises(io.UnsupportedOperation): |
|
625 | with self.assertRaises(io.UnsupportedOperation): | |
609 | reader.readline() |
|
626 | reader.readline() | |
610 |
|
627 | |||
611 | with self.assertRaises(io.UnsupportedOperation): |
|
628 | with self.assertRaises(io.UnsupportedOperation): | |
612 | reader.readlines() |
|
629 | reader.readlines() | |
613 |
|
630 | |||
614 | with self.assertRaises(io.UnsupportedOperation): |
|
631 | with self.assertRaises(io.UnsupportedOperation): | |
615 | iter(reader) |
|
632 | iter(reader) | |
616 |
|
633 | |||
617 | with self.assertRaises(io.UnsupportedOperation): |
|
634 | with self.assertRaises(io.UnsupportedOperation): | |
618 | next(reader) |
|
635 | next(reader) | |
619 |
|
636 | |||
620 | with self.assertRaises(OSError): |
|
637 | with self.assertRaises(OSError): | |
621 | reader.writelines([]) |
|
638 | reader.writelines([]) | |
622 |
|
639 | |||
623 | with self.assertRaises(OSError): |
|
640 | with self.assertRaises(OSError): | |
624 | reader.write(b"foo") |
|
641 | reader.write(b"foo") | |
625 |
|
642 | |||
626 | def test_constant_methods(self): |
|
643 | def test_constant_methods(self): | |
627 | cctx = zstd.ZstdCompressor() |
|
644 | cctx = zstd.ZstdCompressor() | |
628 |
|
645 | |||
629 | with cctx.stream_reader(b"boo") as reader: |
|
646 | with cctx.stream_reader(b"boo") as reader: | |
630 | self.assertTrue(reader.readable()) |
|
647 | self.assertTrue(reader.readable()) | |
631 | self.assertFalse(reader.writable()) |
|
648 | self.assertFalse(reader.writable()) | |
632 | self.assertFalse(reader.seekable()) |
|
649 | self.assertFalse(reader.seekable()) | |
633 | self.assertFalse(reader.isatty()) |
|
650 | self.assertFalse(reader.isatty()) | |
634 | self.assertFalse(reader.closed) |
|
651 | self.assertFalse(reader.closed) | |
635 | self.assertIsNone(reader.flush()) |
|
652 | self.assertIsNone(reader.flush()) | |
636 | self.assertFalse(reader.closed) |
|
653 | self.assertFalse(reader.closed) | |
637 |
|
654 | |||
638 | self.assertTrue(reader.closed) |
|
655 | self.assertTrue(reader.closed) | |
639 |
|
656 | |||
640 | def test_read_closed(self): |
|
657 | def test_read_closed(self): | |
641 | cctx = zstd.ZstdCompressor() |
|
658 | cctx = zstd.ZstdCompressor() | |
642 |
|
659 | |||
643 | with cctx.stream_reader(b"foo" * 60) as reader: |
|
660 | with cctx.stream_reader(b"foo" * 60) as reader: | |
644 | reader.close() |
|
661 | reader.close() | |
645 | self.assertTrue(reader.closed) |
|
662 | self.assertTrue(reader.closed) | |
646 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
663 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
647 | reader.read(10) |
|
664 | reader.read(10) | |
648 |
|
665 | |||
649 | def test_read_sizes(self): |
|
666 | def test_read_sizes(self): | |
650 | cctx = zstd.ZstdCompressor() |
|
667 | cctx = zstd.ZstdCompressor() | |
651 | foo = cctx.compress(b"foo") |
|
668 | foo = cctx.compress(b"foo") | |
652 |
|
669 | |||
653 | with cctx.stream_reader(b"foo") as reader: |
|
670 | with cctx.stream_reader(b"foo") as reader: | |
654 | with self.assertRaisesRegex( |
|
671 | with self.assertRaisesRegex( | |
655 | ValueError, "cannot read negative amounts less than -1" |
|
672 | ValueError, "cannot read negative amounts less than -1" | |
656 | ): |
|
673 | ): | |
657 | reader.read(-2) |
|
674 | reader.read(-2) | |
658 |
|
675 | |||
659 | self.assertEqual(reader.read(0), b"") |
|
676 | self.assertEqual(reader.read(0), b"") | |
660 | self.assertEqual(reader.read(), foo) |
|
677 | self.assertEqual(reader.read(), foo) | |
661 |
|
678 | |||
662 | def test_read_buffer(self): |
|
679 | def test_read_buffer(self): | |
663 | cctx = zstd.ZstdCompressor() |
|
680 | cctx = zstd.ZstdCompressor() | |
664 |
|
681 | |||
665 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) |
|
682 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) | |
666 | frame = cctx.compress(source) |
|
683 | frame = cctx.compress(source) | |
667 |
|
684 | |||
668 | with cctx.stream_reader(source) as reader: |
|
685 | with cctx.stream_reader(source) as reader: | |
669 | self.assertEqual(reader.tell(), 0) |
|
686 | self.assertEqual(reader.tell(), 0) | |
670 |
|
687 | |||
671 | # We should get entire frame in one read. |
|
688 | # We should get entire frame in one read. | |
672 | result = reader.read(8192) |
|
689 | result = reader.read(8192) | |
673 | self.assertEqual(result, frame) |
|
690 | self.assertEqual(result, frame) | |
674 | self.assertEqual(reader.tell(), len(result)) |
|
691 | self.assertEqual(reader.tell(), len(result)) | |
675 | self.assertEqual(reader.read(), b"") |
|
692 | self.assertEqual(reader.read(), b"") | |
676 | self.assertEqual(reader.tell(), len(result)) |
|
693 | self.assertEqual(reader.tell(), len(result)) | |
677 |
|
694 | |||
678 | def test_read_buffer_small_chunks(self): |
|
695 | def test_read_buffer_small_chunks(self): | |
679 | cctx = zstd.ZstdCompressor() |
|
696 | cctx = zstd.ZstdCompressor() | |
680 |
|
697 | |||
681 | source = b"foo" * 60 |
|
698 | source = b"foo" * 60 | |
682 | chunks = [] |
|
699 | chunks = [] | |
683 |
|
700 | |||
684 | with cctx.stream_reader(source) as reader: |
|
701 | with cctx.stream_reader(source) as reader: | |
685 | self.assertEqual(reader.tell(), 0) |
|
702 | self.assertEqual(reader.tell(), 0) | |
686 |
|
703 | |||
687 | while True: |
|
704 | while True: | |
688 | chunk = reader.read(1) |
|
705 | chunk = reader.read(1) | |
689 | if not chunk: |
|
706 | if not chunk: | |
690 | break |
|
707 | break | |
691 |
|
708 | |||
692 | chunks.append(chunk) |
|
709 | chunks.append(chunk) | |
693 | self.assertEqual(reader.tell(), sum(map(len, chunks))) |
|
710 | self.assertEqual(reader.tell(), sum(map(len, chunks))) | |
694 |
|
711 | |||
695 | self.assertEqual(b"".join(chunks), cctx.compress(source)) |
|
712 | self.assertEqual(b"".join(chunks), cctx.compress(source)) | |
696 |
|
713 | |||
697 | def test_read_stream(self): |
|
714 | def test_read_stream(self): | |
698 | cctx = zstd.ZstdCompressor() |
|
715 | cctx = zstd.ZstdCompressor() | |
699 |
|
716 | |||
700 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) |
|
717 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) | |
701 | frame = cctx.compress(source) |
|
718 | frame = cctx.compress(source) | |
702 |
|
719 | |||
703 | with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: |
|
720 | with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: | |
704 | self.assertEqual(reader.tell(), 0) |
|
721 | self.assertEqual(reader.tell(), 0) | |
705 |
|
722 | |||
706 | chunk = reader.read(8192) |
|
723 | chunk = reader.read(8192) | |
707 | self.assertEqual(chunk, frame) |
|
724 | self.assertEqual(chunk, frame) | |
708 | self.assertEqual(reader.tell(), len(chunk)) |
|
725 | self.assertEqual(reader.tell(), len(chunk)) | |
709 | self.assertEqual(reader.read(), b"") |
|
726 | self.assertEqual(reader.read(), b"") | |
710 | self.assertEqual(reader.tell(), len(chunk)) |
|
727 | self.assertEqual(reader.tell(), len(chunk)) | |
711 |
|
728 | |||
712 | def test_read_stream_small_chunks(self): |
|
729 | def test_read_stream_small_chunks(self): | |
713 | cctx = zstd.ZstdCompressor() |
|
730 | cctx = zstd.ZstdCompressor() | |
714 |
|
731 | |||
715 | source = b"foo" * 60 |
|
732 | source = b"foo" * 60 | |
716 | chunks = [] |
|
733 | chunks = [] | |
717 |
|
734 | |||
718 | with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: |
|
735 | with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: | |
719 | self.assertEqual(reader.tell(), 0) |
|
736 | self.assertEqual(reader.tell(), 0) | |
720 |
|
737 | |||
721 | while True: |
|
738 | while True: | |
722 | chunk = reader.read(1) |
|
739 | chunk = reader.read(1) | |
723 | if not chunk: |
|
740 | if not chunk: | |
724 | break |
|
741 | break | |
725 |
|
742 | |||
726 | chunks.append(chunk) |
|
743 | chunks.append(chunk) | |
727 | self.assertEqual(reader.tell(), sum(map(len, chunks))) |
|
744 | self.assertEqual(reader.tell(), sum(map(len, chunks))) | |
728 |
|
745 | |||
729 | self.assertEqual(b"".join(chunks), cctx.compress(source)) |
|
746 | self.assertEqual(b"".join(chunks), cctx.compress(source)) | |
730 |
|
747 | |||
731 | def test_read_after_exit(self): |
|
748 | def test_read_after_exit(self): | |
732 | cctx = zstd.ZstdCompressor() |
|
749 | cctx = zstd.ZstdCompressor() | |
733 |
|
750 | |||
734 | with cctx.stream_reader(b"foo" * 60) as reader: |
|
751 | with cctx.stream_reader(b"foo" * 60) as reader: | |
735 | while reader.read(8192): |
|
752 | while reader.read(8192): | |
736 | pass |
|
753 | pass | |
737 |
|
754 | |||
738 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
755 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
739 | reader.read(10) |
|
756 | reader.read(10) | |
740 |
|
757 | |||
741 | def test_bad_size(self): |
|
758 | def test_bad_size(self): | |
742 | cctx = zstd.ZstdCompressor() |
|
759 | cctx = zstd.ZstdCompressor() | |
743 |
|
760 | |||
744 | source = io.BytesIO(b"foobar") |
|
761 | source = io.BytesIO(b"foobar") | |
745 |
|
762 | |||
746 | with cctx.stream_reader(source, size=2) as reader: |
|
763 | with cctx.stream_reader(source, size=2) as reader: | |
747 |
with self.assertRaisesRegex( |
|
764 | with self.assertRaisesRegex( | |
|
765 | zstd.ZstdError, "Src size is incorrect" | |||
|
766 | ): | |||
748 | reader.read(10) |
|
767 | reader.read(10) | |
749 |
|
768 | |||
750 | # Try another compression operation. |
|
769 | # Try another compression operation. | |
751 | with cctx.stream_reader(source, size=42): |
|
770 | with cctx.stream_reader(source, size=42): | |
752 | pass |
|
771 | pass | |
753 |
|
772 | |||
754 | def test_readall(self): |
|
773 | def test_readall(self): | |
755 | cctx = zstd.ZstdCompressor() |
|
774 | cctx = zstd.ZstdCompressor() | |
756 | frame = cctx.compress(b"foo" * 1024) |
|
775 | frame = cctx.compress(b"foo" * 1024) | |
757 |
|
776 | |||
758 | reader = cctx.stream_reader(b"foo" * 1024) |
|
777 | reader = cctx.stream_reader(b"foo" * 1024) | |
759 | self.assertEqual(reader.readall(), frame) |
|
778 | self.assertEqual(reader.readall(), frame) | |
760 |
|
779 | |||
761 | def test_readinto(self): |
|
780 | def test_readinto(self): | |
762 | cctx = zstd.ZstdCompressor() |
|
781 | cctx = zstd.ZstdCompressor() | |
763 | foo = cctx.compress(b"foo") |
|
782 | foo = cctx.compress(b"foo") | |
764 |
|
783 | |||
765 | reader = cctx.stream_reader(b"foo") |
|
784 | reader = cctx.stream_reader(b"foo") | |
766 | with self.assertRaises(Exception): |
|
785 | with self.assertRaises(Exception): | |
767 | reader.readinto(b"foobar") |
|
786 | reader.readinto(b"foobar") | |
768 |
|
787 | |||
769 | # readinto() with sufficiently large destination. |
|
788 | # readinto() with sufficiently large destination. | |
770 | b = bytearray(1024) |
|
789 | b = bytearray(1024) | |
771 | reader = cctx.stream_reader(b"foo") |
|
790 | reader = cctx.stream_reader(b"foo") | |
772 | self.assertEqual(reader.readinto(b), len(foo)) |
|
791 | self.assertEqual(reader.readinto(b), len(foo)) | |
773 | self.assertEqual(b[0 : len(foo)], foo) |
|
792 | self.assertEqual(b[0 : len(foo)], foo) | |
774 | self.assertEqual(reader.readinto(b), 0) |
|
793 | self.assertEqual(reader.readinto(b), 0) | |
775 | self.assertEqual(b[0 : len(foo)], foo) |
|
794 | self.assertEqual(b[0 : len(foo)], foo) | |
776 |
|
795 | |||
777 | # readinto() with small reads. |
|
796 | # readinto() with small reads. | |
778 | b = bytearray(1024) |
|
797 | b = bytearray(1024) | |
779 | reader = cctx.stream_reader(b"foo", read_size=1) |
|
798 | reader = cctx.stream_reader(b"foo", read_size=1) | |
780 | self.assertEqual(reader.readinto(b), len(foo)) |
|
799 | self.assertEqual(reader.readinto(b), len(foo)) | |
781 | self.assertEqual(b[0 : len(foo)], foo) |
|
800 | self.assertEqual(b[0 : len(foo)], foo) | |
782 |
|
801 | |||
783 | # Too small destination buffer. |
|
802 | # Too small destination buffer. | |
784 | b = bytearray(2) |
|
803 | b = bytearray(2) | |
785 | reader = cctx.stream_reader(b"foo") |
|
804 | reader = cctx.stream_reader(b"foo") | |
786 | self.assertEqual(reader.readinto(b), 2) |
|
805 | self.assertEqual(reader.readinto(b), 2) | |
787 | self.assertEqual(b[:], foo[0:2]) |
|
806 | self.assertEqual(b[:], foo[0:2]) | |
788 | self.assertEqual(reader.readinto(b), 2) |
|
807 | self.assertEqual(reader.readinto(b), 2) | |
789 | self.assertEqual(b[:], foo[2:4]) |
|
808 | self.assertEqual(b[:], foo[2:4]) | |
790 | self.assertEqual(reader.readinto(b), 2) |
|
809 | self.assertEqual(reader.readinto(b), 2) | |
791 | self.assertEqual(b[:], foo[4:6]) |
|
810 | self.assertEqual(b[:], foo[4:6]) | |
792 |
|
811 | |||
793 | def test_readinto1(self): |
|
812 | def test_readinto1(self): | |
794 | cctx = zstd.ZstdCompressor() |
|
813 | cctx = zstd.ZstdCompressor() | |
795 | foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) |
|
814 | foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) | |
796 |
|
815 | |||
797 | reader = cctx.stream_reader(b"foo") |
|
816 | reader = cctx.stream_reader(b"foo") | |
798 | with self.assertRaises(Exception): |
|
817 | with self.assertRaises(Exception): | |
799 | reader.readinto1(b"foobar") |
|
818 | reader.readinto1(b"foobar") | |
800 |
|
819 | |||
801 | b = bytearray(1024) |
|
820 | b = bytearray(1024) | |
802 | source = OpCountingBytesIO(b"foo") |
|
821 | source = OpCountingBytesIO(b"foo") | |
803 | reader = cctx.stream_reader(source) |
|
822 | reader = cctx.stream_reader(source) | |
804 | self.assertEqual(reader.readinto1(b), len(foo)) |
|
823 | self.assertEqual(reader.readinto1(b), len(foo)) | |
805 | self.assertEqual(b[0 : len(foo)], foo) |
|
824 | self.assertEqual(b[0 : len(foo)], foo) | |
806 | self.assertEqual(source._read_count, 2) |
|
825 | self.assertEqual(source._read_count, 2) | |
807 |
|
826 | |||
808 | # readinto1() with small reads. |
|
827 | # readinto1() with small reads. | |
809 | b = bytearray(1024) |
|
828 | b = bytearray(1024) | |
810 | source = OpCountingBytesIO(b"foo") |
|
829 | source = OpCountingBytesIO(b"foo") | |
811 | reader = cctx.stream_reader(source, read_size=1) |
|
830 | reader = cctx.stream_reader(source, read_size=1) | |
812 | self.assertEqual(reader.readinto1(b), len(foo)) |
|
831 | self.assertEqual(reader.readinto1(b), len(foo)) | |
813 | self.assertEqual(b[0 : len(foo)], foo) |
|
832 | self.assertEqual(b[0 : len(foo)], foo) | |
814 | self.assertEqual(source._read_count, 4) |
|
833 | self.assertEqual(source._read_count, 4) | |
815 |
|
834 | |||
816 | def test_read1(self): |
|
835 | def test_read1(self): | |
817 | cctx = zstd.ZstdCompressor() |
|
836 | cctx = zstd.ZstdCompressor() | |
818 | foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) |
|
837 | foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) | |
819 |
|
838 | |||
820 | b = OpCountingBytesIO(b"foo") |
|
839 | b = OpCountingBytesIO(b"foo") | |
821 | reader = cctx.stream_reader(b) |
|
840 | reader = cctx.stream_reader(b) | |
822 |
|
841 | |||
823 | self.assertEqual(reader.read1(), foo) |
|
842 | self.assertEqual(reader.read1(), foo) | |
824 | self.assertEqual(b._read_count, 2) |
|
843 | self.assertEqual(b._read_count, 2) | |
825 |
|
844 | |||
826 | b = OpCountingBytesIO(b"foo") |
|
845 | b = OpCountingBytesIO(b"foo") | |
827 | reader = cctx.stream_reader(b) |
|
846 | reader = cctx.stream_reader(b) | |
828 |
|
847 | |||
829 | self.assertEqual(reader.read1(0), b"") |
|
848 | self.assertEqual(reader.read1(0), b"") | |
830 | self.assertEqual(reader.read1(2), foo[0:2]) |
|
849 | self.assertEqual(reader.read1(2), foo[0:2]) | |
831 | self.assertEqual(b._read_count, 2) |
|
850 | self.assertEqual(b._read_count, 2) | |
832 | self.assertEqual(reader.read1(2), foo[2:4]) |
|
851 | self.assertEqual(reader.read1(2), foo[2:4]) | |
833 | self.assertEqual(reader.read1(1024), foo[4:]) |
|
852 | self.assertEqual(reader.read1(1024), foo[4:]) | |
834 |
|
853 | |||
835 |
|
854 | |||
836 | @make_cffi |
|
855 | @make_cffi | |
837 | class TestCompressor_stream_writer(TestCase): |
|
856 | class TestCompressor_stream_writer(TestCase): | |
838 | def test_io_api(self): |
|
857 | def test_io_api(self): | |
839 | buffer = io.BytesIO() |
|
858 | buffer = io.BytesIO() | |
840 | cctx = zstd.ZstdCompressor() |
|
859 | cctx = zstd.ZstdCompressor() | |
841 | writer = cctx.stream_writer(buffer) |
|
860 | writer = cctx.stream_writer(buffer) | |
842 |
|
861 | |||
843 | self.assertFalse(writer.isatty()) |
|
862 | self.assertFalse(writer.isatty()) | |
844 | self.assertFalse(writer.readable()) |
|
863 | self.assertFalse(writer.readable()) | |
845 |
|
864 | |||
846 | with self.assertRaises(io.UnsupportedOperation): |
|
865 | with self.assertRaises(io.UnsupportedOperation): | |
847 | writer.readline() |
|
866 | writer.readline() | |
848 |
|
867 | |||
849 | with self.assertRaises(io.UnsupportedOperation): |
|
868 | with self.assertRaises(io.UnsupportedOperation): | |
850 | writer.readline(42) |
|
869 | writer.readline(42) | |
851 |
|
870 | |||
852 | with self.assertRaises(io.UnsupportedOperation): |
|
871 | with self.assertRaises(io.UnsupportedOperation): | |
853 | writer.readline(size=42) |
|
872 | writer.readline(size=42) | |
854 |
|
873 | |||
855 | with self.assertRaises(io.UnsupportedOperation): |
|
874 | with self.assertRaises(io.UnsupportedOperation): | |
856 | writer.readlines() |
|
875 | writer.readlines() | |
857 |
|
876 | |||
858 | with self.assertRaises(io.UnsupportedOperation): |
|
877 | with self.assertRaises(io.UnsupportedOperation): | |
859 | writer.readlines(42) |
|
878 | writer.readlines(42) | |
860 |
|
879 | |||
861 | with self.assertRaises(io.UnsupportedOperation): |
|
880 | with self.assertRaises(io.UnsupportedOperation): | |
862 | writer.readlines(hint=42) |
|
881 | writer.readlines(hint=42) | |
863 |
|
882 | |||
864 | with self.assertRaises(io.UnsupportedOperation): |
|
883 | with self.assertRaises(io.UnsupportedOperation): | |
865 | writer.seek(0) |
|
884 | writer.seek(0) | |
866 |
|
885 | |||
867 | with self.assertRaises(io.UnsupportedOperation): |
|
886 | with self.assertRaises(io.UnsupportedOperation): | |
868 | writer.seek(10, os.SEEK_SET) |
|
887 | writer.seek(10, os.SEEK_SET) | |
869 |
|
888 | |||
870 | self.assertFalse(writer.seekable()) |
|
889 | self.assertFalse(writer.seekable()) | |
871 |
|
890 | |||
872 | with self.assertRaises(io.UnsupportedOperation): |
|
891 | with self.assertRaises(io.UnsupportedOperation): | |
873 | writer.truncate() |
|
892 | writer.truncate() | |
874 |
|
893 | |||
875 | with self.assertRaises(io.UnsupportedOperation): |
|
894 | with self.assertRaises(io.UnsupportedOperation): | |
876 | writer.truncate(42) |
|
895 | writer.truncate(42) | |
877 |
|
896 | |||
878 | with self.assertRaises(io.UnsupportedOperation): |
|
897 | with self.assertRaises(io.UnsupportedOperation): | |
879 | writer.truncate(size=42) |
|
898 | writer.truncate(size=42) | |
880 |
|
899 | |||
881 | self.assertTrue(writer.writable()) |
|
900 | self.assertTrue(writer.writable()) | |
882 |
|
901 | |||
883 | with self.assertRaises(NotImplementedError): |
|
902 | with self.assertRaises(NotImplementedError): | |
884 | writer.writelines([]) |
|
903 | writer.writelines([]) | |
885 |
|
904 | |||
886 | with self.assertRaises(io.UnsupportedOperation): |
|
905 | with self.assertRaises(io.UnsupportedOperation): | |
887 | writer.read() |
|
906 | writer.read() | |
888 |
|
907 | |||
889 | with self.assertRaises(io.UnsupportedOperation): |
|
908 | with self.assertRaises(io.UnsupportedOperation): | |
890 | writer.read(42) |
|
909 | writer.read(42) | |
891 |
|
910 | |||
892 | with self.assertRaises(io.UnsupportedOperation): |
|
911 | with self.assertRaises(io.UnsupportedOperation): | |
893 | writer.read(size=42) |
|
912 | writer.read(size=42) | |
894 |
|
913 | |||
895 | with self.assertRaises(io.UnsupportedOperation): |
|
914 | with self.assertRaises(io.UnsupportedOperation): | |
896 | writer.readall() |
|
915 | writer.readall() | |
897 |
|
916 | |||
898 | with self.assertRaises(io.UnsupportedOperation): |
|
917 | with self.assertRaises(io.UnsupportedOperation): | |
899 | writer.readinto(None) |
|
918 | writer.readinto(None) | |
900 |
|
919 | |||
901 | with self.assertRaises(io.UnsupportedOperation): |
|
920 | with self.assertRaises(io.UnsupportedOperation): | |
902 | writer.fileno() |
|
921 | writer.fileno() | |
903 |
|
922 | |||
904 | self.assertFalse(writer.closed) |
|
923 | self.assertFalse(writer.closed) | |
905 |
|
924 | |||
906 | def test_fileno_file(self): |
|
925 | def test_fileno_file(self): | |
907 | with tempfile.TemporaryFile("wb") as tf: |
|
926 | with tempfile.TemporaryFile("wb") as tf: | |
908 | cctx = zstd.ZstdCompressor() |
|
927 | cctx = zstd.ZstdCompressor() | |
909 | writer = cctx.stream_writer(tf) |
|
928 | writer = cctx.stream_writer(tf) | |
910 |
|
929 | |||
911 | self.assertEqual(writer.fileno(), tf.fileno()) |
|
930 | self.assertEqual(writer.fileno(), tf.fileno()) | |
912 |
|
931 | |||
913 | def test_close(self): |
|
932 | def test_close(self): | |
914 | buffer = NonClosingBytesIO() |
|
933 | buffer = NonClosingBytesIO() | |
915 | cctx = zstd.ZstdCompressor(level=1) |
|
934 | cctx = zstd.ZstdCompressor(level=1) | |
916 | writer = cctx.stream_writer(buffer) |
|
935 | writer = cctx.stream_writer(buffer) | |
917 |
|
936 | |||
918 | writer.write(b"foo" * 1024) |
|
937 | writer.write(b"foo" * 1024) | |
919 | self.assertFalse(writer.closed) |
|
938 | self.assertFalse(writer.closed) | |
920 | self.assertFalse(buffer.closed) |
|
939 | self.assertFalse(buffer.closed) | |
921 | writer.close() |
|
940 | writer.close() | |
922 | self.assertTrue(writer.closed) |
|
941 | self.assertTrue(writer.closed) | |
923 | self.assertTrue(buffer.closed) |
|
942 | self.assertTrue(buffer.closed) | |
924 |
|
943 | |||
925 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
944 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
926 | writer.write(b"foo") |
|
945 | writer.write(b"foo") | |
927 |
|
946 | |||
928 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
947 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
929 | writer.flush() |
|
948 | writer.flush() | |
930 |
|
949 | |||
931 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
950 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
932 | with writer: |
|
951 | with writer: | |
933 | pass |
|
952 | pass | |
934 |
|
953 | |||
935 | self.assertEqual( |
|
954 | self.assertEqual( | |
936 | buffer.getvalue(), |
|
955 | buffer.getvalue(), | |
937 | b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f" |
|
956 | b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f" | |
938 | b"\x6f\x01\x00\xfa\xd3\x77\x43", |
|
957 | b"\x6f\x01\x00\xfa\xd3\x77\x43", | |
939 | ) |
|
958 | ) | |
940 |
|
959 | |||
941 | # Context manager exit should close stream. |
|
960 | # Context manager exit should close stream. | |
942 | buffer = io.BytesIO() |
|
961 | buffer = io.BytesIO() | |
943 | writer = cctx.stream_writer(buffer) |
|
962 | writer = cctx.stream_writer(buffer) | |
944 |
|
963 | |||
945 | with writer: |
|
964 | with writer: | |
946 | writer.write(b"foo") |
|
965 | writer.write(b"foo") | |
947 |
|
966 | |||
948 | self.assertTrue(writer.closed) |
|
967 | self.assertTrue(writer.closed) | |
949 |
|
968 | |||
950 | def test_empty(self): |
|
969 | def test_empty(self): | |
951 | buffer = NonClosingBytesIO() |
|
970 | buffer = NonClosingBytesIO() | |
952 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
971 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
953 | with cctx.stream_writer(buffer) as compressor: |
|
972 | with cctx.stream_writer(buffer) as compressor: | |
954 | compressor.write(b"") |
|
973 | compressor.write(b"") | |
955 |
|
974 | |||
956 | result = buffer.getvalue() |
|
975 | result = buffer.getvalue() | |
957 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") |
|
976 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") | |
958 |
|
977 | |||
959 | params = zstd.get_frame_parameters(result) |
|
978 | params = zstd.get_frame_parameters(result) | |
960 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
979 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
961 | self.assertEqual(params.window_size, 524288) |
|
980 | self.assertEqual(params.window_size, 524288) | |
962 | self.assertEqual(params.dict_id, 0) |
|
981 | self.assertEqual(params.dict_id, 0) | |
963 | self.assertFalse(params.has_checksum) |
|
982 | self.assertFalse(params.has_checksum) | |
964 |
|
983 | |||
965 | # Test without context manager. |
|
984 | # Test without context manager. | |
966 | buffer = io.BytesIO() |
|
985 | buffer = io.BytesIO() | |
967 | compressor = cctx.stream_writer(buffer) |
|
986 | compressor = cctx.stream_writer(buffer) | |
968 | self.assertEqual(compressor.write(b""), 0) |
|
987 | self.assertEqual(compressor.write(b""), 0) | |
969 | self.assertEqual(buffer.getvalue(), b"") |
|
988 | self.assertEqual(buffer.getvalue(), b"") | |
970 | self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9) |
|
989 | self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9) | |
971 | result = buffer.getvalue() |
|
990 | result = buffer.getvalue() | |
972 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") |
|
991 | self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") | |
973 |
|
992 | |||
974 | params = zstd.get_frame_parameters(result) |
|
993 | params = zstd.get_frame_parameters(result) | |
975 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
994 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
976 | self.assertEqual(params.window_size, 524288) |
|
995 | self.assertEqual(params.window_size, 524288) | |
977 | self.assertEqual(params.dict_id, 0) |
|
996 | self.assertEqual(params.dict_id, 0) | |
978 | self.assertFalse(params.has_checksum) |
|
997 | self.assertFalse(params.has_checksum) | |
979 |
|
998 | |||
980 | # Test write_return_read=True |
|
999 | # Test write_return_read=True | |
981 | compressor = cctx.stream_writer(buffer, write_return_read=True) |
|
1000 | compressor = cctx.stream_writer(buffer, write_return_read=True) | |
982 | self.assertEqual(compressor.write(b""), 0) |
|
1001 | self.assertEqual(compressor.write(b""), 0) | |
983 |
|
1002 | |||
984 | def test_input_types(self): |
|
1003 | def test_input_types(self): | |
985 | expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" |
|
1004 | expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" | |
986 | cctx = zstd.ZstdCompressor(level=1) |
|
1005 | cctx = zstd.ZstdCompressor(level=1) | |
987 |
|
1006 | |||
988 | mutable_array = bytearray(3) |
|
1007 | mutable_array = bytearray(3) | |
989 | mutable_array[:] = b"foo" |
|
1008 | mutable_array[:] = b"foo" | |
990 |
|
1009 | |||
991 | sources = [ |
|
1010 | sources = [ | |
992 | memoryview(b"foo"), |
|
1011 | memoryview(b"foo"), | |
993 | bytearray(b"foo"), |
|
1012 | bytearray(b"foo"), | |
994 | mutable_array, |
|
1013 | mutable_array, | |
995 | ] |
|
1014 | ] | |
996 |
|
1015 | |||
997 | for source in sources: |
|
1016 | for source in sources: | |
998 | buffer = NonClosingBytesIO() |
|
1017 | buffer = NonClosingBytesIO() | |
999 | with cctx.stream_writer(buffer) as compressor: |
|
1018 | with cctx.stream_writer(buffer) as compressor: | |
1000 | compressor.write(source) |
|
1019 | compressor.write(source) | |
1001 |
|
1020 | |||
1002 | self.assertEqual(buffer.getvalue(), expected) |
|
1021 | self.assertEqual(buffer.getvalue(), expected) | |
1003 |
|
1022 | |||
1004 | compressor = cctx.stream_writer(buffer, write_return_read=True) |
|
1023 | compressor = cctx.stream_writer(buffer, write_return_read=True) | |
1005 | self.assertEqual(compressor.write(source), len(source)) |
|
1024 | self.assertEqual(compressor.write(source), len(source)) | |
1006 |
|
1025 | |||
1007 | def test_multiple_compress(self): |
|
1026 | def test_multiple_compress(self): | |
1008 | buffer = NonClosingBytesIO() |
|
1027 | buffer = NonClosingBytesIO() | |
1009 | cctx = zstd.ZstdCompressor(level=5) |
|
1028 | cctx = zstd.ZstdCompressor(level=5) | |
1010 | with cctx.stream_writer(buffer) as compressor: |
|
1029 | with cctx.stream_writer(buffer) as compressor: | |
1011 | self.assertEqual(compressor.write(b"foo"), 0) |
|
1030 | self.assertEqual(compressor.write(b"foo"), 0) | |
1012 | self.assertEqual(compressor.write(b"bar"), 0) |
|
1031 | self.assertEqual(compressor.write(b"bar"), 0) | |
1013 | self.assertEqual(compressor.write(b"x" * 8192), 0) |
|
1032 | self.assertEqual(compressor.write(b"x" * 8192), 0) | |
1014 |
|
1033 | |||
1015 | result = buffer.getvalue() |
|
1034 | result = buffer.getvalue() | |
1016 | self.assertEqual( |
|
1035 | self.assertEqual( | |
1017 | result, |
|
1036 | result, | |
1018 | b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" |
|
1037 | b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" | |
1019 | b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", |
|
1038 | b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", | |
1020 | ) |
|
1039 | ) | |
1021 |
|
1040 | |||
1022 | # Test without context manager. |
|
1041 | # Test without context manager. | |
1023 | buffer = io.BytesIO() |
|
1042 | buffer = io.BytesIO() | |
1024 | compressor = cctx.stream_writer(buffer) |
|
1043 | compressor = cctx.stream_writer(buffer) | |
1025 | self.assertEqual(compressor.write(b"foo"), 0) |
|
1044 | self.assertEqual(compressor.write(b"foo"), 0) | |
1026 | self.assertEqual(compressor.write(b"bar"), 0) |
|
1045 | self.assertEqual(compressor.write(b"bar"), 0) | |
1027 | self.assertEqual(compressor.write(b"x" * 8192), 0) |
|
1046 | self.assertEqual(compressor.write(b"x" * 8192), 0) | |
1028 | self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) |
|
1047 | self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) | |
1029 | result = buffer.getvalue() |
|
1048 | result = buffer.getvalue() | |
1030 | self.assertEqual( |
|
1049 | self.assertEqual( | |
1031 | result, |
|
1050 | result, | |
1032 | b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" |
|
1051 | b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" | |
1033 | b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", |
|
1052 | b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", | |
1034 | ) |
|
1053 | ) | |
1035 |
|
1054 | |||
1036 | # Test with write_return_read=True. |
|
1055 | # Test with write_return_read=True. | |
1037 | compressor = cctx.stream_writer(buffer, write_return_read=True) |
|
1056 | compressor = cctx.stream_writer(buffer, write_return_read=True) | |
1038 | self.assertEqual(compressor.write(b"foo"), 3) |
|
1057 | self.assertEqual(compressor.write(b"foo"), 3) | |
1039 | self.assertEqual(compressor.write(b"barbiz"), 6) |
|
1058 | self.assertEqual(compressor.write(b"barbiz"), 6) | |
1040 | self.assertEqual(compressor.write(b"x" * 8192), 8192) |
|
1059 | self.assertEqual(compressor.write(b"x" * 8192), 8192) | |
1041 |
|
1060 | |||
1042 | def test_dictionary(self): |
|
1061 | def test_dictionary(self): | |
1043 | samples = [] |
|
1062 | samples = [] | |
1044 | for i in range(128): |
|
1063 | for i in range(128): | |
1045 | samples.append(b"foo" * 64) |
|
1064 | samples.append(b"foo" * 64) | |
1046 | samples.append(b"bar" * 64) |
|
1065 | samples.append(b"bar" * 64) | |
1047 | samples.append(b"foobar" * 64) |
|
1066 | samples.append(b"foobar" * 64) | |
1048 |
|
1067 | |||
1049 | d = zstd.train_dictionary(8192, samples) |
|
1068 | d = zstd.train_dictionary(8192, samples) | |
1050 |
|
1069 | |||
1051 | h = hashlib.sha1(d.as_bytes()).hexdigest() |
|
1070 | h = hashlib.sha1(d.as_bytes()).hexdigest() | |
1052 | self.assertEqual(h, "7a2e59a876db958f74257141045af8f912e00d4e") |
|
1071 | self.assertEqual(h, "7a2e59a876db958f74257141045af8f912e00d4e") | |
1053 |
|
1072 | |||
1054 | buffer = NonClosingBytesIO() |
|
1073 | buffer = NonClosingBytesIO() | |
1055 | cctx = zstd.ZstdCompressor(level=9, dict_data=d) |
|
1074 | cctx = zstd.ZstdCompressor(level=9, dict_data=d) | |
1056 | with cctx.stream_writer(buffer) as compressor: |
|
1075 | with cctx.stream_writer(buffer) as compressor: | |
1057 | self.assertEqual(compressor.write(b"foo"), 0) |
|
1076 | self.assertEqual(compressor.write(b"foo"), 0) | |
1058 | self.assertEqual(compressor.write(b"bar"), 0) |
|
1077 | self.assertEqual(compressor.write(b"bar"), 0) | |
1059 | self.assertEqual(compressor.write(b"foo" * 16384), 0) |
|
1078 | self.assertEqual(compressor.write(b"foo" * 16384), 0) | |
1060 |
|
1079 | |||
1061 | compressed = buffer.getvalue() |
|
1080 | compressed = buffer.getvalue() | |
1062 |
|
1081 | |||
1063 | params = zstd.get_frame_parameters(compressed) |
|
1082 | params = zstd.get_frame_parameters(compressed) | |
1064 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1083 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1065 | self.assertEqual(params.window_size, 2097152) |
|
1084 | self.assertEqual(params.window_size, 2097152) | |
1066 | self.assertEqual(params.dict_id, d.dict_id()) |
|
1085 | self.assertEqual(params.dict_id, d.dict_id()) | |
1067 | self.assertFalse(params.has_checksum) |
|
1086 | self.assertFalse(params.has_checksum) | |
1068 |
|
1087 | |||
1069 | h = hashlib.sha1(compressed).hexdigest() |
|
1088 | h = hashlib.sha1(compressed).hexdigest() | |
1070 | self.assertEqual(h, "0a7c05635061f58039727cdbe76388c6f4cfef06") |
|
1089 | self.assertEqual(h, "0a7c05635061f58039727cdbe76388c6f4cfef06") | |
1071 |
|
1090 | |||
1072 | source = b"foo" + b"bar" + (b"foo" * 16384) |
|
1091 | source = b"foo" + b"bar" + (b"foo" * 16384) | |
1073 |
|
1092 | |||
1074 | dctx = zstd.ZstdDecompressor(dict_data=d) |
|
1093 | dctx = zstd.ZstdDecompressor(dict_data=d) | |
1075 |
|
1094 | |||
1076 | self.assertEqual( |
|
1095 | self.assertEqual( | |
1077 | dctx.decompress(compressed, max_output_size=len(source)), source |
|
1096 | dctx.decompress(compressed, max_output_size=len(source)), source | |
1078 | ) |
|
1097 | ) | |
1079 |
|
1098 | |||
1080 | def test_compression_params(self): |
|
1099 | def test_compression_params(self): | |
1081 | params = zstd.ZstdCompressionParameters( |
|
1100 | params = zstd.ZstdCompressionParameters( | |
1082 | window_log=20, |
|
1101 | window_log=20, | |
1083 | chain_log=6, |
|
1102 | chain_log=6, | |
1084 | hash_log=12, |
|
1103 | hash_log=12, | |
1085 | min_match=5, |
|
1104 | min_match=5, | |
1086 | search_log=4, |
|
1105 | search_log=4, | |
1087 | target_length=10, |
|
1106 | target_length=10, | |
1088 | strategy=zstd.STRATEGY_FAST, |
|
1107 | strategy=zstd.STRATEGY_FAST, | |
1089 | ) |
|
1108 | ) | |
1090 |
|
1109 | |||
1091 | buffer = NonClosingBytesIO() |
|
1110 | buffer = NonClosingBytesIO() | |
1092 | cctx = zstd.ZstdCompressor(compression_params=params) |
|
1111 | cctx = zstd.ZstdCompressor(compression_params=params) | |
1093 | with cctx.stream_writer(buffer) as compressor: |
|
1112 | with cctx.stream_writer(buffer) as compressor: | |
1094 | self.assertEqual(compressor.write(b"foo"), 0) |
|
1113 | self.assertEqual(compressor.write(b"foo"), 0) | |
1095 | self.assertEqual(compressor.write(b"bar"), 0) |
|
1114 | self.assertEqual(compressor.write(b"bar"), 0) | |
1096 | self.assertEqual(compressor.write(b"foobar" * 16384), 0) |
|
1115 | self.assertEqual(compressor.write(b"foobar" * 16384), 0) | |
1097 |
|
1116 | |||
1098 | compressed = buffer.getvalue() |
|
1117 | compressed = buffer.getvalue() | |
1099 |
|
1118 | |||
1100 | params = zstd.get_frame_parameters(compressed) |
|
1119 | params = zstd.get_frame_parameters(compressed) | |
1101 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1120 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1102 | self.assertEqual(params.window_size, 1048576) |
|
1121 | self.assertEqual(params.window_size, 1048576) | |
1103 | self.assertEqual(params.dict_id, 0) |
|
1122 | self.assertEqual(params.dict_id, 0) | |
1104 | self.assertFalse(params.has_checksum) |
|
1123 | self.assertFalse(params.has_checksum) | |
1105 |
|
1124 | |||
1106 | h = hashlib.sha1(compressed).hexdigest() |
|
1125 | h = hashlib.sha1(compressed).hexdigest() | |
1107 | self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b") |
|
1126 | self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b") | |
1108 |
|
1127 | |||
1109 | def test_write_checksum(self): |
|
1128 | def test_write_checksum(self): | |
1110 | no_checksum = NonClosingBytesIO() |
|
1129 | no_checksum = NonClosingBytesIO() | |
1111 | cctx = zstd.ZstdCompressor(level=1) |
|
1130 | cctx = zstd.ZstdCompressor(level=1) | |
1112 | with cctx.stream_writer(no_checksum) as compressor: |
|
1131 | with cctx.stream_writer(no_checksum) as compressor: | |
1113 | self.assertEqual(compressor.write(b"foobar"), 0) |
|
1132 | self.assertEqual(compressor.write(b"foobar"), 0) | |
1114 |
|
1133 | |||
1115 | with_checksum = NonClosingBytesIO() |
|
1134 | with_checksum = NonClosingBytesIO() | |
1116 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
1135 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | |
1117 | with cctx.stream_writer(with_checksum) as compressor: |
|
1136 | with cctx.stream_writer(with_checksum) as compressor: | |
1118 | self.assertEqual(compressor.write(b"foobar"), 0) |
|
1137 | self.assertEqual(compressor.write(b"foobar"), 0) | |
1119 |
|
1138 | |||
1120 | no_params = zstd.get_frame_parameters(no_checksum.getvalue()) |
|
1139 | no_params = zstd.get_frame_parameters(no_checksum.getvalue()) | |
1121 | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) |
|
1140 | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | |
1122 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1141 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1123 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1142 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1124 | self.assertEqual(no_params.dict_id, 0) |
|
1143 | self.assertEqual(no_params.dict_id, 0) | |
1125 | self.assertEqual(with_params.dict_id, 0) |
|
1144 | self.assertEqual(with_params.dict_id, 0) | |
1126 | self.assertFalse(no_params.has_checksum) |
|
1145 | self.assertFalse(no_params.has_checksum) | |
1127 | self.assertTrue(with_params.has_checksum) |
|
1146 | self.assertTrue(with_params.has_checksum) | |
1128 |
|
1147 | |||
1129 | self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4) |
|
1148 | self.assertEqual( | |
|
1149 | len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 | |||
|
1150 | ) | |||
1130 |
|
1151 | |||
1131 | def test_write_content_size(self): |
|
1152 | def test_write_content_size(self): | |
1132 | no_size = NonClosingBytesIO() |
|
1153 | no_size = NonClosingBytesIO() | |
1133 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
1154 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
1134 | with cctx.stream_writer(no_size) as compressor: |
|
1155 | with cctx.stream_writer(no_size) as compressor: | |
1135 | self.assertEqual(compressor.write(b"foobar" * 256), 0) |
|
1156 | self.assertEqual(compressor.write(b"foobar" * 256), 0) | |
1136 |
|
1157 | |||
1137 | with_size = NonClosingBytesIO() |
|
1158 | with_size = NonClosingBytesIO() | |
1138 | cctx = zstd.ZstdCompressor(level=1) |
|
1159 | cctx = zstd.ZstdCompressor(level=1) | |
1139 | with cctx.stream_writer(with_size) as compressor: |
|
1160 | with cctx.stream_writer(with_size) as compressor: | |
1140 | self.assertEqual(compressor.write(b"foobar" * 256), 0) |
|
1161 | self.assertEqual(compressor.write(b"foobar" * 256), 0) | |
1141 |
|
1162 | |||
1142 | # Source size is not known in streaming mode, so header not |
|
1163 | # Source size is not known in streaming mode, so header not | |
1143 | # written. |
|
1164 | # written. | |
1144 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) |
|
1165 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) | |
1145 |
|
1166 | |||
1146 | # Declaring size will write the header. |
|
1167 | # Declaring size will write the header. | |
1147 | with_size = NonClosingBytesIO() |
|
1168 | with_size = NonClosingBytesIO() | |
1148 | with cctx.stream_writer(with_size, size=len(b"foobar" * 256)) as compressor: |
|
1169 | with cctx.stream_writer( | |
|
1170 | with_size, size=len(b"foobar" * 256) | |||
|
1171 | ) as compressor: | |||
1149 | self.assertEqual(compressor.write(b"foobar" * 256), 0) |
|
1172 | self.assertEqual(compressor.write(b"foobar" * 256), 0) | |
1150 |
|
1173 | |||
1151 | no_params = zstd.get_frame_parameters(no_size.getvalue()) |
|
1174 | no_params = zstd.get_frame_parameters(no_size.getvalue()) | |
1152 | with_params = zstd.get_frame_parameters(with_size.getvalue()) |
|
1175 | with_params = zstd.get_frame_parameters(with_size.getvalue()) | |
1153 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1176 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1154 | self.assertEqual(with_params.content_size, 1536) |
|
1177 | self.assertEqual(with_params.content_size, 1536) | |
1155 | self.assertEqual(no_params.dict_id, 0) |
|
1178 | self.assertEqual(no_params.dict_id, 0) | |
1156 | self.assertEqual(with_params.dict_id, 0) |
|
1179 | self.assertEqual(with_params.dict_id, 0) | |
1157 | self.assertFalse(no_params.has_checksum) |
|
1180 | self.assertFalse(no_params.has_checksum) | |
1158 | self.assertFalse(with_params.has_checksum) |
|
1181 | self.assertFalse(with_params.has_checksum) | |
1159 |
|
1182 | |||
1160 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) |
|
1183 | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) | |
1161 |
|
1184 | |||
1162 | def test_no_dict_id(self): |
|
1185 | def test_no_dict_id(self): | |
1163 | samples = [] |
|
1186 | samples = [] | |
1164 | for i in range(128): |
|
1187 | for i in range(128): | |
1165 | samples.append(b"foo" * 64) |
|
1188 | samples.append(b"foo" * 64) | |
1166 | samples.append(b"bar" * 64) |
|
1189 | samples.append(b"bar" * 64) | |
1167 | samples.append(b"foobar" * 64) |
|
1190 | samples.append(b"foobar" * 64) | |
1168 |
|
1191 | |||
1169 | d = zstd.train_dictionary(1024, samples) |
|
1192 | d = zstd.train_dictionary(1024, samples) | |
1170 |
|
1193 | |||
1171 | with_dict_id = NonClosingBytesIO() |
|
1194 | with_dict_id = NonClosingBytesIO() | |
1172 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
1195 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) | |
1173 | with cctx.stream_writer(with_dict_id) as compressor: |
|
1196 | with cctx.stream_writer(with_dict_id) as compressor: | |
1174 | self.assertEqual(compressor.write(b"foobarfoobar"), 0) |
|
1197 | self.assertEqual(compressor.write(b"foobarfoobar"), 0) | |
1175 |
|
1198 | |||
1176 | self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03") |
|
1199 | self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03") | |
1177 |
|
1200 | |||
1178 | cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) |
|
1201 | cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) | |
1179 | no_dict_id = NonClosingBytesIO() |
|
1202 | no_dict_id = NonClosingBytesIO() | |
1180 | with cctx.stream_writer(no_dict_id) as compressor: |
|
1203 | with cctx.stream_writer(no_dict_id) as compressor: | |
1181 | self.assertEqual(compressor.write(b"foobarfoobar"), 0) |
|
1204 | self.assertEqual(compressor.write(b"foobarfoobar"), 0) | |
1182 |
|
1205 | |||
1183 | self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00") |
|
1206 | self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00") | |
1184 |
|
1207 | |||
1185 | no_params = zstd.get_frame_parameters(no_dict_id.getvalue()) |
|
1208 | no_params = zstd.get_frame_parameters(no_dict_id.getvalue()) | |
1186 | with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) |
|
1209 | with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) | |
1187 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1210 | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1188 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1211 | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1189 | self.assertEqual(no_params.dict_id, 0) |
|
1212 | self.assertEqual(no_params.dict_id, 0) | |
1190 | self.assertEqual(with_params.dict_id, d.dict_id()) |
|
1213 | self.assertEqual(with_params.dict_id, d.dict_id()) | |
1191 | self.assertFalse(no_params.has_checksum) |
|
1214 | self.assertFalse(no_params.has_checksum) | |
1192 | self.assertFalse(with_params.has_checksum) |
|
1215 | self.assertFalse(with_params.has_checksum) | |
1193 |
|
1216 | |||
1194 | self.assertEqual(len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4) |
|
1217 | self.assertEqual( | |
|
1218 | len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4 | |||
|
1219 | ) | |||
1195 |
|
1220 | |||
1196 | def test_memory_size(self): |
|
1221 | def test_memory_size(self): | |
1197 | cctx = zstd.ZstdCompressor(level=3) |
|
1222 | cctx = zstd.ZstdCompressor(level=3) | |
1198 | buffer = io.BytesIO() |
|
1223 | buffer = io.BytesIO() | |
1199 | with cctx.stream_writer(buffer) as compressor: |
|
1224 | with cctx.stream_writer(buffer) as compressor: | |
1200 | compressor.write(b"foo") |
|
1225 | compressor.write(b"foo") | |
1201 | size = compressor.memory_size() |
|
1226 | size = compressor.memory_size() | |
1202 |
|
1227 | |||
1203 | self.assertGreater(size, 100000) |
|
1228 | self.assertGreater(size, 100000) | |
1204 |
|
1229 | |||
1205 | def test_write_size(self): |
|
1230 | def test_write_size(self): | |
1206 | cctx = zstd.ZstdCompressor(level=3) |
|
1231 | cctx = zstd.ZstdCompressor(level=3) | |
1207 | dest = OpCountingBytesIO() |
|
1232 | dest = OpCountingBytesIO() | |
1208 | with cctx.stream_writer(dest, write_size=1) as compressor: |
|
1233 | with cctx.stream_writer(dest, write_size=1) as compressor: | |
1209 | self.assertEqual(compressor.write(b"foo"), 0) |
|
1234 | self.assertEqual(compressor.write(b"foo"), 0) | |
1210 | self.assertEqual(compressor.write(b"bar"), 0) |
|
1235 | self.assertEqual(compressor.write(b"bar"), 0) | |
1211 | self.assertEqual(compressor.write(b"foobar"), 0) |
|
1236 | self.assertEqual(compressor.write(b"foobar"), 0) | |
1212 |
|
1237 | |||
1213 | self.assertEqual(len(dest.getvalue()), dest._write_count) |
|
1238 | self.assertEqual(len(dest.getvalue()), dest._write_count) | |
1214 |
|
1239 | |||
1215 | def test_flush_repeated(self): |
|
1240 | def test_flush_repeated(self): | |
1216 | cctx = zstd.ZstdCompressor(level=3) |
|
1241 | cctx = zstd.ZstdCompressor(level=3) | |
1217 | dest = OpCountingBytesIO() |
|
1242 | dest = OpCountingBytesIO() | |
1218 | with cctx.stream_writer(dest) as compressor: |
|
1243 | with cctx.stream_writer(dest) as compressor: | |
1219 | self.assertEqual(compressor.write(b"foo"), 0) |
|
1244 | self.assertEqual(compressor.write(b"foo"), 0) | |
1220 | self.assertEqual(dest._write_count, 0) |
|
1245 | self.assertEqual(dest._write_count, 0) | |
1221 | self.assertEqual(compressor.flush(), 12) |
|
1246 | self.assertEqual(compressor.flush(), 12) | |
1222 | self.assertEqual(dest._write_count, 1) |
|
1247 | self.assertEqual(dest._write_count, 1) | |
1223 | self.assertEqual(compressor.write(b"bar"), 0) |
|
1248 | self.assertEqual(compressor.write(b"bar"), 0) | |
1224 | self.assertEqual(dest._write_count, 1) |
|
1249 | self.assertEqual(dest._write_count, 1) | |
1225 | self.assertEqual(compressor.flush(), 6) |
|
1250 | self.assertEqual(compressor.flush(), 6) | |
1226 | self.assertEqual(dest._write_count, 2) |
|
1251 | self.assertEqual(dest._write_count, 2) | |
1227 | self.assertEqual(compressor.write(b"baz"), 0) |
|
1252 | self.assertEqual(compressor.write(b"baz"), 0) | |
1228 |
|
1253 | |||
1229 | self.assertEqual(dest._write_count, 3) |
|
1254 | self.assertEqual(dest._write_count, 3) | |
1230 |
|
1255 | |||
1231 | def test_flush_empty_block(self): |
|
1256 | def test_flush_empty_block(self): | |
1232 | cctx = zstd.ZstdCompressor(level=3, write_checksum=True) |
|
1257 | cctx = zstd.ZstdCompressor(level=3, write_checksum=True) | |
1233 | dest = OpCountingBytesIO() |
|
1258 | dest = OpCountingBytesIO() | |
1234 | with cctx.stream_writer(dest) as compressor: |
|
1259 | with cctx.stream_writer(dest) as compressor: | |
1235 | self.assertEqual(compressor.write(b"foobar" * 8192), 0) |
|
1260 | self.assertEqual(compressor.write(b"foobar" * 8192), 0) | |
1236 | count = dest._write_count |
|
1261 | count = dest._write_count | |
1237 | offset = dest.tell() |
|
1262 | offset = dest.tell() | |
1238 | self.assertEqual(compressor.flush(), 23) |
|
1263 | self.assertEqual(compressor.flush(), 23) | |
1239 | self.assertGreater(dest._write_count, count) |
|
1264 | self.assertGreater(dest._write_count, count) | |
1240 | self.assertGreater(dest.tell(), offset) |
|
1265 | self.assertGreater(dest.tell(), offset) | |
1241 | offset = dest.tell() |
|
1266 | offset = dest.tell() | |
1242 | # Ending the write here should cause an empty block to be written |
|
1267 | # Ending the write here should cause an empty block to be written | |
1243 | # to denote end of frame. |
|
1268 | # to denote end of frame. | |
1244 |
|
1269 | |||
1245 | trailing = dest.getvalue()[offset:] |
|
1270 | trailing = dest.getvalue()[offset:] | |
1246 | # 3 bytes block header + 4 bytes frame checksum |
|
1271 | # 3 bytes block header + 4 bytes frame checksum | |
1247 | self.assertEqual(len(trailing), 7) |
|
1272 | self.assertEqual(len(trailing), 7) | |
1248 |
|
1273 | |||
1249 | header = trailing[0:3] |
|
1274 | header = trailing[0:3] | |
1250 | self.assertEqual(header, b"\x01\x00\x00") |
|
1275 | self.assertEqual(header, b"\x01\x00\x00") | |
1251 |
|
1276 | |||
1252 | def test_flush_frame(self): |
|
1277 | def test_flush_frame(self): | |
1253 | cctx = zstd.ZstdCompressor(level=3) |
|
1278 | cctx = zstd.ZstdCompressor(level=3) | |
1254 | dest = OpCountingBytesIO() |
|
1279 | dest = OpCountingBytesIO() | |
1255 |
|
1280 | |||
1256 | with cctx.stream_writer(dest) as compressor: |
|
1281 | with cctx.stream_writer(dest) as compressor: | |
1257 | self.assertEqual(compressor.write(b"foobar" * 8192), 0) |
|
1282 | self.assertEqual(compressor.write(b"foobar" * 8192), 0) | |
1258 | self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) |
|
1283 | self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) | |
1259 | compressor.write(b"biz" * 16384) |
|
1284 | compressor.write(b"biz" * 16384) | |
1260 |
|
1285 | |||
1261 | self.assertEqual( |
|
1286 | self.assertEqual( | |
1262 | dest.getvalue(), |
|
1287 | dest.getvalue(), | |
1263 | # Frame 1. |
|
1288 | # Frame 1. | |
1264 | b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f" |
|
1289 | b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f" | |
1265 | b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08" |
|
1290 | b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08" | |
1266 | # Frame 2. |
|
1291 | # Frame 2. | |
1267 | b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a" |
|
1292 | b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a" | |
1268 | b"\x01\x00\xfa\x3f\x75\x37\x04", |
|
1293 | b"\x01\x00\xfa\x3f\x75\x37\x04", | |
1269 | ) |
|
1294 | ) | |
1270 |
|
1295 | |||
1271 | def test_bad_flush_mode(self): |
|
1296 | def test_bad_flush_mode(self): | |
1272 | cctx = zstd.ZstdCompressor() |
|
1297 | cctx = zstd.ZstdCompressor() | |
1273 | dest = io.BytesIO() |
|
1298 | dest = io.BytesIO() | |
1274 | with cctx.stream_writer(dest) as compressor: |
|
1299 | with cctx.stream_writer(dest) as compressor: | |
1275 | with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"): |
|
1300 | with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"): | |
1276 | compressor.flush(flush_mode=42) |
|
1301 | compressor.flush(flush_mode=42) | |
1277 |
|
1302 | |||
1278 | def test_multithreaded(self): |
|
1303 | def test_multithreaded(self): | |
1279 | dest = NonClosingBytesIO() |
|
1304 | dest = NonClosingBytesIO() | |
1280 | cctx = zstd.ZstdCompressor(threads=2) |
|
1305 | cctx = zstd.ZstdCompressor(threads=2) | |
1281 | with cctx.stream_writer(dest) as compressor: |
|
1306 | with cctx.stream_writer(dest) as compressor: | |
1282 | compressor.write(b"a" * 1048576) |
|
1307 | compressor.write(b"a" * 1048576) | |
1283 | compressor.write(b"b" * 1048576) |
|
1308 | compressor.write(b"b" * 1048576) | |
1284 | compressor.write(b"c" * 1048576) |
|
1309 | compressor.write(b"c" * 1048576) | |
1285 |
|
1310 | |||
1286 | self.assertEqual(len(dest.getvalue()), 111) |
|
1311 | self.assertEqual(len(dest.getvalue()), 111) | |
1287 |
|
1312 | |||
1288 | def test_tell(self): |
|
1313 | def test_tell(self): | |
1289 | dest = io.BytesIO() |
|
1314 | dest = io.BytesIO() | |
1290 | cctx = zstd.ZstdCompressor() |
|
1315 | cctx = zstd.ZstdCompressor() | |
1291 | with cctx.stream_writer(dest) as compressor: |
|
1316 | with cctx.stream_writer(dest) as compressor: | |
1292 | self.assertEqual(compressor.tell(), 0) |
|
1317 | self.assertEqual(compressor.tell(), 0) | |
1293 |
|
1318 | |||
1294 | for i in range(256): |
|
1319 | for i in range(256): | |
1295 | compressor.write(b"foo" * (i + 1)) |
|
1320 | compressor.write(b"foo" * (i + 1)) | |
1296 | self.assertEqual(compressor.tell(), dest.tell()) |
|
1321 | self.assertEqual(compressor.tell(), dest.tell()) | |
1297 |
|
1322 | |||
1298 | def test_bad_size(self): |
|
1323 | def test_bad_size(self): | |
1299 | cctx = zstd.ZstdCompressor() |
|
1324 | cctx = zstd.ZstdCompressor() | |
1300 |
|
1325 | |||
1301 | dest = io.BytesIO() |
|
1326 | dest = io.BytesIO() | |
1302 |
|
1327 | |||
1303 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): |
|
1328 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): | |
1304 | with cctx.stream_writer(dest, size=2) as compressor: |
|
1329 | with cctx.stream_writer(dest, size=2) as compressor: | |
1305 | compressor.write(b"foo") |
|
1330 | compressor.write(b"foo") | |
1306 |
|
1331 | |||
1307 | # Test another operation. |
|
1332 | # Test another operation. | |
1308 | with cctx.stream_writer(dest, size=42): |
|
1333 | with cctx.stream_writer(dest, size=42): | |
1309 | pass |
|
1334 | pass | |
1310 |
|
1335 | |||
1311 | def test_tarfile_compat(self): |
|
1336 | def test_tarfile_compat(self): | |
1312 | dest = NonClosingBytesIO() |
|
1337 | dest = NonClosingBytesIO() | |
1313 | cctx = zstd.ZstdCompressor() |
|
1338 | cctx = zstd.ZstdCompressor() | |
1314 | with cctx.stream_writer(dest) as compressor: |
|
1339 | with cctx.stream_writer(dest) as compressor: | |
1315 | with tarfile.open("tf", mode="w|", fileobj=compressor) as tf: |
|
1340 | with tarfile.open("tf", mode="w|", fileobj=compressor) as tf: | |
1316 | tf.add(__file__, "test_compressor.py") |
|
1341 | tf.add(__file__, "test_compressor.py") | |
1317 |
|
1342 | |||
1318 | dest = io.BytesIO(dest.getvalue()) |
|
1343 | dest = io.BytesIO(dest.getvalue()) | |
1319 |
|
1344 | |||
1320 | dctx = zstd.ZstdDecompressor() |
|
1345 | dctx = zstd.ZstdDecompressor() | |
1321 | with dctx.stream_reader(dest) as reader: |
|
1346 | with dctx.stream_reader(dest) as reader: | |
1322 | with tarfile.open(mode="r|", fileobj=reader) as tf: |
|
1347 | with tarfile.open(mode="r|", fileobj=reader) as tf: | |
1323 | for member in tf: |
|
1348 | for member in tf: | |
1324 | self.assertEqual(member.name, "test_compressor.py") |
|
1349 | self.assertEqual(member.name, "test_compressor.py") | |
1325 |
|
1350 | |||
1326 |
|
1351 | |||
1327 | @make_cffi |
|
1352 | @make_cffi | |
1328 | class TestCompressor_read_to_iter(TestCase): |
|
1353 | class TestCompressor_read_to_iter(TestCase): | |
1329 | def test_type_validation(self): |
|
1354 | def test_type_validation(self): | |
1330 | cctx = zstd.ZstdCompressor() |
|
1355 | cctx = zstd.ZstdCompressor() | |
1331 |
|
1356 | |||
1332 | # Object with read() works. |
|
1357 | # Object with read() works. | |
1333 | for chunk in cctx.read_to_iter(io.BytesIO()): |
|
1358 | for chunk in cctx.read_to_iter(io.BytesIO()): | |
1334 | pass |
|
1359 | pass | |
1335 |
|
1360 | |||
1336 | # Buffer protocol works. |
|
1361 | # Buffer protocol works. | |
1337 | for chunk in cctx.read_to_iter(b"foobar"): |
|
1362 | for chunk in cctx.read_to_iter(b"foobar"): | |
1338 | pass |
|
1363 | pass | |
1339 |
|
1364 | |||
1340 |
with self.assertRaisesRegex( |
|
1365 | with self.assertRaisesRegex( | |
|
1366 | ValueError, "must pass an object with a read" | |||
|
1367 | ): | |||
1341 | for chunk in cctx.read_to_iter(True): |
|
1368 | for chunk in cctx.read_to_iter(True): | |
1342 | pass |
|
1369 | pass | |
1343 |
|
1370 | |||
1344 | def test_read_empty(self): |
|
1371 | def test_read_empty(self): | |
1345 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
1372 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
1346 |
|
1373 | |||
1347 | source = io.BytesIO() |
|
1374 | source = io.BytesIO() | |
1348 | it = cctx.read_to_iter(source) |
|
1375 | it = cctx.read_to_iter(source) | |
1349 | chunks = list(it) |
|
1376 | chunks = list(it) | |
1350 | self.assertEqual(len(chunks), 1) |
|
1377 | self.assertEqual(len(chunks), 1) | |
1351 | compressed = b"".join(chunks) |
|
1378 | compressed = b"".join(chunks) | |
1352 | self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") |
|
1379 | self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") | |
1353 |
|
1380 | |||
1354 | # And again with the buffer protocol. |
|
1381 | # And again with the buffer protocol. | |
1355 | it = cctx.read_to_iter(b"") |
|
1382 | it = cctx.read_to_iter(b"") | |
1356 | chunks = list(it) |
|
1383 | chunks = list(it) | |
1357 | self.assertEqual(len(chunks), 1) |
|
1384 | self.assertEqual(len(chunks), 1) | |
1358 | compressed2 = b"".join(chunks) |
|
1385 | compressed2 = b"".join(chunks) | |
1359 | self.assertEqual(compressed2, compressed) |
|
1386 | self.assertEqual(compressed2, compressed) | |
1360 |
|
1387 | |||
1361 | def test_read_large(self): |
|
1388 | def test_read_large(self): | |
1362 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
1389 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
1363 |
|
1390 | |||
1364 | source = io.BytesIO() |
|
1391 | source = io.BytesIO() | |
1365 | source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) |
|
1392 | source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) | |
1366 | source.write(b"o") |
|
1393 | source.write(b"o") | |
1367 | source.seek(0) |
|
1394 | source.seek(0) | |
1368 |
|
1395 | |||
1369 | # Creating an iterator should not perform any compression until |
|
1396 | # Creating an iterator should not perform any compression until | |
1370 | # first read. |
|
1397 | # first read. | |
1371 | it = cctx.read_to_iter(source, size=len(source.getvalue())) |
|
1398 | it = cctx.read_to_iter(source, size=len(source.getvalue())) | |
1372 | self.assertEqual(source.tell(), 0) |
|
1399 | self.assertEqual(source.tell(), 0) | |
1373 |
|
1400 | |||
1374 | # We should have exactly 2 output chunks. |
|
1401 | # We should have exactly 2 output chunks. | |
1375 | chunks = [] |
|
1402 | chunks = [] | |
1376 | chunk = next(it) |
|
1403 | chunk = next(it) | |
1377 | self.assertIsNotNone(chunk) |
|
1404 | self.assertIsNotNone(chunk) | |
1378 | self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) |
|
1405 | self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) | |
1379 | chunks.append(chunk) |
|
1406 | chunks.append(chunk) | |
1380 | chunk = next(it) |
|
1407 | chunk = next(it) | |
1381 | self.assertIsNotNone(chunk) |
|
1408 | self.assertIsNotNone(chunk) | |
1382 | chunks.append(chunk) |
|
1409 | chunks.append(chunk) | |
1383 |
|
1410 | |||
1384 | self.assertEqual(source.tell(), len(source.getvalue())) |
|
1411 | self.assertEqual(source.tell(), len(source.getvalue())) | |
1385 |
|
1412 | |||
1386 | with self.assertRaises(StopIteration): |
|
1413 | with self.assertRaises(StopIteration): | |
1387 | next(it) |
|
1414 | next(it) | |
1388 |
|
1415 | |||
1389 | # And again for good measure. |
|
1416 | # And again for good measure. | |
1390 | with self.assertRaises(StopIteration): |
|
1417 | with self.assertRaises(StopIteration): | |
1391 | next(it) |
|
1418 | next(it) | |
1392 |
|
1419 | |||
1393 | # We should get the same output as the one-shot compression mechanism. |
|
1420 | # We should get the same output as the one-shot compression mechanism. | |
1394 | self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) |
|
1421 | self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) | |
1395 |
|
1422 | |||
1396 | params = zstd.get_frame_parameters(b"".join(chunks)) |
|
1423 | params = zstd.get_frame_parameters(b"".join(chunks)) | |
1397 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1424 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1398 | self.assertEqual(params.window_size, 262144) |
|
1425 | self.assertEqual(params.window_size, 262144) | |
1399 | self.assertEqual(params.dict_id, 0) |
|
1426 | self.assertEqual(params.dict_id, 0) | |
1400 | self.assertFalse(params.has_checksum) |
|
1427 | self.assertFalse(params.has_checksum) | |
1401 |
|
1428 | |||
1402 | # Now check the buffer protocol. |
|
1429 | # Now check the buffer protocol. | |
1403 | it = cctx.read_to_iter(source.getvalue()) |
|
1430 | it = cctx.read_to_iter(source.getvalue()) | |
1404 | chunks = list(it) |
|
1431 | chunks = list(it) | |
1405 | self.assertEqual(len(chunks), 2) |
|
1432 | self.assertEqual(len(chunks), 2) | |
1406 |
|
1433 | |||
1407 | params = zstd.get_frame_parameters(b"".join(chunks)) |
|
1434 | params = zstd.get_frame_parameters(b"".join(chunks)) | |
1408 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
1435 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
1409 | # self.assertEqual(params.window_size, 262144) |
|
1436 | # self.assertEqual(params.window_size, 262144) | |
1410 | self.assertEqual(params.dict_id, 0) |
|
1437 | self.assertEqual(params.dict_id, 0) | |
1411 | self.assertFalse(params.has_checksum) |
|
1438 | self.assertFalse(params.has_checksum) | |
1412 |
|
1439 | |||
1413 | self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) |
|
1440 | self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) | |
1414 |
|
1441 | |||
1415 | def test_read_write_size(self): |
|
1442 | def test_read_write_size(self): | |
1416 | source = OpCountingBytesIO(b"foobarfoobar") |
|
1443 | source = OpCountingBytesIO(b"foobarfoobar") | |
1417 | cctx = zstd.ZstdCompressor(level=3) |
|
1444 | cctx = zstd.ZstdCompressor(level=3) | |
1418 | for chunk in cctx.read_to_iter(source, read_size=1, write_size=1): |
|
1445 | for chunk in cctx.read_to_iter(source, read_size=1, write_size=1): | |
1419 | self.assertEqual(len(chunk), 1) |
|
1446 | self.assertEqual(len(chunk), 1) | |
1420 |
|
1447 | |||
1421 | self.assertEqual(source._read_count, len(source.getvalue()) + 1) |
|
1448 | self.assertEqual(source._read_count, len(source.getvalue()) + 1) | |
1422 |
|
1449 | |||
1423 | def test_multithreaded(self): |
|
1450 | def test_multithreaded(self): | |
1424 | source = io.BytesIO() |
|
1451 | source = io.BytesIO() | |
1425 | source.write(b"a" * 1048576) |
|
1452 | source.write(b"a" * 1048576) | |
1426 | source.write(b"b" * 1048576) |
|
1453 | source.write(b"b" * 1048576) | |
1427 | source.write(b"c" * 1048576) |
|
1454 | source.write(b"c" * 1048576) | |
1428 | source.seek(0) |
|
1455 | source.seek(0) | |
1429 |
|
1456 | |||
1430 | cctx = zstd.ZstdCompressor(threads=2) |
|
1457 | cctx = zstd.ZstdCompressor(threads=2) | |
1431 |
|
1458 | |||
1432 | compressed = b"".join(cctx.read_to_iter(source)) |
|
1459 | compressed = b"".join(cctx.read_to_iter(source)) | |
1433 | self.assertEqual(len(compressed), 111) |
|
1460 | self.assertEqual(len(compressed), 111) | |
1434 |
|
1461 | |||
1435 | def test_bad_size(self): |
|
1462 | def test_bad_size(self): | |
1436 | cctx = zstd.ZstdCompressor() |
|
1463 | cctx = zstd.ZstdCompressor() | |
1437 |
|
1464 | |||
1438 | source = io.BytesIO(b"a" * 42) |
|
1465 | source = io.BytesIO(b"a" * 42) | |
1439 |
|
1466 | |||
1440 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): |
|
1467 | with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): | |
1441 | b"".join(cctx.read_to_iter(source, size=2)) |
|
1468 | b"".join(cctx.read_to_iter(source, size=2)) | |
1442 |
|
1469 | |||
1443 | # Test another operation on errored compressor. |
|
1470 | # Test another operation on errored compressor. | |
1444 | b"".join(cctx.read_to_iter(source)) |
|
1471 | b"".join(cctx.read_to_iter(source)) | |
1445 |
|
1472 | |||
1446 |
|
1473 | |||
1447 | @make_cffi |
|
1474 | @make_cffi | |
1448 | class TestCompressor_chunker(TestCase): |
|
1475 | class TestCompressor_chunker(TestCase): | |
1449 | def test_empty(self): |
|
1476 | def test_empty(self): | |
1450 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
1477 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
1451 | chunker = cctx.chunker() |
|
1478 | chunker = cctx.chunker() | |
1452 |
|
1479 | |||
1453 | it = chunker.compress(b"") |
|
1480 | it = chunker.compress(b"") | |
1454 |
|
1481 | |||
1455 | with self.assertRaises(StopIteration): |
|
1482 | with self.assertRaises(StopIteration): | |
1456 | next(it) |
|
1483 | next(it) | |
1457 |
|
1484 | |||
1458 | it = chunker.finish() |
|
1485 | it = chunker.finish() | |
1459 |
|
1486 | |||
1460 | self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x58\x01\x00\x00") |
|
1487 | self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x58\x01\x00\x00") | |
1461 |
|
1488 | |||
1462 | with self.assertRaises(StopIteration): |
|
1489 | with self.assertRaises(StopIteration): | |
1463 | next(it) |
|
1490 | next(it) | |
1464 |
|
1491 | |||
1465 | def test_simple_input(self): |
|
1492 | def test_simple_input(self): | |
1466 | cctx = zstd.ZstdCompressor() |
|
1493 | cctx = zstd.ZstdCompressor() | |
1467 | chunker = cctx.chunker() |
|
1494 | chunker = cctx.chunker() | |
1468 |
|
1495 | |||
1469 | it = chunker.compress(b"foobar") |
|
1496 | it = chunker.compress(b"foobar") | |
1470 |
|
1497 | |||
1471 | with self.assertRaises(StopIteration): |
|
1498 | with self.assertRaises(StopIteration): | |
1472 | next(it) |
|
1499 | next(it) | |
1473 |
|
1500 | |||
1474 | it = chunker.compress(b"baz" * 30) |
|
1501 | it = chunker.compress(b"baz" * 30) | |
1475 |
|
1502 | |||
1476 | with self.assertRaises(StopIteration): |
|
1503 | with self.assertRaises(StopIteration): | |
1477 | next(it) |
|
1504 | next(it) | |
1478 |
|
1505 | |||
1479 | it = chunker.finish() |
|
1506 | it = chunker.finish() | |
1480 |
|
1507 | |||
1481 | self.assertEqual( |
|
1508 | self.assertEqual( | |
1482 | next(it), |
|
1509 | next(it), | |
1483 | b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f" |
|
1510 | b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f" | |
1484 | b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e", |
|
1511 | b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e", | |
1485 | ) |
|
1512 | ) | |
1486 |
|
1513 | |||
1487 | with self.assertRaises(StopIteration): |
|
1514 | with self.assertRaises(StopIteration): | |
1488 | next(it) |
|
1515 | next(it) | |
1489 |
|
1516 | |||
1490 | def test_input_size(self): |
|
1517 | def test_input_size(self): | |
1491 | cctx = zstd.ZstdCompressor() |
|
1518 | cctx = zstd.ZstdCompressor() | |
1492 | chunker = cctx.chunker(size=1024) |
|
1519 | chunker = cctx.chunker(size=1024) | |
1493 |
|
1520 | |||
1494 | it = chunker.compress(b"x" * 1000) |
|
1521 | it = chunker.compress(b"x" * 1000) | |
1495 |
|
1522 | |||
1496 | with self.assertRaises(StopIteration): |
|
1523 | with self.assertRaises(StopIteration): | |
1497 | next(it) |
|
1524 | next(it) | |
1498 |
|
1525 | |||
1499 | it = chunker.compress(b"y" * 24) |
|
1526 | it = chunker.compress(b"y" * 24) | |
1500 |
|
1527 | |||
1501 | with self.assertRaises(StopIteration): |
|
1528 | with self.assertRaises(StopIteration): | |
1502 | next(it) |
|
1529 | next(it) | |
1503 |
|
1530 | |||
1504 | chunks = list(chunker.finish()) |
|
1531 | chunks = list(chunker.finish()) | |
1505 |
|
1532 | |||
1506 | self.assertEqual( |
|
1533 | self.assertEqual( | |
1507 | chunks, |
|
1534 | chunks, | |
1508 | [ |
|
1535 | [ | |
1509 | b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" |
|
1536 | b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" | |
1510 | b"\xa0\x16\xe3\x2b\x80\x05" |
|
1537 | b"\xa0\x16\xe3\x2b\x80\x05" | |
1511 | ], |
|
1538 | ], | |
1512 | ) |
|
1539 | ) | |
1513 |
|
1540 | |||
1514 | dctx = zstd.ZstdDecompressor() |
|
1541 | dctx = zstd.ZstdDecompressor() | |
1515 |
|
1542 | |||
1516 | self.assertEqual(dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24)) |
|
1543 | self.assertEqual( | |
|
1544 | dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24) | |||
|
1545 | ) | |||
1517 |
|
1546 | |||
1518 | def test_small_chunk_size(self): |
|
1547 | def test_small_chunk_size(self): | |
1519 | cctx = zstd.ZstdCompressor() |
|
1548 | cctx = zstd.ZstdCompressor() | |
1520 | chunker = cctx.chunker(chunk_size=1) |
|
1549 | chunker = cctx.chunker(chunk_size=1) | |
1521 |
|
1550 | |||
1522 | chunks = list(chunker.compress(b"foo" * 1024)) |
|
1551 | chunks = list(chunker.compress(b"foo" * 1024)) | |
1523 | self.assertEqual(chunks, []) |
|
1552 | self.assertEqual(chunks, []) | |
1524 |
|
1553 | |||
1525 | chunks = list(chunker.finish()) |
|
1554 | chunks = list(chunker.finish()) | |
1526 | self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) |
|
1555 | self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) | |
1527 |
|
1556 | |||
1528 | self.assertEqual( |
|
1557 | self.assertEqual( | |
1529 | b"".join(chunks), |
|
1558 | b"".join(chunks), | |
1530 | b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" |
|
1559 | b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" | |
1531 | b"\xfa\xd3\x77\x43", |
|
1560 | b"\xfa\xd3\x77\x43", | |
1532 | ) |
|
1561 | ) | |
1533 |
|
1562 | |||
1534 | dctx = zstd.ZstdDecompressor() |
|
1563 | dctx = zstd.ZstdDecompressor() | |
1535 | self.assertEqual( |
|
1564 | self.assertEqual( | |
1536 |
dctx.decompress(b"".join(chunks), max_output_size=10000), |
|
1565 | dctx.decompress(b"".join(chunks), max_output_size=10000), | |
|
1566 | b"foo" * 1024, | |||
1537 | ) |
|
1567 | ) | |
1538 |
|
1568 | |||
1539 | def test_input_types(self): |
|
1569 | def test_input_types(self): | |
1540 | cctx = zstd.ZstdCompressor() |
|
1570 | cctx = zstd.ZstdCompressor() | |
1541 |
|
1571 | |||
1542 | mutable_array = bytearray(3) |
|
1572 | mutable_array = bytearray(3) | |
1543 | mutable_array[:] = b"foo" |
|
1573 | mutable_array[:] = b"foo" | |
1544 |
|
1574 | |||
1545 | sources = [ |
|
1575 | sources = [ | |
1546 | memoryview(b"foo"), |
|
1576 | memoryview(b"foo"), | |
1547 | bytearray(b"foo"), |
|
1577 | bytearray(b"foo"), | |
1548 | mutable_array, |
|
1578 | mutable_array, | |
1549 | ] |
|
1579 | ] | |
1550 |
|
1580 | |||
1551 | for source in sources: |
|
1581 | for source in sources: | |
1552 | chunker = cctx.chunker() |
|
1582 | chunker = cctx.chunker() | |
1553 |
|
1583 | |||
1554 | self.assertEqual(list(chunker.compress(source)), []) |
|
1584 | self.assertEqual(list(chunker.compress(source)), []) | |
1555 | self.assertEqual( |
|
1585 | self.assertEqual( | |
1556 | list(chunker.finish()), |
|
1586 | list(chunker.finish()), | |
1557 | [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"], |
|
1587 | [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"], | |
1558 | ) |
|
1588 | ) | |
1559 |
|
1589 | |||
1560 | def test_flush(self): |
|
1590 | def test_flush(self): | |
1561 | cctx = zstd.ZstdCompressor() |
|
1591 | cctx = zstd.ZstdCompressor() | |
1562 | chunker = cctx.chunker() |
|
1592 | chunker = cctx.chunker() | |
1563 |
|
1593 | |||
1564 | self.assertEqual(list(chunker.compress(b"foo" * 1024)), []) |
|
1594 | self.assertEqual(list(chunker.compress(b"foo" * 1024)), []) | |
1565 | self.assertEqual(list(chunker.compress(b"bar" * 1024)), []) |
|
1595 | self.assertEqual(list(chunker.compress(b"bar" * 1024)), []) | |
1566 |
|
1596 | |||
1567 | chunks1 = list(chunker.flush()) |
|
1597 | chunks1 = list(chunker.flush()) | |
1568 |
|
1598 | |||
1569 | self.assertEqual( |
|
1599 | self.assertEqual( | |
1570 | chunks1, |
|
1600 | chunks1, | |
1571 | [ |
|
1601 | [ | |
1572 | b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72" |
|
1602 | b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72" | |
1573 | b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02" |
|
1603 | b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02" | |
1574 | ], |
|
1604 | ], | |
1575 | ) |
|
1605 | ) | |
1576 |
|
1606 | |||
1577 | self.assertEqual(list(chunker.flush()), []) |
|
1607 | self.assertEqual(list(chunker.flush()), []) | |
1578 | self.assertEqual(list(chunker.flush()), []) |
|
1608 | self.assertEqual(list(chunker.flush()), []) | |
1579 |
|
1609 | |||
1580 | self.assertEqual(list(chunker.compress(b"baz" * 1024)), []) |
|
1610 | self.assertEqual(list(chunker.compress(b"baz" * 1024)), []) | |
1581 |
|
1611 | |||
1582 | chunks2 = list(chunker.flush()) |
|
1612 | chunks2 = list(chunker.flush()) | |
1583 | self.assertEqual(len(chunks2), 1) |
|
1613 | self.assertEqual(len(chunks2), 1) | |
1584 |
|
1614 | |||
1585 | chunks3 = list(chunker.finish()) |
|
1615 | chunks3 = list(chunker.finish()) | |
1586 | self.assertEqual(len(chunks2), 1) |
|
1616 | self.assertEqual(len(chunks2), 1) | |
1587 |
|
1617 | |||
1588 | dctx = zstd.ZstdDecompressor() |
|
1618 | dctx = zstd.ZstdDecompressor() | |
1589 |
|
1619 | |||
1590 | self.assertEqual( |
|
1620 | self.assertEqual( | |
1591 | dctx.decompress( |
|
1621 | dctx.decompress( | |
1592 | b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000 |
|
1622 | b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000 | |
1593 | ), |
|
1623 | ), | |
1594 | (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024), |
|
1624 | (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024), | |
1595 | ) |
|
1625 | ) | |
1596 |
|
1626 | |||
1597 | def test_compress_after_finish(self): |
|
1627 | def test_compress_after_finish(self): | |
1598 | cctx = zstd.ZstdCompressor() |
|
1628 | cctx = zstd.ZstdCompressor() | |
1599 | chunker = cctx.chunker() |
|
1629 | chunker = cctx.chunker() | |
1600 |
|
1630 | |||
1601 | list(chunker.compress(b"foo")) |
|
1631 | list(chunker.compress(b"foo")) | |
1602 | list(chunker.finish()) |
|
1632 | list(chunker.finish()) | |
1603 |
|
1633 | |||
1604 | with self.assertRaisesRegex( |
|
1634 | with self.assertRaisesRegex( | |
1605 | zstd.ZstdError, r"cannot call compress\(\) after compression finished" |
|
1635 | zstd.ZstdError, | |
|
1636 | r"cannot call compress\(\) after compression finished", | |||
1606 | ): |
|
1637 | ): | |
1607 | list(chunker.compress(b"foo")) |
|
1638 | list(chunker.compress(b"foo")) | |
1608 |
|
1639 | |||
1609 | def test_flush_after_finish(self): |
|
1640 | def test_flush_after_finish(self): | |
1610 | cctx = zstd.ZstdCompressor() |
|
1641 | cctx = zstd.ZstdCompressor() | |
1611 | chunker = cctx.chunker() |
|
1642 | chunker = cctx.chunker() | |
1612 |
|
1643 | |||
1613 | list(chunker.compress(b"foo")) |
|
1644 | list(chunker.compress(b"foo")) | |
1614 | list(chunker.finish()) |
|
1645 | list(chunker.finish()) | |
1615 |
|
1646 | |||
1616 | with self.assertRaisesRegex( |
|
1647 | with self.assertRaisesRegex( | |
1617 | zstd.ZstdError, r"cannot call flush\(\) after compression finished" |
|
1648 | zstd.ZstdError, r"cannot call flush\(\) after compression finished" | |
1618 | ): |
|
1649 | ): | |
1619 | list(chunker.flush()) |
|
1650 | list(chunker.flush()) | |
1620 |
|
1651 | |||
1621 | def test_finish_after_finish(self): |
|
1652 | def test_finish_after_finish(self): | |
1622 | cctx = zstd.ZstdCompressor() |
|
1653 | cctx = zstd.ZstdCompressor() | |
1623 | chunker = cctx.chunker() |
|
1654 | chunker = cctx.chunker() | |
1624 |
|
1655 | |||
1625 | list(chunker.compress(b"foo")) |
|
1656 | list(chunker.compress(b"foo")) | |
1626 | list(chunker.finish()) |
|
1657 | list(chunker.finish()) | |
1627 |
|
1658 | |||
1628 | with self.assertRaisesRegex( |
|
1659 | with self.assertRaisesRegex( | |
1629 | zstd.ZstdError, r"cannot call finish\(\) after compression finished" |
|
1660 | zstd.ZstdError, r"cannot call finish\(\) after compression finished" | |
1630 | ): |
|
1661 | ): | |
1631 | list(chunker.finish()) |
|
1662 | list(chunker.finish()) | |
1632 |
|
1663 | |||
1633 |
|
1664 | |||
1634 | class TestCompressor_multi_compress_to_buffer(TestCase): |
|
1665 | class TestCompressor_multi_compress_to_buffer(TestCase): | |
1635 | def test_invalid_inputs(self): |
|
1666 | def test_invalid_inputs(self): | |
1636 | cctx = zstd.ZstdCompressor() |
|
1667 | cctx = zstd.ZstdCompressor() | |
1637 |
|
1668 | |||
1638 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
1669 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
1639 | self.skipTest("multi_compress_to_buffer not available") |
|
1670 | self.skipTest("multi_compress_to_buffer not available") | |
1640 |
|
1671 | |||
1641 | with self.assertRaises(TypeError): |
|
1672 | with self.assertRaises(TypeError): | |
1642 | cctx.multi_compress_to_buffer(True) |
|
1673 | cctx.multi_compress_to_buffer(True) | |
1643 |
|
1674 | |||
1644 | with self.assertRaises(TypeError): |
|
1675 | with self.assertRaises(TypeError): | |
1645 | cctx.multi_compress_to_buffer((1, 2)) |
|
1676 | cctx.multi_compress_to_buffer((1, 2)) | |
1646 |
|
1677 | |||
1647 |
with self.assertRaisesRegex( |
|
1678 | with self.assertRaisesRegex( | |
|
1679 | TypeError, "item 0 not a bytes like object" | |||
|
1680 | ): | |||
1648 | cctx.multi_compress_to_buffer([u"foo"]) |
|
1681 | cctx.multi_compress_to_buffer([u"foo"]) | |
1649 |
|
1682 | |||
1650 | def test_empty_input(self): |
|
1683 | def test_empty_input(self): | |
1651 | cctx = zstd.ZstdCompressor() |
|
1684 | cctx = zstd.ZstdCompressor() | |
1652 |
|
1685 | |||
1653 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
1686 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
1654 | self.skipTest("multi_compress_to_buffer not available") |
|
1687 | self.skipTest("multi_compress_to_buffer not available") | |
1655 |
|
1688 | |||
1656 | with self.assertRaisesRegex(ValueError, "no source elements found"): |
|
1689 | with self.assertRaisesRegex(ValueError, "no source elements found"): | |
1657 | cctx.multi_compress_to_buffer([]) |
|
1690 | cctx.multi_compress_to_buffer([]) | |
1658 |
|
1691 | |||
1659 | with self.assertRaisesRegex(ValueError, "source elements are empty"): |
|
1692 | with self.assertRaisesRegex(ValueError, "source elements are empty"): | |
1660 | cctx.multi_compress_to_buffer([b"", b"", b""]) |
|
1693 | cctx.multi_compress_to_buffer([b"", b"", b""]) | |
1661 |
|
1694 | |||
1662 | def test_list_input(self): |
|
1695 | def test_list_input(self): | |
1663 | cctx = zstd.ZstdCompressor(write_checksum=True) |
|
1696 | cctx = zstd.ZstdCompressor(write_checksum=True) | |
1664 |
|
1697 | |||
1665 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
1698 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
1666 | self.skipTest("multi_compress_to_buffer not available") |
|
1699 | self.skipTest("multi_compress_to_buffer not available") | |
1667 |
|
1700 | |||
1668 | original = [b"foo" * 12, b"bar" * 6] |
|
1701 | original = [b"foo" * 12, b"bar" * 6] | |
1669 | frames = [cctx.compress(c) for c in original] |
|
1702 | frames = [cctx.compress(c) for c in original] | |
1670 | b = cctx.multi_compress_to_buffer(original) |
|
1703 | b = cctx.multi_compress_to_buffer(original) | |
1671 |
|
1704 | |||
1672 | self.assertIsInstance(b, zstd.BufferWithSegmentsCollection) |
|
1705 | self.assertIsInstance(b, zstd.BufferWithSegmentsCollection) | |
1673 |
|
1706 | |||
1674 | self.assertEqual(len(b), 2) |
|
1707 | self.assertEqual(len(b), 2) | |
1675 | self.assertEqual(b.size(), 44) |
|
1708 | self.assertEqual(b.size(), 44) | |
1676 |
|
1709 | |||
1677 | self.assertEqual(b[0].tobytes(), frames[0]) |
|
1710 | self.assertEqual(b[0].tobytes(), frames[0]) | |
1678 | self.assertEqual(b[1].tobytes(), frames[1]) |
|
1711 | self.assertEqual(b[1].tobytes(), frames[1]) | |
1679 |
|
1712 | |||
1680 | def test_buffer_with_segments_input(self): |
|
1713 | def test_buffer_with_segments_input(self): | |
1681 | cctx = zstd.ZstdCompressor(write_checksum=True) |
|
1714 | cctx = zstd.ZstdCompressor(write_checksum=True) | |
1682 |
|
1715 | |||
1683 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
1716 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
1684 | self.skipTest("multi_compress_to_buffer not available") |
|
1717 | self.skipTest("multi_compress_to_buffer not available") | |
1685 |
|
1718 | |||
1686 | original = [b"foo" * 4, b"bar" * 6] |
|
1719 | original = [b"foo" * 4, b"bar" * 6] | |
1687 | frames = [cctx.compress(c) for c in original] |
|
1720 | frames = [cctx.compress(c) for c in original] | |
1688 |
|
1721 | |||
1689 | offsets = struct.pack( |
|
1722 | offsets = struct.pack( | |
1690 | "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) |
|
1723 | "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) | |
1691 | ) |
|
1724 | ) | |
1692 | segments = zstd.BufferWithSegments(b"".join(original), offsets) |
|
1725 | segments = zstd.BufferWithSegments(b"".join(original), offsets) | |
1693 |
|
1726 | |||
1694 | result = cctx.multi_compress_to_buffer(segments) |
|
1727 | result = cctx.multi_compress_to_buffer(segments) | |
1695 |
|
1728 | |||
1696 | self.assertEqual(len(result), 2) |
|
1729 | self.assertEqual(len(result), 2) | |
1697 | self.assertEqual(result.size(), 47) |
|
1730 | self.assertEqual(result.size(), 47) | |
1698 |
|
1731 | |||
1699 | self.assertEqual(result[0].tobytes(), frames[0]) |
|
1732 | self.assertEqual(result[0].tobytes(), frames[0]) | |
1700 | self.assertEqual(result[1].tobytes(), frames[1]) |
|
1733 | self.assertEqual(result[1].tobytes(), frames[1]) | |
1701 |
|
1734 | |||
1702 | def test_buffer_with_segments_collection_input(self): |
|
1735 | def test_buffer_with_segments_collection_input(self): | |
1703 | cctx = zstd.ZstdCompressor(write_checksum=True) |
|
1736 | cctx = zstd.ZstdCompressor(write_checksum=True) | |
1704 |
|
1737 | |||
1705 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
1738 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
1706 | self.skipTest("multi_compress_to_buffer not available") |
|
1739 | self.skipTest("multi_compress_to_buffer not available") | |
1707 |
|
1740 | |||
1708 | original = [ |
|
1741 | original = [ | |
1709 | b"foo1", |
|
1742 | b"foo1", | |
1710 | b"foo2" * 2, |
|
1743 | b"foo2" * 2, | |
1711 | b"foo3" * 3, |
|
1744 | b"foo3" * 3, | |
1712 | b"foo4" * 4, |
|
1745 | b"foo4" * 4, | |
1713 | b"foo5" * 5, |
|
1746 | b"foo5" * 5, | |
1714 | ] |
|
1747 | ] | |
1715 |
|
1748 | |||
1716 | frames = [cctx.compress(c) for c in original] |
|
1749 | frames = [cctx.compress(c) for c in original] | |
1717 |
|
1750 | |||
1718 | b = b"".join([original[0], original[1]]) |
|
1751 | b = b"".join([original[0], original[1]]) | |
1719 | b1 = zstd.BufferWithSegments( |
|
1752 | b1 = zstd.BufferWithSegments( | |
1720 | b, |
|
1753 | b, | |
1721 | struct.pack( |
|
1754 | struct.pack( | |
1722 | "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) |
|
1755 | "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) | |
1723 | ), |
|
1756 | ), | |
1724 | ) |
|
1757 | ) | |
1725 | b = b"".join([original[2], original[3], original[4]]) |
|
1758 | b = b"".join([original[2], original[3], original[4]]) | |
1726 | b2 = zstd.BufferWithSegments( |
|
1759 | b2 = zstd.BufferWithSegments( | |
1727 | b, |
|
1760 | b, | |
1728 | struct.pack( |
|
1761 | struct.pack( | |
1729 | "=QQQQQQ", |
|
1762 | "=QQQQQQ", | |
1730 | 0, |
|
1763 | 0, | |
1731 | len(original[2]), |
|
1764 | len(original[2]), | |
1732 | len(original[2]), |
|
1765 | len(original[2]), | |
1733 | len(original[3]), |
|
1766 | len(original[3]), | |
1734 | len(original[2]) + len(original[3]), |
|
1767 | len(original[2]) + len(original[3]), | |
1735 | len(original[4]), |
|
1768 | len(original[4]), | |
1736 | ), |
|
1769 | ), | |
1737 | ) |
|
1770 | ) | |
1738 |
|
1771 | |||
1739 | c = zstd.BufferWithSegmentsCollection(b1, b2) |
|
1772 | c = zstd.BufferWithSegmentsCollection(b1, b2) | |
1740 |
|
1773 | |||
1741 | result = cctx.multi_compress_to_buffer(c) |
|
1774 | result = cctx.multi_compress_to_buffer(c) | |
1742 |
|
1775 | |||
1743 | self.assertEqual(len(result), len(frames)) |
|
1776 | self.assertEqual(len(result), len(frames)) | |
1744 |
|
1777 | |||
1745 | for i, frame in enumerate(frames): |
|
1778 | for i, frame in enumerate(frames): | |
1746 | self.assertEqual(result[i].tobytes(), frame) |
|
1779 | self.assertEqual(result[i].tobytes(), frame) | |
1747 |
|
1780 | |||
1748 | def test_multiple_threads(self): |
|
1781 | def test_multiple_threads(self): | |
1749 | # threads argument will cause multi-threaded ZSTD APIs to be used, which will |
|
1782 | # threads argument will cause multi-threaded ZSTD APIs to be used, which will | |
1750 | # make output different. |
|
1783 | # make output different. | |
1751 | refcctx = zstd.ZstdCompressor(write_checksum=True) |
|
1784 | refcctx = zstd.ZstdCompressor(write_checksum=True) | |
1752 | reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)] |
|
1785 | reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)] | |
1753 |
|
1786 | |||
1754 | cctx = zstd.ZstdCompressor(write_checksum=True) |
|
1787 | cctx = zstd.ZstdCompressor(write_checksum=True) | |
1755 |
|
1788 | |||
1756 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
1789 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
1757 | self.skipTest("multi_compress_to_buffer not available") |
|
1790 | self.skipTest("multi_compress_to_buffer not available") | |
1758 |
|
1791 | |||
1759 | frames = [] |
|
1792 | frames = [] | |
1760 | frames.extend(b"x" * 64 for i in range(256)) |
|
1793 | frames.extend(b"x" * 64 for i in range(256)) | |
1761 | frames.extend(b"y" * 64 for i in range(256)) |
|
1794 | frames.extend(b"y" * 64 for i in range(256)) | |
1762 |
|
1795 | |||
1763 | result = cctx.multi_compress_to_buffer(frames, threads=-1) |
|
1796 | result = cctx.multi_compress_to_buffer(frames, threads=-1) | |
1764 |
|
1797 | |||
1765 | self.assertEqual(len(result), 512) |
|
1798 | self.assertEqual(len(result), 512) | |
1766 | for i in range(512): |
|
1799 | for i in range(512): | |
1767 | if i < 256: |
|
1800 | if i < 256: | |
1768 | self.assertEqual(result[i].tobytes(), reference[0]) |
|
1801 | self.assertEqual(result[i].tobytes(), reference[0]) | |
1769 | else: |
|
1802 | else: | |
1770 | self.assertEqual(result[i].tobytes(), reference[1]) |
|
1803 | self.assertEqual(result[i].tobytes(), reference[1]) |
@@ -1,836 +1,884 b'' | |||||
1 | import io |
|
1 | import io | |
2 | import os |
|
2 | import os | |
3 | import unittest |
|
3 | import unittest | |
4 |
|
4 | |||
5 | try: |
|
5 | try: | |
6 | import hypothesis |
|
6 | import hypothesis | |
7 | import hypothesis.strategies as strategies |
|
7 | import hypothesis.strategies as strategies | |
8 | except ImportError: |
|
8 | except ImportError: | |
9 | raise unittest.SkipTest("hypothesis not available") |
|
9 | raise unittest.SkipTest("hypothesis not available") | |
10 |
|
10 | |||
11 | import zstandard as zstd |
|
11 | import zstandard as zstd | |
12 |
|
12 | |||
13 | from .common import ( |
|
13 | from .common import ( | |
14 | make_cffi, |
|
14 | make_cffi, | |
15 | NonClosingBytesIO, |
|
15 | NonClosingBytesIO, | |
16 | random_input_data, |
|
16 | random_input_data, | |
17 | TestCase, |
|
17 | TestCase, | |
18 | ) |
|
18 | ) | |
19 |
|
19 | |||
20 |
|
20 | |||
21 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
21 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
22 | @make_cffi |
|
22 | @make_cffi | |
23 | class TestCompressor_stream_reader_fuzzing(TestCase): |
|
23 | class TestCompressor_stream_reader_fuzzing(TestCase): | |
24 | @hypothesis.settings( |
|
24 | @hypothesis.settings( | |
25 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
25 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
26 | ) |
|
26 | ) | |
27 | @hypothesis.given( |
|
27 | @hypothesis.given( | |
28 | original=strategies.sampled_from(random_input_data()), |
|
28 | original=strategies.sampled_from(random_input_data()), | |
29 | level=strategies.integers(min_value=1, max_value=5), |
|
29 | level=strategies.integers(min_value=1, max_value=5), | |
30 | source_read_size=strategies.integers(1, 16384), |
|
30 | source_read_size=strategies.integers(1, 16384), | |
31 |
read_size=strategies.integers( |
|
31 | read_size=strategies.integers( | |
|
32 | -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
33 | ), | |||
32 | ) |
|
34 | ) | |
33 | def test_stream_source_read(self, original, level, source_read_size, read_size): |
|
35 | def test_stream_source_read( | |
|
36 | self, original, level, source_read_size, read_size | |||
|
37 | ): | |||
34 | if read_size == 0: |
|
38 | if read_size == 0: | |
35 | read_size = -1 |
|
39 | read_size = -1 | |
36 |
|
40 | |||
37 | refctx = zstd.ZstdCompressor(level=level) |
|
41 | refctx = zstd.ZstdCompressor(level=level) | |
38 | ref_frame = refctx.compress(original) |
|
42 | ref_frame = refctx.compress(original) | |
39 |
|
43 | |||
40 | cctx = zstd.ZstdCompressor(level=level) |
|
44 | cctx = zstd.ZstdCompressor(level=level) | |
41 | with cctx.stream_reader( |
|
45 | with cctx.stream_reader( | |
42 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
46 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
43 | ) as reader: |
|
47 | ) as reader: | |
44 | chunks = [] |
|
48 | chunks = [] | |
45 | while True: |
|
49 | while True: | |
46 | chunk = reader.read(read_size) |
|
50 | chunk = reader.read(read_size) | |
47 | if not chunk: |
|
51 | if not chunk: | |
48 | break |
|
52 | break | |
49 |
|
53 | |||
50 | chunks.append(chunk) |
|
54 | chunks.append(chunk) | |
51 |
|
55 | |||
52 | self.assertEqual(b"".join(chunks), ref_frame) |
|
56 | self.assertEqual(b"".join(chunks), ref_frame) | |
53 |
|
57 | |||
54 | @hypothesis.settings( |
|
58 | @hypothesis.settings( | |
55 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
59 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
56 | ) |
|
60 | ) | |
57 | @hypothesis.given( |
|
61 | @hypothesis.given( | |
58 | original=strategies.sampled_from(random_input_data()), |
|
62 | original=strategies.sampled_from(random_input_data()), | |
59 | level=strategies.integers(min_value=1, max_value=5), |
|
63 | level=strategies.integers(min_value=1, max_value=5), | |
60 | source_read_size=strategies.integers(1, 16384), |
|
64 | source_read_size=strategies.integers(1, 16384), | |
61 |
read_size=strategies.integers( |
|
65 | read_size=strategies.integers( | |
|
66 | -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
67 | ), | |||
62 | ) |
|
68 | ) | |
63 | def test_buffer_source_read(self, original, level, source_read_size, read_size): |
|
69 | def test_buffer_source_read( | |
|
70 | self, original, level, source_read_size, read_size | |||
|
71 | ): | |||
64 | if read_size == 0: |
|
72 | if read_size == 0: | |
65 | read_size = -1 |
|
73 | read_size = -1 | |
66 |
|
74 | |||
67 | refctx = zstd.ZstdCompressor(level=level) |
|
75 | refctx = zstd.ZstdCompressor(level=level) | |
68 | ref_frame = refctx.compress(original) |
|
76 | ref_frame = refctx.compress(original) | |
69 |
|
77 | |||
70 | cctx = zstd.ZstdCompressor(level=level) |
|
78 | cctx = zstd.ZstdCompressor(level=level) | |
71 | with cctx.stream_reader( |
|
79 | with cctx.stream_reader( | |
72 | original, size=len(original), read_size=source_read_size |
|
80 | original, size=len(original), read_size=source_read_size | |
73 | ) as reader: |
|
81 | ) as reader: | |
74 | chunks = [] |
|
82 | chunks = [] | |
75 | while True: |
|
83 | while True: | |
76 | chunk = reader.read(read_size) |
|
84 | chunk = reader.read(read_size) | |
77 | if not chunk: |
|
85 | if not chunk: | |
78 | break |
|
86 | break | |
79 |
|
87 | |||
80 | chunks.append(chunk) |
|
88 | chunks.append(chunk) | |
81 |
|
89 | |||
82 | self.assertEqual(b"".join(chunks), ref_frame) |
|
90 | self.assertEqual(b"".join(chunks), ref_frame) | |
83 |
|
91 | |||
84 | @hypothesis.settings( |
|
92 | @hypothesis.settings( | |
85 | suppress_health_check=[ |
|
93 | suppress_health_check=[ | |
86 | hypothesis.HealthCheck.large_base_example, |
|
94 | hypothesis.HealthCheck.large_base_example, | |
87 | hypothesis.HealthCheck.too_slow, |
|
95 | hypothesis.HealthCheck.too_slow, | |
88 | ] |
|
96 | ] | |
89 | ) |
|
97 | ) | |
90 | @hypothesis.given( |
|
98 | @hypothesis.given( | |
91 | original=strategies.sampled_from(random_input_data()), |
|
99 | original=strategies.sampled_from(random_input_data()), | |
92 | level=strategies.integers(min_value=1, max_value=5), |
|
100 | level=strategies.integers(min_value=1, max_value=5), | |
93 | source_read_size=strategies.integers(1, 16384), |
|
101 | source_read_size=strategies.integers(1, 16384), | |
94 | read_sizes=strategies.data(), |
|
102 | read_sizes=strategies.data(), | |
95 | ) |
|
103 | ) | |
96 | def test_stream_source_read_variance( |
|
104 | def test_stream_source_read_variance( | |
97 | self, original, level, source_read_size, read_sizes |
|
105 | self, original, level, source_read_size, read_sizes | |
98 | ): |
|
106 | ): | |
99 | refctx = zstd.ZstdCompressor(level=level) |
|
107 | refctx = zstd.ZstdCompressor(level=level) | |
100 | ref_frame = refctx.compress(original) |
|
108 | ref_frame = refctx.compress(original) | |
101 |
|
109 | |||
102 | cctx = zstd.ZstdCompressor(level=level) |
|
110 | cctx = zstd.ZstdCompressor(level=level) | |
103 | with cctx.stream_reader( |
|
111 | with cctx.stream_reader( | |
104 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
112 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
105 | ) as reader: |
|
113 | ) as reader: | |
106 | chunks = [] |
|
114 | chunks = [] | |
107 | while True: |
|
115 | while True: | |
108 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) |
|
116 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) | |
109 | chunk = reader.read(read_size) |
|
117 | chunk = reader.read(read_size) | |
110 | if not chunk and read_size: |
|
118 | if not chunk and read_size: | |
111 | break |
|
119 | break | |
112 |
|
120 | |||
113 | chunks.append(chunk) |
|
121 | chunks.append(chunk) | |
114 |
|
122 | |||
115 | self.assertEqual(b"".join(chunks), ref_frame) |
|
123 | self.assertEqual(b"".join(chunks), ref_frame) | |
116 |
|
124 | |||
117 | @hypothesis.settings( |
|
125 | @hypothesis.settings( | |
118 | suppress_health_check=[ |
|
126 | suppress_health_check=[ | |
119 | hypothesis.HealthCheck.large_base_example, |
|
127 | hypothesis.HealthCheck.large_base_example, | |
120 | hypothesis.HealthCheck.too_slow, |
|
128 | hypothesis.HealthCheck.too_slow, | |
121 | ] |
|
129 | ] | |
122 | ) |
|
130 | ) | |
123 | @hypothesis.given( |
|
131 | @hypothesis.given( | |
124 | original=strategies.sampled_from(random_input_data()), |
|
132 | original=strategies.sampled_from(random_input_data()), | |
125 | level=strategies.integers(min_value=1, max_value=5), |
|
133 | level=strategies.integers(min_value=1, max_value=5), | |
126 | source_read_size=strategies.integers(1, 16384), |
|
134 | source_read_size=strategies.integers(1, 16384), | |
127 | read_sizes=strategies.data(), |
|
135 | read_sizes=strategies.data(), | |
128 | ) |
|
136 | ) | |
129 | def test_buffer_source_read_variance( |
|
137 | def test_buffer_source_read_variance( | |
130 | self, original, level, source_read_size, read_sizes |
|
138 | self, original, level, source_read_size, read_sizes | |
131 | ): |
|
139 | ): | |
132 |
|
140 | |||
133 | refctx = zstd.ZstdCompressor(level=level) |
|
141 | refctx = zstd.ZstdCompressor(level=level) | |
134 | ref_frame = refctx.compress(original) |
|
142 | ref_frame = refctx.compress(original) | |
135 |
|
143 | |||
136 | cctx = zstd.ZstdCompressor(level=level) |
|
144 | cctx = zstd.ZstdCompressor(level=level) | |
137 | with cctx.stream_reader( |
|
145 | with cctx.stream_reader( | |
138 | original, size=len(original), read_size=source_read_size |
|
146 | original, size=len(original), read_size=source_read_size | |
139 | ) as reader: |
|
147 | ) as reader: | |
140 | chunks = [] |
|
148 | chunks = [] | |
141 | while True: |
|
149 | while True: | |
142 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) |
|
150 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) | |
143 | chunk = reader.read(read_size) |
|
151 | chunk = reader.read(read_size) | |
144 | if not chunk and read_size: |
|
152 | if not chunk and read_size: | |
145 | break |
|
153 | break | |
146 |
|
154 | |||
147 | chunks.append(chunk) |
|
155 | chunks.append(chunk) | |
148 |
|
156 | |||
149 | self.assertEqual(b"".join(chunks), ref_frame) |
|
157 | self.assertEqual(b"".join(chunks), ref_frame) | |
150 |
|
158 | |||
151 | @hypothesis.settings( |
|
159 | @hypothesis.settings( | |
152 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
160 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
153 | ) |
|
161 | ) | |
154 | @hypothesis.given( |
|
162 | @hypothesis.given( | |
155 | original=strategies.sampled_from(random_input_data()), |
|
163 | original=strategies.sampled_from(random_input_data()), | |
156 | level=strategies.integers(min_value=1, max_value=5), |
|
164 | level=strategies.integers(min_value=1, max_value=5), | |
157 | source_read_size=strategies.integers(1, 16384), |
|
165 | source_read_size=strategies.integers(1, 16384), | |
158 |
read_size=strategies.integers( |
|
166 | read_size=strategies.integers( | |
|
167 | 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
168 | ), | |||
159 | ) |
|
169 | ) | |
160 | def test_stream_source_readinto(self, original, level, source_read_size, read_size): |
|
170 | def test_stream_source_readinto( | |
|
171 | self, original, level, source_read_size, read_size | |||
|
172 | ): | |||
161 | refctx = zstd.ZstdCompressor(level=level) |
|
173 | refctx = zstd.ZstdCompressor(level=level) | |
162 | ref_frame = refctx.compress(original) |
|
174 | ref_frame = refctx.compress(original) | |
163 |
|
175 | |||
164 | cctx = zstd.ZstdCompressor(level=level) |
|
176 | cctx = zstd.ZstdCompressor(level=level) | |
165 | with cctx.stream_reader( |
|
177 | with cctx.stream_reader( | |
166 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
178 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
167 | ) as reader: |
|
179 | ) as reader: | |
168 | chunks = [] |
|
180 | chunks = [] | |
169 | while True: |
|
181 | while True: | |
170 | b = bytearray(read_size) |
|
182 | b = bytearray(read_size) | |
171 | count = reader.readinto(b) |
|
183 | count = reader.readinto(b) | |
172 |
|
184 | |||
173 | if not count: |
|
185 | if not count: | |
174 | break |
|
186 | break | |
175 |
|
187 | |||
176 | chunks.append(bytes(b[0:count])) |
|
188 | chunks.append(bytes(b[0:count])) | |
177 |
|
189 | |||
178 | self.assertEqual(b"".join(chunks), ref_frame) |
|
190 | self.assertEqual(b"".join(chunks), ref_frame) | |
179 |
|
191 | |||
180 | @hypothesis.settings( |
|
192 | @hypothesis.settings( | |
181 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
193 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
182 | ) |
|
194 | ) | |
183 | @hypothesis.given( |
|
195 | @hypothesis.given( | |
184 | original=strategies.sampled_from(random_input_data()), |
|
196 | original=strategies.sampled_from(random_input_data()), | |
185 | level=strategies.integers(min_value=1, max_value=5), |
|
197 | level=strategies.integers(min_value=1, max_value=5), | |
186 | source_read_size=strategies.integers(1, 16384), |
|
198 | source_read_size=strategies.integers(1, 16384), | |
187 |
read_size=strategies.integers( |
|
199 | read_size=strategies.integers( | |
|
200 | 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
201 | ), | |||
188 | ) |
|
202 | ) | |
189 | def test_buffer_source_readinto(self, original, level, source_read_size, read_size): |
|
203 | def test_buffer_source_readinto( | |
|
204 | self, original, level, source_read_size, read_size | |||
|
205 | ): | |||
190 |
|
206 | |||
191 | refctx = zstd.ZstdCompressor(level=level) |
|
207 | refctx = zstd.ZstdCompressor(level=level) | |
192 | ref_frame = refctx.compress(original) |
|
208 | ref_frame = refctx.compress(original) | |
193 |
|
209 | |||
194 | cctx = zstd.ZstdCompressor(level=level) |
|
210 | cctx = zstd.ZstdCompressor(level=level) | |
195 | with cctx.stream_reader( |
|
211 | with cctx.stream_reader( | |
196 | original, size=len(original), read_size=source_read_size |
|
212 | original, size=len(original), read_size=source_read_size | |
197 | ) as reader: |
|
213 | ) as reader: | |
198 | chunks = [] |
|
214 | chunks = [] | |
199 | while True: |
|
215 | while True: | |
200 | b = bytearray(read_size) |
|
216 | b = bytearray(read_size) | |
201 | count = reader.readinto(b) |
|
217 | count = reader.readinto(b) | |
202 |
|
218 | |||
203 | if not count: |
|
219 | if not count: | |
204 | break |
|
220 | break | |
205 |
|
221 | |||
206 | chunks.append(bytes(b[0:count])) |
|
222 | chunks.append(bytes(b[0:count])) | |
207 |
|
223 | |||
208 | self.assertEqual(b"".join(chunks), ref_frame) |
|
224 | self.assertEqual(b"".join(chunks), ref_frame) | |
209 |
|
225 | |||
210 | @hypothesis.settings( |
|
226 | @hypothesis.settings( | |
211 | suppress_health_check=[ |
|
227 | suppress_health_check=[ | |
212 | hypothesis.HealthCheck.large_base_example, |
|
228 | hypothesis.HealthCheck.large_base_example, | |
213 | hypothesis.HealthCheck.too_slow, |
|
229 | hypothesis.HealthCheck.too_slow, | |
214 | ] |
|
230 | ] | |
215 | ) |
|
231 | ) | |
216 | @hypothesis.given( |
|
232 | @hypothesis.given( | |
217 | original=strategies.sampled_from(random_input_data()), |
|
233 | original=strategies.sampled_from(random_input_data()), | |
218 | level=strategies.integers(min_value=1, max_value=5), |
|
234 | level=strategies.integers(min_value=1, max_value=5), | |
219 | source_read_size=strategies.integers(1, 16384), |
|
235 | source_read_size=strategies.integers(1, 16384), | |
220 | read_sizes=strategies.data(), |
|
236 | read_sizes=strategies.data(), | |
221 | ) |
|
237 | ) | |
222 | def test_stream_source_readinto_variance( |
|
238 | def test_stream_source_readinto_variance( | |
223 | self, original, level, source_read_size, read_sizes |
|
239 | self, original, level, source_read_size, read_sizes | |
224 | ): |
|
240 | ): | |
225 | refctx = zstd.ZstdCompressor(level=level) |
|
241 | refctx = zstd.ZstdCompressor(level=level) | |
226 | ref_frame = refctx.compress(original) |
|
242 | ref_frame = refctx.compress(original) | |
227 |
|
243 | |||
228 | cctx = zstd.ZstdCompressor(level=level) |
|
244 | cctx = zstd.ZstdCompressor(level=level) | |
229 | with cctx.stream_reader( |
|
245 | with cctx.stream_reader( | |
230 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
246 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
231 | ) as reader: |
|
247 | ) as reader: | |
232 | chunks = [] |
|
248 | chunks = [] | |
233 | while True: |
|
249 | while True: | |
234 | read_size = read_sizes.draw(strategies.integers(1, 16384)) |
|
250 | read_size = read_sizes.draw(strategies.integers(1, 16384)) | |
235 | b = bytearray(read_size) |
|
251 | b = bytearray(read_size) | |
236 | count = reader.readinto(b) |
|
252 | count = reader.readinto(b) | |
237 |
|
253 | |||
238 | if not count: |
|
254 | if not count: | |
239 | break |
|
255 | break | |
240 |
|
256 | |||
241 | chunks.append(bytes(b[0:count])) |
|
257 | chunks.append(bytes(b[0:count])) | |
242 |
|
258 | |||
243 | self.assertEqual(b"".join(chunks), ref_frame) |
|
259 | self.assertEqual(b"".join(chunks), ref_frame) | |
244 |
|
260 | |||
245 | @hypothesis.settings( |
|
261 | @hypothesis.settings( | |
246 | suppress_health_check=[ |
|
262 | suppress_health_check=[ | |
247 | hypothesis.HealthCheck.large_base_example, |
|
263 | hypothesis.HealthCheck.large_base_example, | |
248 | hypothesis.HealthCheck.too_slow, |
|
264 | hypothesis.HealthCheck.too_slow, | |
249 | ] |
|
265 | ] | |
250 | ) |
|
266 | ) | |
251 | @hypothesis.given( |
|
267 | @hypothesis.given( | |
252 | original=strategies.sampled_from(random_input_data()), |
|
268 | original=strategies.sampled_from(random_input_data()), | |
253 | level=strategies.integers(min_value=1, max_value=5), |
|
269 | level=strategies.integers(min_value=1, max_value=5), | |
254 | source_read_size=strategies.integers(1, 16384), |
|
270 | source_read_size=strategies.integers(1, 16384), | |
255 | read_sizes=strategies.data(), |
|
271 | read_sizes=strategies.data(), | |
256 | ) |
|
272 | ) | |
257 | def test_buffer_source_readinto_variance( |
|
273 | def test_buffer_source_readinto_variance( | |
258 | self, original, level, source_read_size, read_sizes |
|
274 | self, original, level, source_read_size, read_sizes | |
259 | ): |
|
275 | ): | |
260 |
|
276 | |||
261 | refctx = zstd.ZstdCompressor(level=level) |
|
277 | refctx = zstd.ZstdCompressor(level=level) | |
262 | ref_frame = refctx.compress(original) |
|
278 | ref_frame = refctx.compress(original) | |
263 |
|
279 | |||
264 | cctx = zstd.ZstdCompressor(level=level) |
|
280 | cctx = zstd.ZstdCompressor(level=level) | |
265 | with cctx.stream_reader( |
|
281 | with cctx.stream_reader( | |
266 | original, size=len(original), read_size=source_read_size |
|
282 | original, size=len(original), read_size=source_read_size | |
267 | ) as reader: |
|
283 | ) as reader: | |
268 | chunks = [] |
|
284 | chunks = [] | |
269 | while True: |
|
285 | while True: | |
270 | read_size = read_sizes.draw(strategies.integers(1, 16384)) |
|
286 | read_size = read_sizes.draw(strategies.integers(1, 16384)) | |
271 | b = bytearray(read_size) |
|
287 | b = bytearray(read_size) | |
272 | count = reader.readinto(b) |
|
288 | count = reader.readinto(b) | |
273 |
|
289 | |||
274 | if not count: |
|
290 | if not count: | |
275 | break |
|
291 | break | |
276 |
|
292 | |||
277 | chunks.append(bytes(b[0:count])) |
|
293 | chunks.append(bytes(b[0:count])) | |
278 |
|
294 | |||
279 | self.assertEqual(b"".join(chunks), ref_frame) |
|
295 | self.assertEqual(b"".join(chunks), ref_frame) | |
280 |
|
296 | |||
281 | @hypothesis.settings( |
|
297 | @hypothesis.settings( | |
282 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
298 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
283 | ) |
|
299 | ) | |
284 | @hypothesis.given( |
|
300 | @hypothesis.given( | |
285 | original=strategies.sampled_from(random_input_data()), |
|
301 | original=strategies.sampled_from(random_input_data()), | |
286 | level=strategies.integers(min_value=1, max_value=5), |
|
302 | level=strategies.integers(min_value=1, max_value=5), | |
287 | source_read_size=strategies.integers(1, 16384), |
|
303 | source_read_size=strategies.integers(1, 16384), | |
288 |
read_size=strategies.integers( |
|
304 | read_size=strategies.integers( | |
|
305 | -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
306 | ), | |||
289 | ) |
|
307 | ) | |
290 | def test_stream_source_read1(self, original, level, source_read_size, read_size): |
|
308 | def test_stream_source_read1( | |
|
309 | self, original, level, source_read_size, read_size | |||
|
310 | ): | |||
291 | if read_size == 0: |
|
311 | if read_size == 0: | |
292 | read_size = -1 |
|
312 | read_size = -1 | |
293 |
|
313 | |||
294 | refctx = zstd.ZstdCompressor(level=level) |
|
314 | refctx = zstd.ZstdCompressor(level=level) | |
295 | ref_frame = refctx.compress(original) |
|
315 | ref_frame = refctx.compress(original) | |
296 |
|
316 | |||
297 | cctx = zstd.ZstdCompressor(level=level) |
|
317 | cctx = zstd.ZstdCompressor(level=level) | |
298 | with cctx.stream_reader( |
|
318 | with cctx.stream_reader( | |
299 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
319 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
300 | ) as reader: |
|
320 | ) as reader: | |
301 | chunks = [] |
|
321 | chunks = [] | |
302 | while True: |
|
322 | while True: | |
303 | chunk = reader.read1(read_size) |
|
323 | chunk = reader.read1(read_size) | |
304 | if not chunk: |
|
324 | if not chunk: | |
305 | break |
|
325 | break | |
306 |
|
326 | |||
307 | chunks.append(chunk) |
|
327 | chunks.append(chunk) | |
308 |
|
328 | |||
309 | self.assertEqual(b"".join(chunks), ref_frame) |
|
329 | self.assertEqual(b"".join(chunks), ref_frame) | |
310 |
|
330 | |||
311 | @hypothesis.settings( |
|
331 | @hypothesis.settings( | |
312 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
332 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
313 | ) |
|
333 | ) | |
314 | @hypothesis.given( |
|
334 | @hypothesis.given( | |
315 | original=strategies.sampled_from(random_input_data()), |
|
335 | original=strategies.sampled_from(random_input_data()), | |
316 | level=strategies.integers(min_value=1, max_value=5), |
|
336 | level=strategies.integers(min_value=1, max_value=5), | |
317 | source_read_size=strategies.integers(1, 16384), |
|
337 | source_read_size=strategies.integers(1, 16384), | |
318 |
read_size=strategies.integers( |
|
338 | read_size=strategies.integers( | |
|
339 | -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
340 | ), | |||
319 | ) |
|
341 | ) | |
320 | def test_buffer_source_read1(self, original, level, source_read_size, read_size): |
|
342 | def test_buffer_source_read1( | |
|
343 | self, original, level, source_read_size, read_size | |||
|
344 | ): | |||
321 | if read_size == 0: |
|
345 | if read_size == 0: | |
322 | read_size = -1 |
|
346 | read_size = -1 | |
323 |
|
347 | |||
324 | refctx = zstd.ZstdCompressor(level=level) |
|
348 | refctx = zstd.ZstdCompressor(level=level) | |
325 | ref_frame = refctx.compress(original) |
|
349 | ref_frame = refctx.compress(original) | |
326 |
|
350 | |||
327 | cctx = zstd.ZstdCompressor(level=level) |
|
351 | cctx = zstd.ZstdCompressor(level=level) | |
328 | with cctx.stream_reader( |
|
352 | with cctx.stream_reader( | |
329 | original, size=len(original), read_size=source_read_size |
|
353 | original, size=len(original), read_size=source_read_size | |
330 | ) as reader: |
|
354 | ) as reader: | |
331 | chunks = [] |
|
355 | chunks = [] | |
332 | while True: |
|
356 | while True: | |
333 | chunk = reader.read1(read_size) |
|
357 | chunk = reader.read1(read_size) | |
334 | if not chunk: |
|
358 | if not chunk: | |
335 | break |
|
359 | break | |
336 |
|
360 | |||
337 | chunks.append(chunk) |
|
361 | chunks.append(chunk) | |
338 |
|
362 | |||
339 | self.assertEqual(b"".join(chunks), ref_frame) |
|
363 | self.assertEqual(b"".join(chunks), ref_frame) | |
340 |
|
364 | |||
341 | @hypothesis.settings( |
|
365 | @hypothesis.settings( | |
342 | suppress_health_check=[ |
|
366 | suppress_health_check=[ | |
343 | hypothesis.HealthCheck.large_base_example, |
|
367 | hypothesis.HealthCheck.large_base_example, | |
344 | hypothesis.HealthCheck.too_slow, |
|
368 | hypothesis.HealthCheck.too_slow, | |
345 | ] |
|
369 | ] | |
346 | ) |
|
370 | ) | |
347 | @hypothesis.given( |
|
371 | @hypothesis.given( | |
348 | original=strategies.sampled_from(random_input_data()), |
|
372 | original=strategies.sampled_from(random_input_data()), | |
349 | level=strategies.integers(min_value=1, max_value=5), |
|
373 | level=strategies.integers(min_value=1, max_value=5), | |
350 | source_read_size=strategies.integers(1, 16384), |
|
374 | source_read_size=strategies.integers(1, 16384), | |
351 | read_sizes=strategies.data(), |
|
375 | read_sizes=strategies.data(), | |
352 | ) |
|
376 | ) | |
353 | def test_stream_source_read1_variance( |
|
377 | def test_stream_source_read1_variance( | |
354 | self, original, level, source_read_size, read_sizes |
|
378 | self, original, level, source_read_size, read_sizes | |
355 | ): |
|
379 | ): | |
356 | refctx = zstd.ZstdCompressor(level=level) |
|
380 | refctx = zstd.ZstdCompressor(level=level) | |
357 | ref_frame = refctx.compress(original) |
|
381 | ref_frame = refctx.compress(original) | |
358 |
|
382 | |||
359 | cctx = zstd.ZstdCompressor(level=level) |
|
383 | cctx = zstd.ZstdCompressor(level=level) | |
360 | with cctx.stream_reader( |
|
384 | with cctx.stream_reader( | |
361 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
385 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
362 | ) as reader: |
|
386 | ) as reader: | |
363 | chunks = [] |
|
387 | chunks = [] | |
364 | while True: |
|
388 | while True: | |
365 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) |
|
389 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) | |
366 | chunk = reader.read1(read_size) |
|
390 | chunk = reader.read1(read_size) | |
367 | if not chunk and read_size: |
|
391 | if not chunk and read_size: | |
368 | break |
|
392 | break | |
369 |
|
393 | |||
370 | chunks.append(chunk) |
|
394 | chunks.append(chunk) | |
371 |
|
395 | |||
372 | self.assertEqual(b"".join(chunks), ref_frame) |
|
396 | self.assertEqual(b"".join(chunks), ref_frame) | |
373 |
|
397 | |||
374 | @hypothesis.settings( |
|
398 | @hypothesis.settings( | |
375 | suppress_health_check=[ |
|
399 | suppress_health_check=[ | |
376 | hypothesis.HealthCheck.large_base_example, |
|
400 | hypothesis.HealthCheck.large_base_example, | |
377 | hypothesis.HealthCheck.too_slow, |
|
401 | hypothesis.HealthCheck.too_slow, | |
378 | ] |
|
402 | ] | |
379 | ) |
|
403 | ) | |
380 | @hypothesis.given( |
|
404 | @hypothesis.given( | |
381 | original=strategies.sampled_from(random_input_data()), |
|
405 | original=strategies.sampled_from(random_input_data()), | |
382 | level=strategies.integers(min_value=1, max_value=5), |
|
406 | level=strategies.integers(min_value=1, max_value=5), | |
383 | source_read_size=strategies.integers(1, 16384), |
|
407 | source_read_size=strategies.integers(1, 16384), | |
384 | read_sizes=strategies.data(), |
|
408 | read_sizes=strategies.data(), | |
385 | ) |
|
409 | ) | |
386 | def test_buffer_source_read1_variance( |
|
410 | def test_buffer_source_read1_variance( | |
387 | self, original, level, source_read_size, read_sizes |
|
411 | self, original, level, source_read_size, read_sizes | |
388 | ): |
|
412 | ): | |
389 |
|
413 | |||
390 | refctx = zstd.ZstdCompressor(level=level) |
|
414 | refctx = zstd.ZstdCompressor(level=level) | |
391 | ref_frame = refctx.compress(original) |
|
415 | ref_frame = refctx.compress(original) | |
392 |
|
416 | |||
393 | cctx = zstd.ZstdCompressor(level=level) |
|
417 | cctx = zstd.ZstdCompressor(level=level) | |
394 | with cctx.stream_reader( |
|
418 | with cctx.stream_reader( | |
395 | original, size=len(original), read_size=source_read_size |
|
419 | original, size=len(original), read_size=source_read_size | |
396 | ) as reader: |
|
420 | ) as reader: | |
397 | chunks = [] |
|
421 | chunks = [] | |
398 | while True: |
|
422 | while True: | |
399 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) |
|
423 | read_size = read_sizes.draw(strategies.integers(-1, 16384)) | |
400 | chunk = reader.read1(read_size) |
|
424 | chunk = reader.read1(read_size) | |
401 | if not chunk and read_size: |
|
425 | if not chunk and read_size: | |
402 | break |
|
426 | break | |
403 |
|
427 | |||
404 | chunks.append(chunk) |
|
428 | chunks.append(chunk) | |
405 |
|
429 | |||
406 | self.assertEqual(b"".join(chunks), ref_frame) |
|
430 | self.assertEqual(b"".join(chunks), ref_frame) | |
407 |
|
431 | |||
408 | @hypothesis.settings( |
|
432 | @hypothesis.settings( | |
409 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
433 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
410 | ) |
|
434 | ) | |
411 | @hypothesis.given( |
|
435 | @hypothesis.given( | |
412 | original=strategies.sampled_from(random_input_data()), |
|
436 | original=strategies.sampled_from(random_input_data()), | |
413 | level=strategies.integers(min_value=1, max_value=5), |
|
437 | level=strategies.integers(min_value=1, max_value=5), | |
414 | source_read_size=strategies.integers(1, 16384), |
|
438 | source_read_size=strategies.integers(1, 16384), | |
415 |
read_size=strategies.integers( |
|
439 | read_size=strategies.integers( | |
|
440 | 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
441 | ), | |||
416 | ) |
|
442 | ) | |
417 | def test_stream_source_readinto1( |
|
443 | def test_stream_source_readinto1( | |
418 | self, original, level, source_read_size, read_size |
|
444 | self, original, level, source_read_size, read_size | |
419 | ): |
|
445 | ): | |
420 | if read_size == 0: |
|
446 | if read_size == 0: | |
421 | read_size = -1 |
|
447 | read_size = -1 | |
422 |
|
448 | |||
423 | refctx = zstd.ZstdCompressor(level=level) |
|
449 | refctx = zstd.ZstdCompressor(level=level) | |
424 | ref_frame = refctx.compress(original) |
|
450 | ref_frame = refctx.compress(original) | |
425 |
|
451 | |||
426 | cctx = zstd.ZstdCompressor(level=level) |
|
452 | cctx = zstd.ZstdCompressor(level=level) | |
427 | with cctx.stream_reader( |
|
453 | with cctx.stream_reader( | |
428 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
454 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
429 | ) as reader: |
|
455 | ) as reader: | |
430 | chunks = [] |
|
456 | chunks = [] | |
431 | while True: |
|
457 | while True: | |
432 | b = bytearray(read_size) |
|
458 | b = bytearray(read_size) | |
433 | count = reader.readinto1(b) |
|
459 | count = reader.readinto1(b) | |
434 |
|
460 | |||
435 | if not count: |
|
461 | if not count: | |
436 | break |
|
462 | break | |
437 |
|
463 | |||
438 | chunks.append(bytes(b[0:count])) |
|
464 | chunks.append(bytes(b[0:count])) | |
439 |
|
465 | |||
440 | self.assertEqual(b"".join(chunks), ref_frame) |
|
466 | self.assertEqual(b"".join(chunks), ref_frame) | |
441 |
|
467 | |||
442 | @hypothesis.settings( |
|
468 | @hypothesis.settings( | |
443 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
469 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
444 | ) |
|
470 | ) | |
445 | @hypothesis.given( |
|
471 | @hypothesis.given( | |
446 | original=strategies.sampled_from(random_input_data()), |
|
472 | original=strategies.sampled_from(random_input_data()), | |
447 | level=strategies.integers(min_value=1, max_value=5), |
|
473 | level=strategies.integers(min_value=1, max_value=5), | |
448 | source_read_size=strategies.integers(1, 16384), |
|
474 | source_read_size=strategies.integers(1, 16384), | |
449 |
read_size=strategies.integers( |
|
475 | read_size=strategies.integers( | |
|
476 | 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
477 | ), | |||
450 | ) |
|
478 | ) | |
451 | def test_buffer_source_readinto1( |
|
479 | def test_buffer_source_readinto1( | |
452 | self, original, level, source_read_size, read_size |
|
480 | self, original, level, source_read_size, read_size | |
453 | ): |
|
481 | ): | |
454 | if read_size == 0: |
|
482 | if read_size == 0: | |
455 | read_size = -1 |
|
483 | read_size = -1 | |
456 |
|
484 | |||
457 | refctx = zstd.ZstdCompressor(level=level) |
|
485 | refctx = zstd.ZstdCompressor(level=level) | |
458 | ref_frame = refctx.compress(original) |
|
486 | ref_frame = refctx.compress(original) | |
459 |
|
487 | |||
460 | cctx = zstd.ZstdCompressor(level=level) |
|
488 | cctx = zstd.ZstdCompressor(level=level) | |
461 | with cctx.stream_reader( |
|
489 | with cctx.stream_reader( | |
462 | original, size=len(original), read_size=source_read_size |
|
490 | original, size=len(original), read_size=source_read_size | |
463 | ) as reader: |
|
491 | ) as reader: | |
464 | chunks = [] |
|
492 | chunks = [] | |
465 | while True: |
|
493 | while True: | |
466 | b = bytearray(read_size) |
|
494 | b = bytearray(read_size) | |
467 | count = reader.readinto1(b) |
|
495 | count = reader.readinto1(b) | |
468 |
|
496 | |||
469 | if not count: |
|
497 | if not count: | |
470 | break |
|
498 | break | |
471 |
|
499 | |||
472 | chunks.append(bytes(b[0:count])) |
|
500 | chunks.append(bytes(b[0:count])) | |
473 |
|
501 | |||
474 | self.assertEqual(b"".join(chunks), ref_frame) |
|
502 | self.assertEqual(b"".join(chunks), ref_frame) | |
475 |
|
503 | |||
476 | @hypothesis.settings( |
|
504 | @hypothesis.settings( | |
477 | suppress_health_check=[ |
|
505 | suppress_health_check=[ | |
478 | hypothesis.HealthCheck.large_base_example, |
|
506 | hypothesis.HealthCheck.large_base_example, | |
479 | hypothesis.HealthCheck.too_slow, |
|
507 | hypothesis.HealthCheck.too_slow, | |
480 | ] |
|
508 | ] | |
481 | ) |
|
509 | ) | |
482 | @hypothesis.given( |
|
510 | @hypothesis.given( | |
483 | original=strategies.sampled_from(random_input_data()), |
|
511 | original=strategies.sampled_from(random_input_data()), | |
484 | level=strategies.integers(min_value=1, max_value=5), |
|
512 | level=strategies.integers(min_value=1, max_value=5), | |
485 | source_read_size=strategies.integers(1, 16384), |
|
513 | source_read_size=strategies.integers(1, 16384), | |
486 | read_sizes=strategies.data(), |
|
514 | read_sizes=strategies.data(), | |
487 | ) |
|
515 | ) | |
488 | def test_stream_source_readinto1_variance( |
|
516 | def test_stream_source_readinto1_variance( | |
489 | self, original, level, source_read_size, read_sizes |
|
517 | self, original, level, source_read_size, read_sizes | |
490 | ): |
|
518 | ): | |
491 | refctx = zstd.ZstdCompressor(level=level) |
|
519 | refctx = zstd.ZstdCompressor(level=level) | |
492 | ref_frame = refctx.compress(original) |
|
520 | ref_frame = refctx.compress(original) | |
493 |
|
521 | |||
494 | cctx = zstd.ZstdCompressor(level=level) |
|
522 | cctx = zstd.ZstdCompressor(level=level) | |
495 | with cctx.stream_reader( |
|
523 | with cctx.stream_reader( | |
496 | io.BytesIO(original), size=len(original), read_size=source_read_size |
|
524 | io.BytesIO(original), size=len(original), read_size=source_read_size | |
497 | ) as reader: |
|
525 | ) as reader: | |
498 | chunks = [] |
|
526 | chunks = [] | |
499 | while True: |
|
527 | while True: | |
500 | read_size = read_sizes.draw(strategies.integers(1, 16384)) |
|
528 | read_size = read_sizes.draw(strategies.integers(1, 16384)) | |
501 | b = bytearray(read_size) |
|
529 | b = bytearray(read_size) | |
502 | count = reader.readinto1(b) |
|
530 | count = reader.readinto1(b) | |
503 |
|
531 | |||
504 | if not count: |
|
532 | if not count: | |
505 | break |
|
533 | break | |
506 |
|
534 | |||
507 | chunks.append(bytes(b[0:count])) |
|
535 | chunks.append(bytes(b[0:count])) | |
508 |
|
536 | |||
509 | self.assertEqual(b"".join(chunks), ref_frame) |
|
537 | self.assertEqual(b"".join(chunks), ref_frame) | |
510 |
|
538 | |||
511 | @hypothesis.settings( |
|
539 | @hypothesis.settings( | |
512 | suppress_health_check=[ |
|
540 | suppress_health_check=[ | |
513 | hypothesis.HealthCheck.large_base_example, |
|
541 | hypothesis.HealthCheck.large_base_example, | |
514 | hypothesis.HealthCheck.too_slow, |
|
542 | hypothesis.HealthCheck.too_slow, | |
515 | ] |
|
543 | ] | |
516 | ) |
|
544 | ) | |
517 | @hypothesis.given( |
|
545 | @hypothesis.given( | |
518 | original=strategies.sampled_from(random_input_data()), |
|
546 | original=strategies.sampled_from(random_input_data()), | |
519 | level=strategies.integers(min_value=1, max_value=5), |
|
547 | level=strategies.integers(min_value=1, max_value=5), | |
520 | source_read_size=strategies.integers(1, 16384), |
|
548 | source_read_size=strategies.integers(1, 16384), | |
521 | read_sizes=strategies.data(), |
|
549 | read_sizes=strategies.data(), | |
522 | ) |
|
550 | ) | |
523 | def test_buffer_source_readinto1_variance( |
|
551 | def test_buffer_source_readinto1_variance( | |
524 | self, original, level, source_read_size, read_sizes |
|
552 | self, original, level, source_read_size, read_sizes | |
525 | ): |
|
553 | ): | |
526 |
|
554 | |||
527 | refctx = zstd.ZstdCompressor(level=level) |
|
555 | refctx = zstd.ZstdCompressor(level=level) | |
528 | ref_frame = refctx.compress(original) |
|
556 | ref_frame = refctx.compress(original) | |
529 |
|
557 | |||
530 | cctx = zstd.ZstdCompressor(level=level) |
|
558 | cctx = zstd.ZstdCompressor(level=level) | |
531 | with cctx.stream_reader( |
|
559 | with cctx.stream_reader( | |
532 | original, size=len(original), read_size=source_read_size |
|
560 | original, size=len(original), read_size=source_read_size | |
533 | ) as reader: |
|
561 | ) as reader: | |
534 | chunks = [] |
|
562 | chunks = [] | |
535 | while True: |
|
563 | while True: | |
536 | read_size = read_sizes.draw(strategies.integers(1, 16384)) |
|
564 | read_size = read_sizes.draw(strategies.integers(1, 16384)) | |
537 | b = bytearray(read_size) |
|
565 | b = bytearray(read_size) | |
538 | count = reader.readinto1(b) |
|
566 | count = reader.readinto1(b) | |
539 |
|
567 | |||
540 | if not count: |
|
568 | if not count: | |
541 | break |
|
569 | break | |
542 |
|
570 | |||
543 | chunks.append(bytes(b[0:count])) |
|
571 | chunks.append(bytes(b[0:count])) | |
544 |
|
572 | |||
545 | self.assertEqual(b"".join(chunks), ref_frame) |
|
573 | self.assertEqual(b"".join(chunks), ref_frame) | |
546 |
|
574 | |||
547 |
|
575 | |||
548 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
576 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
549 | @make_cffi |
|
577 | @make_cffi | |
550 | class TestCompressor_stream_writer_fuzzing(TestCase): |
|
578 | class TestCompressor_stream_writer_fuzzing(TestCase): | |
551 | @hypothesis.given( |
|
579 | @hypothesis.given( | |
552 | original=strategies.sampled_from(random_input_data()), |
|
580 | original=strategies.sampled_from(random_input_data()), | |
553 | level=strategies.integers(min_value=1, max_value=5), |
|
581 | level=strategies.integers(min_value=1, max_value=5), | |
554 | write_size=strategies.integers(min_value=1, max_value=1048576), |
|
582 | write_size=strategies.integers(min_value=1, max_value=1048576), | |
555 | ) |
|
583 | ) | |
556 | def test_write_size_variance(self, original, level, write_size): |
|
584 | def test_write_size_variance(self, original, level, write_size): | |
557 | refctx = zstd.ZstdCompressor(level=level) |
|
585 | refctx = zstd.ZstdCompressor(level=level) | |
558 | ref_frame = refctx.compress(original) |
|
586 | ref_frame = refctx.compress(original) | |
559 |
|
587 | |||
560 | cctx = zstd.ZstdCompressor(level=level) |
|
588 | cctx = zstd.ZstdCompressor(level=level) | |
561 | b = NonClosingBytesIO() |
|
589 | b = NonClosingBytesIO() | |
562 | with cctx.stream_writer( |
|
590 | with cctx.stream_writer( | |
563 | b, size=len(original), write_size=write_size |
|
591 | b, size=len(original), write_size=write_size | |
564 | ) as compressor: |
|
592 | ) as compressor: | |
565 | compressor.write(original) |
|
593 | compressor.write(original) | |
566 |
|
594 | |||
567 | self.assertEqual(b.getvalue(), ref_frame) |
|
595 | self.assertEqual(b.getvalue(), ref_frame) | |
568 |
|
596 | |||
569 |
|
597 | |||
570 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
598 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
571 | @make_cffi |
|
599 | @make_cffi | |
572 | class TestCompressor_copy_stream_fuzzing(TestCase): |
|
600 | class TestCompressor_copy_stream_fuzzing(TestCase): | |
573 | @hypothesis.given( |
|
601 | @hypothesis.given( | |
574 | original=strategies.sampled_from(random_input_data()), |
|
602 | original=strategies.sampled_from(random_input_data()), | |
575 | level=strategies.integers(min_value=1, max_value=5), |
|
603 | level=strategies.integers(min_value=1, max_value=5), | |
576 | read_size=strategies.integers(min_value=1, max_value=1048576), |
|
604 | read_size=strategies.integers(min_value=1, max_value=1048576), | |
577 | write_size=strategies.integers(min_value=1, max_value=1048576), |
|
605 | write_size=strategies.integers(min_value=1, max_value=1048576), | |
578 | ) |
|
606 | ) | |
579 |
def test_read_write_size_variance( |
|
607 | def test_read_write_size_variance( | |
|
608 | self, original, level, read_size, write_size | |||
|
609 | ): | |||
580 | refctx = zstd.ZstdCompressor(level=level) |
|
610 | refctx = zstd.ZstdCompressor(level=level) | |
581 | ref_frame = refctx.compress(original) |
|
611 | ref_frame = refctx.compress(original) | |
582 |
|
612 | |||
583 | cctx = zstd.ZstdCompressor(level=level) |
|
613 | cctx = zstd.ZstdCompressor(level=level) | |
584 | source = io.BytesIO(original) |
|
614 | source = io.BytesIO(original) | |
585 | dest = io.BytesIO() |
|
615 | dest = io.BytesIO() | |
586 |
|
616 | |||
587 | cctx.copy_stream( |
|
617 | cctx.copy_stream( | |
588 | source, dest, size=len(original), read_size=read_size, write_size=write_size |
|
618 | source, | |
|
619 | dest, | |||
|
620 | size=len(original), | |||
|
621 | read_size=read_size, | |||
|
622 | write_size=write_size, | |||
589 | ) |
|
623 | ) | |
590 |
|
624 | |||
591 | self.assertEqual(dest.getvalue(), ref_frame) |
|
625 | self.assertEqual(dest.getvalue(), ref_frame) | |
592 |
|
626 | |||
593 |
|
627 | |||
594 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
628 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
595 | @make_cffi |
|
629 | @make_cffi | |
596 | class TestCompressor_compressobj_fuzzing(TestCase): |
|
630 | class TestCompressor_compressobj_fuzzing(TestCase): | |
597 | @hypothesis.settings( |
|
631 | @hypothesis.settings( | |
598 | suppress_health_check=[ |
|
632 | suppress_health_check=[ | |
599 | hypothesis.HealthCheck.large_base_example, |
|
633 | hypothesis.HealthCheck.large_base_example, | |
600 | hypothesis.HealthCheck.too_slow, |
|
634 | hypothesis.HealthCheck.too_slow, | |
601 | ] |
|
635 | ] | |
602 | ) |
|
636 | ) | |
603 | @hypothesis.given( |
|
637 | @hypothesis.given( | |
604 | original=strategies.sampled_from(random_input_data()), |
|
638 | original=strategies.sampled_from(random_input_data()), | |
605 | level=strategies.integers(min_value=1, max_value=5), |
|
639 | level=strategies.integers(min_value=1, max_value=5), | |
606 | chunk_sizes=strategies.data(), |
|
640 | chunk_sizes=strategies.data(), | |
607 | ) |
|
641 | ) | |
608 | def test_random_input_sizes(self, original, level, chunk_sizes): |
|
642 | def test_random_input_sizes(self, original, level, chunk_sizes): | |
609 | refctx = zstd.ZstdCompressor(level=level) |
|
643 | refctx = zstd.ZstdCompressor(level=level) | |
610 | ref_frame = refctx.compress(original) |
|
644 | ref_frame = refctx.compress(original) | |
611 |
|
645 | |||
612 | cctx = zstd.ZstdCompressor(level=level) |
|
646 | cctx = zstd.ZstdCompressor(level=level) | |
613 | cobj = cctx.compressobj(size=len(original)) |
|
647 | cobj = cctx.compressobj(size=len(original)) | |
614 |
|
648 | |||
615 | chunks = [] |
|
649 | chunks = [] | |
616 | i = 0 |
|
650 | i = 0 | |
617 | while True: |
|
651 | while True: | |
618 | chunk_size = chunk_sizes.draw(strategies.integers(1, 4096)) |
|
652 | chunk_size = chunk_sizes.draw(strategies.integers(1, 4096)) | |
619 | source = original[i : i + chunk_size] |
|
653 | source = original[i : i + chunk_size] | |
620 | if not source: |
|
654 | if not source: | |
621 | break |
|
655 | break | |
622 |
|
656 | |||
623 | chunks.append(cobj.compress(source)) |
|
657 | chunks.append(cobj.compress(source)) | |
624 | i += chunk_size |
|
658 | i += chunk_size | |
625 |
|
659 | |||
626 | chunks.append(cobj.flush()) |
|
660 | chunks.append(cobj.flush()) | |
627 |
|
661 | |||
628 | self.assertEqual(b"".join(chunks), ref_frame) |
|
662 | self.assertEqual(b"".join(chunks), ref_frame) | |
629 |
|
663 | |||
630 | @hypothesis.settings( |
|
664 | @hypothesis.settings( | |
631 | suppress_health_check=[ |
|
665 | suppress_health_check=[ | |
632 | hypothesis.HealthCheck.large_base_example, |
|
666 | hypothesis.HealthCheck.large_base_example, | |
633 | hypothesis.HealthCheck.too_slow, |
|
667 | hypothesis.HealthCheck.too_slow, | |
634 | ] |
|
668 | ] | |
635 | ) |
|
669 | ) | |
636 | @hypothesis.given( |
|
670 | @hypothesis.given( | |
637 | original=strategies.sampled_from(random_input_data()), |
|
671 | original=strategies.sampled_from(random_input_data()), | |
638 | level=strategies.integers(min_value=1, max_value=5), |
|
672 | level=strategies.integers(min_value=1, max_value=5), | |
639 | chunk_sizes=strategies.data(), |
|
673 | chunk_sizes=strategies.data(), | |
640 | flushes=strategies.data(), |
|
674 | flushes=strategies.data(), | |
641 | ) |
|
675 | ) | |
642 | def test_flush_block(self, original, level, chunk_sizes, flushes): |
|
676 | def test_flush_block(self, original, level, chunk_sizes, flushes): | |
643 | cctx = zstd.ZstdCompressor(level=level) |
|
677 | cctx = zstd.ZstdCompressor(level=level) | |
644 | cobj = cctx.compressobj() |
|
678 | cobj = cctx.compressobj() | |
645 |
|
679 | |||
646 | dctx = zstd.ZstdDecompressor() |
|
680 | dctx = zstd.ZstdDecompressor() | |
647 | dobj = dctx.decompressobj() |
|
681 | dobj = dctx.decompressobj() | |
648 |
|
682 | |||
649 | compressed_chunks = [] |
|
683 | compressed_chunks = [] | |
650 | decompressed_chunks = [] |
|
684 | decompressed_chunks = [] | |
651 | i = 0 |
|
685 | i = 0 | |
652 | while True: |
|
686 | while True: | |
653 | input_size = chunk_sizes.draw(strategies.integers(1, 4096)) |
|
687 | input_size = chunk_sizes.draw(strategies.integers(1, 4096)) | |
654 | source = original[i : i + input_size] |
|
688 | source = original[i : i + input_size] | |
655 | if not source: |
|
689 | if not source: | |
656 | break |
|
690 | break | |
657 |
|
691 | |||
658 | i += input_size |
|
692 | i += input_size | |
659 |
|
693 | |||
660 | chunk = cobj.compress(source) |
|
694 | chunk = cobj.compress(source) | |
661 | compressed_chunks.append(chunk) |
|
695 | compressed_chunks.append(chunk) | |
662 | decompressed_chunks.append(dobj.decompress(chunk)) |
|
696 | decompressed_chunks.append(dobj.decompress(chunk)) | |
663 |
|
697 | |||
664 | if not flushes.draw(strategies.booleans()): |
|
698 | if not flushes.draw(strategies.booleans()): | |
665 | continue |
|
699 | continue | |
666 |
|
700 | |||
667 | chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) |
|
701 | chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) | |
668 | compressed_chunks.append(chunk) |
|
702 | compressed_chunks.append(chunk) | |
669 | decompressed_chunks.append(dobj.decompress(chunk)) |
|
703 | decompressed_chunks.append(dobj.decompress(chunk)) | |
670 |
|
704 | |||
671 | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) |
|
705 | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | |
672 |
|
706 | |||
673 | chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH) |
|
707 | chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH) | |
674 | compressed_chunks.append(chunk) |
|
708 | compressed_chunks.append(chunk) | |
675 | decompressed_chunks.append(dobj.decompress(chunk)) |
|
709 | decompressed_chunks.append(dobj.decompress(chunk)) | |
676 |
|
710 | |||
677 | self.assertEqual( |
|
711 | self.assertEqual( | |
678 | dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)), |
|
712 | dctx.decompress( | |
|
713 | b"".join(compressed_chunks), max_output_size=len(original) | |||
|
714 | ), | |||
679 | original, |
|
715 | original, | |
680 | ) |
|
716 | ) | |
681 | self.assertEqual(b"".join(decompressed_chunks), original) |
|
717 | self.assertEqual(b"".join(decompressed_chunks), original) | |
682 |
|
718 | |||
683 |
|
719 | |||
684 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
720 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
685 | @make_cffi |
|
721 | @make_cffi | |
686 | class TestCompressor_read_to_iter_fuzzing(TestCase): |
|
722 | class TestCompressor_read_to_iter_fuzzing(TestCase): | |
687 | @hypothesis.given( |
|
723 | @hypothesis.given( | |
688 | original=strategies.sampled_from(random_input_data()), |
|
724 | original=strategies.sampled_from(random_input_data()), | |
689 | level=strategies.integers(min_value=1, max_value=5), |
|
725 | level=strategies.integers(min_value=1, max_value=5), | |
690 | read_size=strategies.integers(min_value=1, max_value=4096), |
|
726 | read_size=strategies.integers(min_value=1, max_value=4096), | |
691 | write_size=strategies.integers(min_value=1, max_value=4096), |
|
727 | write_size=strategies.integers(min_value=1, max_value=4096), | |
692 | ) |
|
728 | ) | |
693 |
def test_read_write_size_variance( |
|
729 | def test_read_write_size_variance( | |
|
730 | self, original, level, read_size, write_size | |||
|
731 | ): | |||
694 | refcctx = zstd.ZstdCompressor(level=level) |
|
732 | refcctx = zstd.ZstdCompressor(level=level) | |
695 | ref_frame = refcctx.compress(original) |
|
733 | ref_frame = refcctx.compress(original) | |
696 |
|
734 | |||
697 | source = io.BytesIO(original) |
|
735 | source = io.BytesIO(original) | |
698 |
|
736 | |||
699 | cctx = zstd.ZstdCompressor(level=level) |
|
737 | cctx = zstd.ZstdCompressor(level=level) | |
700 | chunks = list( |
|
738 | chunks = list( | |
701 | cctx.read_to_iter( |
|
739 | cctx.read_to_iter( | |
702 | source, size=len(original), read_size=read_size, write_size=write_size |
|
740 | source, | |
|
741 | size=len(original), | |||
|
742 | read_size=read_size, | |||
|
743 | write_size=write_size, | |||
703 | ) |
|
744 | ) | |
704 | ) |
|
745 | ) | |
705 |
|
746 | |||
706 | self.assertEqual(b"".join(chunks), ref_frame) |
|
747 | self.assertEqual(b"".join(chunks), ref_frame) | |
707 |
|
748 | |||
708 |
|
749 | |||
709 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
750 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
710 | class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase): |
|
751 | class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase): | |
711 | @hypothesis.given( |
|
752 | @hypothesis.given( | |
712 | original=strategies.lists( |
|
753 | original=strategies.lists( | |
713 |
strategies.sampled_from(random_input_data()), |
|
754 | strategies.sampled_from(random_input_data()), | |
|
755 | min_size=1, | |||
|
756 | max_size=1024, | |||
714 | ), |
|
757 | ), | |
715 | threads=strategies.integers(min_value=1, max_value=8), |
|
758 | threads=strategies.integers(min_value=1, max_value=8), | |
716 | use_dict=strategies.booleans(), |
|
759 | use_dict=strategies.booleans(), | |
717 | ) |
|
760 | ) | |
718 | def test_data_equivalence(self, original, threads, use_dict): |
|
761 | def test_data_equivalence(self, original, threads, use_dict): | |
719 | kwargs = {} |
|
762 | kwargs = {} | |
720 |
|
763 | |||
721 | # Use a content dictionary because it is cheap to create. |
|
764 | # Use a content dictionary because it is cheap to create. | |
722 | if use_dict: |
|
765 | if use_dict: | |
723 | kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) |
|
766 | kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) | |
724 |
|
767 | |||
725 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True, **kwargs) |
|
768 | cctx = zstd.ZstdCompressor(level=1, write_checksum=True, **kwargs) | |
726 |
|
769 | |||
727 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
770 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
728 | self.skipTest("multi_compress_to_buffer not available") |
|
771 | self.skipTest("multi_compress_to_buffer not available") | |
729 |
|
772 | |||
730 | result = cctx.multi_compress_to_buffer(original, threads=-1) |
|
773 | result = cctx.multi_compress_to_buffer(original, threads=-1) | |
731 |
|
774 | |||
732 | self.assertEqual(len(result), len(original)) |
|
775 | self.assertEqual(len(result), len(original)) | |
733 |
|
776 | |||
734 | # The frame produced via the batch APIs may not be bit identical to that |
|
777 | # The frame produced via the batch APIs may not be bit identical to that | |
735 | # produced by compress() because compression parameters are adjusted |
|
778 | # produced by compress() because compression parameters are adjusted | |
736 | # from the first input in batch mode. So the only thing we can do is |
|
779 | # from the first input in batch mode. So the only thing we can do is | |
737 | # verify the decompressed data matches the input. |
|
780 | # verify the decompressed data matches the input. | |
738 | dctx = zstd.ZstdDecompressor(**kwargs) |
|
781 | dctx = zstd.ZstdDecompressor(**kwargs) | |
739 |
|
782 | |||
740 | for i, frame in enumerate(result): |
|
783 | for i, frame in enumerate(result): | |
741 | self.assertEqual(dctx.decompress(frame), original[i]) |
|
784 | self.assertEqual(dctx.decompress(frame), original[i]) | |
742 |
|
785 | |||
743 |
|
786 | |||
744 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
787 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
745 | @make_cffi |
|
788 | @make_cffi | |
746 | class TestCompressor_chunker_fuzzing(TestCase): |
|
789 | class TestCompressor_chunker_fuzzing(TestCase): | |
747 | @hypothesis.settings( |
|
790 | @hypothesis.settings( | |
748 | suppress_health_check=[ |
|
791 | suppress_health_check=[ | |
749 | hypothesis.HealthCheck.large_base_example, |
|
792 | hypothesis.HealthCheck.large_base_example, | |
750 | hypothesis.HealthCheck.too_slow, |
|
793 | hypothesis.HealthCheck.too_slow, | |
751 | ] |
|
794 | ] | |
752 | ) |
|
795 | ) | |
753 | @hypothesis.given( |
|
796 | @hypothesis.given( | |
754 | original=strategies.sampled_from(random_input_data()), |
|
797 | original=strategies.sampled_from(random_input_data()), | |
755 | level=strategies.integers(min_value=1, max_value=5), |
|
798 | level=strategies.integers(min_value=1, max_value=5), | |
756 | chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), |
|
799 | chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), | |
757 | input_sizes=strategies.data(), |
|
800 | input_sizes=strategies.data(), | |
758 | ) |
|
801 | ) | |
759 | def test_random_input_sizes(self, original, level, chunk_size, input_sizes): |
|
802 | def test_random_input_sizes(self, original, level, chunk_size, input_sizes): | |
760 | cctx = zstd.ZstdCompressor(level=level) |
|
803 | cctx = zstd.ZstdCompressor(level=level) | |
761 | chunker = cctx.chunker(chunk_size=chunk_size) |
|
804 | chunker = cctx.chunker(chunk_size=chunk_size) | |
762 |
|
805 | |||
763 | chunks = [] |
|
806 | chunks = [] | |
764 | i = 0 |
|
807 | i = 0 | |
765 | while True: |
|
808 | while True: | |
766 | input_size = input_sizes.draw(strategies.integers(1, 4096)) |
|
809 | input_size = input_sizes.draw(strategies.integers(1, 4096)) | |
767 | source = original[i : i + input_size] |
|
810 | source = original[i : i + input_size] | |
768 | if not source: |
|
811 | if not source: | |
769 | break |
|
812 | break | |
770 |
|
813 | |||
771 | chunks.extend(chunker.compress(source)) |
|
814 | chunks.extend(chunker.compress(source)) | |
772 | i += input_size |
|
815 | i += input_size | |
773 |
|
816 | |||
774 | chunks.extend(chunker.finish()) |
|
817 | chunks.extend(chunker.finish()) | |
775 |
|
818 | |||
776 | dctx = zstd.ZstdDecompressor() |
|
819 | dctx = zstd.ZstdDecompressor() | |
777 |
|
820 | |||
778 | self.assertEqual( |
|
821 | self.assertEqual( | |
779 |
dctx.decompress(b"".join(chunks), max_output_size=len(original)), |
|
822 | dctx.decompress(b"".join(chunks), max_output_size=len(original)), | |
|
823 | original, | |||
780 | ) |
|
824 | ) | |
781 |
|
825 | |||
782 | self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1])) |
|
826 | self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1])) | |
783 |
|
827 | |||
784 | @hypothesis.settings( |
|
828 | @hypothesis.settings( | |
785 | suppress_health_check=[ |
|
829 | suppress_health_check=[ | |
786 | hypothesis.HealthCheck.large_base_example, |
|
830 | hypothesis.HealthCheck.large_base_example, | |
787 | hypothesis.HealthCheck.too_slow, |
|
831 | hypothesis.HealthCheck.too_slow, | |
788 | ] |
|
832 | ] | |
789 | ) |
|
833 | ) | |
790 | @hypothesis.given( |
|
834 | @hypothesis.given( | |
791 | original=strategies.sampled_from(random_input_data()), |
|
835 | original=strategies.sampled_from(random_input_data()), | |
792 | level=strategies.integers(min_value=1, max_value=5), |
|
836 | level=strategies.integers(min_value=1, max_value=5), | |
793 | chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), |
|
837 | chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), | |
794 | input_sizes=strategies.data(), |
|
838 | input_sizes=strategies.data(), | |
795 | flushes=strategies.data(), |
|
839 | flushes=strategies.data(), | |
796 | ) |
|
840 | ) | |
797 | def test_flush_block(self, original, level, chunk_size, input_sizes, flushes): |
|
841 | def test_flush_block( | |
|
842 | self, original, level, chunk_size, input_sizes, flushes | |||
|
843 | ): | |||
798 | cctx = zstd.ZstdCompressor(level=level) |
|
844 | cctx = zstd.ZstdCompressor(level=level) | |
799 | chunker = cctx.chunker(chunk_size=chunk_size) |
|
845 | chunker = cctx.chunker(chunk_size=chunk_size) | |
800 |
|
846 | |||
801 | dctx = zstd.ZstdDecompressor() |
|
847 | dctx = zstd.ZstdDecompressor() | |
802 | dobj = dctx.decompressobj() |
|
848 | dobj = dctx.decompressobj() | |
803 |
|
849 | |||
804 | compressed_chunks = [] |
|
850 | compressed_chunks = [] | |
805 | decompressed_chunks = [] |
|
851 | decompressed_chunks = [] | |
806 | i = 0 |
|
852 | i = 0 | |
807 | while True: |
|
853 | while True: | |
808 | input_size = input_sizes.draw(strategies.integers(1, 4096)) |
|
854 | input_size = input_sizes.draw(strategies.integers(1, 4096)) | |
809 | source = original[i : i + input_size] |
|
855 | source = original[i : i + input_size] | |
810 | if not source: |
|
856 | if not source: | |
811 | break |
|
857 | break | |
812 |
|
858 | |||
813 | i += input_size |
|
859 | i += input_size | |
814 |
|
860 | |||
815 | chunks = list(chunker.compress(source)) |
|
861 | chunks = list(chunker.compress(source)) | |
816 | compressed_chunks.extend(chunks) |
|
862 | compressed_chunks.extend(chunks) | |
817 | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) |
|
863 | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) | |
818 |
|
864 | |||
819 | if not flushes.draw(strategies.booleans()): |
|
865 | if not flushes.draw(strategies.booleans()): | |
820 | continue |
|
866 | continue | |
821 |
|
867 | |||
822 | chunks = list(chunker.flush()) |
|
868 | chunks = list(chunker.flush()) | |
823 | compressed_chunks.extend(chunks) |
|
869 | compressed_chunks.extend(chunks) | |
824 | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) |
|
870 | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) | |
825 |
|
871 | |||
826 | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) |
|
872 | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | |
827 |
|
873 | |||
828 | chunks = list(chunker.finish()) |
|
874 | chunks = list(chunker.finish()) | |
829 | compressed_chunks.extend(chunks) |
|
875 | compressed_chunks.extend(chunks) | |
830 | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) |
|
876 | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) | |
831 |
|
877 | |||
832 | self.assertEqual( |
|
878 | self.assertEqual( | |
833 | dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)), |
|
879 | dctx.decompress( | |
|
880 | b"".join(compressed_chunks), max_output_size=len(original) | |||
|
881 | ), | |||
834 | original, |
|
882 | original, | |
835 | ) |
|
883 | ) | |
836 | self.assertEqual(b"".join(decompressed_chunks), original) |
|
884 | self.assertEqual(b"".join(decompressed_chunks), original) |
@@ -1,241 +1,255 b'' | |||||
1 | import sys |
|
1 | import sys | |
2 | import unittest |
|
2 | import unittest | |
3 |
|
3 | |||
4 | import zstandard as zstd |
|
4 | import zstandard as zstd | |
5 |
|
5 | |||
6 | from .common import ( |
|
6 | from .common import ( | |
7 | make_cffi, |
|
7 | make_cffi, | |
8 | TestCase, |
|
8 | TestCase, | |
9 | ) |
|
9 | ) | |
10 |
|
10 | |||
11 |
|
11 | |||
12 | @make_cffi |
|
12 | @make_cffi | |
13 | class TestCompressionParameters(TestCase): |
|
13 | class TestCompressionParameters(TestCase): | |
14 | def test_bounds(self): |
|
14 | def test_bounds(self): | |
15 | zstd.ZstdCompressionParameters( |
|
15 | zstd.ZstdCompressionParameters( | |
16 | window_log=zstd.WINDOWLOG_MIN, |
|
16 | window_log=zstd.WINDOWLOG_MIN, | |
17 | chain_log=zstd.CHAINLOG_MIN, |
|
17 | chain_log=zstd.CHAINLOG_MIN, | |
18 | hash_log=zstd.HASHLOG_MIN, |
|
18 | hash_log=zstd.HASHLOG_MIN, | |
19 | search_log=zstd.SEARCHLOG_MIN, |
|
19 | search_log=zstd.SEARCHLOG_MIN, | |
20 | min_match=zstd.MINMATCH_MIN + 1, |
|
20 | min_match=zstd.MINMATCH_MIN + 1, | |
21 | target_length=zstd.TARGETLENGTH_MIN, |
|
21 | target_length=zstd.TARGETLENGTH_MIN, | |
22 | strategy=zstd.STRATEGY_FAST, |
|
22 | strategy=zstd.STRATEGY_FAST, | |
23 | ) |
|
23 | ) | |
24 |
|
24 | |||
25 | zstd.ZstdCompressionParameters( |
|
25 | zstd.ZstdCompressionParameters( | |
26 | window_log=zstd.WINDOWLOG_MAX, |
|
26 | window_log=zstd.WINDOWLOG_MAX, | |
27 | chain_log=zstd.CHAINLOG_MAX, |
|
27 | chain_log=zstd.CHAINLOG_MAX, | |
28 | hash_log=zstd.HASHLOG_MAX, |
|
28 | hash_log=zstd.HASHLOG_MAX, | |
29 | search_log=zstd.SEARCHLOG_MAX, |
|
29 | search_log=zstd.SEARCHLOG_MAX, | |
30 | min_match=zstd.MINMATCH_MAX - 1, |
|
30 | min_match=zstd.MINMATCH_MAX - 1, | |
31 | target_length=zstd.TARGETLENGTH_MAX, |
|
31 | target_length=zstd.TARGETLENGTH_MAX, | |
32 | strategy=zstd.STRATEGY_BTULTRA2, |
|
32 | strategy=zstd.STRATEGY_BTULTRA2, | |
33 | ) |
|
33 | ) | |
34 |
|
34 | |||
35 | def test_from_level(self): |
|
35 | def test_from_level(self): | |
36 | p = zstd.ZstdCompressionParameters.from_level(1) |
|
36 | p = zstd.ZstdCompressionParameters.from_level(1) | |
37 | self.assertIsInstance(p, zstd.CompressionParameters) |
|
37 | self.assertIsInstance(p, zstd.CompressionParameters) | |
38 |
|
38 | |||
39 | self.assertEqual(p.window_log, 19) |
|
39 | self.assertEqual(p.window_log, 19) | |
40 |
|
40 | |||
41 | p = zstd.ZstdCompressionParameters.from_level(-4) |
|
41 | p = zstd.ZstdCompressionParameters.from_level(-4) | |
42 | self.assertEqual(p.window_log, 19) |
|
42 | self.assertEqual(p.window_log, 19) | |
43 |
|
43 | |||
44 | def test_members(self): |
|
44 | def test_members(self): | |
45 | p = zstd.ZstdCompressionParameters( |
|
45 | p = zstd.ZstdCompressionParameters( | |
46 | window_log=10, |
|
46 | window_log=10, | |
47 | chain_log=6, |
|
47 | chain_log=6, | |
48 | hash_log=7, |
|
48 | hash_log=7, | |
49 | search_log=4, |
|
49 | search_log=4, | |
50 | min_match=5, |
|
50 | min_match=5, | |
51 | target_length=8, |
|
51 | target_length=8, | |
52 | strategy=1, |
|
52 | strategy=1, | |
53 | ) |
|
53 | ) | |
54 | self.assertEqual(p.window_log, 10) |
|
54 | self.assertEqual(p.window_log, 10) | |
55 | self.assertEqual(p.chain_log, 6) |
|
55 | self.assertEqual(p.chain_log, 6) | |
56 | self.assertEqual(p.hash_log, 7) |
|
56 | self.assertEqual(p.hash_log, 7) | |
57 | self.assertEqual(p.search_log, 4) |
|
57 | self.assertEqual(p.search_log, 4) | |
58 | self.assertEqual(p.min_match, 5) |
|
58 | self.assertEqual(p.min_match, 5) | |
59 | self.assertEqual(p.target_length, 8) |
|
59 | self.assertEqual(p.target_length, 8) | |
60 | self.assertEqual(p.compression_strategy, 1) |
|
60 | self.assertEqual(p.compression_strategy, 1) | |
61 |
|
61 | |||
62 | p = zstd.ZstdCompressionParameters(compression_level=2) |
|
62 | p = zstd.ZstdCompressionParameters(compression_level=2) | |
63 | self.assertEqual(p.compression_level, 2) |
|
63 | self.assertEqual(p.compression_level, 2) | |
64 |
|
64 | |||
65 | p = zstd.ZstdCompressionParameters(threads=4) |
|
65 | p = zstd.ZstdCompressionParameters(threads=4) | |
66 | self.assertEqual(p.threads, 4) |
|
66 | self.assertEqual(p.threads, 4) | |
67 |
|
67 | |||
68 |
p = zstd.ZstdCompressionParameters( |
|
68 | p = zstd.ZstdCompressionParameters( | |
|
69 | threads=2, job_size=1048576, overlap_log=6 | |||
|
70 | ) | |||
69 | self.assertEqual(p.threads, 2) |
|
71 | self.assertEqual(p.threads, 2) | |
70 | self.assertEqual(p.job_size, 1048576) |
|
72 | self.assertEqual(p.job_size, 1048576) | |
71 | self.assertEqual(p.overlap_log, 6) |
|
73 | self.assertEqual(p.overlap_log, 6) | |
72 | self.assertEqual(p.overlap_size_log, 6) |
|
74 | self.assertEqual(p.overlap_size_log, 6) | |
73 |
|
75 | |||
74 | p = zstd.ZstdCompressionParameters(compression_level=-1) |
|
76 | p = zstd.ZstdCompressionParameters(compression_level=-1) | |
75 | self.assertEqual(p.compression_level, -1) |
|
77 | self.assertEqual(p.compression_level, -1) | |
76 |
|
78 | |||
77 | p = zstd.ZstdCompressionParameters(compression_level=-2) |
|
79 | p = zstd.ZstdCompressionParameters(compression_level=-2) | |
78 | self.assertEqual(p.compression_level, -2) |
|
80 | self.assertEqual(p.compression_level, -2) | |
79 |
|
81 | |||
80 | p = zstd.ZstdCompressionParameters(force_max_window=True) |
|
82 | p = zstd.ZstdCompressionParameters(force_max_window=True) | |
81 | self.assertEqual(p.force_max_window, 1) |
|
83 | self.assertEqual(p.force_max_window, 1) | |
82 |
|
84 | |||
83 | p = zstd.ZstdCompressionParameters(enable_ldm=True) |
|
85 | p = zstd.ZstdCompressionParameters(enable_ldm=True) | |
84 | self.assertEqual(p.enable_ldm, 1) |
|
86 | self.assertEqual(p.enable_ldm, 1) | |
85 |
|
87 | |||
86 | p = zstd.ZstdCompressionParameters(ldm_hash_log=7) |
|
88 | p = zstd.ZstdCompressionParameters(ldm_hash_log=7) | |
87 | self.assertEqual(p.ldm_hash_log, 7) |
|
89 | self.assertEqual(p.ldm_hash_log, 7) | |
88 |
|
90 | |||
89 | p = zstd.ZstdCompressionParameters(ldm_min_match=6) |
|
91 | p = zstd.ZstdCompressionParameters(ldm_min_match=6) | |
90 | self.assertEqual(p.ldm_min_match, 6) |
|
92 | self.assertEqual(p.ldm_min_match, 6) | |
91 |
|
93 | |||
92 | p = zstd.ZstdCompressionParameters(ldm_bucket_size_log=7) |
|
94 | p = zstd.ZstdCompressionParameters(ldm_bucket_size_log=7) | |
93 | self.assertEqual(p.ldm_bucket_size_log, 7) |
|
95 | self.assertEqual(p.ldm_bucket_size_log, 7) | |
94 |
|
96 | |||
95 | p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) |
|
97 | p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) | |
96 | self.assertEqual(p.ldm_hash_every_log, 8) |
|
98 | self.assertEqual(p.ldm_hash_every_log, 8) | |
97 | self.assertEqual(p.ldm_hash_rate_log, 8) |
|
99 | self.assertEqual(p.ldm_hash_rate_log, 8) | |
98 |
|
100 | |||
99 | def test_estimated_compression_context_size(self): |
|
101 | def test_estimated_compression_context_size(self): | |
100 | p = zstd.ZstdCompressionParameters( |
|
102 | p = zstd.ZstdCompressionParameters( | |
101 | window_log=20, |
|
103 | window_log=20, | |
102 | chain_log=16, |
|
104 | chain_log=16, | |
103 | hash_log=17, |
|
105 | hash_log=17, | |
104 | search_log=1, |
|
106 | search_log=1, | |
105 | min_match=5, |
|
107 | min_match=5, | |
106 | target_length=16, |
|
108 | target_length=16, | |
107 | strategy=zstd.STRATEGY_DFAST, |
|
109 | strategy=zstd.STRATEGY_DFAST, | |
108 | ) |
|
110 | ) | |
109 |
|
111 | |||
110 | # 32-bit has slightly different values from 64-bit. |
|
112 | # 32-bit has slightly different values from 64-bit. | |
111 | self.assertAlmostEqual( |
|
113 | self.assertAlmostEqual( | |
112 | p.estimated_compression_context_size(), 1294464, delta=400 |
|
114 | p.estimated_compression_context_size(), 1294464, delta=400 | |
113 | ) |
|
115 | ) | |
114 |
|
116 | |||
115 | def test_strategy(self): |
|
117 | def test_strategy(self): | |
116 | with self.assertRaisesRegex( |
|
118 | with self.assertRaisesRegex( | |
117 | ValueError, "cannot specify both compression_strategy" |
|
119 | ValueError, "cannot specify both compression_strategy" | |
118 | ): |
|
120 | ): | |
119 | zstd.ZstdCompressionParameters(strategy=0, compression_strategy=0) |
|
121 | zstd.ZstdCompressionParameters(strategy=0, compression_strategy=0) | |
120 |
|
122 | |||
121 | p = zstd.ZstdCompressionParameters(strategy=2) |
|
123 | p = zstd.ZstdCompressionParameters(strategy=2) | |
122 | self.assertEqual(p.compression_strategy, 2) |
|
124 | self.assertEqual(p.compression_strategy, 2) | |
123 |
|
125 | |||
124 | p = zstd.ZstdCompressionParameters(strategy=3) |
|
126 | p = zstd.ZstdCompressionParameters(strategy=3) | |
125 | self.assertEqual(p.compression_strategy, 3) |
|
127 | self.assertEqual(p.compression_strategy, 3) | |
126 |
|
128 | |||
127 | def test_ldm_hash_rate_log(self): |
|
129 | def test_ldm_hash_rate_log(self): | |
128 | with self.assertRaisesRegex( |
|
130 | with self.assertRaisesRegex( | |
129 | ValueError, "cannot specify both ldm_hash_rate_log" |
|
131 | ValueError, "cannot specify both ldm_hash_rate_log" | |
130 | ): |
|
132 | ): | |
131 |
zstd.ZstdCompressionParameters( |
|
133 | zstd.ZstdCompressionParameters( | |
|
134 | ldm_hash_rate_log=8, ldm_hash_every_log=4 | |||
|
135 | ) | |||
132 |
|
136 | |||
133 | p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) |
|
137 | p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) | |
134 | self.assertEqual(p.ldm_hash_every_log, 8) |
|
138 | self.assertEqual(p.ldm_hash_every_log, 8) | |
135 |
|
139 | |||
136 | p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16) |
|
140 | p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16) | |
137 | self.assertEqual(p.ldm_hash_every_log, 16) |
|
141 | self.assertEqual(p.ldm_hash_every_log, 16) | |
138 |
|
142 | |||
139 | def test_overlap_log(self): |
|
143 | def test_overlap_log(self): | |
140 |
with self.assertRaisesRegex( |
|
144 | with self.assertRaisesRegex( | |
|
145 | ValueError, "cannot specify both overlap_log" | |||
|
146 | ): | |||
141 | zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9) |
|
147 | zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9) | |
142 |
|
148 | |||
143 | p = zstd.ZstdCompressionParameters(overlap_log=2) |
|
149 | p = zstd.ZstdCompressionParameters(overlap_log=2) | |
144 | self.assertEqual(p.overlap_log, 2) |
|
150 | self.assertEqual(p.overlap_log, 2) | |
145 | self.assertEqual(p.overlap_size_log, 2) |
|
151 | self.assertEqual(p.overlap_size_log, 2) | |
146 |
|
152 | |||
147 | p = zstd.ZstdCompressionParameters(overlap_size_log=4) |
|
153 | p = zstd.ZstdCompressionParameters(overlap_size_log=4) | |
148 | self.assertEqual(p.overlap_log, 4) |
|
154 | self.assertEqual(p.overlap_log, 4) | |
149 | self.assertEqual(p.overlap_size_log, 4) |
|
155 | self.assertEqual(p.overlap_size_log, 4) | |
150 |
|
156 | |||
151 |
|
157 | |||
152 | @make_cffi |
|
158 | @make_cffi | |
153 | class TestFrameParameters(TestCase): |
|
159 | class TestFrameParameters(TestCase): | |
154 | def test_invalid_type(self): |
|
160 | def test_invalid_type(self): | |
155 | with self.assertRaises(TypeError): |
|
161 | with self.assertRaises(TypeError): | |
156 | zstd.get_frame_parameters(None) |
|
162 | zstd.get_frame_parameters(None) | |
157 |
|
163 | |||
158 | # Python 3 doesn't appear to convert unicode to Py_buffer. |
|
164 | # Python 3 doesn't appear to convert unicode to Py_buffer. | |
159 | if sys.version_info[0] >= 3: |
|
165 | if sys.version_info[0] >= 3: | |
160 | with self.assertRaises(TypeError): |
|
166 | with self.assertRaises(TypeError): | |
161 | zstd.get_frame_parameters(u"foobarbaz") |
|
167 | zstd.get_frame_parameters(u"foobarbaz") | |
162 | else: |
|
168 | else: | |
163 | # CPython will convert unicode to Py_buffer. But CFFI won't. |
|
169 | # CPython will convert unicode to Py_buffer. But CFFI won't. | |
164 | if zstd.backend == "cffi": |
|
170 | if zstd.backend == "cffi": | |
165 | with self.assertRaises(TypeError): |
|
171 | with self.assertRaises(TypeError): | |
166 | zstd.get_frame_parameters(u"foobarbaz") |
|
172 | zstd.get_frame_parameters(u"foobarbaz") | |
167 | else: |
|
173 | else: | |
168 | with self.assertRaises(zstd.ZstdError): |
|
174 | with self.assertRaises(zstd.ZstdError): | |
169 | zstd.get_frame_parameters(u"foobarbaz") |
|
175 | zstd.get_frame_parameters(u"foobarbaz") | |
170 |
|
176 | |||
171 | def test_invalid_input_sizes(self): |
|
177 | def test_invalid_input_sizes(self): | |
172 |
with self.assertRaisesRegex( |
|
178 | with self.assertRaisesRegex( | |
|
179 | zstd.ZstdError, "not enough data for frame" | |||
|
180 | ): | |||
173 | zstd.get_frame_parameters(b"") |
|
181 | zstd.get_frame_parameters(b"") | |
174 |
|
182 | |||
175 |
with self.assertRaisesRegex( |
|
183 | with self.assertRaisesRegex( | |
|
184 | zstd.ZstdError, "not enough data for frame" | |||
|
185 | ): | |||
176 | zstd.get_frame_parameters(zstd.FRAME_HEADER) |
|
186 | zstd.get_frame_parameters(zstd.FRAME_HEADER) | |
177 |
|
187 | |||
178 | def test_invalid_frame(self): |
|
188 | def test_invalid_frame(self): | |
179 | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): |
|
189 | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): | |
180 | zstd.get_frame_parameters(b"foobarbaz") |
|
190 | zstd.get_frame_parameters(b"foobarbaz") | |
181 |
|
191 | |||
182 | def test_attributes(self): |
|
192 | def test_attributes(self): | |
183 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00") |
|
193 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00") | |
184 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
194 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
185 | self.assertEqual(params.window_size, 1024) |
|
195 | self.assertEqual(params.window_size, 1024) | |
186 | self.assertEqual(params.dict_id, 0) |
|
196 | self.assertEqual(params.dict_id, 0) | |
187 | self.assertFalse(params.has_checksum) |
|
197 | self.assertFalse(params.has_checksum) | |
188 |
|
198 | |||
189 | # Lowest 2 bits indicate a dictionary and length. Here, the dict id is 1 byte. |
|
199 | # Lowest 2 bits indicate a dictionary and length. Here, the dict id is 1 byte. | |
190 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x01\x00\xff") |
|
200 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x01\x00\xff") | |
191 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
201 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
192 | self.assertEqual(params.window_size, 1024) |
|
202 | self.assertEqual(params.window_size, 1024) | |
193 | self.assertEqual(params.dict_id, 255) |
|
203 | self.assertEqual(params.dict_id, 255) | |
194 | self.assertFalse(params.has_checksum) |
|
204 | self.assertFalse(params.has_checksum) | |
195 |
|
205 | |||
196 | # Lowest 3rd bit indicates if checksum is present. |
|
206 | # Lowest 3rd bit indicates if checksum is present. | |
197 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00") |
|
207 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00") | |
198 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
208 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
199 | self.assertEqual(params.window_size, 1024) |
|
209 | self.assertEqual(params.window_size, 1024) | |
200 | self.assertEqual(params.dict_id, 0) |
|
210 | self.assertEqual(params.dict_id, 0) | |
201 | self.assertTrue(params.has_checksum) |
|
211 | self.assertTrue(params.has_checksum) | |
202 |
|
212 | |||
203 | # Upper 2 bits indicate content size. |
|
213 | # Upper 2 bits indicate content size. | |
204 |
params = zstd.get_frame_parameters( |
|
214 | params = zstd.get_frame_parameters( | |
|
215 | zstd.FRAME_HEADER + b"\x40\x00\xff\x00" | |||
|
216 | ) | |||
205 | self.assertEqual(params.content_size, 511) |
|
217 | self.assertEqual(params.content_size, 511) | |
206 | self.assertEqual(params.window_size, 1024) |
|
218 | self.assertEqual(params.window_size, 1024) | |
207 | self.assertEqual(params.dict_id, 0) |
|
219 | self.assertEqual(params.dict_id, 0) | |
208 | self.assertFalse(params.has_checksum) |
|
220 | self.assertFalse(params.has_checksum) | |
209 |
|
221 | |||
210 | # Window descriptor is 2nd byte after frame header. |
|
222 | # Window descriptor is 2nd byte after frame header. | |
211 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40") |
|
223 | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40") | |
212 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
224 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
213 | self.assertEqual(params.window_size, 262144) |
|
225 | self.assertEqual(params.window_size, 262144) | |
214 | self.assertEqual(params.dict_id, 0) |
|
226 | self.assertEqual(params.dict_id, 0) | |
215 | self.assertFalse(params.has_checksum) |
|
227 | self.assertFalse(params.has_checksum) | |
216 |
|
228 | |||
217 | # Set multiple things. |
|
229 | # Set multiple things. | |
218 |
params = zstd.get_frame_parameters( |
|
230 | params = zstd.get_frame_parameters( | |
|
231 | zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00" | |||
|
232 | ) | |||
219 | self.assertEqual(params.content_size, 272) |
|
233 | self.assertEqual(params.content_size, 272) | |
220 | self.assertEqual(params.window_size, 262144) |
|
234 | self.assertEqual(params.window_size, 262144) | |
221 | self.assertEqual(params.dict_id, 15) |
|
235 | self.assertEqual(params.dict_id, 15) | |
222 | self.assertTrue(params.has_checksum) |
|
236 | self.assertTrue(params.has_checksum) | |
223 |
|
237 | |||
224 | def test_input_types(self): |
|
238 | def test_input_types(self): | |
225 | v = zstd.FRAME_HEADER + b"\x00\x00" |
|
239 | v = zstd.FRAME_HEADER + b"\x00\x00" | |
226 |
|
240 | |||
227 | mutable_array = bytearray(len(v)) |
|
241 | mutable_array = bytearray(len(v)) | |
228 | mutable_array[:] = v |
|
242 | mutable_array[:] = v | |
229 |
|
243 | |||
230 | sources = [ |
|
244 | sources = [ | |
231 | memoryview(v), |
|
245 | memoryview(v), | |
232 | bytearray(v), |
|
246 | bytearray(v), | |
233 | mutable_array, |
|
247 | mutable_array, | |
234 | ] |
|
248 | ] | |
235 |
|
249 | |||
236 | for source in sources: |
|
250 | for source in sources: | |
237 | params = zstd.get_frame_parameters(source) |
|
251 | params = zstd.get_frame_parameters(source) | |
238 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) |
|
252 | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | |
239 | self.assertEqual(params.window_size, 1024) |
|
253 | self.assertEqual(params.window_size, 1024) | |
240 | self.assertEqual(params.dict_id, 0) |
|
254 | self.assertEqual(params.dict_id, 0) | |
241 | self.assertFalse(params.has_checksum) |
|
255 | self.assertFalse(params.has_checksum) |
@@ -1,105 +1,121 b'' | |||||
1 | import io |
|
1 | import io | |
2 | import os |
|
2 | import os | |
3 | import sys |
|
3 | import sys | |
4 | import unittest |
|
4 | import unittest | |
5 |
|
5 | |||
6 | try: |
|
6 | try: | |
7 | import hypothesis |
|
7 | import hypothesis | |
8 | import hypothesis.strategies as strategies |
|
8 | import hypothesis.strategies as strategies | |
9 | except ImportError: |
|
9 | except ImportError: | |
10 | raise unittest.SkipTest("hypothesis not available") |
|
10 | raise unittest.SkipTest("hypothesis not available") | |
11 |
|
11 | |||
12 | import zstandard as zstd |
|
12 | import zstandard as zstd | |
13 |
|
13 | |||
14 | from .common import ( |
|
14 | from .common import ( | |
15 | make_cffi, |
|
15 | make_cffi, | |
16 | TestCase, |
|
16 | TestCase, | |
17 | ) |
|
17 | ) | |
18 |
|
18 | |||
19 |
|
19 | |||
20 | s_windowlog = strategies.integers( |
|
20 | s_windowlog = strategies.integers( | |
21 | min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX |
|
21 | min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX | |
22 | ) |
|
22 | ) | |
23 | s_chainlog = strategies.integers( |
|
23 | s_chainlog = strategies.integers( | |
24 | min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX |
|
24 | min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX | |
25 | ) |
|
25 | ) | |
26 | s_hashlog = strategies.integers(min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX) |
|
26 | s_hashlog = strategies.integers( | |
|
27 | min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX | |||
|
28 | ) | |||
27 | s_searchlog = strategies.integers( |
|
29 | s_searchlog = strategies.integers( | |
28 | min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX |
|
30 | min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX | |
29 | ) |
|
31 | ) | |
30 | s_minmatch = strategies.integers( |
|
32 | s_minmatch = strategies.integers( | |
31 | min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX |
|
33 | min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX | |
32 | ) |
|
34 | ) | |
33 | s_targetlength = strategies.integers( |
|
35 | s_targetlength = strategies.integers( | |
34 | min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX |
|
36 | min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX | |
35 | ) |
|
37 | ) | |
36 | s_strategy = strategies.sampled_from( |
|
38 | s_strategy = strategies.sampled_from( | |
37 | ( |
|
39 | ( | |
38 | zstd.STRATEGY_FAST, |
|
40 | zstd.STRATEGY_FAST, | |
39 | zstd.STRATEGY_DFAST, |
|
41 | zstd.STRATEGY_DFAST, | |
40 | zstd.STRATEGY_GREEDY, |
|
42 | zstd.STRATEGY_GREEDY, | |
41 | zstd.STRATEGY_LAZY, |
|
43 | zstd.STRATEGY_LAZY, | |
42 | zstd.STRATEGY_LAZY2, |
|
44 | zstd.STRATEGY_LAZY2, | |
43 | zstd.STRATEGY_BTLAZY2, |
|
45 | zstd.STRATEGY_BTLAZY2, | |
44 | zstd.STRATEGY_BTOPT, |
|
46 | zstd.STRATEGY_BTOPT, | |
45 | zstd.STRATEGY_BTULTRA, |
|
47 | zstd.STRATEGY_BTULTRA, | |
46 | zstd.STRATEGY_BTULTRA2, |
|
48 | zstd.STRATEGY_BTULTRA2, | |
47 | ) |
|
49 | ) | |
48 | ) |
|
50 | ) | |
49 |
|
51 | |||
50 |
|
52 | |||
51 | @make_cffi |
|
53 | @make_cffi | |
52 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
54 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
53 | class TestCompressionParametersHypothesis(TestCase): |
|
55 | class TestCompressionParametersHypothesis(TestCase): | |
54 | @hypothesis.given( |
|
56 | @hypothesis.given( | |
55 | s_windowlog, |
|
57 | s_windowlog, | |
56 | s_chainlog, |
|
58 | s_chainlog, | |
57 | s_hashlog, |
|
59 | s_hashlog, | |
58 | s_searchlog, |
|
60 | s_searchlog, | |
59 | s_minmatch, |
|
61 | s_minmatch, | |
60 | s_targetlength, |
|
62 | s_targetlength, | |
61 | s_strategy, |
|
63 | s_strategy, | |
62 | ) |
|
64 | ) | |
63 | def test_valid_init( |
|
65 | def test_valid_init( | |
64 | self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy |
|
66 | self, | |
|
67 | windowlog, | |||
|
68 | chainlog, | |||
|
69 | hashlog, | |||
|
70 | searchlog, | |||
|
71 | minmatch, | |||
|
72 | targetlength, | |||
|
73 | strategy, | |||
65 | ): |
|
74 | ): | |
66 | zstd.ZstdCompressionParameters( |
|
75 | zstd.ZstdCompressionParameters( | |
67 | window_log=windowlog, |
|
76 | window_log=windowlog, | |
68 | chain_log=chainlog, |
|
77 | chain_log=chainlog, | |
69 | hash_log=hashlog, |
|
78 | hash_log=hashlog, | |
70 | search_log=searchlog, |
|
79 | search_log=searchlog, | |
71 | min_match=minmatch, |
|
80 | min_match=minmatch, | |
72 | target_length=targetlength, |
|
81 | target_length=targetlength, | |
73 | strategy=strategy, |
|
82 | strategy=strategy, | |
74 | ) |
|
83 | ) | |
75 |
|
84 | |||
76 | @hypothesis.given( |
|
85 | @hypothesis.given( | |
77 | s_windowlog, |
|
86 | s_windowlog, | |
78 | s_chainlog, |
|
87 | s_chainlog, | |
79 | s_hashlog, |
|
88 | s_hashlog, | |
80 | s_searchlog, |
|
89 | s_searchlog, | |
81 | s_minmatch, |
|
90 | s_minmatch, | |
82 | s_targetlength, |
|
91 | s_targetlength, | |
83 | s_strategy, |
|
92 | s_strategy, | |
84 | ) |
|
93 | ) | |
85 | def test_estimated_compression_context_size( |
|
94 | def test_estimated_compression_context_size( | |
86 | self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy |
|
95 | self, | |
|
96 | windowlog, | |||
|
97 | chainlog, | |||
|
98 | hashlog, | |||
|
99 | searchlog, | |||
|
100 | minmatch, | |||
|
101 | targetlength, | |||
|
102 | strategy, | |||
87 | ): |
|
103 | ): | |
88 | if minmatch == zstd.MINMATCH_MIN and strategy in ( |
|
104 | if minmatch == zstd.MINMATCH_MIN and strategy in ( | |
89 | zstd.STRATEGY_FAST, |
|
105 | zstd.STRATEGY_FAST, | |
90 | zstd.STRATEGY_GREEDY, |
|
106 | zstd.STRATEGY_GREEDY, | |
91 | ): |
|
107 | ): | |
92 | minmatch += 1 |
|
108 | minmatch += 1 | |
93 | elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST: |
|
109 | elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST: | |
94 | minmatch -= 1 |
|
110 | minmatch -= 1 | |
95 |
|
111 | |||
96 | p = zstd.ZstdCompressionParameters( |
|
112 | p = zstd.ZstdCompressionParameters( | |
97 | window_log=windowlog, |
|
113 | window_log=windowlog, | |
98 | chain_log=chainlog, |
|
114 | chain_log=chainlog, | |
99 | hash_log=hashlog, |
|
115 | hash_log=hashlog, | |
100 | search_log=searchlog, |
|
116 | search_log=searchlog, | |
101 | min_match=minmatch, |
|
117 | min_match=minmatch, | |
102 | target_length=targetlength, |
|
118 | target_length=targetlength, | |
103 | strategy=strategy, |
|
119 | strategy=strategy, | |
104 | ) |
|
120 | ) | |
105 | size = p.estimated_compression_context_size() |
|
121 | size = p.estimated_compression_context_size() |
@@ -1,1670 +1,1714 b'' | |||||
1 | import io |
|
1 | import io | |
2 | import os |
|
2 | import os | |
3 | import random |
|
3 | import random | |
4 | import struct |
|
4 | import struct | |
5 | import sys |
|
5 | import sys | |
6 | import tempfile |
|
6 | import tempfile | |
7 | import unittest |
|
7 | import unittest | |
8 |
|
8 | |||
9 | import zstandard as zstd |
|
9 | import zstandard as zstd | |
10 |
|
10 | |||
11 | from .common import ( |
|
11 | from .common import ( | |
12 | generate_samples, |
|
12 | generate_samples, | |
13 | make_cffi, |
|
13 | make_cffi, | |
14 | NonClosingBytesIO, |
|
14 | NonClosingBytesIO, | |
15 | OpCountingBytesIO, |
|
15 | OpCountingBytesIO, | |
16 | TestCase, |
|
16 | TestCase, | |
17 | ) |
|
17 | ) | |
18 |
|
18 | |||
19 |
|
19 | |||
20 | if sys.version_info[0] >= 3: |
|
20 | if sys.version_info[0] >= 3: | |
21 | next = lambda it: it.__next__() |
|
21 | next = lambda it: it.__next__() | |
22 | else: |
|
22 | else: | |
23 | next = lambda it: it.next() |
|
23 | next = lambda it: it.next() | |
24 |
|
24 | |||
25 |
|
25 | |||
26 | @make_cffi |
|
26 | @make_cffi | |
27 | class TestFrameHeaderSize(TestCase): |
|
27 | class TestFrameHeaderSize(TestCase): | |
28 | def test_empty(self): |
|
28 | def test_empty(self): | |
29 | with self.assertRaisesRegex( |
|
29 | with self.assertRaisesRegex( | |
30 | zstd.ZstdError, |
|
30 | zstd.ZstdError, | |
31 | "could not determine frame header size: Src size " "is incorrect", |
|
31 | "could not determine frame header size: Src size " "is incorrect", | |
32 | ): |
|
32 | ): | |
33 | zstd.frame_header_size(b"") |
|
33 | zstd.frame_header_size(b"") | |
34 |
|
34 | |||
35 | def test_too_small(self): |
|
35 | def test_too_small(self): | |
36 | with self.assertRaisesRegex( |
|
36 | with self.assertRaisesRegex( | |
37 | zstd.ZstdError, |
|
37 | zstd.ZstdError, | |
38 | "could not determine frame header size: Src size " "is incorrect", |
|
38 | "could not determine frame header size: Src size " "is incorrect", | |
39 | ): |
|
39 | ): | |
40 | zstd.frame_header_size(b"foob") |
|
40 | zstd.frame_header_size(b"foob") | |
41 |
|
41 | |||
42 | def test_basic(self): |
|
42 | def test_basic(self): | |
43 | # It doesn't matter that it isn't a valid frame. |
|
43 | # It doesn't matter that it isn't a valid frame. | |
44 | self.assertEqual(zstd.frame_header_size(b"long enough but no magic"), 6) |
|
44 | self.assertEqual(zstd.frame_header_size(b"long enough but no magic"), 6) | |
45 |
|
45 | |||
46 |
|
46 | |||
47 | @make_cffi |
|
47 | @make_cffi | |
48 | class TestFrameContentSize(TestCase): |
|
48 | class TestFrameContentSize(TestCase): | |
49 | def test_empty(self): |
|
49 | def test_empty(self): | |
50 | with self.assertRaisesRegex( |
|
50 | with self.assertRaisesRegex( | |
51 | zstd.ZstdError, "error when determining content size" |
|
51 | zstd.ZstdError, "error when determining content size" | |
52 | ): |
|
52 | ): | |
53 | zstd.frame_content_size(b"") |
|
53 | zstd.frame_content_size(b"") | |
54 |
|
54 | |||
55 | def test_too_small(self): |
|
55 | def test_too_small(self): | |
56 | with self.assertRaisesRegex( |
|
56 | with self.assertRaisesRegex( | |
57 | zstd.ZstdError, "error when determining content size" |
|
57 | zstd.ZstdError, "error when determining content size" | |
58 | ): |
|
58 | ): | |
59 | zstd.frame_content_size(b"foob") |
|
59 | zstd.frame_content_size(b"foob") | |
60 |
|
60 | |||
61 | def test_bad_frame(self): |
|
61 | def test_bad_frame(self): | |
62 | with self.assertRaisesRegex( |
|
62 | with self.assertRaisesRegex( | |
63 | zstd.ZstdError, "error when determining content size" |
|
63 | zstd.ZstdError, "error when determining content size" | |
64 | ): |
|
64 | ): | |
65 | zstd.frame_content_size(b"invalid frame header") |
|
65 | zstd.frame_content_size(b"invalid frame header") | |
66 |
|
66 | |||
67 | def test_unknown(self): |
|
67 | def test_unknown(self): | |
68 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
68 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
69 | frame = cctx.compress(b"foobar") |
|
69 | frame = cctx.compress(b"foobar") | |
70 |
|
70 | |||
71 | self.assertEqual(zstd.frame_content_size(frame), -1) |
|
71 | self.assertEqual(zstd.frame_content_size(frame), -1) | |
72 |
|
72 | |||
73 | def test_empty(self): |
|
73 | def test_empty(self): | |
74 | cctx = zstd.ZstdCompressor() |
|
74 | cctx = zstd.ZstdCompressor() | |
75 | frame = cctx.compress(b"") |
|
75 | frame = cctx.compress(b"") | |
76 |
|
76 | |||
77 | self.assertEqual(zstd.frame_content_size(frame), 0) |
|
77 | self.assertEqual(zstd.frame_content_size(frame), 0) | |
78 |
|
78 | |||
79 | def test_basic(self): |
|
79 | def test_basic(self): | |
80 | cctx = zstd.ZstdCompressor() |
|
80 | cctx = zstd.ZstdCompressor() | |
81 | frame = cctx.compress(b"foobar") |
|
81 | frame = cctx.compress(b"foobar") | |
82 |
|
82 | |||
83 | self.assertEqual(zstd.frame_content_size(frame), 6) |
|
83 | self.assertEqual(zstd.frame_content_size(frame), 6) | |
84 |
|
84 | |||
85 |
|
85 | |||
86 | @make_cffi |
|
86 | @make_cffi | |
87 | class TestDecompressor(TestCase): |
|
87 | class TestDecompressor(TestCase): | |
88 | def test_memory_size(self): |
|
88 | def test_memory_size(self): | |
89 | dctx = zstd.ZstdDecompressor() |
|
89 | dctx = zstd.ZstdDecompressor() | |
90 |
|
90 | |||
91 | self.assertGreater(dctx.memory_size(), 100) |
|
91 | self.assertGreater(dctx.memory_size(), 100) | |
92 |
|
92 | |||
93 |
|
93 | |||
94 | @make_cffi |
|
94 | @make_cffi | |
95 | class TestDecompressor_decompress(TestCase): |
|
95 | class TestDecompressor_decompress(TestCase): | |
96 | def test_empty_input(self): |
|
96 | def test_empty_input(self): | |
97 | dctx = zstd.ZstdDecompressor() |
|
97 | dctx = zstd.ZstdDecompressor() | |
98 |
|
98 | |||
99 | with self.assertRaisesRegex( |
|
99 | with self.assertRaisesRegex( | |
100 | zstd.ZstdError, "error determining content size from frame header" |
|
100 | zstd.ZstdError, "error determining content size from frame header" | |
101 | ): |
|
101 | ): | |
102 | dctx.decompress(b"") |
|
102 | dctx.decompress(b"") | |
103 |
|
103 | |||
104 | def test_invalid_input(self): |
|
104 | def test_invalid_input(self): | |
105 | dctx = zstd.ZstdDecompressor() |
|
105 | dctx = zstd.ZstdDecompressor() | |
106 |
|
106 | |||
107 | with self.assertRaisesRegex( |
|
107 | with self.assertRaisesRegex( | |
108 | zstd.ZstdError, "error determining content size from frame header" |
|
108 | zstd.ZstdError, "error determining content size from frame header" | |
109 | ): |
|
109 | ): | |
110 | dctx.decompress(b"foobar") |
|
110 | dctx.decompress(b"foobar") | |
111 |
|
111 | |||
112 | def test_input_types(self): |
|
112 | def test_input_types(self): | |
113 | cctx = zstd.ZstdCompressor(level=1) |
|
113 | cctx = zstd.ZstdCompressor(level=1) | |
114 | compressed = cctx.compress(b"foo") |
|
114 | compressed = cctx.compress(b"foo") | |
115 |
|
115 | |||
116 | mutable_array = bytearray(len(compressed)) |
|
116 | mutable_array = bytearray(len(compressed)) | |
117 | mutable_array[:] = compressed |
|
117 | mutable_array[:] = compressed | |
118 |
|
118 | |||
119 | sources = [ |
|
119 | sources = [ | |
120 | memoryview(compressed), |
|
120 | memoryview(compressed), | |
121 | bytearray(compressed), |
|
121 | bytearray(compressed), | |
122 | mutable_array, |
|
122 | mutable_array, | |
123 | ] |
|
123 | ] | |
124 |
|
124 | |||
125 | dctx = zstd.ZstdDecompressor() |
|
125 | dctx = zstd.ZstdDecompressor() | |
126 | for source in sources: |
|
126 | for source in sources: | |
127 | self.assertEqual(dctx.decompress(source), b"foo") |
|
127 | self.assertEqual(dctx.decompress(source), b"foo") | |
128 |
|
128 | |||
129 | def test_no_content_size_in_frame(self): |
|
129 | def test_no_content_size_in_frame(self): | |
130 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
130 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
131 | compressed = cctx.compress(b"foobar") |
|
131 | compressed = cctx.compress(b"foobar") | |
132 |
|
132 | |||
133 | dctx = zstd.ZstdDecompressor() |
|
133 | dctx = zstd.ZstdDecompressor() | |
134 | with self.assertRaisesRegex( |
|
134 | with self.assertRaisesRegex( | |
135 | zstd.ZstdError, "could not determine content size in frame header" |
|
135 | zstd.ZstdError, "could not determine content size in frame header" | |
136 | ): |
|
136 | ): | |
137 | dctx.decompress(compressed) |
|
137 | dctx.decompress(compressed) | |
138 |
|
138 | |||
139 | def test_content_size_present(self): |
|
139 | def test_content_size_present(self): | |
140 | cctx = zstd.ZstdCompressor() |
|
140 | cctx = zstd.ZstdCompressor() | |
141 | compressed = cctx.compress(b"foobar") |
|
141 | compressed = cctx.compress(b"foobar") | |
142 |
|
142 | |||
143 | dctx = zstd.ZstdDecompressor() |
|
143 | dctx = zstd.ZstdDecompressor() | |
144 | decompressed = dctx.decompress(compressed) |
|
144 | decompressed = dctx.decompress(compressed) | |
145 | self.assertEqual(decompressed, b"foobar") |
|
145 | self.assertEqual(decompressed, b"foobar") | |
146 |
|
146 | |||
147 | def test_empty_roundtrip(self): |
|
147 | def test_empty_roundtrip(self): | |
148 | cctx = zstd.ZstdCompressor() |
|
148 | cctx = zstd.ZstdCompressor() | |
149 | compressed = cctx.compress(b"") |
|
149 | compressed = cctx.compress(b"") | |
150 |
|
150 | |||
151 | dctx = zstd.ZstdDecompressor() |
|
151 | dctx = zstd.ZstdDecompressor() | |
152 | decompressed = dctx.decompress(compressed) |
|
152 | decompressed = dctx.decompress(compressed) | |
153 |
|
153 | |||
154 | self.assertEqual(decompressed, b"") |
|
154 | self.assertEqual(decompressed, b"") | |
155 |
|
155 | |||
156 | def test_max_output_size(self): |
|
156 | def test_max_output_size(self): | |
157 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
157 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
158 | source = b"foobar" * 256 |
|
158 | source = b"foobar" * 256 | |
159 | compressed = cctx.compress(source) |
|
159 | compressed = cctx.compress(source) | |
160 |
|
160 | |||
161 | dctx = zstd.ZstdDecompressor() |
|
161 | dctx = zstd.ZstdDecompressor() | |
162 | # Will fit into buffer exactly the size of input. |
|
162 | # Will fit into buffer exactly the size of input. | |
163 | decompressed = dctx.decompress(compressed, max_output_size=len(source)) |
|
163 | decompressed = dctx.decompress(compressed, max_output_size=len(source)) | |
164 | self.assertEqual(decompressed, source) |
|
164 | self.assertEqual(decompressed, source) | |
165 |
|
165 | |||
166 | # Input size - 1 fails |
|
166 | # Input size - 1 fails | |
167 | with self.assertRaisesRegex( |
|
167 | with self.assertRaisesRegex( | |
168 | zstd.ZstdError, "decompression error: did not decompress full frame" |
|
168 | zstd.ZstdError, "decompression error: did not decompress full frame" | |
169 | ): |
|
169 | ): | |
170 | dctx.decompress(compressed, max_output_size=len(source) - 1) |
|
170 | dctx.decompress(compressed, max_output_size=len(source) - 1) | |
171 |
|
171 | |||
172 | # Input size + 1 works |
|
172 | # Input size + 1 works | |
173 |
decompressed = dctx.decompress( |
|
173 | decompressed = dctx.decompress( | |
|
174 | compressed, max_output_size=len(source) + 1 | |||
|
175 | ) | |||
174 | self.assertEqual(decompressed, source) |
|
176 | self.assertEqual(decompressed, source) | |
175 |
|
177 | |||
176 | # A much larger buffer works. |
|
178 | # A much larger buffer works. | |
177 |
decompressed = dctx.decompress( |
|
179 | decompressed = dctx.decompress( | |
|
180 | compressed, max_output_size=len(source) * 64 | |||
|
181 | ) | |||
178 | self.assertEqual(decompressed, source) |
|
182 | self.assertEqual(decompressed, source) | |
179 |
|
183 | |||
180 | def test_stupidly_large_output_buffer(self): |
|
184 | def test_stupidly_large_output_buffer(self): | |
181 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
185 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
182 | compressed = cctx.compress(b"foobar" * 256) |
|
186 | compressed = cctx.compress(b"foobar" * 256) | |
183 | dctx = zstd.ZstdDecompressor() |
|
187 | dctx = zstd.ZstdDecompressor() | |
184 |
|
188 | |||
185 | # Will get OverflowError on some Python distributions that can't |
|
189 | # Will get OverflowError on some Python distributions that can't | |
186 | # handle really large integers. |
|
190 | # handle really large integers. | |
187 | with self.assertRaises((MemoryError, OverflowError)): |
|
191 | with self.assertRaises((MemoryError, OverflowError)): | |
188 | dctx.decompress(compressed, max_output_size=2 ** 62) |
|
192 | dctx.decompress(compressed, max_output_size=2 ** 62) | |
189 |
|
193 | |||
190 | def test_dictionary(self): |
|
194 | def test_dictionary(self): | |
191 | samples = [] |
|
195 | samples = [] | |
192 | for i in range(128): |
|
196 | for i in range(128): | |
193 | samples.append(b"foo" * 64) |
|
197 | samples.append(b"foo" * 64) | |
194 | samples.append(b"bar" * 64) |
|
198 | samples.append(b"bar" * 64) | |
195 | samples.append(b"foobar" * 64) |
|
199 | samples.append(b"foobar" * 64) | |
196 |
|
200 | |||
197 | d = zstd.train_dictionary(8192, samples) |
|
201 | d = zstd.train_dictionary(8192, samples) | |
198 |
|
202 | |||
199 | orig = b"foobar" * 16384 |
|
203 | orig = b"foobar" * 16384 | |
200 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
204 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) | |
201 | compressed = cctx.compress(orig) |
|
205 | compressed = cctx.compress(orig) | |
202 |
|
206 | |||
203 | dctx = zstd.ZstdDecompressor(dict_data=d) |
|
207 | dctx = zstd.ZstdDecompressor(dict_data=d) | |
204 | decompressed = dctx.decompress(compressed) |
|
208 | decompressed = dctx.decompress(compressed) | |
205 |
|
209 | |||
206 | self.assertEqual(decompressed, orig) |
|
210 | self.assertEqual(decompressed, orig) | |
207 |
|
211 | |||
208 | def test_dictionary_multiple(self): |
|
212 | def test_dictionary_multiple(self): | |
209 | samples = [] |
|
213 | samples = [] | |
210 | for i in range(128): |
|
214 | for i in range(128): | |
211 | samples.append(b"foo" * 64) |
|
215 | samples.append(b"foo" * 64) | |
212 | samples.append(b"bar" * 64) |
|
216 | samples.append(b"bar" * 64) | |
213 | samples.append(b"foobar" * 64) |
|
217 | samples.append(b"foobar" * 64) | |
214 |
|
218 | |||
215 | d = zstd.train_dictionary(8192, samples) |
|
219 | d = zstd.train_dictionary(8192, samples) | |
216 |
|
220 | |||
217 | sources = (b"foobar" * 8192, b"foo" * 8192, b"bar" * 8192) |
|
221 | sources = (b"foobar" * 8192, b"foo" * 8192, b"bar" * 8192) | |
218 | compressed = [] |
|
222 | compressed = [] | |
219 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
223 | cctx = zstd.ZstdCompressor(level=1, dict_data=d) | |
220 | for source in sources: |
|
224 | for source in sources: | |
221 | compressed.append(cctx.compress(source)) |
|
225 | compressed.append(cctx.compress(source)) | |
222 |
|
226 | |||
223 | dctx = zstd.ZstdDecompressor(dict_data=d) |
|
227 | dctx = zstd.ZstdDecompressor(dict_data=d) | |
224 | for i in range(len(sources)): |
|
228 | for i in range(len(sources)): | |
225 | decompressed = dctx.decompress(compressed[i]) |
|
229 | decompressed = dctx.decompress(compressed[i]) | |
226 | self.assertEqual(decompressed, sources[i]) |
|
230 | self.assertEqual(decompressed, sources[i]) | |
227 |
|
231 | |||
228 | def test_max_window_size(self): |
|
232 | def test_max_window_size(self): | |
229 | with open(__file__, "rb") as fh: |
|
233 | with open(__file__, "rb") as fh: | |
230 | source = fh.read() |
|
234 | source = fh.read() | |
231 |
|
235 | |||
232 | # If we write a content size, the decompressor engages single pass |
|
236 | # If we write a content size, the decompressor engages single pass | |
233 | # mode and the window size doesn't come into play. |
|
237 | # mode and the window size doesn't come into play. | |
234 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
238 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
235 | frame = cctx.compress(source) |
|
239 | frame = cctx.compress(source) | |
236 |
|
240 | |||
237 | dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) |
|
241 | dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) | |
238 |
|
242 | |||
239 | with self.assertRaisesRegex( |
|
243 | with self.assertRaisesRegex( | |
240 | zstd.ZstdError, "decompression error: Frame requires too much memory" |
|
244 | zstd.ZstdError, | |
|
245 | "decompression error: Frame requires too much memory", | |||
241 | ): |
|
246 | ): | |
242 | dctx.decompress(frame, max_output_size=len(source)) |
|
247 | dctx.decompress(frame, max_output_size=len(source)) | |
243 |
|
248 | |||
244 |
|
249 | |||
245 | @make_cffi |
|
250 | @make_cffi | |
246 | class TestDecompressor_copy_stream(TestCase): |
|
251 | class TestDecompressor_copy_stream(TestCase): | |
247 | def test_no_read(self): |
|
252 | def test_no_read(self): | |
248 | source = object() |
|
253 | source = object() | |
249 | dest = io.BytesIO() |
|
254 | dest = io.BytesIO() | |
250 |
|
255 | |||
251 | dctx = zstd.ZstdDecompressor() |
|
256 | dctx = zstd.ZstdDecompressor() | |
252 | with self.assertRaises(ValueError): |
|
257 | with self.assertRaises(ValueError): | |
253 | dctx.copy_stream(source, dest) |
|
258 | dctx.copy_stream(source, dest) | |
254 |
|
259 | |||
255 | def test_no_write(self): |
|
260 | def test_no_write(self): | |
256 | source = io.BytesIO() |
|
261 | source = io.BytesIO() | |
257 | dest = object() |
|
262 | dest = object() | |
258 |
|
263 | |||
259 | dctx = zstd.ZstdDecompressor() |
|
264 | dctx = zstd.ZstdDecompressor() | |
260 | with self.assertRaises(ValueError): |
|
265 | with self.assertRaises(ValueError): | |
261 | dctx.copy_stream(source, dest) |
|
266 | dctx.copy_stream(source, dest) | |
262 |
|
267 | |||
263 | def test_empty(self): |
|
268 | def test_empty(self): | |
264 | source = io.BytesIO() |
|
269 | source = io.BytesIO() | |
265 | dest = io.BytesIO() |
|
270 | dest = io.BytesIO() | |
266 |
|
271 | |||
267 | dctx = zstd.ZstdDecompressor() |
|
272 | dctx = zstd.ZstdDecompressor() | |
268 | # TODO should this raise an error? |
|
273 | # TODO should this raise an error? | |
269 | r, w = dctx.copy_stream(source, dest) |
|
274 | r, w = dctx.copy_stream(source, dest) | |
270 |
|
275 | |||
271 | self.assertEqual(r, 0) |
|
276 | self.assertEqual(r, 0) | |
272 | self.assertEqual(w, 0) |
|
277 | self.assertEqual(w, 0) | |
273 | self.assertEqual(dest.getvalue(), b"") |
|
278 | self.assertEqual(dest.getvalue(), b"") | |
274 |
|
279 | |||
275 | def test_large_data(self): |
|
280 | def test_large_data(self): | |
276 | source = io.BytesIO() |
|
281 | source = io.BytesIO() | |
277 | for i in range(255): |
|
282 | for i in range(255): | |
278 | source.write(struct.Struct(">B").pack(i) * 16384) |
|
283 | source.write(struct.Struct(">B").pack(i) * 16384) | |
279 | source.seek(0) |
|
284 | source.seek(0) | |
280 |
|
285 | |||
281 | compressed = io.BytesIO() |
|
286 | compressed = io.BytesIO() | |
282 | cctx = zstd.ZstdCompressor() |
|
287 | cctx = zstd.ZstdCompressor() | |
283 | cctx.copy_stream(source, compressed) |
|
288 | cctx.copy_stream(source, compressed) | |
284 |
|
289 | |||
285 | compressed.seek(0) |
|
290 | compressed.seek(0) | |
286 | dest = io.BytesIO() |
|
291 | dest = io.BytesIO() | |
287 | dctx = zstd.ZstdDecompressor() |
|
292 | dctx = zstd.ZstdDecompressor() | |
288 | r, w = dctx.copy_stream(compressed, dest) |
|
293 | r, w = dctx.copy_stream(compressed, dest) | |
289 |
|
294 | |||
290 | self.assertEqual(r, len(compressed.getvalue())) |
|
295 | self.assertEqual(r, len(compressed.getvalue())) | |
291 | self.assertEqual(w, len(source.getvalue())) |
|
296 | self.assertEqual(w, len(source.getvalue())) | |
292 |
|
297 | |||
293 | def test_read_write_size(self): |
|
298 | def test_read_write_size(self): | |
294 | source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) |
|
299 | source = OpCountingBytesIO( | |
|
300 | zstd.ZstdCompressor().compress(b"foobarfoobar") | |||
|
301 | ) | |||
295 |
|
302 | |||
296 | dest = OpCountingBytesIO() |
|
303 | dest = OpCountingBytesIO() | |
297 | dctx = zstd.ZstdDecompressor() |
|
304 | dctx = zstd.ZstdDecompressor() | |
298 | r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) |
|
305 | r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) | |
299 |
|
306 | |||
300 | self.assertEqual(r, len(source.getvalue())) |
|
307 | self.assertEqual(r, len(source.getvalue())) | |
301 | self.assertEqual(w, len(b"foobarfoobar")) |
|
308 | self.assertEqual(w, len(b"foobarfoobar")) | |
302 | self.assertEqual(source._read_count, len(source.getvalue()) + 1) |
|
309 | self.assertEqual(source._read_count, len(source.getvalue()) + 1) | |
303 | self.assertEqual(dest._write_count, len(dest.getvalue())) |
|
310 | self.assertEqual(dest._write_count, len(dest.getvalue())) | |
304 |
|
311 | |||
305 |
|
312 | |||
306 | @make_cffi |
|
313 | @make_cffi | |
307 | class TestDecompressor_stream_reader(TestCase): |
|
314 | class TestDecompressor_stream_reader(TestCase): | |
308 | def test_context_manager(self): |
|
315 | def test_context_manager(self): | |
309 | dctx = zstd.ZstdDecompressor() |
|
316 | dctx = zstd.ZstdDecompressor() | |
310 |
|
317 | |||
311 | with dctx.stream_reader(b"foo") as reader: |
|
318 | with dctx.stream_reader(b"foo") as reader: | |
312 |
with self.assertRaisesRegex( |
|
319 | with self.assertRaisesRegex( | |
|
320 | ValueError, "cannot __enter__ multiple times" | |||
|
321 | ): | |||
313 | with reader as reader2: |
|
322 | with reader as reader2: | |
314 | pass |
|
323 | pass | |
315 |
|
324 | |||
316 | def test_not_implemented(self): |
|
325 | def test_not_implemented(self): | |
317 | dctx = zstd.ZstdDecompressor() |
|
326 | dctx = zstd.ZstdDecompressor() | |
318 |
|
327 | |||
319 | with dctx.stream_reader(b"foo") as reader: |
|
328 | with dctx.stream_reader(b"foo") as reader: | |
320 | with self.assertRaises(io.UnsupportedOperation): |
|
329 | with self.assertRaises(io.UnsupportedOperation): | |
321 | reader.readline() |
|
330 | reader.readline() | |
322 |
|
331 | |||
323 | with self.assertRaises(io.UnsupportedOperation): |
|
332 | with self.assertRaises(io.UnsupportedOperation): | |
324 | reader.readlines() |
|
333 | reader.readlines() | |
325 |
|
334 | |||
326 | with self.assertRaises(io.UnsupportedOperation): |
|
335 | with self.assertRaises(io.UnsupportedOperation): | |
327 | iter(reader) |
|
336 | iter(reader) | |
328 |
|
337 | |||
329 | with self.assertRaises(io.UnsupportedOperation): |
|
338 | with self.assertRaises(io.UnsupportedOperation): | |
330 | next(reader) |
|
339 | next(reader) | |
331 |
|
340 | |||
332 | with self.assertRaises(io.UnsupportedOperation): |
|
341 | with self.assertRaises(io.UnsupportedOperation): | |
333 | reader.write(b"foo") |
|
342 | reader.write(b"foo") | |
334 |
|
343 | |||
335 | with self.assertRaises(io.UnsupportedOperation): |
|
344 | with self.assertRaises(io.UnsupportedOperation): | |
336 | reader.writelines([]) |
|
345 | reader.writelines([]) | |
337 |
|
346 | |||
338 | def test_constant_methods(self): |
|
347 | def test_constant_methods(self): | |
339 | dctx = zstd.ZstdDecompressor() |
|
348 | dctx = zstd.ZstdDecompressor() | |
340 |
|
349 | |||
341 | with dctx.stream_reader(b"foo") as reader: |
|
350 | with dctx.stream_reader(b"foo") as reader: | |
342 | self.assertFalse(reader.closed) |
|
351 | self.assertFalse(reader.closed) | |
343 | self.assertTrue(reader.readable()) |
|
352 | self.assertTrue(reader.readable()) | |
344 | self.assertFalse(reader.writable()) |
|
353 | self.assertFalse(reader.writable()) | |
345 | self.assertTrue(reader.seekable()) |
|
354 | self.assertTrue(reader.seekable()) | |
346 | self.assertFalse(reader.isatty()) |
|
355 | self.assertFalse(reader.isatty()) | |
347 | self.assertFalse(reader.closed) |
|
356 | self.assertFalse(reader.closed) | |
348 | self.assertIsNone(reader.flush()) |
|
357 | self.assertIsNone(reader.flush()) | |
349 | self.assertFalse(reader.closed) |
|
358 | self.assertFalse(reader.closed) | |
350 |
|
359 | |||
351 | self.assertTrue(reader.closed) |
|
360 | self.assertTrue(reader.closed) | |
352 |
|
361 | |||
353 | def test_read_closed(self): |
|
362 | def test_read_closed(self): | |
354 | dctx = zstd.ZstdDecompressor() |
|
363 | dctx = zstd.ZstdDecompressor() | |
355 |
|
364 | |||
356 | with dctx.stream_reader(b"foo") as reader: |
|
365 | with dctx.stream_reader(b"foo") as reader: | |
357 | reader.close() |
|
366 | reader.close() | |
358 | self.assertTrue(reader.closed) |
|
367 | self.assertTrue(reader.closed) | |
359 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
368 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
360 | reader.read(1) |
|
369 | reader.read(1) | |
361 |
|
370 | |||
362 | def test_read_sizes(self): |
|
371 | def test_read_sizes(self): | |
363 | cctx = zstd.ZstdCompressor() |
|
372 | cctx = zstd.ZstdCompressor() | |
364 | foo = cctx.compress(b"foo") |
|
373 | foo = cctx.compress(b"foo") | |
365 |
|
374 | |||
366 | dctx = zstd.ZstdDecompressor() |
|
375 | dctx = zstd.ZstdDecompressor() | |
367 |
|
376 | |||
368 | with dctx.stream_reader(foo) as reader: |
|
377 | with dctx.stream_reader(foo) as reader: | |
369 | with self.assertRaisesRegex( |
|
378 | with self.assertRaisesRegex( | |
370 | ValueError, "cannot read negative amounts less than -1" |
|
379 | ValueError, "cannot read negative amounts less than -1" | |
371 | ): |
|
380 | ): | |
372 | reader.read(-2) |
|
381 | reader.read(-2) | |
373 |
|
382 | |||
374 | self.assertEqual(reader.read(0), b"") |
|
383 | self.assertEqual(reader.read(0), b"") | |
375 | self.assertEqual(reader.read(), b"foo") |
|
384 | self.assertEqual(reader.read(), b"foo") | |
376 |
|
385 | |||
377 | def test_read_buffer(self): |
|
386 | def test_read_buffer(self): | |
378 | cctx = zstd.ZstdCompressor() |
|
387 | cctx = zstd.ZstdCompressor() | |
379 |
|
388 | |||
380 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) |
|
389 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) | |
381 | frame = cctx.compress(source) |
|
390 | frame = cctx.compress(source) | |
382 |
|
391 | |||
383 | dctx = zstd.ZstdDecompressor() |
|
392 | dctx = zstd.ZstdDecompressor() | |
384 |
|
393 | |||
385 | with dctx.stream_reader(frame) as reader: |
|
394 | with dctx.stream_reader(frame) as reader: | |
386 | self.assertEqual(reader.tell(), 0) |
|
395 | self.assertEqual(reader.tell(), 0) | |
387 |
|
396 | |||
388 | # We should get entire frame in one read. |
|
397 | # We should get entire frame in one read. | |
389 | result = reader.read(8192) |
|
398 | result = reader.read(8192) | |
390 | self.assertEqual(result, source) |
|
399 | self.assertEqual(result, source) | |
391 | self.assertEqual(reader.tell(), len(source)) |
|
400 | self.assertEqual(reader.tell(), len(source)) | |
392 |
|
401 | |||
393 | # Read after EOF should return empty bytes. |
|
402 | # Read after EOF should return empty bytes. | |
394 | self.assertEqual(reader.read(1), b"") |
|
403 | self.assertEqual(reader.read(1), b"") | |
395 | self.assertEqual(reader.tell(), len(result)) |
|
404 | self.assertEqual(reader.tell(), len(result)) | |
396 |
|
405 | |||
397 | self.assertTrue(reader.closed) |
|
406 | self.assertTrue(reader.closed) | |
398 |
|
407 | |||
399 | def test_read_buffer_small_chunks(self): |
|
408 | def test_read_buffer_small_chunks(self): | |
400 | cctx = zstd.ZstdCompressor() |
|
409 | cctx = zstd.ZstdCompressor() | |
401 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) |
|
410 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) | |
402 | frame = cctx.compress(source) |
|
411 | frame = cctx.compress(source) | |
403 |
|
412 | |||
404 | dctx = zstd.ZstdDecompressor() |
|
413 | dctx = zstd.ZstdDecompressor() | |
405 | chunks = [] |
|
414 | chunks = [] | |
406 |
|
415 | |||
407 | with dctx.stream_reader(frame, read_size=1) as reader: |
|
416 | with dctx.stream_reader(frame, read_size=1) as reader: | |
408 | while True: |
|
417 | while True: | |
409 | chunk = reader.read(1) |
|
418 | chunk = reader.read(1) | |
410 | if not chunk: |
|
419 | if not chunk: | |
411 | break |
|
420 | break | |
412 |
|
421 | |||
413 | chunks.append(chunk) |
|
422 | chunks.append(chunk) | |
414 | self.assertEqual(reader.tell(), sum(map(len, chunks))) |
|
423 | self.assertEqual(reader.tell(), sum(map(len, chunks))) | |
415 |
|
424 | |||
416 | self.assertEqual(b"".join(chunks), source) |
|
425 | self.assertEqual(b"".join(chunks), source) | |
417 |
|
426 | |||
418 | def test_read_stream(self): |
|
427 | def test_read_stream(self): | |
419 | cctx = zstd.ZstdCompressor() |
|
428 | cctx = zstd.ZstdCompressor() | |
420 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) |
|
429 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) | |
421 | frame = cctx.compress(source) |
|
430 | frame = cctx.compress(source) | |
422 |
|
431 | |||
423 | dctx = zstd.ZstdDecompressor() |
|
432 | dctx = zstd.ZstdDecompressor() | |
424 | with dctx.stream_reader(io.BytesIO(frame)) as reader: |
|
433 | with dctx.stream_reader(io.BytesIO(frame)) as reader: | |
425 | self.assertEqual(reader.tell(), 0) |
|
434 | self.assertEqual(reader.tell(), 0) | |
426 |
|
435 | |||
427 | chunk = reader.read(8192) |
|
436 | chunk = reader.read(8192) | |
428 | self.assertEqual(chunk, source) |
|
437 | self.assertEqual(chunk, source) | |
429 | self.assertEqual(reader.tell(), len(source)) |
|
438 | self.assertEqual(reader.tell(), len(source)) | |
430 | self.assertEqual(reader.read(1), b"") |
|
439 | self.assertEqual(reader.read(1), b"") | |
431 | self.assertEqual(reader.tell(), len(source)) |
|
440 | self.assertEqual(reader.tell(), len(source)) | |
432 | self.assertFalse(reader.closed) |
|
441 | self.assertFalse(reader.closed) | |
433 |
|
442 | |||
434 | self.assertTrue(reader.closed) |
|
443 | self.assertTrue(reader.closed) | |
435 |
|
444 | |||
436 | def test_read_stream_small_chunks(self): |
|
445 | def test_read_stream_small_chunks(self): | |
437 | cctx = zstd.ZstdCompressor() |
|
446 | cctx = zstd.ZstdCompressor() | |
438 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) |
|
447 | source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) | |
439 | frame = cctx.compress(source) |
|
448 | frame = cctx.compress(source) | |
440 |
|
449 | |||
441 | dctx = zstd.ZstdDecompressor() |
|
450 | dctx = zstd.ZstdDecompressor() | |
442 | chunks = [] |
|
451 | chunks = [] | |
443 |
|
452 | |||
444 | with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader: |
|
453 | with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader: | |
445 | while True: |
|
454 | while True: | |
446 | chunk = reader.read(1) |
|
455 | chunk = reader.read(1) | |
447 | if not chunk: |
|
456 | if not chunk: | |
448 | break |
|
457 | break | |
449 |
|
458 | |||
450 | chunks.append(chunk) |
|
459 | chunks.append(chunk) | |
451 | self.assertEqual(reader.tell(), sum(map(len, chunks))) |
|
460 | self.assertEqual(reader.tell(), sum(map(len, chunks))) | |
452 |
|
461 | |||
453 | self.assertEqual(b"".join(chunks), source) |
|
462 | self.assertEqual(b"".join(chunks), source) | |
454 |
|
463 | |||
455 | def test_read_after_exit(self): |
|
464 | def test_read_after_exit(self): | |
456 | cctx = zstd.ZstdCompressor() |
|
465 | cctx = zstd.ZstdCompressor() | |
457 | frame = cctx.compress(b"foo" * 60) |
|
466 | frame = cctx.compress(b"foo" * 60) | |
458 |
|
467 | |||
459 | dctx = zstd.ZstdDecompressor() |
|
468 | dctx = zstd.ZstdDecompressor() | |
460 |
|
469 | |||
461 | with dctx.stream_reader(frame) as reader: |
|
470 | with dctx.stream_reader(frame) as reader: | |
462 | while reader.read(16): |
|
471 | while reader.read(16): | |
463 | pass |
|
472 | pass | |
464 |
|
473 | |||
465 | self.assertTrue(reader.closed) |
|
474 | self.assertTrue(reader.closed) | |
466 |
|
475 | |||
467 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
476 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
468 | reader.read(10) |
|
477 | reader.read(10) | |
469 |
|
478 | |||
470 | def test_illegal_seeks(self): |
|
479 | def test_illegal_seeks(self): | |
471 | cctx = zstd.ZstdCompressor() |
|
480 | cctx = zstd.ZstdCompressor() | |
472 | frame = cctx.compress(b"foo" * 60) |
|
481 | frame = cctx.compress(b"foo" * 60) | |
473 |
|
482 | |||
474 | dctx = zstd.ZstdDecompressor() |
|
483 | dctx = zstd.ZstdDecompressor() | |
475 |
|
484 | |||
476 | with dctx.stream_reader(frame) as reader: |
|
485 | with dctx.stream_reader(frame) as reader: | |
477 |
with self.assertRaisesRegex( |
|
486 | with self.assertRaisesRegex( | |
|
487 | ValueError, "cannot seek to negative position" | |||
|
488 | ): | |||
478 | reader.seek(-1, os.SEEK_SET) |
|
489 | reader.seek(-1, os.SEEK_SET) | |
479 |
|
490 | |||
480 | reader.read(1) |
|
491 | reader.read(1) | |
481 |
|
492 | |||
482 | with self.assertRaisesRegex( |
|
493 | with self.assertRaisesRegex( | |
483 | ValueError, "cannot seek zstd decompression stream backwards" |
|
494 | ValueError, "cannot seek zstd decompression stream backwards" | |
484 | ): |
|
495 | ): | |
485 | reader.seek(0, os.SEEK_SET) |
|
496 | reader.seek(0, os.SEEK_SET) | |
486 |
|
497 | |||
487 | with self.assertRaisesRegex( |
|
498 | with self.assertRaisesRegex( | |
488 | ValueError, "cannot seek zstd decompression stream backwards" |
|
499 | ValueError, "cannot seek zstd decompression stream backwards" | |
489 | ): |
|
500 | ): | |
490 | reader.seek(-1, os.SEEK_CUR) |
|
501 | reader.seek(-1, os.SEEK_CUR) | |
491 |
|
502 | |||
492 | with self.assertRaisesRegex( |
|
503 | with self.assertRaisesRegex( | |
493 | ValueError, "zstd decompression streams cannot be seeked with SEEK_END" |
|
504 | ValueError, | |
|
505 | "zstd decompression streams cannot be seeked with SEEK_END", | |||
494 | ): |
|
506 | ): | |
495 | reader.seek(0, os.SEEK_END) |
|
507 | reader.seek(0, os.SEEK_END) | |
496 |
|
508 | |||
497 | reader.close() |
|
509 | reader.close() | |
498 |
|
510 | |||
499 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
511 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
500 | reader.seek(4, os.SEEK_SET) |
|
512 | reader.seek(4, os.SEEK_SET) | |
501 |
|
513 | |||
502 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
514 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
503 | reader.seek(0) |
|
515 | reader.seek(0) | |
504 |
|
516 | |||
505 | def test_seek(self): |
|
517 | def test_seek(self): | |
506 | source = b"foobar" * 60 |
|
518 | source = b"foobar" * 60 | |
507 | cctx = zstd.ZstdCompressor() |
|
519 | cctx = zstd.ZstdCompressor() | |
508 | frame = cctx.compress(source) |
|
520 | frame = cctx.compress(source) | |
509 |
|
521 | |||
510 | dctx = zstd.ZstdDecompressor() |
|
522 | dctx = zstd.ZstdDecompressor() | |
511 |
|
523 | |||
512 | with dctx.stream_reader(frame) as reader: |
|
524 | with dctx.stream_reader(frame) as reader: | |
513 | reader.seek(3) |
|
525 | reader.seek(3) | |
514 | self.assertEqual(reader.read(3), b"bar") |
|
526 | self.assertEqual(reader.read(3), b"bar") | |
515 |
|
527 | |||
516 | reader.seek(4, os.SEEK_CUR) |
|
528 | reader.seek(4, os.SEEK_CUR) | |
517 | self.assertEqual(reader.read(2), b"ar") |
|
529 | self.assertEqual(reader.read(2), b"ar") | |
518 |
|
530 | |||
519 | def test_no_context_manager(self): |
|
531 | def test_no_context_manager(self): | |
520 | source = b"foobar" * 60 |
|
532 | source = b"foobar" * 60 | |
521 | cctx = zstd.ZstdCompressor() |
|
533 | cctx = zstd.ZstdCompressor() | |
522 | frame = cctx.compress(source) |
|
534 | frame = cctx.compress(source) | |
523 |
|
535 | |||
524 | dctx = zstd.ZstdDecompressor() |
|
536 | dctx = zstd.ZstdDecompressor() | |
525 | reader = dctx.stream_reader(frame) |
|
537 | reader = dctx.stream_reader(frame) | |
526 |
|
538 | |||
527 | self.assertEqual(reader.read(6), b"foobar") |
|
539 | self.assertEqual(reader.read(6), b"foobar") | |
528 | self.assertEqual(reader.read(18), b"foobar" * 3) |
|
540 | self.assertEqual(reader.read(18), b"foobar" * 3) | |
529 | self.assertFalse(reader.closed) |
|
541 | self.assertFalse(reader.closed) | |
530 |
|
542 | |||
531 | # Calling close prevents subsequent use. |
|
543 | # Calling close prevents subsequent use. | |
532 | reader.close() |
|
544 | reader.close() | |
533 | self.assertTrue(reader.closed) |
|
545 | self.assertTrue(reader.closed) | |
534 |
|
546 | |||
535 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
547 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
536 | reader.read(6) |
|
548 | reader.read(6) | |
537 |
|
549 | |||
538 | def test_read_after_error(self): |
|
550 | def test_read_after_error(self): | |
539 | source = io.BytesIO(b"") |
|
551 | source = io.BytesIO(b"") | |
540 | dctx = zstd.ZstdDecompressor() |
|
552 | dctx = zstd.ZstdDecompressor() | |
541 |
|
553 | |||
542 | reader = dctx.stream_reader(source) |
|
554 | reader = dctx.stream_reader(source) | |
543 |
|
555 | |||
544 | with reader: |
|
556 | with reader: | |
545 | reader.read(0) |
|
557 | reader.read(0) | |
546 |
|
558 | |||
547 | with reader: |
|
559 | with reader: | |
548 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
560 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
549 | reader.read(100) |
|
561 | reader.read(100) | |
550 |
|
562 | |||
551 | def test_partial_read(self): |
|
563 | def test_partial_read(self): | |
552 | # Inspired by https://github.com/indygreg/python-zstandard/issues/71. |
|
564 | # Inspired by https://github.com/indygreg/python-zstandard/issues/71. | |
553 | buffer = io.BytesIO() |
|
565 | buffer = io.BytesIO() | |
554 | cctx = zstd.ZstdCompressor() |
|
566 | cctx = zstd.ZstdCompressor() | |
555 | writer = cctx.stream_writer(buffer) |
|
567 | writer = cctx.stream_writer(buffer) | |
556 | writer.write(bytearray(os.urandom(1000000))) |
|
568 | writer.write(bytearray(os.urandom(1000000))) | |
557 | writer.flush(zstd.FLUSH_FRAME) |
|
569 | writer.flush(zstd.FLUSH_FRAME) | |
558 | buffer.seek(0) |
|
570 | buffer.seek(0) | |
559 |
|
571 | |||
560 | dctx = zstd.ZstdDecompressor() |
|
572 | dctx = zstd.ZstdDecompressor() | |
561 | reader = dctx.stream_reader(buffer) |
|
573 | reader = dctx.stream_reader(buffer) | |
562 |
|
574 | |||
563 | while True: |
|
575 | while True: | |
564 | chunk = reader.read(8192) |
|
576 | chunk = reader.read(8192) | |
565 | if not chunk: |
|
577 | if not chunk: | |
566 | break |
|
578 | break | |
567 |
|
579 | |||
568 | def test_read_multiple_frames(self): |
|
580 | def test_read_multiple_frames(self): | |
569 | cctx = zstd.ZstdCompressor() |
|
581 | cctx = zstd.ZstdCompressor() | |
570 | source = io.BytesIO() |
|
582 | source = io.BytesIO() | |
571 | writer = cctx.stream_writer(source) |
|
583 | writer = cctx.stream_writer(source) | |
572 | writer.write(b"foo") |
|
584 | writer.write(b"foo") | |
573 | writer.flush(zstd.FLUSH_FRAME) |
|
585 | writer.flush(zstd.FLUSH_FRAME) | |
574 | writer.write(b"bar") |
|
586 | writer.write(b"bar") | |
575 | writer.flush(zstd.FLUSH_FRAME) |
|
587 | writer.flush(zstd.FLUSH_FRAME) | |
576 |
|
588 | |||
577 | dctx = zstd.ZstdDecompressor() |
|
589 | dctx = zstd.ZstdDecompressor() | |
578 |
|
590 | |||
579 | reader = dctx.stream_reader(source.getvalue()) |
|
591 | reader = dctx.stream_reader(source.getvalue()) | |
580 | self.assertEqual(reader.read(2), b"fo") |
|
592 | self.assertEqual(reader.read(2), b"fo") | |
581 | self.assertEqual(reader.read(2), b"o") |
|
593 | self.assertEqual(reader.read(2), b"o") | |
582 | self.assertEqual(reader.read(2), b"ba") |
|
594 | self.assertEqual(reader.read(2), b"ba") | |
583 | self.assertEqual(reader.read(2), b"r") |
|
595 | self.assertEqual(reader.read(2), b"r") | |
584 |
|
596 | |||
585 | source.seek(0) |
|
597 | source.seek(0) | |
586 | reader = dctx.stream_reader(source) |
|
598 | reader = dctx.stream_reader(source) | |
587 | self.assertEqual(reader.read(2), b"fo") |
|
599 | self.assertEqual(reader.read(2), b"fo") | |
588 | self.assertEqual(reader.read(2), b"o") |
|
600 | self.assertEqual(reader.read(2), b"o") | |
589 | self.assertEqual(reader.read(2), b"ba") |
|
601 | self.assertEqual(reader.read(2), b"ba") | |
590 | self.assertEqual(reader.read(2), b"r") |
|
602 | self.assertEqual(reader.read(2), b"r") | |
591 |
|
603 | |||
592 | reader = dctx.stream_reader(source.getvalue()) |
|
604 | reader = dctx.stream_reader(source.getvalue()) | |
593 | self.assertEqual(reader.read(3), b"foo") |
|
605 | self.assertEqual(reader.read(3), b"foo") | |
594 | self.assertEqual(reader.read(3), b"bar") |
|
606 | self.assertEqual(reader.read(3), b"bar") | |
595 |
|
607 | |||
596 | source.seek(0) |
|
608 | source.seek(0) | |
597 | reader = dctx.stream_reader(source) |
|
609 | reader = dctx.stream_reader(source) | |
598 | self.assertEqual(reader.read(3), b"foo") |
|
610 | self.assertEqual(reader.read(3), b"foo") | |
599 | self.assertEqual(reader.read(3), b"bar") |
|
611 | self.assertEqual(reader.read(3), b"bar") | |
600 |
|
612 | |||
601 | reader = dctx.stream_reader(source.getvalue()) |
|
613 | reader = dctx.stream_reader(source.getvalue()) | |
602 | self.assertEqual(reader.read(4), b"foo") |
|
614 | self.assertEqual(reader.read(4), b"foo") | |
603 | self.assertEqual(reader.read(4), b"bar") |
|
615 | self.assertEqual(reader.read(4), b"bar") | |
604 |
|
616 | |||
605 | source.seek(0) |
|
617 | source.seek(0) | |
606 | reader = dctx.stream_reader(source) |
|
618 | reader = dctx.stream_reader(source) | |
607 | self.assertEqual(reader.read(4), b"foo") |
|
619 | self.assertEqual(reader.read(4), b"foo") | |
608 | self.assertEqual(reader.read(4), b"bar") |
|
620 | self.assertEqual(reader.read(4), b"bar") | |
609 |
|
621 | |||
610 | reader = dctx.stream_reader(source.getvalue()) |
|
622 | reader = dctx.stream_reader(source.getvalue()) | |
611 | self.assertEqual(reader.read(128), b"foo") |
|
623 | self.assertEqual(reader.read(128), b"foo") | |
612 | self.assertEqual(reader.read(128), b"bar") |
|
624 | self.assertEqual(reader.read(128), b"bar") | |
613 |
|
625 | |||
614 | source.seek(0) |
|
626 | source.seek(0) | |
615 | reader = dctx.stream_reader(source) |
|
627 | reader = dctx.stream_reader(source) | |
616 | self.assertEqual(reader.read(128), b"foo") |
|
628 | self.assertEqual(reader.read(128), b"foo") | |
617 | self.assertEqual(reader.read(128), b"bar") |
|
629 | self.assertEqual(reader.read(128), b"bar") | |
618 |
|
630 | |||
619 | # Now tests for reads spanning frames. |
|
631 | # Now tests for reads spanning frames. | |
620 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
632 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
621 | self.assertEqual(reader.read(3), b"foo") |
|
633 | self.assertEqual(reader.read(3), b"foo") | |
622 | self.assertEqual(reader.read(3), b"bar") |
|
634 | self.assertEqual(reader.read(3), b"bar") | |
623 |
|
635 | |||
624 | source.seek(0) |
|
636 | source.seek(0) | |
625 | reader = dctx.stream_reader(source, read_across_frames=True) |
|
637 | reader = dctx.stream_reader(source, read_across_frames=True) | |
626 | self.assertEqual(reader.read(3), b"foo") |
|
638 | self.assertEqual(reader.read(3), b"foo") | |
627 | self.assertEqual(reader.read(3), b"bar") |
|
639 | self.assertEqual(reader.read(3), b"bar") | |
628 |
|
640 | |||
629 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
641 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
630 | self.assertEqual(reader.read(6), b"foobar") |
|
642 | self.assertEqual(reader.read(6), b"foobar") | |
631 |
|
643 | |||
632 | source.seek(0) |
|
644 | source.seek(0) | |
633 | reader = dctx.stream_reader(source, read_across_frames=True) |
|
645 | reader = dctx.stream_reader(source, read_across_frames=True) | |
634 | self.assertEqual(reader.read(6), b"foobar") |
|
646 | self.assertEqual(reader.read(6), b"foobar") | |
635 |
|
647 | |||
636 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
648 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
637 | self.assertEqual(reader.read(7), b"foobar") |
|
649 | self.assertEqual(reader.read(7), b"foobar") | |
638 |
|
650 | |||
639 | source.seek(0) |
|
651 | source.seek(0) | |
640 | reader = dctx.stream_reader(source, read_across_frames=True) |
|
652 | reader = dctx.stream_reader(source, read_across_frames=True) | |
641 | self.assertEqual(reader.read(7), b"foobar") |
|
653 | self.assertEqual(reader.read(7), b"foobar") | |
642 |
|
654 | |||
643 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
655 | reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
644 | self.assertEqual(reader.read(128), b"foobar") |
|
656 | self.assertEqual(reader.read(128), b"foobar") | |
645 |
|
657 | |||
646 | source.seek(0) |
|
658 | source.seek(0) | |
647 | reader = dctx.stream_reader(source, read_across_frames=True) |
|
659 | reader = dctx.stream_reader(source, read_across_frames=True) | |
648 | self.assertEqual(reader.read(128), b"foobar") |
|
660 | self.assertEqual(reader.read(128), b"foobar") | |
649 |
|
661 | |||
650 | def test_readinto(self): |
|
662 | def test_readinto(self): | |
651 | cctx = zstd.ZstdCompressor() |
|
663 | cctx = zstd.ZstdCompressor() | |
652 | foo = cctx.compress(b"foo") |
|
664 | foo = cctx.compress(b"foo") | |
653 |
|
665 | |||
654 | dctx = zstd.ZstdDecompressor() |
|
666 | dctx = zstd.ZstdDecompressor() | |
655 |
|
667 | |||
656 | # Attempting to readinto() a non-writable buffer fails. |
|
668 | # Attempting to readinto() a non-writable buffer fails. | |
657 | # The exact exception varies based on the backend. |
|
669 | # The exact exception varies based on the backend. | |
658 | reader = dctx.stream_reader(foo) |
|
670 | reader = dctx.stream_reader(foo) | |
659 | with self.assertRaises(Exception): |
|
671 | with self.assertRaises(Exception): | |
660 | reader.readinto(b"foobar") |
|
672 | reader.readinto(b"foobar") | |
661 |
|
673 | |||
662 | # readinto() with sufficiently large destination. |
|
674 | # readinto() with sufficiently large destination. | |
663 | b = bytearray(1024) |
|
675 | b = bytearray(1024) | |
664 | reader = dctx.stream_reader(foo) |
|
676 | reader = dctx.stream_reader(foo) | |
665 | self.assertEqual(reader.readinto(b), 3) |
|
677 | self.assertEqual(reader.readinto(b), 3) | |
666 | self.assertEqual(b[0:3], b"foo") |
|
678 | self.assertEqual(b[0:3], b"foo") | |
667 | self.assertEqual(reader.readinto(b), 0) |
|
679 | self.assertEqual(reader.readinto(b), 0) | |
668 | self.assertEqual(b[0:3], b"foo") |
|
680 | self.assertEqual(b[0:3], b"foo") | |
669 |
|
681 | |||
670 | # readinto() with small reads. |
|
682 | # readinto() with small reads. | |
671 | b = bytearray(1024) |
|
683 | b = bytearray(1024) | |
672 | reader = dctx.stream_reader(foo, read_size=1) |
|
684 | reader = dctx.stream_reader(foo, read_size=1) | |
673 | self.assertEqual(reader.readinto(b), 3) |
|
685 | self.assertEqual(reader.readinto(b), 3) | |
674 | self.assertEqual(b[0:3], b"foo") |
|
686 | self.assertEqual(b[0:3], b"foo") | |
675 |
|
687 | |||
676 | # Too small destination buffer. |
|
688 | # Too small destination buffer. | |
677 | b = bytearray(2) |
|
689 | b = bytearray(2) | |
678 | reader = dctx.stream_reader(foo) |
|
690 | reader = dctx.stream_reader(foo) | |
679 | self.assertEqual(reader.readinto(b), 2) |
|
691 | self.assertEqual(reader.readinto(b), 2) | |
680 | self.assertEqual(b[:], b"fo") |
|
692 | self.assertEqual(b[:], b"fo") | |
681 |
|
693 | |||
682 | def test_readinto1(self): |
|
694 | def test_readinto1(self): | |
683 | cctx = zstd.ZstdCompressor() |
|
695 | cctx = zstd.ZstdCompressor() | |
684 | foo = cctx.compress(b"foo") |
|
696 | foo = cctx.compress(b"foo") | |
685 |
|
697 | |||
686 | dctx = zstd.ZstdDecompressor() |
|
698 | dctx = zstd.ZstdDecompressor() | |
687 |
|
699 | |||
688 | reader = dctx.stream_reader(foo) |
|
700 | reader = dctx.stream_reader(foo) | |
689 | with self.assertRaises(Exception): |
|
701 | with self.assertRaises(Exception): | |
690 | reader.readinto1(b"foobar") |
|
702 | reader.readinto1(b"foobar") | |
691 |
|
703 | |||
692 | # Sufficiently large destination. |
|
704 | # Sufficiently large destination. | |
693 | b = bytearray(1024) |
|
705 | b = bytearray(1024) | |
694 | reader = dctx.stream_reader(foo) |
|
706 | reader = dctx.stream_reader(foo) | |
695 | self.assertEqual(reader.readinto1(b), 3) |
|
707 | self.assertEqual(reader.readinto1(b), 3) | |
696 | self.assertEqual(b[0:3], b"foo") |
|
708 | self.assertEqual(b[0:3], b"foo") | |
697 | self.assertEqual(reader.readinto1(b), 0) |
|
709 | self.assertEqual(reader.readinto1(b), 0) | |
698 | self.assertEqual(b[0:3], b"foo") |
|
710 | self.assertEqual(b[0:3], b"foo") | |
699 |
|
711 | |||
700 | # readinto() with small reads. |
|
712 | # readinto() with small reads. | |
701 | b = bytearray(1024) |
|
713 | b = bytearray(1024) | |
702 | reader = dctx.stream_reader(foo, read_size=1) |
|
714 | reader = dctx.stream_reader(foo, read_size=1) | |
703 | self.assertEqual(reader.readinto1(b), 3) |
|
715 | self.assertEqual(reader.readinto1(b), 3) | |
704 | self.assertEqual(b[0:3], b"foo") |
|
716 | self.assertEqual(b[0:3], b"foo") | |
705 |
|
717 | |||
706 | # Too small destination buffer. |
|
718 | # Too small destination buffer. | |
707 | b = bytearray(2) |
|
719 | b = bytearray(2) | |
708 | reader = dctx.stream_reader(foo) |
|
720 | reader = dctx.stream_reader(foo) | |
709 | self.assertEqual(reader.readinto1(b), 2) |
|
721 | self.assertEqual(reader.readinto1(b), 2) | |
710 | self.assertEqual(b[:], b"fo") |
|
722 | self.assertEqual(b[:], b"fo") | |
711 |
|
723 | |||
712 | def test_readall(self): |
|
724 | def test_readall(self): | |
713 | cctx = zstd.ZstdCompressor() |
|
725 | cctx = zstd.ZstdCompressor() | |
714 | foo = cctx.compress(b"foo") |
|
726 | foo = cctx.compress(b"foo") | |
715 |
|
727 | |||
716 | dctx = zstd.ZstdDecompressor() |
|
728 | dctx = zstd.ZstdDecompressor() | |
717 | reader = dctx.stream_reader(foo) |
|
729 | reader = dctx.stream_reader(foo) | |
718 |
|
730 | |||
719 | self.assertEqual(reader.readall(), b"foo") |
|
731 | self.assertEqual(reader.readall(), b"foo") | |
720 |
|
732 | |||
721 | def test_read1(self): |
|
733 | def test_read1(self): | |
722 | cctx = zstd.ZstdCompressor() |
|
734 | cctx = zstd.ZstdCompressor() | |
723 | foo = cctx.compress(b"foo") |
|
735 | foo = cctx.compress(b"foo") | |
724 |
|
736 | |||
725 | dctx = zstd.ZstdDecompressor() |
|
737 | dctx = zstd.ZstdDecompressor() | |
726 |
|
738 | |||
727 | b = OpCountingBytesIO(foo) |
|
739 | b = OpCountingBytesIO(foo) | |
728 | reader = dctx.stream_reader(b) |
|
740 | reader = dctx.stream_reader(b) | |
729 |
|
741 | |||
730 | self.assertEqual(reader.read1(), b"foo") |
|
742 | self.assertEqual(reader.read1(), b"foo") | |
731 | self.assertEqual(b._read_count, 1) |
|
743 | self.assertEqual(b._read_count, 1) | |
732 |
|
744 | |||
733 | b = OpCountingBytesIO(foo) |
|
745 | b = OpCountingBytesIO(foo) | |
734 | reader = dctx.stream_reader(b) |
|
746 | reader = dctx.stream_reader(b) | |
735 |
|
747 | |||
736 | self.assertEqual(reader.read1(0), b"") |
|
748 | self.assertEqual(reader.read1(0), b"") | |
737 | self.assertEqual(reader.read1(2), b"fo") |
|
749 | self.assertEqual(reader.read1(2), b"fo") | |
738 | self.assertEqual(b._read_count, 1) |
|
750 | self.assertEqual(b._read_count, 1) | |
739 | self.assertEqual(reader.read1(1), b"o") |
|
751 | self.assertEqual(reader.read1(1), b"o") | |
740 | self.assertEqual(b._read_count, 1) |
|
752 | self.assertEqual(b._read_count, 1) | |
741 | self.assertEqual(reader.read1(1), b"") |
|
753 | self.assertEqual(reader.read1(1), b"") | |
742 | self.assertEqual(b._read_count, 2) |
|
754 | self.assertEqual(b._read_count, 2) | |
743 |
|
755 | |||
744 | def test_read_lines(self): |
|
756 | def test_read_lines(self): | |
745 | cctx = zstd.ZstdCompressor() |
|
757 | cctx = zstd.ZstdCompressor() | |
746 | source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024)) |
|
758 | source = b"\n".join( | |
|
759 | ("line %d" % i).encode("ascii") for i in range(1024) | |||
|
760 | ) | |||
747 |
|
761 | |||
748 | frame = cctx.compress(source) |
|
762 | frame = cctx.compress(source) | |
749 |
|
763 | |||
750 | dctx = zstd.ZstdDecompressor() |
|
764 | dctx = zstd.ZstdDecompressor() | |
751 | reader = dctx.stream_reader(frame) |
|
765 | reader = dctx.stream_reader(frame) | |
752 | tr = io.TextIOWrapper(reader, encoding="utf-8") |
|
766 | tr = io.TextIOWrapper(reader, encoding="utf-8") | |
753 |
|
767 | |||
754 | lines = [] |
|
768 | lines = [] | |
755 | for line in tr: |
|
769 | for line in tr: | |
756 | lines.append(line.encode("utf-8")) |
|
770 | lines.append(line.encode("utf-8")) | |
757 |
|
771 | |||
758 | self.assertEqual(len(lines), 1024) |
|
772 | self.assertEqual(len(lines), 1024) | |
759 | self.assertEqual(b"".join(lines), source) |
|
773 | self.assertEqual(b"".join(lines), source) | |
760 |
|
774 | |||
761 | reader = dctx.stream_reader(frame) |
|
775 | reader = dctx.stream_reader(frame) | |
762 | tr = io.TextIOWrapper(reader, encoding="utf-8") |
|
776 | tr = io.TextIOWrapper(reader, encoding="utf-8") | |
763 |
|
777 | |||
764 | lines = tr.readlines() |
|
778 | lines = tr.readlines() | |
765 | self.assertEqual(len(lines), 1024) |
|
779 | self.assertEqual(len(lines), 1024) | |
766 | self.assertEqual("".join(lines).encode("utf-8"), source) |
|
780 | self.assertEqual("".join(lines).encode("utf-8"), source) | |
767 |
|
781 | |||
768 | reader = dctx.stream_reader(frame) |
|
782 | reader = dctx.stream_reader(frame) | |
769 | tr = io.TextIOWrapper(reader, encoding="utf-8") |
|
783 | tr = io.TextIOWrapper(reader, encoding="utf-8") | |
770 |
|
784 | |||
771 | lines = [] |
|
785 | lines = [] | |
772 | while True: |
|
786 | while True: | |
773 | line = tr.readline() |
|
787 | line = tr.readline() | |
774 | if not line: |
|
788 | if not line: | |
775 | break |
|
789 | break | |
776 |
|
790 | |||
777 | lines.append(line.encode("utf-8")) |
|
791 | lines.append(line.encode("utf-8")) | |
778 |
|
792 | |||
779 | self.assertEqual(len(lines), 1024) |
|
793 | self.assertEqual(len(lines), 1024) | |
780 | self.assertEqual(b"".join(lines), source) |
|
794 | self.assertEqual(b"".join(lines), source) | |
781 |
|
795 | |||
782 |
|
796 | |||
783 | @make_cffi |
|
797 | @make_cffi | |
784 | class TestDecompressor_decompressobj(TestCase): |
|
798 | class TestDecompressor_decompressobj(TestCase): | |
785 | def test_simple(self): |
|
799 | def test_simple(self): | |
786 | data = zstd.ZstdCompressor(level=1).compress(b"foobar") |
|
800 | data = zstd.ZstdCompressor(level=1).compress(b"foobar") | |
787 |
|
801 | |||
788 | dctx = zstd.ZstdDecompressor() |
|
802 | dctx = zstd.ZstdDecompressor() | |
789 | dobj = dctx.decompressobj() |
|
803 | dobj = dctx.decompressobj() | |
790 | self.assertEqual(dobj.decompress(data), b"foobar") |
|
804 | self.assertEqual(dobj.decompress(data), b"foobar") | |
791 | self.assertIsNone(dobj.flush()) |
|
805 | self.assertIsNone(dobj.flush()) | |
792 | self.assertIsNone(dobj.flush(10)) |
|
806 | self.assertIsNone(dobj.flush(10)) | |
793 | self.assertIsNone(dobj.flush(length=100)) |
|
807 | self.assertIsNone(dobj.flush(length=100)) | |
794 |
|
808 | |||
795 | def test_input_types(self): |
|
809 | def test_input_types(self): | |
796 | compressed = zstd.ZstdCompressor(level=1).compress(b"foo") |
|
810 | compressed = zstd.ZstdCompressor(level=1).compress(b"foo") | |
797 |
|
811 | |||
798 | dctx = zstd.ZstdDecompressor() |
|
812 | dctx = zstd.ZstdDecompressor() | |
799 |
|
813 | |||
800 | mutable_array = bytearray(len(compressed)) |
|
814 | mutable_array = bytearray(len(compressed)) | |
801 | mutable_array[:] = compressed |
|
815 | mutable_array[:] = compressed | |
802 |
|
816 | |||
803 | sources = [ |
|
817 | sources = [ | |
804 | memoryview(compressed), |
|
818 | memoryview(compressed), | |
805 | bytearray(compressed), |
|
819 | bytearray(compressed), | |
806 | mutable_array, |
|
820 | mutable_array, | |
807 | ] |
|
821 | ] | |
808 |
|
822 | |||
809 | for source in sources: |
|
823 | for source in sources: | |
810 | dobj = dctx.decompressobj() |
|
824 | dobj = dctx.decompressobj() | |
811 | self.assertIsNone(dobj.flush()) |
|
825 | self.assertIsNone(dobj.flush()) | |
812 | self.assertIsNone(dobj.flush(10)) |
|
826 | self.assertIsNone(dobj.flush(10)) | |
813 | self.assertIsNone(dobj.flush(length=100)) |
|
827 | self.assertIsNone(dobj.flush(length=100)) | |
814 | self.assertEqual(dobj.decompress(source), b"foo") |
|
828 | self.assertEqual(dobj.decompress(source), b"foo") | |
815 | self.assertIsNone(dobj.flush()) |
|
829 | self.assertIsNone(dobj.flush()) | |
816 |
|
830 | |||
817 | def test_reuse(self): |
|
831 | def test_reuse(self): | |
818 | data = zstd.ZstdCompressor(level=1).compress(b"foobar") |
|
832 | data = zstd.ZstdCompressor(level=1).compress(b"foobar") | |
819 |
|
833 | |||
820 | dctx = zstd.ZstdDecompressor() |
|
834 | dctx = zstd.ZstdDecompressor() | |
821 | dobj = dctx.decompressobj() |
|
835 | dobj = dctx.decompressobj() | |
822 | dobj.decompress(data) |
|
836 | dobj.decompress(data) | |
823 |
|
837 | |||
824 |
with self.assertRaisesRegex( |
|
838 | with self.assertRaisesRegex( | |
|
839 | zstd.ZstdError, "cannot use a decompressobj" | |||
|
840 | ): | |||
825 | dobj.decompress(data) |
|
841 | dobj.decompress(data) | |
826 | self.assertIsNone(dobj.flush()) |
|
842 | self.assertIsNone(dobj.flush()) | |
827 |
|
843 | |||
828 | def test_bad_write_size(self): |
|
844 | def test_bad_write_size(self): | |
829 | dctx = zstd.ZstdDecompressor() |
|
845 | dctx = zstd.ZstdDecompressor() | |
830 |
|
846 | |||
831 | with self.assertRaisesRegex(ValueError, "write_size must be positive"): |
|
847 | with self.assertRaisesRegex(ValueError, "write_size must be positive"): | |
832 | dctx.decompressobj(write_size=0) |
|
848 | dctx.decompressobj(write_size=0) | |
833 |
|
849 | |||
834 | def test_write_size(self): |
|
850 | def test_write_size(self): | |
835 | source = b"foo" * 64 + b"bar" * 128 |
|
851 | source = b"foo" * 64 + b"bar" * 128 | |
836 | data = zstd.ZstdCompressor(level=1).compress(source) |
|
852 | data = zstd.ZstdCompressor(level=1).compress(source) | |
837 |
|
853 | |||
838 | dctx = zstd.ZstdDecompressor() |
|
854 | dctx = zstd.ZstdDecompressor() | |
839 |
|
855 | |||
840 | for i in range(128): |
|
856 | for i in range(128): | |
841 | dobj = dctx.decompressobj(write_size=i + 1) |
|
857 | dobj = dctx.decompressobj(write_size=i + 1) | |
842 | self.assertEqual(dobj.decompress(data), source) |
|
858 | self.assertEqual(dobj.decompress(data), source) | |
843 |
|
859 | |||
844 |
|
860 | |||
845 | def decompress_via_writer(data): |
|
861 | def decompress_via_writer(data): | |
846 | buffer = io.BytesIO() |
|
862 | buffer = io.BytesIO() | |
847 | dctx = zstd.ZstdDecompressor() |
|
863 | dctx = zstd.ZstdDecompressor() | |
848 | decompressor = dctx.stream_writer(buffer) |
|
864 | decompressor = dctx.stream_writer(buffer) | |
849 | decompressor.write(data) |
|
865 | decompressor.write(data) | |
850 |
|
866 | |||
851 | return buffer.getvalue() |
|
867 | return buffer.getvalue() | |
852 |
|
868 | |||
853 |
|
869 | |||
854 | @make_cffi |
|
870 | @make_cffi | |
855 | class TestDecompressor_stream_writer(TestCase): |
|
871 | class TestDecompressor_stream_writer(TestCase): | |
856 | def test_io_api(self): |
|
872 | def test_io_api(self): | |
857 | buffer = io.BytesIO() |
|
873 | buffer = io.BytesIO() | |
858 | dctx = zstd.ZstdDecompressor() |
|
874 | dctx = zstd.ZstdDecompressor() | |
859 | writer = dctx.stream_writer(buffer) |
|
875 | writer = dctx.stream_writer(buffer) | |
860 |
|
876 | |||
861 | self.assertFalse(writer.closed) |
|
877 | self.assertFalse(writer.closed) | |
862 | self.assertFalse(writer.isatty()) |
|
878 | self.assertFalse(writer.isatty()) | |
863 | self.assertFalse(writer.readable()) |
|
879 | self.assertFalse(writer.readable()) | |
864 |
|
880 | |||
865 | with self.assertRaises(io.UnsupportedOperation): |
|
881 | with self.assertRaises(io.UnsupportedOperation): | |
866 | writer.readline() |
|
882 | writer.readline() | |
867 |
|
883 | |||
868 | with self.assertRaises(io.UnsupportedOperation): |
|
884 | with self.assertRaises(io.UnsupportedOperation): | |
869 | writer.readline(42) |
|
885 | writer.readline(42) | |
870 |
|
886 | |||
871 | with self.assertRaises(io.UnsupportedOperation): |
|
887 | with self.assertRaises(io.UnsupportedOperation): | |
872 | writer.readline(size=42) |
|
888 | writer.readline(size=42) | |
873 |
|
889 | |||
874 | with self.assertRaises(io.UnsupportedOperation): |
|
890 | with self.assertRaises(io.UnsupportedOperation): | |
875 | writer.readlines() |
|
891 | writer.readlines() | |
876 |
|
892 | |||
877 | with self.assertRaises(io.UnsupportedOperation): |
|
893 | with self.assertRaises(io.UnsupportedOperation): | |
878 | writer.readlines(42) |
|
894 | writer.readlines(42) | |
879 |
|
895 | |||
880 | with self.assertRaises(io.UnsupportedOperation): |
|
896 | with self.assertRaises(io.UnsupportedOperation): | |
881 | writer.readlines(hint=42) |
|
897 | writer.readlines(hint=42) | |
882 |
|
898 | |||
883 | with self.assertRaises(io.UnsupportedOperation): |
|
899 | with self.assertRaises(io.UnsupportedOperation): | |
884 | writer.seek(0) |
|
900 | writer.seek(0) | |
885 |
|
901 | |||
886 | with self.assertRaises(io.UnsupportedOperation): |
|
902 | with self.assertRaises(io.UnsupportedOperation): | |
887 | writer.seek(10, os.SEEK_SET) |
|
903 | writer.seek(10, os.SEEK_SET) | |
888 |
|
904 | |||
889 | self.assertFalse(writer.seekable()) |
|
905 | self.assertFalse(writer.seekable()) | |
890 |
|
906 | |||
891 | with self.assertRaises(io.UnsupportedOperation): |
|
907 | with self.assertRaises(io.UnsupportedOperation): | |
892 | writer.tell() |
|
908 | writer.tell() | |
893 |
|
909 | |||
894 | with self.assertRaises(io.UnsupportedOperation): |
|
910 | with self.assertRaises(io.UnsupportedOperation): | |
895 | writer.truncate() |
|
911 | writer.truncate() | |
896 |
|
912 | |||
897 | with self.assertRaises(io.UnsupportedOperation): |
|
913 | with self.assertRaises(io.UnsupportedOperation): | |
898 | writer.truncate(42) |
|
914 | writer.truncate(42) | |
899 |
|
915 | |||
900 | with self.assertRaises(io.UnsupportedOperation): |
|
916 | with self.assertRaises(io.UnsupportedOperation): | |
901 | writer.truncate(size=42) |
|
917 | writer.truncate(size=42) | |
902 |
|
918 | |||
903 | self.assertTrue(writer.writable()) |
|
919 | self.assertTrue(writer.writable()) | |
904 |
|
920 | |||
905 | with self.assertRaises(io.UnsupportedOperation): |
|
921 | with self.assertRaises(io.UnsupportedOperation): | |
906 | writer.writelines([]) |
|
922 | writer.writelines([]) | |
907 |
|
923 | |||
908 | with self.assertRaises(io.UnsupportedOperation): |
|
924 | with self.assertRaises(io.UnsupportedOperation): | |
909 | writer.read() |
|
925 | writer.read() | |
910 |
|
926 | |||
911 | with self.assertRaises(io.UnsupportedOperation): |
|
927 | with self.assertRaises(io.UnsupportedOperation): | |
912 | writer.read(42) |
|
928 | writer.read(42) | |
913 |
|
929 | |||
914 | with self.assertRaises(io.UnsupportedOperation): |
|
930 | with self.assertRaises(io.UnsupportedOperation): | |
915 | writer.read(size=42) |
|
931 | writer.read(size=42) | |
916 |
|
932 | |||
917 | with self.assertRaises(io.UnsupportedOperation): |
|
933 | with self.assertRaises(io.UnsupportedOperation): | |
918 | writer.readall() |
|
934 | writer.readall() | |
919 |
|
935 | |||
920 | with self.assertRaises(io.UnsupportedOperation): |
|
936 | with self.assertRaises(io.UnsupportedOperation): | |
921 | writer.readinto(None) |
|
937 | writer.readinto(None) | |
922 |
|
938 | |||
923 | with self.assertRaises(io.UnsupportedOperation): |
|
939 | with self.assertRaises(io.UnsupportedOperation): | |
924 | writer.fileno() |
|
940 | writer.fileno() | |
925 |
|
941 | |||
926 | def test_fileno_file(self): |
|
942 | def test_fileno_file(self): | |
927 | with tempfile.TemporaryFile("wb") as tf: |
|
943 | with tempfile.TemporaryFile("wb") as tf: | |
928 | dctx = zstd.ZstdDecompressor() |
|
944 | dctx = zstd.ZstdDecompressor() | |
929 | writer = dctx.stream_writer(tf) |
|
945 | writer = dctx.stream_writer(tf) | |
930 |
|
946 | |||
931 | self.assertEqual(writer.fileno(), tf.fileno()) |
|
947 | self.assertEqual(writer.fileno(), tf.fileno()) | |
932 |
|
948 | |||
933 | def test_close(self): |
|
949 | def test_close(self): | |
934 | foo = zstd.ZstdCompressor().compress(b"foo") |
|
950 | foo = zstd.ZstdCompressor().compress(b"foo") | |
935 |
|
951 | |||
936 | buffer = NonClosingBytesIO() |
|
952 | buffer = NonClosingBytesIO() | |
937 | dctx = zstd.ZstdDecompressor() |
|
953 | dctx = zstd.ZstdDecompressor() | |
938 | writer = dctx.stream_writer(buffer) |
|
954 | writer = dctx.stream_writer(buffer) | |
939 |
|
955 | |||
940 | writer.write(foo) |
|
956 | writer.write(foo) | |
941 | self.assertFalse(writer.closed) |
|
957 | self.assertFalse(writer.closed) | |
942 | self.assertFalse(buffer.closed) |
|
958 | self.assertFalse(buffer.closed) | |
943 | writer.close() |
|
959 | writer.close() | |
944 | self.assertTrue(writer.closed) |
|
960 | self.assertTrue(writer.closed) | |
945 | self.assertTrue(buffer.closed) |
|
961 | self.assertTrue(buffer.closed) | |
946 |
|
962 | |||
947 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
963 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
948 | writer.write(b"") |
|
964 | writer.write(b"") | |
949 |
|
965 | |||
950 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
966 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
951 | writer.flush() |
|
967 | writer.flush() | |
952 |
|
968 | |||
953 | with self.assertRaisesRegex(ValueError, "stream is closed"): |
|
969 | with self.assertRaisesRegex(ValueError, "stream is closed"): | |
954 | with writer: |
|
970 | with writer: | |
955 | pass |
|
971 | pass | |
956 |
|
972 | |||
957 | self.assertEqual(buffer.getvalue(), b"foo") |
|
973 | self.assertEqual(buffer.getvalue(), b"foo") | |
958 |
|
974 | |||
959 | # Context manager exit should close stream. |
|
975 | # Context manager exit should close stream. | |
960 | buffer = NonClosingBytesIO() |
|
976 | buffer = NonClosingBytesIO() | |
961 | writer = dctx.stream_writer(buffer) |
|
977 | writer = dctx.stream_writer(buffer) | |
962 |
|
978 | |||
963 | with writer: |
|
979 | with writer: | |
964 | writer.write(foo) |
|
980 | writer.write(foo) | |
965 |
|
981 | |||
966 | self.assertTrue(writer.closed) |
|
982 | self.assertTrue(writer.closed) | |
967 | self.assertEqual(buffer.getvalue(), b"foo") |
|
983 | self.assertEqual(buffer.getvalue(), b"foo") | |
968 |
|
984 | |||
969 | def test_flush(self): |
|
985 | def test_flush(self): | |
970 | buffer = OpCountingBytesIO() |
|
986 | buffer = OpCountingBytesIO() | |
971 | dctx = zstd.ZstdDecompressor() |
|
987 | dctx = zstd.ZstdDecompressor() | |
972 | writer = dctx.stream_writer(buffer) |
|
988 | writer = dctx.stream_writer(buffer) | |
973 |
|
989 | |||
974 | writer.flush() |
|
990 | writer.flush() | |
975 | self.assertEqual(buffer._flush_count, 1) |
|
991 | self.assertEqual(buffer._flush_count, 1) | |
976 | writer.flush() |
|
992 | writer.flush() | |
977 | self.assertEqual(buffer._flush_count, 2) |
|
993 | self.assertEqual(buffer._flush_count, 2) | |
978 |
|
994 | |||
979 | def test_empty_roundtrip(self): |
|
995 | def test_empty_roundtrip(self): | |
980 | cctx = zstd.ZstdCompressor() |
|
996 | cctx = zstd.ZstdCompressor() | |
981 | empty = cctx.compress(b"") |
|
997 | empty = cctx.compress(b"") | |
982 | self.assertEqual(decompress_via_writer(empty), b"") |
|
998 | self.assertEqual(decompress_via_writer(empty), b"") | |
983 |
|
999 | |||
984 | def test_input_types(self): |
|
1000 | def test_input_types(self): | |
985 | cctx = zstd.ZstdCompressor(level=1) |
|
1001 | cctx = zstd.ZstdCompressor(level=1) | |
986 | compressed = cctx.compress(b"foo") |
|
1002 | compressed = cctx.compress(b"foo") | |
987 |
|
1003 | |||
988 | mutable_array = bytearray(len(compressed)) |
|
1004 | mutable_array = bytearray(len(compressed)) | |
989 | mutable_array[:] = compressed |
|
1005 | mutable_array[:] = compressed | |
990 |
|
1006 | |||
991 | sources = [ |
|
1007 | sources = [ | |
992 | memoryview(compressed), |
|
1008 | memoryview(compressed), | |
993 | bytearray(compressed), |
|
1009 | bytearray(compressed), | |
994 | mutable_array, |
|
1010 | mutable_array, | |
995 | ] |
|
1011 | ] | |
996 |
|
1012 | |||
997 | dctx = zstd.ZstdDecompressor() |
|
1013 | dctx = zstd.ZstdDecompressor() | |
998 | for source in sources: |
|
1014 | for source in sources: | |
999 | buffer = io.BytesIO() |
|
1015 | buffer = io.BytesIO() | |
1000 |
|
1016 | |||
1001 | decompressor = dctx.stream_writer(buffer) |
|
1017 | decompressor = dctx.stream_writer(buffer) | |
1002 | decompressor.write(source) |
|
1018 | decompressor.write(source) | |
1003 | self.assertEqual(buffer.getvalue(), b"foo") |
|
1019 | self.assertEqual(buffer.getvalue(), b"foo") | |
1004 |
|
1020 | |||
1005 | buffer = NonClosingBytesIO() |
|
1021 | buffer = NonClosingBytesIO() | |
1006 |
|
1022 | |||
1007 | with dctx.stream_writer(buffer) as decompressor: |
|
1023 | with dctx.stream_writer(buffer) as decompressor: | |
1008 | self.assertEqual(decompressor.write(source), 3) |
|
1024 | self.assertEqual(decompressor.write(source), 3) | |
1009 |
|
1025 | |||
1010 | self.assertEqual(buffer.getvalue(), b"foo") |
|
1026 | self.assertEqual(buffer.getvalue(), b"foo") | |
1011 |
|
1027 | |||
1012 | buffer = io.BytesIO() |
|
1028 | buffer = io.BytesIO() | |
1013 | writer = dctx.stream_writer(buffer, write_return_read=True) |
|
1029 | writer = dctx.stream_writer(buffer, write_return_read=True) | |
1014 | self.assertEqual(writer.write(source), len(source)) |
|
1030 | self.assertEqual(writer.write(source), len(source)) | |
1015 | self.assertEqual(buffer.getvalue(), b"foo") |
|
1031 | self.assertEqual(buffer.getvalue(), b"foo") | |
1016 |
|
1032 | |||
1017 | def test_large_roundtrip(self): |
|
1033 | def test_large_roundtrip(self): | |
1018 | chunks = [] |
|
1034 | chunks = [] | |
1019 | for i in range(255): |
|
1035 | for i in range(255): | |
1020 | chunks.append(struct.Struct(">B").pack(i) * 16384) |
|
1036 | chunks.append(struct.Struct(">B").pack(i) * 16384) | |
1021 | orig = b"".join(chunks) |
|
1037 | orig = b"".join(chunks) | |
1022 | cctx = zstd.ZstdCompressor() |
|
1038 | cctx = zstd.ZstdCompressor() | |
1023 | compressed = cctx.compress(orig) |
|
1039 | compressed = cctx.compress(orig) | |
1024 |
|
1040 | |||
1025 | self.assertEqual(decompress_via_writer(compressed), orig) |
|
1041 | self.assertEqual(decompress_via_writer(compressed), orig) | |
1026 |
|
1042 | |||
1027 | def test_multiple_calls(self): |
|
1043 | def test_multiple_calls(self): | |
1028 | chunks = [] |
|
1044 | chunks = [] | |
1029 | for i in range(255): |
|
1045 | for i in range(255): | |
1030 | for j in range(255): |
|
1046 | for j in range(255): | |
1031 | chunks.append(struct.Struct(">B").pack(j) * i) |
|
1047 | chunks.append(struct.Struct(">B").pack(j) * i) | |
1032 |
|
1048 | |||
1033 | orig = b"".join(chunks) |
|
1049 | orig = b"".join(chunks) | |
1034 | cctx = zstd.ZstdCompressor() |
|
1050 | cctx = zstd.ZstdCompressor() | |
1035 | compressed = cctx.compress(orig) |
|
1051 | compressed = cctx.compress(orig) | |
1036 |
|
1052 | |||
1037 | buffer = NonClosingBytesIO() |
|
1053 | buffer = NonClosingBytesIO() | |
1038 | dctx = zstd.ZstdDecompressor() |
|
1054 | dctx = zstd.ZstdDecompressor() | |
1039 | with dctx.stream_writer(buffer) as decompressor: |
|
1055 | with dctx.stream_writer(buffer) as decompressor: | |
1040 | pos = 0 |
|
1056 | pos = 0 | |
1041 | while pos < len(compressed): |
|
1057 | while pos < len(compressed): | |
1042 | pos2 = pos + 8192 |
|
1058 | pos2 = pos + 8192 | |
1043 | decompressor.write(compressed[pos:pos2]) |
|
1059 | decompressor.write(compressed[pos:pos2]) | |
1044 | pos += 8192 |
|
1060 | pos += 8192 | |
1045 | self.assertEqual(buffer.getvalue(), orig) |
|
1061 | self.assertEqual(buffer.getvalue(), orig) | |
1046 |
|
1062 | |||
1047 | # Again with write_return_read=True |
|
1063 | # Again with write_return_read=True | |
1048 | buffer = io.BytesIO() |
|
1064 | buffer = io.BytesIO() | |
1049 | writer = dctx.stream_writer(buffer, write_return_read=True) |
|
1065 | writer = dctx.stream_writer(buffer, write_return_read=True) | |
1050 | pos = 0 |
|
1066 | pos = 0 | |
1051 | while pos < len(compressed): |
|
1067 | while pos < len(compressed): | |
1052 | pos2 = pos + 8192 |
|
1068 | pos2 = pos + 8192 | |
1053 | chunk = compressed[pos:pos2] |
|
1069 | chunk = compressed[pos:pos2] | |
1054 | self.assertEqual(writer.write(chunk), len(chunk)) |
|
1070 | self.assertEqual(writer.write(chunk), len(chunk)) | |
1055 | pos += 8192 |
|
1071 | pos += 8192 | |
1056 | self.assertEqual(buffer.getvalue(), orig) |
|
1072 | self.assertEqual(buffer.getvalue(), orig) | |
1057 |
|
1073 | |||
1058 | def test_dictionary(self): |
|
1074 | def test_dictionary(self): | |
1059 | samples = [] |
|
1075 | samples = [] | |
1060 | for i in range(128): |
|
1076 | for i in range(128): | |
1061 | samples.append(b"foo" * 64) |
|
1077 | samples.append(b"foo" * 64) | |
1062 | samples.append(b"bar" * 64) |
|
1078 | samples.append(b"bar" * 64) | |
1063 | samples.append(b"foobar" * 64) |
|
1079 | samples.append(b"foobar" * 64) | |
1064 |
|
1080 | |||
1065 | d = zstd.train_dictionary(8192, samples) |
|
1081 | d = zstd.train_dictionary(8192, samples) | |
1066 |
|
1082 | |||
1067 | orig = b"foobar" * 16384 |
|
1083 | orig = b"foobar" * 16384 | |
1068 | buffer = NonClosingBytesIO() |
|
1084 | buffer = NonClosingBytesIO() | |
1069 | cctx = zstd.ZstdCompressor(dict_data=d) |
|
1085 | cctx = zstd.ZstdCompressor(dict_data=d) | |
1070 | with cctx.stream_writer(buffer) as compressor: |
|
1086 | with cctx.stream_writer(buffer) as compressor: | |
1071 | self.assertEqual(compressor.write(orig), 0) |
|
1087 | self.assertEqual(compressor.write(orig), 0) | |
1072 |
|
1088 | |||
1073 | compressed = buffer.getvalue() |
|
1089 | compressed = buffer.getvalue() | |
1074 | buffer = io.BytesIO() |
|
1090 | buffer = io.BytesIO() | |
1075 |
|
1091 | |||
1076 | dctx = zstd.ZstdDecompressor(dict_data=d) |
|
1092 | dctx = zstd.ZstdDecompressor(dict_data=d) | |
1077 | decompressor = dctx.stream_writer(buffer) |
|
1093 | decompressor = dctx.stream_writer(buffer) | |
1078 | self.assertEqual(decompressor.write(compressed), len(orig)) |
|
1094 | self.assertEqual(decompressor.write(compressed), len(orig)) | |
1079 | self.assertEqual(buffer.getvalue(), orig) |
|
1095 | self.assertEqual(buffer.getvalue(), orig) | |
1080 |
|
1096 | |||
1081 | buffer = NonClosingBytesIO() |
|
1097 | buffer = NonClosingBytesIO() | |
1082 |
|
1098 | |||
1083 | with dctx.stream_writer(buffer) as decompressor: |
|
1099 | with dctx.stream_writer(buffer) as decompressor: | |
1084 | self.assertEqual(decompressor.write(compressed), len(orig)) |
|
1100 | self.assertEqual(decompressor.write(compressed), len(orig)) | |
1085 |
|
1101 | |||
1086 | self.assertEqual(buffer.getvalue(), orig) |
|
1102 | self.assertEqual(buffer.getvalue(), orig) | |
1087 |
|
1103 | |||
1088 | def test_memory_size(self): |
|
1104 | def test_memory_size(self): | |
1089 | dctx = zstd.ZstdDecompressor() |
|
1105 | dctx = zstd.ZstdDecompressor() | |
1090 | buffer = io.BytesIO() |
|
1106 | buffer = io.BytesIO() | |
1091 |
|
1107 | |||
1092 | decompressor = dctx.stream_writer(buffer) |
|
1108 | decompressor = dctx.stream_writer(buffer) | |
1093 | size = decompressor.memory_size() |
|
1109 | size = decompressor.memory_size() | |
1094 | self.assertGreater(size, 100000) |
|
1110 | self.assertGreater(size, 100000) | |
1095 |
|
1111 | |||
1096 | with dctx.stream_writer(buffer) as decompressor: |
|
1112 | with dctx.stream_writer(buffer) as decompressor: | |
1097 | size = decompressor.memory_size() |
|
1113 | size = decompressor.memory_size() | |
1098 |
|
1114 | |||
1099 | self.assertGreater(size, 100000) |
|
1115 | self.assertGreater(size, 100000) | |
1100 |
|
1116 | |||
1101 | def test_write_size(self): |
|
1117 | def test_write_size(self): | |
1102 | source = zstd.ZstdCompressor().compress(b"foobarfoobar") |
|
1118 | source = zstd.ZstdCompressor().compress(b"foobarfoobar") | |
1103 | dest = OpCountingBytesIO() |
|
1119 | dest = OpCountingBytesIO() | |
1104 | dctx = zstd.ZstdDecompressor() |
|
1120 | dctx = zstd.ZstdDecompressor() | |
1105 | with dctx.stream_writer(dest, write_size=1) as decompressor: |
|
1121 | with dctx.stream_writer(dest, write_size=1) as decompressor: | |
1106 | s = struct.Struct(">B") |
|
1122 | s = struct.Struct(">B") | |
1107 | for c in source: |
|
1123 | for c in source: | |
1108 | if not isinstance(c, str): |
|
1124 | if not isinstance(c, str): | |
1109 | c = s.pack(c) |
|
1125 | c = s.pack(c) | |
1110 | decompressor.write(c) |
|
1126 | decompressor.write(c) | |
1111 |
|
1127 | |||
1112 | self.assertEqual(dest.getvalue(), b"foobarfoobar") |
|
1128 | self.assertEqual(dest.getvalue(), b"foobarfoobar") | |
1113 | self.assertEqual(dest._write_count, len(dest.getvalue())) |
|
1129 | self.assertEqual(dest._write_count, len(dest.getvalue())) | |
1114 |
|
1130 | |||
1115 |
|
1131 | |||
1116 | @make_cffi |
|
1132 | @make_cffi | |
1117 | class TestDecompressor_read_to_iter(TestCase): |
|
1133 | class TestDecompressor_read_to_iter(TestCase): | |
1118 | def test_type_validation(self): |
|
1134 | def test_type_validation(self): | |
1119 | dctx = zstd.ZstdDecompressor() |
|
1135 | dctx = zstd.ZstdDecompressor() | |
1120 |
|
1136 | |||
1121 | # Object with read() works. |
|
1137 | # Object with read() works. | |
1122 | dctx.read_to_iter(io.BytesIO()) |
|
1138 | dctx.read_to_iter(io.BytesIO()) | |
1123 |
|
1139 | |||
1124 | # Buffer protocol works. |
|
1140 | # Buffer protocol works. | |
1125 | dctx.read_to_iter(b"foobar") |
|
1141 | dctx.read_to_iter(b"foobar") | |
1126 |
|
1142 | |||
1127 |
with self.assertRaisesRegex( |
|
1143 | with self.assertRaisesRegex( | |
|
1144 | ValueError, "must pass an object with a read" | |||
|
1145 | ): | |||
1128 | b"".join(dctx.read_to_iter(True)) |
|
1146 | b"".join(dctx.read_to_iter(True)) | |
1129 |
|
1147 | |||
1130 | def test_empty_input(self): |
|
1148 | def test_empty_input(self): | |
1131 | dctx = zstd.ZstdDecompressor() |
|
1149 | dctx = zstd.ZstdDecompressor() | |
1132 |
|
1150 | |||
1133 | source = io.BytesIO() |
|
1151 | source = io.BytesIO() | |
1134 | it = dctx.read_to_iter(source) |
|
1152 | it = dctx.read_to_iter(source) | |
1135 | # TODO this is arguably wrong. Should get an error about missing frame foo. |
|
1153 | # TODO this is arguably wrong. Should get an error about missing frame foo. | |
1136 | with self.assertRaises(StopIteration): |
|
1154 | with self.assertRaises(StopIteration): | |
1137 | next(it) |
|
1155 | next(it) | |
1138 |
|
1156 | |||
1139 | it = dctx.read_to_iter(b"") |
|
1157 | it = dctx.read_to_iter(b"") | |
1140 | with self.assertRaises(StopIteration): |
|
1158 | with self.assertRaises(StopIteration): | |
1141 | next(it) |
|
1159 | next(it) | |
1142 |
|
1160 | |||
1143 | def test_invalid_input(self): |
|
1161 | def test_invalid_input(self): | |
1144 | dctx = zstd.ZstdDecompressor() |
|
1162 | dctx = zstd.ZstdDecompressor() | |
1145 |
|
1163 | |||
1146 | source = io.BytesIO(b"foobar") |
|
1164 | source = io.BytesIO(b"foobar") | |
1147 | it = dctx.read_to_iter(source) |
|
1165 | it = dctx.read_to_iter(source) | |
1148 | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): |
|
1166 | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): | |
1149 | next(it) |
|
1167 | next(it) | |
1150 |
|
1168 | |||
1151 | it = dctx.read_to_iter(b"foobar") |
|
1169 | it = dctx.read_to_iter(b"foobar") | |
1152 | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): |
|
1170 | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): | |
1153 | next(it) |
|
1171 | next(it) | |
1154 |
|
1172 | |||
1155 | def test_empty_roundtrip(self): |
|
1173 | def test_empty_roundtrip(self): | |
1156 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
1174 | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | |
1157 | empty = cctx.compress(b"") |
|
1175 | empty = cctx.compress(b"") | |
1158 |
|
1176 | |||
1159 | source = io.BytesIO(empty) |
|
1177 | source = io.BytesIO(empty) | |
1160 | source.seek(0) |
|
1178 | source.seek(0) | |
1161 |
|
1179 | |||
1162 | dctx = zstd.ZstdDecompressor() |
|
1180 | dctx = zstd.ZstdDecompressor() | |
1163 | it = dctx.read_to_iter(source) |
|
1181 | it = dctx.read_to_iter(source) | |
1164 |
|
1182 | |||
1165 | # No chunks should be emitted since there is no data. |
|
1183 | # No chunks should be emitted since there is no data. | |
1166 | with self.assertRaises(StopIteration): |
|
1184 | with self.assertRaises(StopIteration): | |
1167 | next(it) |
|
1185 | next(it) | |
1168 |
|
1186 | |||
1169 | # Again for good measure. |
|
1187 | # Again for good measure. | |
1170 | with self.assertRaises(StopIteration): |
|
1188 | with self.assertRaises(StopIteration): | |
1171 | next(it) |
|
1189 | next(it) | |
1172 |
|
1190 | |||
1173 | def test_skip_bytes_too_large(self): |
|
1191 | def test_skip_bytes_too_large(self): | |
1174 | dctx = zstd.ZstdDecompressor() |
|
1192 | dctx = zstd.ZstdDecompressor() | |
1175 |
|
1193 | |||
1176 | with self.assertRaisesRegex( |
|
1194 | with self.assertRaisesRegex( | |
1177 | ValueError, "skip_bytes must be smaller than read_size" |
|
1195 | ValueError, "skip_bytes must be smaller than read_size" | |
1178 | ): |
|
1196 | ): | |
1179 | b"".join(dctx.read_to_iter(b"", skip_bytes=1, read_size=1)) |
|
1197 | b"".join(dctx.read_to_iter(b"", skip_bytes=1, read_size=1)) | |
1180 |
|
1198 | |||
1181 | with self.assertRaisesRegex( |
|
1199 | with self.assertRaisesRegex( | |
1182 | ValueError, "skip_bytes larger than first input chunk" |
|
1200 | ValueError, "skip_bytes larger than first input chunk" | |
1183 | ): |
|
1201 | ): | |
1184 | b"".join(dctx.read_to_iter(b"foobar", skip_bytes=10)) |
|
1202 | b"".join(dctx.read_to_iter(b"foobar", skip_bytes=10)) | |
1185 |
|
1203 | |||
1186 | def test_skip_bytes(self): |
|
1204 | def test_skip_bytes(self): | |
1187 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
1205 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
1188 | compressed = cctx.compress(b"foobar") |
|
1206 | compressed = cctx.compress(b"foobar") | |
1189 |
|
1207 | |||
1190 | dctx = zstd.ZstdDecompressor() |
|
1208 | dctx = zstd.ZstdDecompressor() | |
1191 | output = b"".join(dctx.read_to_iter(b"hdr" + compressed, skip_bytes=3)) |
|
1209 | output = b"".join(dctx.read_to_iter(b"hdr" + compressed, skip_bytes=3)) | |
1192 | self.assertEqual(output, b"foobar") |
|
1210 | self.assertEqual(output, b"foobar") | |
1193 |
|
1211 | |||
1194 | def test_large_output(self): |
|
1212 | def test_large_output(self): | |
1195 | source = io.BytesIO() |
|
1213 | source = io.BytesIO() | |
1196 | source.write(b"f" * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) |
|
1214 | source.write(b"f" * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) | |
1197 | source.write(b"o") |
|
1215 | source.write(b"o") | |
1198 | source.seek(0) |
|
1216 | source.seek(0) | |
1199 |
|
1217 | |||
1200 | cctx = zstd.ZstdCompressor(level=1) |
|
1218 | cctx = zstd.ZstdCompressor(level=1) | |
1201 | compressed = io.BytesIO(cctx.compress(source.getvalue())) |
|
1219 | compressed = io.BytesIO(cctx.compress(source.getvalue())) | |
1202 | compressed.seek(0) |
|
1220 | compressed.seek(0) | |
1203 |
|
1221 | |||
1204 | dctx = zstd.ZstdDecompressor() |
|
1222 | dctx = zstd.ZstdDecompressor() | |
1205 | it = dctx.read_to_iter(compressed) |
|
1223 | it = dctx.read_to_iter(compressed) | |
1206 |
|
1224 | |||
1207 | chunks = [] |
|
1225 | chunks = [] | |
1208 | chunks.append(next(it)) |
|
1226 | chunks.append(next(it)) | |
1209 | chunks.append(next(it)) |
|
1227 | chunks.append(next(it)) | |
1210 |
|
1228 | |||
1211 | with self.assertRaises(StopIteration): |
|
1229 | with self.assertRaises(StopIteration): | |
1212 | next(it) |
|
1230 | next(it) | |
1213 |
|
1231 | |||
1214 | decompressed = b"".join(chunks) |
|
1232 | decompressed = b"".join(chunks) | |
1215 | self.assertEqual(decompressed, source.getvalue()) |
|
1233 | self.assertEqual(decompressed, source.getvalue()) | |
1216 |
|
1234 | |||
1217 | # And again with buffer protocol. |
|
1235 | # And again with buffer protocol. | |
1218 | it = dctx.read_to_iter(compressed.getvalue()) |
|
1236 | it = dctx.read_to_iter(compressed.getvalue()) | |
1219 | chunks = [] |
|
1237 | chunks = [] | |
1220 | chunks.append(next(it)) |
|
1238 | chunks.append(next(it)) | |
1221 | chunks.append(next(it)) |
|
1239 | chunks.append(next(it)) | |
1222 |
|
1240 | |||
1223 | with self.assertRaises(StopIteration): |
|
1241 | with self.assertRaises(StopIteration): | |
1224 | next(it) |
|
1242 | next(it) | |
1225 |
|
1243 | |||
1226 | decompressed = b"".join(chunks) |
|
1244 | decompressed = b"".join(chunks) | |
1227 | self.assertEqual(decompressed, source.getvalue()) |
|
1245 | self.assertEqual(decompressed, source.getvalue()) | |
1228 |
|
1246 | |||
1229 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
1247 | @unittest.skipUnless( | |
|
1248 | "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set" | |||
|
1249 | ) | |||
1230 | def test_large_input(self): |
|
1250 | def test_large_input(self): | |
1231 | bytes = list(struct.Struct(">B").pack(i) for i in range(256)) |
|
1251 | bytes = list(struct.Struct(">B").pack(i) for i in range(256)) | |
1232 | compressed = NonClosingBytesIO() |
|
1252 | compressed = NonClosingBytesIO() | |
1233 | input_size = 0 |
|
1253 | input_size = 0 | |
1234 | cctx = zstd.ZstdCompressor(level=1) |
|
1254 | cctx = zstd.ZstdCompressor(level=1) | |
1235 | with cctx.stream_writer(compressed) as compressor: |
|
1255 | with cctx.stream_writer(compressed) as compressor: | |
1236 | while True: |
|
1256 | while True: | |
1237 | compressor.write(random.choice(bytes)) |
|
1257 | compressor.write(random.choice(bytes)) | |
1238 | input_size += 1 |
|
1258 | input_size += 1 | |
1239 |
|
1259 | |||
1240 | have_compressed = ( |
|
1260 | have_compressed = ( | |
1241 | len(compressed.getvalue()) |
|
1261 | len(compressed.getvalue()) | |
1242 | > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
|
1262 | > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | |
1243 | ) |
|
1263 | ) | |
1244 | have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 |
|
1264 | have_raw = ( | |
|
1265 | input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | |||
|
1266 | ) | |||
1245 | if have_compressed and have_raw: |
|
1267 | if have_compressed and have_raw: | |
1246 | break |
|
1268 | break | |
1247 |
|
1269 | |||
1248 | compressed = io.BytesIO(compressed.getvalue()) |
|
1270 | compressed = io.BytesIO(compressed.getvalue()) | |
1249 | self.assertGreater( |
|
1271 | self.assertGreater( | |
1250 |
len(compressed.getvalue()), |
|
1272 | len(compressed.getvalue()), | |
|
1273 | zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | |||
1251 | ) |
|
1274 | ) | |
1252 |
|
1275 | |||
1253 | dctx = zstd.ZstdDecompressor() |
|
1276 | dctx = zstd.ZstdDecompressor() | |
1254 | it = dctx.read_to_iter(compressed) |
|
1277 | it = dctx.read_to_iter(compressed) | |
1255 |
|
1278 | |||
1256 | chunks = [] |
|
1279 | chunks = [] | |
1257 | chunks.append(next(it)) |
|
1280 | chunks.append(next(it)) | |
1258 | chunks.append(next(it)) |
|
1281 | chunks.append(next(it)) | |
1259 | chunks.append(next(it)) |
|
1282 | chunks.append(next(it)) | |
1260 |
|
1283 | |||
1261 | with self.assertRaises(StopIteration): |
|
1284 | with self.assertRaises(StopIteration): | |
1262 | next(it) |
|
1285 | next(it) | |
1263 |
|
1286 | |||
1264 | decompressed = b"".join(chunks) |
|
1287 | decompressed = b"".join(chunks) | |
1265 | self.assertEqual(len(decompressed), input_size) |
|
1288 | self.assertEqual(len(decompressed), input_size) | |
1266 |
|
1289 | |||
1267 | # And again with buffer protocol. |
|
1290 | # And again with buffer protocol. | |
1268 | it = dctx.read_to_iter(compressed.getvalue()) |
|
1291 | it = dctx.read_to_iter(compressed.getvalue()) | |
1269 |
|
1292 | |||
1270 | chunks = [] |
|
1293 | chunks = [] | |
1271 | chunks.append(next(it)) |
|
1294 | chunks.append(next(it)) | |
1272 | chunks.append(next(it)) |
|
1295 | chunks.append(next(it)) | |
1273 | chunks.append(next(it)) |
|
1296 | chunks.append(next(it)) | |
1274 |
|
1297 | |||
1275 | with self.assertRaises(StopIteration): |
|
1298 | with self.assertRaises(StopIteration): | |
1276 | next(it) |
|
1299 | next(it) | |
1277 |
|
1300 | |||
1278 | decompressed = b"".join(chunks) |
|
1301 | decompressed = b"".join(chunks) | |
1279 | self.assertEqual(len(decompressed), input_size) |
|
1302 | self.assertEqual(len(decompressed), input_size) | |
1280 |
|
1303 | |||
1281 | def test_interesting(self): |
|
1304 | def test_interesting(self): | |
1282 | # Found this edge case via fuzzing. |
|
1305 | # Found this edge case via fuzzing. | |
1283 | cctx = zstd.ZstdCompressor(level=1) |
|
1306 | cctx = zstd.ZstdCompressor(level=1) | |
1284 |
|
1307 | |||
1285 | source = io.BytesIO() |
|
1308 | source = io.BytesIO() | |
1286 |
|
1309 | |||
1287 | compressed = NonClosingBytesIO() |
|
1310 | compressed = NonClosingBytesIO() | |
1288 | with cctx.stream_writer(compressed) as compressor: |
|
1311 | with cctx.stream_writer(compressed) as compressor: | |
1289 | for i in range(256): |
|
1312 | for i in range(256): | |
1290 | chunk = b"\0" * 1024 |
|
1313 | chunk = b"\0" * 1024 | |
1291 | compressor.write(chunk) |
|
1314 | compressor.write(chunk) | |
1292 | source.write(chunk) |
|
1315 | source.write(chunk) | |
1293 |
|
1316 | |||
1294 | dctx = zstd.ZstdDecompressor() |
|
1317 | dctx = zstd.ZstdDecompressor() | |
1295 |
|
1318 | |||
1296 | simple = dctx.decompress( |
|
1319 | simple = dctx.decompress( | |
1297 | compressed.getvalue(), max_output_size=len(source.getvalue()) |
|
1320 | compressed.getvalue(), max_output_size=len(source.getvalue()) | |
1298 | ) |
|
1321 | ) | |
1299 | self.assertEqual(simple, source.getvalue()) |
|
1322 | self.assertEqual(simple, source.getvalue()) | |
1300 |
|
1323 | |||
1301 | compressed = io.BytesIO(compressed.getvalue()) |
|
1324 | compressed = io.BytesIO(compressed.getvalue()) | |
1302 | streamed = b"".join(dctx.read_to_iter(compressed)) |
|
1325 | streamed = b"".join(dctx.read_to_iter(compressed)) | |
1303 | self.assertEqual(streamed, source.getvalue()) |
|
1326 | self.assertEqual(streamed, source.getvalue()) | |
1304 |
|
1327 | |||
1305 | def test_read_write_size(self): |
|
1328 | def test_read_write_size(self): | |
1306 | source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) |
|
1329 | source = OpCountingBytesIO( | |
|
1330 | zstd.ZstdCompressor().compress(b"foobarfoobar") | |||
|
1331 | ) | |||
1307 | dctx = zstd.ZstdDecompressor() |
|
1332 | dctx = zstd.ZstdDecompressor() | |
1308 | for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): |
|
1333 | for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): | |
1309 | self.assertEqual(len(chunk), 1) |
|
1334 | self.assertEqual(len(chunk), 1) | |
1310 |
|
1335 | |||
1311 | self.assertEqual(source._read_count, len(source.getvalue())) |
|
1336 | self.assertEqual(source._read_count, len(source.getvalue())) | |
1312 |
|
1337 | |||
1313 | def test_magic_less(self): |
|
1338 | def test_magic_less(self): | |
1314 | params = zstd.CompressionParameters.from_level( |
|
1339 | params = zstd.CompressionParameters.from_level( | |
1315 | 1, format=zstd.FORMAT_ZSTD1_MAGICLESS |
|
1340 | 1, format=zstd.FORMAT_ZSTD1_MAGICLESS | |
1316 | ) |
|
1341 | ) | |
1317 | cctx = zstd.ZstdCompressor(compression_params=params) |
|
1342 | cctx = zstd.ZstdCompressor(compression_params=params) | |
1318 | frame = cctx.compress(b"foobar") |
|
1343 | frame = cctx.compress(b"foobar") | |
1319 |
|
1344 | |||
1320 | self.assertNotEqual(frame[0:4], b"\x28\xb5\x2f\xfd") |
|
1345 | self.assertNotEqual(frame[0:4], b"\x28\xb5\x2f\xfd") | |
1321 |
|
1346 | |||
1322 | dctx = zstd.ZstdDecompressor() |
|
1347 | dctx = zstd.ZstdDecompressor() | |
1323 | with self.assertRaisesRegex( |
|
1348 | with self.assertRaisesRegex( | |
1324 | zstd.ZstdError, "error determining content size from frame header" |
|
1349 | zstd.ZstdError, "error determining content size from frame header" | |
1325 | ): |
|
1350 | ): | |
1326 | dctx.decompress(frame) |
|
1351 | dctx.decompress(frame) | |
1327 |
|
1352 | |||
1328 | dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS) |
|
1353 | dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS) | |
1329 | res = b"".join(dctx.read_to_iter(frame)) |
|
1354 | res = b"".join(dctx.read_to_iter(frame)) | |
1330 | self.assertEqual(res, b"foobar") |
|
1355 | self.assertEqual(res, b"foobar") | |
1331 |
|
1356 | |||
1332 |
|
1357 | |||
1333 | @make_cffi |
|
1358 | @make_cffi | |
1334 | class TestDecompressor_content_dict_chain(TestCase): |
|
1359 | class TestDecompressor_content_dict_chain(TestCase): | |
1335 | def test_bad_inputs_simple(self): |
|
1360 | def test_bad_inputs_simple(self): | |
1336 | dctx = zstd.ZstdDecompressor() |
|
1361 | dctx = zstd.ZstdDecompressor() | |
1337 |
|
1362 | |||
1338 | with self.assertRaises(TypeError): |
|
1363 | with self.assertRaises(TypeError): | |
1339 | dctx.decompress_content_dict_chain(b"foo") |
|
1364 | dctx.decompress_content_dict_chain(b"foo") | |
1340 |
|
1365 | |||
1341 | with self.assertRaises(TypeError): |
|
1366 | with self.assertRaises(TypeError): | |
1342 | dctx.decompress_content_dict_chain((b"foo", b"bar")) |
|
1367 | dctx.decompress_content_dict_chain((b"foo", b"bar")) | |
1343 |
|
1368 | |||
1344 | with self.assertRaisesRegex(ValueError, "empty input chain"): |
|
1369 | with self.assertRaisesRegex(ValueError, "empty input chain"): | |
1345 | dctx.decompress_content_dict_chain([]) |
|
1370 | dctx.decompress_content_dict_chain([]) | |
1346 |
|
1371 | |||
1347 | with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): |
|
1372 | with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): | |
1348 | dctx.decompress_content_dict_chain([u"foo"]) |
|
1373 | dctx.decompress_content_dict_chain([u"foo"]) | |
1349 |
|
1374 | |||
1350 | with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): |
|
1375 | with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): | |
1351 | dctx.decompress_content_dict_chain([True]) |
|
1376 | dctx.decompress_content_dict_chain([True]) | |
1352 |
|
1377 | |||
1353 | with self.assertRaisesRegex( |
|
1378 | with self.assertRaisesRegex( | |
1354 | ValueError, "chunk 0 is too small to contain a zstd frame" |
|
1379 | ValueError, "chunk 0 is too small to contain a zstd frame" | |
1355 | ): |
|
1380 | ): | |
1356 | dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) |
|
1381 | dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) | |
1357 |
|
1382 | |||
1358 | with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"): |
|
1383 | with self.assertRaisesRegex( | |
|
1384 | ValueError, "chunk 0 is not a valid zstd frame" | |||
|
1385 | ): | |||
1359 | dctx.decompress_content_dict_chain([b"foo" * 8]) |
|
1386 | dctx.decompress_content_dict_chain([b"foo" * 8]) | |
1360 |
|
1387 | |||
1361 |
no_size = zstd.ZstdCompressor(write_content_size=False).compress( |
|
1388 | no_size = zstd.ZstdCompressor(write_content_size=False).compress( | |
|
1389 | b"foo" * 64 | |||
|
1390 | ) | |||
1362 |
|
1391 | |||
1363 | with self.assertRaisesRegex( |
|
1392 | with self.assertRaisesRegex( | |
1364 | ValueError, "chunk 0 missing content size in frame" |
|
1393 | ValueError, "chunk 0 missing content size in frame" | |
1365 | ): |
|
1394 | ): | |
1366 | dctx.decompress_content_dict_chain([no_size]) |
|
1395 | dctx.decompress_content_dict_chain([no_size]) | |
1367 |
|
1396 | |||
1368 | # Corrupt first frame. |
|
1397 | # Corrupt first frame. | |
1369 | frame = zstd.ZstdCompressor().compress(b"foo" * 64) |
|
1398 | frame = zstd.ZstdCompressor().compress(b"foo" * 64) | |
1370 | frame = frame[0:12] + frame[15:] |
|
1399 | frame = frame[0:12] + frame[15:] | |
1371 | with self.assertRaisesRegex( |
|
1400 | with self.assertRaisesRegex( | |
1372 | zstd.ZstdError, "chunk 0 did not decompress full frame" |
|
1401 | zstd.ZstdError, "chunk 0 did not decompress full frame" | |
1373 | ): |
|
1402 | ): | |
1374 | dctx.decompress_content_dict_chain([frame]) |
|
1403 | dctx.decompress_content_dict_chain([frame]) | |
1375 |
|
1404 | |||
1376 | def test_bad_subsequent_input(self): |
|
1405 | def test_bad_subsequent_input(self): | |
1377 | initial = zstd.ZstdCompressor().compress(b"foo" * 64) |
|
1406 | initial = zstd.ZstdCompressor().compress(b"foo" * 64) | |
1378 |
|
1407 | |||
1379 | dctx = zstd.ZstdDecompressor() |
|
1408 | dctx = zstd.ZstdDecompressor() | |
1380 |
|
1409 | |||
1381 | with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): |
|
1410 | with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): | |
1382 | dctx.decompress_content_dict_chain([initial, u"foo"]) |
|
1411 | dctx.decompress_content_dict_chain([initial, u"foo"]) | |
1383 |
|
1412 | |||
1384 | with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): |
|
1413 | with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): | |
1385 | dctx.decompress_content_dict_chain([initial, None]) |
|
1414 | dctx.decompress_content_dict_chain([initial, None]) | |
1386 |
|
1415 | |||
1387 | with self.assertRaisesRegex( |
|
1416 | with self.assertRaisesRegex( | |
1388 | ValueError, "chunk 1 is too small to contain a zstd frame" |
|
1417 | ValueError, "chunk 1 is too small to contain a zstd frame" | |
1389 | ): |
|
1418 | ): | |
1390 | dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) |
|
1419 | dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) | |
1391 |
|
1420 | |||
1392 | with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"): |
|
1421 | with self.assertRaisesRegex( | |
|
1422 | ValueError, "chunk 1 is not a valid zstd frame" | |||
|
1423 | ): | |||
1393 | dctx.decompress_content_dict_chain([initial, b"foo" * 8]) |
|
1424 | dctx.decompress_content_dict_chain([initial, b"foo" * 8]) | |
1394 |
|
1425 | |||
1395 |
no_size = zstd.ZstdCompressor(write_content_size=False).compress( |
|
1426 | no_size = zstd.ZstdCompressor(write_content_size=False).compress( | |
|
1427 | b"foo" * 64 | |||
|
1428 | ) | |||
1396 |
|
1429 | |||
1397 | with self.assertRaisesRegex( |
|
1430 | with self.assertRaisesRegex( | |
1398 | ValueError, "chunk 1 missing content size in frame" |
|
1431 | ValueError, "chunk 1 missing content size in frame" | |
1399 | ): |
|
1432 | ): | |
1400 | dctx.decompress_content_dict_chain([initial, no_size]) |
|
1433 | dctx.decompress_content_dict_chain([initial, no_size]) | |
1401 |
|
1434 | |||
1402 | # Corrupt second frame. |
|
1435 | # Corrupt second frame. | |
1403 | cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64)) |
|
1436 | cctx = zstd.ZstdCompressor( | |
|
1437 | dict_data=zstd.ZstdCompressionDict(b"foo" * 64) | |||
|
1438 | ) | |||
1404 | frame = cctx.compress(b"bar" * 64) |
|
1439 | frame = cctx.compress(b"bar" * 64) | |
1405 | frame = frame[0:12] + frame[15:] |
|
1440 | frame = frame[0:12] + frame[15:] | |
1406 |
|
1441 | |||
1407 | with self.assertRaisesRegex( |
|
1442 | with self.assertRaisesRegex( | |
1408 | zstd.ZstdError, "chunk 1 did not decompress full frame" |
|
1443 | zstd.ZstdError, "chunk 1 did not decompress full frame" | |
1409 | ): |
|
1444 | ): | |
1410 | dctx.decompress_content_dict_chain([initial, frame]) |
|
1445 | dctx.decompress_content_dict_chain([initial, frame]) | |
1411 |
|
1446 | |||
1412 | def test_simple(self): |
|
1447 | def test_simple(self): | |
1413 | original = [ |
|
1448 | original = [ | |
1414 | b"foo" * 64, |
|
1449 | b"foo" * 64, | |
1415 | b"foobar" * 64, |
|
1450 | b"foobar" * 64, | |
1416 | b"baz" * 64, |
|
1451 | b"baz" * 64, | |
1417 | b"foobaz" * 64, |
|
1452 | b"foobaz" * 64, | |
1418 | b"foobarbaz" * 64, |
|
1453 | b"foobarbaz" * 64, | |
1419 | ] |
|
1454 | ] | |
1420 |
|
1455 | |||
1421 | chunks = [] |
|
1456 | chunks = [] | |
1422 | chunks.append(zstd.ZstdCompressor().compress(original[0])) |
|
1457 | chunks.append(zstd.ZstdCompressor().compress(original[0])) | |
1423 | for i, chunk in enumerate(original[1:]): |
|
1458 | for i, chunk in enumerate(original[1:]): | |
1424 | d = zstd.ZstdCompressionDict(original[i]) |
|
1459 | d = zstd.ZstdCompressionDict(original[i]) | |
1425 | cctx = zstd.ZstdCompressor(dict_data=d) |
|
1460 | cctx = zstd.ZstdCompressor(dict_data=d) | |
1426 | chunks.append(cctx.compress(chunk)) |
|
1461 | chunks.append(cctx.compress(chunk)) | |
1427 |
|
1462 | |||
1428 | for i in range(1, len(original)): |
|
1463 | for i in range(1, len(original)): | |
1429 | chain = chunks[0:i] |
|
1464 | chain = chunks[0:i] | |
1430 | expected = original[i - 1] |
|
1465 | expected = original[i - 1] | |
1431 | dctx = zstd.ZstdDecompressor() |
|
1466 | dctx = zstd.ZstdDecompressor() | |
1432 | decompressed = dctx.decompress_content_dict_chain(chain) |
|
1467 | decompressed = dctx.decompress_content_dict_chain(chain) | |
1433 | self.assertEqual(decompressed, expected) |
|
1468 | self.assertEqual(decompressed, expected) | |
1434 |
|
1469 | |||
1435 |
|
1470 | |||
1436 | # TODO enable for CFFI |
|
1471 | # TODO enable for CFFI | |
1437 | class TestDecompressor_multi_decompress_to_buffer(TestCase): |
|
1472 | class TestDecompressor_multi_decompress_to_buffer(TestCase): | |
1438 | def test_invalid_inputs(self): |
|
1473 | def test_invalid_inputs(self): | |
1439 | dctx = zstd.ZstdDecompressor() |
|
1474 | dctx = zstd.ZstdDecompressor() | |
1440 |
|
1475 | |||
1441 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1476 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1442 | self.skipTest("multi_decompress_to_buffer not available") |
|
1477 | self.skipTest("multi_decompress_to_buffer not available") | |
1443 |
|
1478 | |||
1444 | with self.assertRaises(TypeError): |
|
1479 | with self.assertRaises(TypeError): | |
1445 | dctx.multi_decompress_to_buffer(True) |
|
1480 | dctx.multi_decompress_to_buffer(True) | |
1446 |
|
1481 | |||
1447 | with self.assertRaises(TypeError): |
|
1482 | with self.assertRaises(TypeError): | |
1448 | dctx.multi_decompress_to_buffer((1, 2)) |
|
1483 | dctx.multi_decompress_to_buffer((1, 2)) | |
1449 |
|
1484 | |||
1450 |
with self.assertRaisesRegex( |
|
1485 | with self.assertRaisesRegex( | |
|
1486 | TypeError, "item 0 not a bytes like object" | |||
|
1487 | ): | |||
1451 | dctx.multi_decompress_to_buffer([u"foo"]) |
|
1488 | dctx.multi_decompress_to_buffer([u"foo"]) | |
1452 |
|
1489 | |||
1453 | with self.assertRaisesRegex( |
|
1490 | with self.assertRaisesRegex( | |
1454 | ValueError, "could not determine decompressed size of item 0" |
|
1491 | ValueError, "could not determine decompressed size of item 0" | |
1455 | ): |
|
1492 | ): | |
1456 | dctx.multi_decompress_to_buffer([b"foobarbaz"]) |
|
1493 | dctx.multi_decompress_to_buffer([b"foobarbaz"]) | |
1457 |
|
1494 | |||
1458 | def test_list_input(self): |
|
1495 | def test_list_input(self): | |
1459 | cctx = zstd.ZstdCompressor() |
|
1496 | cctx = zstd.ZstdCompressor() | |
1460 |
|
1497 | |||
1461 | original = [b"foo" * 4, b"bar" * 6] |
|
1498 | original = [b"foo" * 4, b"bar" * 6] | |
1462 | frames = [cctx.compress(d) for d in original] |
|
1499 | frames = [cctx.compress(d) for d in original] | |
1463 |
|
1500 | |||
1464 | dctx = zstd.ZstdDecompressor() |
|
1501 | dctx = zstd.ZstdDecompressor() | |
1465 |
|
1502 | |||
1466 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1503 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1467 | self.skipTest("multi_decompress_to_buffer not available") |
|
1504 | self.skipTest("multi_decompress_to_buffer not available") | |
1468 |
|
1505 | |||
1469 | result = dctx.multi_decompress_to_buffer(frames) |
|
1506 | result = dctx.multi_decompress_to_buffer(frames) | |
1470 |
|
1507 | |||
1471 | self.assertEqual(len(result), len(frames)) |
|
1508 | self.assertEqual(len(result), len(frames)) | |
1472 | self.assertEqual(result.size(), sum(map(len, original))) |
|
1509 | self.assertEqual(result.size(), sum(map(len, original))) | |
1473 |
|
1510 | |||
1474 | for i, data in enumerate(original): |
|
1511 | for i, data in enumerate(original): | |
1475 | self.assertEqual(result[i].tobytes(), data) |
|
1512 | self.assertEqual(result[i].tobytes(), data) | |
1476 |
|
1513 | |||
1477 | self.assertEqual(result[0].offset, 0) |
|
1514 | self.assertEqual(result[0].offset, 0) | |
1478 | self.assertEqual(len(result[0]), 12) |
|
1515 | self.assertEqual(len(result[0]), 12) | |
1479 | self.assertEqual(result[1].offset, 12) |
|
1516 | self.assertEqual(result[1].offset, 12) | |
1480 | self.assertEqual(len(result[1]), 18) |
|
1517 | self.assertEqual(len(result[1]), 18) | |
1481 |
|
1518 | |||
1482 | def test_list_input_frame_sizes(self): |
|
1519 | def test_list_input_frame_sizes(self): | |
1483 | cctx = zstd.ZstdCompressor() |
|
1520 | cctx = zstd.ZstdCompressor() | |
1484 |
|
1521 | |||
1485 | original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] |
|
1522 | original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] | |
1486 | frames = [cctx.compress(d) for d in original] |
|
1523 | frames = [cctx.compress(d) for d in original] | |
1487 | sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) |
|
1524 | sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) | |
1488 |
|
1525 | |||
1489 | dctx = zstd.ZstdDecompressor() |
|
1526 | dctx = zstd.ZstdDecompressor() | |
1490 |
|
1527 | |||
1491 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1528 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1492 | self.skipTest("multi_decompress_to_buffer not available") |
|
1529 | self.skipTest("multi_decompress_to_buffer not available") | |
1493 |
|
1530 | |||
1494 |
result = dctx.multi_decompress_to_buffer( |
|
1531 | result = dctx.multi_decompress_to_buffer( | |
|
1532 | frames, decompressed_sizes=sizes | |||
|
1533 | ) | |||
1495 |
|
1534 | |||
1496 | self.assertEqual(len(result), len(frames)) |
|
1535 | self.assertEqual(len(result), len(frames)) | |
1497 | self.assertEqual(result.size(), sum(map(len, original))) |
|
1536 | self.assertEqual(result.size(), sum(map(len, original))) | |
1498 |
|
1537 | |||
1499 | for i, data in enumerate(original): |
|
1538 | for i, data in enumerate(original): | |
1500 | self.assertEqual(result[i].tobytes(), data) |
|
1539 | self.assertEqual(result[i].tobytes(), data) | |
1501 |
|
1540 | |||
1502 | def test_buffer_with_segments_input(self): |
|
1541 | def test_buffer_with_segments_input(self): | |
1503 | cctx = zstd.ZstdCompressor() |
|
1542 | cctx = zstd.ZstdCompressor() | |
1504 |
|
1543 | |||
1505 | original = [b"foo" * 4, b"bar" * 6] |
|
1544 | original = [b"foo" * 4, b"bar" * 6] | |
1506 | frames = [cctx.compress(d) for d in original] |
|
1545 | frames = [cctx.compress(d) for d in original] | |
1507 |
|
1546 | |||
1508 | dctx = zstd.ZstdDecompressor() |
|
1547 | dctx = zstd.ZstdDecompressor() | |
1509 |
|
1548 | |||
1510 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1549 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1511 | self.skipTest("multi_decompress_to_buffer not available") |
|
1550 | self.skipTest("multi_decompress_to_buffer not available") | |
1512 |
|
1551 | |||
1513 | segments = struct.pack( |
|
1552 | segments = struct.pack( | |
1514 | "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) |
|
1553 | "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) | |
1515 | ) |
|
1554 | ) | |
1516 | b = zstd.BufferWithSegments(b"".join(frames), segments) |
|
1555 | b = zstd.BufferWithSegments(b"".join(frames), segments) | |
1517 |
|
1556 | |||
1518 | result = dctx.multi_decompress_to_buffer(b) |
|
1557 | result = dctx.multi_decompress_to_buffer(b) | |
1519 |
|
1558 | |||
1520 | self.assertEqual(len(result), len(frames)) |
|
1559 | self.assertEqual(len(result), len(frames)) | |
1521 | self.assertEqual(result[0].offset, 0) |
|
1560 | self.assertEqual(result[0].offset, 0) | |
1522 | self.assertEqual(len(result[0]), 12) |
|
1561 | self.assertEqual(len(result[0]), 12) | |
1523 | self.assertEqual(result[1].offset, 12) |
|
1562 | self.assertEqual(result[1].offset, 12) | |
1524 | self.assertEqual(len(result[1]), 18) |
|
1563 | self.assertEqual(len(result[1]), 18) | |
1525 |
|
1564 | |||
1526 | def test_buffer_with_segments_sizes(self): |
|
1565 | def test_buffer_with_segments_sizes(self): | |
1527 | cctx = zstd.ZstdCompressor(write_content_size=False) |
|
1566 | cctx = zstd.ZstdCompressor(write_content_size=False) | |
1528 | original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] |
|
1567 | original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] | |
1529 | frames = [cctx.compress(d) for d in original] |
|
1568 | frames = [cctx.compress(d) for d in original] | |
1530 | sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) |
|
1569 | sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) | |
1531 |
|
1570 | |||
1532 | dctx = zstd.ZstdDecompressor() |
|
1571 | dctx = zstd.ZstdDecompressor() | |
1533 |
|
1572 | |||
1534 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1573 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1535 | self.skipTest("multi_decompress_to_buffer not available") |
|
1574 | self.skipTest("multi_decompress_to_buffer not available") | |
1536 |
|
1575 | |||
1537 | segments = struct.pack( |
|
1576 | segments = struct.pack( | |
1538 | "=QQQQQQ", |
|
1577 | "=QQQQQQ", | |
1539 | 0, |
|
1578 | 0, | |
1540 | len(frames[0]), |
|
1579 | len(frames[0]), | |
1541 | len(frames[0]), |
|
1580 | len(frames[0]), | |
1542 | len(frames[1]), |
|
1581 | len(frames[1]), | |
1543 | len(frames[0]) + len(frames[1]), |
|
1582 | len(frames[0]) + len(frames[1]), | |
1544 | len(frames[2]), |
|
1583 | len(frames[2]), | |
1545 | ) |
|
1584 | ) | |
1546 | b = zstd.BufferWithSegments(b"".join(frames), segments) |
|
1585 | b = zstd.BufferWithSegments(b"".join(frames), segments) | |
1547 |
|
1586 | |||
1548 | result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) |
|
1587 | result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) | |
1549 |
|
1588 | |||
1550 | self.assertEqual(len(result), len(frames)) |
|
1589 | self.assertEqual(len(result), len(frames)) | |
1551 | self.assertEqual(result.size(), sum(map(len, original))) |
|
1590 | self.assertEqual(result.size(), sum(map(len, original))) | |
1552 |
|
1591 | |||
1553 | for i, data in enumerate(original): |
|
1592 | for i, data in enumerate(original): | |
1554 | self.assertEqual(result[i].tobytes(), data) |
|
1593 | self.assertEqual(result[i].tobytes(), data) | |
1555 |
|
1594 | |||
1556 | def test_buffer_with_segments_collection_input(self): |
|
1595 | def test_buffer_with_segments_collection_input(self): | |
1557 | cctx = zstd.ZstdCompressor() |
|
1596 | cctx = zstd.ZstdCompressor() | |
1558 |
|
1597 | |||
1559 | original = [ |
|
1598 | original = [ | |
1560 | b"foo0" * 2, |
|
1599 | b"foo0" * 2, | |
1561 | b"foo1" * 3, |
|
1600 | b"foo1" * 3, | |
1562 | b"foo2" * 4, |
|
1601 | b"foo2" * 4, | |
1563 | b"foo3" * 5, |
|
1602 | b"foo3" * 5, | |
1564 | b"foo4" * 6, |
|
1603 | b"foo4" * 6, | |
1565 | ] |
|
1604 | ] | |
1566 |
|
1605 | |||
1567 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
1606 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
1568 | self.skipTest("multi_compress_to_buffer not available") |
|
1607 | self.skipTest("multi_compress_to_buffer not available") | |
1569 |
|
1608 | |||
1570 | frames = cctx.multi_compress_to_buffer(original) |
|
1609 | frames = cctx.multi_compress_to_buffer(original) | |
1571 |
|
1610 | |||
1572 | # Check round trip. |
|
1611 | # Check round trip. | |
1573 | dctx = zstd.ZstdDecompressor() |
|
1612 | dctx = zstd.ZstdDecompressor() | |
1574 |
|
1613 | |||
1575 | decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) |
|
1614 | decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) | |
1576 |
|
1615 | |||
1577 | self.assertEqual(len(decompressed), len(original)) |
|
1616 | self.assertEqual(len(decompressed), len(original)) | |
1578 |
|
1617 | |||
1579 | for i, data in enumerate(original): |
|
1618 | for i, data in enumerate(original): | |
1580 | self.assertEqual(data, decompressed[i].tobytes()) |
|
1619 | self.assertEqual(data, decompressed[i].tobytes()) | |
1581 |
|
1620 | |||
1582 | # And a manual mode. |
|
1621 | # And a manual mode. | |
1583 | b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) |
|
1622 | b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) | |
1584 | b1 = zstd.BufferWithSegments( |
|
1623 | b1 = zstd.BufferWithSegments( | |
1585 | b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])) |
|
1624 | b, | |
|
1625 | struct.pack( | |||
|
1626 | "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) | |||
|
1627 | ), | |||
1586 | ) |
|
1628 | ) | |
1587 |
|
1629 | |||
1588 | b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]) |
|
1630 | b = b"".join( | |
|
1631 | [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()] | |||
|
1632 | ) | |||
1589 | b2 = zstd.BufferWithSegments( |
|
1633 | b2 = zstd.BufferWithSegments( | |
1590 | b, |
|
1634 | b, | |
1591 | struct.pack( |
|
1635 | struct.pack( | |
1592 | "=QQQQQQ", |
|
1636 | "=QQQQQQ", | |
1593 | 0, |
|
1637 | 0, | |
1594 | len(frames[2]), |
|
1638 | len(frames[2]), | |
1595 | len(frames[2]), |
|
1639 | len(frames[2]), | |
1596 | len(frames[3]), |
|
1640 | len(frames[3]), | |
1597 | len(frames[2]) + len(frames[3]), |
|
1641 | len(frames[2]) + len(frames[3]), | |
1598 | len(frames[4]), |
|
1642 | len(frames[4]), | |
1599 | ), |
|
1643 | ), | |
1600 | ) |
|
1644 | ) | |
1601 |
|
1645 | |||
1602 | c = zstd.BufferWithSegmentsCollection(b1, b2) |
|
1646 | c = zstd.BufferWithSegmentsCollection(b1, b2) | |
1603 |
|
1647 | |||
1604 | dctx = zstd.ZstdDecompressor() |
|
1648 | dctx = zstd.ZstdDecompressor() | |
1605 | decompressed = dctx.multi_decompress_to_buffer(c) |
|
1649 | decompressed = dctx.multi_decompress_to_buffer(c) | |
1606 |
|
1650 | |||
1607 | self.assertEqual(len(decompressed), 5) |
|
1651 | self.assertEqual(len(decompressed), 5) | |
1608 | for i in range(5): |
|
1652 | for i in range(5): | |
1609 | self.assertEqual(decompressed[i].tobytes(), original[i]) |
|
1653 | self.assertEqual(decompressed[i].tobytes(), original[i]) | |
1610 |
|
1654 | |||
1611 | def test_dict(self): |
|
1655 | def test_dict(self): | |
1612 | d = zstd.train_dictionary(16384, generate_samples(), k=64, d=16) |
|
1656 | d = zstd.train_dictionary(16384, generate_samples(), k=64, d=16) | |
1613 |
|
1657 | |||
1614 | cctx = zstd.ZstdCompressor(dict_data=d, level=1) |
|
1658 | cctx = zstd.ZstdCompressor(dict_data=d, level=1) | |
1615 | frames = [cctx.compress(s) for s in generate_samples()] |
|
1659 | frames = [cctx.compress(s) for s in generate_samples()] | |
1616 |
|
1660 | |||
1617 | dctx = zstd.ZstdDecompressor(dict_data=d) |
|
1661 | dctx = zstd.ZstdDecompressor(dict_data=d) | |
1618 |
|
1662 | |||
1619 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1663 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1620 | self.skipTest("multi_decompress_to_buffer not available") |
|
1664 | self.skipTest("multi_decompress_to_buffer not available") | |
1621 |
|
1665 | |||
1622 | result = dctx.multi_decompress_to_buffer(frames) |
|
1666 | result = dctx.multi_decompress_to_buffer(frames) | |
1623 |
|
1667 | |||
1624 | self.assertEqual([o.tobytes() for o in result], generate_samples()) |
|
1668 | self.assertEqual([o.tobytes() for o in result], generate_samples()) | |
1625 |
|
1669 | |||
1626 | def test_multiple_threads(self): |
|
1670 | def test_multiple_threads(self): | |
1627 | cctx = zstd.ZstdCompressor() |
|
1671 | cctx = zstd.ZstdCompressor() | |
1628 |
|
1672 | |||
1629 | frames = [] |
|
1673 | frames = [] | |
1630 | frames.extend(cctx.compress(b"x" * 64) for i in range(256)) |
|
1674 | frames.extend(cctx.compress(b"x" * 64) for i in range(256)) | |
1631 | frames.extend(cctx.compress(b"y" * 64) for i in range(256)) |
|
1675 | frames.extend(cctx.compress(b"y" * 64) for i in range(256)) | |
1632 |
|
1676 | |||
1633 | dctx = zstd.ZstdDecompressor() |
|
1677 | dctx = zstd.ZstdDecompressor() | |
1634 |
|
1678 | |||
1635 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1679 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1636 | self.skipTest("multi_decompress_to_buffer not available") |
|
1680 | self.skipTest("multi_decompress_to_buffer not available") | |
1637 |
|
1681 | |||
1638 | result = dctx.multi_decompress_to_buffer(frames, threads=-1) |
|
1682 | result = dctx.multi_decompress_to_buffer(frames, threads=-1) | |
1639 |
|
1683 | |||
1640 | self.assertEqual(len(result), len(frames)) |
|
1684 | self.assertEqual(len(result), len(frames)) | |
1641 | self.assertEqual(result.size(), 2 * 64 * 256) |
|
1685 | self.assertEqual(result.size(), 2 * 64 * 256) | |
1642 | self.assertEqual(result[0].tobytes(), b"x" * 64) |
|
1686 | self.assertEqual(result[0].tobytes(), b"x" * 64) | |
1643 | self.assertEqual(result[256].tobytes(), b"y" * 64) |
|
1687 | self.assertEqual(result[256].tobytes(), b"y" * 64) | |
1644 |
|
1688 | |||
1645 | def test_item_failure(self): |
|
1689 | def test_item_failure(self): | |
1646 | cctx = zstd.ZstdCompressor() |
|
1690 | cctx = zstd.ZstdCompressor() | |
1647 | frames = [cctx.compress(b"x" * 128), cctx.compress(b"y" * 128)] |
|
1691 | frames = [cctx.compress(b"x" * 128), cctx.compress(b"y" * 128)] | |
1648 |
|
1692 | |||
1649 | frames[1] = frames[1][0:15] + b"extra" + frames[1][15:] |
|
1693 | frames[1] = frames[1][0:15] + b"extra" + frames[1][15:] | |
1650 |
|
1694 | |||
1651 | dctx = zstd.ZstdDecompressor() |
|
1695 | dctx = zstd.ZstdDecompressor() | |
1652 |
|
1696 | |||
1653 | if not hasattr(dctx, "multi_decompress_to_buffer"): |
|
1697 | if not hasattr(dctx, "multi_decompress_to_buffer"): | |
1654 | self.skipTest("multi_decompress_to_buffer not available") |
|
1698 | self.skipTest("multi_decompress_to_buffer not available") | |
1655 |
|
1699 | |||
1656 | with self.assertRaisesRegex( |
|
1700 | with self.assertRaisesRegex( | |
1657 | zstd.ZstdError, |
|
1701 | zstd.ZstdError, | |
1658 | "error decompressing item 1: (" |
|
1702 | "error decompressing item 1: (" | |
1659 | "Corrupted block|" |
|
1703 | "Corrupted block|" | |
1660 | "Destination buffer is too small)", |
|
1704 | "Destination buffer is too small)", | |
1661 | ): |
|
1705 | ): | |
1662 | dctx.multi_decompress_to_buffer(frames) |
|
1706 | dctx.multi_decompress_to_buffer(frames) | |
1663 |
|
1707 | |||
1664 | with self.assertRaisesRegex( |
|
1708 | with self.assertRaisesRegex( | |
1665 | zstd.ZstdError, |
|
1709 | zstd.ZstdError, | |
1666 | "error decompressing item 1: (" |
|
1710 | "error decompressing item 1: (" | |
1667 | "Corrupted block|" |
|
1711 | "Corrupted block|" | |
1668 | "Destination buffer is too small)", |
|
1712 | "Destination buffer is too small)", | |
1669 | ): |
|
1713 | ): | |
1670 | dctx.multi_decompress_to_buffer(frames, threads=2) |
|
1714 | dctx.multi_decompress_to_buffer(frames, threads=2) |
@@ -1,576 +1,593 b'' | |||||
1 | import io |
|
1 | import io | |
2 | import os |
|
2 | import os | |
3 | import unittest |
|
3 | import unittest | |
4 |
|
4 | |||
5 | try: |
|
5 | try: | |
6 | import hypothesis |
|
6 | import hypothesis | |
7 | import hypothesis.strategies as strategies |
|
7 | import hypothesis.strategies as strategies | |
8 | except ImportError: |
|
8 | except ImportError: | |
9 | raise unittest.SkipTest("hypothesis not available") |
|
9 | raise unittest.SkipTest("hypothesis not available") | |
10 |
|
10 | |||
11 | import zstandard as zstd |
|
11 | import zstandard as zstd | |
12 |
|
12 | |||
13 | from .common import ( |
|
13 | from .common import ( | |
14 | make_cffi, |
|
14 | make_cffi, | |
15 | NonClosingBytesIO, |
|
15 | NonClosingBytesIO, | |
16 | random_input_data, |
|
16 | random_input_data, | |
17 | TestCase, |
|
17 | TestCase, | |
18 | ) |
|
18 | ) | |
19 |
|
19 | |||
20 |
|
20 | |||
21 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
21 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
22 | @make_cffi |
|
22 | @make_cffi | |
23 | class TestDecompressor_stream_reader_fuzzing(TestCase): |
|
23 | class TestDecompressor_stream_reader_fuzzing(TestCase): | |
24 | @hypothesis.settings( |
|
24 | @hypothesis.settings( | |
25 | suppress_health_check=[ |
|
25 | suppress_health_check=[ | |
26 | hypothesis.HealthCheck.large_base_example, |
|
26 | hypothesis.HealthCheck.large_base_example, | |
27 | hypothesis.HealthCheck.too_slow, |
|
27 | hypothesis.HealthCheck.too_slow, | |
28 | ] |
|
28 | ] | |
29 | ) |
|
29 | ) | |
30 | @hypothesis.given( |
|
30 | @hypothesis.given( | |
31 | original=strategies.sampled_from(random_input_data()), |
|
31 | original=strategies.sampled_from(random_input_data()), | |
32 | level=strategies.integers(min_value=1, max_value=5), |
|
32 | level=strategies.integers(min_value=1, max_value=5), | |
33 | streaming=strategies.booleans(), |
|
33 | streaming=strategies.booleans(), | |
34 | source_read_size=strategies.integers(1, 1048576), |
|
34 | source_read_size=strategies.integers(1, 1048576), | |
35 | read_sizes=strategies.data(), |
|
35 | read_sizes=strategies.data(), | |
36 | ) |
|
36 | ) | |
37 | def test_stream_source_read_variance( |
|
37 | def test_stream_source_read_variance( | |
38 | self, original, level, streaming, source_read_size, read_sizes |
|
38 | self, original, level, streaming, source_read_size, read_sizes | |
39 | ): |
|
39 | ): | |
40 | cctx = zstd.ZstdCompressor(level=level) |
|
40 | cctx = zstd.ZstdCompressor(level=level) | |
41 |
|
41 | |||
42 | if streaming: |
|
42 | if streaming: | |
43 | source = io.BytesIO() |
|
43 | source = io.BytesIO() | |
44 | writer = cctx.stream_writer(source) |
|
44 | writer = cctx.stream_writer(source) | |
45 | writer.write(original) |
|
45 | writer.write(original) | |
46 | writer.flush(zstd.FLUSH_FRAME) |
|
46 | writer.flush(zstd.FLUSH_FRAME) | |
47 | source.seek(0) |
|
47 | source.seek(0) | |
48 | else: |
|
48 | else: | |
49 | frame = cctx.compress(original) |
|
49 | frame = cctx.compress(original) | |
50 | source = io.BytesIO(frame) |
|
50 | source = io.BytesIO(frame) | |
51 |
|
51 | |||
52 | dctx = zstd.ZstdDecompressor() |
|
52 | dctx = zstd.ZstdDecompressor() | |
53 |
|
53 | |||
54 | chunks = [] |
|
54 | chunks = [] | |
55 | with dctx.stream_reader(source, read_size=source_read_size) as reader: |
|
55 | with dctx.stream_reader(source, read_size=source_read_size) as reader: | |
56 | while True: |
|
56 | while True: | |
57 | read_size = read_sizes.draw(strategies.integers(-1, 131072)) |
|
57 | read_size = read_sizes.draw(strategies.integers(-1, 131072)) | |
58 | chunk = reader.read(read_size) |
|
58 | chunk = reader.read(read_size) | |
59 | if not chunk and read_size: |
|
59 | if not chunk and read_size: | |
60 | break |
|
60 | break | |
61 |
|
61 | |||
62 | chunks.append(chunk) |
|
62 | chunks.append(chunk) | |
63 |
|
63 | |||
64 | self.assertEqual(b"".join(chunks), original) |
|
64 | self.assertEqual(b"".join(chunks), original) | |
65 |
|
65 | |||
66 | # Similar to above except we have a constant read() size. |
|
66 | # Similar to above except we have a constant read() size. | |
67 | @hypothesis.settings( |
|
67 | @hypothesis.settings( | |
68 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
68 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
69 | ) |
|
69 | ) | |
70 | @hypothesis.given( |
|
70 | @hypothesis.given( | |
71 | original=strategies.sampled_from(random_input_data()), |
|
71 | original=strategies.sampled_from(random_input_data()), | |
72 | level=strategies.integers(min_value=1, max_value=5), |
|
72 | level=strategies.integers(min_value=1, max_value=5), | |
73 | streaming=strategies.booleans(), |
|
73 | streaming=strategies.booleans(), | |
74 | source_read_size=strategies.integers(1, 1048576), |
|
74 | source_read_size=strategies.integers(1, 1048576), | |
75 | read_size=strategies.integers(-1, 131072), |
|
75 | read_size=strategies.integers(-1, 131072), | |
76 | ) |
|
76 | ) | |
77 | def test_stream_source_read_size( |
|
77 | def test_stream_source_read_size( | |
78 | self, original, level, streaming, source_read_size, read_size |
|
78 | self, original, level, streaming, source_read_size, read_size | |
79 | ): |
|
79 | ): | |
80 | if read_size == 0: |
|
80 | if read_size == 0: | |
81 | read_size = 1 |
|
81 | read_size = 1 | |
82 |
|
82 | |||
83 | cctx = zstd.ZstdCompressor(level=level) |
|
83 | cctx = zstd.ZstdCompressor(level=level) | |
84 |
|
84 | |||
85 | if streaming: |
|
85 | if streaming: | |
86 | source = io.BytesIO() |
|
86 | source = io.BytesIO() | |
87 | writer = cctx.stream_writer(source) |
|
87 | writer = cctx.stream_writer(source) | |
88 | writer.write(original) |
|
88 | writer.write(original) | |
89 | writer.flush(zstd.FLUSH_FRAME) |
|
89 | writer.flush(zstd.FLUSH_FRAME) | |
90 | source.seek(0) |
|
90 | source.seek(0) | |
91 | else: |
|
91 | else: | |
92 | frame = cctx.compress(original) |
|
92 | frame = cctx.compress(original) | |
93 | source = io.BytesIO(frame) |
|
93 | source = io.BytesIO(frame) | |
94 |
|
94 | |||
95 | dctx = zstd.ZstdDecompressor() |
|
95 | dctx = zstd.ZstdDecompressor() | |
96 |
|
96 | |||
97 | chunks = [] |
|
97 | chunks = [] | |
98 | reader = dctx.stream_reader(source, read_size=source_read_size) |
|
98 | reader = dctx.stream_reader(source, read_size=source_read_size) | |
99 | while True: |
|
99 | while True: | |
100 | chunk = reader.read(read_size) |
|
100 | chunk = reader.read(read_size) | |
101 | if not chunk and read_size: |
|
101 | if not chunk and read_size: | |
102 | break |
|
102 | break | |
103 |
|
103 | |||
104 | chunks.append(chunk) |
|
104 | chunks.append(chunk) | |
105 |
|
105 | |||
106 | self.assertEqual(b"".join(chunks), original) |
|
106 | self.assertEqual(b"".join(chunks), original) | |
107 |
|
107 | |||
108 | @hypothesis.settings( |
|
108 | @hypothesis.settings( | |
109 | suppress_health_check=[ |
|
109 | suppress_health_check=[ | |
110 | hypothesis.HealthCheck.large_base_example, |
|
110 | hypothesis.HealthCheck.large_base_example, | |
111 | hypothesis.HealthCheck.too_slow, |
|
111 | hypothesis.HealthCheck.too_slow, | |
112 | ] |
|
112 | ] | |
113 | ) |
|
113 | ) | |
114 | @hypothesis.given( |
|
114 | @hypothesis.given( | |
115 | original=strategies.sampled_from(random_input_data()), |
|
115 | original=strategies.sampled_from(random_input_data()), | |
116 | level=strategies.integers(min_value=1, max_value=5), |
|
116 | level=strategies.integers(min_value=1, max_value=5), | |
117 | streaming=strategies.booleans(), |
|
117 | streaming=strategies.booleans(), | |
118 | source_read_size=strategies.integers(1, 1048576), |
|
118 | source_read_size=strategies.integers(1, 1048576), | |
119 | read_sizes=strategies.data(), |
|
119 | read_sizes=strategies.data(), | |
120 | ) |
|
120 | ) | |
121 | def test_buffer_source_read_variance( |
|
121 | def test_buffer_source_read_variance( | |
122 | self, original, level, streaming, source_read_size, read_sizes |
|
122 | self, original, level, streaming, source_read_size, read_sizes | |
123 | ): |
|
123 | ): | |
124 | cctx = zstd.ZstdCompressor(level=level) |
|
124 | cctx = zstd.ZstdCompressor(level=level) | |
125 |
|
125 | |||
126 | if streaming: |
|
126 | if streaming: | |
127 | source = io.BytesIO() |
|
127 | source = io.BytesIO() | |
128 | writer = cctx.stream_writer(source) |
|
128 | writer = cctx.stream_writer(source) | |
129 | writer.write(original) |
|
129 | writer.write(original) | |
130 | writer.flush(zstd.FLUSH_FRAME) |
|
130 | writer.flush(zstd.FLUSH_FRAME) | |
131 | frame = source.getvalue() |
|
131 | frame = source.getvalue() | |
132 | else: |
|
132 | else: | |
133 | frame = cctx.compress(original) |
|
133 | frame = cctx.compress(original) | |
134 |
|
134 | |||
135 | dctx = zstd.ZstdDecompressor() |
|
135 | dctx = zstd.ZstdDecompressor() | |
136 | chunks = [] |
|
136 | chunks = [] | |
137 |
|
137 | |||
138 | with dctx.stream_reader(frame, read_size=source_read_size) as reader: |
|
138 | with dctx.stream_reader(frame, read_size=source_read_size) as reader: | |
139 | while True: |
|
139 | while True: | |
140 | read_size = read_sizes.draw(strategies.integers(-1, 131072)) |
|
140 | read_size = read_sizes.draw(strategies.integers(-1, 131072)) | |
141 | chunk = reader.read(read_size) |
|
141 | chunk = reader.read(read_size) | |
142 | if not chunk and read_size: |
|
142 | if not chunk and read_size: | |
143 | break |
|
143 | break | |
144 |
|
144 | |||
145 | chunks.append(chunk) |
|
145 | chunks.append(chunk) | |
146 |
|
146 | |||
147 | self.assertEqual(b"".join(chunks), original) |
|
147 | self.assertEqual(b"".join(chunks), original) | |
148 |
|
148 | |||
149 | # Similar to above except we have a constant read() size. |
|
149 | # Similar to above except we have a constant read() size. | |
150 | @hypothesis.settings( |
|
150 | @hypothesis.settings( | |
151 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
151 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
152 | ) |
|
152 | ) | |
153 | @hypothesis.given( |
|
153 | @hypothesis.given( | |
154 | original=strategies.sampled_from(random_input_data()), |
|
154 | original=strategies.sampled_from(random_input_data()), | |
155 | level=strategies.integers(min_value=1, max_value=5), |
|
155 | level=strategies.integers(min_value=1, max_value=5), | |
156 | streaming=strategies.booleans(), |
|
156 | streaming=strategies.booleans(), | |
157 | source_read_size=strategies.integers(1, 1048576), |
|
157 | source_read_size=strategies.integers(1, 1048576), | |
158 | read_size=strategies.integers(-1, 131072), |
|
158 | read_size=strategies.integers(-1, 131072), | |
159 | ) |
|
159 | ) | |
160 | def test_buffer_source_constant_read_size( |
|
160 | def test_buffer_source_constant_read_size( | |
161 | self, original, level, streaming, source_read_size, read_size |
|
161 | self, original, level, streaming, source_read_size, read_size | |
162 | ): |
|
162 | ): | |
163 | if read_size == 0: |
|
163 | if read_size == 0: | |
164 | read_size = -1 |
|
164 | read_size = -1 | |
165 |
|
165 | |||
166 | cctx = zstd.ZstdCompressor(level=level) |
|
166 | cctx = zstd.ZstdCompressor(level=level) | |
167 |
|
167 | |||
168 | if streaming: |
|
168 | if streaming: | |
169 | source = io.BytesIO() |
|
169 | source = io.BytesIO() | |
170 | writer = cctx.stream_writer(source) |
|
170 | writer = cctx.stream_writer(source) | |
171 | writer.write(original) |
|
171 | writer.write(original) | |
172 | writer.flush(zstd.FLUSH_FRAME) |
|
172 | writer.flush(zstd.FLUSH_FRAME) | |
173 | frame = source.getvalue() |
|
173 | frame = source.getvalue() | |
174 | else: |
|
174 | else: | |
175 | frame = cctx.compress(original) |
|
175 | frame = cctx.compress(original) | |
176 |
|
176 | |||
177 | dctx = zstd.ZstdDecompressor() |
|
177 | dctx = zstd.ZstdDecompressor() | |
178 | chunks = [] |
|
178 | chunks = [] | |
179 |
|
179 | |||
180 | reader = dctx.stream_reader(frame, read_size=source_read_size) |
|
180 | reader = dctx.stream_reader(frame, read_size=source_read_size) | |
181 | while True: |
|
181 | while True: | |
182 | chunk = reader.read(read_size) |
|
182 | chunk = reader.read(read_size) | |
183 | if not chunk and read_size: |
|
183 | if not chunk and read_size: | |
184 | break |
|
184 | break | |
185 |
|
185 | |||
186 | chunks.append(chunk) |
|
186 | chunks.append(chunk) | |
187 |
|
187 | |||
188 | self.assertEqual(b"".join(chunks), original) |
|
188 | self.assertEqual(b"".join(chunks), original) | |
189 |
|
189 | |||
190 | @hypothesis.settings( |
|
190 | @hypothesis.settings( | |
191 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] |
|
191 | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | |
192 | ) |
|
192 | ) | |
193 | @hypothesis.given( |
|
193 | @hypothesis.given( | |
194 | original=strategies.sampled_from(random_input_data()), |
|
194 | original=strategies.sampled_from(random_input_data()), | |
195 | level=strategies.integers(min_value=1, max_value=5), |
|
195 | level=strategies.integers(min_value=1, max_value=5), | |
196 | streaming=strategies.booleans(), |
|
196 | streaming=strategies.booleans(), | |
197 | source_read_size=strategies.integers(1, 1048576), |
|
197 | source_read_size=strategies.integers(1, 1048576), | |
198 | ) |
|
198 | ) | |
199 | def test_stream_source_readall(self, original, level, streaming, source_read_size): |
|
199 | def test_stream_source_readall( | |
|
200 | self, original, level, streaming, source_read_size | |||
|
201 | ): | |||
200 | cctx = zstd.ZstdCompressor(level=level) |
|
202 | cctx = zstd.ZstdCompressor(level=level) | |
201 |
|
203 | |||
202 | if streaming: |
|
204 | if streaming: | |
203 | source = io.BytesIO() |
|
205 | source = io.BytesIO() | |
204 | writer = cctx.stream_writer(source) |
|
206 | writer = cctx.stream_writer(source) | |
205 | writer.write(original) |
|
207 | writer.write(original) | |
206 | writer.flush(zstd.FLUSH_FRAME) |
|
208 | writer.flush(zstd.FLUSH_FRAME) | |
207 | source.seek(0) |
|
209 | source.seek(0) | |
208 | else: |
|
210 | else: | |
209 | frame = cctx.compress(original) |
|
211 | frame = cctx.compress(original) | |
210 | source = io.BytesIO(frame) |
|
212 | source = io.BytesIO(frame) | |
211 |
|
213 | |||
212 | dctx = zstd.ZstdDecompressor() |
|
214 | dctx = zstd.ZstdDecompressor() | |
213 |
|
215 | |||
214 | data = dctx.stream_reader(source, read_size=source_read_size).readall() |
|
216 | data = dctx.stream_reader(source, read_size=source_read_size).readall() | |
215 | self.assertEqual(data, original) |
|
217 | self.assertEqual(data, original) | |
216 |
|
218 | |||
217 | @hypothesis.settings( |
|
219 | @hypothesis.settings( | |
218 | suppress_health_check=[ |
|
220 | suppress_health_check=[ | |
219 | hypothesis.HealthCheck.large_base_example, |
|
221 | hypothesis.HealthCheck.large_base_example, | |
220 | hypothesis.HealthCheck.too_slow, |
|
222 | hypothesis.HealthCheck.too_slow, | |
221 | ] |
|
223 | ] | |
222 | ) |
|
224 | ) | |
223 | @hypothesis.given( |
|
225 | @hypothesis.given( | |
224 | original=strategies.sampled_from(random_input_data()), |
|
226 | original=strategies.sampled_from(random_input_data()), | |
225 | level=strategies.integers(min_value=1, max_value=5), |
|
227 | level=strategies.integers(min_value=1, max_value=5), | |
226 | streaming=strategies.booleans(), |
|
228 | streaming=strategies.booleans(), | |
227 | source_read_size=strategies.integers(1, 1048576), |
|
229 | source_read_size=strategies.integers(1, 1048576), | |
228 | read_sizes=strategies.data(), |
|
230 | read_sizes=strategies.data(), | |
229 | ) |
|
231 | ) | |
230 | def test_stream_source_read1_variance( |
|
232 | def test_stream_source_read1_variance( | |
231 | self, original, level, streaming, source_read_size, read_sizes |
|
233 | self, original, level, streaming, source_read_size, read_sizes | |
232 | ): |
|
234 | ): | |
233 | cctx = zstd.ZstdCompressor(level=level) |
|
235 | cctx = zstd.ZstdCompressor(level=level) | |
234 |
|
236 | |||
235 | if streaming: |
|
237 | if streaming: | |
236 | source = io.BytesIO() |
|
238 | source = io.BytesIO() | |
237 | writer = cctx.stream_writer(source) |
|
239 | writer = cctx.stream_writer(source) | |
238 | writer.write(original) |
|
240 | writer.write(original) | |
239 | writer.flush(zstd.FLUSH_FRAME) |
|
241 | writer.flush(zstd.FLUSH_FRAME) | |
240 | source.seek(0) |
|
242 | source.seek(0) | |
241 | else: |
|
243 | else: | |
242 | frame = cctx.compress(original) |
|
244 | frame = cctx.compress(original) | |
243 | source = io.BytesIO(frame) |
|
245 | source = io.BytesIO(frame) | |
244 |
|
246 | |||
245 | dctx = zstd.ZstdDecompressor() |
|
247 | dctx = zstd.ZstdDecompressor() | |
246 |
|
248 | |||
247 | chunks = [] |
|
249 | chunks = [] | |
248 | with dctx.stream_reader(source, read_size=source_read_size) as reader: |
|
250 | with dctx.stream_reader(source, read_size=source_read_size) as reader: | |
249 | while True: |
|
251 | while True: | |
250 | read_size = read_sizes.draw(strategies.integers(-1, 131072)) |
|
252 | read_size = read_sizes.draw(strategies.integers(-1, 131072)) | |
251 | chunk = reader.read1(read_size) |
|
253 | chunk = reader.read1(read_size) | |
252 | if not chunk and read_size: |
|
254 | if not chunk and read_size: | |
253 | break |
|
255 | break | |
254 |
|
256 | |||
255 | chunks.append(chunk) |
|
257 | chunks.append(chunk) | |
256 |
|
258 | |||
257 | self.assertEqual(b"".join(chunks), original) |
|
259 | self.assertEqual(b"".join(chunks), original) | |
258 |
|
260 | |||
259 | @hypothesis.settings( |
|
261 | @hypothesis.settings( | |
260 | suppress_health_check=[ |
|
262 | suppress_health_check=[ | |
261 | hypothesis.HealthCheck.large_base_example, |
|
263 | hypothesis.HealthCheck.large_base_example, | |
262 | hypothesis.HealthCheck.too_slow, |
|
264 | hypothesis.HealthCheck.too_slow, | |
263 | ] |
|
265 | ] | |
264 | ) |
|
266 | ) | |
265 | @hypothesis.given( |
|
267 | @hypothesis.given( | |
266 | original=strategies.sampled_from(random_input_data()), |
|
268 | original=strategies.sampled_from(random_input_data()), | |
267 | level=strategies.integers(min_value=1, max_value=5), |
|
269 | level=strategies.integers(min_value=1, max_value=5), | |
268 | streaming=strategies.booleans(), |
|
270 | streaming=strategies.booleans(), | |
269 | source_read_size=strategies.integers(1, 1048576), |
|
271 | source_read_size=strategies.integers(1, 1048576), | |
270 | read_sizes=strategies.data(), |
|
272 | read_sizes=strategies.data(), | |
271 | ) |
|
273 | ) | |
272 | def test_stream_source_readinto1_variance( |
|
274 | def test_stream_source_readinto1_variance( | |
273 | self, original, level, streaming, source_read_size, read_sizes |
|
275 | self, original, level, streaming, source_read_size, read_sizes | |
274 | ): |
|
276 | ): | |
275 | cctx = zstd.ZstdCompressor(level=level) |
|
277 | cctx = zstd.ZstdCompressor(level=level) | |
276 |
|
278 | |||
277 | if streaming: |
|
279 | if streaming: | |
278 | source = io.BytesIO() |
|
280 | source = io.BytesIO() | |
279 | writer = cctx.stream_writer(source) |
|
281 | writer = cctx.stream_writer(source) | |
280 | writer.write(original) |
|
282 | writer.write(original) | |
281 | writer.flush(zstd.FLUSH_FRAME) |
|
283 | writer.flush(zstd.FLUSH_FRAME) | |
282 | source.seek(0) |
|
284 | source.seek(0) | |
283 | else: |
|
285 | else: | |
284 | frame = cctx.compress(original) |
|
286 | frame = cctx.compress(original) | |
285 | source = io.BytesIO(frame) |
|
287 | source = io.BytesIO(frame) | |
286 |
|
288 | |||
287 | dctx = zstd.ZstdDecompressor() |
|
289 | dctx = zstd.ZstdDecompressor() | |
288 |
|
290 | |||
289 | chunks = [] |
|
291 | chunks = [] | |
290 | with dctx.stream_reader(source, read_size=source_read_size) as reader: |
|
292 | with dctx.stream_reader(source, read_size=source_read_size) as reader: | |
291 | while True: |
|
293 | while True: | |
292 | read_size = read_sizes.draw(strategies.integers(1, 131072)) |
|
294 | read_size = read_sizes.draw(strategies.integers(1, 131072)) | |
293 | b = bytearray(read_size) |
|
295 | b = bytearray(read_size) | |
294 | count = reader.readinto1(b) |
|
296 | count = reader.readinto1(b) | |
295 |
|
297 | |||
296 | if not count: |
|
298 | if not count: | |
297 | break |
|
299 | break | |
298 |
|
300 | |||
299 | chunks.append(bytes(b[0:count])) |
|
301 | chunks.append(bytes(b[0:count])) | |
300 |
|
302 | |||
301 | self.assertEqual(b"".join(chunks), original) |
|
303 | self.assertEqual(b"".join(chunks), original) | |
302 |
|
304 | |||
303 | @hypothesis.settings( |
|
305 | @hypothesis.settings( | |
304 | suppress_health_check=[ |
|
306 | suppress_health_check=[ | |
305 | hypothesis.HealthCheck.large_base_example, |
|
307 | hypothesis.HealthCheck.large_base_example, | |
306 | hypothesis.HealthCheck.too_slow, |
|
308 | hypothesis.HealthCheck.too_slow, | |
307 | ] |
|
309 | ] | |
308 | ) |
|
310 | ) | |
309 | @hypothesis.given( |
|
311 | @hypothesis.given( | |
310 | original=strategies.sampled_from(random_input_data()), |
|
312 | original=strategies.sampled_from(random_input_data()), | |
311 | level=strategies.integers(min_value=1, max_value=5), |
|
313 | level=strategies.integers(min_value=1, max_value=5), | |
312 | source_read_size=strategies.integers(1, 1048576), |
|
314 | source_read_size=strategies.integers(1, 1048576), | |
313 | seek_amounts=strategies.data(), |
|
315 | seek_amounts=strategies.data(), | |
314 | read_sizes=strategies.data(), |
|
316 | read_sizes=strategies.data(), | |
315 | ) |
|
317 | ) | |
316 | def test_relative_seeks( |
|
318 | def test_relative_seeks( | |
317 | self, original, level, source_read_size, seek_amounts, read_sizes |
|
319 | self, original, level, source_read_size, seek_amounts, read_sizes | |
318 | ): |
|
320 | ): | |
319 | cctx = zstd.ZstdCompressor(level=level) |
|
321 | cctx = zstd.ZstdCompressor(level=level) | |
320 | frame = cctx.compress(original) |
|
322 | frame = cctx.compress(original) | |
321 |
|
323 | |||
322 | dctx = zstd.ZstdDecompressor() |
|
324 | dctx = zstd.ZstdDecompressor() | |
323 |
|
325 | |||
324 | with dctx.stream_reader(frame, read_size=source_read_size) as reader: |
|
326 | with dctx.stream_reader(frame, read_size=source_read_size) as reader: | |
325 | while True: |
|
327 | while True: | |
326 | amount = seek_amounts.draw(strategies.integers(0, 16384)) |
|
328 | amount = seek_amounts.draw(strategies.integers(0, 16384)) | |
327 | reader.seek(amount, os.SEEK_CUR) |
|
329 | reader.seek(amount, os.SEEK_CUR) | |
328 |
|
330 | |||
329 | offset = reader.tell() |
|
331 | offset = reader.tell() | |
330 | read_amount = read_sizes.draw(strategies.integers(1, 16384)) |
|
332 | read_amount = read_sizes.draw(strategies.integers(1, 16384)) | |
331 | chunk = reader.read(read_amount) |
|
333 | chunk = reader.read(read_amount) | |
332 |
|
334 | |||
333 | if not chunk: |
|
335 | if not chunk: | |
334 | break |
|
336 | break | |
335 |
|
337 | |||
336 | self.assertEqual(original[offset : offset + len(chunk)], chunk) |
|
338 | self.assertEqual(original[offset : offset + len(chunk)], chunk) | |
337 |
|
339 | |||
338 | @hypothesis.settings( |
|
340 | @hypothesis.settings( | |
339 | suppress_health_check=[ |
|
341 | suppress_health_check=[ | |
340 | hypothesis.HealthCheck.large_base_example, |
|
342 | hypothesis.HealthCheck.large_base_example, | |
341 | hypothesis.HealthCheck.too_slow, |
|
343 | hypothesis.HealthCheck.too_slow, | |
342 | ] |
|
344 | ] | |
343 | ) |
|
345 | ) | |
344 | @hypothesis.given( |
|
346 | @hypothesis.given( | |
345 | originals=strategies.data(), |
|
347 | originals=strategies.data(), | |
346 | frame_count=strategies.integers(min_value=2, max_value=10), |
|
348 | frame_count=strategies.integers(min_value=2, max_value=10), | |
347 | level=strategies.integers(min_value=1, max_value=5), |
|
349 | level=strategies.integers(min_value=1, max_value=5), | |
348 | source_read_size=strategies.integers(1, 1048576), |
|
350 | source_read_size=strategies.integers(1, 1048576), | |
349 | read_sizes=strategies.data(), |
|
351 | read_sizes=strategies.data(), | |
350 | ) |
|
352 | ) | |
351 | def test_multiple_frames( |
|
353 | def test_multiple_frames( | |
352 | self, originals, frame_count, level, source_read_size, read_sizes |
|
354 | self, originals, frame_count, level, source_read_size, read_sizes | |
353 | ): |
|
355 | ): | |
354 |
|
356 | |||
355 | cctx = zstd.ZstdCompressor(level=level) |
|
357 | cctx = zstd.ZstdCompressor(level=level) | |
356 | source = io.BytesIO() |
|
358 | source = io.BytesIO() | |
357 | buffer = io.BytesIO() |
|
359 | buffer = io.BytesIO() | |
358 | writer = cctx.stream_writer(buffer) |
|
360 | writer = cctx.stream_writer(buffer) | |
359 |
|
361 | |||
360 | for i in range(frame_count): |
|
362 | for i in range(frame_count): | |
361 | data = originals.draw(strategies.sampled_from(random_input_data())) |
|
363 | data = originals.draw(strategies.sampled_from(random_input_data())) | |
362 | source.write(data) |
|
364 | source.write(data) | |
363 | writer.write(data) |
|
365 | writer.write(data) | |
364 | writer.flush(zstd.FLUSH_FRAME) |
|
366 | writer.flush(zstd.FLUSH_FRAME) | |
365 |
|
367 | |||
366 | dctx = zstd.ZstdDecompressor() |
|
368 | dctx = zstd.ZstdDecompressor() | |
367 | buffer.seek(0) |
|
369 | buffer.seek(0) | |
368 | reader = dctx.stream_reader( |
|
370 | reader = dctx.stream_reader( | |
369 | buffer, read_size=source_read_size, read_across_frames=True |
|
371 | buffer, read_size=source_read_size, read_across_frames=True | |
370 | ) |
|
372 | ) | |
371 |
|
373 | |||
372 | chunks = [] |
|
374 | chunks = [] | |
373 |
|
375 | |||
374 | while True: |
|
376 | while True: | |
375 | read_amount = read_sizes.draw(strategies.integers(-1, 16384)) |
|
377 | read_amount = read_sizes.draw(strategies.integers(-1, 16384)) | |
376 | chunk = reader.read(read_amount) |
|
378 | chunk = reader.read(read_amount) | |
377 |
|
379 | |||
378 | if not chunk and read_amount: |
|
380 | if not chunk and read_amount: | |
379 | break |
|
381 | break | |
380 |
|
382 | |||
381 | chunks.append(chunk) |
|
383 | chunks.append(chunk) | |
382 |
|
384 | |||
383 | self.assertEqual(source.getvalue(), b"".join(chunks)) |
|
385 | self.assertEqual(source.getvalue(), b"".join(chunks)) | |
384 |
|
386 | |||
385 |
|
387 | |||
386 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
388 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
387 | @make_cffi |
|
389 | @make_cffi | |
388 | class TestDecompressor_stream_writer_fuzzing(TestCase): |
|
390 | class TestDecompressor_stream_writer_fuzzing(TestCase): | |
389 | @hypothesis.settings( |
|
391 | @hypothesis.settings( | |
390 | suppress_health_check=[ |
|
392 | suppress_health_check=[ | |
391 | hypothesis.HealthCheck.large_base_example, |
|
393 | hypothesis.HealthCheck.large_base_example, | |
392 | hypothesis.HealthCheck.too_slow, |
|
394 | hypothesis.HealthCheck.too_slow, | |
393 | ] |
|
395 | ] | |
394 | ) |
|
396 | ) | |
395 | @hypothesis.given( |
|
397 | @hypothesis.given( | |
396 | original=strategies.sampled_from(random_input_data()), |
|
398 | original=strategies.sampled_from(random_input_data()), | |
397 | level=strategies.integers(min_value=1, max_value=5), |
|
399 | level=strategies.integers(min_value=1, max_value=5), | |
398 | write_size=strategies.integers(min_value=1, max_value=8192), |
|
400 | write_size=strategies.integers(min_value=1, max_value=8192), | |
399 | input_sizes=strategies.data(), |
|
401 | input_sizes=strategies.data(), | |
400 | ) |
|
402 | ) | |
401 | def test_write_size_variance(self, original, level, write_size, input_sizes): |
|
403 | def test_write_size_variance( | |
|
404 | self, original, level, write_size, input_sizes | |||
|
405 | ): | |||
402 | cctx = zstd.ZstdCompressor(level=level) |
|
406 | cctx = zstd.ZstdCompressor(level=level) | |
403 | frame = cctx.compress(original) |
|
407 | frame = cctx.compress(original) | |
404 |
|
408 | |||
405 | dctx = zstd.ZstdDecompressor() |
|
409 | dctx = zstd.ZstdDecompressor() | |
406 | source = io.BytesIO(frame) |
|
410 | source = io.BytesIO(frame) | |
407 | dest = NonClosingBytesIO() |
|
411 | dest = NonClosingBytesIO() | |
408 |
|
412 | |||
409 | with dctx.stream_writer(dest, write_size=write_size) as decompressor: |
|
413 | with dctx.stream_writer(dest, write_size=write_size) as decompressor: | |
410 | while True: |
|
414 | while True: | |
411 | input_size = input_sizes.draw(strategies.integers(1, 4096)) |
|
415 | input_size = input_sizes.draw(strategies.integers(1, 4096)) | |
412 | chunk = source.read(input_size) |
|
416 | chunk = source.read(input_size) | |
413 | if not chunk: |
|
417 | if not chunk: | |
414 | break |
|
418 | break | |
415 |
|
419 | |||
416 | decompressor.write(chunk) |
|
420 | decompressor.write(chunk) | |
417 |
|
421 | |||
418 | self.assertEqual(dest.getvalue(), original) |
|
422 | self.assertEqual(dest.getvalue(), original) | |
419 |
|
423 | |||
420 |
|
424 | |||
421 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
425 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
422 | @make_cffi |
|
426 | @make_cffi | |
423 | class TestDecompressor_copy_stream_fuzzing(TestCase): |
|
427 | class TestDecompressor_copy_stream_fuzzing(TestCase): | |
424 | @hypothesis.settings( |
|
428 | @hypothesis.settings( | |
425 | suppress_health_check=[ |
|
429 | suppress_health_check=[ | |
426 | hypothesis.HealthCheck.large_base_example, |
|
430 | hypothesis.HealthCheck.large_base_example, | |
427 | hypothesis.HealthCheck.too_slow, |
|
431 | hypothesis.HealthCheck.too_slow, | |
428 | ] |
|
432 | ] | |
429 | ) |
|
433 | ) | |
430 | @hypothesis.given( |
|
434 | @hypothesis.given( | |
431 | original=strategies.sampled_from(random_input_data()), |
|
435 | original=strategies.sampled_from(random_input_data()), | |
432 | level=strategies.integers(min_value=1, max_value=5), |
|
436 | level=strategies.integers(min_value=1, max_value=5), | |
433 | read_size=strategies.integers(min_value=1, max_value=8192), |
|
437 | read_size=strategies.integers(min_value=1, max_value=8192), | |
434 | write_size=strategies.integers(min_value=1, max_value=8192), |
|
438 | write_size=strategies.integers(min_value=1, max_value=8192), | |
435 | ) |
|
439 | ) | |
436 |
def test_read_write_size_variance( |
|
440 | def test_read_write_size_variance( | |
|
441 | self, original, level, read_size, write_size | |||
|
442 | ): | |||
437 | cctx = zstd.ZstdCompressor(level=level) |
|
443 | cctx = zstd.ZstdCompressor(level=level) | |
438 | frame = cctx.compress(original) |
|
444 | frame = cctx.compress(original) | |
439 |
|
445 | |||
440 | source = io.BytesIO(frame) |
|
446 | source = io.BytesIO(frame) | |
441 | dest = io.BytesIO() |
|
447 | dest = io.BytesIO() | |
442 |
|
448 | |||
443 | dctx = zstd.ZstdDecompressor() |
|
449 | dctx = zstd.ZstdDecompressor() | |
444 | dctx.copy_stream(source, dest, read_size=read_size, write_size=write_size) |
|
450 | dctx.copy_stream( | |
|
451 | source, dest, read_size=read_size, write_size=write_size | |||
|
452 | ) | |||
445 |
|
453 | |||
446 | self.assertEqual(dest.getvalue(), original) |
|
454 | self.assertEqual(dest.getvalue(), original) | |
447 |
|
455 | |||
448 |
|
456 | |||
449 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
457 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
450 | @make_cffi |
|
458 | @make_cffi | |
451 | class TestDecompressor_decompressobj_fuzzing(TestCase): |
|
459 | class TestDecompressor_decompressobj_fuzzing(TestCase): | |
452 | @hypothesis.settings( |
|
460 | @hypothesis.settings( | |
453 | suppress_health_check=[ |
|
461 | suppress_health_check=[ | |
454 | hypothesis.HealthCheck.large_base_example, |
|
462 | hypothesis.HealthCheck.large_base_example, | |
455 | hypothesis.HealthCheck.too_slow, |
|
463 | hypothesis.HealthCheck.too_slow, | |
456 | ] |
|
464 | ] | |
457 | ) |
|
465 | ) | |
458 | @hypothesis.given( |
|
466 | @hypothesis.given( | |
459 | original=strategies.sampled_from(random_input_data()), |
|
467 | original=strategies.sampled_from(random_input_data()), | |
460 | level=strategies.integers(min_value=1, max_value=5), |
|
468 | level=strategies.integers(min_value=1, max_value=5), | |
461 | chunk_sizes=strategies.data(), |
|
469 | chunk_sizes=strategies.data(), | |
462 | ) |
|
470 | ) | |
463 | def test_random_input_sizes(self, original, level, chunk_sizes): |
|
471 | def test_random_input_sizes(self, original, level, chunk_sizes): | |
464 | cctx = zstd.ZstdCompressor(level=level) |
|
472 | cctx = zstd.ZstdCompressor(level=level) | |
465 | frame = cctx.compress(original) |
|
473 | frame = cctx.compress(original) | |
466 |
|
474 | |||
467 | source = io.BytesIO(frame) |
|
475 | source = io.BytesIO(frame) | |
468 |
|
476 | |||
469 | dctx = zstd.ZstdDecompressor() |
|
477 | dctx = zstd.ZstdDecompressor() | |
470 | dobj = dctx.decompressobj() |
|
478 | dobj = dctx.decompressobj() | |
471 |
|
479 | |||
472 | chunks = [] |
|
480 | chunks = [] | |
473 | while True: |
|
481 | while True: | |
474 | chunk_size = chunk_sizes.draw(strategies.integers(1, 4096)) |
|
482 | chunk_size = chunk_sizes.draw(strategies.integers(1, 4096)) | |
475 | chunk = source.read(chunk_size) |
|
483 | chunk = source.read(chunk_size) | |
476 | if not chunk: |
|
484 | if not chunk: | |
477 | break |
|
485 | break | |
478 |
|
486 | |||
479 | chunks.append(dobj.decompress(chunk)) |
|
487 | chunks.append(dobj.decompress(chunk)) | |
480 |
|
488 | |||
481 | self.assertEqual(b"".join(chunks), original) |
|
489 | self.assertEqual(b"".join(chunks), original) | |
482 |
|
490 | |||
483 | @hypothesis.settings( |
|
491 | @hypothesis.settings( | |
484 | suppress_health_check=[ |
|
492 | suppress_health_check=[ | |
485 | hypothesis.HealthCheck.large_base_example, |
|
493 | hypothesis.HealthCheck.large_base_example, | |
486 | hypothesis.HealthCheck.too_slow, |
|
494 | hypothesis.HealthCheck.too_slow, | |
487 | ] |
|
495 | ] | |
488 | ) |
|
496 | ) | |
489 | @hypothesis.given( |
|
497 | @hypothesis.given( | |
490 | original=strategies.sampled_from(random_input_data()), |
|
498 | original=strategies.sampled_from(random_input_data()), | |
491 | level=strategies.integers(min_value=1, max_value=5), |
|
499 | level=strategies.integers(min_value=1, max_value=5), | |
492 | write_size=strategies.integers( |
|
500 | write_size=strategies.integers( | |
493 | min_value=1, max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE |
|
501 | min_value=1, | |
|
502 | max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |||
494 | ), |
|
503 | ), | |
495 | chunk_sizes=strategies.data(), |
|
504 | chunk_sizes=strategies.data(), | |
496 | ) |
|
505 | ) | |
497 | def test_random_output_sizes(self, original, level, write_size, chunk_sizes): |
|
506 | def test_random_output_sizes( | |
|
507 | self, original, level, write_size, chunk_sizes | |||
|
508 | ): | |||
498 | cctx = zstd.ZstdCompressor(level=level) |
|
509 | cctx = zstd.ZstdCompressor(level=level) | |
499 | frame = cctx.compress(original) |
|
510 | frame = cctx.compress(original) | |
500 |
|
511 | |||
501 | source = io.BytesIO(frame) |
|
512 | source = io.BytesIO(frame) | |
502 |
|
513 | |||
503 | dctx = zstd.ZstdDecompressor() |
|
514 | dctx = zstd.ZstdDecompressor() | |
504 | dobj = dctx.decompressobj(write_size=write_size) |
|
515 | dobj = dctx.decompressobj(write_size=write_size) | |
505 |
|
516 | |||
506 | chunks = [] |
|
517 | chunks = [] | |
507 | while True: |
|
518 | while True: | |
508 | chunk_size = chunk_sizes.draw(strategies.integers(1, 4096)) |
|
519 | chunk_size = chunk_sizes.draw(strategies.integers(1, 4096)) | |
509 | chunk = source.read(chunk_size) |
|
520 | chunk = source.read(chunk_size) | |
510 | if not chunk: |
|
521 | if not chunk: | |
511 | break |
|
522 | break | |
512 |
|
523 | |||
513 | chunks.append(dobj.decompress(chunk)) |
|
524 | chunks.append(dobj.decompress(chunk)) | |
514 |
|
525 | |||
515 | self.assertEqual(b"".join(chunks), original) |
|
526 | self.assertEqual(b"".join(chunks), original) | |
516 |
|
527 | |||
517 |
|
528 | |||
518 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
529 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
519 | @make_cffi |
|
530 | @make_cffi | |
520 | class TestDecompressor_read_to_iter_fuzzing(TestCase): |
|
531 | class TestDecompressor_read_to_iter_fuzzing(TestCase): | |
521 | @hypothesis.given( |
|
532 | @hypothesis.given( | |
522 | original=strategies.sampled_from(random_input_data()), |
|
533 | original=strategies.sampled_from(random_input_data()), | |
523 | level=strategies.integers(min_value=1, max_value=5), |
|
534 | level=strategies.integers(min_value=1, max_value=5), | |
524 | read_size=strategies.integers(min_value=1, max_value=4096), |
|
535 | read_size=strategies.integers(min_value=1, max_value=4096), | |
525 | write_size=strategies.integers(min_value=1, max_value=4096), |
|
536 | write_size=strategies.integers(min_value=1, max_value=4096), | |
526 | ) |
|
537 | ) | |
527 |
def test_read_write_size_variance( |
|
538 | def test_read_write_size_variance( | |
|
539 | self, original, level, read_size, write_size | |||
|
540 | ): | |||
528 | cctx = zstd.ZstdCompressor(level=level) |
|
541 | cctx = zstd.ZstdCompressor(level=level) | |
529 | frame = cctx.compress(original) |
|
542 | frame = cctx.compress(original) | |
530 |
|
543 | |||
531 | source = io.BytesIO(frame) |
|
544 | source = io.BytesIO(frame) | |
532 |
|
545 | |||
533 | dctx = zstd.ZstdDecompressor() |
|
546 | dctx = zstd.ZstdDecompressor() | |
534 | chunks = list( |
|
547 | chunks = list( | |
535 | dctx.read_to_iter(source, read_size=read_size, write_size=write_size) |
|
548 | dctx.read_to_iter( | |
|
549 | source, read_size=read_size, write_size=write_size | |||
|
550 | ) | |||
536 | ) |
|
551 | ) | |
537 |
|
552 | |||
538 | self.assertEqual(b"".join(chunks), original) |
|
553 | self.assertEqual(b"".join(chunks), original) | |
539 |
|
554 | |||
540 |
|
555 | |||
541 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
|
556 | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | |
542 | class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase): |
|
557 | class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase): | |
543 | @hypothesis.given( |
|
558 | @hypothesis.given( | |
544 | original=strategies.lists( |
|
559 | original=strategies.lists( | |
545 |
strategies.sampled_from(random_input_data()), |
|
560 | strategies.sampled_from(random_input_data()), | |
|
561 | min_size=1, | |||
|
562 | max_size=1024, | |||
546 | ), |
|
563 | ), | |
547 | threads=strategies.integers(min_value=1, max_value=8), |
|
564 | threads=strategies.integers(min_value=1, max_value=8), | |
548 | use_dict=strategies.booleans(), |
|
565 | use_dict=strategies.booleans(), | |
549 | ) |
|
566 | ) | |
550 | def test_data_equivalence(self, original, threads, use_dict): |
|
567 | def test_data_equivalence(self, original, threads, use_dict): | |
551 | kwargs = {} |
|
568 | kwargs = {} | |
552 | if use_dict: |
|
569 | if use_dict: | |
553 | kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) |
|
570 | kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) | |
554 |
|
571 | |||
555 | cctx = zstd.ZstdCompressor( |
|
572 | cctx = zstd.ZstdCompressor( | |
556 | level=1, write_content_size=True, write_checksum=True, **kwargs |
|
573 | level=1, write_content_size=True, write_checksum=True, **kwargs | |
557 | ) |
|
574 | ) | |
558 |
|
575 | |||
559 | if not hasattr(cctx, "multi_compress_to_buffer"): |
|
576 | if not hasattr(cctx, "multi_compress_to_buffer"): | |
560 | self.skipTest("multi_compress_to_buffer not available") |
|
577 | self.skipTest("multi_compress_to_buffer not available") | |
561 |
|
578 | |||
562 | frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1) |
|
579 | frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1) | |
563 |
|
580 | |||
564 | dctx = zstd.ZstdDecompressor(**kwargs) |
|
581 | dctx = zstd.ZstdDecompressor(**kwargs) | |
565 | result = dctx.multi_decompress_to_buffer(frames_buffer) |
|
582 | result = dctx.multi_decompress_to_buffer(frames_buffer) | |
566 |
|
583 | |||
567 | self.assertEqual(len(result), len(original)) |
|
584 | self.assertEqual(len(result), len(original)) | |
568 | for i, frame in enumerate(result): |
|
585 | for i, frame in enumerate(result): | |
569 | self.assertEqual(frame.tobytes(), original[i]) |
|
586 | self.assertEqual(frame.tobytes(), original[i]) | |
570 |
|
587 | |||
571 | frames_list = [f.tobytes() for f in frames_buffer] |
|
588 | frames_list = [f.tobytes() for f in frames_buffer] | |
572 | result = dctx.multi_decompress_to_buffer(frames_list) |
|
589 | result = dctx.multi_decompress_to_buffer(frames_list) | |
573 |
|
590 | |||
574 | self.assertEqual(len(result), len(original)) |
|
591 | self.assertEqual(len(result), len(original)) | |
575 | for i, frame in enumerate(result): |
|
592 | for i, frame in enumerate(result): | |
576 | self.assertEqual(frame.tobytes(), original[i]) |
|
593 | self.assertEqual(frame.tobytes(), original[i]) |
@@ -1,92 +1,102 b'' | |||||
1 | import struct |
|
1 | import struct | |
2 | import sys |
|
2 | import sys | |
3 | import unittest |
|
3 | import unittest | |
4 |
|
4 | |||
5 | import zstandard as zstd |
|
5 | import zstandard as zstd | |
6 |
|
6 | |||
7 | from .common import ( |
|
7 | from .common import ( | |
8 | generate_samples, |
|
8 | generate_samples, | |
9 | make_cffi, |
|
9 | make_cffi, | |
10 | random_input_data, |
|
10 | random_input_data, | |
11 | TestCase, |
|
11 | TestCase, | |
12 | ) |
|
12 | ) | |
13 |
|
13 | |||
14 | if sys.version_info[0] >= 3: |
|
14 | if sys.version_info[0] >= 3: | |
15 | int_type = int |
|
15 | int_type = int | |
16 | else: |
|
16 | else: | |
17 | int_type = long |
|
17 | int_type = long | |
18 |
|
18 | |||
19 |
|
19 | |||
20 | @make_cffi |
|
20 | @make_cffi | |
21 | class TestTrainDictionary(TestCase): |
|
21 | class TestTrainDictionary(TestCase): | |
22 | def test_no_args(self): |
|
22 | def test_no_args(self): | |
23 | with self.assertRaises(TypeError): |
|
23 | with self.assertRaises(TypeError): | |
24 | zstd.train_dictionary() |
|
24 | zstd.train_dictionary() | |
25 |
|
25 | |||
26 | def test_bad_args(self): |
|
26 | def test_bad_args(self): | |
27 | with self.assertRaises(TypeError): |
|
27 | with self.assertRaises(TypeError): | |
28 | zstd.train_dictionary(8192, u"foo") |
|
28 | zstd.train_dictionary(8192, u"foo") | |
29 |
|
29 | |||
30 | with self.assertRaises(ValueError): |
|
30 | with self.assertRaises(ValueError): | |
31 | zstd.train_dictionary(8192, [u"foo"]) |
|
31 | zstd.train_dictionary(8192, [u"foo"]) | |
32 |
|
32 | |||
33 | def test_no_params(self): |
|
33 | def test_no_params(self): | |
34 | d = zstd.train_dictionary(8192, random_input_data()) |
|
34 | d = zstd.train_dictionary(8192, random_input_data()) | |
35 | self.assertIsInstance(d.dict_id(), int_type) |
|
35 | self.assertIsInstance(d.dict_id(), int_type) | |
36 |
|
36 | |||
37 | # The dictionary ID may be different across platforms. |
|
37 | # The dictionary ID may be different across platforms. | |
38 | expected = b"\x37\xa4\x30\xec" + struct.pack("<I", d.dict_id()) |
|
38 | expected = b"\x37\xa4\x30\xec" + struct.pack("<I", d.dict_id()) | |
39 |
|
39 | |||
40 | data = d.as_bytes() |
|
40 | data = d.as_bytes() | |
41 | self.assertEqual(data[0:8], expected) |
|
41 | self.assertEqual(data[0:8], expected) | |
42 |
|
42 | |||
43 | def test_basic(self): |
|
43 | def test_basic(self): | |
44 | d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) |
|
44 | d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) | |
45 | self.assertIsInstance(d.dict_id(), int_type) |
|
45 | self.assertIsInstance(d.dict_id(), int_type) | |
46 |
|
46 | |||
47 | data = d.as_bytes() |
|
47 | data = d.as_bytes() | |
48 | self.assertEqual(data[0:4], b"\x37\xa4\x30\xec") |
|
48 | self.assertEqual(data[0:4], b"\x37\xa4\x30\xec") | |
49 |
|
49 | |||
50 | self.assertEqual(d.k, 64) |
|
50 | self.assertEqual(d.k, 64) | |
51 | self.assertEqual(d.d, 16) |
|
51 | self.assertEqual(d.d, 16) | |
52 |
|
52 | |||
53 | def test_set_dict_id(self): |
|
53 | def test_set_dict_id(self): | |
54 | d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16, dict_id=42) |
|
54 | d = zstd.train_dictionary( | |
|
55 | 8192, generate_samples(), k=64, d=16, dict_id=42 | |||
|
56 | ) | |||
55 | self.assertEqual(d.dict_id(), 42) |
|
57 | self.assertEqual(d.dict_id(), 42) | |
56 |
|
58 | |||
57 | def test_optimize(self): |
|
59 | def test_optimize(self): | |
58 | d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1, d=16) |
|
60 | d = zstd.train_dictionary( | |
|
61 | 8192, generate_samples(), threads=-1, steps=1, d=16 | |||
|
62 | ) | |||
59 |
|
63 | |||
60 | # This varies by platform. |
|
64 | # This varies by platform. | |
61 | self.assertIn(d.k, (50, 2000)) |
|
65 | self.assertIn(d.k, (50, 2000)) | |
62 | self.assertEqual(d.d, 16) |
|
66 | self.assertEqual(d.d, 16) | |
63 |
|
67 | |||
64 |
|
68 | |||
65 | @make_cffi |
|
69 | @make_cffi | |
66 | class TestCompressionDict(TestCase): |
|
70 | class TestCompressionDict(TestCase): | |
67 | def test_bad_mode(self): |
|
71 | def test_bad_mode(self): | |
68 | with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"): |
|
72 | with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"): | |
69 | zstd.ZstdCompressionDict(b"foo", dict_type=42) |
|
73 | zstd.ZstdCompressionDict(b"foo", dict_type=42) | |
70 |
|
74 | |||
71 | def test_bad_precompute_compress(self): |
|
75 | def test_bad_precompute_compress(self): | |
72 | d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) |
|
76 | d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) | |
73 |
|
77 | |||
74 |
with self.assertRaisesRegex( |
|
78 | with self.assertRaisesRegex( | |
|
79 | ValueError, "must specify one of level or " | |||
|
80 | ): | |||
75 | d.precompute_compress() |
|
81 | d.precompute_compress() | |
76 |
|
82 | |||
77 | with self.assertRaisesRegex(ValueError, "must only specify one of level or "): |
|
83 | with self.assertRaisesRegex( | |
|
84 | ValueError, "must only specify one of level or " | |||
|
85 | ): | |||
78 | d.precompute_compress( |
|
86 | d.precompute_compress( | |
79 | level=3, compression_params=zstd.CompressionParameters() |
|
87 | level=3, compression_params=zstd.CompressionParameters() | |
80 | ) |
|
88 | ) | |
81 |
|
89 | |||
82 | def test_precompute_compress_rawcontent(self): |
|
90 | def test_precompute_compress_rawcontent(self): | |
83 | d = zstd.ZstdCompressionDict( |
|
91 | d = zstd.ZstdCompressionDict( | |
84 | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT |
|
92 | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT | |
85 | ) |
|
93 | ) | |
86 | d.precompute_compress(level=1) |
|
94 | d.precompute_compress(level=1) | |
87 |
|
95 | |||
88 | d = zstd.ZstdCompressionDict( |
|
96 | d = zstd.ZstdCompressionDict( | |
89 | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT |
|
97 | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT | |
90 | ) |
|
98 | ) | |
91 | with self.assertRaisesRegex(zstd.ZstdError, "unable to precompute dictionary"): |
|
99 | with self.assertRaisesRegex( | |
|
100 | zstd.ZstdError, "unable to precompute dictionary" | |||
|
101 | ): | |||
92 | d.precompute_compress(level=1) |
|
102 | d.precompute_compress(level=1) |
@@ -1,2615 +1,2769 b'' | |||||
1 | # Copyright (c) 2016-present, Gregory Szorc |
|
1 | # Copyright (c) 2016-present, Gregory Szorc | |
2 | # All rights reserved. |
|
2 | # All rights reserved. | |
3 | # |
|
3 | # | |
4 | # This software may be modified and distributed under the terms |
|
4 | # This software may be modified and distributed under the terms | |
5 | # of the BSD license. See the LICENSE file for details. |
|
5 | # of the BSD license. See the LICENSE file for details. | |
6 |
|
6 | |||
7 | """Python interface to the Zstandard (zstd) compression library.""" |
|
7 | """Python interface to the Zstandard (zstd) compression library.""" | |
8 |
|
8 | |||
9 | from __future__ import absolute_import, unicode_literals |
|
9 | from __future__ import absolute_import, unicode_literals | |
10 |
|
10 | |||
11 | # This should match what the C extension exports. |
|
11 | # This should match what the C extension exports. | |
12 | __all__ = [ |
|
12 | __all__ = [ | |
13 | #'BufferSegment', |
|
13 | #'BufferSegment', | |
14 | #'BufferSegments', |
|
14 | #'BufferSegments', | |
15 | #'BufferWithSegments', |
|
15 | #'BufferWithSegments', | |
16 | #'BufferWithSegmentsCollection', |
|
16 | #'BufferWithSegmentsCollection', | |
17 | "CompressionParameters", |
|
17 | "CompressionParameters", | |
18 | "ZstdCompressionDict", |
|
18 | "ZstdCompressionDict", | |
19 | "ZstdCompressionParameters", |
|
19 | "ZstdCompressionParameters", | |
20 | "ZstdCompressor", |
|
20 | "ZstdCompressor", | |
21 | "ZstdError", |
|
21 | "ZstdError", | |
22 | "ZstdDecompressor", |
|
22 | "ZstdDecompressor", | |
23 | "FrameParameters", |
|
23 | "FrameParameters", | |
24 | "estimate_decompression_context_size", |
|
24 | "estimate_decompression_context_size", | |
25 | "frame_content_size", |
|
25 | "frame_content_size", | |
26 | "frame_header_size", |
|
26 | "frame_header_size", | |
27 | "get_frame_parameters", |
|
27 | "get_frame_parameters", | |
28 | "train_dictionary", |
|
28 | "train_dictionary", | |
29 | # Constants. |
|
29 | # Constants. | |
30 | "FLUSH_BLOCK", |
|
30 | "FLUSH_BLOCK", | |
31 | "FLUSH_FRAME", |
|
31 | "FLUSH_FRAME", | |
32 | "COMPRESSOBJ_FLUSH_FINISH", |
|
32 | "COMPRESSOBJ_FLUSH_FINISH", | |
33 | "COMPRESSOBJ_FLUSH_BLOCK", |
|
33 | "COMPRESSOBJ_FLUSH_BLOCK", | |
34 | "ZSTD_VERSION", |
|
34 | "ZSTD_VERSION", | |
35 | "FRAME_HEADER", |
|
35 | "FRAME_HEADER", | |
36 | "CONTENTSIZE_UNKNOWN", |
|
36 | "CONTENTSIZE_UNKNOWN", | |
37 | "CONTENTSIZE_ERROR", |
|
37 | "CONTENTSIZE_ERROR", | |
38 | "MAX_COMPRESSION_LEVEL", |
|
38 | "MAX_COMPRESSION_LEVEL", | |
39 | "COMPRESSION_RECOMMENDED_INPUT_SIZE", |
|
39 | "COMPRESSION_RECOMMENDED_INPUT_SIZE", | |
40 | "COMPRESSION_RECOMMENDED_OUTPUT_SIZE", |
|
40 | "COMPRESSION_RECOMMENDED_OUTPUT_SIZE", | |
41 | "DECOMPRESSION_RECOMMENDED_INPUT_SIZE", |
|
41 | "DECOMPRESSION_RECOMMENDED_INPUT_SIZE", | |
42 | "DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE", |
|
42 | "DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE", | |
43 | "MAGIC_NUMBER", |
|
43 | "MAGIC_NUMBER", | |
44 | "BLOCKSIZELOG_MAX", |
|
44 | "BLOCKSIZELOG_MAX", | |
45 | "BLOCKSIZE_MAX", |
|
45 | "BLOCKSIZE_MAX", | |
46 | "WINDOWLOG_MIN", |
|
46 | "WINDOWLOG_MIN", | |
47 | "WINDOWLOG_MAX", |
|
47 | "WINDOWLOG_MAX", | |
48 | "CHAINLOG_MIN", |
|
48 | "CHAINLOG_MIN", | |
49 | "CHAINLOG_MAX", |
|
49 | "CHAINLOG_MAX", | |
50 | "HASHLOG_MIN", |
|
50 | "HASHLOG_MIN", | |
51 | "HASHLOG_MAX", |
|
51 | "HASHLOG_MAX", | |
52 | "HASHLOG3_MAX", |
|
52 | "HASHLOG3_MAX", | |
53 | "MINMATCH_MIN", |
|
53 | "MINMATCH_MIN", | |
54 | "MINMATCH_MAX", |
|
54 | "MINMATCH_MAX", | |
55 | "SEARCHLOG_MIN", |
|
55 | "SEARCHLOG_MIN", | |
56 | "SEARCHLOG_MAX", |
|
56 | "SEARCHLOG_MAX", | |
57 | "SEARCHLENGTH_MIN", |
|
57 | "SEARCHLENGTH_MIN", | |
58 | "SEARCHLENGTH_MAX", |
|
58 | "SEARCHLENGTH_MAX", | |
59 | "TARGETLENGTH_MIN", |
|
59 | "TARGETLENGTH_MIN", | |
60 | "TARGETLENGTH_MAX", |
|
60 | "TARGETLENGTH_MAX", | |
61 | "LDM_MINMATCH_MIN", |
|
61 | "LDM_MINMATCH_MIN", | |
62 | "LDM_MINMATCH_MAX", |
|
62 | "LDM_MINMATCH_MAX", | |
63 | "LDM_BUCKETSIZELOG_MAX", |
|
63 | "LDM_BUCKETSIZELOG_MAX", | |
64 | "STRATEGY_FAST", |
|
64 | "STRATEGY_FAST", | |
65 | "STRATEGY_DFAST", |
|
65 | "STRATEGY_DFAST", | |
66 | "STRATEGY_GREEDY", |
|
66 | "STRATEGY_GREEDY", | |
67 | "STRATEGY_LAZY", |
|
67 | "STRATEGY_LAZY", | |
68 | "STRATEGY_LAZY2", |
|
68 | "STRATEGY_LAZY2", | |
69 | "STRATEGY_BTLAZY2", |
|
69 | "STRATEGY_BTLAZY2", | |
70 | "STRATEGY_BTOPT", |
|
70 | "STRATEGY_BTOPT", | |
71 | "STRATEGY_BTULTRA", |
|
71 | "STRATEGY_BTULTRA", | |
72 | "STRATEGY_BTULTRA2", |
|
72 | "STRATEGY_BTULTRA2", | |
73 | "DICT_TYPE_AUTO", |
|
73 | "DICT_TYPE_AUTO", | |
74 | "DICT_TYPE_RAWCONTENT", |
|
74 | "DICT_TYPE_RAWCONTENT", | |
75 | "DICT_TYPE_FULLDICT", |
|
75 | "DICT_TYPE_FULLDICT", | |
76 | "FORMAT_ZSTD1", |
|
76 | "FORMAT_ZSTD1", | |
77 | "FORMAT_ZSTD1_MAGICLESS", |
|
77 | "FORMAT_ZSTD1_MAGICLESS", | |
78 | ] |
|
78 | ] | |
79 |
|
79 | |||
80 | import io |
|
80 | import io | |
81 | import os |
|
81 | import os | |
82 | import sys |
|
82 | import sys | |
83 |
|
83 | |||
84 | from _zstd_cffi import ( |
|
84 | from _zstd_cffi import ( | |
85 | ffi, |
|
85 | ffi, | |
86 | lib, |
|
86 | lib, | |
87 | ) |
|
87 | ) | |
88 |
|
88 | |||
89 | if sys.version_info[0] == 2: |
|
89 | if sys.version_info[0] == 2: | |
90 | bytes_type = str |
|
90 | bytes_type = str | |
91 | int_type = long |
|
91 | int_type = long | |
92 | else: |
|
92 | else: | |
93 | bytes_type = bytes |
|
93 | bytes_type = bytes | |
94 | int_type = int |
|
94 | int_type = int | |
95 |
|
95 | |||
96 |
|
96 | |||
97 | COMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_CStreamInSize() |
|
97 | COMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_CStreamInSize() | |
98 | COMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_CStreamOutSize() |
|
98 | COMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_CStreamOutSize() | |
99 | DECOMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_DStreamInSize() |
|
99 | DECOMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_DStreamInSize() | |
100 | DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_DStreamOutSize() |
|
100 | DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_DStreamOutSize() | |
101 |
|
101 | |||
102 | new_nonzero = ffi.new_allocator(should_clear_after_alloc=False) |
|
102 | new_nonzero = ffi.new_allocator(should_clear_after_alloc=False) | |
103 |
|
103 | |||
104 |
|
104 | |||
105 | MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel() |
|
105 | MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel() | |
106 | MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER |
|
106 | MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER | |
107 | FRAME_HEADER = b"\x28\xb5\x2f\xfd" |
|
107 | FRAME_HEADER = b"\x28\xb5\x2f\xfd" | |
108 | CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN |
|
108 | CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN | |
109 | CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR |
|
109 | CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR | |
110 | ZSTD_VERSION = ( |
|
110 | ZSTD_VERSION = ( | |
111 | lib.ZSTD_VERSION_MAJOR, |
|
111 | lib.ZSTD_VERSION_MAJOR, | |
112 | lib.ZSTD_VERSION_MINOR, |
|
112 | lib.ZSTD_VERSION_MINOR, | |
113 | lib.ZSTD_VERSION_RELEASE, |
|
113 | lib.ZSTD_VERSION_RELEASE, | |
114 | ) |
|
114 | ) | |
115 |
|
115 | |||
116 | BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX |
|
116 | BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX | |
117 | BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX |
|
117 | BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX | |
118 | WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN |
|
118 | WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN | |
119 | WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX |
|
119 | WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX | |
120 | CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN |
|
120 | CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN | |
121 | CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX |
|
121 | CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX | |
122 | HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN |
|
122 | HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN | |
123 | HASHLOG_MAX = lib.ZSTD_HASHLOG_MAX |
|
123 | HASHLOG_MAX = lib.ZSTD_HASHLOG_MAX | |
124 | HASHLOG3_MAX = lib.ZSTD_HASHLOG3_MAX |
|
124 | HASHLOG3_MAX = lib.ZSTD_HASHLOG3_MAX | |
125 | MINMATCH_MIN = lib.ZSTD_MINMATCH_MIN |
|
125 | MINMATCH_MIN = lib.ZSTD_MINMATCH_MIN | |
126 | MINMATCH_MAX = lib.ZSTD_MINMATCH_MAX |
|
126 | MINMATCH_MAX = lib.ZSTD_MINMATCH_MAX | |
127 | SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN |
|
127 | SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN | |
128 | SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX |
|
128 | SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX | |
129 | SEARCHLENGTH_MIN = lib.ZSTD_MINMATCH_MIN |
|
129 | SEARCHLENGTH_MIN = lib.ZSTD_MINMATCH_MIN | |
130 | SEARCHLENGTH_MAX = lib.ZSTD_MINMATCH_MAX |
|
130 | SEARCHLENGTH_MAX = lib.ZSTD_MINMATCH_MAX | |
131 | TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN |
|
131 | TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN | |
132 | TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX |
|
132 | TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX | |
133 | LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN |
|
133 | LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN | |
134 | LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX |
|
134 | LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX | |
135 | LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX |
|
135 | LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX | |
136 |
|
136 | |||
137 | STRATEGY_FAST = lib.ZSTD_fast |
|
137 | STRATEGY_FAST = lib.ZSTD_fast | |
138 | STRATEGY_DFAST = lib.ZSTD_dfast |
|
138 | STRATEGY_DFAST = lib.ZSTD_dfast | |
139 | STRATEGY_GREEDY = lib.ZSTD_greedy |
|
139 | STRATEGY_GREEDY = lib.ZSTD_greedy | |
140 | STRATEGY_LAZY = lib.ZSTD_lazy |
|
140 | STRATEGY_LAZY = lib.ZSTD_lazy | |
141 | STRATEGY_LAZY2 = lib.ZSTD_lazy2 |
|
141 | STRATEGY_LAZY2 = lib.ZSTD_lazy2 | |
142 | STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2 |
|
142 | STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2 | |
143 | STRATEGY_BTOPT = lib.ZSTD_btopt |
|
143 | STRATEGY_BTOPT = lib.ZSTD_btopt | |
144 | STRATEGY_BTULTRA = lib.ZSTD_btultra |
|
144 | STRATEGY_BTULTRA = lib.ZSTD_btultra | |
145 | STRATEGY_BTULTRA2 = lib.ZSTD_btultra2 |
|
145 | STRATEGY_BTULTRA2 = lib.ZSTD_btultra2 | |
146 |
|
146 | |||
147 | DICT_TYPE_AUTO = lib.ZSTD_dct_auto |
|
147 | DICT_TYPE_AUTO = lib.ZSTD_dct_auto | |
148 | DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent |
|
148 | DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent | |
149 | DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict |
|
149 | DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict | |
150 |
|
150 | |||
151 | FORMAT_ZSTD1 = lib.ZSTD_f_zstd1 |
|
151 | FORMAT_ZSTD1 = lib.ZSTD_f_zstd1 | |
152 | FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless |
|
152 | FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless | |
153 |
|
153 | |||
154 | FLUSH_BLOCK = 0 |
|
154 | FLUSH_BLOCK = 0 | |
155 | FLUSH_FRAME = 1 |
|
155 | FLUSH_FRAME = 1 | |
156 |
|
156 | |||
157 | COMPRESSOBJ_FLUSH_FINISH = 0 |
|
157 | COMPRESSOBJ_FLUSH_FINISH = 0 | |
158 | COMPRESSOBJ_FLUSH_BLOCK = 1 |
|
158 | COMPRESSOBJ_FLUSH_BLOCK = 1 | |
159 |
|
159 | |||
160 |
|
160 | |||
161 | def _cpu_count(): |
|
161 | def _cpu_count(): | |
162 | # os.cpu_count() was introducd in Python 3.4. |
|
162 | # os.cpu_count() was introducd in Python 3.4. | |
163 | try: |
|
163 | try: | |
164 | return os.cpu_count() or 0 |
|
164 | return os.cpu_count() or 0 | |
165 | except AttributeError: |
|
165 | except AttributeError: | |
166 | pass |
|
166 | pass | |
167 |
|
167 | |||
168 | # Linux. |
|
168 | # Linux. | |
169 | try: |
|
169 | try: | |
170 | if sys.version_info[0] == 2: |
|
170 | if sys.version_info[0] == 2: | |
171 | return os.sysconf(b"SC_NPROCESSORS_ONLN") |
|
171 | return os.sysconf(b"SC_NPROCESSORS_ONLN") | |
172 | else: |
|
172 | else: | |
173 | return os.sysconf("SC_NPROCESSORS_ONLN") |
|
173 | return os.sysconf("SC_NPROCESSORS_ONLN") | |
174 | except (AttributeError, ValueError): |
|
174 | except (AttributeError, ValueError): | |
175 | pass |
|
175 | pass | |
176 |
|
176 | |||
177 | # TODO implement on other platforms. |
|
177 | # TODO implement on other platforms. | |
178 | return 0 |
|
178 | return 0 | |
179 |
|
179 | |||
180 |
|
180 | |||
181 | class ZstdError(Exception): |
|
181 | class ZstdError(Exception): | |
182 | pass |
|
182 | pass | |
183 |
|
183 | |||
184 |
|
184 | |||
185 | def _zstd_error(zresult): |
|
185 | def _zstd_error(zresult): | |
186 | # Resolves to bytes on Python 2 and 3. We use the string for formatting |
|
186 | # Resolves to bytes on Python 2 and 3. We use the string for formatting | |
187 | # into error messages, which will be literal unicode. So convert it to |
|
187 | # into error messages, which will be literal unicode. So convert it to | |
188 | # unicode. |
|
188 | # unicode. | |
189 | return ffi.string(lib.ZSTD_getErrorName(zresult)).decode("utf-8") |
|
189 | return ffi.string(lib.ZSTD_getErrorName(zresult)).decode("utf-8") | |
190 |
|
190 | |||
191 |
|
191 | |||
192 | def _make_cctx_params(params): |
|
192 | def _make_cctx_params(params): | |
193 | res = lib.ZSTD_createCCtxParams() |
|
193 | res = lib.ZSTD_createCCtxParams() | |
194 | if res == ffi.NULL: |
|
194 | if res == ffi.NULL: | |
195 | raise MemoryError() |
|
195 | raise MemoryError() | |
196 |
|
196 | |||
197 | res = ffi.gc(res, lib.ZSTD_freeCCtxParams) |
|
197 | res = ffi.gc(res, lib.ZSTD_freeCCtxParams) | |
198 |
|
198 | |||
199 | attrs = [ |
|
199 | attrs = [ | |
200 | (lib.ZSTD_c_format, params.format), |
|
200 | (lib.ZSTD_c_format, params.format), | |
201 | (lib.ZSTD_c_compressionLevel, params.compression_level), |
|
201 | (lib.ZSTD_c_compressionLevel, params.compression_level), | |
202 | (lib.ZSTD_c_windowLog, params.window_log), |
|
202 | (lib.ZSTD_c_windowLog, params.window_log), | |
203 | (lib.ZSTD_c_hashLog, params.hash_log), |
|
203 | (lib.ZSTD_c_hashLog, params.hash_log), | |
204 | (lib.ZSTD_c_chainLog, params.chain_log), |
|
204 | (lib.ZSTD_c_chainLog, params.chain_log), | |
205 | (lib.ZSTD_c_searchLog, params.search_log), |
|
205 | (lib.ZSTD_c_searchLog, params.search_log), | |
206 | (lib.ZSTD_c_minMatch, params.min_match), |
|
206 | (lib.ZSTD_c_minMatch, params.min_match), | |
207 | (lib.ZSTD_c_targetLength, params.target_length), |
|
207 | (lib.ZSTD_c_targetLength, params.target_length), | |
208 | (lib.ZSTD_c_strategy, params.compression_strategy), |
|
208 | (lib.ZSTD_c_strategy, params.compression_strategy), | |
209 | (lib.ZSTD_c_contentSizeFlag, params.write_content_size), |
|
209 | (lib.ZSTD_c_contentSizeFlag, params.write_content_size), | |
210 | (lib.ZSTD_c_checksumFlag, params.write_checksum), |
|
210 | (lib.ZSTD_c_checksumFlag, params.write_checksum), | |
211 | (lib.ZSTD_c_dictIDFlag, params.write_dict_id), |
|
211 | (lib.ZSTD_c_dictIDFlag, params.write_dict_id), | |
212 | (lib.ZSTD_c_nbWorkers, params.threads), |
|
212 | (lib.ZSTD_c_nbWorkers, params.threads), | |
213 | (lib.ZSTD_c_jobSize, params.job_size), |
|
213 | (lib.ZSTD_c_jobSize, params.job_size), | |
214 | (lib.ZSTD_c_overlapLog, params.overlap_log), |
|
214 | (lib.ZSTD_c_overlapLog, params.overlap_log), | |
215 | (lib.ZSTD_c_forceMaxWindow, params.force_max_window), |
|
215 | (lib.ZSTD_c_forceMaxWindow, params.force_max_window), | |
216 | (lib.ZSTD_c_enableLongDistanceMatching, params.enable_ldm), |
|
216 | (lib.ZSTD_c_enableLongDistanceMatching, params.enable_ldm), | |
217 | (lib.ZSTD_c_ldmHashLog, params.ldm_hash_log), |
|
217 | (lib.ZSTD_c_ldmHashLog, params.ldm_hash_log), | |
218 | (lib.ZSTD_c_ldmMinMatch, params.ldm_min_match), |
|
218 | (lib.ZSTD_c_ldmMinMatch, params.ldm_min_match), | |
219 | (lib.ZSTD_c_ldmBucketSizeLog, params.ldm_bucket_size_log), |
|
219 | (lib.ZSTD_c_ldmBucketSizeLog, params.ldm_bucket_size_log), | |
220 | (lib.ZSTD_c_ldmHashRateLog, params.ldm_hash_rate_log), |
|
220 | (lib.ZSTD_c_ldmHashRateLog, params.ldm_hash_rate_log), | |
221 | ] |
|
221 | ] | |
222 |
|
222 | |||
223 | for param, value in attrs: |
|
223 | for param, value in attrs: | |
224 | _set_compression_parameter(res, param, value) |
|
224 | _set_compression_parameter(res, param, value) | |
225 |
|
225 | |||
226 | return res |
|
226 | return res | |
227 |
|
227 | |||
228 |
|
228 | |||
229 | class ZstdCompressionParameters(object): |
|
229 | class ZstdCompressionParameters(object): | |
230 | @staticmethod |
|
230 | @staticmethod | |
231 | def from_level(level, source_size=0, dict_size=0, **kwargs): |
|
231 | def from_level(level, source_size=0, dict_size=0, **kwargs): | |
232 | params = lib.ZSTD_getCParams(level, source_size, dict_size) |
|
232 | params = lib.ZSTD_getCParams(level, source_size, dict_size) | |
233 |
|
233 | |||
234 | args = { |
|
234 | args = { | |
235 | "window_log": "windowLog", |
|
235 | "window_log": "windowLog", | |
236 | "chain_log": "chainLog", |
|
236 | "chain_log": "chainLog", | |
237 | "hash_log": "hashLog", |
|
237 | "hash_log": "hashLog", | |
238 | "search_log": "searchLog", |
|
238 | "search_log": "searchLog", | |
239 | "min_match": "minMatch", |
|
239 | "min_match": "minMatch", | |
240 | "target_length": "targetLength", |
|
240 | "target_length": "targetLength", | |
241 | "compression_strategy": "strategy", |
|
241 | "compression_strategy": "strategy", | |
242 | } |
|
242 | } | |
243 |
|
243 | |||
244 | for arg, attr in args.items(): |
|
244 | for arg, attr in args.items(): | |
245 | if arg not in kwargs: |
|
245 | if arg not in kwargs: | |
246 | kwargs[arg] = getattr(params, attr) |
|
246 | kwargs[arg] = getattr(params, attr) | |
247 |
|
247 | |||
248 | return ZstdCompressionParameters(**kwargs) |
|
248 | return ZstdCompressionParameters(**kwargs) | |
249 |
|
249 | |||
250 | def __init__( |
|
250 | def __init__( | |
251 | self, |
|
251 | self, | |
252 | format=0, |
|
252 | format=0, | |
253 | compression_level=0, |
|
253 | compression_level=0, | |
254 | window_log=0, |
|
254 | window_log=0, | |
255 | hash_log=0, |
|
255 | hash_log=0, | |
256 | chain_log=0, |
|
256 | chain_log=0, | |
257 | search_log=0, |
|
257 | search_log=0, | |
258 | min_match=0, |
|
258 | min_match=0, | |
259 | target_length=0, |
|
259 | target_length=0, | |
260 | strategy=-1, |
|
260 | strategy=-1, | |
261 | compression_strategy=-1, |
|
261 | compression_strategy=-1, | |
262 | write_content_size=1, |
|
262 | write_content_size=1, | |
263 | write_checksum=0, |
|
263 | write_checksum=0, | |
264 | write_dict_id=0, |
|
264 | write_dict_id=0, | |
265 | job_size=0, |
|
265 | job_size=0, | |
266 | overlap_log=-1, |
|
266 | overlap_log=-1, | |
267 | overlap_size_log=-1, |
|
267 | overlap_size_log=-1, | |
268 | force_max_window=0, |
|
268 | force_max_window=0, | |
269 | enable_ldm=0, |
|
269 | enable_ldm=0, | |
270 | ldm_hash_log=0, |
|
270 | ldm_hash_log=0, | |
271 | ldm_min_match=0, |
|
271 | ldm_min_match=0, | |
272 | ldm_bucket_size_log=0, |
|
272 | ldm_bucket_size_log=0, | |
273 | ldm_hash_rate_log=-1, |
|
273 | ldm_hash_rate_log=-1, | |
274 | ldm_hash_every_log=-1, |
|
274 | ldm_hash_every_log=-1, | |
275 | threads=0, |
|
275 | threads=0, | |
276 | ): |
|
276 | ): | |
277 |
|
277 | |||
278 | params = lib.ZSTD_createCCtxParams() |
|
278 | params = lib.ZSTD_createCCtxParams() | |
279 | if params == ffi.NULL: |
|
279 | if params == ffi.NULL: | |
280 | raise MemoryError() |
|
280 | raise MemoryError() | |
281 |
|
281 | |||
282 | params = ffi.gc(params, lib.ZSTD_freeCCtxParams) |
|
282 | params = ffi.gc(params, lib.ZSTD_freeCCtxParams) | |
283 |
|
283 | |||
284 | self._params = params |
|
284 | self._params = params | |
285 |
|
285 | |||
286 | if threads < 0: |
|
286 | if threads < 0: | |
287 | threads = _cpu_count() |
|
287 | threads = _cpu_count() | |
288 |
|
288 | |||
289 | # We need to set ZSTD_c_nbWorkers before ZSTD_c_jobSize and ZSTD_c_overlapLog |
|
289 | # We need to set ZSTD_c_nbWorkers before ZSTD_c_jobSize and ZSTD_c_overlapLog | |
290 | # because setting ZSTD_c_nbWorkers resets the other parameters. |
|
290 | # because setting ZSTD_c_nbWorkers resets the other parameters. | |
291 | _set_compression_parameter(params, lib.ZSTD_c_nbWorkers, threads) |
|
291 | _set_compression_parameter(params, lib.ZSTD_c_nbWorkers, threads) | |
292 |
|
292 | |||
293 | _set_compression_parameter(params, lib.ZSTD_c_format, format) |
|
293 | _set_compression_parameter(params, lib.ZSTD_c_format, format) | |
294 | _set_compression_parameter( |
|
294 | _set_compression_parameter( | |
295 | params, lib.ZSTD_c_compressionLevel, compression_level |
|
295 | params, lib.ZSTD_c_compressionLevel, compression_level | |
296 | ) |
|
296 | ) | |
297 | _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log) |
|
297 | _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log) | |
298 | _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log) |
|
298 | _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log) | |
299 | _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log) |
|
299 | _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log) | |
300 | _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log) |
|
300 | _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log) | |
301 | _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match) |
|
301 | _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match) | |
302 | _set_compression_parameter(params, lib.ZSTD_c_targetLength, target_length) |
|
302 | _set_compression_parameter( | |
|
303 | params, lib.ZSTD_c_targetLength, target_length | |||
|
304 | ) | |||
303 |
|
305 | |||
304 | if strategy != -1 and compression_strategy != -1: |
|
306 | if strategy != -1 and compression_strategy != -1: | |
305 | raise ValueError("cannot specify both compression_strategy and strategy") |
|
307 | raise ValueError( | |
|
308 | "cannot specify both compression_strategy and strategy" | |||
|
309 | ) | |||
306 |
|
310 | |||
307 | if compression_strategy != -1: |
|
311 | if compression_strategy != -1: | |
308 | strategy = compression_strategy |
|
312 | strategy = compression_strategy | |
309 | elif strategy == -1: |
|
313 | elif strategy == -1: | |
310 | strategy = 0 |
|
314 | strategy = 0 | |
311 |
|
315 | |||
312 | _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy) |
|
316 | _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy) | |
313 | _set_compression_parameter( |
|
317 | _set_compression_parameter( | |
314 | params, lib.ZSTD_c_contentSizeFlag, write_content_size |
|
318 | params, lib.ZSTD_c_contentSizeFlag, write_content_size | |
315 | ) |
|
319 | ) | |
316 | _set_compression_parameter(params, lib.ZSTD_c_checksumFlag, write_checksum) |
|
320 | _set_compression_parameter( | |
|
321 | params, lib.ZSTD_c_checksumFlag, write_checksum | |||
|
322 | ) | |||
317 | _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id) |
|
323 | _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id) | |
318 | _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size) |
|
324 | _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size) | |
319 |
|
325 | |||
320 | if overlap_log != -1 and overlap_size_log != -1: |
|
326 | if overlap_log != -1 and overlap_size_log != -1: | |
321 | raise ValueError("cannot specify both overlap_log and overlap_size_log") |
|
327 | raise ValueError( | |
|
328 | "cannot specify both overlap_log and overlap_size_log" | |||
|
329 | ) | |||
322 |
|
330 | |||
323 | if overlap_size_log != -1: |
|
331 | if overlap_size_log != -1: | |
324 | overlap_log = overlap_size_log |
|
332 | overlap_log = overlap_size_log | |
325 | elif overlap_log == -1: |
|
333 | elif overlap_log == -1: | |
326 | overlap_log = 0 |
|
334 | overlap_log = 0 | |
327 |
|
335 | |||
328 | _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log) |
|
336 | _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log) | |
329 | _set_compression_parameter(params, lib.ZSTD_c_forceMaxWindow, force_max_window) |
|
337 | _set_compression_parameter( | |
|
338 | params, lib.ZSTD_c_forceMaxWindow, force_max_window | |||
|
339 | ) | |||
330 | _set_compression_parameter( |
|
340 | _set_compression_parameter( | |
331 | params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm |
|
341 | params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm | |
332 | ) |
|
342 | ) | |
333 | _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log) |
|
343 | _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log) | |
334 |
_set_compression_parameter( |
|
344 | _set_compression_parameter( | |
|
345 | params, lib.ZSTD_c_ldmMinMatch, ldm_min_match | |||
|
346 | ) | |||
335 | _set_compression_parameter( |
|
347 | _set_compression_parameter( | |
336 | params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log |
|
348 | params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log | |
337 | ) |
|
349 | ) | |
338 |
|
350 | |||
339 | if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1: |
|
351 | if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1: | |
340 | raise ValueError( |
|
352 | raise ValueError( | |
341 | "cannot specify both ldm_hash_rate_log and ldm_hash_every_log" |
|
353 | "cannot specify both ldm_hash_rate_log and ldm_hash_every_log" | |
342 | ) |
|
354 | ) | |
343 |
|
355 | |||
344 | if ldm_hash_every_log != -1: |
|
356 | if ldm_hash_every_log != -1: | |
345 | ldm_hash_rate_log = ldm_hash_every_log |
|
357 | ldm_hash_rate_log = ldm_hash_every_log | |
346 | elif ldm_hash_rate_log == -1: |
|
358 | elif ldm_hash_rate_log == -1: | |
347 | ldm_hash_rate_log = 0 |
|
359 | ldm_hash_rate_log = 0 | |
348 |
|
360 | |||
349 | _set_compression_parameter(params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log) |
|
361 | _set_compression_parameter( | |
|
362 | params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log | |||
|
363 | ) | |||
350 |
|
364 | |||
351 | @property |
|
365 | @property | |
352 | def format(self): |
|
366 | def format(self): | |
353 | return _get_compression_parameter(self._params, lib.ZSTD_c_format) |
|
367 | return _get_compression_parameter(self._params, lib.ZSTD_c_format) | |
354 |
|
368 | |||
355 | @property |
|
369 | @property | |
356 | def compression_level(self): |
|
370 | def compression_level(self): | |
357 |
return _get_compression_parameter( |
|
371 | return _get_compression_parameter( | |
|
372 | self._params, lib.ZSTD_c_compressionLevel | |||
|
373 | ) | |||
358 |
|
374 | |||
359 | @property |
|
375 | @property | |
360 | def window_log(self): |
|
376 | def window_log(self): | |
361 | return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog) |
|
377 | return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog) | |
362 |
|
378 | |||
363 | @property |
|
379 | @property | |
364 | def hash_log(self): |
|
380 | def hash_log(self): | |
365 | return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog) |
|
381 | return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog) | |
366 |
|
382 | |||
367 | @property |
|
383 | @property | |
368 | def chain_log(self): |
|
384 | def chain_log(self): | |
369 | return _get_compression_parameter(self._params, lib.ZSTD_c_chainLog) |
|
385 | return _get_compression_parameter(self._params, lib.ZSTD_c_chainLog) | |
370 |
|
386 | |||
371 | @property |
|
387 | @property | |
372 | def search_log(self): |
|
388 | def search_log(self): | |
373 | return _get_compression_parameter(self._params, lib.ZSTD_c_searchLog) |
|
389 | return _get_compression_parameter(self._params, lib.ZSTD_c_searchLog) | |
374 |
|
390 | |||
375 | @property |
|
391 | @property | |
376 | def min_match(self): |
|
392 | def min_match(self): | |
377 | return _get_compression_parameter(self._params, lib.ZSTD_c_minMatch) |
|
393 | return _get_compression_parameter(self._params, lib.ZSTD_c_minMatch) | |
378 |
|
394 | |||
379 | @property |
|
395 | @property | |
380 | def target_length(self): |
|
396 | def target_length(self): | |
381 | return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength) |
|
397 | return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength) | |
382 |
|
398 | |||
383 | @property |
|
399 | @property | |
384 | def compression_strategy(self): |
|
400 | def compression_strategy(self): | |
385 | return _get_compression_parameter(self._params, lib.ZSTD_c_strategy) |
|
401 | return _get_compression_parameter(self._params, lib.ZSTD_c_strategy) | |
386 |
|
402 | |||
387 | @property |
|
403 | @property | |
388 | def write_content_size(self): |
|
404 | def write_content_size(self): | |
389 |
return _get_compression_parameter( |
|
405 | return _get_compression_parameter( | |
|
406 | self._params, lib.ZSTD_c_contentSizeFlag | |||
|
407 | ) | |||
390 |
|
408 | |||
391 | @property |
|
409 | @property | |
392 | def write_checksum(self): |
|
410 | def write_checksum(self): | |
393 | return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag) |
|
411 | return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag) | |
394 |
|
412 | |||
395 | @property |
|
413 | @property | |
396 | def write_dict_id(self): |
|
414 | def write_dict_id(self): | |
397 | return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag) |
|
415 | return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag) | |
398 |
|
416 | |||
399 | @property |
|
417 | @property | |
400 | def job_size(self): |
|
418 | def job_size(self): | |
401 | return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize) |
|
419 | return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize) | |
402 |
|
420 | |||
403 | @property |
|
421 | @property | |
404 | def overlap_log(self): |
|
422 | def overlap_log(self): | |
405 | return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog) |
|
423 | return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog) | |
406 |
|
424 | |||
407 | @property |
|
425 | @property | |
408 | def overlap_size_log(self): |
|
426 | def overlap_size_log(self): | |
409 | return self.overlap_log |
|
427 | return self.overlap_log | |
410 |
|
428 | |||
411 | @property |
|
429 | @property | |
412 | def force_max_window(self): |
|
430 | def force_max_window(self): | |
413 |
return _get_compression_parameter( |
|
431 | return _get_compression_parameter( | |
|
432 | self._params, lib.ZSTD_c_forceMaxWindow | |||
|
433 | ) | |||
414 |
|
434 | |||
415 | @property |
|
435 | @property | |
416 | def enable_ldm(self): |
|
436 | def enable_ldm(self): | |
417 | return _get_compression_parameter( |
|
437 | return _get_compression_parameter( | |
418 | self._params, lib.ZSTD_c_enableLongDistanceMatching |
|
438 | self._params, lib.ZSTD_c_enableLongDistanceMatching | |
419 | ) |
|
439 | ) | |
420 |
|
440 | |||
421 | @property |
|
441 | @property | |
422 | def ldm_hash_log(self): |
|
442 | def ldm_hash_log(self): | |
423 | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog) |
|
443 | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog) | |
424 |
|
444 | |||
425 | @property |
|
445 | @property | |
426 | def ldm_min_match(self): |
|
446 | def ldm_min_match(self): | |
427 | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch) |
|
447 | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch) | |
428 |
|
448 | |||
429 | @property |
|
449 | @property | |
430 | def ldm_bucket_size_log(self): |
|
450 | def ldm_bucket_size_log(self): | |
431 |
return _get_compression_parameter( |
|
451 | return _get_compression_parameter( | |
|
452 | self._params, lib.ZSTD_c_ldmBucketSizeLog | |||
|
453 | ) | |||
432 |
|
454 | |||
433 | @property |
|
455 | @property | |
434 | def ldm_hash_rate_log(self): |
|
456 | def ldm_hash_rate_log(self): | |
435 |
return _get_compression_parameter( |
|
457 | return _get_compression_parameter( | |
|
458 | self._params, lib.ZSTD_c_ldmHashRateLog | |||
|
459 | ) | |||
436 |
|
460 | |||
437 | @property |
|
461 | @property | |
438 | def ldm_hash_every_log(self): |
|
462 | def ldm_hash_every_log(self): | |
439 | return self.ldm_hash_rate_log |
|
463 | return self.ldm_hash_rate_log | |
440 |
|
464 | |||
441 | @property |
|
465 | @property | |
442 | def threads(self): |
|
466 | def threads(self): | |
443 | return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers) |
|
467 | return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers) | |
444 |
|
468 | |||
445 | def estimated_compression_context_size(self): |
|
469 | def estimated_compression_context_size(self): | |
446 | return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params) |
|
470 | return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params) | |
447 |
|
471 | |||
448 |
|
472 | |||
449 | CompressionParameters = ZstdCompressionParameters |
|
473 | CompressionParameters = ZstdCompressionParameters | |
450 |
|
474 | |||
451 |
|
475 | |||
452 | def estimate_decompression_context_size(): |
|
476 | def estimate_decompression_context_size(): | |
453 | return lib.ZSTD_estimateDCtxSize() |
|
477 | return lib.ZSTD_estimateDCtxSize() | |
454 |
|
478 | |||
455 |
|
479 | |||
456 | def _set_compression_parameter(params, param, value): |
|
480 | def _set_compression_parameter(params, param, value): | |
457 | zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value) |
|
481 | zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value) | |
458 | if lib.ZSTD_isError(zresult): |
|
482 | if lib.ZSTD_isError(zresult): | |
459 | raise ZstdError( |
|
483 | raise ZstdError( | |
460 |
"unable to set compression context parameter: %s" |
|
484 | "unable to set compression context parameter: %s" | |
|
485 | % _zstd_error(zresult) | |||
461 | ) |
|
486 | ) | |
462 |
|
487 | |||
463 |
|
488 | |||
464 | def _get_compression_parameter(params, param): |
|
489 | def _get_compression_parameter(params, param): | |
465 | result = ffi.new("int *") |
|
490 | result = ffi.new("int *") | |
466 |
|
491 | |||
467 | zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result) |
|
492 | zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result) | |
468 | if lib.ZSTD_isError(zresult): |
|
493 | if lib.ZSTD_isError(zresult): | |
469 | raise ZstdError( |
|
494 | raise ZstdError( | |
470 |
"unable to get compression context parameter: %s" |
|
495 | "unable to get compression context parameter: %s" | |
|
496 | % _zstd_error(zresult) | |||
471 | ) |
|
497 | ) | |
472 |
|
498 | |||
473 | return result[0] |
|
499 | return result[0] | |
474 |
|
500 | |||
475 |
|
501 | |||
476 | class ZstdCompressionWriter(object): |
|
502 | class ZstdCompressionWriter(object): | |
477 | def __init__(self, compressor, writer, source_size, write_size, write_return_read): |
|
503 | def __init__( | |
|
504 | self, compressor, writer, source_size, write_size, write_return_read | |||
|
505 | ): | |||
478 | self._compressor = compressor |
|
506 | self._compressor = compressor | |
479 | self._writer = writer |
|
507 | self._writer = writer | |
480 | self._write_size = write_size |
|
508 | self._write_size = write_size | |
481 | self._write_return_read = bool(write_return_read) |
|
509 | self._write_return_read = bool(write_return_read) | |
482 | self._entered = False |
|
510 | self._entered = False | |
483 | self._closed = False |
|
511 | self._closed = False | |
484 | self._bytes_compressed = 0 |
|
512 | self._bytes_compressed = 0 | |
485 |
|
513 | |||
486 | self._dst_buffer = ffi.new("char[]", write_size) |
|
514 | self._dst_buffer = ffi.new("char[]", write_size) | |
487 | self._out_buffer = ffi.new("ZSTD_outBuffer *") |
|
515 | self._out_buffer = ffi.new("ZSTD_outBuffer *") | |
488 | self._out_buffer.dst = self._dst_buffer |
|
516 | self._out_buffer.dst = self._dst_buffer | |
489 | self._out_buffer.size = len(self._dst_buffer) |
|
517 | self._out_buffer.size = len(self._dst_buffer) | |
490 | self._out_buffer.pos = 0 |
|
518 | self._out_buffer.pos = 0 | |
491 |
|
519 | |||
492 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size) |
|
520 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size) | |
493 | if lib.ZSTD_isError(zresult): |
|
521 | if lib.ZSTD_isError(zresult): | |
494 | raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) |
|
522 | raise ZstdError( | |
|
523 | "error setting source size: %s" % _zstd_error(zresult) | |||
|
524 | ) | |||
495 |
|
525 | |||
496 | def __enter__(self): |
|
526 | def __enter__(self): | |
497 | if self._closed: |
|
527 | if self._closed: | |
498 | raise ValueError("stream is closed") |
|
528 | raise ValueError("stream is closed") | |
499 |
|
529 | |||
500 | if self._entered: |
|
530 | if self._entered: | |
501 | raise ZstdError("cannot __enter__ multiple times") |
|
531 | raise ZstdError("cannot __enter__ multiple times") | |
502 |
|
532 | |||
503 | self._entered = True |
|
533 | self._entered = True | |
504 | return self |
|
534 | return self | |
505 |
|
535 | |||
506 | def __exit__(self, exc_type, exc_value, exc_tb): |
|
536 | def __exit__(self, exc_type, exc_value, exc_tb): | |
507 | self._entered = False |
|
537 | self._entered = False | |
508 |
|
538 | |||
509 | if not exc_type and not exc_value and not exc_tb: |
|
539 | if not exc_type and not exc_value and not exc_tb: | |
510 | self.close() |
|
540 | self.close() | |
511 |
|
541 | |||
512 | self._compressor = None |
|
542 | self._compressor = None | |
513 |
|
543 | |||
514 | return False |
|
544 | return False | |
515 |
|
545 | |||
516 | def memory_size(self): |
|
546 | def memory_size(self): | |
517 | return lib.ZSTD_sizeof_CCtx(self._compressor._cctx) |
|
547 | return lib.ZSTD_sizeof_CCtx(self._compressor._cctx) | |
518 |
|
548 | |||
519 | def fileno(self): |
|
549 | def fileno(self): | |
520 | f = getattr(self._writer, "fileno", None) |
|
550 | f = getattr(self._writer, "fileno", None) | |
521 | if f: |
|
551 | if f: | |
522 | return f() |
|
552 | return f() | |
523 | else: |
|
553 | else: | |
524 | raise OSError("fileno not available on underlying writer") |
|
554 | raise OSError("fileno not available on underlying writer") | |
525 |
|
555 | |||
526 | def close(self): |
|
556 | def close(self): | |
527 | if self._closed: |
|
557 | if self._closed: | |
528 | return |
|
558 | return | |
529 |
|
559 | |||
530 | try: |
|
560 | try: | |
531 | self.flush(FLUSH_FRAME) |
|
561 | self.flush(FLUSH_FRAME) | |
532 | finally: |
|
562 | finally: | |
533 | self._closed = True |
|
563 | self._closed = True | |
534 |
|
564 | |||
535 | # Call close() on underlying stream as well. |
|
565 | # Call close() on underlying stream as well. | |
536 | f = getattr(self._writer, "close", None) |
|
566 | f = getattr(self._writer, "close", None) | |
537 | if f: |
|
567 | if f: | |
538 | f() |
|
568 | f() | |
539 |
|
569 | |||
540 | @property |
|
570 | @property | |
541 | def closed(self): |
|
571 | def closed(self): | |
542 | return self._closed |
|
572 | return self._closed | |
543 |
|
573 | |||
544 | def isatty(self): |
|
574 | def isatty(self): | |
545 | return False |
|
575 | return False | |
546 |
|
576 | |||
547 | def readable(self): |
|
577 | def readable(self): | |
548 | return False |
|
578 | return False | |
549 |
|
579 | |||
550 | def readline(self, size=-1): |
|
580 | def readline(self, size=-1): | |
551 | raise io.UnsupportedOperation() |
|
581 | raise io.UnsupportedOperation() | |
552 |
|
582 | |||
553 | def readlines(self, hint=-1): |
|
583 | def readlines(self, hint=-1): | |
554 | raise io.UnsupportedOperation() |
|
584 | raise io.UnsupportedOperation() | |
555 |
|
585 | |||
556 | def seek(self, offset, whence=None): |
|
586 | def seek(self, offset, whence=None): | |
557 | raise io.UnsupportedOperation() |
|
587 | raise io.UnsupportedOperation() | |
558 |
|
588 | |||
559 | def seekable(self): |
|
589 | def seekable(self): | |
560 | return False |
|
590 | return False | |
561 |
|
591 | |||
562 | def truncate(self, size=None): |
|
592 | def truncate(self, size=None): | |
563 | raise io.UnsupportedOperation() |
|
593 | raise io.UnsupportedOperation() | |
564 |
|
594 | |||
565 | def writable(self): |
|
595 | def writable(self): | |
566 | return True |
|
596 | return True | |
567 |
|
597 | |||
568 | def writelines(self, lines): |
|
598 | def writelines(self, lines): | |
569 | raise NotImplementedError("writelines() is not yet implemented") |
|
599 | raise NotImplementedError("writelines() is not yet implemented") | |
570 |
|
600 | |||
571 | def read(self, size=-1): |
|
601 | def read(self, size=-1): | |
572 | raise io.UnsupportedOperation() |
|
602 | raise io.UnsupportedOperation() | |
573 |
|
603 | |||
574 | def readall(self): |
|
604 | def readall(self): | |
575 | raise io.UnsupportedOperation() |
|
605 | raise io.UnsupportedOperation() | |
576 |
|
606 | |||
577 | def readinto(self, b): |
|
607 | def readinto(self, b): | |
578 | raise io.UnsupportedOperation() |
|
608 | raise io.UnsupportedOperation() | |
579 |
|
609 | |||
580 | def write(self, data): |
|
610 | def write(self, data): | |
581 | if self._closed: |
|
611 | if self._closed: | |
582 | raise ValueError("stream is closed") |
|
612 | raise ValueError("stream is closed") | |
583 |
|
613 | |||
584 | total_write = 0 |
|
614 | total_write = 0 | |
585 |
|
615 | |||
586 | data_buffer = ffi.from_buffer(data) |
|
616 | data_buffer = ffi.from_buffer(data) | |
587 |
|
617 | |||
588 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
618 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
589 | in_buffer.src = data_buffer |
|
619 | in_buffer.src = data_buffer | |
590 | in_buffer.size = len(data_buffer) |
|
620 | in_buffer.size = len(data_buffer) | |
591 | in_buffer.pos = 0 |
|
621 | in_buffer.pos = 0 | |
592 |
|
622 | |||
593 | out_buffer = self._out_buffer |
|
623 | out_buffer = self._out_buffer | |
594 | out_buffer.pos = 0 |
|
624 | out_buffer.pos = 0 | |
595 |
|
625 | |||
596 | while in_buffer.pos < in_buffer.size: |
|
626 | while in_buffer.pos < in_buffer.size: | |
597 | zresult = lib.ZSTD_compressStream2( |
|
627 | zresult = lib.ZSTD_compressStream2( | |
598 |
self._compressor._cctx, |
|
628 | self._compressor._cctx, | |
|
629 | out_buffer, | |||
|
630 | in_buffer, | |||
|
631 | lib.ZSTD_e_continue, | |||
599 | ) |
|
632 | ) | |
600 | if lib.ZSTD_isError(zresult): |
|
633 | if lib.ZSTD_isError(zresult): | |
601 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
634 | raise ZstdError( | |
|
635 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
636 | ) | |||
602 |
|
637 | |||
603 | if out_buffer.pos: |
|
638 | if out_buffer.pos: | |
604 |
self._writer.write( |
|
639 | self._writer.write( | |
|
640 | ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||
|
641 | ) | |||
605 | total_write += out_buffer.pos |
|
642 | total_write += out_buffer.pos | |
606 | self._bytes_compressed += out_buffer.pos |
|
643 | self._bytes_compressed += out_buffer.pos | |
607 | out_buffer.pos = 0 |
|
644 | out_buffer.pos = 0 | |
608 |
|
645 | |||
609 | if self._write_return_read: |
|
646 | if self._write_return_read: | |
610 | return in_buffer.pos |
|
647 | return in_buffer.pos | |
611 | else: |
|
648 | else: | |
612 | return total_write |
|
649 | return total_write | |
613 |
|
650 | |||
614 | def flush(self, flush_mode=FLUSH_BLOCK): |
|
651 | def flush(self, flush_mode=FLUSH_BLOCK): | |
615 | if flush_mode == FLUSH_BLOCK: |
|
652 | if flush_mode == FLUSH_BLOCK: | |
616 | flush = lib.ZSTD_e_flush |
|
653 | flush = lib.ZSTD_e_flush | |
617 | elif flush_mode == FLUSH_FRAME: |
|
654 | elif flush_mode == FLUSH_FRAME: | |
618 | flush = lib.ZSTD_e_end |
|
655 | flush = lib.ZSTD_e_end | |
619 | else: |
|
656 | else: | |
620 | raise ValueError("unknown flush_mode: %r" % flush_mode) |
|
657 | raise ValueError("unknown flush_mode: %r" % flush_mode) | |
621 |
|
658 | |||
622 | if self._closed: |
|
659 | if self._closed: | |
623 | raise ValueError("stream is closed") |
|
660 | raise ValueError("stream is closed") | |
624 |
|
661 | |||
625 | total_write = 0 |
|
662 | total_write = 0 | |
626 |
|
663 | |||
627 | out_buffer = self._out_buffer |
|
664 | out_buffer = self._out_buffer | |
628 | out_buffer.pos = 0 |
|
665 | out_buffer.pos = 0 | |
629 |
|
666 | |||
630 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
667 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
631 | in_buffer.src = ffi.NULL |
|
668 | in_buffer.src = ffi.NULL | |
632 | in_buffer.size = 0 |
|
669 | in_buffer.size = 0 | |
633 | in_buffer.pos = 0 |
|
670 | in_buffer.pos = 0 | |
634 |
|
671 | |||
635 | while True: |
|
672 | while True: | |
636 | zresult = lib.ZSTD_compressStream2( |
|
673 | zresult = lib.ZSTD_compressStream2( | |
637 | self._compressor._cctx, out_buffer, in_buffer, flush |
|
674 | self._compressor._cctx, out_buffer, in_buffer, flush | |
638 | ) |
|
675 | ) | |
639 | if lib.ZSTD_isError(zresult): |
|
676 | if lib.ZSTD_isError(zresult): | |
640 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
677 | raise ZstdError( | |
|
678 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
679 | ) | |||
641 |
|
680 | |||
642 | if out_buffer.pos: |
|
681 | if out_buffer.pos: | |
643 |
self._writer.write( |
|
682 | self._writer.write( | |
|
683 | ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||
|
684 | ) | |||
644 | total_write += out_buffer.pos |
|
685 | total_write += out_buffer.pos | |
645 | self._bytes_compressed += out_buffer.pos |
|
686 | self._bytes_compressed += out_buffer.pos | |
646 | out_buffer.pos = 0 |
|
687 | out_buffer.pos = 0 | |
647 |
|
688 | |||
648 | if not zresult: |
|
689 | if not zresult: | |
649 | break |
|
690 | break | |
650 |
|
691 | |||
651 | return total_write |
|
692 | return total_write | |
652 |
|
693 | |||
653 | def tell(self): |
|
694 | def tell(self): | |
654 | return self._bytes_compressed |
|
695 | return self._bytes_compressed | |
655 |
|
696 | |||
656 |
|
697 | |||
657 | class ZstdCompressionObj(object): |
|
698 | class ZstdCompressionObj(object): | |
658 | def compress(self, data): |
|
699 | def compress(self, data): | |
659 | if self._finished: |
|
700 | if self._finished: | |
660 | raise ZstdError("cannot call compress() after compressor finished") |
|
701 | raise ZstdError("cannot call compress() after compressor finished") | |
661 |
|
702 | |||
662 | data_buffer = ffi.from_buffer(data) |
|
703 | data_buffer = ffi.from_buffer(data) | |
663 | source = ffi.new("ZSTD_inBuffer *") |
|
704 | source = ffi.new("ZSTD_inBuffer *") | |
664 | source.src = data_buffer |
|
705 | source.src = data_buffer | |
665 | source.size = len(data_buffer) |
|
706 | source.size = len(data_buffer) | |
666 | source.pos = 0 |
|
707 | source.pos = 0 | |
667 |
|
708 | |||
668 | chunks = [] |
|
709 | chunks = [] | |
669 |
|
710 | |||
670 | while source.pos < len(data): |
|
711 | while source.pos < len(data): | |
671 | zresult = lib.ZSTD_compressStream2( |
|
712 | zresult = lib.ZSTD_compressStream2( | |
672 | self._compressor._cctx, self._out, source, lib.ZSTD_e_continue |
|
713 | self._compressor._cctx, self._out, source, lib.ZSTD_e_continue | |
673 | ) |
|
714 | ) | |
674 | if lib.ZSTD_isError(zresult): |
|
715 | if lib.ZSTD_isError(zresult): | |
675 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
716 | raise ZstdError( | |
|
717 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
718 | ) | |||
676 |
|
719 | |||
677 | if self._out.pos: |
|
720 | if self._out.pos: | |
678 | chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) |
|
721 | chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) | |
679 | self._out.pos = 0 |
|
722 | self._out.pos = 0 | |
680 |
|
723 | |||
681 | return b"".join(chunks) |
|
724 | return b"".join(chunks) | |
682 |
|
725 | |||
683 | def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH): |
|
726 | def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH): | |
684 | if flush_mode not in (COMPRESSOBJ_FLUSH_FINISH, COMPRESSOBJ_FLUSH_BLOCK): |
|
727 | if flush_mode not in ( | |
|
728 | COMPRESSOBJ_FLUSH_FINISH, | |||
|
729 | COMPRESSOBJ_FLUSH_BLOCK, | |||
|
730 | ): | |||
685 | raise ValueError("flush mode not recognized") |
|
731 | raise ValueError("flush mode not recognized") | |
686 |
|
732 | |||
687 | if self._finished: |
|
733 | if self._finished: | |
688 | raise ZstdError("compressor object already finished") |
|
734 | raise ZstdError("compressor object already finished") | |
689 |
|
735 | |||
690 | if flush_mode == COMPRESSOBJ_FLUSH_BLOCK: |
|
736 | if flush_mode == COMPRESSOBJ_FLUSH_BLOCK: | |
691 | z_flush_mode = lib.ZSTD_e_flush |
|
737 | z_flush_mode = lib.ZSTD_e_flush | |
692 | elif flush_mode == COMPRESSOBJ_FLUSH_FINISH: |
|
738 | elif flush_mode == COMPRESSOBJ_FLUSH_FINISH: | |
693 | z_flush_mode = lib.ZSTD_e_end |
|
739 | z_flush_mode = lib.ZSTD_e_end | |
694 | self._finished = True |
|
740 | self._finished = True | |
695 | else: |
|
741 | else: | |
696 | raise ZstdError("unhandled flush mode") |
|
742 | raise ZstdError("unhandled flush mode") | |
697 |
|
743 | |||
698 | assert self._out.pos == 0 |
|
744 | assert self._out.pos == 0 | |
699 |
|
745 | |||
700 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
746 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
701 | in_buffer.src = ffi.NULL |
|
747 | in_buffer.src = ffi.NULL | |
702 | in_buffer.size = 0 |
|
748 | in_buffer.size = 0 | |
703 | in_buffer.pos = 0 |
|
749 | in_buffer.pos = 0 | |
704 |
|
750 | |||
705 | chunks = [] |
|
751 | chunks = [] | |
706 |
|
752 | |||
707 | while True: |
|
753 | while True: | |
708 | zresult = lib.ZSTD_compressStream2( |
|
754 | zresult = lib.ZSTD_compressStream2( | |
709 | self._compressor._cctx, self._out, in_buffer, z_flush_mode |
|
755 | self._compressor._cctx, self._out, in_buffer, z_flush_mode | |
710 | ) |
|
756 | ) | |
711 | if lib.ZSTD_isError(zresult): |
|
757 | if lib.ZSTD_isError(zresult): | |
712 | raise ZstdError( |
|
758 | raise ZstdError( | |
713 | "error ending compression stream: %s" % _zstd_error(zresult) |
|
759 | "error ending compression stream: %s" % _zstd_error(zresult) | |
714 | ) |
|
760 | ) | |
715 |
|
761 | |||
716 | if self._out.pos: |
|
762 | if self._out.pos: | |
717 | chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) |
|
763 | chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) | |
718 | self._out.pos = 0 |
|
764 | self._out.pos = 0 | |
719 |
|
765 | |||
720 | if not zresult: |
|
766 | if not zresult: | |
721 | break |
|
767 | break | |
722 |
|
768 | |||
723 | return b"".join(chunks) |
|
769 | return b"".join(chunks) | |
724 |
|
770 | |||
725 |
|
771 | |||
726 | class ZstdCompressionChunker(object): |
|
772 | class ZstdCompressionChunker(object): | |
727 | def __init__(self, compressor, chunk_size): |
|
773 | def __init__(self, compressor, chunk_size): | |
728 | self._compressor = compressor |
|
774 | self._compressor = compressor | |
729 | self._out = ffi.new("ZSTD_outBuffer *") |
|
775 | self._out = ffi.new("ZSTD_outBuffer *") | |
730 | self._dst_buffer = ffi.new("char[]", chunk_size) |
|
776 | self._dst_buffer = ffi.new("char[]", chunk_size) | |
731 | self._out.dst = self._dst_buffer |
|
777 | self._out.dst = self._dst_buffer | |
732 | self._out.size = chunk_size |
|
778 | self._out.size = chunk_size | |
733 | self._out.pos = 0 |
|
779 | self._out.pos = 0 | |
734 |
|
780 | |||
735 | self._in = ffi.new("ZSTD_inBuffer *") |
|
781 | self._in = ffi.new("ZSTD_inBuffer *") | |
736 | self._in.src = ffi.NULL |
|
782 | self._in.src = ffi.NULL | |
737 | self._in.size = 0 |
|
783 | self._in.size = 0 | |
738 | self._in.pos = 0 |
|
784 | self._in.pos = 0 | |
739 | self._finished = False |
|
785 | self._finished = False | |
740 |
|
786 | |||
741 | def compress(self, data): |
|
787 | def compress(self, data): | |
742 | if self._finished: |
|
788 | if self._finished: | |
743 | raise ZstdError("cannot call compress() after compression finished") |
|
789 | raise ZstdError("cannot call compress() after compression finished") | |
744 |
|
790 | |||
745 | if self._in.src != ffi.NULL: |
|
791 | if self._in.src != ffi.NULL: | |
746 | raise ZstdError( |
|
792 | raise ZstdError( | |
747 | "cannot perform operation before consuming output " |
|
793 | "cannot perform operation before consuming output " | |
748 | "from previous operation" |
|
794 | "from previous operation" | |
749 | ) |
|
795 | ) | |
750 |
|
796 | |||
751 | data_buffer = ffi.from_buffer(data) |
|
797 | data_buffer = ffi.from_buffer(data) | |
752 |
|
798 | |||
753 | if not len(data_buffer): |
|
799 | if not len(data_buffer): | |
754 | return |
|
800 | return | |
755 |
|
801 | |||
756 | self._in.src = data_buffer |
|
802 | self._in.src = data_buffer | |
757 | self._in.size = len(data_buffer) |
|
803 | self._in.size = len(data_buffer) | |
758 | self._in.pos = 0 |
|
804 | self._in.pos = 0 | |
759 |
|
805 | |||
760 | while self._in.pos < self._in.size: |
|
806 | while self._in.pos < self._in.size: | |
761 | zresult = lib.ZSTD_compressStream2( |
|
807 | zresult = lib.ZSTD_compressStream2( | |
762 | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_continue |
|
808 | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_continue | |
763 | ) |
|
809 | ) | |
764 |
|
810 | |||
765 | if self._in.pos == self._in.size: |
|
811 | if self._in.pos == self._in.size: | |
766 | self._in.src = ffi.NULL |
|
812 | self._in.src = ffi.NULL | |
767 | self._in.size = 0 |
|
813 | self._in.size = 0 | |
768 | self._in.pos = 0 |
|
814 | self._in.pos = 0 | |
769 |
|
815 | |||
770 | if lib.ZSTD_isError(zresult): |
|
816 | if lib.ZSTD_isError(zresult): | |
771 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
817 | raise ZstdError( | |
|
818 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
819 | ) | |||
772 |
|
820 | |||
773 | if self._out.pos == self._out.size: |
|
821 | if self._out.pos == self._out.size: | |
774 | yield ffi.buffer(self._out.dst, self._out.pos)[:] |
|
822 | yield ffi.buffer(self._out.dst, self._out.pos)[:] | |
775 | self._out.pos = 0 |
|
823 | self._out.pos = 0 | |
776 |
|
824 | |||
777 | def flush(self): |
|
825 | def flush(self): | |
778 | if self._finished: |
|
826 | if self._finished: | |
779 | raise ZstdError("cannot call flush() after compression finished") |
|
827 | raise ZstdError("cannot call flush() after compression finished") | |
780 |
|
828 | |||
781 | if self._in.src != ffi.NULL: |
|
829 | if self._in.src != ffi.NULL: | |
782 | raise ZstdError( |
|
830 | raise ZstdError( | |
783 |
"cannot call flush() before consuming output from " |
|
831 | "cannot call flush() before consuming output from " | |
|
832 | "previous operation" | |||
784 | ) |
|
833 | ) | |
785 |
|
834 | |||
786 | while True: |
|
835 | while True: | |
787 | zresult = lib.ZSTD_compressStream2( |
|
836 | zresult = lib.ZSTD_compressStream2( | |
788 | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush |
|
837 | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush | |
789 | ) |
|
838 | ) | |
790 | if lib.ZSTD_isError(zresult): |
|
839 | if lib.ZSTD_isError(zresult): | |
791 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
840 | raise ZstdError( | |
|
841 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
842 | ) | |||
792 |
|
843 | |||
793 | if self._out.pos: |
|
844 | if self._out.pos: | |
794 | yield ffi.buffer(self._out.dst, self._out.pos)[:] |
|
845 | yield ffi.buffer(self._out.dst, self._out.pos)[:] | |
795 | self._out.pos = 0 |
|
846 | self._out.pos = 0 | |
796 |
|
847 | |||
797 | if not zresult: |
|
848 | if not zresult: | |
798 | return |
|
849 | return | |
799 |
|
850 | |||
800 | def finish(self): |
|
851 | def finish(self): | |
801 | if self._finished: |
|
852 | if self._finished: | |
802 | raise ZstdError("cannot call finish() after compression finished") |
|
853 | raise ZstdError("cannot call finish() after compression finished") | |
803 |
|
854 | |||
804 | if self._in.src != ffi.NULL: |
|
855 | if self._in.src != ffi.NULL: | |
805 | raise ZstdError( |
|
856 | raise ZstdError( | |
806 | "cannot call finish() before consuming output from " |
|
857 | "cannot call finish() before consuming output from " | |
807 | "previous operation" |
|
858 | "previous operation" | |
808 | ) |
|
859 | ) | |
809 |
|
860 | |||
810 | while True: |
|
861 | while True: | |
811 | zresult = lib.ZSTD_compressStream2( |
|
862 | zresult = lib.ZSTD_compressStream2( | |
812 | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end |
|
863 | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end | |
813 | ) |
|
864 | ) | |
814 | if lib.ZSTD_isError(zresult): |
|
865 | if lib.ZSTD_isError(zresult): | |
815 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
866 | raise ZstdError( | |
|
867 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
868 | ) | |||
816 |
|
869 | |||
817 | if self._out.pos: |
|
870 | if self._out.pos: | |
818 | yield ffi.buffer(self._out.dst, self._out.pos)[:] |
|
871 | yield ffi.buffer(self._out.dst, self._out.pos)[:] | |
819 | self._out.pos = 0 |
|
872 | self._out.pos = 0 | |
820 |
|
873 | |||
821 | if not zresult: |
|
874 | if not zresult: | |
822 | self._finished = True |
|
875 | self._finished = True | |
823 | return |
|
876 | return | |
824 |
|
877 | |||
825 |
|
878 | |||
826 | class ZstdCompressionReader(object): |
|
879 | class ZstdCompressionReader(object): | |
827 | def __init__(self, compressor, source, read_size): |
|
880 | def __init__(self, compressor, source, read_size): | |
828 | self._compressor = compressor |
|
881 | self._compressor = compressor | |
829 | self._source = source |
|
882 | self._source = source | |
830 | self._read_size = read_size |
|
883 | self._read_size = read_size | |
831 | self._entered = False |
|
884 | self._entered = False | |
832 | self._closed = False |
|
885 | self._closed = False | |
833 | self._bytes_compressed = 0 |
|
886 | self._bytes_compressed = 0 | |
834 | self._finished_input = False |
|
887 | self._finished_input = False | |
835 | self._finished_output = False |
|
888 | self._finished_output = False | |
836 |
|
889 | |||
837 | self._in_buffer = ffi.new("ZSTD_inBuffer *") |
|
890 | self._in_buffer = ffi.new("ZSTD_inBuffer *") | |
838 | # Holds a ref so backing bytes in self._in_buffer stay alive. |
|
891 | # Holds a ref so backing bytes in self._in_buffer stay alive. | |
839 | self._source_buffer = None |
|
892 | self._source_buffer = None | |
840 |
|
893 | |||
841 | def __enter__(self): |
|
894 | def __enter__(self): | |
842 | if self._entered: |
|
895 | if self._entered: | |
843 | raise ValueError("cannot __enter__ multiple times") |
|
896 | raise ValueError("cannot __enter__ multiple times") | |
844 |
|
897 | |||
845 | self._entered = True |
|
898 | self._entered = True | |
846 | return self |
|
899 | return self | |
847 |
|
900 | |||
848 | def __exit__(self, exc_type, exc_value, exc_tb): |
|
901 | def __exit__(self, exc_type, exc_value, exc_tb): | |
849 | self._entered = False |
|
902 | self._entered = False | |
850 | self._closed = True |
|
903 | self._closed = True | |
851 | self._source = None |
|
904 | self._source = None | |
852 | self._compressor = None |
|
905 | self._compressor = None | |
853 |
|
906 | |||
854 | return False |
|
907 | return False | |
855 |
|
908 | |||
856 | def readable(self): |
|
909 | def readable(self): | |
857 | return True |
|
910 | return True | |
858 |
|
911 | |||
859 | def writable(self): |
|
912 | def writable(self): | |
860 | return False |
|
913 | return False | |
861 |
|
914 | |||
862 | def seekable(self): |
|
915 | def seekable(self): | |
863 | return False |
|
916 | return False | |
864 |
|
917 | |||
865 | def readline(self): |
|
918 | def readline(self): | |
866 | raise io.UnsupportedOperation() |
|
919 | raise io.UnsupportedOperation() | |
867 |
|
920 | |||
868 | def readlines(self): |
|
921 | def readlines(self): | |
869 | raise io.UnsupportedOperation() |
|
922 | raise io.UnsupportedOperation() | |
870 |
|
923 | |||
871 | def write(self, data): |
|
924 | def write(self, data): | |
872 | raise OSError("stream is not writable") |
|
925 | raise OSError("stream is not writable") | |
873 |
|
926 | |||
874 | def writelines(self, ignored): |
|
927 | def writelines(self, ignored): | |
875 | raise OSError("stream is not writable") |
|
928 | raise OSError("stream is not writable") | |
876 |
|
929 | |||
877 | def isatty(self): |
|
930 | def isatty(self): | |
878 | return False |
|
931 | return False | |
879 |
|
932 | |||
880 | def flush(self): |
|
933 | def flush(self): | |
881 | return None |
|
934 | return None | |
882 |
|
935 | |||
883 | def close(self): |
|
936 | def close(self): | |
884 | self._closed = True |
|
937 | self._closed = True | |
885 | return None |
|
938 | return None | |
886 |
|
939 | |||
887 | @property |
|
940 | @property | |
888 | def closed(self): |
|
941 | def closed(self): | |
889 | return self._closed |
|
942 | return self._closed | |
890 |
|
943 | |||
891 | def tell(self): |
|
944 | def tell(self): | |
892 | return self._bytes_compressed |
|
945 | return self._bytes_compressed | |
893 |
|
946 | |||
894 | def readall(self): |
|
947 | def readall(self): | |
895 | chunks = [] |
|
948 | chunks = [] | |
896 |
|
949 | |||
897 | while True: |
|
950 | while True: | |
898 | chunk = self.read(1048576) |
|
951 | chunk = self.read(1048576) | |
899 | if not chunk: |
|
952 | if not chunk: | |
900 | break |
|
953 | break | |
901 |
|
954 | |||
902 | chunks.append(chunk) |
|
955 | chunks.append(chunk) | |
903 |
|
956 | |||
904 | return b"".join(chunks) |
|
957 | return b"".join(chunks) | |
905 |
|
958 | |||
906 | def __iter__(self): |
|
959 | def __iter__(self): | |
907 | raise io.UnsupportedOperation() |
|
960 | raise io.UnsupportedOperation() | |
908 |
|
961 | |||
909 | def __next__(self): |
|
962 | def __next__(self): | |
910 | raise io.UnsupportedOperation() |
|
963 | raise io.UnsupportedOperation() | |
911 |
|
964 | |||
912 | next = __next__ |
|
965 | next = __next__ | |
913 |
|
966 | |||
914 | def _read_input(self): |
|
967 | def _read_input(self): | |
915 | if self._finished_input: |
|
968 | if self._finished_input: | |
916 | return |
|
969 | return | |
917 |
|
970 | |||
918 | if hasattr(self._source, "read"): |
|
971 | if hasattr(self._source, "read"): | |
919 | data = self._source.read(self._read_size) |
|
972 | data = self._source.read(self._read_size) | |
920 |
|
973 | |||
921 | if not data: |
|
974 | if not data: | |
922 | self._finished_input = True |
|
975 | self._finished_input = True | |
923 | return |
|
976 | return | |
924 |
|
977 | |||
925 | self._source_buffer = ffi.from_buffer(data) |
|
978 | self._source_buffer = ffi.from_buffer(data) | |
926 | self._in_buffer.src = self._source_buffer |
|
979 | self._in_buffer.src = self._source_buffer | |
927 | self._in_buffer.size = len(self._source_buffer) |
|
980 | self._in_buffer.size = len(self._source_buffer) | |
928 | self._in_buffer.pos = 0 |
|
981 | self._in_buffer.pos = 0 | |
929 | else: |
|
982 | else: | |
930 | self._source_buffer = ffi.from_buffer(self._source) |
|
983 | self._source_buffer = ffi.from_buffer(self._source) | |
931 | self._in_buffer.src = self._source_buffer |
|
984 | self._in_buffer.src = self._source_buffer | |
932 | self._in_buffer.size = len(self._source_buffer) |
|
985 | self._in_buffer.size = len(self._source_buffer) | |
933 | self._in_buffer.pos = 0 |
|
986 | self._in_buffer.pos = 0 | |
934 |
|
987 | |||
935 | def _compress_into_buffer(self, out_buffer): |
|
988 | def _compress_into_buffer(self, out_buffer): | |
936 | if self._in_buffer.pos >= self._in_buffer.size: |
|
989 | if self._in_buffer.pos >= self._in_buffer.size: | |
937 | return |
|
990 | return | |
938 |
|
991 | |||
939 | old_pos = out_buffer.pos |
|
992 | old_pos = out_buffer.pos | |
940 |
|
993 | |||
941 | zresult = lib.ZSTD_compressStream2( |
|
994 | zresult = lib.ZSTD_compressStream2( | |
942 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_continue |
|
995 | self._compressor._cctx, | |
|
996 | out_buffer, | |||
|
997 | self._in_buffer, | |||
|
998 | lib.ZSTD_e_continue, | |||
943 | ) |
|
999 | ) | |
944 |
|
1000 | |||
945 | self._bytes_compressed += out_buffer.pos - old_pos |
|
1001 | self._bytes_compressed += out_buffer.pos - old_pos | |
946 |
|
1002 | |||
947 | if self._in_buffer.pos == self._in_buffer.size: |
|
1003 | if self._in_buffer.pos == self._in_buffer.size: | |
948 | self._in_buffer.src = ffi.NULL |
|
1004 | self._in_buffer.src = ffi.NULL | |
949 | self._in_buffer.pos = 0 |
|
1005 | self._in_buffer.pos = 0 | |
950 | self._in_buffer.size = 0 |
|
1006 | self._in_buffer.size = 0 | |
951 | self._source_buffer = None |
|
1007 | self._source_buffer = None | |
952 |
|
1008 | |||
953 | if not hasattr(self._source, "read"): |
|
1009 | if not hasattr(self._source, "read"): | |
954 | self._finished_input = True |
|
1010 | self._finished_input = True | |
955 |
|
1011 | |||
956 | if lib.ZSTD_isError(zresult): |
|
1012 | if lib.ZSTD_isError(zresult): | |
957 | raise ZstdError("zstd compress error: %s", _zstd_error(zresult)) |
|
1013 | raise ZstdError("zstd compress error: %s", _zstd_error(zresult)) | |
958 |
|
1014 | |||
959 | return out_buffer.pos and out_buffer.pos == out_buffer.size |
|
1015 | return out_buffer.pos and out_buffer.pos == out_buffer.size | |
960 |
|
1016 | |||
961 | def read(self, size=-1): |
|
1017 | def read(self, size=-1): | |
962 | if self._closed: |
|
1018 | if self._closed: | |
963 | raise ValueError("stream is closed") |
|
1019 | raise ValueError("stream is closed") | |
964 |
|
1020 | |||
965 | if size < -1: |
|
1021 | if size < -1: | |
966 | raise ValueError("cannot read negative amounts less than -1") |
|
1022 | raise ValueError("cannot read negative amounts less than -1") | |
967 |
|
1023 | |||
968 | if size == -1: |
|
1024 | if size == -1: | |
969 | return self.readall() |
|
1025 | return self.readall() | |
970 |
|
1026 | |||
971 | if self._finished_output or size == 0: |
|
1027 | if self._finished_output or size == 0: | |
972 | return b"" |
|
1028 | return b"" | |
973 |
|
1029 | |||
974 | # Need a dedicated ref to dest buffer otherwise it gets collected. |
|
1030 | # Need a dedicated ref to dest buffer otherwise it gets collected. | |
975 | dst_buffer = ffi.new("char[]", size) |
|
1031 | dst_buffer = ffi.new("char[]", size) | |
976 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1032 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
977 | out_buffer.dst = dst_buffer |
|
1033 | out_buffer.dst = dst_buffer | |
978 | out_buffer.size = size |
|
1034 | out_buffer.size = size | |
979 | out_buffer.pos = 0 |
|
1035 | out_buffer.pos = 0 | |
980 |
|
1036 | |||
981 | if self._compress_into_buffer(out_buffer): |
|
1037 | if self._compress_into_buffer(out_buffer): | |
982 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1038 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
983 |
|
1039 | |||
984 | while not self._finished_input: |
|
1040 | while not self._finished_input: | |
985 | self._read_input() |
|
1041 | self._read_input() | |
986 |
|
1042 | |||
987 | if self._compress_into_buffer(out_buffer): |
|
1043 | if self._compress_into_buffer(out_buffer): | |
988 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1044 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
989 |
|
1045 | |||
990 | # EOF |
|
1046 | # EOF | |
991 | old_pos = out_buffer.pos |
|
1047 | old_pos = out_buffer.pos | |
992 |
|
1048 | |||
993 | zresult = lib.ZSTD_compressStream2( |
|
1049 | zresult = lib.ZSTD_compressStream2( | |
994 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end |
|
1050 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | |
995 | ) |
|
1051 | ) | |
996 |
|
1052 | |||
997 | self._bytes_compressed += out_buffer.pos - old_pos |
|
1053 | self._bytes_compressed += out_buffer.pos - old_pos | |
998 |
|
1054 | |||
999 | if lib.ZSTD_isError(zresult): |
|
1055 | if lib.ZSTD_isError(zresult): | |
1000 | raise ZstdError("error ending compression stream: %s", _zstd_error(zresult)) |
|
1056 | raise ZstdError( | |
|
1057 | "error ending compression stream: %s", _zstd_error(zresult) | |||
|
1058 | ) | |||
1001 |
|
1059 | |||
1002 | if zresult == 0: |
|
1060 | if zresult == 0: | |
1003 | self._finished_output = True |
|
1061 | self._finished_output = True | |
1004 |
|
1062 | |||
1005 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1063 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
1006 |
|
1064 | |||
1007 | def read1(self, size=-1): |
|
1065 | def read1(self, size=-1): | |
1008 | if self._closed: |
|
1066 | if self._closed: | |
1009 | raise ValueError("stream is closed") |
|
1067 | raise ValueError("stream is closed") | |
1010 |
|
1068 | |||
1011 | if size < -1: |
|
1069 | if size < -1: | |
1012 | raise ValueError("cannot read negative amounts less than -1") |
|
1070 | raise ValueError("cannot read negative amounts less than -1") | |
1013 |
|
1071 | |||
1014 | if self._finished_output or size == 0: |
|
1072 | if self._finished_output or size == 0: | |
1015 | return b"" |
|
1073 | return b"" | |
1016 |
|
1074 | |||
1017 | # -1 returns arbitrary number of bytes. |
|
1075 | # -1 returns arbitrary number of bytes. | |
1018 | if size == -1: |
|
1076 | if size == -1: | |
1019 | size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE |
|
1077 | size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |
1020 |
|
1078 | |||
1021 | dst_buffer = ffi.new("char[]", size) |
|
1079 | dst_buffer = ffi.new("char[]", size) | |
1022 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1080 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1023 | out_buffer.dst = dst_buffer |
|
1081 | out_buffer.dst = dst_buffer | |
1024 | out_buffer.size = size |
|
1082 | out_buffer.size = size | |
1025 | out_buffer.pos = 0 |
|
1083 | out_buffer.pos = 0 | |
1026 |
|
1084 | |||
1027 | # read1() dictates that we can perform at most 1 call to the |
|
1085 | # read1() dictates that we can perform at most 1 call to the | |
1028 | # underlying stream to get input. However, we can't satisfy this |
|
1086 | # underlying stream to get input. However, we can't satisfy this | |
1029 | # restriction with compression because not all input generates output. |
|
1087 | # restriction with compression because not all input generates output. | |
1030 | # It is possible to perform a block flush in order to ensure output. |
|
1088 | # It is possible to perform a block flush in order to ensure output. | |
1031 | # But this may not be desirable behavior. So we allow multiple read() |
|
1089 | # But this may not be desirable behavior. So we allow multiple read() | |
1032 | # to the underlying stream. But unlike read(), we stop once we have |
|
1090 | # to the underlying stream. But unlike read(), we stop once we have | |
1033 | # any output. |
|
1091 | # any output. | |
1034 |
|
1092 | |||
1035 | self._compress_into_buffer(out_buffer) |
|
1093 | self._compress_into_buffer(out_buffer) | |
1036 | if out_buffer.pos: |
|
1094 | if out_buffer.pos: | |
1037 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1095 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
1038 |
|
1096 | |||
1039 | while not self._finished_input: |
|
1097 | while not self._finished_input: | |
1040 | self._read_input() |
|
1098 | self._read_input() | |
1041 |
|
1099 | |||
1042 | # If we've filled the output buffer, return immediately. |
|
1100 | # If we've filled the output buffer, return immediately. | |
1043 | if self._compress_into_buffer(out_buffer): |
|
1101 | if self._compress_into_buffer(out_buffer): | |
1044 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1102 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
1045 |
|
1103 | |||
1046 | # If we've populated the output buffer and we're not at EOF, |
|
1104 | # If we've populated the output buffer and we're not at EOF, | |
1047 | # also return, as we've satisfied the read1() limits. |
|
1105 | # also return, as we've satisfied the read1() limits. | |
1048 | if out_buffer.pos and not self._finished_input: |
|
1106 | if out_buffer.pos and not self._finished_input: | |
1049 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1107 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
1050 |
|
1108 | |||
1051 | # Else if we're at EOS and we have room left in the buffer, |
|
1109 | # Else if we're at EOS and we have room left in the buffer, | |
1052 | # fall through to below and try to add more data to the output. |
|
1110 | # fall through to below and try to add more data to the output. | |
1053 |
|
1111 | |||
1054 | # EOF. |
|
1112 | # EOF. | |
1055 | old_pos = out_buffer.pos |
|
1113 | old_pos = out_buffer.pos | |
1056 |
|
1114 | |||
1057 | zresult = lib.ZSTD_compressStream2( |
|
1115 | zresult = lib.ZSTD_compressStream2( | |
1058 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end |
|
1116 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | |
1059 | ) |
|
1117 | ) | |
1060 |
|
1118 | |||
1061 | self._bytes_compressed += out_buffer.pos - old_pos |
|
1119 | self._bytes_compressed += out_buffer.pos - old_pos | |
1062 |
|
1120 | |||
1063 | if lib.ZSTD_isError(zresult): |
|
1121 | if lib.ZSTD_isError(zresult): | |
1064 | raise ZstdError( |
|
1122 | raise ZstdError( | |
1065 | "error ending compression stream: %s" % _zstd_error(zresult) |
|
1123 | "error ending compression stream: %s" % _zstd_error(zresult) | |
1066 | ) |
|
1124 | ) | |
1067 |
|
1125 | |||
1068 | if zresult == 0: |
|
1126 | if zresult == 0: | |
1069 | self._finished_output = True |
|
1127 | self._finished_output = True | |
1070 |
|
1128 | |||
1071 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1129 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
1072 |
|
1130 | |||
1073 | def readinto(self, b): |
|
1131 | def readinto(self, b): | |
1074 | if self._closed: |
|
1132 | if self._closed: | |
1075 | raise ValueError("stream is closed") |
|
1133 | raise ValueError("stream is closed") | |
1076 |
|
1134 | |||
1077 | if self._finished_output: |
|
1135 | if self._finished_output: | |
1078 | return 0 |
|
1136 | return 0 | |
1079 |
|
1137 | |||
1080 | # TODO use writable=True once we require CFFI >= 1.12. |
|
1138 | # TODO use writable=True once we require CFFI >= 1.12. | |
1081 | dest_buffer = ffi.from_buffer(b) |
|
1139 | dest_buffer = ffi.from_buffer(b) | |
1082 | ffi.memmove(b, b"", 0) |
|
1140 | ffi.memmove(b, b"", 0) | |
1083 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1141 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1084 | out_buffer.dst = dest_buffer |
|
1142 | out_buffer.dst = dest_buffer | |
1085 | out_buffer.size = len(dest_buffer) |
|
1143 | out_buffer.size = len(dest_buffer) | |
1086 | out_buffer.pos = 0 |
|
1144 | out_buffer.pos = 0 | |
1087 |
|
1145 | |||
1088 | if self._compress_into_buffer(out_buffer): |
|
1146 | if self._compress_into_buffer(out_buffer): | |
1089 | return out_buffer.pos |
|
1147 | return out_buffer.pos | |
1090 |
|
1148 | |||
1091 | while not self._finished_input: |
|
1149 | while not self._finished_input: | |
1092 | self._read_input() |
|
1150 | self._read_input() | |
1093 | if self._compress_into_buffer(out_buffer): |
|
1151 | if self._compress_into_buffer(out_buffer): | |
1094 | return out_buffer.pos |
|
1152 | return out_buffer.pos | |
1095 |
|
1153 | |||
1096 | # EOF. |
|
1154 | # EOF. | |
1097 | old_pos = out_buffer.pos |
|
1155 | old_pos = out_buffer.pos | |
1098 | zresult = lib.ZSTD_compressStream2( |
|
1156 | zresult = lib.ZSTD_compressStream2( | |
1099 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end |
|
1157 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | |
1100 | ) |
|
1158 | ) | |
1101 |
|
1159 | |||
1102 | self._bytes_compressed += out_buffer.pos - old_pos |
|
1160 | self._bytes_compressed += out_buffer.pos - old_pos | |
1103 |
|
1161 | |||
1104 | if lib.ZSTD_isError(zresult): |
|
1162 | if lib.ZSTD_isError(zresult): | |
1105 | raise ZstdError("error ending compression stream: %s", _zstd_error(zresult)) |
|
1163 | raise ZstdError( | |
|
1164 | "error ending compression stream: %s", _zstd_error(zresult) | |||
|
1165 | ) | |||
1106 |
|
1166 | |||
1107 | if zresult == 0: |
|
1167 | if zresult == 0: | |
1108 | self._finished_output = True |
|
1168 | self._finished_output = True | |
1109 |
|
1169 | |||
1110 | return out_buffer.pos |
|
1170 | return out_buffer.pos | |
1111 |
|
1171 | |||
1112 | def readinto1(self, b): |
|
1172 | def readinto1(self, b): | |
1113 | if self._closed: |
|
1173 | if self._closed: | |
1114 | raise ValueError("stream is closed") |
|
1174 | raise ValueError("stream is closed") | |
1115 |
|
1175 | |||
1116 | if self._finished_output: |
|
1176 | if self._finished_output: | |
1117 | return 0 |
|
1177 | return 0 | |
1118 |
|
1178 | |||
1119 | # TODO use writable=True once we require CFFI >= 1.12. |
|
1179 | # TODO use writable=True once we require CFFI >= 1.12. | |
1120 | dest_buffer = ffi.from_buffer(b) |
|
1180 | dest_buffer = ffi.from_buffer(b) | |
1121 | ffi.memmove(b, b"", 0) |
|
1181 | ffi.memmove(b, b"", 0) | |
1122 |
|
1182 | |||
1123 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1183 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1124 | out_buffer.dst = dest_buffer |
|
1184 | out_buffer.dst = dest_buffer | |
1125 | out_buffer.size = len(dest_buffer) |
|
1185 | out_buffer.size = len(dest_buffer) | |
1126 | out_buffer.pos = 0 |
|
1186 | out_buffer.pos = 0 | |
1127 |
|
1187 | |||
1128 | self._compress_into_buffer(out_buffer) |
|
1188 | self._compress_into_buffer(out_buffer) | |
1129 | if out_buffer.pos: |
|
1189 | if out_buffer.pos: | |
1130 | return out_buffer.pos |
|
1190 | return out_buffer.pos | |
1131 |
|
1191 | |||
1132 | while not self._finished_input: |
|
1192 | while not self._finished_input: | |
1133 | self._read_input() |
|
1193 | self._read_input() | |
1134 |
|
1194 | |||
1135 | if self._compress_into_buffer(out_buffer): |
|
1195 | if self._compress_into_buffer(out_buffer): | |
1136 | return out_buffer.pos |
|
1196 | return out_buffer.pos | |
1137 |
|
1197 | |||
1138 | if out_buffer.pos and not self._finished_input: |
|
1198 | if out_buffer.pos and not self._finished_input: | |
1139 | return out_buffer.pos |
|
1199 | return out_buffer.pos | |
1140 |
|
1200 | |||
1141 | # EOF. |
|
1201 | # EOF. | |
1142 | old_pos = out_buffer.pos |
|
1202 | old_pos = out_buffer.pos | |
1143 |
|
1203 | |||
1144 | zresult = lib.ZSTD_compressStream2( |
|
1204 | zresult = lib.ZSTD_compressStream2( | |
1145 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end |
|
1205 | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | |
1146 | ) |
|
1206 | ) | |
1147 |
|
1207 | |||
1148 | self._bytes_compressed += out_buffer.pos - old_pos |
|
1208 | self._bytes_compressed += out_buffer.pos - old_pos | |
1149 |
|
1209 | |||
1150 | if lib.ZSTD_isError(zresult): |
|
1210 | if lib.ZSTD_isError(zresult): | |
1151 | raise ZstdError( |
|
1211 | raise ZstdError( | |
1152 | "error ending compression stream: %s" % _zstd_error(zresult) |
|
1212 | "error ending compression stream: %s" % _zstd_error(zresult) | |
1153 | ) |
|
1213 | ) | |
1154 |
|
1214 | |||
1155 | if zresult == 0: |
|
1215 | if zresult == 0: | |
1156 | self._finished_output = True |
|
1216 | self._finished_output = True | |
1157 |
|
1217 | |||
1158 | return out_buffer.pos |
|
1218 | return out_buffer.pos | |
1159 |
|
1219 | |||
1160 |
|
1220 | |||
1161 | class ZstdCompressor(object): |
|
1221 | class ZstdCompressor(object): | |
1162 | def __init__( |
|
1222 | def __init__( | |
1163 | self, |
|
1223 | self, | |
1164 | level=3, |
|
1224 | level=3, | |
1165 | dict_data=None, |
|
1225 | dict_data=None, | |
1166 | compression_params=None, |
|
1226 | compression_params=None, | |
1167 | write_checksum=None, |
|
1227 | write_checksum=None, | |
1168 | write_content_size=None, |
|
1228 | write_content_size=None, | |
1169 | write_dict_id=None, |
|
1229 | write_dict_id=None, | |
1170 | threads=0, |
|
1230 | threads=0, | |
1171 | ): |
|
1231 | ): | |
1172 | if level > lib.ZSTD_maxCLevel(): |
|
1232 | if level > lib.ZSTD_maxCLevel(): | |
1173 | raise ValueError("level must be less than %d" % lib.ZSTD_maxCLevel()) |
|
1233 | raise ValueError( | |
|
1234 | "level must be less than %d" % lib.ZSTD_maxCLevel() | |||
|
1235 | ) | |||
1174 |
|
1236 | |||
1175 | if threads < 0: |
|
1237 | if threads < 0: | |
1176 | threads = _cpu_count() |
|
1238 | threads = _cpu_count() | |
1177 |
|
1239 | |||
1178 | if compression_params and write_checksum is not None: |
|
1240 | if compression_params and write_checksum is not None: | |
1179 | raise ValueError("cannot define compression_params and " "write_checksum") |
|
1241 | raise ValueError( | |
|
1242 | "cannot define compression_params and " "write_checksum" | |||
|
1243 | ) | |||
1180 |
|
1244 | |||
1181 | if compression_params and write_content_size is not None: |
|
1245 | if compression_params and write_content_size is not None: | |
1182 | raise ValueError( |
|
1246 | raise ValueError( | |
1183 | "cannot define compression_params and " "write_content_size" |
|
1247 | "cannot define compression_params and " "write_content_size" | |
1184 | ) |
|
1248 | ) | |
1185 |
|
1249 | |||
1186 | if compression_params and write_dict_id is not None: |
|
1250 | if compression_params and write_dict_id is not None: | |
1187 | raise ValueError("cannot define compression_params and " "write_dict_id") |
|
1251 | raise ValueError( | |
|
1252 | "cannot define compression_params and " "write_dict_id" | |||
|
1253 | ) | |||
1188 |
|
1254 | |||
1189 | if compression_params and threads: |
|
1255 | if compression_params and threads: | |
1190 | raise ValueError("cannot define compression_params and threads") |
|
1256 | raise ValueError("cannot define compression_params and threads") | |
1191 |
|
1257 | |||
1192 | if compression_params: |
|
1258 | if compression_params: | |
1193 | self._params = _make_cctx_params(compression_params) |
|
1259 | self._params = _make_cctx_params(compression_params) | |
1194 | else: |
|
1260 | else: | |
1195 | if write_dict_id is None: |
|
1261 | if write_dict_id is None: | |
1196 | write_dict_id = True |
|
1262 | write_dict_id = True | |
1197 |
|
1263 | |||
1198 | params = lib.ZSTD_createCCtxParams() |
|
1264 | params = lib.ZSTD_createCCtxParams() | |
1199 | if params == ffi.NULL: |
|
1265 | if params == ffi.NULL: | |
1200 | raise MemoryError() |
|
1266 | raise MemoryError() | |
1201 |
|
1267 | |||
1202 | self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams) |
|
1268 | self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams) | |
1203 |
|
1269 | |||
1204 |
_set_compression_parameter( |
|
1270 | _set_compression_parameter( | |
|
1271 | self._params, lib.ZSTD_c_compressionLevel, level | |||
|
1272 | ) | |||
1205 |
|
1273 | |||
1206 | _set_compression_parameter( |
|
1274 | _set_compression_parameter( | |
1207 | self._params, |
|
1275 | self._params, | |
1208 | lib.ZSTD_c_contentSizeFlag, |
|
1276 | lib.ZSTD_c_contentSizeFlag, | |
1209 | write_content_size if write_content_size is not None else 1, |
|
1277 | write_content_size if write_content_size is not None else 1, | |
1210 | ) |
|
1278 | ) | |
1211 |
|
1279 | |||
1212 | _set_compression_parameter( |
|
1280 | _set_compression_parameter( | |
1213 | self._params, lib.ZSTD_c_checksumFlag, 1 if write_checksum else 0 |
|
1281 | self._params, | |
|
1282 | lib.ZSTD_c_checksumFlag, | |||
|
1283 | 1 if write_checksum else 0, | |||
1214 | ) |
|
1284 | ) | |
1215 |
|
1285 | |||
1216 | _set_compression_parameter( |
|
1286 | _set_compression_parameter( | |
1217 | self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0 |
|
1287 | self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0 | |
1218 | ) |
|
1288 | ) | |
1219 |
|
1289 | |||
1220 | if threads: |
|
1290 | if threads: | |
1221 |
_set_compression_parameter( |
|
1291 | _set_compression_parameter( | |
|
1292 | self._params, lib.ZSTD_c_nbWorkers, threads | |||
|
1293 | ) | |||
1222 |
|
1294 | |||
1223 | cctx = lib.ZSTD_createCCtx() |
|
1295 | cctx = lib.ZSTD_createCCtx() | |
1224 | if cctx == ffi.NULL: |
|
1296 | if cctx == ffi.NULL: | |
1225 | raise MemoryError() |
|
1297 | raise MemoryError() | |
1226 |
|
1298 | |||
1227 | self._cctx = cctx |
|
1299 | self._cctx = cctx | |
1228 | self._dict_data = dict_data |
|
1300 | self._dict_data = dict_data | |
1229 |
|
1301 | |||
1230 | # We defer setting up garbage collection until after calling |
|
1302 | # We defer setting up garbage collection until after calling | |
1231 | # _setup_cctx() to ensure the memory size estimate is more accurate. |
|
1303 | # _setup_cctx() to ensure the memory size estimate is more accurate. | |
1232 | try: |
|
1304 | try: | |
1233 | self._setup_cctx() |
|
1305 | self._setup_cctx() | |
1234 | finally: |
|
1306 | finally: | |
1235 | self._cctx = ffi.gc( |
|
1307 | self._cctx = ffi.gc( | |
1236 | cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx) |
|
1308 | cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx) | |
1237 | ) |
|
1309 | ) | |
1238 |
|
1310 | |||
1239 | def _setup_cctx(self): |
|
1311 | def _setup_cctx(self): | |
1240 |
zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams( |
|
1312 | zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams( | |
|
1313 | self._cctx, self._params | |||
|
1314 | ) | |||
1241 | if lib.ZSTD_isError(zresult): |
|
1315 | if lib.ZSTD_isError(zresult): | |
1242 | raise ZstdError( |
|
1316 | raise ZstdError( | |
1243 |
"could not set compression parameters: %s" |
|
1317 | "could not set compression parameters: %s" | |
|
1318 | % _zstd_error(zresult) | |||
1244 | ) |
|
1319 | ) | |
1245 |
|
1320 | |||
1246 | dict_data = self._dict_data |
|
1321 | dict_data = self._dict_data | |
1247 |
|
1322 | |||
1248 | if dict_data: |
|
1323 | if dict_data: | |
1249 | if dict_data._cdict: |
|
1324 | if dict_data._cdict: | |
1250 | zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict) |
|
1325 | zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict) | |
1251 | else: |
|
1326 | else: | |
1252 | zresult = lib.ZSTD_CCtx_loadDictionary_advanced( |
|
1327 | zresult = lib.ZSTD_CCtx_loadDictionary_advanced( | |
1253 | self._cctx, |
|
1328 | self._cctx, | |
1254 | dict_data.as_bytes(), |
|
1329 | dict_data.as_bytes(), | |
1255 | len(dict_data), |
|
1330 | len(dict_data), | |
1256 | lib.ZSTD_dlm_byRef, |
|
1331 | lib.ZSTD_dlm_byRef, | |
1257 | dict_data._dict_type, |
|
1332 | dict_data._dict_type, | |
1258 | ) |
|
1333 | ) | |
1259 |
|
1334 | |||
1260 | if lib.ZSTD_isError(zresult): |
|
1335 | if lib.ZSTD_isError(zresult): | |
1261 | raise ZstdError( |
|
1336 | raise ZstdError( | |
1262 |
"could not load compression dictionary: %s" |
|
1337 | "could not load compression dictionary: %s" | |
|
1338 | % _zstd_error(zresult) | |||
1263 | ) |
|
1339 | ) | |
1264 |
|
1340 | |||
1265 | def memory_size(self): |
|
1341 | def memory_size(self): | |
1266 | return lib.ZSTD_sizeof_CCtx(self._cctx) |
|
1342 | return lib.ZSTD_sizeof_CCtx(self._cctx) | |
1267 |
|
1343 | |||
1268 | def compress(self, data): |
|
1344 | def compress(self, data): | |
1269 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) |
|
1345 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | |
1270 |
|
1346 | |||
1271 | data_buffer = ffi.from_buffer(data) |
|
1347 | data_buffer = ffi.from_buffer(data) | |
1272 |
|
1348 | |||
1273 | dest_size = lib.ZSTD_compressBound(len(data_buffer)) |
|
1349 | dest_size = lib.ZSTD_compressBound(len(data_buffer)) | |
1274 | out = new_nonzero("char[]", dest_size) |
|
1350 | out = new_nonzero("char[]", dest_size) | |
1275 |
|
1351 | |||
1276 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer)) |
|
1352 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer)) | |
1277 | if lib.ZSTD_isError(zresult): |
|
1353 | if lib.ZSTD_isError(zresult): | |
1278 | raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) |
|
1354 | raise ZstdError( | |
|
1355 | "error setting source size: %s" % _zstd_error(zresult) | |||
|
1356 | ) | |||
1279 |
|
1357 | |||
1280 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1358 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1281 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
1359 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
1282 |
|
1360 | |||
1283 | out_buffer.dst = out |
|
1361 | out_buffer.dst = out | |
1284 | out_buffer.size = dest_size |
|
1362 | out_buffer.size = dest_size | |
1285 | out_buffer.pos = 0 |
|
1363 | out_buffer.pos = 0 | |
1286 |
|
1364 | |||
1287 | in_buffer.src = data_buffer |
|
1365 | in_buffer.src = data_buffer | |
1288 | in_buffer.size = len(data_buffer) |
|
1366 | in_buffer.size = len(data_buffer) | |
1289 | in_buffer.pos = 0 |
|
1367 | in_buffer.pos = 0 | |
1290 |
|
1368 | |||
1291 | zresult = lib.ZSTD_compressStream2( |
|
1369 | zresult = lib.ZSTD_compressStream2( | |
1292 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end |
|
1370 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end | |
1293 | ) |
|
1371 | ) | |
1294 |
|
1372 | |||
1295 | if lib.ZSTD_isError(zresult): |
|
1373 | if lib.ZSTD_isError(zresult): | |
1296 | raise ZstdError("cannot compress: %s" % _zstd_error(zresult)) |
|
1374 | raise ZstdError("cannot compress: %s" % _zstd_error(zresult)) | |
1297 | elif zresult: |
|
1375 | elif zresult: | |
1298 | raise ZstdError("unexpected partial frame flush") |
|
1376 | raise ZstdError("unexpected partial frame flush") | |
1299 |
|
1377 | |||
1300 | return ffi.buffer(out, out_buffer.pos)[:] |
|
1378 | return ffi.buffer(out, out_buffer.pos)[:] | |
1301 |
|
1379 | |||
1302 | def compressobj(self, size=-1): |
|
1380 | def compressobj(self, size=-1): | |
1303 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) |
|
1381 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | |
1304 |
|
1382 | |||
1305 | if size < 0: |
|
1383 | if size < 0: | |
1306 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN |
|
1384 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | |
1307 |
|
1385 | |||
1308 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) |
|
1386 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | |
1309 | if lib.ZSTD_isError(zresult): |
|
1387 | if lib.ZSTD_isError(zresult): | |
1310 | raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) |
|
1388 | raise ZstdError( | |
|
1389 | "error setting source size: %s" % _zstd_error(zresult) | |||
|
1390 | ) | |||
1311 |
|
1391 | |||
1312 | cobj = ZstdCompressionObj() |
|
1392 | cobj = ZstdCompressionObj() | |
1313 | cobj._out = ffi.new("ZSTD_outBuffer *") |
|
1393 | cobj._out = ffi.new("ZSTD_outBuffer *") | |
1314 |
cobj._dst_buffer = ffi.new( |
|
1394 | cobj._dst_buffer = ffi.new( | |
|
1395 | "char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||
|
1396 | ) | |||
1315 | cobj._out.dst = cobj._dst_buffer |
|
1397 | cobj._out.dst = cobj._dst_buffer | |
1316 | cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE |
|
1398 | cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |
1317 | cobj._out.pos = 0 |
|
1399 | cobj._out.pos = 0 | |
1318 | cobj._compressor = self |
|
1400 | cobj._compressor = self | |
1319 | cobj._finished = False |
|
1401 | cobj._finished = False | |
1320 |
|
1402 | |||
1321 | return cobj |
|
1403 | return cobj | |
1322 |
|
1404 | |||
1323 | def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): |
|
1405 | def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): | |
1324 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) |
|
1406 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | |
1325 |
|
1407 | |||
1326 | if size < 0: |
|
1408 | if size < 0: | |
1327 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN |
|
1409 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | |
1328 |
|
1410 | |||
1329 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) |
|
1411 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | |
1330 | if lib.ZSTD_isError(zresult): |
|
1412 | if lib.ZSTD_isError(zresult): | |
1331 | raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) |
|
1413 | raise ZstdError( | |
|
1414 | "error setting source size: %s" % _zstd_error(zresult) | |||
|
1415 | ) | |||
1332 |
|
1416 | |||
1333 | return ZstdCompressionChunker(self, chunk_size=chunk_size) |
|
1417 | return ZstdCompressionChunker(self, chunk_size=chunk_size) | |
1334 |
|
1418 | |||
1335 | def copy_stream( |
|
1419 | def copy_stream( | |
1336 | self, |
|
1420 | self, | |
1337 | ifh, |
|
1421 | ifh, | |
1338 | ofh, |
|
1422 | ofh, | |
1339 | size=-1, |
|
1423 | size=-1, | |
1340 | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, |
|
1424 | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | |
1341 | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, |
|
1425 | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |
1342 | ): |
|
1426 | ): | |
1343 |
|
1427 | |||
1344 | if not hasattr(ifh, "read"): |
|
1428 | if not hasattr(ifh, "read"): | |
1345 | raise ValueError("first argument must have a read() method") |
|
1429 | raise ValueError("first argument must have a read() method") | |
1346 | if not hasattr(ofh, "write"): |
|
1430 | if not hasattr(ofh, "write"): | |
1347 | raise ValueError("second argument must have a write() method") |
|
1431 | raise ValueError("second argument must have a write() method") | |
1348 |
|
1432 | |||
1349 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) |
|
1433 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | |
1350 |
|
1434 | |||
1351 | if size < 0: |
|
1435 | if size < 0: | |
1352 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN |
|
1436 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | |
1353 |
|
1437 | |||
1354 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) |
|
1438 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | |
1355 | if lib.ZSTD_isError(zresult): |
|
1439 | if lib.ZSTD_isError(zresult): | |
1356 | raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) |
|
1440 | raise ZstdError( | |
|
1441 | "error setting source size: %s" % _zstd_error(zresult) | |||
|
1442 | ) | |||
1357 |
|
1443 | |||
1358 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
1444 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
1359 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1445 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1360 |
|
1446 | |||
1361 | dst_buffer = ffi.new("char[]", write_size) |
|
1447 | dst_buffer = ffi.new("char[]", write_size) | |
1362 | out_buffer.dst = dst_buffer |
|
1448 | out_buffer.dst = dst_buffer | |
1363 | out_buffer.size = write_size |
|
1449 | out_buffer.size = write_size | |
1364 | out_buffer.pos = 0 |
|
1450 | out_buffer.pos = 0 | |
1365 |
|
1451 | |||
1366 | total_read, total_write = 0, 0 |
|
1452 | total_read, total_write = 0, 0 | |
1367 |
|
1453 | |||
1368 | while True: |
|
1454 | while True: | |
1369 | data = ifh.read(read_size) |
|
1455 | data = ifh.read(read_size) | |
1370 | if not data: |
|
1456 | if not data: | |
1371 | break |
|
1457 | break | |
1372 |
|
1458 | |||
1373 | data_buffer = ffi.from_buffer(data) |
|
1459 | data_buffer = ffi.from_buffer(data) | |
1374 | total_read += len(data_buffer) |
|
1460 | total_read += len(data_buffer) | |
1375 | in_buffer.src = data_buffer |
|
1461 | in_buffer.src = data_buffer | |
1376 | in_buffer.size = len(data_buffer) |
|
1462 | in_buffer.size = len(data_buffer) | |
1377 | in_buffer.pos = 0 |
|
1463 | in_buffer.pos = 0 | |
1378 |
|
1464 | |||
1379 | while in_buffer.pos < in_buffer.size: |
|
1465 | while in_buffer.pos < in_buffer.size: | |
1380 | zresult = lib.ZSTD_compressStream2( |
|
1466 | zresult = lib.ZSTD_compressStream2( | |
1381 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue |
|
1467 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | |
1382 | ) |
|
1468 | ) | |
1383 | if lib.ZSTD_isError(zresult): |
|
1469 | if lib.ZSTD_isError(zresult): | |
1384 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
1470 | raise ZstdError( | |
|
1471 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
1472 | ) | |||
1385 |
|
1473 | |||
1386 | if out_buffer.pos: |
|
1474 | if out_buffer.pos: | |
1387 | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) |
|
1475 | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | |
1388 | total_write += out_buffer.pos |
|
1476 | total_write += out_buffer.pos | |
1389 | out_buffer.pos = 0 |
|
1477 | out_buffer.pos = 0 | |
1390 |
|
1478 | |||
1391 | # We've finished reading. Flush the compressor. |
|
1479 | # We've finished reading. Flush the compressor. | |
1392 | while True: |
|
1480 | while True: | |
1393 | zresult = lib.ZSTD_compressStream2( |
|
1481 | zresult = lib.ZSTD_compressStream2( | |
1394 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end |
|
1482 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end | |
1395 | ) |
|
1483 | ) | |
1396 | if lib.ZSTD_isError(zresult): |
|
1484 | if lib.ZSTD_isError(zresult): | |
1397 | raise ZstdError( |
|
1485 | raise ZstdError( | |
1398 | "error ending compression stream: %s" % _zstd_error(zresult) |
|
1486 | "error ending compression stream: %s" % _zstd_error(zresult) | |
1399 | ) |
|
1487 | ) | |
1400 |
|
1488 | |||
1401 | if out_buffer.pos: |
|
1489 | if out_buffer.pos: | |
1402 | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) |
|
1490 | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | |
1403 | total_write += out_buffer.pos |
|
1491 | total_write += out_buffer.pos | |
1404 | out_buffer.pos = 0 |
|
1492 | out_buffer.pos = 0 | |
1405 |
|
1493 | |||
1406 | if zresult == 0: |
|
1494 | if zresult == 0: | |
1407 | break |
|
1495 | break | |
1408 |
|
1496 | |||
1409 | return total_read, total_write |
|
1497 | return total_read, total_write | |
1410 |
|
1498 | |||
1411 | def stream_reader( |
|
1499 | def stream_reader( | |
1412 | self, source, size=-1, read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE |
|
1500 | self, source, size=-1, read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE | |
1413 | ): |
|
1501 | ): | |
1414 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) |
|
1502 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | |
1415 |
|
1503 | |||
1416 | try: |
|
1504 | try: | |
1417 | size = len(source) |
|
1505 | size = len(source) | |
1418 | except Exception: |
|
1506 | except Exception: | |
1419 | pass |
|
1507 | pass | |
1420 |
|
1508 | |||
1421 | if size < 0: |
|
1509 | if size < 0: | |
1422 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN |
|
1510 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | |
1423 |
|
1511 | |||
1424 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) |
|
1512 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | |
1425 | if lib.ZSTD_isError(zresult): |
|
1513 | if lib.ZSTD_isError(zresult): | |
1426 | raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) |
|
1514 | raise ZstdError( | |
|
1515 | "error setting source size: %s" % _zstd_error(zresult) | |||
|
1516 | ) | |||
1427 |
|
1517 | |||
1428 | return ZstdCompressionReader(self, source, read_size) |
|
1518 | return ZstdCompressionReader(self, source, read_size) | |
1429 |
|
1519 | |||
1430 | def stream_writer( |
|
1520 | def stream_writer( | |
1431 | self, |
|
1521 | self, | |
1432 | writer, |
|
1522 | writer, | |
1433 | size=-1, |
|
1523 | size=-1, | |
1434 | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, |
|
1524 | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |
1435 | write_return_read=False, |
|
1525 | write_return_read=False, | |
1436 | ): |
|
1526 | ): | |
1437 |
|
1527 | |||
1438 | if not hasattr(writer, "write"): |
|
1528 | if not hasattr(writer, "write"): | |
1439 | raise ValueError("must pass an object with a write() method") |
|
1529 | raise ValueError("must pass an object with a write() method") | |
1440 |
|
1530 | |||
1441 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) |
|
1531 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | |
1442 |
|
1532 | |||
1443 | if size < 0: |
|
1533 | if size < 0: | |
1444 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN |
|
1534 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | |
1445 |
|
1535 | |||
1446 | return ZstdCompressionWriter(self, writer, size, write_size, write_return_read) |
|
1536 | return ZstdCompressionWriter( | |
|
1537 | self, writer, size, write_size, write_return_read | |||
|
1538 | ) | |||
1447 |
|
1539 | |||
1448 | write_to = stream_writer |
|
1540 | write_to = stream_writer | |
1449 |
|
1541 | |||
1450 | def read_to_iter( |
|
1542 | def read_to_iter( | |
1451 | self, |
|
1543 | self, | |
1452 | reader, |
|
1544 | reader, | |
1453 | size=-1, |
|
1545 | size=-1, | |
1454 | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, |
|
1546 | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | |
1455 | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, |
|
1547 | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |
1456 | ): |
|
1548 | ): | |
1457 | if hasattr(reader, "read"): |
|
1549 | if hasattr(reader, "read"): | |
1458 | have_read = True |
|
1550 | have_read = True | |
1459 | elif hasattr(reader, "__getitem__"): |
|
1551 | elif hasattr(reader, "__getitem__"): | |
1460 | have_read = False |
|
1552 | have_read = False | |
1461 | buffer_offset = 0 |
|
1553 | buffer_offset = 0 | |
1462 | size = len(reader) |
|
1554 | size = len(reader) | |
1463 | else: |
|
1555 | else: | |
1464 | raise ValueError( |
|
1556 | raise ValueError( | |
1465 | "must pass an object with a read() method or " |
|
1557 | "must pass an object with a read() method or " | |
1466 | "conforms to buffer protocol" |
|
1558 | "conforms to buffer protocol" | |
1467 | ) |
|
1559 | ) | |
1468 |
|
1560 | |||
1469 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) |
|
1561 | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | |
1470 |
|
1562 | |||
1471 | if size < 0: |
|
1563 | if size < 0: | |
1472 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN |
|
1564 | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | |
1473 |
|
1565 | |||
1474 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) |
|
1566 | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | |
1475 | if lib.ZSTD_isError(zresult): |
|
1567 | if lib.ZSTD_isError(zresult): | |
1476 | raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) |
|
1568 | raise ZstdError( | |
|
1569 | "error setting source size: %s" % _zstd_error(zresult) | |||
|
1570 | ) | |||
1477 |
|
1571 | |||
1478 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
1572 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
1479 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1573 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1480 |
|
1574 | |||
1481 | in_buffer.src = ffi.NULL |
|
1575 | in_buffer.src = ffi.NULL | |
1482 | in_buffer.size = 0 |
|
1576 | in_buffer.size = 0 | |
1483 | in_buffer.pos = 0 |
|
1577 | in_buffer.pos = 0 | |
1484 |
|
1578 | |||
1485 | dst_buffer = ffi.new("char[]", write_size) |
|
1579 | dst_buffer = ffi.new("char[]", write_size) | |
1486 | out_buffer.dst = dst_buffer |
|
1580 | out_buffer.dst = dst_buffer | |
1487 | out_buffer.size = write_size |
|
1581 | out_buffer.size = write_size | |
1488 | out_buffer.pos = 0 |
|
1582 | out_buffer.pos = 0 | |
1489 |
|
1583 | |||
1490 | while True: |
|
1584 | while True: | |
1491 | # We should never have output data sitting around after a previous |
|
1585 | # We should never have output data sitting around after a previous | |
1492 | # iteration. |
|
1586 | # iteration. | |
1493 | assert out_buffer.pos == 0 |
|
1587 | assert out_buffer.pos == 0 | |
1494 |
|
1588 | |||
1495 | # Collect input data. |
|
1589 | # Collect input data. | |
1496 | if have_read: |
|
1590 | if have_read: | |
1497 | read_result = reader.read(read_size) |
|
1591 | read_result = reader.read(read_size) | |
1498 | else: |
|
1592 | else: | |
1499 | remaining = len(reader) - buffer_offset |
|
1593 | remaining = len(reader) - buffer_offset | |
1500 | slice_size = min(remaining, read_size) |
|
1594 | slice_size = min(remaining, read_size) | |
1501 | read_result = reader[buffer_offset : buffer_offset + slice_size] |
|
1595 | read_result = reader[buffer_offset : buffer_offset + slice_size] | |
1502 | buffer_offset += slice_size |
|
1596 | buffer_offset += slice_size | |
1503 |
|
1597 | |||
1504 | # No new input data. Break out of the read loop. |
|
1598 | # No new input data. Break out of the read loop. | |
1505 | if not read_result: |
|
1599 | if not read_result: | |
1506 | break |
|
1600 | break | |
1507 |
|
1601 | |||
1508 | # Feed all read data into the compressor and emit output until |
|
1602 | # Feed all read data into the compressor and emit output until | |
1509 | # exhausted. |
|
1603 | # exhausted. | |
1510 | read_buffer = ffi.from_buffer(read_result) |
|
1604 | read_buffer = ffi.from_buffer(read_result) | |
1511 | in_buffer.src = read_buffer |
|
1605 | in_buffer.src = read_buffer | |
1512 | in_buffer.size = len(read_buffer) |
|
1606 | in_buffer.size = len(read_buffer) | |
1513 | in_buffer.pos = 0 |
|
1607 | in_buffer.pos = 0 | |
1514 |
|
1608 | |||
1515 | while in_buffer.pos < in_buffer.size: |
|
1609 | while in_buffer.pos < in_buffer.size: | |
1516 | zresult = lib.ZSTD_compressStream2( |
|
1610 | zresult = lib.ZSTD_compressStream2( | |
1517 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue |
|
1611 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | |
1518 | ) |
|
1612 | ) | |
1519 | if lib.ZSTD_isError(zresult): |
|
1613 | if lib.ZSTD_isError(zresult): | |
1520 | raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) |
|
1614 | raise ZstdError( | |
|
1615 | "zstd compress error: %s" % _zstd_error(zresult) | |||
|
1616 | ) | |||
1521 |
|
1617 | |||
1522 | if out_buffer.pos: |
|
1618 | if out_buffer.pos: | |
1523 | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1619 | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
1524 | out_buffer.pos = 0 |
|
1620 | out_buffer.pos = 0 | |
1525 | yield data |
|
1621 | yield data | |
1526 |
|
1622 | |||
1527 | assert out_buffer.pos == 0 |
|
1623 | assert out_buffer.pos == 0 | |
1528 |
|
1624 | |||
1529 | # And repeat the loop to collect more data. |
|
1625 | # And repeat the loop to collect more data. | |
1530 | continue |
|
1626 | continue | |
1531 |
|
1627 | |||
1532 | # If we get here, input is exhausted. End the stream and emit what |
|
1628 | # If we get here, input is exhausted. End the stream and emit what | |
1533 | # remains. |
|
1629 | # remains. | |
1534 | while True: |
|
1630 | while True: | |
1535 | assert out_buffer.pos == 0 |
|
1631 | assert out_buffer.pos == 0 | |
1536 | zresult = lib.ZSTD_compressStream2( |
|
1632 | zresult = lib.ZSTD_compressStream2( | |
1537 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end |
|
1633 | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end | |
1538 | ) |
|
1634 | ) | |
1539 | if lib.ZSTD_isError(zresult): |
|
1635 | if lib.ZSTD_isError(zresult): | |
1540 | raise ZstdError( |
|
1636 | raise ZstdError( | |
1541 | "error ending compression stream: %s" % _zstd_error(zresult) |
|
1637 | "error ending compression stream: %s" % _zstd_error(zresult) | |
1542 | ) |
|
1638 | ) | |
1543 |
|
1639 | |||
1544 | if out_buffer.pos: |
|
1640 | if out_buffer.pos: | |
1545 | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
1641 | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
1546 | out_buffer.pos = 0 |
|
1642 | out_buffer.pos = 0 | |
1547 | yield data |
|
1643 | yield data | |
1548 |
|
1644 | |||
1549 | if zresult == 0: |
|
1645 | if zresult == 0: | |
1550 | break |
|
1646 | break | |
1551 |
|
1647 | |||
1552 | read_from = read_to_iter |
|
1648 | read_from = read_to_iter | |
1553 |
|
1649 | |||
1554 | def frame_progression(self): |
|
1650 | def frame_progression(self): | |
1555 | progression = lib.ZSTD_getFrameProgression(self._cctx) |
|
1651 | progression = lib.ZSTD_getFrameProgression(self._cctx) | |
1556 |
|
1652 | |||
1557 | return progression.ingested, progression.consumed, progression.produced |
|
1653 | return progression.ingested, progression.consumed, progression.produced | |
1558 |
|
1654 | |||
1559 |
|
1655 | |||
1560 | class FrameParameters(object): |
|
1656 | class FrameParameters(object): | |
1561 | def __init__(self, fparams): |
|
1657 | def __init__(self, fparams): | |
1562 | self.content_size = fparams.frameContentSize |
|
1658 | self.content_size = fparams.frameContentSize | |
1563 | self.window_size = fparams.windowSize |
|
1659 | self.window_size = fparams.windowSize | |
1564 | self.dict_id = fparams.dictID |
|
1660 | self.dict_id = fparams.dictID | |
1565 | self.has_checksum = bool(fparams.checksumFlag) |
|
1661 | self.has_checksum = bool(fparams.checksumFlag) | |
1566 |
|
1662 | |||
1567 |
|
1663 | |||
1568 | def frame_content_size(data): |
|
1664 | def frame_content_size(data): | |
1569 | data_buffer = ffi.from_buffer(data) |
|
1665 | data_buffer = ffi.from_buffer(data) | |
1570 |
|
1666 | |||
1571 | size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer)) |
|
1667 | size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer)) | |
1572 |
|
1668 | |||
1573 | if size == lib.ZSTD_CONTENTSIZE_ERROR: |
|
1669 | if size == lib.ZSTD_CONTENTSIZE_ERROR: | |
1574 | raise ZstdError("error when determining content size") |
|
1670 | raise ZstdError("error when determining content size") | |
1575 | elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN: |
|
1671 | elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN: | |
1576 | return -1 |
|
1672 | return -1 | |
1577 | else: |
|
1673 | else: | |
1578 | return size |
|
1674 | return size | |
1579 |
|
1675 | |||
1580 |
|
1676 | |||
1581 | def frame_header_size(data): |
|
1677 | def frame_header_size(data): | |
1582 | data_buffer = ffi.from_buffer(data) |
|
1678 | data_buffer = ffi.from_buffer(data) | |
1583 |
|
1679 | |||
1584 | zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer)) |
|
1680 | zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer)) | |
1585 | if lib.ZSTD_isError(zresult): |
|
1681 | if lib.ZSTD_isError(zresult): | |
1586 | raise ZstdError( |
|
1682 | raise ZstdError( | |
1587 | "could not determine frame header size: %s" % _zstd_error(zresult) |
|
1683 | "could not determine frame header size: %s" % _zstd_error(zresult) | |
1588 | ) |
|
1684 | ) | |
1589 |
|
1685 | |||
1590 | return zresult |
|
1686 | return zresult | |
1591 |
|
1687 | |||
1592 |
|
1688 | |||
1593 | def get_frame_parameters(data): |
|
1689 | def get_frame_parameters(data): | |
1594 | params = ffi.new("ZSTD_frameHeader *") |
|
1690 | params = ffi.new("ZSTD_frameHeader *") | |
1595 |
|
1691 | |||
1596 | data_buffer = ffi.from_buffer(data) |
|
1692 | data_buffer = ffi.from_buffer(data) | |
1597 | zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) |
|
1693 | zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) | |
1598 | if lib.ZSTD_isError(zresult): |
|
1694 | if lib.ZSTD_isError(zresult): | |
1599 | raise ZstdError("cannot get frame parameters: %s" % _zstd_error(zresult)) |
|
1695 | raise ZstdError( | |
|
1696 | "cannot get frame parameters: %s" % _zstd_error(zresult) | |||
|
1697 | ) | |||
1600 |
|
1698 | |||
1601 | if zresult: |
|
1699 | if zresult: | |
1602 | raise ZstdError("not enough data for frame parameters; need %d bytes" % zresult) |
|
1700 | raise ZstdError( | |
|
1701 | "not enough data for frame parameters; need %d bytes" % zresult | |||
|
1702 | ) | |||
1603 |
|
1703 | |||
1604 | return FrameParameters(params[0]) |
|
1704 | return FrameParameters(params[0]) | |
1605 |
|
1705 | |||
1606 |
|
1706 | |||
1607 | class ZstdCompressionDict(object): |
|
1707 | class ZstdCompressionDict(object): | |
1608 | def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0): |
|
1708 | def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0): | |
1609 | assert isinstance(data, bytes_type) |
|
1709 | assert isinstance(data, bytes_type) | |
1610 | self._data = data |
|
1710 | self._data = data | |
1611 | self.k = k |
|
1711 | self.k = k | |
1612 | self.d = d |
|
1712 | self.d = d | |
1613 |
|
1713 | |||
1614 | if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT, DICT_TYPE_FULLDICT): |
|
1714 | if dict_type not in ( | |
|
1715 | DICT_TYPE_AUTO, | |||
|
1716 | DICT_TYPE_RAWCONTENT, | |||
|
1717 | DICT_TYPE_FULLDICT, | |||
|
1718 | ): | |||
1615 | raise ValueError( |
|
1719 | raise ValueError( | |
1616 |
"invalid dictionary load mode: %d; must use " |
|
1720 | "invalid dictionary load mode: %d; must use " | |
|
1721 | "DICT_TYPE_* constants" | |||
1617 | ) |
|
1722 | ) | |
1618 |
|
1723 | |||
1619 | self._dict_type = dict_type |
|
1724 | self._dict_type = dict_type | |
1620 | self._cdict = None |
|
1725 | self._cdict = None | |
1621 |
|
1726 | |||
1622 | def __len__(self): |
|
1727 | def __len__(self): | |
1623 | return len(self._data) |
|
1728 | return len(self._data) | |
1624 |
|
1729 | |||
1625 | def dict_id(self): |
|
1730 | def dict_id(self): | |
1626 | return int_type(lib.ZDICT_getDictID(self._data, len(self._data))) |
|
1731 | return int_type(lib.ZDICT_getDictID(self._data, len(self._data))) | |
1627 |
|
1732 | |||
1628 | def as_bytes(self): |
|
1733 | def as_bytes(self): | |
1629 | return self._data |
|
1734 | return self._data | |
1630 |
|
1735 | |||
1631 | def precompute_compress(self, level=0, compression_params=None): |
|
1736 | def precompute_compress(self, level=0, compression_params=None): | |
1632 | if level and compression_params: |
|
1737 | if level and compression_params: | |
1633 | raise ValueError("must only specify one of level or " "compression_params") |
|
1738 | raise ValueError( | |
|
1739 | "must only specify one of level or " "compression_params" | |||
|
1740 | ) | |||
1634 |
|
1741 | |||
1635 | if not level and not compression_params: |
|
1742 | if not level and not compression_params: | |
1636 | raise ValueError("must specify one of level or compression_params") |
|
1743 | raise ValueError("must specify one of level or compression_params") | |
1637 |
|
1744 | |||
1638 | if level: |
|
1745 | if level: | |
1639 | cparams = lib.ZSTD_getCParams(level, 0, len(self._data)) |
|
1746 | cparams = lib.ZSTD_getCParams(level, 0, len(self._data)) | |
1640 | else: |
|
1747 | else: | |
1641 | cparams = ffi.new("ZSTD_compressionParameters") |
|
1748 | cparams = ffi.new("ZSTD_compressionParameters") | |
1642 | cparams.chainLog = compression_params.chain_log |
|
1749 | cparams.chainLog = compression_params.chain_log | |
1643 | cparams.hashLog = compression_params.hash_log |
|
1750 | cparams.hashLog = compression_params.hash_log | |
1644 | cparams.minMatch = compression_params.min_match |
|
1751 | cparams.minMatch = compression_params.min_match | |
1645 | cparams.searchLog = compression_params.search_log |
|
1752 | cparams.searchLog = compression_params.search_log | |
1646 | cparams.strategy = compression_params.compression_strategy |
|
1753 | cparams.strategy = compression_params.compression_strategy | |
1647 | cparams.targetLength = compression_params.target_length |
|
1754 | cparams.targetLength = compression_params.target_length | |
1648 | cparams.windowLog = compression_params.window_log |
|
1755 | cparams.windowLog = compression_params.window_log | |
1649 |
|
1756 | |||
1650 | cdict = lib.ZSTD_createCDict_advanced( |
|
1757 | cdict = lib.ZSTD_createCDict_advanced( | |
1651 | self._data, |
|
1758 | self._data, | |
1652 | len(self._data), |
|
1759 | len(self._data), | |
1653 | lib.ZSTD_dlm_byRef, |
|
1760 | lib.ZSTD_dlm_byRef, | |
1654 | self._dict_type, |
|
1761 | self._dict_type, | |
1655 | cparams, |
|
1762 | cparams, | |
1656 | lib.ZSTD_defaultCMem, |
|
1763 | lib.ZSTD_defaultCMem, | |
1657 | ) |
|
1764 | ) | |
1658 | if cdict == ffi.NULL: |
|
1765 | if cdict == ffi.NULL: | |
1659 | raise ZstdError("unable to precompute dictionary") |
|
1766 | raise ZstdError("unable to precompute dictionary") | |
1660 |
|
1767 | |||
1661 | self._cdict = ffi.gc( |
|
1768 | self._cdict = ffi.gc( | |
1662 | cdict, lib.ZSTD_freeCDict, size=lib.ZSTD_sizeof_CDict(cdict) |
|
1769 | cdict, lib.ZSTD_freeCDict, size=lib.ZSTD_sizeof_CDict(cdict) | |
1663 | ) |
|
1770 | ) | |
1664 |
|
1771 | |||
1665 | @property |
|
1772 | @property | |
1666 | def _ddict(self): |
|
1773 | def _ddict(self): | |
1667 | ddict = lib.ZSTD_createDDict_advanced( |
|
1774 | ddict = lib.ZSTD_createDDict_advanced( | |
1668 | self._data, |
|
1775 | self._data, | |
1669 | len(self._data), |
|
1776 | len(self._data), | |
1670 | lib.ZSTD_dlm_byRef, |
|
1777 | lib.ZSTD_dlm_byRef, | |
1671 | self._dict_type, |
|
1778 | self._dict_type, | |
1672 | lib.ZSTD_defaultCMem, |
|
1779 | lib.ZSTD_defaultCMem, | |
1673 | ) |
|
1780 | ) | |
1674 |
|
1781 | |||
1675 | if ddict == ffi.NULL: |
|
1782 | if ddict == ffi.NULL: | |
1676 | raise ZstdError("could not create decompression dict") |
|
1783 | raise ZstdError("could not create decompression dict") | |
1677 |
|
1784 | |||
1678 | ddict = ffi.gc(ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict)) |
|
1785 | ddict = ffi.gc( | |
|
1786 | ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict) | |||
|
1787 | ) | |||
1679 | self.__dict__["_ddict"] = ddict |
|
1788 | self.__dict__["_ddict"] = ddict | |
1680 |
|
1789 | |||
1681 | return ddict |
|
1790 | return ddict | |
1682 |
|
1791 | |||
1683 |
|
1792 | |||
1684 | def train_dictionary( |
|
1793 | def train_dictionary( | |
1685 | dict_size, |
|
1794 | dict_size, | |
1686 | samples, |
|
1795 | samples, | |
1687 | k=0, |
|
1796 | k=0, | |
1688 | d=0, |
|
1797 | d=0, | |
1689 | notifications=0, |
|
1798 | notifications=0, | |
1690 | dict_id=0, |
|
1799 | dict_id=0, | |
1691 | level=0, |
|
1800 | level=0, | |
1692 | steps=0, |
|
1801 | steps=0, | |
1693 | threads=0, |
|
1802 | threads=0, | |
1694 | ): |
|
1803 | ): | |
1695 | if not isinstance(samples, list): |
|
1804 | if not isinstance(samples, list): | |
1696 | raise TypeError("samples must be a list") |
|
1805 | raise TypeError("samples must be a list") | |
1697 |
|
1806 | |||
1698 | if threads < 0: |
|
1807 | if threads < 0: | |
1699 | threads = _cpu_count() |
|
1808 | threads = _cpu_count() | |
1700 |
|
1809 | |||
1701 | total_size = sum(map(len, samples)) |
|
1810 | total_size = sum(map(len, samples)) | |
1702 |
|
1811 | |||
1703 | samples_buffer = new_nonzero("char[]", total_size) |
|
1812 | samples_buffer = new_nonzero("char[]", total_size) | |
1704 | sample_sizes = new_nonzero("size_t[]", len(samples)) |
|
1813 | sample_sizes = new_nonzero("size_t[]", len(samples)) | |
1705 |
|
1814 | |||
1706 | offset = 0 |
|
1815 | offset = 0 | |
1707 | for i, sample in enumerate(samples): |
|
1816 | for i, sample in enumerate(samples): | |
1708 | if not isinstance(sample, bytes_type): |
|
1817 | if not isinstance(sample, bytes_type): | |
1709 | raise ValueError("samples must be bytes") |
|
1818 | raise ValueError("samples must be bytes") | |
1710 |
|
1819 | |||
1711 | l = len(sample) |
|
1820 | l = len(sample) | |
1712 | ffi.memmove(samples_buffer + offset, sample, l) |
|
1821 | ffi.memmove(samples_buffer + offset, sample, l) | |
1713 | offset += l |
|
1822 | offset += l | |
1714 | sample_sizes[i] = l |
|
1823 | sample_sizes[i] = l | |
1715 |
|
1824 | |||
1716 | dict_data = new_nonzero("char[]", dict_size) |
|
1825 | dict_data = new_nonzero("char[]", dict_size) | |
1717 |
|
1826 | |||
1718 | dparams = ffi.new("ZDICT_cover_params_t *")[0] |
|
1827 | dparams = ffi.new("ZDICT_cover_params_t *")[0] | |
1719 | dparams.k = k |
|
1828 | dparams.k = k | |
1720 | dparams.d = d |
|
1829 | dparams.d = d | |
1721 | dparams.steps = steps |
|
1830 | dparams.steps = steps | |
1722 | dparams.nbThreads = threads |
|
1831 | dparams.nbThreads = threads | |
1723 | dparams.zParams.notificationLevel = notifications |
|
1832 | dparams.zParams.notificationLevel = notifications | |
1724 | dparams.zParams.dictID = dict_id |
|
1833 | dparams.zParams.dictID = dict_id | |
1725 | dparams.zParams.compressionLevel = level |
|
1834 | dparams.zParams.compressionLevel = level | |
1726 |
|
1835 | |||
1727 | if ( |
|
1836 | if ( | |
1728 | not dparams.k |
|
1837 | not dparams.k | |
1729 | and not dparams.d |
|
1838 | and not dparams.d | |
1730 | and not dparams.steps |
|
1839 | and not dparams.steps | |
1731 | and not dparams.nbThreads |
|
1840 | and not dparams.nbThreads | |
1732 | and not dparams.zParams.notificationLevel |
|
1841 | and not dparams.zParams.notificationLevel | |
1733 | and not dparams.zParams.dictID |
|
1842 | and not dparams.zParams.dictID | |
1734 | and not dparams.zParams.compressionLevel |
|
1843 | and not dparams.zParams.compressionLevel | |
1735 | ): |
|
1844 | ): | |
1736 | zresult = lib.ZDICT_trainFromBuffer( |
|
1845 | zresult = lib.ZDICT_trainFromBuffer( | |
1737 | ffi.addressof(dict_data), |
|
1846 | ffi.addressof(dict_data), | |
1738 | dict_size, |
|
1847 | dict_size, | |
1739 | ffi.addressof(samples_buffer), |
|
1848 | ffi.addressof(samples_buffer), | |
1740 | ffi.addressof(sample_sizes, 0), |
|
1849 | ffi.addressof(sample_sizes, 0), | |
1741 | len(samples), |
|
1850 | len(samples), | |
1742 | ) |
|
1851 | ) | |
1743 | elif dparams.steps or dparams.nbThreads: |
|
1852 | elif dparams.steps or dparams.nbThreads: | |
1744 | zresult = lib.ZDICT_optimizeTrainFromBuffer_cover( |
|
1853 | zresult = lib.ZDICT_optimizeTrainFromBuffer_cover( | |
1745 | ffi.addressof(dict_data), |
|
1854 | ffi.addressof(dict_data), | |
1746 | dict_size, |
|
1855 | dict_size, | |
1747 | ffi.addressof(samples_buffer), |
|
1856 | ffi.addressof(samples_buffer), | |
1748 | ffi.addressof(sample_sizes, 0), |
|
1857 | ffi.addressof(sample_sizes, 0), | |
1749 | len(samples), |
|
1858 | len(samples), | |
1750 | ffi.addressof(dparams), |
|
1859 | ffi.addressof(dparams), | |
1751 | ) |
|
1860 | ) | |
1752 | else: |
|
1861 | else: | |
1753 | zresult = lib.ZDICT_trainFromBuffer_cover( |
|
1862 | zresult = lib.ZDICT_trainFromBuffer_cover( | |
1754 | ffi.addressof(dict_data), |
|
1863 | ffi.addressof(dict_data), | |
1755 | dict_size, |
|
1864 | dict_size, | |
1756 | ffi.addressof(samples_buffer), |
|
1865 | ffi.addressof(samples_buffer), | |
1757 | ffi.addressof(sample_sizes, 0), |
|
1866 | ffi.addressof(sample_sizes, 0), | |
1758 | len(samples), |
|
1867 | len(samples), | |
1759 | dparams, |
|
1868 | dparams, | |
1760 | ) |
|
1869 | ) | |
1761 |
|
1870 | |||
1762 | if lib.ZDICT_isError(zresult): |
|
1871 | if lib.ZDICT_isError(zresult): | |
1763 | msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode("utf-8") |
|
1872 | msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode("utf-8") | |
1764 | raise ZstdError("cannot train dict: %s" % msg) |
|
1873 | raise ZstdError("cannot train dict: %s" % msg) | |
1765 |
|
1874 | |||
1766 | return ZstdCompressionDict( |
|
1875 | return ZstdCompressionDict( | |
1767 | ffi.buffer(dict_data, zresult)[:], |
|
1876 | ffi.buffer(dict_data, zresult)[:], | |
1768 | dict_type=DICT_TYPE_FULLDICT, |
|
1877 | dict_type=DICT_TYPE_FULLDICT, | |
1769 | k=dparams.k, |
|
1878 | k=dparams.k, | |
1770 | d=dparams.d, |
|
1879 | d=dparams.d, | |
1771 | ) |
|
1880 | ) | |
1772 |
|
1881 | |||
1773 |
|
1882 | |||
1774 | class ZstdDecompressionObj(object): |
|
1883 | class ZstdDecompressionObj(object): | |
1775 | def __init__(self, decompressor, write_size): |
|
1884 | def __init__(self, decompressor, write_size): | |
1776 | self._decompressor = decompressor |
|
1885 | self._decompressor = decompressor | |
1777 | self._write_size = write_size |
|
1886 | self._write_size = write_size | |
1778 | self._finished = False |
|
1887 | self._finished = False | |
1779 |
|
1888 | |||
1780 | def decompress(self, data): |
|
1889 | def decompress(self, data): | |
1781 | if self._finished: |
|
1890 | if self._finished: | |
1782 | raise ZstdError("cannot use a decompressobj multiple times") |
|
1891 | raise ZstdError("cannot use a decompressobj multiple times") | |
1783 |
|
1892 | |||
1784 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
1893 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
1785 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
1894 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1786 |
|
1895 | |||
1787 | data_buffer = ffi.from_buffer(data) |
|
1896 | data_buffer = ffi.from_buffer(data) | |
1788 |
|
1897 | |||
1789 | if len(data_buffer) == 0: |
|
1898 | if len(data_buffer) == 0: | |
1790 | return b"" |
|
1899 | return b"" | |
1791 |
|
1900 | |||
1792 | in_buffer.src = data_buffer |
|
1901 | in_buffer.src = data_buffer | |
1793 | in_buffer.size = len(data_buffer) |
|
1902 | in_buffer.size = len(data_buffer) | |
1794 | in_buffer.pos = 0 |
|
1903 | in_buffer.pos = 0 | |
1795 |
|
1904 | |||
1796 | dst_buffer = ffi.new("char[]", self._write_size) |
|
1905 | dst_buffer = ffi.new("char[]", self._write_size) | |
1797 | out_buffer.dst = dst_buffer |
|
1906 | out_buffer.dst = dst_buffer | |
1798 | out_buffer.size = len(dst_buffer) |
|
1907 | out_buffer.size = len(dst_buffer) | |
1799 | out_buffer.pos = 0 |
|
1908 | out_buffer.pos = 0 | |
1800 |
|
1909 | |||
1801 | chunks = [] |
|
1910 | chunks = [] | |
1802 |
|
1911 | |||
1803 | while True: |
|
1912 | while True: | |
1804 | zresult = lib.ZSTD_decompressStream( |
|
1913 | zresult = lib.ZSTD_decompressStream( | |
1805 | self._decompressor._dctx, out_buffer, in_buffer |
|
1914 | self._decompressor._dctx, out_buffer, in_buffer | |
1806 | ) |
|
1915 | ) | |
1807 | if lib.ZSTD_isError(zresult): |
|
1916 | if lib.ZSTD_isError(zresult): | |
1808 | raise ZstdError("zstd decompressor error: %s" % _zstd_error(zresult)) |
|
1917 | raise ZstdError( | |
|
1918 | "zstd decompressor error: %s" % _zstd_error(zresult) | |||
|
1919 | ) | |||
1809 |
|
1920 | |||
1810 | if zresult == 0: |
|
1921 | if zresult == 0: | |
1811 | self._finished = True |
|
1922 | self._finished = True | |
1812 | self._decompressor = None |
|
1923 | self._decompressor = None | |
1813 |
|
1924 | |||
1814 | if out_buffer.pos: |
|
1925 | if out_buffer.pos: | |
1815 | chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) |
|
1926 | chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | |
1816 |
|
1927 | |||
1817 | if zresult == 0 or ( |
|
1928 | if zresult == 0 or ( | |
1818 | in_buffer.pos == in_buffer.size and out_buffer.pos == 0 |
|
1929 | in_buffer.pos == in_buffer.size and out_buffer.pos == 0 | |
1819 | ): |
|
1930 | ): | |
1820 | break |
|
1931 | break | |
1821 |
|
1932 | |||
1822 | out_buffer.pos = 0 |
|
1933 | out_buffer.pos = 0 | |
1823 |
|
1934 | |||
1824 | return b"".join(chunks) |
|
1935 | return b"".join(chunks) | |
1825 |
|
1936 | |||
1826 | def flush(self, length=0): |
|
1937 | def flush(self, length=0): | |
1827 | pass |
|
1938 | pass | |
1828 |
|
1939 | |||
1829 |
|
1940 | |||
1830 | class ZstdDecompressionReader(object): |
|
1941 | class ZstdDecompressionReader(object): | |
1831 | def __init__(self, decompressor, source, read_size, read_across_frames): |
|
1942 | def __init__(self, decompressor, source, read_size, read_across_frames): | |
1832 | self._decompressor = decompressor |
|
1943 | self._decompressor = decompressor | |
1833 | self._source = source |
|
1944 | self._source = source | |
1834 | self._read_size = read_size |
|
1945 | self._read_size = read_size | |
1835 | self._read_across_frames = bool(read_across_frames) |
|
1946 | self._read_across_frames = bool(read_across_frames) | |
1836 | self._entered = False |
|
1947 | self._entered = False | |
1837 | self._closed = False |
|
1948 | self._closed = False | |
1838 | self._bytes_decompressed = 0 |
|
1949 | self._bytes_decompressed = 0 | |
1839 | self._finished_input = False |
|
1950 | self._finished_input = False | |
1840 | self._finished_output = False |
|
1951 | self._finished_output = False | |
1841 | self._in_buffer = ffi.new("ZSTD_inBuffer *") |
|
1952 | self._in_buffer = ffi.new("ZSTD_inBuffer *") | |
1842 | # Holds a ref to self._in_buffer.src. |
|
1953 | # Holds a ref to self._in_buffer.src. | |
1843 | self._source_buffer = None |
|
1954 | self._source_buffer = None | |
1844 |
|
1955 | |||
1845 | def __enter__(self): |
|
1956 | def __enter__(self): | |
1846 | if self._entered: |
|
1957 | if self._entered: | |
1847 | raise ValueError("cannot __enter__ multiple times") |
|
1958 | raise ValueError("cannot __enter__ multiple times") | |
1848 |
|
1959 | |||
1849 | self._entered = True |
|
1960 | self._entered = True | |
1850 | return self |
|
1961 | return self | |
1851 |
|
1962 | |||
1852 | def __exit__(self, exc_type, exc_value, exc_tb): |
|
1963 | def __exit__(self, exc_type, exc_value, exc_tb): | |
1853 | self._entered = False |
|
1964 | self._entered = False | |
1854 | self._closed = True |
|
1965 | self._closed = True | |
1855 | self._source = None |
|
1966 | self._source = None | |
1856 | self._decompressor = None |
|
1967 | self._decompressor = None | |
1857 |
|
1968 | |||
1858 | return False |
|
1969 | return False | |
1859 |
|
1970 | |||
1860 | def readable(self): |
|
1971 | def readable(self): | |
1861 | return True |
|
1972 | return True | |
1862 |
|
1973 | |||
1863 | def writable(self): |
|
1974 | def writable(self): | |
1864 | return False |
|
1975 | return False | |
1865 |
|
1976 | |||
1866 | def seekable(self): |
|
1977 | def seekable(self): | |
1867 | return True |
|
1978 | return True | |
1868 |
|
1979 | |||
1869 | def readline(self): |
|
1980 | def readline(self): | |
1870 | raise io.UnsupportedOperation() |
|
1981 | raise io.UnsupportedOperation() | |
1871 |
|
1982 | |||
1872 | def readlines(self): |
|
1983 | def readlines(self): | |
1873 | raise io.UnsupportedOperation() |
|
1984 | raise io.UnsupportedOperation() | |
1874 |
|
1985 | |||
1875 | def write(self, data): |
|
1986 | def write(self, data): | |
1876 | raise io.UnsupportedOperation() |
|
1987 | raise io.UnsupportedOperation() | |
1877 |
|
1988 | |||
1878 | def writelines(self, lines): |
|
1989 | def writelines(self, lines): | |
1879 | raise io.UnsupportedOperation() |
|
1990 | raise io.UnsupportedOperation() | |
1880 |
|
1991 | |||
1881 | def isatty(self): |
|
1992 | def isatty(self): | |
1882 | return False |
|
1993 | return False | |
1883 |
|
1994 | |||
1884 | def flush(self): |
|
1995 | def flush(self): | |
1885 | return None |
|
1996 | return None | |
1886 |
|
1997 | |||
1887 | def close(self): |
|
1998 | def close(self): | |
1888 | self._closed = True |
|
1999 | self._closed = True | |
1889 | return None |
|
2000 | return None | |
1890 |
|
2001 | |||
1891 | @property |
|
2002 | @property | |
1892 | def closed(self): |
|
2003 | def closed(self): | |
1893 | return self._closed |
|
2004 | return self._closed | |
1894 |
|
2005 | |||
1895 | def tell(self): |
|
2006 | def tell(self): | |
1896 | return self._bytes_decompressed |
|
2007 | return self._bytes_decompressed | |
1897 |
|
2008 | |||
1898 | def readall(self): |
|
2009 | def readall(self): | |
1899 | chunks = [] |
|
2010 | chunks = [] | |
1900 |
|
2011 | |||
1901 | while True: |
|
2012 | while True: | |
1902 | chunk = self.read(1048576) |
|
2013 | chunk = self.read(1048576) | |
1903 | if not chunk: |
|
2014 | if not chunk: | |
1904 | break |
|
2015 | break | |
1905 |
|
2016 | |||
1906 | chunks.append(chunk) |
|
2017 | chunks.append(chunk) | |
1907 |
|
2018 | |||
1908 | return b"".join(chunks) |
|
2019 | return b"".join(chunks) | |
1909 |
|
2020 | |||
1910 | def __iter__(self): |
|
2021 | def __iter__(self): | |
1911 | raise io.UnsupportedOperation() |
|
2022 | raise io.UnsupportedOperation() | |
1912 |
|
2023 | |||
1913 | def __next__(self): |
|
2024 | def __next__(self): | |
1914 | raise io.UnsupportedOperation() |
|
2025 | raise io.UnsupportedOperation() | |
1915 |
|
2026 | |||
1916 | next = __next__ |
|
2027 | next = __next__ | |
1917 |
|
2028 | |||
1918 | def _read_input(self): |
|
2029 | def _read_input(self): | |
1919 | # We have data left over in the input buffer. Use it. |
|
2030 | # We have data left over in the input buffer. Use it. | |
1920 | if self._in_buffer.pos < self._in_buffer.size: |
|
2031 | if self._in_buffer.pos < self._in_buffer.size: | |
1921 | return |
|
2032 | return | |
1922 |
|
2033 | |||
1923 | # All input data exhausted. Nothing to do. |
|
2034 | # All input data exhausted. Nothing to do. | |
1924 | if self._finished_input: |
|
2035 | if self._finished_input: | |
1925 | return |
|
2036 | return | |
1926 |
|
2037 | |||
1927 | # Else populate the input buffer from our source. |
|
2038 | # Else populate the input buffer from our source. | |
1928 | if hasattr(self._source, "read"): |
|
2039 | if hasattr(self._source, "read"): | |
1929 | data = self._source.read(self._read_size) |
|
2040 | data = self._source.read(self._read_size) | |
1930 |
|
2041 | |||
1931 | if not data: |
|
2042 | if not data: | |
1932 | self._finished_input = True |
|
2043 | self._finished_input = True | |
1933 | return |
|
2044 | return | |
1934 |
|
2045 | |||
1935 | self._source_buffer = ffi.from_buffer(data) |
|
2046 | self._source_buffer = ffi.from_buffer(data) | |
1936 | self._in_buffer.src = self._source_buffer |
|
2047 | self._in_buffer.src = self._source_buffer | |
1937 | self._in_buffer.size = len(self._source_buffer) |
|
2048 | self._in_buffer.size = len(self._source_buffer) | |
1938 | self._in_buffer.pos = 0 |
|
2049 | self._in_buffer.pos = 0 | |
1939 | else: |
|
2050 | else: | |
1940 | self._source_buffer = ffi.from_buffer(self._source) |
|
2051 | self._source_buffer = ffi.from_buffer(self._source) | |
1941 | self._in_buffer.src = self._source_buffer |
|
2052 | self._in_buffer.src = self._source_buffer | |
1942 | self._in_buffer.size = len(self._source_buffer) |
|
2053 | self._in_buffer.size = len(self._source_buffer) | |
1943 | self._in_buffer.pos = 0 |
|
2054 | self._in_buffer.pos = 0 | |
1944 |
|
2055 | |||
1945 | def _decompress_into_buffer(self, out_buffer): |
|
2056 | def _decompress_into_buffer(self, out_buffer): | |
1946 | """Decompress available input into an output buffer. |
|
2057 | """Decompress available input into an output buffer. | |
1947 |
|
2058 | |||
1948 | Returns True if data in output buffer should be emitted. |
|
2059 | Returns True if data in output buffer should be emitted. | |
1949 | """ |
|
2060 | """ | |
1950 | zresult = lib.ZSTD_decompressStream( |
|
2061 | zresult = lib.ZSTD_decompressStream( | |
1951 | self._decompressor._dctx, out_buffer, self._in_buffer |
|
2062 | self._decompressor._dctx, out_buffer, self._in_buffer | |
1952 | ) |
|
2063 | ) | |
1953 |
|
2064 | |||
1954 | if self._in_buffer.pos == self._in_buffer.size: |
|
2065 | if self._in_buffer.pos == self._in_buffer.size: | |
1955 | self._in_buffer.src = ffi.NULL |
|
2066 | self._in_buffer.src = ffi.NULL | |
1956 | self._in_buffer.pos = 0 |
|
2067 | self._in_buffer.pos = 0 | |
1957 | self._in_buffer.size = 0 |
|
2068 | self._in_buffer.size = 0 | |
1958 | self._source_buffer = None |
|
2069 | self._source_buffer = None | |
1959 |
|
2070 | |||
1960 | if not hasattr(self._source, "read"): |
|
2071 | if not hasattr(self._source, "read"): | |
1961 | self._finished_input = True |
|
2072 | self._finished_input = True | |
1962 |
|
2073 | |||
1963 | if lib.ZSTD_isError(zresult): |
|
2074 | if lib.ZSTD_isError(zresult): | |
1964 | raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) |
|
2075 | raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) | |
1965 |
|
2076 | |||
1966 | # Emit data if there is data AND either: |
|
2077 | # Emit data if there is data AND either: | |
1967 | # a) output buffer is full (read amount is satisfied) |
|
2078 | # a) output buffer is full (read amount is satisfied) | |
1968 | # b) we're at end of a frame and not in frame spanning mode |
|
2079 | # b) we're at end of a frame and not in frame spanning mode | |
1969 | return out_buffer.pos and ( |
|
2080 | return out_buffer.pos and ( | |
1970 | out_buffer.pos == out_buffer.size |
|
2081 | out_buffer.pos == out_buffer.size | |
1971 | or zresult == 0 |
|
2082 | or zresult == 0 | |
1972 | and not self._read_across_frames |
|
2083 | and not self._read_across_frames | |
1973 | ) |
|
2084 | ) | |
1974 |
|
2085 | |||
1975 | def read(self, size=-1): |
|
2086 | def read(self, size=-1): | |
1976 | if self._closed: |
|
2087 | if self._closed: | |
1977 | raise ValueError("stream is closed") |
|
2088 | raise ValueError("stream is closed") | |
1978 |
|
2089 | |||
1979 | if size < -1: |
|
2090 | if size < -1: | |
1980 | raise ValueError("cannot read negative amounts less than -1") |
|
2091 | raise ValueError("cannot read negative amounts less than -1") | |
1981 |
|
2092 | |||
1982 | if size == -1: |
|
2093 | if size == -1: | |
1983 | # This is recursive. But it gets the job done. |
|
2094 | # This is recursive. But it gets the job done. | |
1984 | return self.readall() |
|
2095 | return self.readall() | |
1985 |
|
2096 | |||
1986 | if self._finished_output or size == 0: |
|
2097 | if self._finished_output or size == 0: | |
1987 | return b"" |
|
2098 | return b"" | |
1988 |
|
2099 | |||
1989 | # We /could/ call into readinto() here. But that introduces more |
|
2100 | # We /could/ call into readinto() here. But that introduces more | |
1990 | # overhead. |
|
2101 | # overhead. | |
1991 | dst_buffer = ffi.new("char[]", size) |
|
2102 | dst_buffer = ffi.new("char[]", size) | |
1992 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2103 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
1993 | out_buffer.dst = dst_buffer |
|
2104 | out_buffer.dst = dst_buffer | |
1994 | out_buffer.size = size |
|
2105 | out_buffer.size = size | |
1995 | out_buffer.pos = 0 |
|
2106 | out_buffer.pos = 0 | |
1996 |
|
2107 | |||
1997 | self._read_input() |
|
2108 | self._read_input() | |
1998 | if self._decompress_into_buffer(out_buffer): |
|
2109 | if self._decompress_into_buffer(out_buffer): | |
1999 | self._bytes_decompressed += out_buffer.pos |
|
2110 | self._bytes_decompressed += out_buffer.pos | |
2000 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
2111 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
2001 |
|
2112 | |||
2002 | while not self._finished_input: |
|
2113 | while not self._finished_input: | |
2003 | self._read_input() |
|
2114 | self._read_input() | |
2004 | if self._decompress_into_buffer(out_buffer): |
|
2115 | if self._decompress_into_buffer(out_buffer): | |
2005 | self._bytes_decompressed += out_buffer.pos |
|
2116 | self._bytes_decompressed += out_buffer.pos | |
2006 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
2117 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
2007 |
|
2118 | |||
2008 | self._bytes_decompressed += out_buffer.pos |
|
2119 | self._bytes_decompressed += out_buffer.pos | |
2009 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
2120 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
2010 |
|
2121 | |||
2011 | def readinto(self, b): |
|
2122 | def readinto(self, b): | |
2012 | if self._closed: |
|
2123 | if self._closed: | |
2013 | raise ValueError("stream is closed") |
|
2124 | raise ValueError("stream is closed") | |
2014 |
|
2125 | |||
2015 | if self._finished_output: |
|
2126 | if self._finished_output: | |
2016 | return 0 |
|
2127 | return 0 | |
2017 |
|
2128 | |||
2018 | # TODO use writable=True once we require CFFI >= 1.12. |
|
2129 | # TODO use writable=True once we require CFFI >= 1.12. | |
2019 | dest_buffer = ffi.from_buffer(b) |
|
2130 | dest_buffer = ffi.from_buffer(b) | |
2020 | ffi.memmove(b, b"", 0) |
|
2131 | ffi.memmove(b, b"", 0) | |
2021 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2132 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2022 | out_buffer.dst = dest_buffer |
|
2133 | out_buffer.dst = dest_buffer | |
2023 | out_buffer.size = len(dest_buffer) |
|
2134 | out_buffer.size = len(dest_buffer) | |
2024 | out_buffer.pos = 0 |
|
2135 | out_buffer.pos = 0 | |
2025 |
|
2136 | |||
2026 | self._read_input() |
|
2137 | self._read_input() | |
2027 | if self._decompress_into_buffer(out_buffer): |
|
2138 | if self._decompress_into_buffer(out_buffer): | |
2028 | self._bytes_decompressed += out_buffer.pos |
|
2139 | self._bytes_decompressed += out_buffer.pos | |
2029 | return out_buffer.pos |
|
2140 | return out_buffer.pos | |
2030 |
|
2141 | |||
2031 | while not self._finished_input: |
|
2142 | while not self._finished_input: | |
2032 | self._read_input() |
|
2143 | self._read_input() | |
2033 | if self._decompress_into_buffer(out_buffer): |
|
2144 | if self._decompress_into_buffer(out_buffer): | |
2034 | self._bytes_decompressed += out_buffer.pos |
|
2145 | self._bytes_decompressed += out_buffer.pos | |
2035 | return out_buffer.pos |
|
2146 | return out_buffer.pos | |
2036 |
|
2147 | |||
2037 | self._bytes_decompressed += out_buffer.pos |
|
2148 | self._bytes_decompressed += out_buffer.pos | |
2038 | return out_buffer.pos |
|
2149 | return out_buffer.pos | |
2039 |
|
2150 | |||
2040 | def read1(self, size=-1): |
|
2151 | def read1(self, size=-1): | |
2041 | if self._closed: |
|
2152 | if self._closed: | |
2042 | raise ValueError("stream is closed") |
|
2153 | raise ValueError("stream is closed") | |
2043 |
|
2154 | |||
2044 | if size < -1: |
|
2155 | if size < -1: | |
2045 | raise ValueError("cannot read negative amounts less than -1") |
|
2156 | raise ValueError("cannot read negative amounts less than -1") | |
2046 |
|
2157 | |||
2047 | if self._finished_output or size == 0: |
|
2158 | if self._finished_output or size == 0: | |
2048 | return b"" |
|
2159 | return b"" | |
2049 |
|
2160 | |||
2050 | # -1 returns arbitrary number of bytes. |
|
2161 | # -1 returns arbitrary number of bytes. | |
2051 | if size == -1: |
|
2162 | if size == -1: | |
2052 | size = DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE |
|
2163 | size = DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE | |
2053 |
|
2164 | |||
2054 | dst_buffer = ffi.new("char[]", size) |
|
2165 | dst_buffer = ffi.new("char[]", size) | |
2055 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2166 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2056 | out_buffer.dst = dst_buffer |
|
2167 | out_buffer.dst = dst_buffer | |
2057 | out_buffer.size = size |
|
2168 | out_buffer.size = size | |
2058 | out_buffer.pos = 0 |
|
2169 | out_buffer.pos = 0 | |
2059 |
|
2170 | |||
2060 | # read1() dictates that we can perform at most 1 call to underlying |
|
2171 | # read1() dictates that we can perform at most 1 call to underlying | |
2061 | # stream to get input. However, we can't satisfy this restriction with |
|
2172 | # stream to get input. However, we can't satisfy this restriction with | |
2062 | # decompression because not all input generates output. So we allow |
|
2173 | # decompression because not all input generates output. So we allow | |
2063 | # multiple read(). But unlike read(), we stop once we have any output. |
|
2174 | # multiple read(). But unlike read(), we stop once we have any output. | |
2064 | while not self._finished_input: |
|
2175 | while not self._finished_input: | |
2065 | self._read_input() |
|
2176 | self._read_input() | |
2066 | self._decompress_into_buffer(out_buffer) |
|
2177 | self._decompress_into_buffer(out_buffer) | |
2067 |
|
2178 | |||
2068 | if out_buffer.pos: |
|
2179 | if out_buffer.pos: | |
2069 | break |
|
2180 | break | |
2070 |
|
2181 | |||
2071 | self._bytes_decompressed += out_buffer.pos |
|
2182 | self._bytes_decompressed += out_buffer.pos | |
2072 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
2183 | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
2073 |
|
2184 | |||
2074 | def readinto1(self, b): |
|
2185 | def readinto1(self, b): | |
2075 | if self._closed: |
|
2186 | if self._closed: | |
2076 | raise ValueError("stream is closed") |
|
2187 | raise ValueError("stream is closed") | |
2077 |
|
2188 | |||
2078 | if self._finished_output: |
|
2189 | if self._finished_output: | |
2079 | return 0 |
|
2190 | return 0 | |
2080 |
|
2191 | |||
2081 | # TODO use writable=True once we require CFFI >= 1.12. |
|
2192 | # TODO use writable=True once we require CFFI >= 1.12. | |
2082 | dest_buffer = ffi.from_buffer(b) |
|
2193 | dest_buffer = ffi.from_buffer(b) | |
2083 | ffi.memmove(b, b"", 0) |
|
2194 | ffi.memmove(b, b"", 0) | |
2084 |
|
2195 | |||
2085 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2196 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2086 | out_buffer.dst = dest_buffer |
|
2197 | out_buffer.dst = dest_buffer | |
2087 | out_buffer.size = len(dest_buffer) |
|
2198 | out_buffer.size = len(dest_buffer) | |
2088 | out_buffer.pos = 0 |
|
2199 | out_buffer.pos = 0 | |
2089 |
|
2200 | |||
2090 | while not self._finished_input and not self._finished_output: |
|
2201 | while not self._finished_input and not self._finished_output: | |
2091 | self._read_input() |
|
2202 | self._read_input() | |
2092 | self._decompress_into_buffer(out_buffer) |
|
2203 | self._decompress_into_buffer(out_buffer) | |
2093 |
|
2204 | |||
2094 | if out_buffer.pos: |
|
2205 | if out_buffer.pos: | |
2095 | break |
|
2206 | break | |
2096 |
|
2207 | |||
2097 | self._bytes_decompressed += out_buffer.pos |
|
2208 | self._bytes_decompressed += out_buffer.pos | |
2098 | return out_buffer.pos |
|
2209 | return out_buffer.pos | |
2099 |
|
2210 | |||
2100 | def seek(self, pos, whence=os.SEEK_SET): |
|
2211 | def seek(self, pos, whence=os.SEEK_SET): | |
2101 | if self._closed: |
|
2212 | if self._closed: | |
2102 | raise ValueError("stream is closed") |
|
2213 | raise ValueError("stream is closed") | |
2103 |
|
2214 | |||
2104 | read_amount = 0 |
|
2215 | read_amount = 0 | |
2105 |
|
2216 | |||
2106 | if whence == os.SEEK_SET: |
|
2217 | if whence == os.SEEK_SET: | |
2107 | if pos < 0: |
|
2218 | if pos < 0: | |
2108 | raise ValueError("cannot seek to negative position with SEEK_SET") |
|
2219 | raise ValueError( | |
|
2220 | "cannot seek to negative position with SEEK_SET" | |||
|
2221 | ) | |||
2109 |
|
2222 | |||
2110 | if pos < self._bytes_decompressed: |
|
2223 | if pos < self._bytes_decompressed: | |
2111 | raise ValueError("cannot seek zstd decompression stream " "backwards") |
|
2224 | raise ValueError( | |
|
2225 | "cannot seek zstd decompression stream " "backwards" | |||
|
2226 | ) | |||
2112 |
|
2227 | |||
2113 | read_amount = pos - self._bytes_decompressed |
|
2228 | read_amount = pos - self._bytes_decompressed | |
2114 |
|
2229 | |||
2115 | elif whence == os.SEEK_CUR: |
|
2230 | elif whence == os.SEEK_CUR: | |
2116 | if pos < 0: |
|
2231 | if pos < 0: | |
2117 | raise ValueError("cannot seek zstd decompression stream " "backwards") |
|
2232 | raise ValueError( | |
|
2233 | "cannot seek zstd decompression stream " "backwards" | |||
|
2234 | ) | |||
2118 |
|
2235 | |||
2119 | read_amount = pos |
|
2236 | read_amount = pos | |
2120 | elif whence == os.SEEK_END: |
|
2237 | elif whence == os.SEEK_END: | |
2121 | raise ValueError( |
|
2238 | raise ValueError( | |
2122 | "zstd decompression streams cannot be seeked " "with SEEK_END" |
|
2239 | "zstd decompression streams cannot be seeked " "with SEEK_END" | |
2123 | ) |
|
2240 | ) | |
2124 |
|
2241 | |||
2125 | while read_amount: |
|
2242 | while read_amount: | |
2126 | result = self.read(min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)) |
|
2243 | result = self.read( | |
|
2244 | min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) | |||
|
2245 | ) | |||
2127 |
|
2246 | |||
2128 | if not result: |
|
2247 | if not result: | |
2129 | break |
|
2248 | break | |
2130 |
|
2249 | |||
2131 | read_amount -= len(result) |
|
2250 | read_amount -= len(result) | |
2132 |
|
2251 | |||
2133 | return self._bytes_decompressed |
|
2252 | return self._bytes_decompressed | |
2134 |
|
2253 | |||
2135 |
|
2254 | |||
2136 | class ZstdDecompressionWriter(object): |
|
2255 | class ZstdDecompressionWriter(object): | |
2137 | def __init__(self, decompressor, writer, write_size, write_return_read): |
|
2256 | def __init__(self, decompressor, writer, write_size, write_return_read): | |
2138 | decompressor._ensure_dctx() |
|
2257 | decompressor._ensure_dctx() | |
2139 |
|
2258 | |||
2140 | self._decompressor = decompressor |
|
2259 | self._decompressor = decompressor | |
2141 | self._writer = writer |
|
2260 | self._writer = writer | |
2142 | self._write_size = write_size |
|
2261 | self._write_size = write_size | |
2143 | self._write_return_read = bool(write_return_read) |
|
2262 | self._write_return_read = bool(write_return_read) | |
2144 | self._entered = False |
|
2263 | self._entered = False | |
2145 | self._closed = False |
|
2264 | self._closed = False | |
2146 |
|
2265 | |||
2147 | def __enter__(self): |
|
2266 | def __enter__(self): | |
2148 | if self._closed: |
|
2267 | if self._closed: | |
2149 | raise ValueError("stream is closed") |
|
2268 | raise ValueError("stream is closed") | |
2150 |
|
2269 | |||
2151 | if self._entered: |
|
2270 | if self._entered: | |
2152 | raise ZstdError("cannot __enter__ multiple times") |
|
2271 | raise ZstdError("cannot __enter__ multiple times") | |
2153 |
|
2272 | |||
2154 | self._entered = True |
|
2273 | self._entered = True | |
2155 |
|
2274 | |||
2156 | return self |
|
2275 | return self | |
2157 |
|
2276 | |||
2158 | def __exit__(self, exc_type, exc_value, exc_tb): |
|
2277 | def __exit__(self, exc_type, exc_value, exc_tb): | |
2159 | self._entered = False |
|
2278 | self._entered = False | |
2160 | self.close() |
|
2279 | self.close() | |
2161 |
|
2280 | |||
2162 | def memory_size(self): |
|
2281 | def memory_size(self): | |
2163 | return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx) |
|
2282 | return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx) | |
2164 |
|
2283 | |||
2165 | def close(self): |
|
2284 | def close(self): | |
2166 | if self._closed: |
|
2285 | if self._closed: | |
2167 | return |
|
2286 | return | |
2168 |
|
2287 | |||
2169 | try: |
|
2288 | try: | |
2170 | self.flush() |
|
2289 | self.flush() | |
2171 | finally: |
|
2290 | finally: | |
2172 | self._closed = True |
|
2291 | self._closed = True | |
2173 |
|
2292 | |||
2174 | f = getattr(self._writer, "close", None) |
|
2293 | f = getattr(self._writer, "close", None) | |
2175 | if f: |
|
2294 | if f: | |
2176 | f() |
|
2295 | f() | |
2177 |
|
2296 | |||
2178 | @property |
|
2297 | @property | |
2179 | def closed(self): |
|
2298 | def closed(self): | |
2180 | return self._closed |
|
2299 | return self._closed | |
2181 |
|
2300 | |||
2182 | def fileno(self): |
|
2301 | def fileno(self): | |
2183 | f = getattr(self._writer, "fileno", None) |
|
2302 | f = getattr(self._writer, "fileno", None) | |
2184 | if f: |
|
2303 | if f: | |
2185 | return f() |
|
2304 | return f() | |
2186 | else: |
|
2305 | else: | |
2187 | raise OSError("fileno not available on underlying writer") |
|
2306 | raise OSError("fileno not available on underlying writer") | |
2188 |
|
2307 | |||
2189 | def flush(self): |
|
2308 | def flush(self): | |
2190 | if self._closed: |
|
2309 | if self._closed: | |
2191 | raise ValueError("stream is closed") |
|
2310 | raise ValueError("stream is closed") | |
2192 |
|
2311 | |||
2193 | f = getattr(self._writer, "flush", None) |
|
2312 | f = getattr(self._writer, "flush", None) | |
2194 | if f: |
|
2313 | if f: | |
2195 | return f() |
|
2314 | return f() | |
2196 |
|
2315 | |||
2197 | def isatty(self): |
|
2316 | def isatty(self): | |
2198 | return False |
|
2317 | return False | |
2199 |
|
2318 | |||
2200 | def readable(self): |
|
2319 | def readable(self): | |
2201 | return False |
|
2320 | return False | |
2202 |
|
2321 | |||
2203 | def readline(self, size=-1): |
|
2322 | def readline(self, size=-1): | |
2204 | raise io.UnsupportedOperation() |
|
2323 | raise io.UnsupportedOperation() | |
2205 |
|
2324 | |||
2206 | def readlines(self, hint=-1): |
|
2325 | def readlines(self, hint=-1): | |
2207 | raise io.UnsupportedOperation() |
|
2326 | raise io.UnsupportedOperation() | |
2208 |
|
2327 | |||
2209 | def seek(self, offset, whence=None): |
|
2328 | def seek(self, offset, whence=None): | |
2210 | raise io.UnsupportedOperation() |
|
2329 | raise io.UnsupportedOperation() | |
2211 |
|
2330 | |||
2212 | def seekable(self): |
|
2331 | def seekable(self): | |
2213 | return False |
|
2332 | return False | |
2214 |
|
2333 | |||
2215 | def tell(self): |
|
2334 | def tell(self): | |
2216 | raise io.UnsupportedOperation() |
|
2335 | raise io.UnsupportedOperation() | |
2217 |
|
2336 | |||
2218 | def truncate(self, size=None): |
|
2337 | def truncate(self, size=None): | |
2219 | raise io.UnsupportedOperation() |
|
2338 | raise io.UnsupportedOperation() | |
2220 |
|
2339 | |||
2221 | def writable(self): |
|
2340 | def writable(self): | |
2222 | return True |
|
2341 | return True | |
2223 |
|
2342 | |||
2224 | def writelines(self, lines): |
|
2343 | def writelines(self, lines): | |
2225 | raise io.UnsupportedOperation() |
|
2344 | raise io.UnsupportedOperation() | |
2226 |
|
2345 | |||
2227 | def read(self, size=-1): |
|
2346 | def read(self, size=-1): | |
2228 | raise io.UnsupportedOperation() |
|
2347 | raise io.UnsupportedOperation() | |
2229 |
|
2348 | |||
2230 | def readall(self): |
|
2349 | def readall(self): | |
2231 | raise io.UnsupportedOperation() |
|
2350 | raise io.UnsupportedOperation() | |
2232 |
|
2351 | |||
2233 | def readinto(self, b): |
|
2352 | def readinto(self, b): | |
2234 | raise io.UnsupportedOperation() |
|
2353 | raise io.UnsupportedOperation() | |
2235 |
|
2354 | |||
2236 | def write(self, data): |
|
2355 | def write(self, data): | |
2237 | if self._closed: |
|
2356 | if self._closed: | |
2238 | raise ValueError("stream is closed") |
|
2357 | raise ValueError("stream is closed") | |
2239 |
|
2358 | |||
2240 | total_write = 0 |
|
2359 | total_write = 0 | |
2241 |
|
2360 | |||
2242 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
2361 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
2243 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2362 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2244 |
|
2363 | |||
2245 | data_buffer = ffi.from_buffer(data) |
|
2364 | data_buffer = ffi.from_buffer(data) | |
2246 | in_buffer.src = data_buffer |
|
2365 | in_buffer.src = data_buffer | |
2247 | in_buffer.size = len(data_buffer) |
|
2366 | in_buffer.size = len(data_buffer) | |
2248 | in_buffer.pos = 0 |
|
2367 | in_buffer.pos = 0 | |
2249 |
|
2368 | |||
2250 | dst_buffer = ffi.new("char[]", self._write_size) |
|
2369 | dst_buffer = ffi.new("char[]", self._write_size) | |
2251 | out_buffer.dst = dst_buffer |
|
2370 | out_buffer.dst = dst_buffer | |
2252 | out_buffer.size = len(dst_buffer) |
|
2371 | out_buffer.size = len(dst_buffer) | |
2253 | out_buffer.pos = 0 |
|
2372 | out_buffer.pos = 0 | |
2254 |
|
2373 | |||
2255 | dctx = self._decompressor._dctx |
|
2374 | dctx = self._decompressor._dctx | |
2256 |
|
2375 | |||
2257 | while in_buffer.pos < in_buffer.size: |
|
2376 | while in_buffer.pos < in_buffer.size: | |
2258 | zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer) |
|
2377 | zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer) | |
2259 | if lib.ZSTD_isError(zresult): |
|
2378 | if lib.ZSTD_isError(zresult): | |
2260 | raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) |
|
2379 | raise ZstdError( | |
|
2380 | "zstd decompress error: %s" % _zstd_error(zresult) | |||
|
2381 | ) | |||
2261 |
|
2382 | |||
2262 | if out_buffer.pos: |
|
2383 | if out_buffer.pos: | |
2263 |
self._writer.write( |
|
2384 | self._writer.write( | |
|
2385 | ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||
|
2386 | ) | |||
2264 | total_write += out_buffer.pos |
|
2387 | total_write += out_buffer.pos | |
2265 | out_buffer.pos = 0 |
|
2388 | out_buffer.pos = 0 | |
2266 |
|
2389 | |||
2267 | if self._write_return_read: |
|
2390 | if self._write_return_read: | |
2268 | return in_buffer.pos |
|
2391 | return in_buffer.pos | |
2269 | else: |
|
2392 | else: | |
2270 | return total_write |
|
2393 | return total_write | |
2271 |
|
2394 | |||
2272 |
|
2395 | |||
2273 | class ZstdDecompressor(object): |
|
2396 | class ZstdDecompressor(object): | |
2274 | def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1): |
|
2397 | def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1): | |
2275 | self._dict_data = dict_data |
|
2398 | self._dict_data = dict_data | |
2276 | self._max_window_size = max_window_size |
|
2399 | self._max_window_size = max_window_size | |
2277 | self._format = format |
|
2400 | self._format = format | |
2278 |
|
2401 | |||
2279 | dctx = lib.ZSTD_createDCtx() |
|
2402 | dctx = lib.ZSTD_createDCtx() | |
2280 | if dctx == ffi.NULL: |
|
2403 | if dctx == ffi.NULL: | |
2281 | raise MemoryError() |
|
2404 | raise MemoryError() | |
2282 |
|
2405 | |||
2283 | self._dctx = dctx |
|
2406 | self._dctx = dctx | |
2284 |
|
2407 | |||
2285 | # Defer setting up garbage collection until full state is loaded so |
|
2408 | # Defer setting up garbage collection until full state is loaded so | |
2286 | # the memory size is more accurate. |
|
2409 | # the memory size is more accurate. | |
2287 | try: |
|
2410 | try: | |
2288 | self._ensure_dctx() |
|
2411 | self._ensure_dctx() | |
2289 | finally: |
|
2412 | finally: | |
2290 | self._dctx = ffi.gc( |
|
2413 | self._dctx = ffi.gc( | |
2291 | dctx, lib.ZSTD_freeDCtx, size=lib.ZSTD_sizeof_DCtx(dctx) |
|
2414 | dctx, lib.ZSTD_freeDCtx, size=lib.ZSTD_sizeof_DCtx(dctx) | |
2292 | ) |
|
2415 | ) | |
2293 |
|
2416 | |||
2294 | def memory_size(self): |
|
2417 | def memory_size(self): | |
2295 | return lib.ZSTD_sizeof_DCtx(self._dctx) |
|
2418 | return lib.ZSTD_sizeof_DCtx(self._dctx) | |
2296 |
|
2419 | |||
2297 | def decompress(self, data, max_output_size=0): |
|
2420 | def decompress(self, data, max_output_size=0): | |
2298 | self._ensure_dctx() |
|
2421 | self._ensure_dctx() | |
2299 |
|
2422 | |||
2300 | data_buffer = ffi.from_buffer(data) |
|
2423 | data_buffer = ffi.from_buffer(data) | |
2301 |
|
2424 | |||
2302 |
output_size = lib.ZSTD_getFrameContentSize( |
|
2425 | output_size = lib.ZSTD_getFrameContentSize( | |
|
2426 | data_buffer, len(data_buffer) | |||
|
2427 | ) | |||
2303 |
|
2428 | |||
2304 | if output_size == lib.ZSTD_CONTENTSIZE_ERROR: |
|
2429 | if output_size == lib.ZSTD_CONTENTSIZE_ERROR: | |
2305 | raise ZstdError("error determining content size from frame header") |
|
2430 | raise ZstdError("error determining content size from frame header") | |
2306 | elif output_size == 0: |
|
2431 | elif output_size == 0: | |
2307 | return b"" |
|
2432 | return b"" | |
2308 | elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: |
|
2433 | elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: | |
2309 | if not max_output_size: |
|
2434 | if not max_output_size: | |
2310 | raise ZstdError("could not determine content size in frame header") |
|
2435 | raise ZstdError( | |
|
2436 | "could not determine content size in frame header" | |||
|
2437 | ) | |||
2311 |
|
2438 | |||
2312 | result_buffer = ffi.new("char[]", max_output_size) |
|
2439 | result_buffer = ffi.new("char[]", max_output_size) | |
2313 | result_size = max_output_size |
|
2440 | result_size = max_output_size | |
2314 | output_size = 0 |
|
2441 | output_size = 0 | |
2315 | else: |
|
2442 | else: | |
2316 | result_buffer = ffi.new("char[]", output_size) |
|
2443 | result_buffer = ffi.new("char[]", output_size) | |
2317 | result_size = output_size |
|
2444 | result_size = output_size | |
2318 |
|
2445 | |||
2319 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2446 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2320 | out_buffer.dst = result_buffer |
|
2447 | out_buffer.dst = result_buffer | |
2321 | out_buffer.size = result_size |
|
2448 | out_buffer.size = result_size | |
2322 | out_buffer.pos = 0 |
|
2449 | out_buffer.pos = 0 | |
2323 |
|
2450 | |||
2324 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
2451 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
2325 | in_buffer.src = data_buffer |
|
2452 | in_buffer.src = data_buffer | |
2326 | in_buffer.size = len(data_buffer) |
|
2453 | in_buffer.size = len(data_buffer) | |
2327 | in_buffer.pos = 0 |
|
2454 | in_buffer.pos = 0 | |
2328 |
|
2455 | |||
2329 | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) |
|
2456 | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | |
2330 | if lib.ZSTD_isError(zresult): |
|
2457 | if lib.ZSTD_isError(zresult): | |
2331 | raise ZstdError("decompression error: %s" % _zstd_error(zresult)) |
|
2458 | raise ZstdError("decompression error: %s" % _zstd_error(zresult)) | |
2332 | elif zresult: |
|
2459 | elif zresult: | |
2333 | raise ZstdError("decompression error: did not decompress full frame") |
|
2460 | raise ZstdError( | |
|
2461 | "decompression error: did not decompress full frame" | |||
|
2462 | ) | |||
2334 | elif output_size and out_buffer.pos != output_size: |
|
2463 | elif output_size and out_buffer.pos != output_size: | |
2335 | raise ZstdError( |
|
2464 | raise ZstdError( | |
2336 | "decompression error: decompressed %d bytes; expected %d" |
|
2465 | "decompression error: decompressed %d bytes; expected %d" | |
2337 | % (zresult, output_size) |
|
2466 | % (zresult, output_size) | |
2338 | ) |
|
2467 | ) | |
2339 |
|
2468 | |||
2340 | return ffi.buffer(result_buffer, out_buffer.pos)[:] |
|
2469 | return ffi.buffer(result_buffer, out_buffer.pos)[:] | |
2341 |
|
2470 | |||
2342 | def stream_reader( |
|
2471 | def stream_reader( | |
2343 | self, |
|
2472 | self, | |
2344 | source, |
|
2473 | source, | |
2345 | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, |
|
2474 | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | |
2346 | read_across_frames=False, |
|
2475 | read_across_frames=False, | |
2347 | ): |
|
2476 | ): | |
2348 | self._ensure_dctx() |
|
2477 | self._ensure_dctx() | |
2349 |
return ZstdDecompressionReader( |
|
2478 | return ZstdDecompressionReader( | |
|
2479 | self, source, read_size, read_across_frames | |||
|
2480 | ) | |||
2350 |
|
2481 | |||
2351 | def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): |
|
2482 | def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): | |
2352 | if write_size < 1: |
|
2483 | if write_size < 1: | |
2353 | raise ValueError("write_size must be positive") |
|
2484 | raise ValueError("write_size must be positive") | |
2354 |
|
2485 | |||
2355 | self._ensure_dctx() |
|
2486 | self._ensure_dctx() | |
2356 | return ZstdDecompressionObj(self, write_size=write_size) |
|
2487 | return ZstdDecompressionObj(self, write_size=write_size) | |
2357 |
|
2488 | |||
2358 | def read_to_iter( |
|
2489 | def read_to_iter( | |
2359 | self, |
|
2490 | self, | |
2360 | reader, |
|
2491 | reader, | |
2361 | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, |
|
2492 | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | |
2362 | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, |
|
2493 | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |
2363 | skip_bytes=0, |
|
2494 | skip_bytes=0, | |
2364 | ): |
|
2495 | ): | |
2365 | if skip_bytes >= read_size: |
|
2496 | if skip_bytes >= read_size: | |
2366 | raise ValueError("skip_bytes must be smaller than read_size") |
|
2497 | raise ValueError("skip_bytes must be smaller than read_size") | |
2367 |
|
2498 | |||
2368 | if hasattr(reader, "read"): |
|
2499 | if hasattr(reader, "read"): | |
2369 | have_read = True |
|
2500 | have_read = True | |
2370 | elif hasattr(reader, "__getitem__"): |
|
2501 | elif hasattr(reader, "__getitem__"): | |
2371 | have_read = False |
|
2502 | have_read = False | |
2372 | buffer_offset = 0 |
|
2503 | buffer_offset = 0 | |
2373 | size = len(reader) |
|
2504 | size = len(reader) | |
2374 | else: |
|
2505 | else: | |
2375 | raise ValueError( |
|
2506 | raise ValueError( | |
2376 | "must pass an object with a read() method or " |
|
2507 | "must pass an object with a read() method or " | |
2377 | "conforms to buffer protocol" |
|
2508 | "conforms to buffer protocol" | |
2378 | ) |
|
2509 | ) | |
2379 |
|
2510 | |||
2380 | if skip_bytes: |
|
2511 | if skip_bytes: | |
2381 | if have_read: |
|
2512 | if have_read: | |
2382 | reader.read(skip_bytes) |
|
2513 | reader.read(skip_bytes) | |
2383 | else: |
|
2514 | else: | |
2384 | if skip_bytes > size: |
|
2515 | if skip_bytes > size: | |
2385 | raise ValueError("skip_bytes larger than first input chunk") |
|
2516 | raise ValueError("skip_bytes larger than first input chunk") | |
2386 |
|
2517 | |||
2387 | buffer_offset = skip_bytes |
|
2518 | buffer_offset = skip_bytes | |
2388 |
|
2519 | |||
2389 | self._ensure_dctx() |
|
2520 | self._ensure_dctx() | |
2390 |
|
2521 | |||
2391 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
2522 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
2392 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2523 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2393 |
|
2524 | |||
2394 | dst_buffer = ffi.new("char[]", write_size) |
|
2525 | dst_buffer = ffi.new("char[]", write_size) | |
2395 | out_buffer.dst = dst_buffer |
|
2526 | out_buffer.dst = dst_buffer | |
2396 | out_buffer.size = len(dst_buffer) |
|
2527 | out_buffer.size = len(dst_buffer) | |
2397 | out_buffer.pos = 0 |
|
2528 | out_buffer.pos = 0 | |
2398 |
|
2529 | |||
2399 | while True: |
|
2530 | while True: | |
2400 | assert out_buffer.pos == 0 |
|
2531 | assert out_buffer.pos == 0 | |
2401 |
|
2532 | |||
2402 | if have_read: |
|
2533 | if have_read: | |
2403 | read_result = reader.read(read_size) |
|
2534 | read_result = reader.read(read_size) | |
2404 | else: |
|
2535 | else: | |
2405 | remaining = size - buffer_offset |
|
2536 | remaining = size - buffer_offset | |
2406 | slice_size = min(remaining, read_size) |
|
2537 | slice_size = min(remaining, read_size) | |
2407 | read_result = reader[buffer_offset : buffer_offset + slice_size] |
|
2538 | read_result = reader[buffer_offset : buffer_offset + slice_size] | |
2408 | buffer_offset += slice_size |
|
2539 | buffer_offset += slice_size | |
2409 |
|
2540 | |||
2410 | # No new input. Break out of read loop. |
|
2541 | # No new input. Break out of read loop. | |
2411 | if not read_result: |
|
2542 | if not read_result: | |
2412 | break |
|
2543 | break | |
2413 |
|
2544 | |||
2414 | # Feed all read data into decompressor and emit output until |
|
2545 | # Feed all read data into decompressor and emit output until | |
2415 | # exhausted. |
|
2546 | # exhausted. | |
2416 | read_buffer = ffi.from_buffer(read_result) |
|
2547 | read_buffer = ffi.from_buffer(read_result) | |
2417 | in_buffer.src = read_buffer |
|
2548 | in_buffer.src = read_buffer | |
2418 | in_buffer.size = len(read_buffer) |
|
2549 | in_buffer.size = len(read_buffer) | |
2419 | in_buffer.pos = 0 |
|
2550 | in_buffer.pos = 0 | |
2420 |
|
2551 | |||
2421 | while in_buffer.pos < in_buffer.size: |
|
2552 | while in_buffer.pos < in_buffer.size: | |
2422 | assert out_buffer.pos == 0 |
|
2553 | assert out_buffer.pos == 0 | |
2423 |
|
2554 | |||
2424 |
zresult = lib.ZSTD_decompressStream( |
|
2555 | zresult = lib.ZSTD_decompressStream( | |
|
2556 | self._dctx, out_buffer, in_buffer | |||
|
2557 | ) | |||
2425 | if lib.ZSTD_isError(zresult): |
|
2558 | if lib.ZSTD_isError(zresult): | |
2426 | raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) |
|
2559 | raise ZstdError( | |
|
2560 | "zstd decompress error: %s" % _zstd_error(zresult) | |||
|
2561 | ) | |||
2427 |
|
2562 | |||
2428 | if out_buffer.pos: |
|
2563 | if out_buffer.pos: | |
2429 | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] |
|
2564 | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |
2430 | out_buffer.pos = 0 |
|
2565 | out_buffer.pos = 0 | |
2431 | yield data |
|
2566 | yield data | |
2432 |
|
2567 | |||
2433 | if zresult == 0: |
|
2568 | if zresult == 0: | |
2434 | return |
|
2569 | return | |
2435 |
|
2570 | |||
2436 | # Repeat loop to collect more input data. |
|
2571 | # Repeat loop to collect more input data. | |
2437 | continue |
|
2572 | continue | |
2438 |
|
2573 | |||
2439 | # If we get here, input is exhausted. |
|
2574 | # If we get here, input is exhausted. | |
2440 |
|
2575 | |||
2441 | read_from = read_to_iter |
|
2576 | read_from = read_to_iter | |
2442 |
|
2577 | |||
2443 | def stream_writer( |
|
2578 | def stream_writer( | |
2444 | self, |
|
2579 | self, | |
2445 | writer, |
|
2580 | writer, | |
2446 | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, |
|
2581 | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |
2447 | write_return_read=False, |
|
2582 | write_return_read=False, | |
2448 | ): |
|
2583 | ): | |
2449 | if not hasattr(writer, "write"): |
|
2584 | if not hasattr(writer, "write"): | |
2450 | raise ValueError("must pass an object with a write() method") |
|
2585 | raise ValueError("must pass an object with a write() method") | |
2451 |
|
2586 | |||
2452 |
return ZstdDecompressionWriter( |
|
2587 | return ZstdDecompressionWriter( | |
|
2588 | self, writer, write_size, write_return_read | |||
|
2589 | ) | |||
2453 |
|
2590 | |||
2454 | write_to = stream_writer |
|
2591 | write_to = stream_writer | |
2455 |
|
2592 | |||
2456 | def copy_stream( |
|
2593 | def copy_stream( | |
2457 | self, |
|
2594 | self, | |
2458 | ifh, |
|
2595 | ifh, | |
2459 | ofh, |
|
2596 | ofh, | |
2460 | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, |
|
2597 | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | |
2461 | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, |
|
2598 | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |
2462 | ): |
|
2599 | ): | |
2463 | if not hasattr(ifh, "read"): |
|
2600 | if not hasattr(ifh, "read"): | |
2464 | raise ValueError("first argument must have a read() method") |
|
2601 | raise ValueError("first argument must have a read() method") | |
2465 | if not hasattr(ofh, "write"): |
|
2602 | if not hasattr(ofh, "write"): | |
2466 | raise ValueError("second argument must have a write() method") |
|
2603 | raise ValueError("second argument must have a write() method") | |
2467 |
|
2604 | |||
2468 | self._ensure_dctx() |
|
2605 | self._ensure_dctx() | |
2469 |
|
2606 | |||
2470 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
2607 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
2471 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2608 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2472 |
|
2609 | |||
2473 | dst_buffer = ffi.new("char[]", write_size) |
|
2610 | dst_buffer = ffi.new("char[]", write_size) | |
2474 | out_buffer.dst = dst_buffer |
|
2611 | out_buffer.dst = dst_buffer | |
2475 | out_buffer.size = write_size |
|
2612 | out_buffer.size = write_size | |
2476 | out_buffer.pos = 0 |
|
2613 | out_buffer.pos = 0 | |
2477 |
|
2614 | |||
2478 | total_read, total_write = 0, 0 |
|
2615 | total_read, total_write = 0, 0 | |
2479 |
|
2616 | |||
2480 | # Read all available input. |
|
2617 | # Read all available input. | |
2481 | while True: |
|
2618 | while True: | |
2482 | data = ifh.read(read_size) |
|
2619 | data = ifh.read(read_size) | |
2483 | if not data: |
|
2620 | if not data: | |
2484 | break |
|
2621 | break | |
2485 |
|
2622 | |||
2486 | data_buffer = ffi.from_buffer(data) |
|
2623 | data_buffer = ffi.from_buffer(data) | |
2487 | total_read += len(data_buffer) |
|
2624 | total_read += len(data_buffer) | |
2488 | in_buffer.src = data_buffer |
|
2625 | in_buffer.src = data_buffer | |
2489 | in_buffer.size = len(data_buffer) |
|
2626 | in_buffer.size = len(data_buffer) | |
2490 | in_buffer.pos = 0 |
|
2627 | in_buffer.pos = 0 | |
2491 |
|
2628 | |||
2492 | # Flush all read data to output. |
|
2629 | # Flush all read data to output. | |
2493 | while in_buffer.pos < in_buffer.size: |
|
2630 | while in_buffer.pos < in_buffer.size: | |
2494 |
zresult = lib.ZSTD_decompressStream( |
|
2631 | zresult = lib.ZSTD_decompressStream( | |
|
2632 | self._dctx, out_buffer, in_buffer | |||
|
2633 | ) | |||
2495 | if lib.ZSTD_isError(zresult): |
|
2634 | if lib.ZSTD_isError(zresult): | |
2496 | raise ZstdError( |
|
2635 | raise ZstdError( | |
2497 | "zstd decompressor error: %s" % _zstd_error(zresult) |
|
2636 | "zstd decompressor error: %s" % _zstd_error(zresult) | |
2498 | ) |
|
2637 | ) | |
2499 |
|
2638 | |||
2500 | if out_buffer.pos: |
|
2639 | if out_buffer.pos: | |
2501 | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) |
|
2640 | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | |
2502 | total_write += out_buffer.pos |
|
2641 | total_write += out_buffer.pos | |
2503 | out_buffer.pos = 0 |
|
2642 | out_buffer.pos = 0 | |
2504 |
|
2643 | |||
2505 | # Continue loop to keep reading. |
|
2644 | # Continue loop to keep reading. | |
2506 |
|
2645 | |||
2507 | return total_read, total_write |
|
2646 | return total_read, total_write | |
2508 |
|
2647 | |||
2509 | def decompress_content_dict_chain(self, frames): |
|
2648 | def decompress_content_dict_chain(self, frames): | |
2510 | if not isinstance(frames, list): |
|
2649 | if not isinstance(frames, list): | |
2511 | raise TypeError("argument must be a list") |
|
2650 | raise TypeError("argument must be a list") | |
2512 |
|
2651 | |||
2513 | if not frames: |
|
2652 | if not frames: | |
2514 | raise ValueError("empty input chain") |
|
2653 | raise ValueError("empty input chain") | |
2515 |
|
2654 | |||
2516 | # First chunk should not be using a dictionary. We handle it specially. |
|
2655 | # First chunk should not be using a dictionary. We handle it specially. | |
2517 | chunk = frames[0] |
|
2656 | chunk = frames[0] | |
2518 | if not isinstance(chunk, bytes_type): |
|
2657 | if not isinstance(chunk, bytes_type): | |
2519 | raise ValueError("chunk 0 must be bytes") |
|
2658 | raise ValueError("chunk 0 must be bytes") | |
2520 |
|
2659 | |||
2521 | # All chunks should be zstd frames and should have content size set. |
|
2660 | # All chunks should be zstd frames and should have content size set. | |
2522 | chunk_buffer = ffi.from_buffer(chunk) |
|
2661 | chunk_buffer = ffi.from_buffer(chunk) | |
2523 | params = ffi.new("ZSTD_frameHeader *") |
|
2662 | params = ffi.new("ZSTD_frameHeader *") | |
2524 |
zresult = lib.ZSTD_getFrameHeader( |
|
2663 | zresult = lib.ZSTD_getFrameHeader( | |
|
2664 | params, chunk_buffer, len(chunk_buffer) | |||
|
2665 | ) | |||
2525 | if lib.ZSTD_isError(zresult): |
|
2666 | if lib.ZSTD_isError(zresult): | |
2526 | raise ValueError("chunk 0 is not a valid zstd frame") |
|
2667 | raise ValueError("chunk 0 is not a valid zstd frame") | |
2527 | elif zresult: |
|
2668 | elif zresult: | |
2528 | raise ValueError("chunk 0 is too small to contain a zstd frame") |
|
2669 | raise ValueError("chunk 0 is too small to contain a zstd frame") | |
2529 |
|
2670 | |||
2530 | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: |
|
2671 | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | |
2531 | raise ValueError("chunk 0 missing content size in frame") |
|
2672 | raise ValueError("chunk 0 missing content size in frame") | |
2532 |
|
2673 | |||
2533 | self._ensure_dctx(load_dict=False) |
|
2674 | self._ensure_dctx(load_dict=False) | |
2534 |
|
2675 | |||
2535 | last_buffer = ffi.new("char[]", params.frameContentSize) |
|
2676 | last_buffer = ffi.new("char[]", params.frameContentSize) | |
2536 |
|
2677 | |||
2537 | out_buffer = ffi.new("ZSTD_outBuffer *") |
|
2678 | out_buffer = ffi.new("ZSTD_outBuffer *") | |
2538 | out_buffer.dst = last_buffer |
|
2679 | out_buffer.dst = last_buffer | |
2539 | out_buffer.size = len(last_buffer) |
|
2680 | out_buffer.size = len(last_buffer) | |
2540 | out_buffer.pos = 0 |
|
2681 | out_buffer.pos = 0 | |
2541 |
|
2682 | |||
2542 | in_buffer = ffi.new("ZSTD_inBuffer *") |
|
2683 | in_buffer = ffi.new("ZSTD_inBuffer *") | |
2543 | in_buffer.src = chunk_buffer |
|
2684 | in_buffer.src = chunk_buffer | |
2544 | in_buffer.size = len(chunk_buffer) |
|
2685 | in_buffer.size = len(chunk_buffer) | |
2545 | in_buffer.pos = 0 |
|
2686 | in_buffer.pos = 0 | |
2546 |
|
2687 | |||
2547 | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) |
|
2688 | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | |
2548 | if lib.ZSTD_isError(zresult): |
|
2689 | if lib.ZSTD_isError(zresult): | |
2549 | raise ZstdError("could not decompress chunk 0: %s" % _zstd_error(zresult)) |
|
2690 | raise ZstdError( | |
|
2691 | "could not decompress chunk 0: %s" % _zstd_error(zresult) | |||
|
2692 | ) | |||
2550 | elif zresult: |
|
2693 | elif zresult: | |
2551 | raise ZstdError("chunk 0 did not decompress full frame") |
|
2694 | raise ZstdError("chunk 0 did not decompress full frame") | |
2552 |
|
2695 | |||
2553 | # Special case of chain length of 1 |
|
2696 | # Special case of chain length of 1 | |
2554 | if len(frames) == 1: |
|
2697 | if len(frames) == 1: | |
2555 | return ffi.buffer(last_buffer, len(last_buffer))[:] |
|
2698 | return ffi.buffer(last_buffer, len(last_buffer))[:] | |
2556 |
|
2699 | |||
2557 | i = 1 |
|
2700 | i = 1 | |
2558 | while i < len(frames): |
|
2701 | while i < len(frames): | |
2559 | chunk = frames[i] |
|
2702 | chunk = frames[i] | |
2560 | if not isinstance(chunk, bytes_type): |
|
2703 | if not isinstance(chunk, bytes_type): | |
2561 | raise ValueError("chunk %d must be bytes" % i) |
|
2704 | raise ValueError("chunk %d must be bytes" % i) | |
2562 |
|
2705 | |||
2563 | chunk_buffer = ffi.from_buffer(chunk) |
|
2706 | chunk_buffer = ffi.from_buffer(chunk) | |
2564 |
zresult = lib.ZSTD_getFrameHeader( |
|
2707 | zresult = lib.ZSTD_getFrameHeader( | |
|
2708 | params, chunk_buffer, len(chunk_buffer) | |||
|
2709 | ) | |||
2565 | if lib.ZSTD_isError(zresult): |
|
2710 | if lib.ZSTD_isError(zresult): | |
2566 | raise ValueError("chunk %d is not a valid zstd frame" % i) |
|
2711 | raise ValueError("chunk %d is not a valid zstd frame" % i) | |
2567 | elif zresult: |
|
2712 | elif zresult: | |
2568 | raise ValueError("chunk %d is too small to contain a zstd frame" % i) |
|
2713 | raise ValueError( | |
|
2714 | "chunk %d is too small to contain a zstd frame" % i | |||
|
2715 | ) | |||
2569 |
|
2716 | |||
2570 | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: |
|
2717 | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | |
2571 | raise ValueError("chunk %d missing content size in frame" % i) |
|
2718 | raise ValueError("chunk %d missing content size in frame" % i) | |
2572 |
|
2719 | |||
2573 | dest_buffer = ffi.new("char[]", params.frameContentSize) |
|
2720 | dest_buffer = ffi.new("char[]", params.frameContentSize) | |
2574 |
|
2721 | |||
2575 | out_buffer.dst = dest_buffer |
|
2722 | out_buffer.dst = dest_buffer | |
2576 | out_buffer.size = len(dest_buffer) |
|
2723 | out_buffer.size = len(dest_buffer) | |
2577 | out_buffer.pos = 0 |
|
2724 | out_buffer.pos = 0 | |
2578 |
|
2725 | |||
2579 | in_buffer.src = chunk_buffer |
|
2726 | in_buffer.src = chunk_buffer | |
2580 | in_buffer.size = len(chunk_buffer) |
|
2727 | in_buffer.size = len(chunk_buffer) | |
2581 | in_buffer.pos = 0 |
|
2728 | in_buffer.pos = 0 | |
2582 |
|
2729 | |||
2583 |
zresult = lib.ZSTD_decompressStream( |
|
2730 | zresult = lib.ZSTD_decompressStream( | |
|
2731 | self._dctx, out_buffer, in_buffer | |||
|
2732 | ) | |||
2584 | if lib.ZSTD_isError(zresult): |
|
2733 | if lib.ZSTD_isError(zresult): | |
2585 | raise ZstdError( |
|
2734 | raise ZstdError( | |
2586 | "could not decompress chunk %d: %s" % _zstd_error(zresult) |
|
2735 | "could not decompress chunk %d: %s" % _zstd_error(zresult) | |
2587 | ) |
|
2736 | ) | |
2588 | elif zresult: |
|
2737 | elif zresult: | |
2589 | raise ZstdError("chunk %d did not decompress full frame" % i) |
|
2738 | raise ZstdError("chunk %d did not decompress full frame" % i) | |
2590 |
|
2739 | |||
2591 | last_buffer = dest_buffer |
|
2740 | last_buffer = dest_buffer | |
2592 | i += 1 |
|
2741 | i += 1 | |
2593 |
|
2742 | |||
2594 | return ffi.buffer(last_buffer, len(last_buffer))[:] |
|
2743 | return ffi.buffer(last_buffer, len(last_buffer))[:] | |
2595 |
|
2744 | |||
2596 | def _ensure_dctx(self, load_dict=True): |
|
2745 | def _ensure_dctx(self, load_dict=True): | |
2597 | lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only) |
|
2746 | lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only) | |
2598 |
|
2747 | |||
2599 | if self._max_window_size: |
|
2748 | if self._max_window_size: | |
2600 |
zresult = lib.ZSTD_DCtx_setMaxWindowSize( |
|
2749 | zresult = lib.ZSTD_DCtx_setMaxWindowSize( | |
|
2750 | self._dctx, self._max_window_size | |||
|
2751 | ) | |||
2601 | if lib.ZSTD_isError(zresult): |
|
2752 | if lib.ZSTD_isError(zresult): | |
2602 | raise ZstdError( |
|
2753 | raise ZstdError( | |
2603 | "unable to set max window size: %s" % _zstd_error(zresult) |
|
2754 | "unable to set max window size: %s" % _zstd_error(zresult) | |
2604 | ) |
|
2755 | ) | |
2605 |
|
2756 | |||
2606 | zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format) |
|
2757 | zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format) | |
2607 | if lib.ZSTD_isError(zresult): |
|
2758 | if lib.ZSTD_isError(zresult): | |
2608 | raise ZstdError("unable to set decoding format: %s" % _zstd_error(zresult)) |
|
2759 | raise ZstdError( | |
|
2760 | "unable to set decoding format: %s" % _zstd_error(zresult) | |||
|
2761 | ) | |||
2609 |
|
2762 | |||
2610 | if self._dict_data and load_dict: |
|
2763 | if self._dict_data and load_dict: | |
2611 | zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict) |
|
2764 | zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict) | |
2612 | if lib.ZSTD_isError(zresult): |
|
2765 | if lib.ZSTD_isError(zresult): | |
2613 | raise ZstdError( |
|
2766 | raise ZstdError( | |
2614 |
"unable to reference prepared dictionary: %s" |
|
2767 | "unable to reference prepared dictionary: %s" | |
|
2768 | % _zstd_error(zresult) | |||
2615 | ) |
|
2769 | ) |
@@ -1,5 +1,5 b'' | |||||
1 | #require black |
|
1 | #require black | |
2 |
|
2 | |||
3 | $ cd $RUNTESTDIR/.. |
|
3 | $ cd $RUNTESTDIR/.. | |
4 |
$ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/** |
|
4 | $ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/**'` | |
5 |
|
5 |
General Comments 0
You need to be logged in to leave comments.
Login now