import io import random import struct import sys try: import unittest2 as unittest except ImportError: import unittest import zstd from .common import OpCountingBytesIO if sys.version_info[0] >= 3: next = lambda it: it.__next__() else: next = lambda it: it.next() class TestDecompressor_decompress(unittest.TestCase): def test_empty_input(self): dctx = zstd.ZstdDecompressor() with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): dctx.decompress(b'') def test_invalid_input(self): dctx = zstd.ZstdDecompressor() with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): dctx.decompress(b'foobar') def test_no_content_size_in_frame(self): cctx = zstd.ZstdCompressor(write_content_size=False) compressed = cctx.compress(b'foobar') dctx = zstd.ZstdDecompressor() with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): dctx.decompress(compressed) def test_content_size_present(self): cctx = zstd.ZstdCompressor(write_content_size=True) compressed = cctx.compress(b'foobar') dctx = zstd.ZstdDecompressor() decompressed = dctx.decompress(compressed) self.assertEqual(decompressed, b'foobar') def test_max_output_size(self): cctx = zstd.ZstdCompressor(write_content_size=False) source = b'foobar' * 256 compressed = cctx.compress(source) dctx = zstd.ZstdDecompressor() # Will fit into buffer exactly the size of input. decompressed = dctx.decompress(compressed, max_output_size=len(source)) self.assertEqual(decompressed, source) # Input size - 1 fails with self.assertRaisesRegexp(zstd.ZstdError, 'Destination buffer is too small'): dctx.decompress(compressed, max_output_size=len(source) - 1) # Input size + 1 works decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) self.assertEqual(decompressed, source) # A much larger buffer works. decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) self.assertEqual(decompressed, source) def test_stupidly_large_output_buffer(self): cctx = zstd.ZstdCompressor(write_content_size=False) compressed = cctx.compress(b'foobar' * 256) dctx = zstd.ZstdDecompressor() # Will get OverflowError on some Python distributions that can't # handle really large integers. with self.assertRaises((MemoryError, OverflowError)): dctx.decompress(compressed, max_output_size=2**62) def test_dictionary(self): samples = [] for i in range(128): samples.append(b'foo' * 64) samples.append(b'bar' * 64) samples.append(b'foobar' * 64) d = zstd.train_dictionary(8192, samples) orig = b'foobar' * 16384 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) compressed = cctx.compress(orig) dctx = zstd.ZstdDecompressor(dict_data=d) decompressed = dctx.decompress(compressed) self.assertEqual(decompressed, orig) def test_dictionary_multiple(self): samples = [] for i in range(128): samples.append(b'foo' * 64) samples.append(b'bar' * 64) samples.append(b'foobar' * 64) d = zstd.train_dictionary(8192, samples) sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192) compressed = [] cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) for source in sources: compressed.append(cctx.compress(source)) dctx = zstd.ZstdDecompressor(dict_data=d) for i in range(len(sources)): decompressed = dctx.decompress(compressed[i]) self.assertEqual(decompressed, sources[i]) class TestDecompressor_copy_stream(unittest.TestCase): def test_no_read(self): source = object() dest = io.BytesIO() dctx = zstd.ZstdDecompressor() with self.assertRaises(ValueError): dctx.copy_stream(source, dest) def test_no_write(self): source = io.BytesIO() dest = object() dctx = zstd.ZstdDecompressor() with self.assertRaises(ValueError): dctx.copy_stream(source, dest) def test_empty(self): source = io.BytesIO() dest = io.BytesIO() dctx = zstd.ZstdDecompressor() # TODO should this raise an error? r, w = dctx.copy_stream(source, dest) self.assertEqual(r, 0) self.assertEqual(w, 0) self.assertEqual(dest.getvalue(), b'') def test_large_data(self): source = io.BytesIO() for i in range(255): source.write(struct.Struct('>B').pack(i) * 16384) source.seek(0) compressed = io.BytesIO() cctx = zstd.ZstdCompressor() cctx.copy_stream(source, compressed) compressed.seek(0) dest = io.BytesIO() dctx = zstd.ZstdDecompressor() r, w = dctx.copy_stream(compressed, dest) self.assertEqual(r, len(compressed.getvalue())) self.assertEqual(w, len(source.getvalue())) def test_read_write_size(self): source = OpCountingBytesIO(zstd.ZstdCompressor().compress( b'foobarfoobar')) dest = OpCountingBytesIO() dctx = zstd.ZstdDecompressor() r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) self.assertEqual(r, len(source.getvalue())) self.assertEqual(w, len(b'foobarfoobar')) self.assertEqual(source._read_count, len(source.getvalue()) + 1) self.assertEqual(dest._write_count, len(dest.getvalue())) class TestDecompressor_decompressobj(unittest.TestCase): def test_simple(self): data = zstd.ZstdCompressor(level=1).compress(b'foobar') dctx = zstd.ZstdDecompressor() dobj = dctx.decompressobj() self.assertEqual(dobj.decompress(data), b'foobar') def test_reuse(self): data = zstd.ZstdCompressor(level=1).compress(b'foobar') dctx = zstd.ZstdDecompressor() dobj = dctx.decompressobj() dobj.decompress(data) with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): dobj.decompress(data) def decompress_via_writer(data): buffer = io.BytesIO() dctx = zstd.ZstdDecompressor() with dctx.write_to(buffer) as decompressor: decompressor.write(data) return buffer.getvalue() class TestDecompressor_write_to(unittest.TestCase): def test_empty_roundtrip(self): cctx = zstd.ZstdCompressor() empty = cctx.compress(b'') self.assertEqual(decompress_via_writer(empty), b'') def test_large_roundtrip(self): chunks = [] for i in range(255): chunks.append(struct.Struct('>B').pack(i) * 16384) orig = b''.join(chunks) cctx = zstd.ZstdCompressor() compressed = cctx.compress(orig) self.assertEqual(decompress_via_writer(compressed), orig) def test_multiple_calls(self): chunks = [] for i in range(255): for j in range(255): chunks.append(struct.Struct('>B').pack(j) * i) orig = b''.join(chunks) cctx = zstd.ZstdCompressor() compressed = cctx.compress(orig) buffer = io.BytesIO() dctx = zstd.ZstdDecompressor() with dctx.write_to(buffer) as decompressor: pos = 0 while pos < len(compressed): pos2 = pos + 8192 decompressor.write(compressed[pos:pos2]) pos += 8192 self.assertEqual(buffer.getvalue(), orig) def test_dictionary(self): samples = [] for i in range(128): samples.append(b'foo' * 64) samples.append(b'bar' * 64) samples.append(b'foobar' * 64) d = zstd.train_dictionary(8192, samples) orig = b'foobar' * 16384 buffer = io.BytesIO() cctx = zstd.ZstdCompressor(dict_data=d) with cctx.write_to(buffer) as compressor: compressor.write(orig) compressed = buffer.getvalue() buffer = io.BytesIO() dctx = zstd.ZstdDecompressor(dict_data=d) with dctx.write_to(buffer) as decompressor: decompressor.write(compressed) self.assertEqual(buffer.getvalue(), orig) def test_memory_size(self): dctx = zstd.ZstdDecompressor() buffer = io.BytesIO() with dctx.write_to(buffer) as decompressor: size = decompressor.memory_size() self.assertGreater(size, 100000) def test_write_size(self): source = zstd.ZstdCompressor().compress(b'foobarfoobar') dest = OpCountingBytesIO() dctx = zstd.ZstdDecompressor() with dctx.write_to(dest, write_size=1) as decompressor: s = struct.Struct('>B') for c in source: if not isinstance(c, str): c = s.pack(c) decompressor.write(c) self.assertEqual(dest.getvalue(), b'foobarfoobar') self.assertEqual(dest._write_count, len(dest.getvalue())) class TestDecompressor_read_from(unittest.TestCase): def test_type_validation(self): dctx = zstd.ZstdDecompressor() # Object with read() works. dctx.read_from(io.BytesIO()) # Buffer protocol works. dctx.read_from(b'foobar') with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): dctx.read_from(True) def test_empty_input(self): dctx = zstd.ZstdDecompressor() source = io.BytesIO() it = dctx.read_from(source) # TODO this is arguably wrong. Should get an error about missing frame foo. with self.assertRaises(StopIteration): next(it) it = dctx.read_from(b'') with self.assertRaises(StopIteration): next(it) def test_invalid_input(self): dctx = zstd.ZstdDecompressor() source = io.BytesIO(b'foobar') it = dctx.read_from(source) with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): next(it) it = dctx.read_from(b'foobar') with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): next(it) def test_empty_roundtrip(self): cctx = zstd.ZstdCompressor(level=1, write_content_size=False) empty = cctx.compress(b'') source = io.BytesIO(empty) source.seek(0) dctx = zstd.ZstdDecompressor() it = dctx.read_from(source) # No chunks should be emitted since there is no data. with self.assertRaises(StopIteration): next(it) # Again for good measure. with self.assertRaises(StopIteration): next(it) def test_skip_bytes_too_large(self): dctx = zstd.ZstdDecompressor() with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'): dctx.read_from(b'', skip_bytes=1, read_size=1) with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'): b''.join(dctx.read_from(b'foobar', skip_bytes=10)) def test_skip_bytes(self): cctx = zstd.ZstdCompressor(write_content_size=False) compressed = cctx.compress(b'foobar') dctx = zstd.ZstdDecompressor() output = b''.join(dctx.read_from(b'hdr' + compressed, skip_bytes=3)) self.assertEqual(output, b'foobar') def test_large_output(self): source = io.BytesIO() source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) source.write(b'o') source.seek(0) cctx = zstd.ZstdCompressor(level=1) compressed = io.BytesIO(cctx.compress(source.getvalue())) compressed.seek(0) dctx = zstd.ZstdDecompressor() it = dctx.read_from(compressed) chunks = [] chunks.append(next(it)) chunks.append(next(it)) with self.assertRaises(StopIteration): next(it) decompressed = b''.join(chunks) self.assertEqual(decompressed, source.getvalue()) # And again with buffer protocol. it = dctx.read_from(compressed.getvalue()) chunks = [] chunks.append(next(it)) chunks.append(next(it)) with self.assertRaises(StopIteration): next(it) decompressed = b''.join(chunks) self.assertEqual(decompressed, source.getvalue()) def test_large_input(self): bytes = list(struct.Struct('>B').pack(i) for i in range(256)) compressed = io.BytesIO() input_size = 0 cctx = zstd.ZstdCompressor(level=1) with cctx.write_to(compressed) as compressor: while True: compressor.write(random.choice(bytes)) input_size += 1 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 if have_compressed and have_raw: break compressed.seek(0) self.assertGreater(len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) dctx = zstd.ZstdDecompressor() it = dctx.read_from(compressed) chunks = [] chunks.append(next(it)) chunks.append(next(it)) chunks.append(next(it)) with self.assertRaises(StopIteration): next(it) decompressed = b''.join(chunks) self.assertEqual(len(decompressed), input_size) # And again with buffer protocol. it = dctx.read_from(compressed.getvalue()) chunks = [] chunks.append(next(it)) chunks.append(next(it)) chunks.append(next(it)) with self.assertRaises(StopIteration): next(it) decompressed = b''.join(chunks) self.assertEqual(len(decompressed), input_size) def test_interesting(self): # Found this edge case via fuzzing. cctx = zstd.ZstdCompressor(level=1) source = io.BytesIO() compressed = io.BytesIO() with cctx.write_to(compressed) as compressor: for i in range(256): chunk = b'\0' * 1024 compressor.write(chunk) source.write(chunk) dctx = zstd.ZstdDecompressor() simple = dctx.decompress(compressed.getvalue(), max_output_size=len(source.getvalue())) self.assertEqual(simple, source.getvalue()) compressed.seek(0) streamed = b''.join(dctx.read_from(compressed)) self.assertEqual(streamed, source.getvalue()) def test_read_write_size(self): source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) dctx = zstd.ZstdDecompressor() for chunk in dctx.read_from(source, read_size=1, write_size=1): self.assertEqual(len(chunk), 1) self.assertEqual(source._read_count, len(source.getvalue()))