|
|
import imp
|
|
|
import inspect
|
|
|
import io
|
|
|
import os
|
|
|
import types
|
|
|
import unittest
|
|
|
|
|
|
try:
|
|
|
import hypothesis
|
|
|
except ImportError:
|
|
|
hypothesis = None
|
|
|
|
|
|
|
|
|
class TestCase(unittest.TestCase):
|
|
|
if not getattr(unittest.TestCase, "assertRaisesRegex", False):
|
|
|
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
|
|
|
|
|
|
|
|
|
def make_cffi(cls):
|
|
|
"""Decorator to add CFFI versions of each test method."""
|
|
|
|
|
|
# The module containing this class definition should
|
|
|
# `import zstandard as zstd`. Otherwise things may blow up.
|
|
|
mod = inspect.getmodule(cls)
|
|
|
if not hasattr(mod, "zstd"):
|
|
|
raise Exception('test module does not contain "zstd" symbol')
|
|
|
|
|
|
if not hasattr(mod.zstd, "backend"):
|
|
|
raise Exception(
|
|
|
'zstd symbol does not have "backend" attribute; did '
|
|
|
"you `import zstandard as zstd`?"
|
|
|
)
|
|
|
|
|
|
# If `import zstandard` already chose the cffi backend, there is nothing
|
|
|
# for us to do: we only add the cffi variation if the default backend
|
|
|
# is the C extension.
|
|
|
if mod.zstd.backend == "cffi":
|
|
|
return cls
|
|
|
|
|
|
old_env = dict(os.environ)
|
|
|
os.environ["PYTHON_ZSTANDARD_IMPORT_POLICY"] = "cffi"
|
|
|
try:
|
|
|
try:
|
|
|
mod_info = imp.find_module("zstandard")
|
|
|
mod = imp.load_module("zstandard_cffi", *mod_info)
|
|
|
except ImportError:
|
|
|
return cls
|
|
|
finally:
|
|
|
os.environ.clear()
|
|
|
os.environ.update(old_env)
|
|
|
|
|
|
if mod.backend != "cffi":
|
|
|
raise Exception("got the zstandard %s backend instead of cffi" % mod.backend)
|
|
|
|
|
|
# If CFFI version is available, dynamically construct test methods
|
|
|
# that use it.
|
|
|
|
|
|
for attr in dir(cls):
|
|
|
fn = getattr(cls, attr)
|
|
|
if not inspect.ismethod(fn) and not inspect.isfunction(fn):
|
|
|
continue
|
|
|
|
|
|
if not fn.__name__.startswith("test_"):
|
|
|
continue
|
|
|
|
|
|
name = "%s_cffi" % fn.__name__
|
|
|
|
|
|
# Replace the "zstd" symbol with the CFFI module instance. Then copy
|
|
|
# the function object and install it in a new attribute.
|
|
|
if isinstance(fn, types.FunctionType):
|
|
|
globs = dict(fn.__globals__)
|
|
|
globs["zstd"] = mod
|
|
|
new_fn = types.FunctionType(
|
|
|
fn.__code__, globs, name, fn.__defaults__, fn.__closure__
|
|
|
)
|
|
|
new_method = new_fn
|
|
|
else:
|
|
|
globs = dict(fn.__func__.func_globals)
|
|
|
globs["zstd"] = mod
|
|
|
new_fn = types.FunctionType(
|
|
|
fn.__func__.func_code,
|
|
|
globs,
|
|
|
name,
|
|
|
fn.__func__.func_defaults,
|
|
|
fn.__func__.func_closure,
|
|
|
)
|
|
|
new_method = types.UnboundMethodType(new_fn, fn.im_self, fn.im_class)
|
|
|
|
|
|
setattr(cls, name, new_method)
|
|
|
|
|
|
return cls
|
|
|
|
|
|
|
|
|
class NonClosingBytesIO(io.BytesIO):
|
|
|
"""BytesIO that saves the underlying buffer on close().
|
|
|
|
|
|
This allows us to access written data after close().
|
|
|
"""
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super(NonClosingBytesIO, self).__init__(*args, **kwargs)
|
|
|
self._saved_buffer = None
|
|
|
|
|
|
def close(self):
|
|
|
self._saved_buffer = self.getvalue()
|
|
|
return super(NonClosingBytesIO, self).close()
|
|
|
|
|
|
def getvalue(self):
|
|
|
if self.closed:
|
|
|
return self._saved_buffer
|
|
|
else:
|
|
|
return super(NonClosingBytesIO, self).getvalue()
|
|
|
|
|
|
|
|
|
class OpCountingBytesIO(NonClosingBytesIO):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
self._flush_count = 0
|
|
|
self._read_count = 0
|
|
|
self._write_count = 0
|
|
|
return super(OpCountingBytesIO, self).__init__(*args, **kwargs)
|
|
|
|
|
|
def flush(self):
|
|
|
self._flush_count += 1
|
|
|
return super(OpCountingBytesIO, self).flush()
|
|
|
|
|
|
def read(self, *args):
|
|
|
self._read_count += 1
|
|
|
return super(OpCountingBytesIO, self).read(*args)
|
|
|
|
|
|
def write(self, data):
|
|
|
self._write_count += 1
|
|
|
return super(OpCountingBytesIO, self).write(data)
|
|
|
|
|
|
|
|
|
_source_files = []
|
|
|
|
|
|
|
|
|
def random_input_data():
|
|
|
"""Obtain the raw content of source files.
|
|
|
|
|
|
This is used for generating "random" data to feed into fuzzing, since it is
|
|
|
faster than random content generation.
|
|
|
"""
|
|
|
if _source_files:
|
|
|
return _source_files
|
|
|
|
|
|
for root, dirs, files in os.walk(os.path.dirname(__file__)):
|
|
|
dirs[:] = list(sorted(dirs))
|
|
|
for f in sorted(files):
|
|
|
try:
|
|
|
with open(os.path.join(root, f), "rb") as fh:
|
|
|
data = fh.read()
|
|
|
if data:
|
|
|
_source_files.append(data)
|
|
|
except OSError:
|
|
|
pass
|
|
|
|
|
|
# Also add some actual random data.
|
|
|
_source_files.append(os.urandom(100))
|
|
|
_source_files.append(os.urandom(1000))
|
|
|
_source_files.append(os.urandom(10000))
|
|
|
_source_files.append(os.urandom(100000))
|
|
|
_source_files.append(os.urandom(1000000))
|
|
|
|
|
|
return _source_files
|
|
|
|
|
|
|
|
|
def generate_samples():
|
|
|
inputs = [
|
|
|
b"foo",
|
|
|
b"bar",
|
|
|
b"abcdef",
|
|
|
b"sometext",
|
|
|
b"baz",
|
|
|
]
|
|
|
|
|
|
samples = []
|
|
|
|
|
|
for i in range(128):
|
|
|
samples.append(inputs[i % 5])
|
|
|
samples.append(inputs[i % 5] * (i + 3))
|
|
|
samples.append(inputs[-(i % 5)] * (i + 2))
|
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
if hypothesis:
|
|
|
default_settings = hypothesis.settings(deadline=10000)
|
|
|
hypothesis.settings.register_profile("default", default_settings)
|
|
|
|
|
|
ci_settings = hypothesis.settings(deadline=20000, max_examples=1000)
|
|
|
hypothesis.settings.register_profile("ci", ci_settings)
|
|
|
|
|
|
expensive_settings = hypothesis.settings(deadline=None, max_examples=10000)
|
|
|
hypothesis.settings.register_profile("expensive", expensive_settings)
|
|
|
|
|
|
hypothesis.settings.load_profile(os.environ.get("HYPOTHESIS_PROFILE", "default"))
|
|
|
|