|
|
import io
|
|
|
import os
|
|
|
import random
|
|
|
import struct
|
|
|
import sys
|
|
|
import tempfile
|
|
|
import unittest
|
|
|
|
|
|
import zstandard as zstd
|
|
|
|
|
|
from .common import (
|
|
|
generate_samples,
|
|
|
make_cffi,
|
|
|
NonClosingBytesIO,
|
|
|
OpCountingBytesIO,
|
|
|
)
|
|
|
|
|
|
|
|
|
if sys.version_info[0] >= 3:
|
|
|
next = lambda it: it.__next__()
|
|
|
else:
|
|
|
next = lambda it: it.next()
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestFrameHeaderSize(unittest.TestCase):
|
|
|
def test_empty(self):
|
|
|
with self.assertRaisesRegexp(
|
|
|
zstd.ZstdError, 'could not determine frame header size: Src size '
|
|
|
'is incorrect'):
|
|
|
zstd.frame_header_size(b'')
|
|
|
|
|
|
def test_too_small(self):
|
|
|
with self.assertRaisesRegexp(
|
|
|
zstd.ZstdError, 'could not determine frame header size: Src size '
|
|
|
'is incorrect'):
|
|
|
zstd.frame_header_size(b'foob')
|
|
|
|
|
|
def test_basic(self):
|
|
|
# It doesn't matter that it isn't a valid frame.
|
|
|
self.assertEqual(zstd.frame_header_size(b'long enough but no magic'), 6)
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestFrameContentSize(unittest.TestCase):
|
|
|
def test_empty(self):
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError,
|
|
|
'error when determining content size'):
|
|
|
zstd.frame_content_size(b'')
|
|
|
|
|
|
def test_too_small(self):
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError,
|
|
|
'error when determining content size'):
|
|
|
zstd.frame_content_size(b'foob')
|
|
|
|
|
|
def test_bad_frame(self):
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError,
|
|
|
'error when determining content size'):
|
|
|
zstd.frame_content_size(b'invalid frame header')
|
|
|
|
|
|
def test_unknown(self):
|
|
|
cctx = zstd.ZstdCompressor(write_content_size=False)
|
|
|
frame = cctx.compress(b'foobar')
|
|
|
|
|
|
self.assertEqual(zstd.frame_content_size(frame), -1)
|
|
|
|
|
|
def test_empty(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
frame = cctx.compress(b'')
|
|
|
|
|
|
self.assertEqual(zstd.frame_content_size(frame), 0)
|
|
|
|
|
|
def test_basic(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
frame = cctx.compress(b'foobar')
|
|
|
|
|
|
self.assertEqual(zstd.frame_content_size(frame), 6)
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor(unittest.TestCase):
|
|
|
def test_memory_size(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
self.assertGreater(dctx.memory_size(), 100)
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor_decompress(unittest.TestCase):
|
|
|
def test_empty_input(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError, 'error determining content size from frame header'):
|
|
|
dctx.decompress(b'')
|
|
|
|
|
|
def test_invalid_input(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError, 'error determining content size from frame header'):
|
|
|
dctx.decompress(b'foobar')
|
|
|
|
|
|
def test_input_types(self):
|
|
|
cctx = zstd.ZstdCompressor(level=1)
|
|
|
compressed = cctx.compress(b'foo')
|
|
|
|
|
|
mutable_array = bytearray(len(compressed))
|
|
|
mutable_array[:] = compressed
|
|
|
|
|
|
sources = [
|
|
|
memoryview(compressed),
|
|
|
bytearray(compressed),
|
|
|
mutable_array,
|
|
|
]
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
for source in sources:
|
|
|
self.assertEqual(dctx.decompress(source), b'foo')
|
|
|
|
|
|
def test_no_content_size_in_frame(self):
|
|
|
cctx = zstd.ZstdCompressor(write_content_size=False)
|
|
|
compressed = cctx.compress(b'foobar')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError, 'could not determine content size in frame header'):
|
|
|
dctx.decompress(compressed)
|
|
|
|
|
|
def test_content_size_present(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
compressed = cctx.compress(b'foobar')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
decompressed = dctx.decompress(compressed)
|
|
|
self.assertEqual(decompressed, b'foobar')
|
|
|
|
|
|
def test_empty_roundtrip(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
compressed = cctx.compress(b'')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
decompressed = dctx.decompress(compressed)
|
|
|
|
|
|
self.assertEqual(decompressed, b'')
|
|
|
|
|
|
def test_max_output_size(self):
|
|
|
cctx = zstd.ZstdCompressor(write_content_size=False)
|
|
|
source = b'foobar' * 256
|
|
|
compressed = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
# Will fit into buffer exactly the size of input.
|
|
|
decompressed = dctx.decompress(compressed, max_output_size=len(source))
|
|
|
self.assertEqual(decompressed, source)
|
|
|
|
|
|
# Input size - 1 fails
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError,
|
|
|
'decompression error: did not decompress full frame'):
|
|
|
dctx.decompress(compressed, max_output_size=len(source) - 1)
|
|
|
|
|
|
# Input size + 1 works
|
|
|
decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1)
|
|
|
self.assertEqual(decompressed, source)
|
|
|
|
|
|
# A much larger buffer works.
|
|
|
decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64)
|
|
|
self.assertEqual(decompressed, source)
|
|
|
|
|
|
def test_stupidly_large_output_buffer(self):
|
|
|
cctx = zstd.ZstdCompressor(write_content_size=False)
|
|
|
compressed = cctx.compress(b'foobar' * 256)
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
# Will get OverflowError on some Python distributions that can't
|
|
|
# handle really large integers.
|
|
|
with self.assertRaises((MemoryError, OverflowError)):
|
|
|
dctx.decompress(compressed, max_output_size=2**62)
|
|
|
|
|
|
def test_dictionary(self):
|
|
|
samples = []
|
|
|
for i in range(128):
|
|
|
samples.append(b'foo' * 64)
|
|
|
samples.append(b'bar' * 64)
|
|
|
samples.append(b'foobar' * 64)
|
|
|
|
|
|
d = zstd.train_dictionary(8192, samples)
|
|
|
|
|
|
orig = b'foobar' * 16384
|
|
|
cctx = zstd.ZstdCompressor(level=1, dict_data=d)
|
|
|
compressed = cctx.compress(orig)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor(dict_data=d)
|
|
|
decompressed = dctx.decompress(compressed)
|
|
|
|
|
|
self.assertEqual(decompressed, orig)
|
|
|
|
|
|
def test_dictionary_multiple(self):
|
|
|
samples = []
|
|
|
for i in range(128):
|
|
|
samples.append(b'foo' * 64)
|
|
|
samples.append(b'bar' * 64)
|
|
|
samples.append(b'foobar' * 64)
|
|
|
|
|
|
d = zstd.train_dictionary(8192, samples)
|
|
|
|
|
|
sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192)
|
|
|
compressed = []
|
|
|
cctx = zstd.ZstdCompressor(level=1, dict_data=d)
|
|
|
for source in sources:
|
|
|
compressed.append(cctx.compress(source))
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor(dict_data=d)
|
|
|
for i in range(len(sources)):
|
|
|
decompressed = dctx.decompress(compressed[i])
|
|
|
self.assertEqual(decompressed, sources[i])
|
|
|
|
|
|
def test_max_window_size(self):
|
|
|
with open(__file__, 'rb') as fh:
|
|
|
source = fh.read()
|
|
|
|
|
|
# If we write a content size, the decompressor engages single pass
|
|
|
# mode and the window size doesn't come into play.
|
|
|
cctx = zstd.ZstdCompressor(write_content_size=False)
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor(max_window_size=2**zstd.WINDOWLOG_MIN)
|
|
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
zstd.ZstdError, 'decompression error: Frame requires too much memory'):
|
|
|
dctx.decompress(frame, max_output_size=len(source))
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor_copy_stream(unittest.TestCase):
|
|
|
def test_no_read(self):
|
|
|
source = object()
|
|
|
dest = io.BytesIO()
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
with self.assertRaises(ValueError):
|
|
|
dctx.copy_stream(source, dest)
|
|
|
|
|
|
def test_no_write(self):
|
|
|
source = io.BytesIO()
|
|
|
dest = object()
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
with self.assertRaises(ValueError):
|
|
|
dctx.copy_stream(source, dest)
|
|
|
|
|
|
def test_empty(self):
|
|
|
source = io.BytesIO()
|
|
|
dest = io.BytesIO()
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
# TODO should this raise an error?
|
|
|
r, w = dctx.copy_stream(source, dest)
|
|
|
|
|
|
self.assertEqual(r, 0)
|
|
|
self.assertEqual(w, 0)
|
|
|
self.assertEqual(dest.getvalue(), b'')
|
|
|
|
|
|
def test_large_data(self):
|
|
|
source = io.BytesIO()
|
|
|
for i in range(255):
|
|
|
source.write(struct.Struct('>B').pack(i) * 16384)
|
|
|
source.seek(0)
|
|
|
|
|
|
compressed = io.BytesIO()
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
cctx.copy_stream(source, compressed)
|
|
|
|
|
|
compressed.seek(0)
|
|
|
dest = io.BytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
r, w = dctx.copy_stream(compressed, dest)
|
|
|
|
|
|
self.assertEqual(r, len(compressed.getvalue()))
|
|
|
self.assertEqual(w, len(source.getvalue()))
|
|
|
|
|
|
def test_read_write_size(self):
|
|
|
source = OpCountingBytesIO(zstd.ZstdCompressor().compress(
|
|
|
b'foobarfoobar'))
|
|
|
|
|
|
dest = OpCountingBytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
|
|
|
|
|
|
self.assertEqual(r, len(source.getvalue()))
|
|
|
self.assertEqual(w, len(b'foobarfoobar'))
|
|
|
self.assertEqual(source._read_count, len(source.getvalue()) + 1)
|
|
|
self.assertEqual(dest._write_count, len(dest.getvalue()))
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor_stream_reader(unittest.TestCase):
|
|
|
def test_context_manager(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(b'foo') as reader:
|
|
|
with self.assertRaisesRegexp(ValueError, 'cannot __enter__ multiple times'):
|
|
|
with reader as reader2:
|
|
|
pass
|
|
|
|
|
|
def test_not_implemented(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(b'foo') as reader:
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
reader.readline()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
reader.readlines()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
iter(reader)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
next(reader)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
reader.write(b'foo')
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
reader.writelines([])
|
|
|
|
|
|
def test_constant_methods(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(b'foo') as reader:
|
|
|
self.assertFalse(reader.closed)
|
|
|
self.assertTrue(reader.readable())
|
|
|
self.assertFalse(reader.writable())
|
|
|
self.assertTrue(reader.seekable())
|
|
|
self.assertFalse(reader.isatty())
|
|
|
self.assertFalse(reader.closed)
|
|
|
self.assertIsNone(reader.flush())
|
|
|
self.assertFalse(reader.closed)
|
|
|
|
|
|
self.assertTrue(reader.closed)
|
|
|
|
|
|
def test_read_closed(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(b'foo') as reader:
|
|
|
reader.close()
|
|
|
self.assertTrue(reader.closed)
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
reader.read(1)
|
|
|
|
|
|
def test_read_sizes(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
foo = cctx.compress(b'foo')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(foo) as reader:
|
|
|
with self.assertRaisesRegexp(ValueError, 'cannot read negative amounts less than -1'):
|
|
|
reader.read(-2)
|
|
|
|
|
|
self.assertEqual(reader.read(0), b'')
|
|
|
self.assertEqual(reader.read(), b'foo')
|
|
|
|
|
|
def test_read_buffer(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
|
|
|
source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(frame) as reader:
|
|
|
self.assertEqual(reader.tell(), 0)
|
|
|
|
|
|
# We should get entire frame in one read.
|
|
|
result = reader.read(8192)
|
|
|
self.assertEqual(result, source)
|
|
|
self.assertEqual(reader.tell(), len(source))
|
|
|
|
|
|
# Read after EOF should return empty bytes.
|
|
|
self.assertEqual(reader.read(1), b'')
|
|
|
self.assertEqual(reader.tell(), len(result))
|
|
|
|
|
|
self.assertTrue(reader.closed)
|
|
|
|
|
|
def test_read_buffer_small_chunks(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
chunks = []
|
|
|
|
|
|
with dctx.stream_reader(frame, read_size=1) as reader:
|
|
|
while True:
|
|
|
chunk = reader.read(1)
|
|
|
if not chunk:
|
|
|
break
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
self.assertEqual(reader.tell(), sum(map(len, chunks)))
|
|
|
|
|
|
self.assertEqual(b''.join(chunks), source)
|
|
|
|
|
|
def test_read_stream(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
with dctx.stream_reader(io.BytesIO(frame)) as reader:
|
|
|
self.assertEqual(reader.tell(), 0)
|
|
|
|
|
|
chunk = reader.read(8192)
|
|
|
self.assertEqual(chunk, source)
|
|
|
self.assertEqual(reader.tell(), len(source))
|
|
|
self.assertEqual(reader.read(1), b'')
|
|
|
self.assertEqual(reader.tell(), len(source))
|
|
|
self.assertFalse(reader.closed)
|
|
|
|
|
|
self.assertTrue(reader.closed)
|
|
|
|
|
|
def test_read_stream_small_chunks(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
chunks = []
|
|
|
|
|
|
with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader:
|
|
|
while True:
|
|
|
chunk = reader.read(1)
|
|
|
if not chunk:
|
|
|
break
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
self.assertEqual(reader.tell(), sum(map(len, chunks)))
|
|
|
|
|
|
self.assertEqual(b''.join(chunks), source)
|
|
|
|
|
|
def test_read_after_exit(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
frame = cctx.compress(b'foo' * 60)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(frame) as reader:
|
|
|
while reader.read(16):
|
|
|
pass
|
|
|
|
|
|
self.assertTrue(reader.closed)
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
reader.read(10)
|
|
|
|
|
|
def test_illegal_seeks(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
frame = cctx.compress(b'foo' * 60)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(frame) as reader:
|
|
|
with self.assertRaisesRegexp(ValueError,
|
|
|
'cannot seek to negative position'):
|
|
|
reader.seek(-1, os.SEEK_SET)
|
|
|
|
|
|
reader.read(1)
|
|
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
ValueError, 'cannot seek zstd decompression stream backwards'):
|
|
|
reader.seek(0, os.SEEK_SET)
|
|
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
ValueError, 'cannot seek zstd decompression stream backwards'):
|
|
|
reader.seek(-1, os.SEEK_CUR)
|
|
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
ValueError,
|
|
|
'zstd decompression streams cannot be seeked with SEEK_END'):
|
|
|
reader.seek(0, os.SEEK_END)
|
|
|
|
|
|
reader.close()
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
reader.seek(4, os.SEEK_SET)
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
reader.seek(0)
|
|
|
|
|
|
def test_seek(self):
|
|
|
source = b'foobar' * 60
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with dctx.stream_reader(frame) as reader:
|
|
|
reader.seek(3)
|
|
|
self.assertEqual(reader.read(3), b'bar')
|
|
|
|
|
|
reader.seek(4, os.SEEK_CUR)
|
|
|
self.assertEqual(reader.read(2), b'ar')
|
|
|
|
|
|
def test_no_context_manager(self):
|
|
|
source = b'foobar' * 60
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
reader = dctx.stream_reader(frame)
|
|
|
|
|
|
self.assertEqual(reader.read(6), b'foobar')
|
|
|
self.assertEqual(reader.read(18), b'foobar' * 3)
|
|
|
self.assertFalse(reader.closed)
|
|
|
|
|
|
# Calling close prevents subsequent use.
|
|
|
reader.close()
|
|
|
self.assertTrue(reader.closed)
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
reader.read(6)
|
|
|
|
|
|
def test_read_after_error(self):
|
|
|
source = io.BytesIO(b'')
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
reader = dctx.stream_reader(source)
|
|
|
|
|
|
with reader:
|
|
|
reader.read(0)
|
|
|
|
|
|
with reader:
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
reader.read(100)
|
|
|
|
|
|
def test_partial_read(self):
|
|
|
# Inspired by https://github.com/indygreg/python-zstandard/issues/71.
|
|
|
buffer = io.BytesIO()
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
writer = cctx.stream_writer(buffer)
|
|
|
writer.write(bytearray(os.urandom(1000000)))
|
|
|
writer.flush(zstd.FLUSH_FRAME)
|
|
|
buffer.seek(0)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
reader = dctx.stream_reader(buffer)
|
|
|
|
|
|
while True:
|
|
|
chunk = reader.read(8192)
|
|
|
if not chunk:
|
|
|
break
|
|
|
|
|
|
def test_read_multiple_frames(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
source = io.BytesIO()
|
|
|
writer = cctx.stream_writer(source)
|
|
|
writer.write(b'foo')
|
|
|
writer.flush(zstd.FLUSH_FRAME)
|
|
|
writer.write(b'bar')
|
|
|
writer.flush(zstd.FLUSH_FRAME)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
reader = dctx.stream_reader(source.getvalue())
|
|
|
self.assertEqual(reader.read(2), b'fo')
|
|
|
self.assertEqual(reader.read(2), b'o')
|
|
|
self.assertEqual(reader.read(2), b'ba')
|
|
|
self.assertEqual(reader.read(2), b'r')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source)
|
|
|
self.assertEqual(reader.read(2), b'fo')
|
|
|
self.assertEqual(reader.read(2), b'o')
|
|
|
self.assertEqual(reader.read(2), b'ba')
|
|
|
self.assertEqual(reader.read(2), b'r')
|
|
|
|
|
|
reader = dctx.stream_reader(source.getvalue())
|
|
|
self.assertEqual(reader.read(3), b'foo')
|
|
|
self.assertEqual(reader.read(3), b'bar')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source)
|
|
|
self.assertEqual(reader.read(3), b'foo')
|
|
|
self.assertEqual(reader.read(3), b'bar')
|
|
|
|
|
|
reader = dctx.stream_reader(source.getvalue())
|
|
|
self.assertEqual(reader.read(4), b'foo')
|
|
|
self.assertEqual(reader.read(4), b'bar')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source)
|
|
|
self.assertEqual(reader.read(4), b'foo')
|
|
|
self.assertEqual(reader.read(4), b'bar')
|
|
|
|
|
|
reader = dctx.stream_reader(source.getvalue())
|
|
|
self.assertEqual(reader.read(128), b'foo')
|
|
|
self.assertEqual(reader.read(128), b'bar')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source)
|
|
|
self.assertEqual(reader.read(128), b'foo')
|
|
|
self.assertEqual(reader.read(128), b'bar')
|
|
|
|
|
|
# Now tests for reads spanning frames.
|
|
|
reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
|
|
|
self.assertEqual(reader.read(3), b'foo')
|
|
|
self.assertEqual(reader.read(3), b'bar')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source, read_across_frames=True)
|
|
|
self.assertEqual(reader.read(3), b'foo')
|
|
|
self.assertEqual(reader.read(3), b'bar')
|
|
|
|
|
|
reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
|
|
|
self.assertEqual(reader.read(6), b'foobar')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source, read_across_frames=True)
|
|
|
self.assertEqual(reader.read(6), b'foobar')
|
|
|
|
|
|
reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
|
|
|
self.assertEqual(reader.read(7), b'foobar')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source, read_across_frames=True)
|
|
|
self.assertEqual(reader.read(7), b'foobar')
|
|
|
|
|
|
reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
|
|
|
self.assertEqual(reader.read(128), b'foobar')
|
|
|
|
|
|
source.seek(0)
|
|
|
reader = dctx.stream_reader(source, read_across_frames=True)
|
|
|
self.assertEqual(reader.read(128), b'foobar')
|
|
|
|
|
|
def test_readinto(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
foo = cctx.compress(b'foo')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
# Attempting to readinto() a non-writable buffer fails.
|
|
|
# The exact exception varies based on the backend.
|
|
|
reader = dctx.stream_reader(foo)
|
|
|
with self.assertRaises(Exception):
|
|
|
reader.readinto(b'foobar')
|
|
|
|
|
|
# readinto() with sufficiently large destination.
|
|
|
b = bytearray(1024)
|
|
|
reader = dctx.stream_reader(foo)
|
|
|
self.assertEqual(reader.readinto(b), 3)
|
|
|
self.assertEqual(b[0:3], b'foo')
|
|
|
self.assertEqual(reader.readinto(b), 0)
|
|
|
self.assertEqual(b[0:3], b'foo')
|
|
|
|
|
|
# readinto() with small reads.
|
|
|
b = bytearray(1024)
|
|
|
reader = dctx.stream_reader(foo, read_size=1)
|
|
|
self.assertEqual(reader.readinto(b), 3)
|
|
|
self.assertEqual(b[0:3], b'foo')
|
|
|
|
|
|
# Too small destination buffer.
|
|
|
b = bytearray(2)
|
|
|
reader = dctx.stream_reader(foo)
|
|
|
self.assertEqual(reader.readinto(b), 2)
|
|
|
self.assertEqual(b[:], b'fo')
|
|
|
|
|
|
def test_readinto1(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
foo = cctx.compress(b'foo')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
reader = dctx.stream_reader(foo)
|
|
|
with self.assertRaises(Exception):
|
|
|
reader.readinto1(b'foobar')
|
|
|
|
|
|
# Sufficiently large destination.
|
|
|
b = bytearray(1024)
|
|
|
reader = dctx.stream_reader(foo)
|
|
|
self.assertEqual(reader.readinto1(b), 3)
|
|
|
self.assertEqual(b[0:3], b'foo')
|
|
|
self.assertEqual(reader.readinto1(b), 0)
|
|
|
self.assertEqual(b[0:3], b'foo')
|
|
|
|
|
|
# readinto() with small reads.
|
|
|
b = bytearray(1024)
|
|
|
reader = dctx.stream_reader(foo, read_size=1)
|
|
|
self.assertEqual(reader.readinto1(b), 3)
|
|
|
self.assertEqual(b[0:3], b'foo')
|
|
|
|
|
|
# Too small destination buffer.
|
|
|
b = bytearray(2)
|
|
|
reader = dctx.stream_reader(foo)
|
|
|
self.assertEqual(reader.readinto1(b), 2)
|
|
|
self.assertEqual(b[:], b'fo')
|
|
|
|
|
|
def test_readall(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
foo = cctx.compress(b'foo')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
reader = dctx.stream_reader(foo)
|
|
|
|
|
|
self.assertEqual(reader.readall(), b'foo')
|
|
|
|
|
|
def test_read1(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
foo = cctx.compress(b'foo')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
b = OpCountingBytesIO(foo)
|
|
|
reader = dctx.stream_reader(b)
|
|
|
|
|
|
self.assertEqual(reader.read1(), b'foo')
|
|
|
self.assertEqual(b._read_count, 1)
|
|
|
|
|
|
b = OpCountingBytesIO(foo)
|
|
|
reader = dctx.stream_reader(b)
|
|
|
|
|
|
self.assertEqual(reader.read1(0), b'')
|
|
|
self.assertEqual(reader.read1(2), b'fo')
|
|
|
self.assertEqual(b._read_count, 1)
|
|
|
self.assertEqual(reader.read1(1), b'o')
|
|
|
self.assertEqual(b._read_count, 1)
|
|
|
self.assertEqual(reader.read1(1), b'')
|
|
|
self.assertEqual(b._read_count, 2)
|
|
|
|
|
|
def test_read_lines(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
source = b'\n'.join(('line %d' % i).encode('ascii') for i in range(1024))
|
|
|
|
|
|
frame = cctx.compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
reader = dctx.stream_reader(frame)
|
|
|
tr = io.TextIOWrapper(reader, encoding='utf-8')
|
|
|
|
|
|
lines = []
|
|
|
for line in tr:
|
|
|
lines.append(line.encode('utf-8'))
|
|
|
|
|
|
self.assertEqual(len(lines), 1024)
|
|
|
self.assertEqual(b''.join(lines), source)
|
|
|
|
|
|
reader = dctx.stream_reader(frame)
|
|
|
tr = io.TextIOWrapper(reader, encoding='utf-8')
|
|
|
|
|
|
lines = tr.readlines()
|
|
|
self.assertEqual(len(lines), 1024)
|
|
|
self.assertEqual(''.join(lines).encode('utf-8'), source)
|
|
|
|
|
|
reader = dctx.stream_reader(frame)
|
|
|
tr = io.TextIOWrapper(reader, encoding='utf-8')
|
|
|
|
|
|
lines = []
|
|
|
while True:
|
|
|
line = tr.readline()
|
|
|
if not line:
|
|
|
break
|
|
|
|
|
|
lines.append(line.encode('utf-8'))
|
|
|
|
|
|
self.assertEqual(len(lines), 1024)
|
|
|
self.assertEqual(b''.join(lines), source)
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor_decompressobj(unittest.TestCase):
|
|
|
def test_simple(self):
|
|
|
data = zstd.ZstdCompressor(level=1).compress(b'foobar')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
dobj = dctx.decompressobj()
|
|
|
self.assertEqual(dobj.decompress(data), b'foobar')
|
|
|
self.assertIsNone(dobj.flush())
|
|
|
self.assertIsNone(dobj.flush(10))
|
|
|
self.assertIsNone(dobj.flush(length=100))
|
|
|
|
|
|
def test_input_types(self):
|
|
|
compressed = zstd.ZstdCompressor(level=1).compress(b'foo')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
mutable_array = bytearray(len(compressed))
|
|
|
mutable_array[:] = compressed
|
|
|
|
|
|
sources = [
|
|
|
memoryview(compressed),
|
|
|
bytearray(compressed),
|
|
|
mutable_array,
|
|
|
]
|
|
|
|
|
|
for source in sources:
|
|
|
dobj = dctx.decompressobj()
|
|
|
self.assertIsNone(dobj.flush())
|
|
|
self.assertIsNone(dobj.flush(10))
|
|
|
self.assertIsNone(dobj.flush(length=100))
|
|
|
self.assertEqual(dobj.decompress(source), b'foo')
|
|
|
self.assertIsNone(dobj.flush())
|
|
|
|
|
|
def test_reuse(self):
|
|
|
data = zstd.ZstdCompressor(level=1).compress(b'foobar')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
dobj = dctx.decompressobj()
|
|
|
dobj.decompress(data)
|
|
|
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'):
|
|
|
dobj.decompress(data)
|
|
|
self.assertIsNone(dobj.flush())
|
|
|
|
|
|
def test_bad_write_size(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'write_size must be positive'):
|
|
|
dctx.decompressobj(write_size=0)
|
|
|
|
|
|
def test_write_size(self):
|
|
|
source = b'foo' * 64 + b'bar' * 128
|
|
|
data = zstd.ZstdCompressor(level=1).compress(source)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
for i in range(128):
|
|
|
dobj = dctx.decompressobj(write_size=i + 1)
|
|
|
self.assertEqual(dobj.decompress(data), source)
|
|
|
|
|
|
|
|
|
def decompress_via_writer(data):
|
|
|
buffer = io.BytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
decompressor = dctx.stream_writer(buffer)
|
|
|
decompressor.write(data)
|
|
|
|
|
|
return buffer.getvalue()
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor_stream_writer(unittest.TestCase):
|
|
|
def test_io_api(self):
|
|
|
buffer = io.BytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
writer = dctx.stream_writer(buffer)
|
|
|
|
|
|
self.assertFalse(writer.closed)
|
|
|
self.assertFalse(writer.isatty())
|
|
|
self.assertFalse(writer.readable())
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readline()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readline(42)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readline(size=42)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readlines()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readlines(42)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readlines(hint=42)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.seek(0)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.seek(10, os.SEEK_SET)
|
|
|
|
|
|
self.assertFalse(writer.seekable())
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.tell()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.truncate()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.truncate(42)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.truncate(size=42)
|
|
|
|
|
|
self.assertTrue(writer.writable())
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.writelines([])
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.read()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.read(42)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.read(size=42)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readall()
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.readinto(None)
|
|
|
|
|
|
with self.assertRaises(io.UnsupportedOperation):
|
|
|
writer.fileno()
|
|
|
|
|
|
def test_fileno_file(self):
|
|
|
with tempfile.TemporaryFile('wb') as tf:
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
writer = dctx.stream_writer(tf)
|
|
|
|
|
|
self.assertEqual(writer.fileno(), tf.fileno())
|
|
|
|
|
|
def test_close(self):
|
|
|
foo = zstd.ZstdCompressor().compress(b'foo')
|
|
|
|
|
|
buffer = NonClosingBytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
writer = dctx.stream_writer(buffer)
|
|
|
|
|
|
writer.write(foo)
|
|
|
self.assertFalse(writer.closed)
|
|
|
self.assertFalse(buffer.closed)
|
|
|
writer.close()
|
|
|
self.assertTrue(writer.closed)
|
|
|
self.assertTrue(buffer.closed)
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
writer.write(b'')
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
writer.flush()
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'stream is closed'):
|
|
|
with writer:
|
|
|
pass
|
|
|
|
|
|
self.assertEqual(buffer.getvalue(), b'foo')
|
|
|
|
|
|
# Context manager exit should close stream.
|
|
|
buffer = NonClosingBytesIO()
|
|
|
writer = dctx.stream_writer(buffer)
|
|
|
|
|
|
with writer:
|
|
|
writer.write(foo)
|
|
|
|
|
|
self.assertTrue(writer.closed)
|
|
|
self.assertEqual(buffer.getvalue(), b'foo')
|
|
|
|
|
|
def test_flush(self):
|
|
|
buffer = OpCountingBytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
writer = dctx.stream_writer(buffer)
|
|
|
|
|
|
writer.flush()
|
|
|
self.assertEqual(buffer._flush_count, 1)
|
|
|
writer.flush()
|
|
|
self.assertEqual(buffer._flush_count, 2)
|
|
|
|
|
|
def test_empty_roundtrip(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
empty = cctx.compress(b'')
|
|
|
self.assertEqual(decompress_via_writer(empty), b'')
|
|
|
|
|
|
def test_input_types(self):
|
|
|
cctx = zstd.ZstdCompressor(level=1)
|
|
|
compressed = cctx.compress(b'foo')
|
|
|
|
|
|
mutable_array = bytearray(len(compressed))
|
|
|
mutable_array[:] = compressed
|
|
|
|
|
|
sources = [
|
|
|
memoryview(compressed),
|
|
|
bytearray(compressed),
|
|
|
mutable_array,
|
|
|
]
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
for source in sources:
|
|
|
buffer = io.BytesIO()
|
|
|
|
|
|
decompressor = dctx.stream_writer(buffer)
|
|
|
decompressor.write(source)
|
|
|
self.assertEqual(buffer.getvalue(), b'foo')
|
|
|
|
|
|
buffer = NonClosingBytesIO()
|
|
|
|
|
|
with dctx.stream_writer(buffer) as decompressor:
|
|
|
self.assertEqual(decompressor.write(source), 3)
|
|
|
|
|
|
self.assertEqual(buffer.getvalue(), b'foo')
|
|
|
|
|
|
buffer = io.BytesIO()
|
|
|
writer = dctx.stream_writer(buffer, write_return_read=True)
|
|
|
self.assertEqual(writer.write(source), len(source))
|
|
|
self.assertEqual(buffer.getvalue(), b'foo')
|
|
|
|
|
|
def test_large_roundtrip(self):
|
|
|
chunks = []
|
|
|
for i in range(255):
|
|
|
chunks.append(struct.Struct('>B').pack(i) * 16384)
|
|
|
orig = b''.join(chunks)
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
compressed = cctx.compress(orig)
|
|
|
|
|
|
self.assertEqual(decompress_via_writer(compressed), orig)
|
|
|
|
|
|
def test_multiple_calls(self):
|
|
|
chunks = []
|
|
|
for i in range(255):
|
|
|
for j in range(255):
|
|
|
chunks.append(struct.Struct('>B').pack(j) * i)
|
|
|
|
|
|
orig = b''.join(chunks)
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
compressed = cctx.compress(orig)
|
|
|
|
|
|
buffer = NonClosingBytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
with dctx.stream_writer(buffer) as decompressor:
|
|
|
pos = 0
|
|
|
while pos < len(compressed):
|
|
|
pos2 = pos + 8192
|
|
|
decompressor.write(compressed[pos:pos2])
|
|
|
pos += 8192
|
|
|
self.assertEqual(buffer.getvalue(), orig)
|
|
|
|
|
|
# Again with write_return_read=True
|
|
|
buffer = io.BytesIO()
|
|
|
writer = dctx.stream_writer(buffer, write_return_read=True)
|
|
|
pos = 0
|
|
|
while pos < len(compressed):
|
|
|
pos2 = pos + 8192
|
|
|
chunk = compressed[pos:pos2]
|
|
|
self.assertEqual(writer.write(chunk), len(chunk))
|
|
|
pos += 8192
|
|
|
self.assertEqual(buffer.getvalue(), orig)
|
|
|
|
|
|
def test_dictionary(self):
|
|
|
samples = []
|
|
|
for i in range(128):
|
|
|
samples.append(b'foo' * 64)
|
|
|
samples.append(b'bar' * 64)
|
|
|
samples.append(b'foobar' * 64)
|
|
|
|
|
|
d = zstd.train_dictionary(8192, samples)
|
|
|
|
|
|
orig = b'foobar' * 16384
|
|
|
buffer = NonClosingBytesIO()
|
|
|
cctx = zstd.ZstdCompressor(dict_data=d)
|
|
|
with cctx.stream_writer(buffer) as compressor:
|
|
|
self.assertEqual(compressor.write(orig), 0)
|
|
|
|
|
|
compressed = buffer.getvalue()
|
|
|
buffer = io.BytesIO()
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor(dict_data=d)
|
|
|
decompressor = dctx.stream_writer(buffer)
|
|
|
self.assertEqual(decompressor.write(compressed), len(orig))
|
|
|
self.assertEqual(buffer.getvalue(), orig)
|
|
|
|
|
|
buffer = NonClosingBytesIO()
|
|
|
|
|
|
with dctx.stream_writer(buffer) as decompressor:
|
|
|
self.assertEqual(decompressor.write(compressed), len(orig))
|
|
|
|
|
|
self.assertEqual(buffer.getvalue(), orig)
|
|
|
|
|
|
def test_memory_size(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
buffer = io.BytesIO()
|
|
|
|
|
|
decompressor = dctx.stream_writer(buffer)
|
|
|
size = decompressor.memory_size()
|
|
|
self.assertGreater(size, 100000)
|
|
|
|
|
|
with dctx.stream_writer(buffer) as decompressor:
|
|
|
size = decompressor.memory_size()
|
|
|
|
|
|
self.assertGreater(size, 100000)
|
|
|
|
|
|
def test_write_size(self):
|
|
|
source = zstd.ZstdCompressor().compress(b'foobarfoobar')
|
|
|
dest = OpCountingBytesIO()
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
with dctx.stream_writer(dest, write_size=1) as decompressor:
|
|
|
s = struct.Struct('>B')
|
|
|
for c in source:
|
|
|
if not isinstance(c, str):
|
|
|
c = s.pack(c)
|
|
|
decompressor.write(c)
|
|
|
|
|
|
self.assertEqual(dest.getvalue(), b'foobarfoobar')
|
|
|
self.assertEqual(dest._write_count, len(dest.getvalue()))
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor_read_to_iter(unittest.TestCase):
|
|
|
def test_type_validation(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
# Object with read() works.
|
|
|
dctx.read_to_iter(io.BytesIO())
|
|
|
|
|
|
# Buffer protocol works.
|
|
|
dctx.read_to_iter(b'foobar')
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
|
|
|
b''.join(dctx.read_to_iter(True))
|
|
|
|
|
|
def test_empty_input(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
source = io.BytesIO()
|
|
|
it = dctx.read_to_iter(source)
|
|
|
# TODO this is arguably wrong. Should get an error about missing frame foo.
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
it = dctx.read_to_iter(b'')
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
def test_invalid_input(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
source = io.BytesIO(b'foobar')
|
|
|
it = dctx.read_to_iter(source)
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'):
|
|
|
next(it)
|
|
|
|
|
|
it = dctx.read_to_iter(b'foobar')
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'):
|
|
|
next(it)
|
|
|
|
|
|
def test_empty_roundtrip(self):
|
|
|
cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
|
|
|
empty = cctx.compress(b'')
|
|
|
|
|
|
source = io.BytesIO(empty)
|
|
|
source.seek(0)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
it = dctx.read_to_iter(source)
|
|
|
|
|
|
# No chunks should be emitted since there is no data.
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
# Again for good measure.
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
def test_skip_bytes_too_large(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'):
|
|
|
b''.join(dctx.read_to_iter(b'', skip_bytes=1, read_size=1))
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'):
|
|
|
b''.join(dctx.read_to_iter(b'foobar', skip_bytes=10))
|
|
|
|
|
|
def test_skip_bytes(self):
|
|
|
cctx = zstd.ZstdCompressor(write_content_size=False)
|
|
|
compressed = cctx.compress(b'foobar')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
output = b''.join(dctx.read_to_iter(b'hdr' + compressed, skip_bytes=3))
|
|
|
self.assertEqual(output, b'foobar')
|
|
|
|
|
|
def test_large_output(self):
|
|
|
source = io.BytesIO()
|
|
|
source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
|
|
|
source.write(b'o')
|
|
|
source.seek(0)
|
|
|
|
|
|
cctx = zstd.ZstdCompressor(level=1)
|
|
|
compressed = io.BytesIO(cctx.compress(source.getvalue()))
|
|
|
compressed.seek(0)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
it = dctx.read_to_iter(compressed)
|
|
|
|
|
|
chunks = []
|
|
|
chunks.append(next(it))
|
|
|
chunks.append(next(it))
|
|
|
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
decompressed = b''.join(chunks)
|
|
|
self.assertEqual(decompressed, source.getvalue())
|
|
|
|
|
|
# And again with buffer protocol.
|
|
|
it = dctx.read_to_iter(compressed.getvalue())
|
|
|
chunks = []
|
|
|
chunks.append(next(it))
|
|
|
chunks.append(next(it))
|
|
|
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
decompressed = b''.join(chunks)
|
|
|
self.assertEqual(decompressed, source.getvalue())
|
|
|
|
|
|
@unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
|
|
|
def test_large_input(self):
|
|
|
bytes = list(struct.Struct('>B').pack(i) for i in range(256))
|
|
|
compressed = NonClosingBytesIO()
|
|
|
input_size = 0
|
|
|
cctx = zstd.ZstdCompressor(level=1)
|
|
|
with cctx.stream_writer(compressed) as compressor:
|
|
|
while True:
|
|
|
compressor.write(random.choice(bytes))
|
|
|
input_size += 1
|
|
|
|
|
|
have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
|
|
|
have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
|
|
|
if have_compressed and have_raw:
|
|
|
break
|
|
|
|
|
|
compressed = io.BytesIO(compressed.getvalue())
|
|
|
self.assertGreater(len(compressed.getvalue()),
|
|
|
zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
it = dctx.read_to_iter(compressed)
|
|
|
|
|
|
chunks = []
|
|
|
chunks.append(next(it))
|
|
|
chunks.append(next(it))
|
|
|
chunks.append(next(it))
|
|
|
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
decompressed = b''.join(chunks)
|
|
|
self.assertEqual(len(decompressed), input_size)
|
|
|
|
|
|
# And again with buffer protocol.
|
|
|
it = dctx.read_to_iter(compressed.getvalue())
|
|
|
|
|
|
chunks = []
|
|
|
chunks.append(next(it))
|
|
|
chunks.append(next(it))
|
|
|
chunks.append(next(it))
|
|
|
|
|
|
with self.assertRaises(StopIteration):
|
|
|
next(it)
|
|
|
|
|
|
decompressed = b''.join(chunks)
|
|
|
self.assertEqual(len(decompressed), input_size)
|
|
|
|
|
|
def test_interesting(self):
|
|
|
# Found this edge case via fuzzing.
|
|
|
cctx = zstd.ZstdCompressor(level=1)
|
|
|
|
|
|
source = io.BytesIO()
|
|
|
|
|
|
compressed = NonClosingBytesIO()
|
|
|
with cctx.stream_writer(compressed) as compressor:
|
|
|
for i in range(256):
|
|
|
chunk = b'\0' * 1024
|
|
|
compressor.write(chunk)
|
|
|
source.write(chunk)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
simple = dctx.decompress(compressed.getvalue(),
|
|
|
max_output_size=len(source.getvalue()))
|
|
|
self.assertEqual(simple, source.getvalue())
|
|
|
|
|
|
compressed = io.BytesIO(compressed.getvalue())
|
|
|
streamed = b''.join(dctx.read_to_iter(compressed))
|
|
|
self.assertEqual(streamed, source.getvalue())
|
|
|
|
|
|
def test_read_write_size(self):
|
|
|
source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar'))
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
for chunk in dctx.read_to_iter(source, read_size=1, write_size=1):
|
|
|
self.assertEqual(len(chunk), 1)
|
|
|
|
|
|
self.assertEqual(source._read_count, len(source.getvalue()))
|
|
|
|
|
|
def test_magic_less(self):
|
|
|
params = zstd.CompressionParameters.from_level(
|
|
|
1, format=zstd.FORMAT_ZSTD1_MAGICLESS)
|
|
|
cctx = zstd.ZstdCompressor(compression_params=params)
|
|
|
frame = cctx.compress(b'foobar')
|
|
|
|
|
|
self.assertNotEqual(frame[0:4], b'\x28\xb5\x2f\xfd')
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
with self.assertRaisesRegexp(
|
|
|
zstd.ZstdError, 'error determining content size from frame header'):
|
|
|
dctx.decompress(frame)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS)
|
|
|
res = b''.join(dctx.read_to_iter(frame))
|
|
|
self.assertEqual(res, b'foobar')
|
|
|
|
|
|
|
|
|
@make_cffi
|
|
|
class TestDecompressor_content_dict_chain(unittest.TestCase):
|
|
|
def test_bad_inputs_simple(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
dctx.decompress_content_dict_chain(b'foo')
|
|
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
dctx.decompress_content_dict_chain((b'foo', b'bar'))
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'empty input chain'):
|
|
|
dctx.decompress_content_dict_chain([])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
|
|
|
dctx.decompress_content_dict_chain([u'foo'])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
|
|
|
dctx.decompress_content_dict_chain([True])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 0 is too small to contain a zstd frame'):
|
|
|
dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'):
|
|
|
dctx.decompress_content_dict_chain([b'foo' * 8])
|
|
|
|
|
|
no_size = zstd.ZstdCompressor(write_content_size=False).compress(b'foo' * 64)
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'):
|
|
|
dctx.decompress_content_dict_chain([no_size])
|
|
|
|
|
|
# Corrupt first frame.
|
|
|
frame = zstd.ZstdCompressor().compress(b'foo' * 64)
|
|
|
frame = frame[0:12] + frame[15:]
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError,
|
|
|
'chunk 0 did not decompress full frame'):
|
|
|
dctx.decompress_content_dict_chain([frame])
|
|
|
|
|
|
def test_bad_subsequent_input(self):
|
|
|
initial = zstd.ZstdCompressor().compress(b'foo' * 64)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
|
|
|
dctx.decompress_content_dict_chain([initial, u'foo'])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
|
|
|
dctx.decompress_content_dict_chain([initial, None])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 1 is too small to contain a zstd frame'):
|
|
|
dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'):
|
|
|
dctx.decompress_content_dict_chain([initial, b'foo' * 8])
|
|
|
|
|
|
no_size = zstd.ZstdCompressor(write_content_size=False).compress(b'foo' * 64)
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'):
|
|
|
dctx.decompress_content_dict_chain([initial, no_size])
|
|
|
|
|
|
# Corrupt second frame.
|
|
|
cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b'foo' * 64))
|
|
|
frame = cctx.compress(b'bar' * 64)
|
|
|
frame = frame[0:12] + frame[15:]
|
|
|
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError, 'chunk 1 did not decompress full frame'):
|
|
|
dctx.decompress_content_dict_chain([initial, frame])
|
|
|
|
|
|
def test_simple(self):
|
|
|
original = [
|
|
|
b'foo' * 64,
|
|
|
b'foobar' * 64,
|
|
|
b'baz' * 64,
|
|
|
b'foobaz' * 64,
|
|
|
b'foobarbaz' * 64,
|
|
|
]
|
|
|
|
|
|
chunks = []
|
|
|
chunks.append(zstd.ZstdCompressor().compress(original[0]))
|
|
|
for i, chunk in enumerate(original[1:]):
|
|
|
d = zstd.ZstdCompressionDict(original[i])
|
|
|
cctx = zstd.ZstdCompressor(dict_data=d)
|
|
|
chunks.append(cctx.compress(chunk))
|
|
|
|
|
|
for i in range(1, len(original)):
|
|
|
chain = chunks[0:i]
|
|
|
expected = original[i - 1]
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
decompressed = dctx.decompress_content_dict_chain(chain)
|
|
|
self.assertEqual(decompressed, expected)
|
|
|
|
|
|
|
|
|
# TODO enable for CFFI
|
|
|
class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase):
|
|
|
def test_invalid_inputs(self):
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
dctx.multi_decompress_to_buffer(True)
|
|
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
dctx.multi_decompress_to_buffer((1, 2))
|
|
|
|
|
|
with self.assertRaisesRegexp(TypeError, 'item 0 not a bytes like object'):
|
|
|
dctx.multi_decompress_to_buffer([u'foo'])
|
|
|
|
|
|
with self.assertRaisesRegexp(ValueError, 'could not determine decompressed size of item 0'):
|
|
|
dctx.multi_decompress_to_buffer([b'foobarbaz'])
|
|
|
|
|
|
def test_list_input(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
|
|
|
original = [b'foo' * 4, b'bar' * 6]
|
|
|
frames = [cctx.compress(d) for d in original]
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
result = dctx.multi_decompress_to_buffer(frames)
|
|
|
|
|
|
self.assertEqual(len(result), len(frames))
|
|
|
self.assertEqual(result.size(), sum(map(len, original)))
|
|
|
|
|
|
for i, data in enumerate(original):
|
|
|
self.assertEqual(result[i].tobytes(), data)
|
|
|
|
|
|
self.assertEqual(result[0].offset, 0)
|
|
|
self.assertEqual(len(result[0]), 12)
|
|
|
self.assertEqual(result[1].offset, 12)
|
|
|
self.assertEqual(len(result[1]), 18)
|
|
|
|
|
|
def test_list_input_frame_sizes(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
|
|
|
original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
|
|
|
frames = [cctx.compress(d) for d in original]
|
|
|
sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
|
|
|
|
|
|
self.assertEqual(len(result), len(frames))
|
|
|
self.assertEqual(result.size(), sum(map(len, original)))
|
|
|
|
|
|
for i, data in enumerate(original):
|
|
|
self.assertEqual(result[i].tobytes(), data)
|
|
|
|
|
|
def test_buffer_with_segments_input(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
|
|
|
original = [b'foo' * 4, b'bar' * 6]
|
|
|
frames = [cctx.compress(d) for d in original]
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1]))
|
|
|
b = zstd.BufferWithSegments(b''.join(frames), segments)
|
|
|
|
|
|
result = dctx.multi_decompress_to_buffer(b)
|
|
|
|
|
|
self.assertEqual(len(result), len(frames))
|
|
|
self.assertEqual(result[0].offset, 0)
|
|
|
self.assertEqual(len(result[0]), 12)
|
|
|
self.assertEqual(result[1].offset, 12)
|
|
|
self.assertEqual(len(result[1]), 18)
|
|
|
|
|
|
def test_buffer_with_segments_sizes(self):
|
|
|
cctx = zstd.ZstdCompressor(write_content_size=False)
|
|
|
original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
|
|
|
frames = [cctx.compress(d) for d in original]
|
|
|
sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
segments = struct.pack('=QQQQQQ', 0, len(frames[0]),
|
|
|
len(frames[0]), len(frames[1]),
|
|
|
len(frames[0]) + len(frames[1]), len(frames[2]))
|
|
|
b = zstd.BufferWithSegments(b''.join(frames), segments)
|
|
|
|
|
|
result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
|
|
|
|
|
|
self.assertEqual(len(result), len(frames))
|
|
|
self.assertEqual(result.size(), sum(map(len, original)))
|
|
|
|
|
|
for i, data in enumerate(original):
|
|
|
self.assertEqual(result[i].tobytes(), data)
|
|
|
|
|
|
def test_buffer_with_segments_collection_input(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
|
|
|
original = [
|
|
|
b'foo0' * 2,
|
|
|
b'foo1' * 3,
|
|
|
b'foo2' * 4,
|
|
|
b'foo3' * 5,
|
|
|
b'foo4' * 6,
|
|
|
]
|
|
|
|
|
|
if not hasattr(cctx, 'multi_compress_to_buffer'):
|
|
|
self.skipTest('multi_compress_to_buffer not available')
|
|
|
|
|
|
frames = cctx.multi_compress_to_buffer(original)
|
|
|
|
|
|
# Check round trip.
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
|
|
|
|
|
|
self.assertEqual(len(decompressed), len(original))
|
|
|
|
|
|
for i, data in enumerate(original):
|
|
|
self.assertEqual(data, decompressed[i].tobytes())
|
|
|
|
|
|
# And a manual mode.
|
|
|
b = b''.join([frames[0].tobytes(), frames[1].tobytes()])
|
|
|
b1 = zstd.BufferWithSegments(b, struct.pack('=QQQQ',
|
|
|
0, len(frames[0]),
|
|
|
len(frames[0]), len(frames[1])))
|
|
|
|
|
|
b = b''.join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()])
|
|
|
b2 = zstd.BufferWithSegments(b, struct.pack('=QQQQQQ',
|
|
|
0, len(frames[2]),
|
|
|
len(frames[2]), len(frames[3]),
|
|
|
len(frames[2]) + len(frames[3]), len(frames[4])))
|
|
|
|
|
|
c = zstd.BufferWithSegmentsCollection(b1, b2)
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
decompressed = dctx.multi_decompress_to_buffer(c)
|
|
|
|
|
|
self.assertEqual(len(decompressed), 5)
|
|
|
for i in range(5):
|
|
|
self.assertEqual(decompressed[i].tobytes(), original[i])
|
|
|
|
|
|
def test_dict(self):
|
|
|
d = zstd.train_dictionary(16384, generate_samples(), k=64, d=16)
|
|
|
|
|
|
cctx = zstd.ZstdCompressor(dict_data=d, level=1)
|
|
|
frames = [cctx.compress(s) for s in generate_samples()]
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor(dict_data=d)
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
result = dctx.multi_decompress_to_buffer(frames)
|
|
|
|
|
|
self.assertEqual([o.tobytes() for o in result], generate_samples())
|
|
|
|
|
|
def test_multiple_threads(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
|
|
|
frames = []
|
|
|
frames.extend(cctx.compress(b'x' * 64) for i in range(256))
|
|
|
frames.extend(cctx.compress(b'y' * 64) for i in range(256))
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
result = dctx.multi_decompress_to_buffer(frames, threads=-1)
|
|
|
|
|
|
self.assertEqual(len(result), len(frames))
|
|
|
self.assertEqual(result.size(), 2 * 64 * 256)
|
|
|
self.assertEqual(result[0].tobytes(), b'x' * 64)
|
|
|
self.assertEqual(result[256].tobytes(), b'y' * 64)
|
|
|
|
|
|
def test_item_failure(self):
|
|
|
cctx = zstd.ZstdCompressor()
|
|
|
frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)]
|
|
|
|
|
|
frames[1] = frames[1][0:15] + b'extra' + frames[1][15:]
|
|
|
|
|
|
dctx = zstd.ZstdDecompressor()
|
|
|
|
|
|
if not hasattr(dctx, 'multi_decompress_to_buffer'):
|
|
|
self.skipTest('multi_decompress_to_buffer not available')
|
|
|
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError,
|
|
|
'error decompressing item 1: ('
|
|
|
'Corrupted block|'
|
|
|
'Destination buffer is too small)'):
|
|
|
dctx.multi_decompress_to_buffer(frames)
|
|
|
|
|
|
with self.assertRaisesRegexp(zstd.ZstdError,
|
|
|
'error decompressing item 1: ('
|
|
|
'Corrupted block|'
|
|
|
'Destination buffer is too small)'):
|
|
|
dctx.multi_decompress_to_buffer(frames, threads=2)
|
|
|
|
|
|
|