test_decompressor.py
478 lines
| 15.0 KiB
| text/x-python
|
PythonLexer
Gregory Szorc
|
r30435 | import io | ||
import random | ||||
import struct | ||||
import sys | ||||
try: | ||||
import unittest2 as unittest | ||||
except ImportError: | ||||
import unittest | ||||
import zstd | ||||
from .common import OpCountingBytesIO | ||||
if sys.version_info[0] >= 3: | ||||
next = lambda it: it.__next__() | ||||
else: | ||||
next = lambda it: it.next() | ||||
class TestDecompressor_decompress(unittest.TestCase): | ||||
def test_empty_input(self): | ||||
dctx = zstd.ZstdDecompressor() | ||||
with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): | ||||
dctx.decompress(b'') | ||||
def test_invalid_input(self): | ||||
dctx = zstd.ZstdDecompressor() | ||||
with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): | ||||
dctx.decompress(b'foobar') | ||||
def test_no_content_size_in_frame(self): | ||||
cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
compressed = cctx.compress(b'foobar') | ||||
dctx = zstd.ZstdDecompressor() | ||||
with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): | ||||
dctx.decompress(compressed) | ||||
def test_content_size_present(self): | ||||
cctx = zstd.ZstdCompressor(write_content_size=True) | ||||
compressed = cctx.compress(b'foobar') | ||||
dctx = zstd.ZstdDecompressor() | ||||
decompressed = dctx.decompress(compressed) | ||||
self.assertEqual(decompressed, b'foobar') | ||||
def test_max_output_size(self): | ||||
cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
source = b'foobar' * 256 | ||||
compressed = cctx.compress(source) | ||||
dctx = zstd.ZstdDecompressor() | ||||
# Will fit into buffer exactly the size of input. | ||||
decompressed = dctx.decompress(compressed, max_output_size=len(source)) | ||||
self.assertEqual(decompressed, source) | ||||
# Input size - 1 fails | ||||
with self.assertRaisesRegexp(zstd.ZstdError, 'Destination buffer is too small'): | ||||
dctx.decompress(compressed, max_output_size=len(source) - 1) | ||||
# Input size + 1 works | ||||
decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) | ||||
self.assertEqual(decompressed, source) | ||||
# A much larger buffer works. | ||||
decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) | ||||
self.assertEqual(decompressed, source) | ||||
def test_stupidly_large_output_buffer(self): | ||||
cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
compressed = cctx.compress(b'foobar' * 256) | ||||
dctx = zstd.ZstdDecompressor() | ||||
# Will get OverflowError on some Python distributions that can't | ||||
# handle really large integers. | ||||
with self.assertRaises((MemoryError, OverflowError)): | ||||
dctx.decompress(compressed, max_output_size=2**62) | ||||
def test_dictionary(self): | ||||
samples = [] | ||||
for i in range(128): | ||||
samples.append(b'foo' * 64) | ||||
samples.append(b'bar' * 64) | ||||
samples.append(b'foobar' * 64) | ||||
d = zstd.train_dictionary(8192, samples) | ||||
orig = b'foobar' * 16384 | ||||
cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) | ||||
compressed = cctx.compress(orig) | ||||
dctx = zstd.ZstdDecompressor(dict_data=d) | ||||
decompressed = dctx.decompress(compressed) | ||||
self.assertEqual(decompressed, orig) | ||||
def test_dictionary_multiple(self): | ||||
samples = [] | ||||
for i in range(128): | ||||
samples.append(b'foo' * 64) | ||||
samples.append(b'bar' * 64) | ||||
samples.append(b'foobar' * 64) | ||||
d = zstd.train_dictionary(8192, samples) | ||||
sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192) | ||||
compressed = [] | ||||
cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) | ||||
for source in sources: | ||||
compressed.append(cctx.compress(source)) | ||||
dctx = zstd.ZstdDecompressor(dict_data=d) | ||||
for i in range(len(sources)): | ||||
decompressed = dctx.decompress(compressed[i]) | ||||
self.assertEqual(decompressed, sources[i]) | ||||
class TestDecompressor_copy_stream(unittest.TestCase): | ||||
def test_no_read(self): | ||||
source = object() | ||||
dest = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor() | ||||
with self.assertRaises(ValueError): | ||||
dctx.copy_stream(source, dest) | ||||
def test_no_write(self): | ||||
source = io.BytesIO() | ||||
dest = object() | ||||
dctx = zstd.ZstdDecompressor() | ||||
with self.assertRaises(ValueError): | ||||
dctx.copy_stream(source, dest) | ||||
def test_empty(self): | ||||
source = io.BytesIO() | ||||
dest = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor() | ||||
# TODO should this raise an error? | ||||
r, w = dctx.copy_stream(source, dest) | ||||
self.assertEqual(r, 0) | ||||
self.assertEqual(w, 0) | ||||
self.assertEqual(dest.getvalue(), b'') | ||||
def test_large_data(self): | ||||
source = io.BytesIO() | ||||
for i in range(255): | ||||
source.write(struct.Struct('>B').pack(i) * 16384) | ||||
source.seek(0) | ||||
compressed = io.BytesIO() | ||||
cctx = zstd.ZstdCompressor() | ||||
cctx.copy_stream(source, compressed) | ||||
compressed.seek(0) | ||||
dest = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor() | ||||
r, w = dctx.copy_stream(compressed, dest) | ||||
self.assertEqual(r, len(compressed.getvalue())) | ||||
self.assertEqual(w, len(source.getvalue())) | ||||
def test_read_write_size(self): | ||||
source = OpCountingBytesIO(zstd.ZstdCompressor().compress( | ||||
b'foobarfoobar')) | ||||
dest = OpCountingBytesIO() | ||||
dctx = zstd.ZstdDecompressor() | ||||
r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) | ||||
self.assertEqual(r, len(source.getvalue())) | ||||
self.assertEqual(w, len(b'foobarfoobar')) | ||||
self.assertEqual(source._read_count, len(source.getvalue()) + 1) | ||||
self.assertEqual(dest._write_count, len(dest.getvalue())) | ||||
class TestDecompressor_decompressobj(unittest.TestCase): | ||||
def test_simple(self): | ||||
data = zstd.ZstdCompressor(level=1).compress(b'foobar') | ||||
dctx = zstd.ZstdDecompressor() | ||||
dobj = dctx.decompressobj() | ||||
self.assertEqual(dobj.decompress(data), b'foobar') | ||||
def test_reuse(self): | ||||
data = zstd.ZstdCompressor(level=1).compress(b'foobar') | ||||
dctx = zstd.ZstdDecompressor() | ||||
dobj = dctx.decompressobj() | ||||
dobj.decompress(data) | ||||
with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): | ||||
dobj.decompress(data) | ||||
def decompress_via_writer(data): | ||||
buffer = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor() | ||||
with dctx.write_to(buffer) as decompressor: | ||||
decompressor.write(data) | ||||
return buffer.getvalue() | ||||
class TestDecompressor_write_to(unittest.TestCase): | ||||
def test_empty_roundtrip(self): | ||||
cctx = zstd.ZstdCompressor() | ||||
empty = cctx.compress(b'') | ||||
self.assertEqual(decompress_via_writer(empty), b'') | ||||
def test_large_roundtrip(self): | ||||
chunks = [] | ||||
for i in range(255): | ||||
chunks.append(struct.Struct('>B').pack(i) * 16384) | ||||
orig = b''.join(chunks) | ||||
cctx = zstd.ZstdCompressor() | ||||
compressed = cctx.compress(orig) | ||||
self.assertEqual(decompress_via_writer(compressed), orig) | ||||
def test_multiple_calls(self): | ||||
chunks = [] | ||||
for i in range(255): | ||||
for j in range(255): | ||||
chunks.append(struct.Struct('>B').pack(j) * i) | ||||
orig = b''.join(chunks) | ||||
cctx = zstd.ZstdCompressor() | ||||
compressed = cctx.compress(orig) | ||||
buffer = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor() | ||||
with dctx.write_to(buffer) as decompressor: | ||||
pos = 0 | ||||
while pos < len(compressed): | ||||
pos2 = pos + 8192 | ||||
decompressor.write(compressed[pos:pos2]) | ||||
pos += 8192 | ||||
self.assertEqual(buffer.getvalue(), orig) | ||||
def test_dictionary(self): | ||||
samples = [] | ||||
for i in range(128): | ||||
samples.append(b'foo' * 64) | ||||
samples.append(b'bar' * 64) | ||||
samples.append(b'foobar' * 64) | ||||
d = zstd.train_dictionary(8192, samples) | ||||
orig = b'foobar' * 16384 | ||||
buffer = io.BytesIO() | ||||
cctx = zstd.ZstdCompressor(dict_data=d) | ||||
with cctx.write_to(buffer) as compressor: | ||||
compressor.write(orig) | ||||
compressed = buffer.getvalue() | ||||
buffer = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor(dict_data=d) | ||||
with dctx.write_to(buffer) as decompressor: | ||||
decompressor.write(compressed) | ||||
self.assertEqual(buffer.getvalue(), orig) | ||||
def test_memory_size(self): | ||||
dctx = zstd.ZstdDecompressor() | ||||
buffer = io.BytesIO() | ||||
with dctx.write_to(buffer) as decompressor: | ||||
size = decompressor.memory_size() | ||||
self.assertGreater(size, 100000) | ||||
def test_write_size(self): | ||||
source = zstd.ZstdCompressor().compress(b'foobarfoobar') | ||||
dest = OpCountingBytesIO() | ||||
dctx = zstd.ZstdDecompressor() | ||||
with dctx.write_to(dest, write_size=1) as decompressor: | ||||
s = struct.Struct('>B') | ||||
for c in source: | ||||
if not isinstance(c, str): | ||||
c = s.pack(c) | ||||
decompressor.write(c) | ||||
self.assertEqual(dest.getvalue(), b'foobarfoobar') | ||||
self.assertEqual(dest._write_count, len(dest.getvalue())) | ||||
class TestDecompressor_read_from(unittest.TestCase): | ||||
def test_type_validation(self): | ||||
dctx = zstd.ZstdDecompressor() | ||||
# Object with read() works. | ||||
dctx.read_from(io.BytesIO()) | ||||
# Buffer protocol works. | ||||
dctx.read_from(b'foobar') | ||||
with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): | ||||
dctx.read_from(True) | ||||
def test_empty_input(self): | ||||
dctx = zstd.ZstdDecompressor() | ||||
source = io.BytesIO() | ||||
it = dctx.read_from(source) | ||||
# TODO this is arguably wrong. Should get an error about missing frame foo. | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
it = dctx.read_from(b'') | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
def test_invalid_input(self): | ||||
dctx = zstd.ZstdDecompressor() | ||||
source = io.BytesIO(b'foobar') | ||||
it = dctx.read_from(source) | ||||
with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): | ||||
next(it) | ||||
it = dctx.read_from(b'foobar') | ||||
with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): | ||||
next(it) | ||||
def test_empty_roundtrip(self): | ||||
cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
empty = cctx.compress(b'') | ||||
source = io.BytesIO(empty) | ||||
source.seek(0) | ||||
dctx = zstd.ZstdDecompressor() | ||||
it = dctx.read_from(source) | ||||
# No chunks should be emitted since there is no data. | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
# Again for good measure. | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
def test_skip_bytes_too_large(self): | ||||
dctx = zstd.ZstdDecompressor() | ||||
with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'): | ||||
dctx.read_from(b'', skip_bytes=1, read_size=1) | ||||
with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'): | ||||
b''.join(dctx.read_from(b'foobar', skip_bytes=10)) | ||||
def test_skip_bytes(self): | ||||
cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
compressed = cctx.compress(b'foobar') | ||||
dctx = zstd.ZstdDecompressor() | ||||
output = b''.join(dctx.read_from(b'hdr' + compressed, skip_bytes=3)) | ||||
self.assertEqual(output, b'foobar') | ||||
def test_large_output(self): | ||||
source = io.BytesIO() | ||||
source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) | ||||
source.write(b'o') | ||||
source.seek(0) | ||||
cctx = zstd.ZstdCompressor(level=1) | ||||
compressed = io.BytesIO(cctx.compress(source.getvalue())) | ||||
compressed.seek(0) | ||||
dctx = zstd.ZstdDecompressor() | ||||
it = dctx.read_from(compressed) | ||||
chunks = [] | ||||
chunks.append(next(it)) | ||||
chunks.append(next(it)) | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
decompressed = b''.join(chunks) | ||||
self.assertEqual(decompressed, source.getvalue()) | ||||
# And again with buffer protocol. | ||||
it = dctx.read_from(compressed.getvalue()) | ||||
chunks = [] | ||||
chunks.append(next(it)) | ||||
chunks.append(next(it)) | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
decompressed = b''.join(chunks) | ||||
self.assertEqual(decompressed, source.getvalue()) | ||||
def test_large_input(self): | ||||
bytes = list(struct.Struct('>B').pack(i) for i in range(256)) | ||||
compressed = io.BytesIO() | ||||
input_size = 0 | ||||
cctx = zstd.ZstdCompressor(level=1) | ||||
with cctx.write_to(compressed) as compressor: | ||||
while True: | ||||
compressor.write(random.choice(bytes)) | ||||
input_size += 1 | ||||
have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | ||||
have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | ||||
if have_compressed and have_raw: | ||||
break | ||||
compressed.seek(0) | ||||
self.assertGreater(len(compressed.getvalue()), | ||||
zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) | ||||
dctx = zstd.ZstdDecompressor() | ||||
it = dctx.read_from(compressed) | ||||
chunks = [] | ||||
chunks.append(next(it)) | ||||
chunks.append(next(it)) | ||||
chunks.append(next(it)) | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
decompressed = b''.join(chunks) | ||||
self.assertEqual(len(decompressed), input_size) | ||||
# And again with buffer protocol. | ||||
it = dctx.read_from(compressed.getvalue()) | ||||
chunks = [] | ||||
chunks.append(next(it)) | ||||
chunks.append(next(it)) | ||||
chunks.append(next(it)) | ||||
with self.assertRaises(StopIteration): | ||||
next(it) | ||||
decompressed = b''.join(chunks) | ||||
self.assertEqual(len(decompressed), input_size) | ||||
def test_interesting(self): | ||||
# Found this edge case via fuzzing. | ||||
cctx = zstd.ZstdCompressor(level=1) | ||||
source = io.BytesIO() | ||||
compressed = io.BytesIO() | ||||
with cctx.write_to(compressed) as compressor: | ||||
for i in range(256): | ||||
chunk = b'\0' * 1024 | ||||
compressor.write(chunk) | ||||
source.write(chunk) | ||||
dctx = zstd.ZstdDecompressor() | ||||
simple = dctx.decompress(compressed.getvalue(), | ||||
max_output_size=len(source.getvalue())) | ||||
self.assertEqual(simple, source.getvalue()) | ||||
compressed.seek(0) | ||||
streamed = b''.join(dctx.read_from(compressed)) | ||||
self.assertEqual(streamed, source.getvalue()) | ||||
def test_read_write_size(self): | ||||
source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) | ||||
dctx = zstd.ZstdDecompressor() | ||||
for chunk in dctx.read_from(source, read_size=1, write_size=1): | ||||
self.assertEqual(len(chunk), 1) | ||||
self.assertEqual(source._read_count, len(source.getvalue())) | ||||