##// END OF EJS Templates
merge default into stable for 5.4 release
Pulkit Goyal -
r45608:26ce8e75 merge 5.4rc0 stable
parent child Browse files
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -0,0 +1,60 b''
1 # Instructions:
2 #
3 # 1. cargo install --version 0.5.0 pyoxidizer
4 # 2. cd /path/to/hg
5 # 3. pyoxidizer build --path contrib/packaging [--release]
6 # 4. Run build/pyoxidizer/<arch>/<debug|release>/app/hg
7 #
8 # If you need to build again, you need to remove the build/lib.* and
9 # build/temp.* directories, otherwise PyOxidizer fails to pick up C
10 # extensions. This is a bug in PyOxidizer.
11
12 ROOT = CWD + "/../.."
13
14 set_build_path(ROOT + "/build/pyoxidizer")
15
16 def make_exe():
17 dist = default_python_distribution()
18
19 code = "import hgdemandimport; hgdemandimport.enable(); from mercurial import dispatch; dispatch.run()"
20
21 config = PythonInterpreterConfig(
22 raw_allocator = "system",
23 run_eval = code,
24 # We want to let the user load extensions from the file system
25 filesystem_importer = True,
26 # We need this to make resourceutil happy, since it looks for sys.frozen.
27 sys_frozen = True,
28 legacy_windows_stdio = True,
29 )
30
31 exe = dist.to_python_executable(
32 name = "hg",
33 config = config,
34 )
35
36 # Use setup.py install to build Mercurial and collect Python resources to
37 # embed in the executable.
38 resources = dist.setup_py_install(ROOT)
39 exe.add_python_resources(resources)
40
41 return exe
42
43 def make_install(exe):
44 m = FileManifest()
45
46 # `hg` goes in root directory.
47 m.add_python_resource(".", exe)
48
49 templates = glob(
50 include=[ROOT + "/mercurial/templates/**/*"],
51 strip_prefix = ROOT + "/mercurial/",
52 )
53 m.add_manifest(templates)
54
55 return m
56
57 register_target("exe", make_exe)
58 register_target("app", make_install, depends = ["exe"], default = True)
59
60 resolve_targets()
@@ -0,0 +1,93 b''
1 #!/usr/bin/env python
2 #
3 # A small script to automatically reject idle Diffs
4 #
5 # you need to set the PHABBOT_USER and PHABBOT_TOKEN environment variable for authentication
6 from __future__ import absolute_import, print_function
7
8 import datetime
9 import os
10 import sys
11
12 import phabricator
13
14 MESSAGE = """There seems to have been no activities on this Diff for the past 3 Months.
15
16 By policy, we are automatically moving it out of the `need-review` state.
17
18 Please, move it back to `need-review` without hesitation if this diff should still be discussed.
19
20 :baymax:need-review-idle:
21 """
22
23
24 PHAB_URL = "https://phab.mercurial-scm.org/api/"
25 USER = os.environ.get("PHABBOT_USER", "baymax")
26 TOKEN = os.environ.get("PHABBOT_TOKEN")
27
28
29 NOW = datetime.datetime.now()
30
31 # 3 months in seconds
32 DELAY = 60 * 60 * 24 * 30 * 3
33
34
35 def get_all_diff(phab):
36 """Fetch all the diff that the need review"""
37 return phab.differential.query(
38 status="status-needs-review",
39 order="order-modified",
40 paths=[('HG', None)],
41 )
42
43
44 def filter_diffs(diffs, older_than):
45 """filter diffs to only keep the one unmodified sin <older_than> seconds"""
46 olds = []
47 for d in diffs:
48 modified = int(d['dateModified'])
49 modified = datetime.datetime.fromtimestamp(modified)
50 d["idleFor"] = idle_for = NOW - modified
51 if idle_for.total_seconds() > older_than:
52 olds.append(d)
53 return olds
54
55
56 def nudge_diff(phab, diff):
57 """Comment on the idle diff and reject it"""
58 diff_id = int(d['id'])
59 phab.differential.createcomment(
60 revision_id=diff_id, message=MESSAGE, action="reject"
61 )
62
63
64 if not USER:
65 print(
66 "not user specified please set PHABBOT_USER and PHABBOT_TOKEN",
67 file=sys.stderr,
68 )
69 elif not TOKEN:
70 print(
71 "not api-token specified please set PHABBOT_USER and PHABBOT_TOKEN",
72 file=sys.stderr,
73 )
74 sys.exit(1)
75
76 phab = phabricator.Phabricator(USER, host=PHAB_URL, token=TOKEN)
77 phab.connect()
78 phab.update_interfaces()
79 print('Hello "%s".' % phab.user.whoami()['realName'])
80
81 diffs = get_all_diff(phab)
82 print("Found %d Diffs" % len(diffs))
83 olds = filter_diffs(diffs, DELAY)
84 print("Found %d old Diffs" % len(olds))
85 for d in olds:
86 diff_id = d['id']
87 status = d['statusName']
88 modified = int(d['dateModified'])
89 idle_for = d["idleFor"]
90 msg = 'nudging D%s in "%s" state for %s'
91 print(msg % (diff_id, status, idle_for))
92 # uncomment to actually affect phab
93 nudge_diff(phab, d)
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100644
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: new file 100755
The requested commit or file is too big and content was truncated. Show full diff
@@ -1,258 +1,259 b''
1 1 # If you want to change PREFIX, do not just edit it below. The changed
2 2 # value wont get passed on to recursive make calls. You should instead
3 3 # override the variable on the command like:
4 4 #
5 5 # % make PREFIX=/opt/ install
6 6
7 7 export PREFIX=/usr/local
8 8 PYTHON?=python
9 9 $(eval HGROOT := $(shell pwd))
10 10 HGPYTHONS ?= $(HGROOT)/build/pythons
11 11 PURE=
12 12 PYFILESCMD=find mercurial hgext doc -name '*.py'
13 13 PYFILES:=$(shell $(PYFILESCMD))
14 14 DOCFILES=mercurial/helptext/*.txt
15 15 export LANGUAGE=C
16 16 export LC_ALL=C
17 17 TESTFLAGS ?= $(shell echo $$HGTESTFLAGS)
18 18 OSXVERSIONFLAGS ?= $(shell echo $$OSXVERSIONFLAGS)
19 19 CARGO = cargo
20 20
21 21 # Set this to e.g. "mingw32" to use a non-default compiler.
22 22 COMPILER=
23 23
24 24 COMPILERFLAG_tmp_ =
25 25 COMPILERFLAG_tmp_${COMPILER} ?= -c $(COMPILER)
26 26 COMPILERFLAG=${COMPILERFLAG_tmp_${COMPILER}}
27 27
28 28 help:
29 29 @echo 'Commonly used make targets:'
30 30 @echo ' all - build program and documentation'
31 31 @echo ' install - install program and man pages to $$PREFIX ($(PREFIX))'
32 32 @echo ' install-home - install with setup.py install --home=$$HOME ($(HOME))'
33 33 @echo ' local - build for inplace usage'
34 34 @echo ' tests - run all tests in the automatic test suite'
35 35 @echo ' test-foo - run only specified tests (e.g. test-merge1.t)'
36 36 @echo ' dist - run all tests and create a source tarball in dist/'
37 37 @echo ' clean - remove files created by other targets'
38 38 @echo ' (except installed files or dist source tarball)'
39 39 @echo ' update-pot - update i18n/hg.pot'
40 40 @echo
41 41 @echo 'Example for a system-wide installation under /usr/local:'
42 42 @echo ' make all && su -c "make install" && hg version'
43 43 @echo
44 44 @echo 'Example for a local installation (usable in this directory):'
45 45 @echo ' make local && ./hg version'
46 46
47 47 all: build doc
48 48
49 49 local:
50 50 $(PYTHON) setup.py $(PURE) \
51 51 build_py -c -d . \
52 52 build_ext $(COMPILERFLAG) -i \
53 53 build_hgexe $(COMPILERFLAG) -i \
54 54 build_mo
55 55 env HGRCPATH= $(PYTHON) hg version
56 56
57 57 build:
58 58 $(PYTHON) setup.py $(PURE) build $(COMPILERFLAG)
59 59
60 60 wheel:
61 61 FORCE_SETUPTOOLS=1 $(PYTHON) setup.py $(PURE) bdist_wheel $(COMPILERFLAG)
62 62
63 63 doc:
64 64 $(MAKE) -C doc
65 65
66 66 cleanbutpackages:
67 rm -f hg.exe
67 68 -$(PYTHON) setup.py clean --all # ignore errors from this command
68 69 find contrib doc hgext hgext3rd i18n mercurial tests hgdemandimport \
69 70 \( -name '*.py[cdo]' -o -name '*.so' \) -exec rm -f '{}' ';'
70 71 rm -f MANIFEST MANIFEST.in hgext/__index__.py tests/*.err
71 72 rm -f mercurial/__modulepolicy__.py
72 73 if test -d .hg; then rm -f mercurial/__version__.py; fi
73 74 rm -rf build mercurial/locale
74 75 $(MAKE) -C doc clean
75 76 $(MAKE) -C contrib/chg distclean
76 77 rm -rf rust/target
77 78 rm -f mercurial/rustext.so
78 79
79 80 clean: cleanbutpackages
80 81 rm -rf packages
81 82
82 83 install: install-bin install-doc
83 84
84 85 install-bin: build
85 86 $(PYTHON) setup.py $(PURE) install --root="$(DESTDIR)/" --prefix="$(PREFIX)" --force
86 87
87 88 install-doc: doc
88 89 cd doc && $(MAKE) $(MFLAGS) install
89 90
90 91 install-home: install-home-bin install-home-doc
91 92
92 93 install-home-bin: build
93 94 $(PYTHON) setup.py $(PURE) install --home="$(HOME)" --prefix="" --force
94 95
95 96 install-home-doc: doc
96 97 cd doc && $(MAKE) $(MFLAGS) PREFIX="$(HOME)" install
97 98
98 99 MANIFEST-doc:
99 100 $(MAKE) -C doc MANIFEST
100 101
101 102 MANIFEST.in: MANIFEST-doc
102 103 hg manifest | sed -e 's/^/include /' > MANIFEST.in
103 104 echo include mercurial/__version__.py >> MANIFEST.in
104 105 sed -e 's/^/include /' < doc/MANIFEST >> MANIFEST.in
105 106
106 107 dist: tests dist-notests
107 108
108 109 dist-notests: doc MANIFEST.in
109 110 TAR_OPTIONS="--owner=root --group=root --mode=u+w,go-w,a+rX-s" $(PYTHON) setup.py -q sdist
110 111
111 112 check: tests
112 113
113 114 tests:
114 115 # Run Rust tests if cargo is installed
115 116 if command -v $(CARGO) >/dev/null 2>&1; then \
116 117 $(MAKE) rust-tests; \
117 118 fi
118 119 cd tests && $(PYTHON) run-tests.py $(TESTFLAGS)
119 120
120 121 test-%:
121 122 cd tests && $(PYTHON) run-tests.py $(TESTFLAGS) $@
122 123
123 124 testpy-%:
124 125 @echo Looking for Python $* in $(HGPYTHONS)
125 126 [ -e $(HGPYTHONS)/$*/bin/python ] || ( \
126 127 cd $$(mktemp --directory --tmpdir) && \
127 128 $(MAKE) -f $(HGROOT)/contrib/Makefile.python PYTHONVER=$* PREFIX=$(HGPYTHONS)/$* python )
128 129 cd tests && $(HGPYTHONS)/$*/bin/python run-tests.py $(TESTFLAGS)
129 130
130 131 rust-tests: py_feature = $(shell $(PYTHON) -c \
131 132 'import sys; print(["python27-bin", "python3-bin"][sys.version_info[0] >= 3])')
132 133 rust-tests:
133 134 cd $(HGROOT)/rust/hg-cpython \
134 135 && $(CARGO) test --quiet --all \
135 136 --no-default-features --features "$(py_feature)"
136 137
137 138 check-code:
138 139 hg manifest | xargs python contrib/check-code.py
139 140
140 141 format-c:
141 142 clang-format --style file -i \
142 143 `hg files 'set:(**.c or **.cc or **.h) and not "listfile:contrib/clang-format-ignorelist"'`
143 144
144 145 update-pot: i18n/hg.pot
145 146
146 147 i18n/hg.pot: $(PYFILES) $(DOCFILES) i18n/posplit i18n/hggettext
147 148 $(PYTHON) i18n/hggettext mercurial/commands.py \
148 149 hgext/*.py hgext/*/__init__.py \
149 150 mercurial/fileset.py mercurial/revset.py \
150 151 mercurial/templatefilters.py \
151 152 mercurial/templatefuncs.py \
152 153 mercurial/templatekw.py \
153 154 mercurial/filemerge.py \
154 155 mercurial/hgweb/webcommands.py \
155 156 mercurial/util.py \
156 157 $(DOCFILES) > i18n/hg.pot.tmp
157 158 # All strings marked for translation in Mercurial contain
158 159 # ASCII characters only. But some files contain string
159 160 # literals like this '\037\213'. xgettext thinks it has to
160 161 # parse them even though they are not marked for translation.
161 162 # Extracting with an explicit encoding of ISO-8859-1 will make
162 163 # xgettext "parse" and ignore them.
163 164 $(PYFILESCMD) | xargs \
164 165 xgettext --package-name "Mercurial" \
165 166 --msgid-bugs-address "<mercurial-devel@mercurial-scm.org>" \
166 167 --copyright-holder "Matt Mackall <mpm@selenic.com> and others" \
167 168 --from-code ISO-8859-1 --join --sort-by-file --add-comments=i18n: \
168 169 -d hg -p i18n -o hg.pot.tmp
169 170 $(PYTHON) i18n/posplit i18n/hg.pot.tmp
170 171 # The target file is not created before the last step. So it never is in
171 172 # an intermediate state.
172 173 mv -f i18n/hg.pot.tmp i18n/hg.pot
173 174
174 175 %.po: i18n/hg.pot
175 176 # work on a temporary copy for never having a half completed target
176 177 cp $@ $@.tmp
177 178 msgmerge --no-location --update $@.tmp $^
178 179 mv -f $@.tmp $@
179 180
180 181 # Packaging targets
181 182
182 183 packaging_targets := \
183 184 centos5 \
184 185 centos6 \
185 186 centos7 \
186 187 centos8 \
187 188 deb \
188 189 docker-centos5 \
189 190 docker-centos6 \
190 191 docker-centos7 \
191 192 docker-centos8 \
192 193 docker-debian-bullseye \
193 194 docker-debian-buster \
194 195 docker-debian-stretch \
195 196 docker-fedora \
196 197 docker-ubuntu-trusty \
197 198 docker-ubuntu-trusty-ppa \
198 199 docker-ubuntu-xenial \
199 200 docker-ubuntu-xenial-ppa \
200 201 docker-ubuntu-artful \
201 202 docker-ubuntu-artful-ppa \
202 203 docker-ubuntu-bionic \
203 204 docker-ubuntu-bionic-ppa \
204 205 fedora \
205 206 linux-wheels \
206 207 linux-wheels-x86_64 \
207 208 linux-wheels-i686 \
208 209 ppa
209 210
210 211 # Forward packaging targets for convenience.
211 212 $(packaging_targets):
212 213 $(MAKE) -C contrib/packaging $@
213 214
214 215 osx:
215 216 rm -rf build/mercurial
216 217 /usr/bin/python2.7 setup.py install --optimize=1 \
217 218 --root=build/mercurial/ --prefix=/usr/local/ \
218 219 --install-lib=/Library/Python/2.7/site-packages/
219 220 make -C doc all install DESTDIR="$(PWD)/build/mercurial/"
220 221 # Place a bogon .DS_Store file in the target dir so we can be
221 222 # sure it doesn't get included in the final package.
222 223 touch build/mercurial/.DS_Store
223 224 # install zsh completions - this location appears to be
224 225 # searched by default as of macOS Sierra.
225 226 install -d build/mercurial/usr/local/share/zsh/site-functions/
226 227 install -m 0644 contrib/zsh_completion build/mercurial/usr/local/share/zsh/site-functions/_hg
227 228 # install bash completions - there doesn't appear to be a
228 229 # place that's searched by default for bash, so we'll follow
229 230 # the lead of Apple's git install and just put it in a
230 231 # location of our own.
231 232 install -d build/mercurial/usr/local/hg/contrib/
232 233 install -m 0644 contrib/bash_completion build/mercurial/usr/local/hg/contrib/hg-completion.bash
233 234 make -C contrib/chg \
234 235 HGPATH=/usr/local/bin/hg \
235 236 PYTHON=/usr/bin/python2.7 \
236 237 HGEXTDIR=/Library/Python/2.7/site-packages/hgext \
237 238 DESTDIR=../../build/mercurial \
238 239 PREFIX=/usr/local \
239 240 clean install
240 241 mkdir -p $${OUTPUTDIR:-dist}
241 242 HGVER=$$(python contrib/genosxversion.py $(OSXVERSIONFLAGS) build/mercurial/Library/Python/2.7/site-packages/mercurial/__version__.py) && \
242 243 OSXVER=$$(sw_vers -productVersion | cut -d. -f1,2) && \
243 244 pkgbuild --filter \\.DS_Store --root build/mercurial/ \
244 245 --identifier org.mercurial-scm.mercurial \
245 246 --version "$${HGVER}" \
246 247 build/mercurial.pkg && \
247 248 productbuild --distribution contrib/packaging/macosx/distribution.xml \
248 249 --package-path build/ \
249 250 --version "$${HGVER}" \
250 251 --resources contrib/packaging/macosx/ \
251 252 "$${OUTPUTDIR:-dist/}"/Mercurial-"$${HGVER}"-macosx"$${OSXVER}".pkg
252 253
253 254 .PHONY: help all local build doc cleanbutpackages clean install install-bin \
254 255 install-doc install-home install-home-bin install-home-doc \
255 256 dist dist-notests check tests rust-tests check-code format-c \
256 257 update-pot \
257 258 $(packaging_targets) \
258 259 osx
@@ -1,15 +1,14 b''
1 1 [tool.black]
2 2 line-length = 80
3 3 exclude = '''
4 4 build/
5 5 | wheelhouse/
6 6 | dist/
7 7 | packages/
8 8 | \.hg/
9 9 | \.mypy_cache/
10 10 | \.venv/
11 11 | mercurial/thirdparty/
12 | contrib/python-zstandard/
13 12 '''
14 13 skip-string-normalization = true
15 14 quiet = true
@@ -1,126 +1,126 b''
1 1 # __init__.py - asv benchmark suite
2 2 #
3 3 # Copyright 2016 Logilab SA <contact@logilab.fr>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 # "historical portability" policy of contrib/benchmarks:
9 9 #
10 10 # We have to make this code work correctly with current mercurial stable branch
11 11 # and if possible with reasonable cost with early Mercurial versions.
12 12
13 13 '''ASV (https://asv.readthedocs.io) benchmark suite
14 14
15 15 Benchmark are parameterized against reference repositories found in the
16 16 directory pointed by the REPOS_DIR environment variable.
17 17
18 18 Invocation example:
19 19
20 20 $ export REPOS_DIR=~/hgperf/repos
21 21 # run suite on given revision
22 22 $ asv --config contrib/asv.conf.json run REV
23 23 # run suite on new changesets found in stable and default branch
24 24 $ asv --config contrib/asv.conf.json run NEW
25 25 # display a comparative result table of benchmark results between two given
26 26 # revisions
27 27 $ asv --config contrib/asv.conf.json compare REV1 REV2
28 28 # compute regression detection and generate ASV static website
29 29 $ asv --config contrib/asv.conf.json publish
30 30 # serve the static website
31 31 $ asv --config contrib/asv.conf.json preview
32 32 '''
33 33
34 34 from __future__ import absolute_import
35 35
36 36 import functools
37 37 import os
38 38 import re
39 39
40 40 from mercurial import (
41 41 extensions,
42 42 hg,
43 43 ui as uimod,
44 44 util,
45 45 )
46 46
47 47 basedir = os.path.abspath(
48 48 os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)
49 49 )
50 50 reposdir = os.environ['REPOS_DIR']
51 51 reposnames = [
52 52 name
53 53 for name in os.listdir(reposdir)
54 54 if os.path.isdir(os.path.join(reposdir, name, ".hg"))
55 55 ]
56 56 if not reposnames:
57 57 raise ValueError("No repositories found in $REPO_DIR")
58 58 outputre = re.compile(
59 59 (
60 60 r'! wall (\d+.\d+) comb \d+.\d+ user \d+.\d+ sys '
61 61 r'\d+.\d+ \(best of \d+\)'
62 62 )
63 63 )
64 64
65 65
66 66 def runperfcommand(reponame, command, *args, **kwargs):
67 67 os.environ["HGRCPATH"] = os.environ.get("ASVHGRCPATH", "")
68 68 # for "historical portability"
69 69 # ui.load() has been available since d83ca85
70 70 if util.safehasattr(uimod.ui, "load"):
71 71 ui = uimod.ui.load()
72 72 else:
73 73 ui = uimod.ui()
74 74 repo = hg.repository(ui, os.path.join(reposdir, reponame))
75 75 perfext = extensions.load(
76 76 ui, 'perfext', os.path.join(basedir, 'contrib', 'perf.py')
77 77 )
78 78 cmd = getattr(perfext, command)
79 79 ui.pushbuffer()
80 80 cmd(ui, repo, *args, **kwargs)
81 81 output = ui.popbuffer()
82 82 match = outputre.search(output)
83 83 if not match:
84 raise ValueError("Invalid output {0}".format(output))
84 raise ValueError("Invalid output {}".format(output))
85 85 return float(match.group(1))
86 86
87 87
88 88 def perfbench(repos=reposnames, name=None, params=None):
89 89 """decorator to declare ASV benchmark based on contrib/perf.py extension
90 90
91 91 An ASV benchmark is a python function with the given attributes:
92 92
93 93 __name__: should start with track_, time_ or mem_ to be collected by ASV
94 94 params and param_name: parameter matrix to display multiple graphs on the
95 95 same page.
96 96 pretty_name: If defined it's displayed in web-ui instead of __name__
97 97 (useful for revsets)
98 98 the module name is prepended to the benchmark name and displayed as
99 99 "category" in webui.
100 100
101 101 Benchmarks are automatically parameterized with repositories found in the
102 102 REPOS_DIR environment variable.
103 103
104 104 `params` is the param matrix in the form of a list of tuple
105 105 (param_name, [value0, value1])
106 106
107 107 For example [(x, [a, b]), (y, [c, d])] declare benchmarks for
108 108 (a, c), (a, d), (b, c) and (b, d).
109 109 """
110 110 params = list(params or [])
111 111 params.insert(0, ("repo", repos))
112 112
113 113 def decorator(func):
114 114 @functools.wraps(func)
115 115 def wrapped(repo, *args):
116 116 def perf(command, *a, **kw):
117 117 return runperfcommand(repo, command, *a, **kw)
118 118
119 119 return func(perf, *args)
120 120
121 121 wrapped.params = [p[1] for p in params]
122 122 wrapped.param_names = [p[0] for p in params]
123 123 wrapped.pretty_name = name
124 124 return wrapped
125 125
126 126 return decorator
@@ -1,113 +1,113 b''
1 1 #!/usr/bin/env python
2 2 #
3 3 # check-py3-compat - check Python 3 compatibility of Mercurial files
4 4 #
5 5 # Copyright 2015 Gregory Szorc <gregory.szorc@gmail.com>
6 6 #
7 7 # This software may be used and distributed according to the terms of the
8 8 # GNU General Public License version 2 or any later version.
9 9
10 10 from __future__ import absolute_import, print_function
11 11
12 12 import ast
13 13 import importlib
14 14 import os
15 15 import sys
16 16 import traceback
17 17 import warnings
18 18
19 19
20 20 def check_compat_py2(f):
21 21 """Check Python 3 compatibility for a file with Python 2"""
22 22 with open(f, 'rb') as fh:
23 23 content = fh.read()
24 24 root = ast.parse(content)
25 25
26 26 # Ignore empty files.
27 27 if not root.body:
28 28 return
29 29
30 30 futures = set()
31 31 haveprint = False
32 32 for node in ast.walk(root):
33 33 if isinstance(node, ast.ImportFrom):
34 34 if node.module == '__future__':
35 futures |= set(n.name for n in node.names)
35 futures |= {n.name for n in node.names}
36 36 elif isinstance(node, ast.Print):
37 37 haveprint = True
38 38
39 39 if 'absolute_import' not in futures:
40 40 print('%s not using absolute_import' % f)
41 41 if haveprint and 'print_function' not in futures:
42 42 print('%s requires print_function' % f)
43 43
44 44
45 45 def check_compat_py3(f):
46 46 """Check Python 3 compatibility of a file with Python 3."""
47 47 with open(f, 'rb') as fh:
48 48 content = fh.read()
49 49
50 50 try:
51 51 ast.parse(content, filename=f)
52 52 except SyntaxError as e:
53 53 print('%s: invalid syntax: %s' % (f, e))
54 54 return
55 55
56 56 # Try to import the module.
57 57 # For now we only support modules in packages because figuring out module
58 58 # paths for things not in a package can be confusing.
59 59 if f.startswith(
60 60 ('hgdemandimport/', 'hgext/', 'mercurial/')
61 61 ) and not f.endswith('__init__.py'):
62 62 assert f.endswith('.py')
63 63 name = f.replace('/', '.')[:-3]
64 64 try:
65 65 importlib.import_module(name)
66 66 except Exception as e:
67 67 exc_type, exc_value, tb = sys.exc_info()
68 68 # We walk the stack and ignore frames from our custom importer,
69 69 # import mechanisms, and stdlib modules. This kinda/sorta
70 70 # emulates CPython behavior in import.c while also attempting
71 71 # to pin blame on a Mercurial file.
72 72 for frame in reversed(traceback.extract_tb(tb)):
73 73 if frame.name == '_call_with_frames_removed':
74 74 continue
75 75 if 'importlib' in frame.filename:
76 76 continue
77 77 if 'mercurial/__init__.py' in frame.filename:
78 78 continue
79 79 if frame.filename.startswith(sys.prefix):
80 80 continue
81 81 break
82 82
83 83 if frame.filename:
84 84 filename = os.path.basename(frame.filename)
85 85 print(
86 86 '%s: error importing: <%s> %s (error at %s:%d)'
87 87 % (f, type(e).__name__, e, filename, frame.lineno)
88 88 )
89 89 else:
90 90 print(
91 91 '%s: error importing module: <%s> %s (line %d)'
92 92 % (f, type(e).__name__, e, frame.lineno)
93 93 )
94 94
95 95
96 96 if __name__ == '__main__':
97 97 if sys.version_info[0] == 2:
98 98 fn = check_compat_py2
99 99 else:
100 100 fn = check_compat_py3
101 101
102 102 for f in sys.argv[1:]:
103 103 with warnings.catch_warnings(record=True) as warns:
104 104 fn(f)
105 105
106 106 for w in warns:
107 107 print(
108 108 warnings.formatwarning(
109 109 w.message, w.category, w.filename, w.lineno
110 110 ).rstrip()
111 111 )
112 112
113 113 sys.exit(0)
@@ -1,456 +1,470 b''
1 1 /*
2 2 * A fast client for Mercurial command server
3 3 *
4 4 * Copyright (c) 2011 Yuya Nishihara <yuya@tcha.org>
5 5 *
6 6 * This software may be used and distributed according to the terms of the
7 7 * GNU General Public License version 2 or any later version.
8 8 */
9 9
10 10 #include <assert.h>
11 11 #include <errno.h>
12 12 #include <fcntl.h>
13 13 #include <signal.h>
14 14 #include <stdio.h>
15 15 #include <stdlib.h>
16 16 #include <string.h>
17 17 #include <sys/file.h>
18 18 #include <sys/stat.h>
19 19 #include <sys/types.h>
20 20 #include <sys/un.h>
21 21 #include <sys/wait.h>
22 22 #include <time.h>
23 23 #include <unistd.h>
24 24
25 25 #include "hgclient.h"
26 26 #include "procutil.h"
27 27 #include "util.h"
28 28
29 29 #ifndef PATH_MAX
30 30 #define PATH_MAX 4096
31 31 #endif
32 32
33 33 struct cmdserveropts {
34 34 char sockname[PATH_MAX];
35 35 char initsockname[PATH_MAX];
36 36 char redirectsockname[PATH_MAX];
37 37 size_t argsize;
38 38 const char **args;
39 39 };
40 40
41 41 static void initcmdserveropts(struct cmdserveropts *opts)
42 42 {
43 43 memset(opts, 0, sizeof(struct cmdserveropts));
44 44 }
45 45
46 46 static void freecmdserveropts(struct cmdserveropts *opts)
47 47 {
48 48 free(opts->args);
49 49 opts->args = NULL;
50 50 opts->argsize = 0;
51 51 }
52 52
53 53 /*
54 54 * Test if an argument is a sensitive flag that should be passed to the server.
55 55 * Return 0 if not, otherwise the number of arguments starting from the current
56 56 * one that should be passed to the server.
57 57 */
58 58 static size_t testsensitiveflag(const char *arg)
59 59 {
60 60 static const struct {
61 61 const char *name;
62 62 size_t narg;
63 63 } flags[] = {
64 64 {"--config", 1}, {"--cwd", 1}, {"--repo", 1},
65 65 {"--repository", 1}, {"--traceback", 0}, {"-R", 1},
66 66 };
67 67 size_t i;
68 68 for (i = 0; i < sizeof(flags) / sizeof(flags[0]); ++i) {
69 69 size_t len = strlen(flags[i].name);
70 70 size_t narg = flags[i].narg;
71 71 if (memcmp(arg, flags[i].name, len) == 0) {
72 72 if (arg[len] == '\0') {
73 73 /* --flag (value) */
74 74 return narg + 1;
75 75 } else if (arg[len] == '=' && narg > 0) {
76 76 /* --flag=value */
77 77 return 1;
78 78 } else if (flags[i].name[1] != '-') {
79 79 /* short flag */
80 80 return 1;
81 81 }
82 82 }
83 83 }
84 84 return 0;
85 85 }
86 86
87 87 /*
88 88 * Parse argv[] and put sensitive flags to opts->args
89 89 */
90 90 static void setcmdserverargs(struct cmdserveropts *opts, int argc,
91 91 const char *argv[])
92 92 {
93 93 size_t i, step;
94 94 opts->argsize = 0;
95 95 for (i = 0, step = 1; i < (size_t)argc; i += step, step = 1) {
96 96 if (!argv[i])
97 97 continue; /* pass clang-analyse */
98 98 if (strcmp(argv[i], "--") == 0)
99 99 break;
100 100 size_t n = testsensitiveflag(argv[i]);
101 101 if (n == 0 || i + n > (size_t)argc)
102 102 continue;
103 103 opts->args =
104 104 reallocx(opts->args, (n + opts->argsize) * sizeof(char *));
105 105 memcpy(opts->args + opts->argsize, argv + i,
106 106 sizeof(char *) * n);
107 107 opts->argsize += n;
108 108 step = n;
109 109 }
110 110 }
111 111
112 112 static void preparesockdir(const char *sockdir)
113 113 {
114 114 int r;
115 115 r = mkdir(sockdir, 0700);
116 116 if (r < 0 && errno != EEXIST)
117 117 abortmsgerrno("cannot create sockdir %s", sockdir);
118 118
119 119 struct stat st;
120 120 r = lstat(sockdir, &st);
121 121 if (r < 0)
122 122 abortmsgerrno("cannot stat %s", sockdir);
123 123 if (!S_ISDIR(st.st_mode))
124 124 abortmsg("cannot create sockdir %s (file exists)", sockdir);
125 125 if (st.st_uid != geteuid() || st.st_mode & 0077)
126 126 abortmsg("insecure sockdir %s", sockdir);
127 127 }
128 128
129 129 /*
130 130 * Check if a socket directory exists and is only owned by the current user.
131 131 * Return 1 if so, 0 if not. This is used to check if XDG_RUNTIME_DIR can be
132 132 * used or not. According to the specification [1], XDG_RUNTIME_DIR should be
133 133 * ignored if the directory is not owned by the user with mode 0700.
134 134 * [1]: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
135 135 */
136 136 static int checkruntimedir(const char *sockdir)
137 137 {
138 138 struct stat st;
139 139 int r = lstat(sockdir, &st);
140 140 if (r < 0) /* ex. does not exist */
141 141 return 0;
142 142 if (!S_ISDIR(st.st_mode)) /* ex. is a file, not a directory */
143 143 return 0;
144 144 return st.st_uid == geteuid() && (st.st_mode & 0777) == 0700;
145 145 }
146 146
147 147 static void getdefaultsockdir(char sockdir[], size_t size)
148 148 {
149 149 /* by default, put socket file in secure directory
150 150 * (${XDG_RUNTIME_DIR}/chg, or /${TMPDIR:-tmp}/chg$UID)
151 151 * (permission of socket file may be ignored on some Unices) */
152 152 const char *runtimedir = getenv("XDG_RUNTIME_DIR");
153 153 int r;
154 154 if (runtimedir && checkruntimedir(runtimedir)) {
155 155 r = snprintf(sockdir, size, "%s/chg", runtimedir);
156 156 } else {
157 157 const char *tmpdir = getenv("TMPDIR");
158 158 if (!tmpdir)
159 159 tmpdir = "/tmp";
160 160 r = snprintf(sockdir, size, "%s/chg%d", tmpdir, geteuid());
161 161 }
162 162 if (r < 0 || (size_t)r >= size)
163 163 abortmsg("too long TMPDIR (r = %d)", r);
164 164 }
165 165
166 166 static void setcmdserveropts(struct cmdserveropts *opts)
167 167 {
168 168 int r;
169 169 char sockdir[PATH_MAX];
170 170 const char *envsockname = getenv("CHGSOCKNAME");
171 171 if (!envsockname) {
172 172 getdefaultsockdir(sockdir, sizeof(sockdir));
173 173 preparesockdir(sockdir);
174 174 }
175 175
176 176 const char *basename = (envsockname) ? envsockname : sockdir;
177 177 const char *sockfmt = (envsockname) ? "%s" : "%s/server";
178 178 r = snprintf(opts->sockname, sizeof(opts->sockname), sockfmt, basename);
179 179 if (r < 0 || (size_t)r >= sizeof(opts->sockname))
180 180 abortmsg("too long TMPDIR or CHGSOCKNAME (r = %d)", r);
181 181 r = snprintf(opts->initsockname, sizeof(opts->initsockname), "%s.%u",
182 182 opts->sockname, (unsigned)getpid());
183 183 if (r < 0 || (size_t)r >= sizeof(opts->initsockname))
184 184 abortmsg("too long TMPDIR or CHGSOCKNAME (r = %d)", r);
185 185 }
186 186
187 187 static const char *gethgcmd(void)
188 188 {
189 189 static const char *hgcmd = NULL;
190 190 if (!hgcmd) {
191 191 hgcmd = getenv("CHGHG");
192 192 if (!hgcmd || hgcmd[0] == '\0')
193 193 hgcmd = getenv("HG");
194 194 if (!hgcmd || hgcmd[0] == '\0')
195 195 #ifdef HGPATH
196 196 hgcmd = (HGPATH);
197 197 #else
198 198 hgcmd = "hg";
199 199 #endif
200 200 }
201 201 return hgcmd;
202 202 }
203 203
204 204 static void execcmdserver(const struct cmdserveropts *opts)
205 205 {
206 206 const char *hgcmd = gethgcmd();
207 207
208 208 const char *baseargv[] = {
209 209 hgcmd,
210 210 "serve",
211 211 "--cmdserver",
212 212 "chgunix",
213 213 "--address",
214 214 opts->initsockname,
215 215 "--daemon-postexec",
216 216 "chdir:/",
217 217 };
218 218 size_t baseargvsize = sizeof(baseargv) / sizeof(baseargv[0]);
219 219 size_t argsize = baseargvsize + opts->argsize + 1;
220 220
221 221 const char **argv = mallocx(sizeof(char *) * argsize);
222 222 memcpy(argv, baseargv, sizeof(baseargv));
223 223 if (opts->args) {
224 224 size_t size = sizeof(char *) * opts->argsize;
225 225 memcpy(argv + baseargvsize, opts->args, size);
226 226 }
227 227 argv[argsize - 1] = NULL;
228 228
229 const char *lc_ctype_env = getenv("LC_CTYPE");
230 if (lc_ctype_env == NULL) {
231 if (putenv("CHG_CLEAR_LC_CTYPE=") != 0)
232 abortmsgerrno("failed to putenv CHG_CLEAR_LC_CTYPE");
233 } else {
234 if (setenv("CHGORIG_LC_CTYPE", lc_ctype_env, 1) != 0) {
235 abortmsgerrno("failed to setenv CHGORIG_LC_CTYYPE");
236 }
237 }
238
229 239 if (putenv("CHGINTERNALMARK=") != 0)
230 240 abortmsgerrno("failed to putenv");
231 241 if (execvp(hgcmd, (char **)argv) < 0)
232 242 abortmsgerrno("failed to exec cmdserver");
233 243 free(argv);
234 244 }
235 245
236 246 /* Retry until we can connect to the server. Give up after some time. */
237 247 static hgclient_t *retryconnectcmdserver(struct cmdserveropts *opts, pid_t pid)
238 248 {
239 249 static const struct timespec sleepreq = {0, 10 * 1000000};
240 250 int pst = 0;
241 251
242 252 debugmsg("try connect to %s repeatedly", opts->initsockname);
243 253
244 254 unsigned int timeoutsec = 60; /* default: 60 seconds */
245 255 const char *timeoutenv = getenv("CHGTIMEOUT");
246 256 if (timeoutenv)
247 257 sscanf(timeoutenv, "%u", &timeoutsec);
248 258
249 259 for (unsigned int i = 0; !timeoutsec || i < timeoutsec * 100; i++) {
250 260 hgclient_t *hgc = hgc_open(opts->initsockname);
251 261 if (hgc) {
252 262 debugmsg("rename %s to %s", opts->initsockname,
253 263 opts->sockname);
254 264 int r = rename(opts->initsockname, opts->sockname);
255 265 if (r != 0)
256 266 abortmsgerrno("cannot rename");
257 267 return hgc;
258 268 }
259 269
260 270 if (pid > 0) {
261 271 /* collect zombie if child process fails to start */
262 272 int r = waitpid(pid, &pst, WNOHANG);
263 273 if (r != 0)
264 274 goto cleanup;
265 275 }
266 276
267 277 nanosleep(&sleepreq, NULL);
268 278 }
269 279
270 280 abortmsg("timed out waiting for cmdserver %s", opts->initsockname);
271 281 return NULL;
272 282
273 283 cleanup:
274 284 if (WIFEXITED(pst)) {
275 285 if (WEXITSTATUS(pst) == 0)
276 286 abortmsg("could not connect to cmdserver "
277 287 "(exited with status 0)");
278 288 debugmsg("cmdserver exited with status %d", WEXITSTATUS(pst));
279 289 exit(WEXITSTATUS(pst));
280 290 } else if (WIFSIGNALED(pst)) {
281 291 abortmsg("cmdserver killed by signal %d", WTERMSIG(pst));
282 292 } else {
283 293 abortmsg("error while waiting for cmdserver");
284 294 }
285 295 return NULL;
286 296 }
287 297
288 298 /* Connect to a cmdserver. Will start a new server on demand. */
289 299 static hgclient_t *connectcmdserver(struct cmdserveropts *opts)
290 300 {
291 301 const char *sockname =
292 302 opts->redirectsockname[0] ? opts->redirectsockname : opts->sockname;
293 303 debugmsg("try connect to %s", sockname);
294 304 hgclient_t *hgc = hgc_open(sockname);
295 305 if (hgc)
296 306 return hgc;
297 307
298 308 /* prevent us from being connected to an outdated server: we were
299 309 * told by a server to redirect to opts->redirectsockname and that
300 310 * address does not work. we do not want to connect to the server
301 311 * again because it will probably tell us the same thing. */
302 312 if (sockname == opts->redirectsockname)
303 313 unlink(opts->sockname);
304 314
305 315 debugmsg("start cmdserver at %s", opts->initsockname);
306 316
307 317 pid_t pid = fork();
308 318 if (pid < 0)
309 319 abortmsg("failed to fork cmdserver process");
310 320 if (pid == 0) {
311 321 execcmdserver(opts);
312 322 } else {
313 323 hgc = retryconnectcmdserver(opts, pid);
314 324 }
315 325
316 326 return hgc;
317 327 }
318 328
319 329 static void killcmdserver(const struct cmdserveropts *opts)
320 330 {
321 331 /* resolve config hash */
322 332 char *resolvedpath = realpath(opts->sockname, NULL);
323 333 if (resolvedpath) {
324 334 unlink(resolvedpath);
325 335 free(resolvedpath);
326 336 }
327 337 }
328 338
329 339 /* Run instructions sent from the server like unlink and set redirect path
330 340 * Return 1 if reconnect is needed, otherwise 0 */
331 341 static int runinstructions(struct cmdserveropts *opts, const char **insts)
332 342 {
333 343 int needreconnect = 0;
334 344 if (!insts)
335 345 return needreconnect;
336 346
337 347 assert(insts);
338 348 opts->redirectsockname[0] = '\0';
339 349 const char **pinst;
340 350 for (pinst = insts; *pinst; pinst++) {
341 351 debugmsg("instruction: %s", *pinst);
342 352 if (strncmp(*pinst, "unlink ", 7) == 0) {
343 353 unlink(*pinst + 7);
344 354 } else if (strncmp(*pinst, "redirect ", 9) == 0) {
345 355 int r = snprintf(opts->redirectsockname,
346 356 sizeof(opts->redirectsockname), "%s",
347 357 *pinst + 9);
348 358 if (r < 0 || r >= (int)sizeof(opts->redirectsockname))
349 359 abortmsg("redirect path is too long (%d)", r);
350 360 needreconnect = 1;
351 361 } else if (strncmp(*pinst, "exit ", 5) == 0) {
352 362 int n = 0;
353 363 if (sscanf(*pinst + 5, "%d", &n) != 1)
354 364 abortmsg("cannot read the exit code");
355 365 exit(n);
356 366 } else if (strcmp(*pinst, "reconnect") == 0) {
357 367 needreconnect = 1;
358 368 } else {
359 369 abortmsg("unknown instruction: %s", *pinst);
360 370 }
361 371 }
362 372 return needreconnect;
363 373 }
364 374
365 375 /*
366 376 * Test whether the command is unsupported or not. This is not designed to
367 * cover all cases. But it's fast, does not depend on the server and does
368 * not return false positives.
377 * cover all cases. But it's fast, does not depend on the server.
369 378 */
370 379 static int isunsupported(int argc, const char *argv[])
371 380 {
372 381 enum { SERVE = 1,
373 382 DAEMON = 2,
374 383 SERVEDAEMON = SERVE | DAEMON,
375 384 };
376 385 unsigned int state = 0;
377 386 int i;
378 387 for (i = 0; i < argc; ++i) {
379 388 if (strcmp(argv[i], "--") == 0)
380 389 break;
381 if (i == 0 && strcmp("serve", argv[i]) == 0)
390 /*
391 * there can be false positives but no false negative
392 * we cannot assume `serve` will always be first argument
393 * because global options can be passed before the command name
394 */
395 if (strcmp("serve", argv[i]) == 0)
382 396 state |= SERVE;
383 397 else if (strcmp("-d", argv[i]) == 0 ||
384 398 strcmp("--daemon", argv[i]) == 0)
385 399 state |= DAEMON;
386 400 }
387 401 return (state & SERVEDAEMON) == SERVEDAEMON;
388 402 }
389 403
390 404 static void execoriginalhg(const char *argv[])
391 405 {
392 406 debugmsg("execute original hg");
393 407 if (execvp(gethgcmd(), (char **)argv) < 0)
394 408 abortmsgerrno("failed to exec original hg");
395 409 }
396 410
397 411 int main(int argc, const char *argv[], const char *envp[])
398 412 {
399 413 if (getenv("CHGDEBUG"))
400 414 enabledebugmsg();
401 415
402 416 if (!getenv("HGPLAIN") && isatty(fileno(stderr)))
403 417 enablecolor();
404 418
405 419 if (getenv("CHGINTERNALMARK"))
406 420 abortmsg("chg started by chg detected.\n"
407 421 "Please make sure ${HG:-hg} is not a symlink or "
408 422 "wrapper to chg. Alternatively, set $CHGHG to the "
409 423 "path of real hg.");
410 424
411 425 if (isunsupported(argc - 1, argv + 1))
412 426 execoriginalhg(argv);
413 427
414 428 struct cmdserveropts opts;
415 429 initcmdserveropts(&opts);
416 430 setcmdserveropts(&opts);
417 431 setcmdserverargs(&opts, argc, argv);
418 432
419 433 if (argc == 2) {
420 434 if (strcmp(argv[1], "--kill-chg-daemon") == 0) {
421 435 killcmdserver(&opts);
422 436 return 0;
423 437 }
424 438 }
425 439
426 440 hgclient_t *hgc;
427 441 size_t retry = 0;
428 442 while (1) {
429 443 hgc = connectcmdserver(&opts);
430 444 if (!hgc)
431 445 abortmsg("cannot open hg client");
432 446 hgc_setenv(hgc, envp);
433 447 const char **insts = hgc_validate(hgc, argv + 1, argc - 1);
434 448 int needreconnect = runinstructions(&opts, insts);
435 449 free(insts);
436 450 if (!needreconnect)
437 451 break;
438 452 hgc_close(hgc);
439 453 if (++retry > 10)
440 454 abortmsg("too many redirections.\n"
441 455 "Please make sure %s is not a wrapper which "
442 456 "changes sensitive environment variables "
443 457 "before executing hg. If you have to use a "
444 458 "wrapper, wrap chg instead of hg.",
445 459 gethgcmd());
446 460 }
447 461
448 462 setupsignalhandler(hgc_peerpid(hgc), hgc_peerpgid(hgc));
449 463 atexit(waitpager);
450 464 int exitcode = hgc_runcommand(hgc, argv + 1, argc - 1);
451 465 restoresignalhandler();
452 466 hgc_close(hgc);
453 467 freecmdserveropts(&opts);
454 468
455 469 return exitcode;
456 470 }
@@ -1,14 +1,14 b''
1 1 [fix]
2 2 clang-format:command = clang-format --style file
3 3 clang-format:pattern = set:(**.c or **.cc or **.h) and not "include:contrib/clang-format-ignorelist"
4 4
5 5 rustfmt:command = rustfmt +nightly
6 6 rustfmt:pattern = set:**.rs
7 7
8 8 black:command = black --config=black.toml -
9 black:pattern = set:**.py - mercurial/thirdparty/** - "contrib/python-zstandard/**"
9 black:pattern = set:**.py - mercurial/thirdparty/**
10 10
11 11 # Mercurial doesn't have any Go code, but if we did this is how we
12 12 # would configure `hg fix` for Go:
13 13 go:command = gofmt
14 14 go:pattern = set:**.go
@@ -1,81 +1,81 b''
1 1 image: octobus/ci-mercurial-core
2 2
3 3 # The runner made a clone as root.
4 4 # We make a new clone owned by user used to run the step.
5 5 before_script:
6 6 - hg clone . /tmp/mercurial-ci/ --noupdate
7 7 - hg -R /tmp/mercurial-ci/ update `hg log --rev '.' --template '{node}'`
8 8 - cd /tmp/mercurial-ci/
9 9 - ls -1 tests/test-check-*.* > /tmp/check-tests.txt
10 10
11 11 variables:
12 12 PYTHON: python
13 13 TEST_HGMODULEPOLICY: "allow"
14 14
15 15 .runtests_template: &runtests
16 16 script:
17 17 - echo "python used, $PYTHON"
18 18 - echo "$RUNTEST_ARGS"
19 19 - HGMODULEPOLICY="$TEST_HGMODULEPOLICY" "$PYTHON" tests/run-tests.py --color=always $RUNTEST_ARGS
20 20
21 21 checks-py2:
22 22 <<: *runtests
23 23 variables:
24 24 RUNTEST_ARGS: "--time --test-list /tmp/check-tests.txt"
25 25
26 26 checks-py3:
27 27 <<: *runtests
28 28 variables:
29 29 RUNTEST_ARGS: "--time --test-list /tmp/check-tests.txt"
30 30 PYTHON: python3
31 31
32 32 rust-cargo-test-py2: &rust_cargo_test
33 33 script:
34 34 - echo "python used, $PYTHON"
35 35 - make rust-tests
36 36
37 37 rust-cargo-test-py3:
38 38 <<: *rust_cargo_test
39 39 variables:
40 40 PYTHON: python3
41 41
42 42 test-py2:
43 43 <<: *runtests
44 44 variables:
45 RUNTEST_ARGS: "--blacklist /tmp/check-tests.txt"
45 RUNTEST_ARGS: " --no-rust --blacklist /tmp/check-tests.txt"
46 46 TEST_HGMODULEPOLICY: "c"
47 47
48 48 test-py3:
49 49 <<: *runtests
50 50 variables:
51 RUNTEST_ARGS: "--blacklist /tmp/check-tests.txt"
51 RUNTEST_ARGS: " --no-rust --blacklist /tmp/check-tests.txt"
52 52 PYTHON: python3
53 53 TEST_HGMODULEPOLICY: "c"
54 54
55 55 test-py2-pure:
56 56 <<: *runtests
57 57 variables:
58 58 RUNTEST_ARGS: "--pure --blacklist /tmp/check-tests.txt"
59 59 TEST_HGMODULEPOLICY: "py"
60 60
61 61 test-py3-pure:
62 62 <<: *runtests
63 63 variables:
64 64 RUNTEST_ARGS: "--pure --blacklist /tmp/check-tests.txt"
65 65 PYTHON: python3
66 66 TEST_HGMODULEPOLICY: "py"
67 67
68 68 test-py2-rust:
69 69 <<: *runtests
70 70 variables:
71 71 HGWITHRUSTEXT: cpython
72 RUNTEST_ARGS: "--blacklist /tmp/check-tests.txt"
72 RUNTEST_ARGS: "--rust --blacklist /tmp/check-tests.txt"
73 73 TEST_HGMODULEPOLICY: "rust+c"
74 74
75 75 test-py3-rust:
76 76 <<: *runtests
77 77 variables:
78 78 HGWITHRUSTEXT: cpython
79 RUNTEST_ARGS: "--blacklist /tmp/check-tests.txt"
79 RUNTEST_ARGS: "--rust --blacklist /tmp/check-tests.txt"
80 80 PYTHON: python3
81 81 TEST_HGMODULEPOLICY: "rust+c"
@@ -1,820 +1,821 b''
1 1 #!/usr/bin/env python
2 2
3 3 from __future__ import absolute_import, print_function
4 4
5 5 import ast
6 6 import collections
7 7 import io
8 8 import os
9 9 import sys
10 10
11 11 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
12 12 # to work when run from a virtualenv. The modules were chosen empirically
13 13 # so that the return value matches the return value without virtualenv.
14 14 if True: # disable lexical sorting checks
15 15 try:
16 16 import BaseHTTPServer as basehttpserver
17 17 except ImportError:
18 18 basehttpserver = None
19 19 import zlib
20 20
21 21 import testparseutil
22 22
23 23 # Whitelist of modules that symbols can be directly imported from.
24 24 allowsymbolimports = (
25 25 '__future__',
26 26 'bzrlib',
27 27 'hgclient',
28 28 'mercurial',
29 29 'mercurial.hgweb.common',
30 30 'mercurial.hgweb.request',
31 31 'mercurial.i18n',
32 32 'mercurial.interfaces',
33 33 'mercurial.node',
34 34 'mercurial.pycompat',
35 35 # for revlog to re-export constant to extensions
36 36 'mercurial.revlogutils.constants',
37 37 'mercurial.revlogutils.flagutil',
38 38 # for cffi modules to re-export pure functions
39 39 'mercurial.pure.base85',
40 40 'mercurial.pure.bdiff',
41 41 'mercurial.pure.mpatch',
42 42 'mercurial.pure.osutil',
43 43 'mercurial.pure.parsers',
44 44 # third-party imports should be directly imported
45 45 'mercurial.thirdparty',
46 46 'mercurial.thirdparty.attr',
47 47 'mercurial.thirdparty.zope',
48 48 'mercurial.thirdparty.zope.interface',
49 49 )
50 50
51 51 # Whitelist of symbols that can be directly imported.
52 52 directsymbols = ('demandimport',)
53 53
54 54 # Modules that must be aliased because they are commonly confused with
55 55 # common variables and can create aliasing and readability issues.
56 56 requirealias = {
57 57 'ui': 'uimod',
58 58 }
59 59
60 60
61 61 def usingabsolute(root):
62 62 """Whether absolute imports are being used."""
63 63 if sys.version_info[0] >= 3:
64 64 return True
65 65
66 66 for node in ast.walk(root):
67 67 if isinstance(node, ast.ImportFrom):
68 68 if node.module == '__future__':
69 69 for n in node.names:
70 70 if n.name == 'absolute_import':
71 71 return True
72 72
73 73 return False
74 74
75 75
76 76 def walklocal(root):
77 77 """Recursively yield all descendant nodes but not in a different scope"""
78 78 todo = collections.deque(ast.iter_child_nodes(root))
79 79 yield root, False
80 80 while todo:
81 81 node = todo.popleft()
82 82 newscope = isinstance(node, ast.FunctionDef)
83 83 if not newscope:
84 84 todo.extend(ast.iter_child_nodes(node))
85 85 yield node, newscope
86 86
87 87
88 88 def dotted_name_of_path(path):
89 89 """Given a relative path to a source file, return its dotted module name.
90 90
91 91 >>> dotted_name_of_path('mercurial/error.py')
92 92 'mercurial.error'
93 93 >>> dotted_name_of_path('zlibmodule.so')
94 94 'zlib'
95 95 """
96 96 parts = path.replace(os.sep, '/').split('/')
97 97 parts[-1] = parts[-1].split('.', 1)[0] # remove .py and .so and .ARCH.so
98 98 if parts[-1].endswith('module'):
99 99 parts[-1] = parts[-1][:-6]
100 100 return '.'.join(parts)
101 101
102 102
103 103 def fromlocalfunc(modulename, localmods):
104 104 """Get a function to examine which locally defined module the
105 105 target source imports via a specified name.
106 106
107 107 `modulename` is an `dotted_name_of_path()`-ed source file path,
108 108 which may have `.__init__` at the end of it, of the target source.
109 109
110 110 `localmods` is a set of absolute `dotted_name_of_path()`-ed source file
111 111 paths of locally defined (= Mercurial specific) modules.
112 112
113 113 This function assumes that module names not existing in
114 114 `localmods` are from the Python standard library.
115 115
116 116 This function returns the function, which takes `name` argument,
117 117 and returns `(absname, dottedpath, hassubmod)` tuple if `name`
118 118 matches against locally defined module. Otherwise, it returns
119 119 False.
120 120
121 121 It is assumed that `name` doesn't have `.__init__`.
122 122
123 123 `absname` is an absolute module name of specified `name`
124 124 (e.g. "hgext.convert"). This can be used to compose prefix for sub
125 125 modules or so.
126 126
127 127 `dottedpath` is a `dotted_name_of_path()`-ed source file path
128 128 (e.g. "hgext.convert.__init__") of `name`. This is used to look
129 129 module up in `localmods` again.
130 130
131 131 `hassubmod` is whether it may have sub modules under it (for
132 132 convenient, even though this is also equivalent to "absname !=
133 133 dottednpath")
134 134
135 135 >>> localmods = {'foo.__init__', 'foo.foo1',
136 136 ... 'foo.bar.__init__', 'foo.bar.bar1',
137 137 ... 'baz.__init__', 'baz.baz1'}
138 138 >>> fromlocal = fromlocalfunc('foo.xxx', localmods)
139 139 >>> # relative
140 140 >>> fromlocal('foo1')
141 141 ('foo.foo1', 'foo.foo1', False)
142 142 >>> fromlocal('bar')
143 143 ('foo.bar', 'foo.bar.__init__', True)
144 144 >>> fromlocal('bar.bar1')
145 145 ('foo.bar.bar1', 'foo.bar.bar1', False)
146 146 >>> # absolute
147 147 >>> fromlocal('baz')
148 148 ('baz', 'baz.__init__', True)
149 149 >>> fromlocal('baz.baz1')
150 150 ('baz.baz1', 'baz.baz1', False)
151 151 >>> # unknown = maybe standard library
152 152 >>> fromlocal('os')
153 153 False
154 154 >>> fromlocal(None, 1)
155 155 ('foo', 'foo.__init__', True)
156 156 >>> fromlocal('foo1', 1)
157 157 ('foo.foo1', 'foo.foo1', False)
158 158 >>> fromlocal2 = fromlocalfunc('foo.xxx.yyy', localmods)
159 159 >>> fromlocal2(None, 2)
160 160 ('foo', 'foo.__init__', True)
161 161 >>> fromlocal2('bar2', 1)
162 162 False
163 163 >>> fromlocal2('bar', 2)
164 164 ('foo.bar', 'foo.bar.__init__', True)
165 165 """
166 166 if not isinstance(modulename, str):
167 167 modulename = modulename.decode('ascii')
168 168 prefix = '.'.join(modulename.split('.')[:-1])
169 169 if prefix:
170 170 prefix += '.'
171 171
172 172 def fromlocal(name, level=0):
173 173 # name is false value when relative imports are used.
174 174 if not name:
175 175 # If relative imports are used, level must not be absolute.
176 176 assert level > 0
177 177 candidates = ['.'.join(modulename.split('.')[:-level])]
178 178 else:
179 179 if not level:
180 180 # Check relative name first.
181 181 candidates = [prefix + name, name]
182 182 else:
183 183 candidates = [
184 184 '.'.join(modulename.split('.')[:-level]) + '.' + name
185 185 ]
186 186
187 187 for n in candidates:
188 188 if n in localmods:
189 189 return (n, n, False)
190 190 dottedpath = n + '.__init__'
191 191 if dottedpath in localmods:
192 192 return (n, dottedpath, True)
193 193 return False
194 194
195 195 return fromlocal
196 196
197 197
198 198 def populateextmods(localmods):
199 199 """Populate C extension modules based on pure modules"""
200 200 newlocalmods = set(localmods)
201 201 for n in localmods:
202 202 if n.startswith('mercurial.pure.'):
203 203 m = n[len('mercurial.pure.') :]
204 204 newlocalmods.add('mercurial.cext.' + m)
205 205 newlocalmods.add('mercurial.cffi._' + m)
206 206 return newlocalmods
207 207
208 208
209 209 def list_stdlib_modules():
210 210 """List the modules present in the stdlib.
211 211
212 212 >>> py3 = sys.version_info[0] >= 3
213 213 >>> mods = set(list_stdlib_modules())
214 214 >>> 'BaseHTTPServer' in mods or py3
215 215 True
216 216
217 217 os.path isn't really a module, so it's missing:
218 218
219 219 >>> 'os.path' in mods
220 220 False
221 221
222 222 sys requires special treatment, because it's baked into the
223 223 interpreter, but it should still appear:
224 224
225 225 >>> 'sys' in mods
226 226 True
227 227
228 228 >>> 'collections' in mods
229 229 True
230 230
231 231 >>> 'cStringIO' in mods or py3
232 232 True
233 233
234 234 >>> 'cffi' in mods
235 235 True
236 236 """
237 237 for m in sys.builtin_module_names:
238 238 yield m
239 239 # These modules only exist on windows, but we should always
240 240 # consider them stdlib.
241 241 for m in ['msvcrt', '_winreg']:
242 242 yield m
243 243 yield '__builtin__'
244 244 yield 'builtins' # python3 only
245 245 yield 'importlib.abc' # python3 only
246 246 yield 'importlib.machinery' # python3 only
247 247 yield 'importlib.util' # python3 only
248 248 for m in 'fcntl', 'grp', 'pwd', 'termios': # Unix only
249 249 yield m
250 250 for m in 'cPickle', 'datetime': # in Python (not C) on PyPy
251 251 yield m
252 252 for m in ['cffi']:
253 253 yield m
254 254 stdlib_prefixes = {sys.prefix, sys.exec_prefix}
255 255 # We need to supplement the list of prefixes for the search to work
256 256 # when run from within a virtualenv.
257 257 for mod in (basehttpserver, zlib):
258 258 if mod is None:
259 259 continue
260 260 try:
261 261 # Not all module objects have a __file__ attribute.
262 262 filename = mod.__file__
263 263 except AttributeError:
264 264 continue
265 265 dirname = os.path.dirname(filename)
266 266 for prefix in stdlib_prefixes:
267 267 if dirname.startswith(prefix):
268 268 # Then this directory is redundant.
269 269 break
270 270 else:
271 271 stdlib_prefixes.add(dirname)
272 272 sourceroot = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
273 273 for libpath in sys.path:
274 274 # We want to walk everything in sys.path that starts with something in
275 275 # stdlib_prefixes, but not directories from the hg sources.
276 276 if os.path.abspath(libpath).startswith(sourceroot) or not any(
277 277 libpath.startswith(p) for p in stdlib_prefixes
278 278 ):
279 279 continue
280 280 for top, dirs, files in os.walk(libpath):
281 281 for i, d in reversed(list(enumerate(dirs))):
282 282 if (
283 283 not os.path.exists(os.path.join(top, d, '__init__.py'))
284 284 or top == libpath
285 285 and d in ('hgdemandimport', 'hgext', 'mercurial')
286 286 ):
287 287 del dirs[i]
288 288 for name in files:
289 289 if not name.endswith(('.py', '.so', '.pyc', '.pyo', '.pyd')):
290 290 continue
291 291 if name.startswith('__init__.py'):
292 292 full_path = top
293 293 else:
294 294 full_path = os.path.join(top, name)
295 295 rel_path = full_path[len(libpath) + 1 :]
296 296 mod = dotted_name_of_path(rel_path)
297 297 yield mod
298 298
299 299
300 300 stdlib_modules = set(list_stdlib_modules())
301 301
302 302
303 303 def imported_modules(source, modulename, f, localmods, ignore_nested=False):
304 304 """Given the source of a file as a string, yield the names
305 305 imported by that file.
306 306
307 307 Args:
308 308 source: The python source to examine as a string.
309 309 modulename: of specified python source (may have `__init__`)
310 310 localmods: set of locally defined module names (may have `__init__`)
311 311 ignore_nested: If true, import statements that do not start in
312 312 column zero will be ignored.
313 313
314 314 Returns:
315 315 A list of absolute module names imported by the given source.
316 316
317 317 >>> f = 'foo/xxx.py'
318 318 >>> modulename = 'foo.xxx'
319 319 >>> localmods = {'foo.__init__': True,
320 320 ... 'foo.foo1': True, 'foo.foo2': True,
321 321 ... 'foo.bar.__init__': True, 'foo.bar.bar1': True,
322 322 ... 'baz.__init__': True, 'baz.baz1': True }
323 323 >>> # standard library (= not locally defined ones)
324 324 >>> sorted(imported_modules(
325 325 ... 'from stdlib1 import foo, bar; import stdlib2',
326 326 ... modulename, f, localmods))
327 327 []
328 328 >>> # relative importing
329 329 >>> sorted(imported_modules(
330 330 ... 'import foo1; from bar import bar1',
331 331 ... modulename, f, localmods))
332 332 ['foo.bar.bar1', 'foo.foo1']
333 333 >>> sorted(imported_modules(
334 334 ... 'from bar.bar1 import name1, name2, name3',
335 335 ... modulename, f, localmods))
336 336 ['foo.bar.bar1']
337 337 >>> # absolute importing
338 338 >>> sorted(imported_modules(
339 339 ... 'from baz import baz1, name1',
340 340 ... modulename, f, localmods))
341 341 ['baz.__init__', 'baz.baz1']
342 342 >>> # mixed importing, even though it shouldn't be recommended
343 343 >>> sorted(imported_modules(
344 344 ... 'import stdlib, foo1, baz',
345 345 ... modulename, f, localmods))
346 346 ['baz.__init__', 'foo.foo1']
347 347 >>> # ignore_nested
348 348 >>> sorted(imported_modules(
349 349 ... '''import foo
350 350 ... def wat():
351 351 ... import bar
352 352 ... ''', modulename, f, localmods))
353 353 ['foo.__init__', 'foo.bar.__init__']
354 354 >>> sorted(imported_modules(
355 355 ... '''import foo
356 356 ... def wat():
357 357 ... import bar
358 358 ... ''', modulename, f, localmods, ignore_nested=True))
359 359 ['foo.__init__']
360 360 """
361 361 fromlocal = fromlocalfunc(modulename, localmods)
362 362 for node in ast.walk(ast.parse(source, f)):
363 363 if ignore_nested and getattr(node, 'col_offset', 0) > 0:
364 364 continue
365 365 if isinstance(node, ast.Import):
366 366 for n in node.names:
367 367 found = fromlocal(n.name)
368 368 if not found:
369 369 # this should import standard library
370 370 continue
371 371 yield found[1]
372 372 elif isinstance(node, ast.ImportFrom):
373 373 found = fromlocal(node.module, node.level)
374 374 if not found:
375 375 # this should import standard library
376 376 continue
377 377
378 378 absname, dottedpath, hassubmod = found
379 379 if not hassubmod:
380 380 # "dottedpath" is not a package; must be imported
381 381 yield dottedpath
382 382 # examination of "node.names" should be redundant
383 383 # e.g.: from mercurial.node import nullid, nullrev
384 384 continue
385 385
386 386 modnotfound = False
387 387 prefix = absname + '.'
388 388 for n in node.names:
389 389 found = fromlocal(prefix + n.name)
390 390 if not found:
391 391 # this should be a function or a property of "node.module"
392 392 modnotfound = True
393 393 continue
394 394 yield found[1]
395 if modnotfound:
395 if modnotfound and dottedpath != modulename:
396 396 # "dottedpath" is a package, but imported because of non-module
397 397 # lookup
398 # specifically allow "from . import foo" from __init__.py
398 399 yield dottedpath
399 400
400 401
401 402 def verify_import_convention(module, source, localmods):
402 403 """Verify imports match our established coding convention.
403 404
404 405 We have 2 conventions: legacy and modern. The modern convention is in
405 406 effect when using absolute imports.
406 407
407 408 The legacy convention only looks for mixed imports. The modern convention
408 409 is much more thorough.
409 410 """
410 411 root = ast.parse(source)
411 412 absolute = usingabsolute(root)
412 413
413 414 if absolute:
414 415 return verify_modern_convention(module, root, localmods)
415 416 else:
416 417 return verify_stdlib_on_own_line(root)
417 418
418 419
419 420 def verify_modern_convention(module, root, localmods, root_col_offset=0):
420 421 """Verify a file conforms to the modern import convention rules.
421 422
422 423 The rules of the modern convention are:
423 424
424 425 * Ordering is stdlib followed by local imports. Each group is lexically
425 426 sorted.
426 427 * Importing multiple modules via "import X, Y" is not allowed: use
427 428 separate import statements.
428 429 * Importing multiple modules via "from X import ..." is allowed if using
429 430 parenthesis and one entry per line.
430 431 * Only 1 relative import statement per import level ("from .", "from ..")
431 432 is allowed.
432 433 * Relative imports from higher levels must occur before lower levels. e.g.
433 434 "from .." must be before "from .".
434 435 * Imports from peer packages should use relative import (e.g. do not
435 436 "import mercurial.foo" from a "mercurial.*" module).
436 437 * Symbols can only be imported from specific modules (see
437 438 `allowsymbolimports`). For other modules, first import the module then
438 439 assign the symbol to a module-level variable. In addition, these imports
439 440 must be performed before other local imports. This rule only
440 441 applies to import statements outside of any blocks.
441 442 * Relative imports from the standard library are not allowed, unless that
442 443 library is also a local module.
443 444 * Certain modules must be aliased to alternate names to avoid aliasing
444 445 and readability problems. See `requirealias`.
445 446 """
446 447 if not isinstance(module, str):
447 448 module = module.decode('ascii')
448 449 topmodule = module.split('.')[0]
449 450 fromlocal = fromlocalfunc(module, localmods)
450 451
451 452 # Whether a local/non-stdlib import has been performed.
452 453 seenlocal = None
453 454 # Whether a local/non-stdlib, non-symbol import has been seen.
454 455 seennonsymbollocal = False
455 456 # The last name to be imported (for sorting).
456 457 lastname = None
457 458 laststdlib = None
458 459 # Relative import levels encountered so far.
459 460 seenlevels = set()
460 461
461 462 for node, newscope in walklocal(root):
462 463
463 464 def msg(fmt, *args):
464 465 return (fmt % args, node.lineno)
465 466
466 467 if newscope:
467 468 # Check for local imports in function
468 469 for r in verify_modern_convention(
469 470 module, node, localmods, node.col_offset + 4
470 471 ):
471 472 yield r
472 473 elif isinstance(node, ast.Import):
473 474 # Disallow "import foo, bar" and require separate imports
474 475 # for each module.
475 476 if len(node.names) > 1:
476 477 yield msg(
477 478 'multiple imported names: %s',
478 479 ', '.join(n.name for n in node.names),
479 480 )
480 481
481 482 name = node.names[0].name
482 483 asname = node.names[0].asname
483 484
484 485 stdlib = name in stdlib_modules
485 486
486 487 # Ignore sorting rules on imports inside blocks.
487 488 if node.col_offset == root_col_offset:
488 489 if lastname and name < lastname and laststdlib == stdlib:
489 490 yield msg(
490 491 'imports not lexically sorted: %s < %s', name, lastname
491 492 )
492 493
493 494 lastname = name
494 495 laststdlib = stdlib
495 496
496 497 # stdlib imports should be before local imports.
497 498 if stdlib and seenlocal and node.col_offset == root_col_offset:
498 499 yield msg(
499 500 'stdlib import "%s" follows local import: %s',
500 501 name,
501 502 seenlocal,
502 503 )
503 504
504 505 if not stdlib:
505 506 seenlocal = name
506 507
507 508 # Import of sibling modules should use relative imports.
508 509 topname = name.split('.')[0]
509 510 if topname == topmodule:
510 511 yield msg('import should be relative: %s', name)
511 512
512 513 if name in requirealias and asname != requirealias[name]:
513 514 yield msg(
514 515 '%s module must be "as" aliased to %s',
515 516 name,
516 517 requirealias[name],
517 518 )
518 519
519 520 elif isinstance(node, ast.ImportFrom):
520 521 # Resolve the full imported module name.
521 522 if node.level > 0:
522 523 fullname = '.'.join(module.split('.')[: -node.level])
523 524 if node.module:
524 525 fullname += '.%s' % node.module
525 526 else:
526 527 assert node.module
527 528 fullname = node.module
528 529
529 530 topname = fullname.split('.')[0]
530 531 if topname == topmodule:
531 532 yield msg('import should be relative: %s', fullname)
532 533
533 534 # __future__ is special since it needs to come first and use
534 535 # symbol import.
535 536 if fullname != '__future__':
536 537 if not fullname or (
537 538 fullname in stdlib_modules
538 539 # allow standard 'from typing import ...' style
539 540 and fullname.startswith('.')
540 541 and fullname not in localmods
541 542 and fullname + '.__init__' not in localmods
542 543 ):
543 544 yield msg('relative import of stdlib module')
544 545 else:
545 546 seenlocal = fullname
546 547
547 548 # Direct symbol import is only allowed from certain modules and
548 549 # must occur before non-symbol imports.
549 550 found = fromlocal(node.module, node.level)
550 551 if found and found[2]: # node.module is a package
551 552 prefix = found[0] + '.'
552 553 symbols = (
553 554 n.name for n in node.names if not fromlocal(prefix + n.name)
554 555 )
555 556 else:
556 557 symbols = (n.name for n in node.names)
557 558 symbols = [sym for sym in symbols if sym not in directsymbols]
558 559 if node.module and node.col_offset == root_col_offset:
559 560 if symbols and fullname not in allowsymbolimports:
560 561 yield msg(
561 562 'direct symbol import %s from %s',
562 563 ', '.join(symbols),
563 564 fullname,
564 565 )
565 566
566 567 if symbols and seennonsymbollocal:
567 568 yield msg(
568 569 'symbol import follows non-symbol import: %s', fullname
569 570 )
570 571 if not symbols and fullname not in stdlib_modules:
571 572 seennonsymbollocal = True
572 573
573 574 if not node.module:
574 575 assert node.level
575 576
576 577 # Only allow 1 group per level.
577 578 if (
578 579 node.level in seenlevels
579 580 and node.col_offset == root_col_offset
580 581 ):
581 582 yield msg(
582 583 'multiple "from %s import" statements', '.' * node.level
583 584 )
584 585
585 586 # Higher-level groups come before lower-level groups.
586 587 if any(node.level > l for l in seenlevels):
587 588 yield msg(
588 589 'higher-level import should come first: %s', fullname
589 590 )
590 591
591 592 seenlevels.add(node.level)
592 593
593 594 # Entries in "from .X import ( ... )" lists must be lexically
594 595 # sorted.
595 596 lastentryname = None
596 597
597 598 for n in node.names:
598 599 if lastentryname and n.name < lastentryname:
599 600 yield msg(
600 601 'imports from %s not lexically sorted: %s < %s',
601 602 fullname,
602 603 n.name,
603 604 lastentryname,
604 605 )
605 606
606 607 lastentryname = n.name
607 608
608 609 if n.name in requirealias and n.asname != requirealias[n.name]:
609 610 yield msg(
610 611 '%s from %s must be "as" aliased to %s',
611 612 n.name,
612 613 fullname,
613 614 requirealias[n.name],
614 615 )
615 616
616 617
617 618 def verify_stdlib_on_own_line(root):
618 619 """Given some python source, verify that stdlib imports are done
619 620 in separate statements from relative local module imports.
620 621
621 622 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, foo')))
622 623 [('mixed imports\\n stdlib: sys\\n relative: foo', 1)]
623 624 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, os')))
624 625 []
625 626 >>> list(verify_stdlib_on_own_line(ast.parse('import foo, bar')))
626 627 []
627 628 """
628 629 for node in ast.walk(root):
629 630 if isinstance(node, ast.Import):
630 631 from_stdlib = {False: [], True: []}
631 632 for n in node.names:
632 633 from_stdlib[n.name in stdlib_modules].append(n.name)
633 634 if from_stdlib[True] and from_stdlib[False]:
634 635 yield (
635 636 'mixed imports\n stdlib: %s\n relative: %s'
636 637 % (
637 638 ', '.join(sorted(from_stdlib[True])),
638 639 ', '.join(sorted(from_stdlib[False])),
639 640 ),
640 641 node.lineno,
641 642 )
642 643
643 644
644 645 class CircularImport(Exception):
645 646 pass
646 647
647 648
648 649 def checkmod(mod, imports):
649 650 shortest = {}
650 651 visit = [[mod]]
651 652 while visit:
652 653 path = visit.pop(0)
653 654 for i in sorted(imports.get(path[-1], [])):
654 655 if len(path) < shortest.get(i, 1000):
655 656 shortest[i] = len(path)
656 657 if i in path:
657 658 if i == path[0]:
658 659 raise CircularImport(path)
659 660 continue
660 661 visit.append(path + [i])
661 662
662 663
663 664 def rotatecycle(cycle):
664 665 """arrange a cycle so that the lexicographically first module listed first
665 666
666 667 >>> rotatecycle(['foo', 'bar'])
667 668 ['bar', 'foo', 'bar']
668 669 """
669 670 lowest = min(cycle)
670 671 idx = cycle.index(lowest)
671 672 return cycle[idx:] + cycle[:idx] + [lowest]
672 673
673 674
674 675 def find_cycles(imports):
675 676 """Find cycles in an already-loaded import graph.
676 677
677 678 All module names recorded in `imports` should be absolute one.
678 679
679 680 >>> from __future__ import print_function
680 681 >>> imports = {'top.foo': ['top.bar', 'os.path', 'top.qux'],
681 682 ... 'top.bar': ['top.baz', 'sys'],
682 683 ... 'top.baz': ['top.foo'],
683 684 ... 'top.qux': ['top.foo']}
684 685 >>> print('\\n'.join(sorted(find_cycles(imports))))
685 686 top.bar -> top.baz -> top.foo -> top.bar
686 687 top.foo -> top.qux -> top.foo
687 688 """
688 689 cycles = set()
689 690 for mod in sorted(imports.keys()):
690 691 try:
691 692 checkmod(mod, imports)
692 693 except CircularImport as e:
693 694 cycle = e.args[0]
694 695 cycles.add(" -> ".join(rotatecycle(cycle)))
695 696 return cycles
696 697
697 698
698 699 def _cycle_sortkey(c):
699 700 return len(c), c
700 701
701 702
702 703 def embedded(f, modname, src):
703 704 """Extract embedded python code
704 705
705 706 >>> def _forcestr(thing):
706 707 ... if not isinstance(thing, str):
707 708 ... return thing.decode('ascii')
708 709 ... return thing
709 710 >>> def test(fn, lines):
710 711 ... for s, m, f, l in embedded(fn, b"example", lines):
711 712 ... print("%s %s %d" % (_forcestr(m), _forcestr(f), l))
712 713 ... print(repr(_forcestr(s)))
713 714 >>> lines = [
714 715 ... 'comment',
715 716 ... ' >>> from __future__ import print_function',
716 717 ... " >>> ' multiline",
717 718 ... " ... string'",
718 719 ... ' ',
719 720 ... 'comment',
720 721 ... ' $ cat > foo.py <<EOF',
721 722 ... ' > from __future__ import print_function',
722 723 ... ' > EOF',
723 724 ... ]
724 725 >>> test(b"example.t", lines)
725 726 example[2] doctest.py 1
726 727 "from __future__ import print_function\\n' multiline\\nstring'\\n\\n"
727 728 example[8] foo.py 7
728 729 'from __future__ import print_function\\n'
729 730 """
730 731 errors = []
731 732 for name, starts, ends, code in testparseutil.pyembedded(f, src, errors):
732 733 if not name:
733 734 # use 'doctest.py', in order to make already existing
734 735 # doctest above pass instantly
735 736 name = 'doctest.py'
736 737 # "starts" is "line number" (1-origin), but embedded() is
737 738 # expected to return "line offset" (0-origin). Therefore, this
738 739 # yields "starts - 1".
739 740 if not isinstance(modname, str):
740 741 modname = modname.decode('utf8')
741 742 yield code, "%s[%d]" % (modname, starts), name, starts - 1
742 743
743 744
744 745 def sources(f, modname):
745 746 """Yields possibly multiple sources from a filepath
746 747
747 748 input: filepath, modulename
748 749 yields: script(string), modulename, filepath, linenumber
749 750
750 751 For embedded scripts, the modulename and filepath will be different
751 752 from the function arguments. linenumber is an offset relative to
752 753 the input file.
753 754 """
754 755 py = False
755 756 if not f.endswith('.t'):
756 757 with open(f, 'rb') as src:
757 758 yield src.read(), modname, f, 0
758 759 py = True
759 760 if py or f.endswith('.t'):
760 761 # Strictly speaking we should sniff for the magic header that denotes
761 762 # Python source file encoding. But in reality we don't use anything
762 763 # other than ASCII (mainly) and UTF-8 (in a few exceptions), so
763 764 # simplicity is fine.
764 765 with io.open(f, 'r', encoding='utf-8') as src:
765 766 for script, modname, t, line in embedded(f, modname, src):
766 767 yield script, modname.encode('utf8'), t, line
767 768
768 769
769 770 def main(argv):
770 771 if len(argv) < 2 or (argv[1] == '-' and len(argv) > 2):
771 772 print('Usage: %s {-|file [file] [file] ...}')
772 773 return 1
773 774 if argv[1] == '-':
774 775 argv = argv[:1]
775 776 argv.extend(l.rstrip() for l in sys.stdin.readlines())
776 777 localmodpaths = {}
777 778 used_imports = {}
778 779 any_errors = False
779 780 for source_path in argv[1:]:
780 781 modname = dotted_name_of_path(source_path)
781 782 localmodpaths[modname] = source_path
782 783 localmods = populateextmods(localmodpaths)
783 784 for localmodname, source_path in sorted(localmodpaths.items()):
784 785 if not isinstance(localmodname, bytes):
785 786 # This is only safe because all hg's files are ascii
786 787 localmodname = localmodname.encode('ascii')
787 788 for src, modname, name, line in sources(source_path, localmodname):
788 789 try:
789 790 used_imports[modname] = sorted(
790 791 imported_modules(
791 792 src, modname, name, localmods, ignore_nested=True
792 793 )
793 794 )
794 795 for error, lineno in verify_import_convention(
795 796 modname, src, localmods
796 797 ):
797 798 any_errors = True
798 799 print('%s:%d: %s' % (source_path, lineno + line, error))
799 800 except SyntaxError as e:
800 801 print(
801 802 '%s:%d: SyntaxError: %s' % (source_path, e.lineno + line, e)
802 803 )
803 804 cycles = find_cycles(used_imports)
804 805 if cycles:
805 806 firstmods = set()
806 807 for c in sorted(cycles, key=_cycle_sortkey):
807 808 first = c.split()[0]
808 809 # As a rough cut, ignore any cycle that starts with the
809 810 # same module as some other cycle. Otherwise we see lots
810 811 # of cycles that are effectively duplicates.
811 812 if first in firstmods:
812 813 continue
813 814 print('Import cycle:', c)
814 815 firstmods.add(first)
815 816 any_errors = True
816 817 return any_errors != 0
817 818
818 819
819 820 if __name__ == '__main__':
820 821 sys.exit(int(main(sys.argv)))
@@ -1,3853 +1,3854 b''
1 1 # perf.py - performance test routines
2 2 '''helper extension to measure performance
3 3
4 4 Configurations
5 5 ==============
6 6
7 7 ``perf``
8 8 --------
9 9
10 10 ``all-timing``
11 11 When set, additional statistics will be reported for each benchmark: best,
12 12 worst, median average. If not set only the best timing is reported
13 13 (default: off).
14 14
15 15 ``presleep``
16 16 number of second to wait before any group of runs (default: 1)
17 17
18 18 ``pre-run``
19 19 number of run to perform before starting measurement.
20 20
21 21 ``profile-benchmark``
22 22 Enable profiling for the benchmarked section.
23 23 (The first iteration is benchmarked)
24 24
25 25 ``run-limits``
26 26 Control the number of runs each benchmark will perform. The option value
27 27 should be a list of `<time>-<numberofrun>` pairs. After each run the
28 28 conditions are considered in order with the following logic:
29 29
30 30 If benchmark has been running for <time> seconds, and we have performed
31 31 <numberofrun> iterations, stop the benchmark,
32 32
33 33 The default value is: `3.0-100, 10.0-3`
34 34
35 35 ``stub``
36 36 When set, benchmarks will only be run once, useful for testing
37 37 (default: off)
38 38 '''
39 39
40 40 # "historical portability" policy of perf.py:
41 41 #
42 42 # We have to do:
43 43 # - make perf.py "loadable" with as wide Mercurial version as possible
44 44 # This doesn't mean that perf commands work correctly with that Mercurial.
45 45 # BTW, perf.py itself has been available since 1.1 (or eb240755386d).
46 46 # - make historical perf command work correctly with as wide Mercurial
47 47 # version as possible
48 48 #
49 49 # We have to do, if possible with reasonable cost:
50 50 # - make recent perf command for historical feature work correctly
51 51 # with early Mercurial
52 52 #
53 53 # We don't have to do:
54 54 # - make perf command for recent feature work correctly with early
55 55 # Mercurial
56 56
57 57 from __future__ import absolute_import
58 58 import contextlib
59 59 import functools
60 60 import gc
61 61 import os
62 62 import random
63 63 import shutil
64 64 import struct
65 65 import sys
66 66 import tempfile
67 67 import threading
68 68 import time
69 69 from mercurial import (
70 70 changegroup,
71 71 cmdutil,
72 72 commands,
73 73 copies,
74 74 error,
75 75 extensions,
76 76 hg,
77 77 mdiff,
78 78 merge,
79 79 revlog,
80 80 util,
81 81 )
82 82
83 83 # for "historical portability":
84 84 # try to import modules separately (in dict order), and ignore
85 85 # failure, because these aren't available with early Mercurial
86 86 try:
87 87 from mercurial import branchmap # since 2.5 (or bcee63733aad)
88 88 except ImportError:
89 89 pass
90 90 try:
91 91 from mercurial import obsolete # since 2.3 (or ad0d6c2b3279)
92 92 except ImportError:
93 93 pass
94 94 try:
95 95 from mercurial import registrar # since 3.7 (or 37d50250b696)
96 96
97 97 dir(registrar) # forcibly load it
98 98 except ImportError:
99 99 registrar = None
100 100 try:
101 101 from mercurial import repoview # since 2.5 (or 3a6ddacb7198)
102 102 except ImportError:
103 103 pass
104 104 try:
105 105 from mercurial.utils import repoviewutil # since 5.0
106 106 except ImportError:
107 107 repoviewutil = None
108 108 try:
109 109 from mercurial import scmutil # since 1.9 (or 8b252e826c68)
110 110 except ImportError:
111 111 pass
112 112 try:
113 113 from mercurial import setdiscovery # since 1.9 (or cb98fed52495)
114 114 except ImportError:
115 115 pass
116 116
117 117 try:
118 118 from mercurial import profiling
119 119 except ImportError:
120 120 profiling = None
121 121
122 122
123 123 def identity(a):
124 124 return a
125 125
126 126
127 127 try:
128 128 from mercurial import pycompat
129 129
130 130 getargspec = pycompat.getargspec # added to module after 4.5
131 131 _byteskwargs = pycompat.byteskwargs # since 4.1 (or fbc3f73dc802)
132 132 _sysstr = pycompat.sysstr # since 4.0 (or 2219f4f82ede)
133 133 _bytestr = pycompat.bytestr # since 4.2 (or b70407bd84d5)
134 134 _xrange = pycompat.xrange # since 4.8 (or 7eba8f83129b)
135 135 fsencode = pycompat.fsencode # since 3.9 (or f4a5e0e86a7e)
136 136 if pycompat.ispy3:
137 137 _maxint = sys.maxsize # per py3 docs for replacing maxint
138 138 else:
139 139 _maxint = sys.maxint
140 140 except (NameError, ImportError, AttributeError):
141 141 import inspect
142 142
143 143 getargspec = inspect.getargspec
144 144 _byteskwargs = identity
145 145 _bytestr = str
146 146 fsencode = identity # no py3 support
147 147 _maxint = sys.maxint # no py3 support
148 148 _sysstr = lambda x: x # no py3 support
149 149 _xrange = xrange
150 150
151 151 try:
152 152 # 4.7+
153 153 queue = pycompat.queue.Queue
154 154 except (NameError, AttributeError, ImportError):
155 155 # <4.7.
156 156 try:
157 157 queue = pycompat.queue
158 158 except (NameError, AttributeError, ImportError):
159 159 import Queue as queue
160 160
161 161 try:
162 162 from mercurial import logcmdutil
163 163
164 164 makelogtemplater = logcmdutil.maketemplater
165 165 except (AttributeError, ImportError):
166 166 try:
167 167 makelogtemplater = cmdutil.makelogtemplater
168 168 except (AttributeError, ImportError):
169 169 makelogtemplater = None
170 170
171 171 # for "historical portability":
172 172 # define util.safehasattr forcibly, because util.safehasattr has been
173 173 # available since 1.9.3 (or 94b200a11cf7)
174 174 _undefined = object()
175 175
176 176
177 177 def safehasattr(thing, attr):
178 178 return getattr(thing, _sysstr(attr), _undefined) is not _undefined
179 179
180 180
181 181 setattr(util, 'safehasattr', safehasattr)
182 182
183 183 # for "historical portability":
184 184 # define util.timer forcibly, because util.timer has been available
185 185 # since ae5d60bb70c9
186 186 if safehasattr(time, 'perf_counter'):
187 187 util.timer = time.perf_counter
188 188 elif os.name == b'nt':
189 189 util.timer = time.clock
190 190 else:
191 191 util.timer = time.time
192 192
193 193 # for "historical portability":
194 194 # use locally defined empty option list, if formatteropts isn't
195 195 # available, because commands.formatteropts has been available since
196 196 # 3.2 (or 7a7eed5176a4), even though formatting itself has been
197 197 # available since 2.2 (or ae5f92e154d3)
198 198 formatteropts = getattr(
199 199 cmdutil, "formatteropts", getattr(commands, "formatteropts", [])
200 200 )
201 201
202 202 # for "historical portability":
203 203 # use locally defined option list, if debugrevlogopts isn't available,
204 204 # because commands.debugrevlogopts has been available since 3.7 (or
205 205 # 5606f7d0d063), even though cmdutil.openrevlog() has been available
206 206 # since 1.9 (or a79fea6b3e77).
207 207 revlogopts = getattr(
208 208 cmdutil,
209 209 "debugrevlogopts",
210 210 getattr(
211 211 commands,
212 212 "debugrevlogopts",
213 213 [
214 214 (b'c', b'changelog', False, b'open changelog'),
215 215 (b'm', b'manifest', False, b'open manifest'),
216 216 (b'', b'dir', False, b'open directory manifest'),
217 217 ],
218 218 ),
219 219 )
220 220
221 221 cmdtable = {}
222 222
223 223 # for "historical portability":
224 224 # define parsealiases locally, because cmdutil.parsealiases has been
225 225 # available since 1.5 (or 6252852b4332)
226 226 def parsealiases(cmd):
227 227 return cmd.split(b"|")
228 228
229 229
230 230 if safehasattr(registrar, 'command'):
231 231 command = registrar.command(cmdtable)
232 232 elif safehasattr(cmdutil, 'command'):
233 233 command = cmdutil.command(cmdtable)
234 234 if b'norepo' not in getargspec(command).args:
235 235 # for "historical portability":
236 236 # wrap original cmdutil.command, because "norepo" option has
237 237 # been available since 3.1 (or 75a96326cecb)
238 238 _command = command
239 239
240 240 def command(name, options=(), synopsis=None, norepo=False):
241 241 if norepo:
242 242 commands.norepo += b' %s' % b' '.join(parsealiases(name))
243 243 return _command(name, list(options), synopsis)
244 244
245 245
246 246 else:
247 247 # for "historical portability":
248 248 # define "@command" annotation locally, because cmdutil.command
249 249 # has been available since 1.9 (or 2daa5179e73f)
250 250 def command(name, options=(), synopsis=None, norepo=False):
251 251 def decorator(func):
252 252 if synopsis:
253 253 cmdtable[name] = func, list(options), synopsis
254 254 else:
255 255 cmdtable[name] = func, list(options)
256 256 if norepo:
257 257 commands.norepo += b' %s' % b' '.join(parsealiases(name))
258 258 return func
259 259
260 260 return decorator
261 261
262 262
263 263 try:
264 264 import mercurial.registrar
265 265 import mercurial.configitems
266 266
267 267 configtable = {}
268 268 configitem = mercurial.registrar.configitem(configtable)
269 269 configitem(
270 270 b'perf',
271 271 b'presleep',
272 272 default=mercurial.configitems.dynamicdefault,
273 273 experimental=True,
274 274 )
275 275 configitem(
276 276 b'perf',
277 277 b'stub',
278 278 default=mercurial.configitems.dynamicdefault,
279 279 experimental=True,
280 280 )
281 281 configitem(
282 282 b'perf',
283 283 b'parentscount',
284 284 default=mercurial.configitems.dynamicdefault,
285 285 experimental=True,
286 286 )
287 287 configitem(
288 288 b'perf',
289 289 b'all-timing',
290 290 default=mercurial.configitems.dynamicdefault,
291 291 experimental=True,
292 292 )
293 293 configitem(
294 294 b'perf', b'pre-run', default=mercurial.configitems.dynamicdefault,
295 295 )
296 296 configitem(
297 297 b'perf',
298 298 b'profile-benchmark',
299 299 default=mercurial.configitems.dynamicdefault,
300 300 )
301 301 configitem(
302 302 b'perf',
303 303 b'run-limits',
304 304 default=mercurial.configitems.dynamicdefault,
305 305 experimental=True,
306 306 )
307 307 except (ImportError, AttributeError):
308 308 pass
309 309 except TypeError:
310 310 # compatibility fix for a11fd395e83f
311 311 # hg version: 5.2
312 312 configitem(
313 313 b'perf', b'presleep', default=mercurial.configitems.dynamicdefault,
314 314 )
315 315 configitem(
316 316 b'perf', b'stub', default=mercurial.configitems.dynamicdefault,
317 317 )
318 318 configitem(
319 319 b'perf', b'parentscount', default=mercurial.configitems.dynamicdefault,
320 320 )
321 321 configitem(
322 322 b'perf', b'all-timing', default=mercurial.configitems.dynamicdefault,
323 323 )
324 324 configitem(
325 325 b'perf', b'pre-run', default=mercurial.configitems.dynamicdefault,
326 326 )
327 327 configitem(
328 328 b'perf',
329 329 b'profile-benchmark',
330 330 default=mercurial.configitems.dynamicdefault,
331 331 )
332 332 configitem(
333 333 b'perf', b'run-limits', default=mercurial.configitems.dynamicdefault,
334 334 )
335 335
336 336
337 337 def getlen(ui):
338 338 if ui.configbool(b"perf", b"stub", False):
339 339 return lambda x: 1
340 340 return len
341 341
342 342
343 343 class noop(object):
344 344 """dummy context manager"""
345 345
346 346 def __enter__(self):
347 347 pass
348 348
349 349 def __exit__(self, *args):
350 350 pass
351 351
352 352
353 353 NOOPCTX = noop()
354 354
355 355
356 356 def gettimer(ui, opts=None):
357 357 """return a timer function and formatter: (timer, formatter)
358 358
359 359 This function exists to gather the creation of formatter in a single
360 360 place instead of duplicating it in all performance commands."""
361 361
362 362 # enforce an idle period before execution to counteract power management
363 363 # experimental config: perf.presleep
364 364 time.sleep(getint(ui, b"perf", b"presleep", 1))
365 365
366 366 if opts is None:
367 367 opts = {}
368 368 # redirect all to stderr unless buffer api is in use
369 369 if not ui._buffers:
370 370 ui = ui.copy()
371 371 uifout = safeattrsetter(ui, b'fout', ignoremissing=True)
372 372 if uifout:
373 373 # for "historical portability":
374 374 # ui.fout/ferr have been available since 1.9 (or 4e1ccd4c2b6d)
375 375 uifout.set(ui.ferr)
376 376
377 377 # get a formatter
378 378 uiformatter = getattr(ui, 'formatter', None)
379 379 if uiformatter:
380 380 fm = uiformatter(b'perf', opts)
381 381 else:
382 382 # for "historical portability":
383 383 # define formatter locally, because ui.formatter has been
384 384 # available since 2.2 (or ae5f92e154d3)
385 385 from mercurial import node
386 386
387 387 class defaultformatter(object):
388 388 """Minimized composition of baseformatter and plainformatter
389 389 """
390 390
391 391 def __init__(self, ui, topic, opts):
392 392 self._ui = ui
393 393 if ui.debugflag:
394 394 self.hexfunc = node.hex
395 395 else:
396 396 self.hexfunc = node.short
397 397
398 398 def __nonzero__(self):
399 399 return False
400 400
401 401 __bool__ = __nonzero__
402 402
403 403 def startitem(self):
404 404 pass
405 405
406 406 def data(self, **data):
407 407 pass
408 408
409 409 def write(self, fields, deftext, *fielddata, **opts):
410 410 self._ui.write(deftext % fielddata, **opts)
411 411
412 412 def condwrite(self, cond, fields, deftext, *fielddata, **opts):
413 413 if cond:
414 414 self._ui.write(deftext % fielddata, **opts)
415 415
416 416 def plain(self, text, **opts):
417 417 self._ui.write(text, **opts)
418 418
419 419 def end(self):
420 420 pass
421 421
422 422 fm = defaultformatter(ui, b'perf', opts)
423 423
424 424 # stub function, runs code only once instead of in a loop
425 425 # experimental config: perf.stub
426 426 if ui.configbool(b"perf", b"stub", False):
427 427 return functools.partial(stub_timer, fm), fm
428 428
429 429 # experimental config: perf.all-timing
430 430 displayall = ui.configbool(b"perf", b"all-timing", False)
431 431
432 432 # experimental config: perf.run-limits
433 433 limitspec = ui.configlist(b"perf", b"run-limits", [])
434 434 limits = []
435 435 for item in limitspec:
436 436 parts = item.split(b'-', 1)
437 437 if len(parts) < 2:
438 438 ui.warn((b'malformatted run limit entry, missing "-": %s\n' % item))
439 439 continue
440 440 try:
441 441 time_limit = float(_sysstr(parts[0]))
442 442 except ValueError as e:
443 443 ui.warn(
444 444 (
445 445 b'malformatted run limit entry, %s: %s\n'
446 446 % (_bytestr(e), item)
447 447 )
448 448 )
449 449 continue
450 450 try:
451 451 run_limit = int(_sysstr(parts[1]))
452 452 except ValueError as e:
453 453 ui.warn(
454 454 (
455 455 b'malformatted run limit entry, %s: %s\n'
456 456 % (_bytestr(e), item)
457 457 )
458 458 )
459 459 continue
460 460 limits.append((time_limit, run_limit))
461 461 if not limits:
462 462 limits = DEFAULTLIMITS
463 463
464 464 profiler = None
465 465 if profiling is not None:
466 466 if ui.configbool(b"perf", b"profile-benchmark", False):
467 467 profiler = profiling.profile(ui)
468 468
469 469 prerun = getint(ui, b"perf", b"pre-run", 0)
470 470 t = functools.partial(
471 471 _timer,
472 472 fm,
473 473 displayall=displayall,
474 474 limits=limits,
475 475 prerun=prerun,
476 476 profiler=profiler,
477 477 )
478 478 return t, fm
479 479
480 480
481 481 def stub_timer(fm, func, setup=None, title=None):
482 482 if setup is not None:
483 483 setup()
484 484 func()
485 485
486 486
487 487 @contextlib.contextmanager
488 488 def timeone():
489 489 r = []
490 490 ostart = os.times()
491 491 cstart = util.timer()
492 492 yield r
493 493 cstop = util.timer()
494 494 ostop = os.times()
495 495 a, b = ostart, ostop
496 496 r.append((cstop - cstart, b[0] - a[0], b[1] - a[1]))
497 497
498 498
499 499 # list of stop condition (elapsed time, minimal run count)
500 500 DEFAULTLIMITS = (
501 501 (3.0, 100),
502 502 (10.0, 3),
503 503 )
504 504
505 505
506 506 def _timer(
507 507 fm,
508 508 func,
509 509 setup=None,
510 510 title=None,
511 511 displayall=False,
512 512 limits=DEFAULTLIMITS,
513 513 prerun=0,
514 514 profiler=None,
515 515 ):
516 516 gc.collect()
517 517 results = []
518 518 begin = util.timer()
519 519 count = 0
520 520 if profiler is None:
521 521 profiler = NOOPCTX
522 522 for i in range(prerun):
523 523 if setup is not None:
524 524 setup()
525 525 func()
526 526 keepgoing = True
527 527 while keepgoing:
528 528 if setup is not None:
529 529 setup()
530 530 with profiler:
531 531 with timeone() as item:
532 532 r = func()
533 533 profiler = NOOPCTX
534 534 count += 1
535 535 results.append(item[0])
536 536 cstop = util.timer()
537 537 # Look for a stop condition.
538 538 elapsed = cstop - begin
539 539 for t, mincount in limits:
540 540 if elapsed >= t and count >= mincount:
541 541 keepgoing = False
542 542 break
543 543
544 544 formatone(fm, results, title=title, result=r, displayall=displayall)
545 545
546 546
547 547 def formatone(fm, timings, title=None, result=None, displayall=False):
548 548
549 549 count = len(timings)
550 550
551 551 fm.startitem()
552 552
553 553 if title:
554 554 fm.write(b'title', b'! %s\n', title)
555 555 if result:
556 556 fm.write(b'result', b'! result: %s\n', result)
557 557
558 558 def display(role, entry):
559 559 prefix = b''
560 560 if role != b'best':
561 561 prefix = b'%s.' % role
562 562 fm.plain(b'!')
563 563 fm.write(prefix + b'wall', b' wall %f', entry[0])
564 564 fm.write(prefix + b'comb', b' comb %f', entry[1] + entry[2])
565 565 fm.write(prefix + b'user', b' user %f', entry[1])
566 566 fm.write(prefix + b'sys', b' sys %f', entry[2])
567 567 fm.write(prefix + b'count', b' (%s of %%d)' % role, count)
568 568 fm.plain(b'\n')
569 569
570 570 timings.sort()
571 571 min_val = timings[0]
572 572 display(b'best', min_val)
573 573 if displayall:
574 574 max_val = timings[-1]
575 575 display(b'max', max_val)
576 576 avg = tuple([sum(x) / count for x in zip(*timings)])
577 577 display(b'avg', avg)
578 578 median = timings[len(timings) // 2]
579 579 display(b'median', median)
580 580
581 581
582 582 # utilities for historical portability
583 583
584 584
585 585 def getint(ui, section, name, default):
586 586 # for "historical portability":
587 587 # ui.configint has been available since 1.9 (or fa2b596db182)
588 588 v = ui.config(section, name, None)
589 589 if v is None:
590 590 return default
591 591 try:
592 592 return int(v)
593 593 except ValueError:
594 594 raise error.ConfigError(
595 595 b"%s.%s is not an integer ('%s')" % (section, name, v)
596 596 )
597 597
598 598
599 599 def safeattrsetter(obj, name, ignoremissing=False):
600 600 """Ensure that 'obj' has 'name' attribute before subsequent setattr
601 601
602 602 This function is aborted, if 'obj' doesn't have 'name' attribute
603 603 at runtime. This avoids overlooking removal of an attribute, which
604 604 breaks assumption of performance measurement, in the future.
605 605
606 606 This function returns the object to (1) assign a new value, and
607 607 (2) restore an original value to the attribute.
608 608
609 609 If 'ignoremissing' is true, missing 'name' attribute doesn't cause
610 610 abortion, and this function returns None. This is useful to
611 611 examine an attribute, which isn't ensured in all Mercurial
612 612 versions.
613 613 """
614 614 if not util.safehasattr(obj, name):
615 615 if ignoremissing:
616 616 return None
617 617 raise error.Abort(
618 618 (
619 619 b"missing attribute %s of %s might break assumption"
620 620 b" of performance measurement"
621 621 )
622 622 % (name, obj)
623 623 )
624 624
625 625 origvalue = getattr(obj, _sysstr(name))
626 626
627 627 class attrutil(object):
628 628 def set(self, newvalue):
629 629 setattr(obj, _sysstr(name), newvalue)
630 630
631 631 def restore(self):
632 632 setattr(obj, _sysstr(name), origvalue)
633 633
634 634 return attrutil()
635 635
636 636
637 637 # utilities to examine each internal API changes
638 638
639 639
640 640 def getbranchmapsubsettable():
641 641 # for "historical portability":
642 642 # subsettable is defined in:
643 643 # - branchmap since 2.9 (or 175c6fd8cacc)
644 644 # - repoview since 2.5 (or 59a9f18d4587)
645 645 # - repoviewutil since 5.0
646 646 for mod in (branchmap, repoview, repoviewutil):
647 647 subsettable = getattr(mod, 'subsettable', None)
648 648 if subsettable:
649 649 return subsettable
650 650
651 651 # bisecting in bcee63733aad::59a9f18d4587 can reach here (both
652 652 # branchmap and repoview modules exist, but subsettable attribute
653 653 # doesn't)
654 654 raise error.Abort(
655 655 b"perfbranchmap not available with this Mercurial",
656 656 hint=b"use 2.5 or later",
657 657 )
658 658
659 659
660 660 def getsvfs(repo):
661 661 """Return appropriate object to access files under .hg/store
662 662 """
663 663 # for "historical portability":
664 664 # repo.svfs has been available since 2.3 (or 7034365089bf)
665 665 svfs = getattr(repo, 'svfs', None)
666 666 if svfs:
667 667 return svfs
668 668 else:
669 669 return getattr(repo, 'sopener')
670 670
671 671
672 672 def getvfs(repo):
673 673 """Return appropriate object to access files under .hg
674 674 """
675 675 # for "historical portability":
676 676 # repo.vfs has been available since 2.3 (or 7034365089bf)
677 677 vfs = getattr(repo, 'vfs', None)
678 678 if vfs:
679 679 return vfs
680 680 else:
681 681 return getattr(repo, 'opener')
682 682
683 683
684 684 def repocleartagscachefunc(repo):
685 685 """Return the function to clear tags cache according to repo internal API
686 686 """
687 687 if util.safehasattr(repo, b'_tagscache'): # since 2.0 (or 9dca7653b525)
688 688 # in this case, setattr(repo, '_tagscache', None) or so isn't
689 689 # correct way to clear tags cache, because existing code paths
690 690 # expect _tagscache to be a structured object.
691 691 def clearcache():
692 692 # _tagscache has been filteredpropertycache since 2.5 (or
693 693 # 98c867ac1330), and delattr() can't work in such case
694 694 if '_tagscache' in vars(repo):
695 695 del repo.__dict__['_tagscache']
696 696
697 697 return clearcache
698 698
699 699 repotags = safeattrsetter(repo, b'_tags', ignoremissing=True)
700 700 if repotags: # since 1.4 (or 5614a628d173)
701 701 return lambda: repotags.set(None)
702 702
703 703 repotagscache = safeattrsetter(repo, b'tagscache', ignoremissing=True)
704 704 if repotagscache: # since 0.6 (or d7df759d0e97)
705 705 return lambda: repotagscache.set(None)
706 706
707 707 # Mercurial earlier than 0.6 (or d7df759d0e97) logically reaches
708 708 # this point, but it isn't so problematic, because:
709 709 # - repo.tags of such Mercurial isn't "callable", and repo.tags()
710 710 # in perftags() causes failure soon
711 711 # - perf.py itself has been available since 1.1 (or eb240755386d)
712 712 raise error.Abort(b"tags API of this hg command is unknown")
713 713
714 714
715 715 # utilities to clear cache
716 716
717 717
718 718 def clearfilecache(obj, attrname):
719 719 unfiltered = getattr(obj, 'unfiltered', None)
720 720 if unfiltered is not None:
721 721 obj = obj.unfiltered()
722 722 if attrname in vars(obj):
723 723 delattr(obj, attrname)
724 724 obj._filecache.pop(attrname, None)
725 725
726 726
727 727 def clearchangelog(repo):
728 728 if repo is not repo.unfiltered():
729 729 object.__setattr__(repo, '_clcachekey', None)
730 730 object.__setattr__(repo, '_clcache', None)
731 731 clearfilecache(repo.unfiltered(), 'changelog')
732 732
733 733
734 734 # perf commands
735 735
736 736
737 737 @command(b'perfwalk', formatteropts)
738 738 def perfwalk(ui, repo, *pats, **opts):
739 739 opts = _byteskwargs(opts)
740 740 timer, fm = gettimer(ui, opts)
741 741 m = scmutil.match(repo[None], pats, {})
742 742 timer(
743 743 lambda: len(
744 744 list(
745 745 repo.dirstate.walk(m, subrepos=[], unknown=True, ignored=False)
746 746 )
747 747 )
748 748 )
749 749 fm.end()
750 750
751 751
752 752 @command(b'perfannotate', formatteropts)
753 753 def perfannotate(ui, repo, f, **opts):
754 754 opts = _byteskwargs(opts)
755 755 timer, fm = gettimer(ui, opts)
756 756 fc = repo[b'.'][f]
757 757 timer(lambda: len(fc.annotate(True)))
758 758 fm.end()
759 759
760 760
761 761 @command(
762 762 b'perfstatus',
763 763 [
764 764 (b'u', b'unknown', False, b'ask status to look for unknown files'),
765 765 (b'', b'dirstate', False, b'benchmark the internal dirstate call'),
766 766 ]
767 767 + formatteropts,
768 768 )
769 769 def perfstatus(ui, repo, **opts):
770 770 """benchmark the performance of a single status call
771 771
772 772 The repository data are preserved between each call.
773 773
774 774 By default, only the status of the tracked file are requested. If
775 775 `--unknown` is passed, the "unknown" files are also tracked.
776 776 """
777 777 opts = _byteskwargs(opts)
778 778 # m = match.always(repo.root, repo.getcwd())
779 779 # timer(lambda: sum(map(len, repo.dirstate.status(m, [], False, False,
780 780 # False))))
781 781 timer, fm = gettimer(ui, opts)
782 782 if opts[b'dirstate']:
783 783 dirstate = repo.dirstate
784 784 m = scmutil.matchall(repo)
785 785 unknown = opts[b'unknown']
786 786
787 787 def status_dirstate():
788 788 s = dirstate.status(
789 789 m, subrepos=[], ignored=False, clean=False, unknown=unknown
790 790 )
791 791 sum(map(bool, s))
792 792
793 793 timer(status_dirstate)
794 794 else:
795 795 timer(lambda: sum(map(len, repo.status(unknown=opts[b'unknown']))))
796 796 fm.end()
797 797
798 798
799 799 @command(b'perfaddremove', formatteropts)
800 800 def perfaddremove(ui, repo, **opts):
801 801 opts = _byteskwargs(opts)
802 802 timer, fm = gettimer(ui, opts)
803 803 try:
804 804 oldquiet = repo.ui.quiet
805 805 repo.ui.quiet = True
806 806 matcher = scmutil.match(repo[None])
807 807 opts[b'dry_run'] = True
808 808 if b'uipathfn' in getargspec(scmutil.addremove).args:
809 809 uipathfn = scmutil.getuipathfn(repo)
810 810 timer(lambda: scmutil.addremove(repo, matcher, b"", uipathfn, opts))
811 811 else:
812 812 timer(lambda: scmutil.addremove(repo, matcher, b"", opts))
813 813 finally:
814 814 repo.ui.quiet = oldquiet
815 815 fm.end()
816 816
817 817
818 818 def clearcaches(cl):
819 819 # behave somewhat consistently across internal API changes
820 820 if util.safehasattr(cl, b'clearcaches'):
821 821 cl.clearcaches()
822 822 elif util.safehasattr(cl, b'_nodecache'):
823 823 # <= hg-5.2
824 824 from mercurial.node import nullid, nullrev
825 825
826 826 cl._nodecache = {nullid: nullrev}
827 827 cl._nodepos = None
828 828
829 829
830 830 @command(b'perfheads', formatteropts)
831 831 def perfheads(ui, repo, **opts):
832 832 """benchmark the computation of a changelog heads"""
833 833 opts = _byteskwargs(opts)
834 834 timer, fm = gettimer(ui, opts)
835 835 cl = repo.changelog
836 836
837 837 def s():
838 838 clearcaches(cl)
839 839
840 840 def d():
841 841 len(cl.headrevs())
842 842
843 843 timer(d, setup=s)
844 844 fm.end()
845 845
846 846
847 847 @command(
848 848 b'perftags',
849 849 formatteropts
850 850 + [(b'', b'clear-revlogs', False, b'refresh changelog and manifest'),],
851 851 )
852 852 def perftags(ui, repo, **opts):
853 853 opts = _byteskwargs(opts)
854 854 timer, fm = gettimer(ui, opts)
855 855 repocleartagscache = repocleartagscachefunc(repo)
856 856 clearrevlogs = opts[b'clear_revlogs']
857 857
858 858 def s():
859 859 if clearrevlogs:
860 860 clearchangelog(repo)
861 861 clearfilecache(repo.unfiltered(), 'manifest')
862 862 repocleartagscache()
863 863
864 864 def t():
865 865 return len(repo.tags())
866 866
867 867 timer(t, setup=s)
868 868 fm.end()
869 869
870 870
871 871 @command(b'perfancestors', formatteropts)
872 872 def perfancestors(ui, repo, **opts):
873 873 opts = _byteskwargs(opts)
874 874 timer, fm = gettimer(ui, opts)
875 875 heads = repo.changelog.headrevs()
876 876
877 877 def d():
878 878 for a in repo.changelog.ancestors(heads):
879 879 pass
880 880
881 881 timer(d)
882 882 fm.end()
883 883
884 884
885 885 @command(b'perfancestorset', formatteropts)
886 886 def perfancestorset(ui, repo, revset, **opts):
887 887 opts = _byteskwargs(opts)
888 888 timer, fm = gettimer(ui, opts)
889 889 revs = repo.revs(revset)
890 890 heads = repo.changelog.headrevs()
891 891
892 892 def d():
893 893 s = repo.changelog.ancestors(heads)
894 894 for rev in revs:
895 895 rev in s
896 896
897 897 timer(d)
898 898 fm.end()
899 899
900 900
901 901 @command(b'perfdiscovery', formatteropts, b'PATH')
902 902 def perfdiscovery(ui, repo, path, **opts):
903 903 """benchmark discovery between local repo and the peer at given path
904 904 """
905 905 repos = [repo, None]
906 906 timer, fm = gettimer(ui, opts)
907 907 path = ui.expandpath(path)
908 908
909 909 def s():
910 910 repos[1] = hg.peer(ui, opts, path)
911 911
912 912 def d():
913 913 setdiscovery.findcommonheads(ui, *repos)
914 914
915 915 timer(d, setup=s)
916 916 fm.end()
917 917
918 918
919 919 @command(
920 920 b'perfbookmarks',
921 921 formatteropts
922 922 + [(b'', b'clear-revlogs', False, b'refresh changelog and manifest'),],
923 923 )
924 924 def perfbookmarks(ui, repo, **opts):
925 925 """benchmark parsing bookmarks from disk to memory"""
926 926 opts = _byteskwargs(opts)
927 927 timer, fm = gettimer(ui, opts)
928 928
929 929 clearrevlogs = opts[b'clear_revlogs']
930 930
931 931 def s():
932 932 if clearrevlogs:
933 933 clearchangelog(repo)
934 934 clearfilecache(repo, b'_bookmarks')
935 935
936 936 def d():
937 937 repo._bookmarks
938 938
939 939 timer(d, setup=s)
940 940 fm.end()
941 941
942 942
943 943 @command(b'perfbundleread', formatteropts, b'BUNDLE')
944 944 def perfbundleread(ui, repo, bundlepath, **opts):
945 945 """Benchmark reading of bundle files.
946 946
947 947 This command is meant to isolate the I/O part of bundle reading as
948 948 much as possible.
949 949 """
950 950 from mercurial import (
951 951 bundle2,
952 952 exchange,
953 953 streamclone,
954 954 )
955 955
956 956 opts = _byteskwargs(opts)
957 957
958 958 def makebench(fn):
959 959 def run():
960 960 with open(bundlepath, b'rb') as fh:
961 961 bundle = exchange.readbundle(ui, fh, bundlepath)
962 962 fn(bundle)
963 963
964 964 return run
965 965
966 966 def makereadnbytes(size):
967 967 def run():
968 968 with open(bundlepath, b'rb') as fh:
969 969 bundle = exchange.readbundle(ui, fh, bundlepath)
970 970 while bundle.read(size):
971 971 pass
972 972
973 973 return run
974 974
975 975 def makestdioread(size):
976 976 def run():
977 977 with open(bundlepath, b'rb') as fh:
978 978 while fh.read(size):
979 979 pass
980 980
981 981 return run
982 982
983 983 # bundle1
984 984
985 985 def deltaiter(bundle):
986 986 for delta in bundle.deltaiter():
987 987 pass
988 988
989 989 def iterchunks(bundle):
990 990 for chunk in bundle.getchunks():
991 991 pass
992 992
993 993 # bundle2
994 994
995 995 def forwardchunks(bundle):
996 996 for chunk in bundle._forwardchunks():
997 997 pass
998 998
999 999 def iterparts(bundle):
1000 1000 for part in bundle.iterparts():
1001 1001 pass
1002 1002
1003 1003 def iterpartsseekable(bundle):
1004 1004 for part in bundle.iterparts(seekable=True):
1005 1005 pass
1006 1006
1007 1007 def seek(bundle):
1008 1008 for part in bundle.iterparts(seekable=True):
1009 1009 part.seek(0, os.SEEK_END)
1010 1010
1011 1011 def makepartreadnbytes(size):
1012 1012 def run():
1013 1013 with open(bundlepath, b'rb') as fh:
1014 1014 bundle = exchange.readbundle(ui, fh, bundlepath)
1015 1015 for part in bundle.iterparts():
1016 1016 while part.read(size):
1017 1017 pass
1018 1018
1019 1019 return run
1020 1020
1021 1021 benches = [
1022 1022 (makestdioread(8192), b'read(8k)'),
1023 1023 (makestdioread(16384), b'read(16k)'),
1024 1024 (makestdioread(32768), b'read(32k)'),
1025 1025 (makestdioread(131072), b'read(128k)'),
1026 1026 ]
1027 1027
1028 1028 with open(bundlepath, b'rb') as fh:
1029 1029 bundle = exchange.readbundle(ui, fh, bundlepath)
1030 1030
1031 1031 if isinstance(bundle, changegroup.cg1unpacker):
1032 1032 benches.extend(
1033 1033 [
1034 1034 (makebench(deltaiter), b'cg1 deltaiter()'),
1035 1035 (makebench(iterchunks), b'cg1 getchunks()'),
1036 1036 (makereadnbytes(8192), b'cg1 read(8k)'),
1037 1037 (makereadnbytes(16384), b'cg1 read(16k)'),
1038 1038 (makereadnbytes(32768), b'cg1 read(32k)'),
1039 1039 (makereadnbytes(131072), b'cg1 read(128k)'),
1040 1040 ]
1041 1041 )
1042 1042 elif isinstance(bundle, bundle2.unbundle20):
1043 1043 benches.extend(
1044 1044 [
1045 1045 (makebench(forwardchunks), b'bundle2 forwardchunks()'),
1046 1046 (makebench(iterparts), b'bundle2 iterparts()'),
1047 1047 (
1048 1048 makebench(iterpartsseekable),
1049 1049 b'bundle2 iterparts() seekable',
1050 1050 ),
1051 1051 (makebench(seek), b'bundle2 part seek()'),
1052 1052 (makepartreadnbytes(8192), b'bundle2 part read(8k)'),
1053 1053 (makepartreadnbytes(16384), b'bundle2 part read(16k)'),
1054 1054 (makepartreadnbytes(32768), b'bundle2 part read(32k)'),
1055 1055 (makepartreadnbytes(131072), b'bundle2 part read(128k)'),
1056 1056 ]
1057 1057 )
1058 1058 elif isinstance(bundle, streamclone.streamcloneapplier):
1059 1059 raise error.Abort(b'stream clone bundles not supported')
1060 1060 else:
1061 1061 raise error.Abort(b'unhandled bundle type: %s' % type(bundle))
1062 1062
1063 1063 for fn, title in benches:
1064 1064 timer, fm = gettimer(ui, opts)
1065 1065 timer(fn, title=title)
1066 1066 fm.end()
1067 1067
1068 1068
1069 1069 @command(
1070 1070 b'perfchangegroupchangelog',
1071 1071 formatteropts
1072 1072 + [
1073 1073 (b'', b'cgversion', b'02', b'changegroup version'),
1074 1074 (b'r', b'rev', b'', b'revisions to add to changegroup'),
1075 1075 ],
1076 1076 )
1077 1077 def perfchangegroupchangelog(ui, repo, cgversion=b'02', rev=None, **opts):
1078 1078 """Benchmark producing a changelog group for a changegroup.
1079 1079
1080 1080 This measures the time spent processing the changelog during a
1081 1081 bundle operation. This occurs during `hg bundle` and on a server
1082 1082 processing a `getbundle` wire protocol request (handles clones
1083 1083 and pull requests).
1084 1084
1085 1085 By default, all revisions are added to the changegroup.
1086 1086 """
1087 1087 opts = _byteskwargs(opts)
1088 1088 cl = repo.changelog
1089 1089 nodes = [cl.lookup(r) for r in repo.revs(rev or b'all()')]
1090 1090 bundler = changegroup.getbundler(cgversion, repo)
1091 1091
1092 1092 def d():
1093 1093 state, chunks = bundler._generatechangelog(cl, nodes)
1094 1094 for chunk in chunks:
1095 1095 pass
1096 1096
1097 1097 timer, fm = gettimer(ui, opts)
1098 1098
1099 1099 # Terminal printing can interfere with timing. So disable it.
1100 1100 with ui.configoverride({(b'progress', b'disable'): True}):
1101 1101 timer(d)
1102 1102
1103 1103 fm.end()
1104 1104
1105 1105
1106 1106 @command(b'perfdirs', formatteropts)
1107 1107 def perfdirs(ui, repo, **opts):
1108 1108 opts = _byteskwargs(opts)
1109 1109 timer, fm = gettimer(ui, opts)
1110 1110 dirstate = repo.dirstate
1111 1111 b'a' in dirstate
1112 1112
1113 1113 def d():
1114 1114 dirstate.hasdir(b'a')
1115 1115 del dirstate._map._dirs
1116 1116
1117 1117 timer(d)
1118 1118 fm.end()
1119 1119
1120 1120
1121 1121 @command(
1122 1122 b'perfdirstate',
1123 1123 [
1124 1124 (
1125 1125 b'',
1126 1126 b'iteration',
1127 1127 None,
1128 1128 b'benchmark a full iteration for the dirstate',
1129 1129 ),
1130 1130 (
1131 1131 b'',
1132 1132 b'contains',
1133 1133 None,
1134 1134 b'benchmark a large amount of `nf in dirstate` calls',
1135 1135 ),
1136 1136 ]
1137 1137 + formatteropts,
1138 1138 )
1139 1139 def perfdirstate(ui, repo, **opts):
1140 1140 """benchmap the time of various distate operations
1141 1141
1142 1142 By default benchmark the time necessary to load a dirstate from scratch.
1143 1143 The dirstate is loaded to the point were a "contains" request can be
1144 1144 answered.
1145 1145 """
1146 1146 opts = _byteskwargs(opts)
1147 1147 timer, fm = gettimer(ui, opts)
1148 1148 b"a" in repo.dirstate
1149 1149
1150 1150 if opts[b'iteration'] and opts[b'contains']:
1151 1151 msg = b'only specify one of --iteration or --contains'
1152 1152 raise error.Abort(msg)
1153 1153
1154 1154 if opts[b'iteration']:
1155 1155 setup = None
1156 1156 dirstate = repo.dirstate
1157 1157
1158 1158 def d():
1159 1159 for f in dirstate:
1160 1160 pass
1161 1161
1162 1162 elif opts[b'contains']:
1163 1163 setup = None
1164 1164 dirstate = repo.dirstate
1165 1165 allfiles = list(dirstate)
1166 1166 # also add file path that will be "missing" from the dirstate
1167 1167 allfiles.extend([f[::-1] for f in allfiles])
1168 1168
1169 1169 def d():
1170 1170 for f in allfiles:
1171 1171 f in dirstate
1172 1172
1173 1173 else:
1174 1174
1175 1175 def setup():
1176 1176 repo.dirstate.invalidate()
1177 1177
1178 1178 def d():
1179 1179 b"a" in repo.dirstate
1180 1180
1181 1181 timer(d, setup=setup)
1182 1182 fm.end()
1183 1183
1184 1184
1185 1185 @command(b'perfdirstatedirs', formatteropts)
1186 1186 def perfdirstatedirs(ui, repo, **opts):
1187 1187 """benchmap a 'dirstate.hasdir' call from an empty `dirs` cache
1188 1188 """
1189 1189 opts = _byteskwargs(opts)
1190 1190 timer, fm = gettimer(ui, opts)
1191 1191 repo.dirstate.hasdir(b"a")
1192 1192
1193 1193 def setup():
1194 1194 del repo.dirstate._map._dirs
1195 1195
1196 1196 def d():
1197 1197 repo.dirstate.hasdir(b"a")
1198 1198
1199 1199 timer(d, setup=setup)
1200 1200 fm.end()
1201 1201
1202 1202
1203 1203 @command(b'perfdirstatefoldmap', formatteropts)
1204 1204 def perfdirstatefoldmap(ui, repo, **opts):
1205 1205 """benchmap a `dirstate._map.filefoldmap.get()` request
1206 1206
1207 1207 The dirstate filefoldmap cache is dropped between every request.
1208 1208 """
1209 1209 opts = _byteskwargs(opts)
1210 1210 timer, fm = gettimer(ui, opts)
1211 1211 dirstate = repo.dirstate
1212 1212 dirstate._map.filefoldmap.get(b'a')
1213 1213
1214 1214 def setup():
1215 1215 del dirstate._map.filefoldmap
1216 1216
1217 1217 def d():
1218 1218 dirstate._map.filefoldmap.get(b'a')
1219 1219
1220 1220 timer(d, setup=setup)
1221 1221 fm.end()
1222 1222
1223 1223
1224 1224 @command(b'perfdirfoldmap', formatteropts)
1225 1225 def perfdirfoldmap(ui, repo, **opts):
1226 1226 """benchmap a `dirstate._map.dirfoldmap.get()` request
1227 1227
1228 1228 The dirstate dirfoldmap cache is dropped between every request.
1229 1229 """
1230 1230 opts = _byteskwargs(opts)
1231 1231 timer, fm = gettimer(ui, opts)
1232 1232 dirstate = repo.dirstate
1233 1233 dirstate._map.dirfoldmap.get(b'a')
1234 1234
1235 1235 def setup():
1236 1236 del dirstate._map.dirfoldmap
1237 1237 del dirstate._map._dirs
1238 1238
1239 1239 def d():
1240 1240 dirstate._map.dirfoldmap.get(b'a')
1241 1241
1242 1242 timer(d, setup=setup)
1243 1243 fm.end()
1244 1244
1245 1245
1246 1246 @command(b'perfdirstatewrite', formatteropts)
1247 1247 def perfdirstatewrite(ui, repo, **opts):
1248 1248 """benchmap the time it take to write a dirstate on disk
1249 1249 """
1250 1250 opts = _byteskwargs(opts)
1251 1251 timer, fm = gettimer(ui, opts)
1252 1252 ds = repo.dirstate
1253 1253 b"a" in ds
1254 1254
1255 1255 def setup():
1256 1256 ds._dirty = True
1257 1257
1258 1258 def d():
1259 1259 ds.write(repo.currenttransaction())
1260 1260
1261 1261 timer(d, setup=setup)
1262 1262 fm.end()
1263 1263
1264 1264
1265 1265 def _getmergerevs(repo, opts):
1266 1266 """parse command argument to return rev involved in merge
1267 1267
1268 1268 input: options dictionnary with `rev`, `from` and `bse`
1269 1269 output: (localctx, otherctx, basectx)
1270 1270 """
1271 1271 if opts[b'from']:
1272 1272 fromrev = scmutil.revsingle(repo, opts[b'from'])
1273 1273 wctx = repo[fromrev]
1274 1274 else:
1275 1275 wctx = repo[None]
1276 1276 # we don't want working dir files to be stat'd in the benchmark, so
1277 1277 # prime that cache
1278 1278 wctx.dirty()
1279 1279 rctx = scmutil.revsingle(repo, opts[b'rev'], opts[b'rev'])
1280 1280 if opts[b'base']:
1281 1281 fromrev = scmutil.revsingle(repo, opts[b'base'])
1282 1282 ancestor = repo[fromrev]
1283 1283 else:
1284 1284 ancestor = wctx.ancestor(rctx)
1285 1285 return (wctx, rctx, ancestor)
1286 1286
1287 1287
1288 1288 @command(
1289 1289 b'perfmergecalculate',
1290 1290 [
1291 1291 (b'r', b'rev', b'.', b'rev to merge against'),
1292 1292 (b'', b'from', b'', b'rev to merge from'),
1293 1293 (b'', b'base', b'', b'the revision to use as base'),
1294 1294 ]
1295 1295 + formatteropts,
1296 1296 )
1297 1297 def perfmergecalculate(ui, repo, **opts):
1298 1298 opts = _byteskwargs(opts)
1299 1299 timer, fm = gettimer(ui, opts)
1300 1300
1301 1301 wctx, rctx, ancestor = _getmergerevs(repo, opts)
1302 1302
1303 1303 def d():
1304 1304 # acceptremote is True because we don't want prompts in the middle of
1305 1305 # our benchmark
1306 1306 merge.calculateupdates(
1307 1307 repo,
1308 1308 wctx,
1309 1309 rctx,
1310 1310 [ancestor],
1311 1311 branchmerge=False,
1312 1312 force=False,
1313 1313 acceptremote=True,
1314 1314 followcopies=True,
1315 1315 )
1316 1316
1317 1317 timer(d)
1318 1318 fm.end()
1319 1319
1320 1320
1321 1321 @command(
1322 1322 b'perfmergecopies',
1323 1323 [
1324 1324 (b'r', b'rev', b'.', b'rev to merge against'),
1325 1325 (b'', b'from', b'', b'rev to merge from'),
1326 1326 (b'', b'base', b'', b'the revision to use as base'),
1327 1327 ]
1328 1328 + formatteropts,
1329 1329 )
1330 1330 def perfmergecopies(ui, repo, **opts):
1331 1331 """measure runtime of `copies.mergecopies`"""
1332 1332 opts = _byteskwargs(opts)
1333 1333 timer, fm = gettimer(ui, opts)
1334 1334 wctx, rctx, ancestor = _getmergerevs(repo, opts)
1335 1335
1336 1336 def d():
1337 1337 # acceptremote is True because we don't want prompts in the middle of
1338 1338 # our benchmark
1339 1339 copies.mergecopies(repo, wctx, rctx, ancestor)
1340 1340
1341 1341 timer(d)
1342 1342 fm.end()
1343 1343
1344 1344
1345 1345 @command(b'perfpathcopies', [], b"REV REV")
1346 1346 def perfpathcopies(ui, repo, rev1, rev2, **opts):
1347 1347 """benchmark the copy tracing logic"""
1348 1348 opts = _byteskwargs(opts)
1349 1349 timer, fm = gettimer(ui, opts)
1350 1350 ctx1 = scmutil.revsingle(repo, rev1, rev1)
1351 1351 ctx2 = scmutil.revsingle(repo, rev2, rev2)
1352 1352
1353 1353 def d():
1354 1354 copies.pathcopies(ctx1, ctx2)
1355 1355
1356 1356 timer(d)
1357 1357 fm.end()
1358 1358
1359 1359
1360 1360 @command(
1361 1361 b'perfphases',
1362 1362 [(b'', b'full', False, b'include file reading time too'),],
1363 1363 b"",
1364 1364 )
1365 1365 def perfphases(ui, repo, **opts):
1366 1366 """benchmark phasesets computation"""
1367 1367 opts = _byteskwargs(opts)
1368 1368 timer, fm = gettimer(ui, opts)
1369 1369 _phases = repo._phasecache
1370 1370 full = opts.get(b'full')
1371 1371
1372 1372 def d():
1373 1373 phases = _phases
1374 1374 if full:
1375 1375 clearfilecache(repo, b'_phasecache')
1376 1376 phases = repo._phasecache
1377 1377 phases.invalidate()
1378 1378 phases.loadphaserevs(repo)
1379 1379
1380 1380 timer(d)
1381 1381 fm.end()
1382 1382
1383 1383
1384 1384 @command(b'perfphasesremote', [], b"[DEST]")
1385 1385 def perfphasesremote(ui, repo, dest=None, **opts):
1386 1386 """benchmark time needed to analyse phases of the remote server"""
1387 1387 from mercurial.node import bin
1388 1388 from mercurial import (
1389 1389 exchange,
1390 1390 hg,
1391 1391 phases,
1392 1392 )
1393 1393
1394 1394 opts = _byteskwargs(opts)
1395 1395 timer, fm = gettimer(ui, opts)
1396 1396
1397 1397 path = ui.paths.getpath(dest, default=(b'default-push', b'default'))
1398 1398 if not path:
1399 1399 raise error.Abort(
1400 1400 b'default repository not configured!',
1401 1401 hint=b"see 'hg help config.paths'",
1402 1402 )
1403 1403 dest = path.pushloc or path.loc
1404 1404 ui.statusnoi18n(b'analysing phase of %s\n' % util.hidepassword(dest))
1405 1405 other = hg.peer(repo, opts, dest)
1406 1406
1407 1407 # easier to perform discovery through the operation
1408 1408 op = exchange.pushoperation(repo, other)
1409 1409 exchange._pushdiscoverychangeset(op)
1410 1410
1411 1411 remotesubset = op.fallbackheads
1412 1412
1413 1413 with other.commandexecutor() as e:
1414 1414 remotephases = e.callcommand(
1415 1415 b'listkeys', {b'namespace': b'phases'}
1416 1416 ).result()
1417 1417 del other
1418 1418 publishing = remotephases.get(b'publishing', False)
1419 1419 if publishing:
1420 1420 ui.statusnoi18n(b'publishing: yes\n')
1421 1421 else:
1422 1422 ui.statusnoi18n(b'publishing: no\n')
1423 1423
1424 1424 has_node = getattr(repo.changelog.index, 'has_node', None)
1425 1425 if has_node is None:
1426 1426 has_node = repo.changelog.nodemap.__contains__
1427 1427 nonpublishroots = 0
1428 1428 for nhex, phase in remotephases.iteritems():
1429 1429 if nhex == b'publishing': # ignore data related to publish option
1430 1430 continue
1431 1431 node = bin(nhex)
1432 1432 if has_node(node) and int(phase):
1433 1433 nonpublishroots += 1
1434 1434 ui.statusnoi18n(b'number of roots: %d\n' % len(remotephases))
1435 1435 ui.statusnoi18n(b'number of known non public roots: %d\n' % nonpublishroots)
1436 1436
1437 1437 def d():
1438 1438 phases.remotephasessummary(repo, remotesubset, remotephases)
1439 1439
1440 1440 timer(d)
1441 1441 fm.end()
1442 1442
1443 1443
1444 1444 @command(
1445 1445 b'perfmanifest',
1446 1446 [
1447 1447 (b'm', b'manifest-rev', False, b'Look up a manifest node revision'),
1448 1448 (b'', b'clear-disk', False, b'clear on-disk caches too'),
1449 1449 ]
1450 1450 + formatteropts,
1451 1451 b'REV|NODE',
1452 1452 )
1453 1453 def perfmanifest(ui, repo, rev, manifest_rev=False, clear_disk=False, **opts):
1454 1454 """benchmark the time to read a manifest from disk and return a usable
1455 1455 dict-like object
1456 1456
1457 1457 Manifest caches are cleared before retrieval."""
1458 1458 opts = _byteskwargs(opts)
1459 1459 timer, fm = gettimer(ui, opts)
1460 1460 if not manifest_rev:
1461 1461 ctx = scmutil.revsingle(repo, rev, rev)
1462 1462 t = ctx.manifestnode()
1463 1463 else:
1464 1464 from mercurial.node import bin
1465 1465
1466 1466 if len(rev) == 40:
1467 1467 t = bin(rev)
1468 1468 else:
1469 1469 try:
1470 1470 rev = int(rev)
1471 1471
1472 1472 if util.safehasattr(repo.manifestlog, b'getstorage'):
1473 1473 t = repo.manifestlog.getstorage(b'').node(rev)
1474 1474 else:
1475 1475 t = repo.manifestlog._revlog.lookup(rev)
1476 1476 except ValueError:
1477 1477 raise error.Abort(
1478 1478 b'manifest revision must be integer or full node'
1479 1479 )
1480 1480
1481 1481 def d():
1482 1482 repo.manifestlog.clearcaches(clear_persisted_data=clear_disk)
1483 1483 repo.manifestlog[t].read()
1484 1484
1485 1485 timer(d)
1486 1486 fm.end()
1487 1487
1488 1488
1489 1489 @command(b'perfchangeset', formatteropts)
1490 1490 def perfchangeset(ui, repo, rev, **opts):
1491 1491 opts = _byteskwargs(opts)
1492 1492 timer, fm = gettimer(ui, opts)
1493 1493 n = scmutil.revsingle(repo, rev).node()
1494 1494
1495 1495 def d():
1496 1496 repo.changelog.read(n)
1497 1497 # repo.changelog._cache = None
1498 1498
1499 1499 timer(d)
1500 1500 fm.end()
1501 1501
1502 1502
1503 1503 @command(b'perfignore', formatteropts)
1504 1504 def perfignore(ui, repo, **opts):
1505 1505 """benchmark operation related to computing ignore"""
1506 1506 opts = _byteskwargs(opts)
1507 1507 timer, fm = gettimer(ui, opts)
1508 1508 dirstate = repo.dirstate
1509 1509
1510 1510 def setupone():
1511 1511 dirstate.invalidate()
1512 1512 clearfilecache(dirstate, b'_ignore')
1513 1513
1514 1514 def runone():
1515 1515 dirstate._ignore
1516 1516
1517 1517 timer(runone, setup=setupone, title=b"load")
1518 1518 fm.end()
1519 1519
1520 1520
1521 1521 @command(
1522 1522 b'perfindex',
1523 1523 [
1524 1524 (b'', b'rev', [], b'revision to be looked up (default tip)'),
1525 1525 (b'', b'no-lookup', None, b'do not revision lookup post creation'),
1526 1526 ]
1527 1527 + formatteropts,
1528 1528 )
1529 1529 def perfindex(ui, repo, **opts):
1530 1530 """benchmark index creation time followed by a lookup
1531 1531
1532 1532 The default is to look `tip` up. Depending on the index implementation,
1533 1533 the revision looked up can matters. For example, an implementation
1534 1534 scanning the index will have a faster lookup time for `--rev tip` than for
1535 1535 `--rev 0`. The number of looked up revisions and their order can also
1536 1536 matters.
1537 1537
1538 1538 Example of useful set to test:
1539
1539 1540 * tip
1540 1541 * 0
1541 1542 * -10:
1542 1543 * :10
1543 1544 * -10: + :10
1544 1545 * :10: + -10:
1545 1546 * -10000:
1546 1547 * -10000: + 0
1547 1548
1548 1549 It is not currently possible to check for lookup of a missing node. For
1549 1550 deeper lookup benchmarking, checkout the `perfnodemap` command."""
1550 1551 import mercurial.revlog
1551 1552
1552 1553 opts = _byteskwargs(opts)
1553 1554 timer, fm = gettimer(ui, opts)
1554 1555 mercurial.revlog._prereadsize = 2 ** 24 # disable lazy parser in old hg
1555 1556 if opts[b'no_lookup']:
1556 1557 if opts['rev']:
1557 1558 raise error.Abort('--no-lookup and --rev are mutually exclusive')
1558 1559 nodes = []
1559 1560 elif not opts[b'rev']:
1560 1561 nodes = [repo[b"tip"].node()]
1561 1562 else:
1562 1563 revs = scmutil.revrange(repo, opts[b'rev'])
1563 1564 cl = repo.changelog
1564 1565 nodes = [cl.node(r) for r in revs]
1565 1566
1566 1567 unfi = repo.unfiltered()
1567 1568 # find the filecache func directly
1568 1569 # This avoid polluting the benchmark with the filecache logic
1569 1570 makecl = unfi.__class__.changelog.func
1570 1571
1571 1572 def setup():
1572 1573 # probably not necessary, but for good measure
1573 1574 clearchangelog(unfi)
1574 1575
1575 1576 def d():
1576 1577 cl = makecl(unfi)
1577 1578 for n in nodes:
1578 1579 cl.rev(n)
1579 1580
1580 1581 timer(d, setup=setup)
1581 1582 fm.end()
1582 1583
1583 1584
1584 1585 @command(
1585 1586 b'perfnodemap',
1586 1587 [
1587 1588 (b'', b'rev', [], b'revision to be looked up (default tip)'),
1588 1589 (b'', b'clear-caches', True, b'clear revlog cache between calls'),
1589 1590 ]
1590 1591 + formatteropts,
1591 1592 )
1592 1593 def perfnodemap(ui, repo, **opts):
1593 1594 """benchmark the time necessary to look up revision from a cold nodemap
1594 1595
1595 1596 Depending on the implementation, the amount and order of revision we look
1596 1597 up can varies. Example of useful set to test:
1597 1598 * tip
1598 1599 * 0
1599 1600 * -10:
1600 1601 * :10
1601 1602 * -10: + :10
1602 1603 * :10: + -10:
1603 1604 * -10000:
1604 1605 * -10000: + 0
1605 1606
1606 1607 The command currently focus on valid binary lookup. Benchmarking for
1607 1608 hexlookup, prefix lookup and missing lookup would also be valuable.
1608 1609 """
1609 1610 import mercurial.revlog
1610 1611
1611 1612 opts = _byteskwargs(opts)
1612 1613 timer, fm = gettimer(ui, opts)
1613 1614 mercurial.revlog._prereadsize = 2 ** 24 # disable lazy parser in old hg
1614 1615
1615 1616 unfi = repo.unfiltered()
1616 1617 clearcaches = opts['clear_caches']
1617 1618 # find the filecache func directly
1618 1619 # This avoid polluting the benchmark with the filecache logic
1619 1620 makecl = unfi.__class__.changelog.func
1620 1621 if not opts[b'rev']:
1621 1622 raise error.Abort('use --rev to specify revisions to look up')
1622 1623 revs = scmutil.revrange(repo, opts[b'rev'])
1623 1624 cl = repo.changelog
1624 1625 nodes = [cl.node(r) for r in revs]
1625 1626
1626 1627 # use a list to pass reference to a nodemap from one closure to the next
1627 1628 nodeget = [None]
1628 1629
1629 1630 def setnodeget():
1630 1631 # probably not necessary, but for good measure
1631 1632 clearchangelog(unfi)
1632 1633 cl = makecl(unfi)
1633 1634 if util.safehasattr(cl.index, 'get_rev'):
1634 1635 nodeget[0] = cl.index.get_rev
1635 1636 else:
1636 1637 nodeget[0] = cl.nodemap.get
1637 1638
1638 1639 def d():
1639 1640 get = nodeget[0]
1640 1641 for n in nodes:
1641 1642 get(n)
1642 1643
1643 1644 setup = None
1644 1645 if clearcaches:
1645 1646
1646 1647 def setup():
1647 1648 setnodeget()
1648 1649
1649 1650 else:
1650 1651 setnodeget()
1651 1652 d() # prewarm the data structure
1652 1653 timer(d, setup=setup)
1653 1654 fm.end()
1654 1655
1655 1656
1656 1657 @command(b'perfstartup', formatteropts)
1657 1658 def perfstartup(ui, repo, **opts):
1658 1659 opts = _byteskwargs(opts)
1659 1660 timer, fm = gettimer(ui, opts)
1660 1661
1661 1662 def d():
1662 1663 if os.name != 'nt':
1663 1664 os.system(
1664 1665 b"HGRCPATH= %s version -q > /dev/null" % fsencode(sys.argv[0])
1665 1666 )
1666 1667 else:
1667 1668 os.environ['HGRCPATH'] = r' '
1668 1669 os.system("%s version -q > NUL" % sys.argv[0])
1669 1670
1670 1671 timer(d)
1671 1672 fm.end()
1672 1673
1673 1674
1674 1675 @command(b'perfparents', formatteropts)
1675 1676 def perfparents(ui, repo, **opts):
1676 1677 """benchmark the time necessary to fetch one changeset's parents.
1677 1678
1678 1679 The fetch is done using the `node identifier`, traversing all object layers
1679 1680 from the repository object. The first N revisions will be used for this
1680 1681 benchmark. N is controlled by the ``perf.parentscount`` config option
1681 1682 (default: 1000).
1682 1683 """
1683 1684 opts = _byteskwargs(opts)
1684 1685 timer, fm = gettimer(ui, opts)
1685 1686 # control the number of commits perfparents iterates over
1686 1687 # experimental config: perf.parentscount
1687 1688 count = getint(ui, b"perf", b"parentscount", 1000)
1688 1689 if len(repo.changelog) < count:
1689 1690 raise error.Abort(b"repo needs %d commits for this test" % count)
1690 1691 repo = repo.unfiltered()
1691 1692 nl = [repo.changelog.node(i) for i in _xrange(count)]
1692 1693
1693 1694 def d():
1694 1695 for n in nl:
1695 1696 repo.changelog.parents(n)
1696 1697
1697 1698 timer(d)
1698 1699 fm.end()
1699 1700
1700 1701
1701 1702 @command(b'perfctxfiles', formatteropts)
1702 1703 def perfctxfiles(ui, repo, x, **opts):
1703 1704 opts = _byteskwargs(opts)
1704 1705 x = int(x)
1705 1706 timer, fm = gettimer(ui, opts)
1706 1707
1707 1708 def d():
1708 1709 len(repo[x].files())
1709 1710
1710 1711 timer(d)
1711 1712 fm.end()
1712 1713
1713 1714
1714 1715 @command(b'perfrawfiles', formatteropts)
1715 1716 def perfrawfiles(ui, repo, x, **opts):
1716 1717 opts = _byteskwargs(opts)
1717 1718 x = int(x)
1718 1719 timer, fm = gettimer(ui, opts)
1719 1720 cl = repo.changelog
1720 1721
1721 1722 def d():
1722 1723 len(cl.read(x)[3])
1723 1724
1724 1725 timer(d)
1725 1726 fm.end()
1726 1727
1727 1728
1728 1729 @command(b'perflookup', formatteropts)
1729 1730 def perflookup(ui, repo, rev, **opts):
1730 1731 opts = _byteskwargs(opts)
1731 1732 timer, fm = gettimer(ui, opts)
1732 1733 timer(lambda: len(repo.lookup(rev)))
1733 1734 fm.end()
1734 1735
1735 1736
1736 1737 @command(
1737 1738 b'perflinelogedits',
1738 1739 [
1739 1740 (b'n', b'edits', 10000, b'number of edits'),
1740 1741 (b'', b'max-hunk-lines', 10, b'max lines in a hunk'),
1741 1742 ],
1742 1743 norepo=True,
1743 1744 )
1744 1745 def perflinelogedits(ui, **opts):
1745 1746 from mercurial import linelog
1746 1747
1747 1748 opts = _byteskwargs(opts)
1748 1749
1749 1750 edits = opts[b'edits']
1750 1751 maxhunklines = opts[b'max_hunk_lines']
1751 1752
1752 1753 maxb1 = 100000
1753 1754 random.seed(0)
1754 1755 randint = random.randint
1755 1756 currentlines = 0
1756 1757 arglist = []
1757 1758 for rev in _xrange(edits):
1758 1759 a1 = randint(0, currentlines)
1759 1760 a2 = randint(a1, min(currentlines, a1 + maxhunklines))
1760 1761 b1 = randint(0, maxb1)
1761 1762 b2 = randint(b1, b1 + maxhunklines)
1762 1763 currentlines += (b2 - b1) - (a2 - a1)
1763 1764 arglist.append((rev, a1, a2, b1, b2))
1764 1765
1765 1766 def d():
1766 1767 ll = linelog.linelog()
1767 1768 for args in arglist:
1768 1769 ll.replacelines(*args)
1769 1770
1770 1771 timer, fm = gettimer(ui, opts)
1771 1772 timer(d)
1772 1773 fm.end()
1773 1774
1774 1775
1775 1776 @command(b'perfrevrange', formatteropts)
1776 1777 def perfrevrange(ui, repo, *specs, **opts):
1777 1778 opts = _byteskwargs(opts)
1778 1779 timer, fm = gettimer(ui, opts)
1779 1780 revrange = scmutil.revrange
1780 1781 timer(lambda: len(revrange(repo, specs)))
1781 1782 fm.end()
1782 1783
1783 1784
1784 1785 @command(b'perfnodelookup', formatteropts)
1785 1786 def perfnodelookup(ui, repo, rev, **opts):
1786 1787 opts = _byteskwargs(opts)
1787 1788 timer, fm = gettimer(ui, opts)
1788 1789 import mercurial.revlog
1789 1790
1790 1791 mercurial.revlog._prereadsize = 2 ** 24 # disable lazy parser in old hg
1791 1792 n = scmutil.revsingle(repo, rev).node()
1792 1793 cl = mercurial.revlog.revlog(getsvfs(repo), b"00changelog.i")
1793 1794
1794 1795 def d():
1795 1796 cl.rev(n)
1796 1797 clearcaches(cl)
1797 1798
1798 1799 timer(d)
1799 1800 fm.end()
1800 1801
1801 1802
1802 1803 @command(
1803 1804 b'perflog',
1804 1805 [(b'', b'rename', False, b'ask log to follow renames')] + formatteropts,
1805 1806 )
1806 1807 def perflog(ui, repo, rev=None, **opts):
1807 1808 opts = _byteskwargs(opts)
1808 1809 if rev is None:
1809 1810 rev = []
1810 1811 timer, fm = gettimer(ui, opts)
1811 1812 ui.pushbuffer()
1812 1813 timer(
1813 1814 lambda: commands.log(
1814 1815 ui, repo, rev=rev, date=b'', user=b'', copies=opts.get(b'rename')
1815 1816 )
1816 1817 )
1817 1818 ui.popbuffer()
1818 1819 fm.end()
1819 1820
1820 1821
1821 1822 @command(b'perfmoonwalk', formatteropts)
1822 1823 def perfmoonwalk(ui, repo, **opts):
1823 1824 """benchmark walking the changelog backwards
1824 1825
1825 1826 This also loads the changelog data for each revision in the changelog.
1826 1827 """
1827 1828 opts = _byteskwargs(opts)
1828 1829 timer, fm = gettimer(ui, opts)
1829 1830
1830 1831 def moonwalk():
1831 1832 for i in repo.changelog.revs(start=(len(repo) - 1), stop=-1):
1832 1833 ctx = repo[i]
1833 1834 ctx.branch() # read changelog data (in addition to the index)
1834 1835
1835 1836 timer(moonwalk)
1836 1837 fm.end()
1837 1838
1838 1839
1839 1840 @command(
1840 1841 b'perftemplating',
1841 1842 [(b'r', b'rev', [], b'revisions to run the template on'),] + formatteropts,
1842 1843 )
1843 1844 def perftemplating(ui, repo, testedtemplate=None, **opts):
1844 1845 """test the rendering time of a given template"""
1845 1846 if makelogtemplater is None:
1846 1847 raise error.Abort(
1847 1848 b"perftemplating not available with this Mercurial",
1848 1849 hint=b"use 4.3 or later",
1849 1850 )
1850 1851
1851 1852 opts = _byteskwargs(opts)
1852 1853
1853 1854 nullui = ui.copy()
1854 1855 nullui.fout = open(os.devnull, 'wb')
1855 1856 nullui.disablepager()
1856 1857 revs = opts.get(b'rev')
1857 1858 if not revs:
1858 1859 revs = [b'all()']
1859 1860 revs = list(scmutil.revrange(repo, revs))
1860 1861
1861 1862 defaulttemplate = (
1862 1863 b'{date|shortdate} [{rev}:{node|short}]'
1863 1864 b' {author|person}: {desc|firstline}\n'
1864 1865 )
1865 1866 if testedtemplate is None:
1866 1867 testedtemplate = defaulttemplate
1867 1868 displayer = makelogtemplater(nullui, repo, testedtemplate)
1868 1869
1869 1870 def format():
1870 1871 for r in revs:
1871 1872 ctx = repo[r]
1872 1873 displayer.show(ctx)
1873 1874 displayer.flush(ctx)
1874 1875
1875 1876 timer, fm = gettimer(ui, opts)
1876 1877 timer(format)
1877 1878 fm.end()
1878 1879
1879 1880
1880 1881 def _displaystats(ui, opts, entries, data):
1881 1882 # use a second formatter because the data are quite different, not sure
1882 1883 # how it flies with the templater.
1883 1884 fm = ui.formatter(b'perf-stats', opts)
1884 1885 for key, title in entries:
1885 1886 values = data[key]
1886 1887 nbvalues = len(data)
1887 1888 values.sort()
1888 1889 stats = {
1889 1890 'key': key,
1890 1891 'title': title,
1891 1892 'nbitems': len(values),
1892 1893 'min': values[0][0],
1893 1894 '10%': values[(nbvalues * 10) // 100][0],
1894 1895 '25%': values[(nbvalues * 25) // 100][0],
1895 1896 '50%': values[(nbvalues * 50) // 100][0],
1896 1897 '75%': values[(nbvalues * 75) // 100][0],
1897 1898 '80%': values[(nbvalues * 80) // 100][0],
1898 1899 '85%': values[(nbvalues * 85) // 100][0],
1899 1900 '90%': values[(nbvalues * 90) // 100][0],
1900 1901 '95%': values[(nbvalues * 95) // 100][0],
1901 1902 '99%': values[(nbvalues * 99) // 100][0],
1902 1903 'max': values[-1][0],
1903 1904 }
1904 1905 fm.startitem()
1905 1906 fm.data(**stats)
1906 1907 # make node pretty for the human output
1907 1908 fm.plain('### %s (%d items)\n' % (title, len(values)))
1908 1909 lines = [
1909 1910 'min',
1910 1911 '10%',
1911 1912 '25%',
1912 1913 '50%',
1913 1914 '75%',
1914 1915 '80%',
1915 1916 '85%',
1916 1917 '90%',
1917 1918 '95%',
1918 1919 '99%',
1919 1920 'max',
1920 1921 ]
1921 1922 for l in lines:
1922 1923 fm.plain('%s: %s\n' % (l, stats[l]))
1923 1924 fm.end()
1924 1925
1925 1926
1926 1927 @command(
1927 1928 b'perfhelper-mergecopies',
1928 1929 formatteropts
1929 1930 + [
1930 1931 (b'r', b'revs', [], b'restrict search to these revisions'),
1931 1932 (b'', b'timing', False, b'provides extra data (costly)'),
1932 1933 (b'', b'stats', False, b'provides statistic about the measured data'),
1933 1934 ],
1934 1935 )
1935 1936 def perfhelpermergecopies(ui, repo, revs=[], **opts):
1936 1937 """find statistics about potential parameters for `perfmergecopies`
1937 1938
1938 1939 This command find (base, p1, p2) triplet relevant for copytracing
1939 1940 benchmarking in the context of a merge. It reports values for some of the
1940 1941 parameters that impact merge copy tracing time during merge.
1941 1942
1942 1943 If `--timing` is set, rename detection is run and the associated timing
1943 1944 will be reported. The extra details come at the cost of slower command
1944 1945 execution.
1945 1946
1946 1947 Since rename detection is only run once, other factors might easily
1947 1948 affect the precision of the timing. However it should give a good
1948 1949 approximation of which revision triplets are very costly.
1949 1950 """
1950 1951 opts = _byteskwargs(opts)
1951 1952 fm = ui.formatter(b'perf', opts)
1952 1953 dotiming = opts[b'timing']
1953 1954 dostats = opts[b'stats']
1954 1955
1955 1956 output_template = [
1956 1957 ("base", "%(base)12s"),
1957 1958 ("p1", "%(p1.node)12s"),
1958 1959 ("p2", "%(p2.node)12s"),
1959 1960 ("p1.nb-revs", "%(p1.nbrevs)12d"),
1960 1961 ("p1.nb-files", "%(p1.nbmissingfiles)12d"),
1961 1962 ("p1.renames", "%(p1.renamedfiles)12d"),
1962 1963 ("p1.time", "%(p1.time)12.3f"),
1963 1964 ("p2.nb-revs", "%(p2.nbrevs)12d"),
1964 1965 ("p2.nb-files", "%(p2.nbmissingfiles)12d"),
1965 1966 ("p2.renames", "%(p2.renamedfiles)12d"),
1966 1967 ("p2.time", "%(p2.time)12.3f"),
1967 1968 ("renames", "%(nbrenamedfiles)12d"),
1968 1969 ("total.time", "%(time)12.3f"),
1969 1970 ]
1970 1971 if not dotiming:
1971 1972 output_template = [
1972 1973 i
1973 1974 for i in output_template
1974 1975 if not ('time' in i[0] or 'renames' in i[0])
1975 1976 ]
1976 1977 header_names = [h for (h, v) in output_template]
1977 1978 output = ' '.join([v for (h, v) in output_template]) + '\n'
1978 1979 header = ' '.join(['%12s'] * len(header_names)) + '\n'
1979 1980 fm.plain(header % tuple(header_names))
1980 1981
1981 1982 if not revs:
1982 1983 revs = ['all()']
1983 1984 revs = scmutil.revrange(repo, revs)
1984 1985
1985 1986 if dostats:
1986 1987 alldata = {
1987 1988 'nbrevs': [],
1988 1989 'nbmissingfiles': [],
1989 1990 }
1990 1991 if dotiming:
1991 1992 alldata['parentnbrenames'] = []
1992 1993 alldata['totalnbrenames'] = []
1993 1994 alldata['parenttime'] = []
1994 1995 alldata['totaltime'] = []
1995 1996
1996 1997 roi = repo.revs('merge() and %ld', revs)
1997 1998 for r in roi:
1998 1999 ctx = repo[r]
1999 2000 p1 = ctx.p1()
2000 2001 p2 = ctx.p2()
2001 2002 bases = repo.changelog._commonancestorsheads(p1.rev(), p2.rev())
2002 2003 for b in bases:
2003 2004 b = repo[b]
2004 2005 p1missing = copies._computeforwardmissing(b, p1)
2005 2006 p2missing = copies._computeforwardmissing(b, p2)
2006 2007 data = {
2007 2008 b'base': b.hex(),
2008 2009 b'p1.node': p1.hex(),
2009 2010 b'p1.nbrevs': len(repo.revs('only(%d, %d)', p1.rev(), b.rev())),
2010 2011 b'p1.nbmissingfiles': len(p1missing),
2011 2012 b'p2.node': p2.hex(),
2012 2013 b'p2.nbrevs': len(repo.revs('only(%d, %d)', p2.rev(), b.rev())),
2013 2014 b'p2.nbmissingfiles': len(p2missing),
2014 2015 }
2015 2016 if dostats:
2016 2017 if p1missing:
2017 2018 alldata['nbrevs'].append(
2018 2019 (data['p1.nbrevs'], b.hex(), p1.hex())
2019 2020 )
2020 2021 alldata['nbmissingfiles'].append(
2021 2022 (data['p1.nbmissingfiles'], b.hex(), p1.hex())
2022 2023 )
2023 2024 if p2missing:
2024 2025 alldata['nbrevs'].append(
2025 2026 (data['p2.nbrevs'], b.hex(), p2.hex())
2026 2027 )
2027 2028 alldata['nbmissingfiles'].append(
2028 2029 (data['p2.nbmissingfiles'], b.hex(), p2.hex())
2029 2030 )
2030 2031 if dotiming:
2031 2032 begin = util.timer()
2032 2033 mergedata = copies.mergecopies(repo, p1, p2, b)
2033 2034 end = util.timer()
2034 2035 # not very stable timing since we did only one run
2035 2036 data['time'] = end - begin
2036 2037 # mergedata contains five dicts: "copy", "movewithdir",
2037 2038 # "diverge", "renamedelete" and "dirmove".
2038 2039 # The first 4 are about renamed file so lets count that.
2039 2040 renames = len(mergedata[0])
2040 2041 renames += len(mergedata[1])
2041 2042 renames += len(mergedata[2])
2042 2043 renames += len(mergedata[3])
2043 2044 data['nbrenamedfiles'] = renames
2044 2045 begin = util.timer()
2045 2046 p1renames = copies.pathcopies(b, p1)
2046 2047 end = util.timer()
2047 2048 data['p1.time'] = end - begin
2048 2049 begin = util.timer()
2049 2050 p2renames = copies.pathcopies(b, p2)
2050 2051 end = util.timer()
2051 2052 data['p2.time'] = end - begin
2052 2053 data['p1.renamedfiles'] = len(p1renames)
2053 2054 data['p2.renamedfiles'] = len(p2renames)
2054 2055
2055 2056 if dostats:
2056 2057 if p1missing:
2057 2058 alldata['parentnbrenames'].append(
2058 2059 (data['p1.renamedfiles'], b.hex(), p1.hex())
2059 2060 )
2060 2061 alldata['parenttime'].append(
2061 2062 (data['p1.time'], b.hex(), p1.hex())
2062 2063 )
2063 2064 if p2missing:
2064 2065 alldata['parentnbrenames'].append(
2065 2066 (data['p2.renamedfiles'], b.hex(), p2.hex())
2066 2067 )
2067 2068 alldata['parenttime'].append(
2068 2069 (data['p2.time'], b.hex(), p2.hex())
2069 2070 )
2070 2071 if p1missing or p2missing:
2071 2072 alldata['totalnbrenames'].append(
2072 2073 (
2073 2074 data['nbrenamedfiles'],
2074 2075 b.hex(),
2075 2076 p1.hex(),
2076 2077 p2.hex(),
2077 2078 )
2078 2079 )
2079 2080 alldata['totaltime'].append(
2080 2081 (data['time'], b.hex(), p1.hex(), p2.hex())
2081 2082 )
2082 2083 fm.startitem()
2083 2084 fm.data(**data)
2084 2085 # make node pretty for the human output
2085 2086 out = data.copy()
2086 2087 out['base'] = fm.hexfunc(b.node())
2087 2088 out['p1.node'] = fm.hexfunc(p1.node())
2088 2089 out['p2.node'] = fm.hexfunc(p2.node())
2089 2090 fm.plain(output % out)
2090 2091
2091 2092 fm.end()
2092 2093 if dostats:
2093 2094 # use a second formatter because the data are quite different, not sure
2094 2095 # how it flies with the templater.
2095 2096 entries = [
2096 2097 ('nbrevs', 'number of revision covered'),
2097 2098 ('nbmissingfiles', 'number of missing files at head'),
2098 2099 ]
2099 2100 if dotiming:
2100 2101 entries.append(
2101 2102 ('parentnbrenames', 'rename from one parent to base')
2102 2103 )
2103 2104 entries.append(('totalnbrenames', 'total number of renames'))
2104 2105 entries.append(('parenttime', 'time for one parent'))
2105 2106 entries.append(('totaltime', 'time for both parents'))
2106 2107 _displaystats(ui, opts, entries, alldata)
2107 2108
2108 2109
2109 2110 @command(
2110 2111 b'perfhelper-pathcopies',
2111 2112 formatteropts
2112 2113 + [
2113 2114 (b'r', b'revs', [], b'restrict search to these revisions'),
2114 2115 (b'', b'timing', False, b'provides extra data (costly)'),
2115 2116 (b'', b'stats', False, b'provides statistic about the measured data'),
2116 2117 ],
2117 2118 )
2118 2119 def perfhelperpathcopies(ui, repo, revs=[], **opts):
2119 2120 """find statistic about potential parameters for the `perftracecopies`
2120 2121
2121 2122 This command find source-destination pair relevant for copytracing testing.
2122 2123 It report value for some of the parameters that impact copy tracing time.
2123 2124
2124 2125 If `--timing` is set, rename detection is run and the associated timing
2125 2126 will be reported. The extra details comes at the cost of a slower command
2126 2127 execution.
2127 2128
2128 2129 Since the rename detection is only run once, other factors might easily
2129 2130 affect the precision of the timing. However it should give a good
2130 2131 approximation of which revision pairs are very costly.
2131 2132 """
2132 2133 opts = _byteskwargs(opts)
2133 2134 fm = ui.formatter(b'perf', opts)
2134 2135 dotiming = opts[b'timing']
2135 2136 dostats = opts[b'stats']
2136 2137
2137 2138 if dotiming:
2138 2139 header = '%12s %12s %12s %12s %12s %12s\n'
2139 2140 output = (
2140 2141 "%(source)12s %(destination)12s "
2141 2142 "%(nbrevs)12d %(nbmissingfiles)12d "
2142 2143 "%(nbrenamedfiles)12d %(time)18.5f\n"
2143 2144 )
2144 2145 header_names = (
2145 2146 "source",
2146 2147 "destination",
2147 2148 "nb-revs",
2148 2149 "nb-files",
2149 2150 "nb-renames",
2150 2151 "time",
2151 2152 )
2152 2153 fm.plain(header % header_names)
2153 2154 else:
2154 2155 header = '%12s %12s %12s %12s\n'
2155 2156 output = (
2156 2157 "%(source)12s %(destination)12s "
2157 2158 "%(nbrevs)12d %(nbmissingfiles)12d\n"
2158 2159 )
2159 2160 fm.plain(header % ("source", "destination", "nb-revs", "nb-files"))
2160 2161
2161 2162 if not revs:
2162 2163 revs = ['all()']
2163 2164 revs = scmutil.revrange(repo, revs)
2164 2165
2165 2166 if dostats:
2166 2167 alldata = {
2167 2168 'nbrevs': [],
2168 2169 'nbmissingfiles': [],
2169 2170 }
2170 2171 if dotiming:
2171 2172 alldata['nbrenames'] = []
2172 2173 alldata['time'] = []
2173 2174
2174 2175 roi = repo.revs('merge() and %ld', revs)
2175 2176 for r in roi:
2176 2177 ctx = repo[r]
2177 2178 p1 = ctx.p1().rev()
2178 2179 p2 = ctx.p2().rev()
2179 2180 bases = repo.changelog._commonancestorsheads(p1, p2)
2180 2181 for p in (p1, p2):
2181 2182 for b in bases:
2182 2183 base = repo[b]
2183 2184 parent = repo[p]
2184 2185 missing = copies._computeforwardmissing(base, parent)
2185 2186 if not missing:
2186 2187 continue
2187 2188 data = {
2188 2189 b'source': base.hex(),
2189 2190 b'destination': parent.hex(),
2190 2191 b'nbrevs': len(repo.revs('only(%d, %d)', p, b)),
2191 2192 b'nbmissingfiles': len(missing),
2192 2193 }
2193 2194 if dostats:
2194 2195 alldata['nbrevs'].append(
2195 2196 (data['nbrevs'], base.hex(), parent.hex(),)
2196 2197 )
2197 2198 alldata['nbmissingfiles'].append(
2198 2199 (data['nbmissingfiles'], base.hex(), parent.hex(),)
2199 2200 )
2200 2201 if dotiming:
2201 2202 begin = util.timer()
2202 2203 renames = copies.pathcopies(base, parent)
2203 2204 end = util.timer()
2204 2205 # not very stable timing since we did only one run
2205 2206 data['time'] = end - begin
2206 2207 data['nbrenamedfiles'] = len(renames)
2207 2208 if dostats:
2208 2209 alldata['time'].append(
2209 2210 (data['time'], base.hex(), parent.hex(),)
2210 2211 )
2211 2212 alldata['nbrenames'].append(
2212 2213 (data['nbrenamedfiles'], base.hex(), parent.hex(),)
2213 2214 )
2214 2215 fm.startitem()
2215 2216 fm.data(**data)
2216 2217 out = data.copy()
2217 2218 out['source'] = fm.hexfunc(base.node())
2218 2219 out['destination'] = fm.hexfunc(parent.node())
2219 2220 fm.plain(output % out)
2220 2221
2221 2222 fm.end()
2222 2223 if dostats:
2223 2224 entries = [
2224 2225 ('nbrevs', 'number of revision covered'),
2225 2226 ('nbmissingfiles', 'number of missing files at head'),
2226 2227 ]
2227 2228 if dotiming:
2228 2229 entries.append(('nbrenames', 'renamed files'))
2229 2230 entries.append(('time', 'time'))
2230 2231 _displaystats(ui, opts, entries, alldata)
2231 2232
2232 2233
2233 2234 @command(b'perfcca', formatteropts)
2234 2235 def perfcca(ui, repo, **opts):
2235 2236 opts = _byteskwargs(opts)
2236 2237 timer, fm = gettimer(ui, opts)
2237 2238 timer(lambda: scmutil.casecollisionauditor(ui, False, repo.dirstate))
2238 2239 fm.end()
2239 2240
2240 2241
2241 2242 @command(b'perffncacheload', formatteropts)
2242 2243 def perffncacheload(ui, repo, **opts):
2243 2244 opts = _byteskwargs(opts)
2244 2245 timer, fm = gettimer(ui, opts)
2245 2246 s = repo.store
2246 2247
2247 2248 def d():
2248 2249 s.fncache._load()
2249 2250
2250 2251 timer(d)
2251 2252 fm.end()
2252 2253
2253 2254
2254 2255 @command(b'perffncachewrite', formatteropts)
2255 2256 def perffncachewrite(ui, repo, **opts):
2256 2257 opts = _byteskwargs(opts)
2257 2258 timer, fm = gettimer(ui, opts)
2258 2259 s = repo.store
2259 2260 lock = repo.lock()
2260 2261 s.fncache._load()
2261 2262 tr = repo.transaction(b'perffncachewrite')
2262 2263 tr.addbackup(b'fncache')
2263 2264
2264 2265 def d():
2265 2266 s.fncache._dirty = True
2266 2267 s.fncache.write(tr)
2267 2268
2268 2269 timer(d)
2269 2270 tr.close()
2270 2271 lock.release()
2271 2272 fm.end()
2272 2273
2273 2274
2274 2275 @command(b'perffncacheencode', formatteropts)
2275 2276 def perffncacheencode(ui, repo, **opts):
2276 2277 opts = _byteskwargs(opts)
2277 2278 timer, fm = gettimer(ui, opts)
2278 2279 s = repo.store
2279 2280 s.fncache._load()
2280 2281
2281 2282 def d():
2282 2283 for p in s.fncache.entries:
2283 2284 s.encode(p)
2284 2285
2285 2286 timer(d)
2286 2287 fm.end()
2287 2288
2288 2289
2289 2290 def _bdiffworker(q, blocks, xdiff, ready, done):
2290 2291 while not done.is_set():
2291 2292 pair = q.get()
2292 2293 while pair is not None:
2293 2294 if xdiff:
2294 2295 mdiff.bdiff.xdiffblocks(*pair)
2295 2296 elif blocks:
2296 2297 mdiff.bdiff.blocks(*pair)
2297 2298 else:
2298 2299 mdiff.textdiff(*pair)
2299 2300 q.task_done()
2300 2301 pair = q.get()
2301 2302 q.task_done() # for the None one
2302 2303 with ready:
2303 2304 ready.wait()
2304 2305
2305 2306
2306 2307 def _manifestrevision(repo, mnode):
2307 2308 ml = repo.manifestlog
2308 2309
2309 2310 if util.safehasattr(ml, b'getstorage'):
2310 2311 store = ml.getstorage(b'')
2311 2312 else:
2312 2313 store = ml._revlog
2313 2314
2314 2315 return store.revision(mnode)
2315 2316
2316 2317
2317 2318 @command(
2318 2319 b'perfbdiff',
2319 2320 revlogopts
2320 2321 + formatteropts
2321 2322 + [
2322 2323 (
2323 2324 b'',
2324 2325 b'count',
2325 2326 1,
2326 2327 b'number of revisions to test (when using --startrev)',
2327 2328 ),
2328 2329 (b'', b'alldata', False, b'test bdiffs for all associated revisions'),
2329 2330 (b'', b'threads', 0, b'number of thread to use (disable with 0)'),
2330 2331 (b'', b'blocks', False, b'test computing diffs into blocks'),
2331 2332 (b'', b'xdiff', False, b'use xdiff algorithm'),
2332 2333 ],
2333 2334 b'-c|-m|FILE REV',
2334 2335 )
2335 2336 def perfbdiff(ui, repo, file_, rev=None, count=None, threads=0, **opts):
2336 2337 """benchmark a bdiff between revisions
2337 2338
2338 2339 By default, benchmark a bdiff between its delta parent and itself.
2339 2340
2340 2341 With ``--count``, benchmark bdiffs between delta parents and self for N
2341 2342 revisions starting at the specified revision.
2342 2343
2343 2344 With ``--alldata``, assume the requested revision is a changeset and
2344 2345 measure bdiffs for all changes related to that changeset (manifest
2345 2346 and filelogs).
2346 2347 """
2347 2348 opts = _byteskwargs(opts)
2348 2349
2349 2350 if opts[b'xdiff'] and not opts[b'blocks']:
2350 2351 raise error.CommandError(b'perfbdiff', b'--xdiff requires --blocks')
2351 2352
2352 2353 if opts[b'alldata']:
2353 2354 opts[b'changelog'] = True
2354 2355
2355 2356 if opts.get(b'changelog') or opts.get(b'manifest'):
2356 2357 file_, rev = None, file_
2357 2358 elif rev is None:
2358 2359 raise error.CommandError(b'perfbdiff', b'invalid arguments')
2359 2360
2360 2361 blocks = opts[b'blocks']
2361 2362 xdiff = opts[b'xdiff']
2362 2363 textpairs = []
2363 2364
2364 2365 r = cmdutil.openrevlog(repo, b'perfbdiff', file_, opts)
2365 2366
2366 2367 startrev = r.rev(r.lookup(rev))
2367 2368 for rev in range(startrev, min(startrev + count, len(r) - 1)):
2368 2369 if opts[b'alldata']:
2369 2370 # Load revisions associated with changeset.
2370 2371 ctx = repo[rev]
2371 2372 mtext = _manifestrevision(repo, ctx.manifestnode())
2372 2373 for pctx in ctx.parents():
2373 2374 pman = _manifestrevision(repo, pctx.manifestnode())
2374 2375 textpairs.append((pman, mtext))
2375 2376
2376 2377 # Load filelog revisions by iterating manifest delta.
2377 2378 man = ctx.manifest()
2378 2379 pman = ctx.p1().manifest()
2379 2380 for filename, change in pman.diff(man).items():
2380 2381 fctx = repo.file(filename)
2381 2382 f1 = fctx.revision(change[0][0] or -1)
2382 2383 f2 = fctx.revision(change[1][0] or -1)
2383 2384 textpairs.append((f1, f2))
2384 2385 else:
2385 2386 dp = r.deltaparent(rev)
2386 2387 textpairs.append((r.revision(dp), r.revision(rev)))
2387 2388
2388 2389 withthreads = threads > 0
2389 2390 if not withthreads:
2390 2391
2391 2392 def d():
2392 2393 for pair in textpairs:
2393 2394 if xdiff:
2394 2395 mdiff.bdiff.xdiffblocks(*pair)
2395 2396 elif blocks:
2396 2397 mdiff.bdiff.blocks(*pair)
2397 2398 else:
2398 2399 mdiff.textdiff(*pair)
2399 2400
2400 2401 else:
2401 2402 q = queue()
2402 2403 for i in _xrange(threads):
2403 2404 q.put(None)
2404 2405 ready = threading.Condition()
2405 2406 done = threading.Event()
2406 2407 for i in _xrange(threads):
2407 2408 threading.Thread(
2408 2409 target=_bdiffworker, args=(q, blocks, xdiff, ready, done)
2409 2410 ).start()
2410 2411 q.join()
2411 2412
2412 2413 def d():
2413 2414 for pair in textpairs:
2414 2415 q.put(pair)
2415 2416 for i in _xrange(threads):
2416 2417 q.put(None)
2417 2418 with ready:
2418 2419 ready.notify_all()
2419 2420 q.join()
2420 2421
2421 2422 timer, fm = gettimer(ui, opts)
2422 2423 timer(d)
2423 2424 fm.end()
2424 2425
2425 2426 if withthreads:
2426 2427 done.set()
2427 2428 for i in _xrange(threads):
2428 2429 q.put(None)
2429 2430 with ready:
2430 2431 ready.notify_all()
2431 2432
2432 2433
2433 2434 @command(
2434 2435 b'perfunidiff',
2435 2436 revlogopts
2436 2437 + formatteropts
2437 2438 + [
2438 2439 (
2439 2440 b'',
2440 2441 b'count',
2441 2442 1,
2442 2443 b'number of revisions to test (when using --startrev)',
2443 2444 ),
2444 2445 (b'', b'alldata', False, b'test unidiffs for all associated revisions'),
2445 2446 ],
2446 2447 b'-c|-m|FILE REV',
2447 2448 )
2448 2449 def perfunidiff(ui, repo, file_, rev=None, count=None, **opts):
2449 2450 """benchmark a unified diff between revisions
2450 2451
2451 2452 This doesn't include any copy tracing - it's just a unified diff
2452 2453 of the texts.
2453 2454
2454 2455 By default, benchmark a diff between its delta parent and itself.
2455 2456
2456 2457 With ``--count``, benchmark diffs between delta parents and self for N
2457 2458 revisions starting at the specified revision.
2458 2459
2459 2460 With ``--alldata``, assume the requested revision is a changeset and
2460 2461 measure diffs for all changes related to that changeset (manifest
2461 2462 and filelogs).
2462 2463 """
2463 2464 opts = _byteskwargs(opts)
2464 2465 if opts[b'alldata']:
2465 2466 opts[b'changelog'] = True
2466 2467
2467 2468 if opts.get(b'changelog') or opts.get(b'manifest'):
2468 2469 file_, rev = None, file_
2469 2470 elif rev is None:
2470 2471 raise error.CommandError(b'perfunidiff', b'invalid arguments')
2471 2472
2472 2473 textpairs = []
2473 2474
2474 2475 r = cmdutil.openrevlog(repo, b'perfunidiff', file_, opts)
2475 2476
2476 2477 startrev = r.rev(r.lookup(rev))
2477 2478 for rev in range(startrev, min(startrev + count, len(r) - 1)):
2478 2479 if opts[b'alldata']:
2479 2480 # Load revisions associated with changeset.
2480 2481 ctx = repo[rev]
2481 2482 mtext = _manifestrevision(repo, ctx.manifestnode())
2482 2483 for pctx in ctx.parents():
2483 2484 pman = _manifestrevision(repo, pctx.manifestnode())
2484 2485 textpairs.append((pman, mtext))
2485 2486
2486 2487 # Load filelog revisions by iterating manifest delta.
2487 2488 man = ctx.manifest()
2488 2489 pman = ctx.p1().manifest()
2489 2490 for filename, change in pman.diff(man).items():
2490 2491 fctx = repo.file(filename)
2491 2492 f1 = fctx.revision(change[0][0] or -1)
2492 2493 f2 = fctx.revision(change[1][0] or -1)
2493 2494 textpairs.append((f1, f2))
2494 2495 else:
2495 2496 dp = r.deltaparent(rev)
2496 2497 textpairs.append((r.revision(dp), r.revision(rev)))
2497 2498
2498 2499 def d():
2499 2500 for left, right in textpairs:
2500 2501 # The date strings don't matter, so we pass empty strings.
2501 2502 headerlines, hunks = mdiff.unidiff(
2502 2503 left, b'', right, b'', b'left', b'right', binary=False
2503 2504 )
2504 2505 # consume iterators in roughly the way patch.py does
2505 2506 b'\n'.join(headerlines)
2506 2507 b''.join(sum((list(hlines) for hrange, hlines in hunks), []))
2507 2508
2508 2509 timer, fm = gettimer(ui, opts)
2509 2510 timer(d)
2510 2511 fm.end()
2511 2512
2512 2513
2513 2514 @command(b'perfdiffwd', formatteropts)
2514 2515 def perfdiffwd(ui, repo, **opts):
2515 2516 """Profile diff of working directory changes"""
2516 2517 opts = _byteskwargs(opts)
2517 2518 timer, fm = gettimer(ui, opts)
2518 2519 options = {
2519 2520 'w': 'ignore_all_space',
2520 2521 'b': 'ignore_space_change',
2521 2522 'B': 'ignore_blank_lines',
2522 2523 }
2523 2524
2524 2525 for diffopt in ('', 'w', 'b', 'B', 'wB'):
2525 opts = dict((options[c], b'1') for c in diffopt)
2526 opts = {options[c]: b'1' for c in diffopt}
2526 2527
2527 2528 def d():
2528 2529 ui.pushbuffer()
2529 2530 commands.diff(ui, repo, **opts)
2530 2531 ui.popbuffer()
2531 2532
2532 2533 diffopt = diffopt.encode('ascii')
2533 2534 title = b'diffopts: %s' % (diffopt and (b'-' + diffopt) or b'none')
2534 2535 timer(d, title=title)
2535 2536 fm.end()
2536 2537
2537 2538
2538 2539 @command(b'perfrevlogindex', revlogopts + formatteropts, b'-c|-m|FILE')
2539 2540 def perfrevlogindex(ui, repo, file_=None, **opts):
2540 2541 """Benchmark operations against a revlog index.
2541 2542
2542 2543 This tests constructing a revlog instance, reading index data,
2543 2544 parsing index data, and performing various operations related to
2544 2545 index data.
2545 2546 """
2546 2547
2547 2548 opts = _byteskwargs(opts)
2548 2549
2549 2550 rl = cmdutil.openrevlog(repo, b'perfrevlogindex', file_, opts)
2550 2551
2551 2552 opener = getattr(rl, 'opener') # trick linter
2552 2553 indexfile = rl.indexfile
2553 2554 data = opener.read(indexfile)
2554 2555
2555 2556 header = struct.unpack(b'>I', data[0:4])[0]
2556 2557 version = header & 0xFFFF
2557 2558 if version == 1:
2558 2559 revlogio = revlog.revlogio()
2559 2560 inline = header & (1 << 16)
2560 2561 else:
2561 2562 raise error.Abort(b'unsupported revlog version: %d' % version)
2562 2563
2563 2564 rllen = len(rl)
2564 2565
2565 2566 node0 = rl.node(0)
2566 2567 node25 = rl.node(rllen // 4)
2567 2568 node50 = rl.node(rllen // 2)
2568 2569 node75 = rl.node(rllen // 4 * 3)
2569 2570 node100 = rl.node(rllen - 1)
2570 2571
2571 2572 allrevs = range(rllen)
2572 2573 allrevsrev = list(reversed(allrevs))
2573 2574 allnodes = [rl.node(rev) for rev in range(rllen)]
2574 2575 allnodesrev = list(reversed(allnodes))
2575 2576
2576 2577 def constructor():
2577 2578 revlog.revlog(opener, indexfile)
2578 2579
2579 2580 def read():
2580 2581 with opener(indexfile) as fh:
2581 2582 fh.read()
2582 2583
2583 2584 def parseindex():
2584 2585 revlogio.parseindex(data, inline)
2585 2586
2586 2587 def getentry(revornode):
2587 2588 index = revlogio.parseindex(data, inline)[0]
2588 2589 index[revornode]
2589 2590
2590 2591 def getentries(revs, count=1):
2591 2592 index = revlogio.parseindex(data, inline)[0]
2592 2593
2593 2594 for i in range(count):
2594 2595 for rev in revs:
2595 2596 index[rev]
2596 2597
2597 2598 def resolvenode(node):
2598 2599 index = revlogio.parseindex(data, inline)[0]
2599 2600 rev = getattr(index, 'rev', None)
2600 2601 if rev is None:
2601 2602 nodemap = getattr(
2602 2603 revlogio.parseindex(data, inline)[0], 'nodemap', None
2603 2604 )
2604 2605 # This only works for the C code.
2605 2606 if nodemap is None:
2606 2607 return
2607 2608 rev = nodemap.__getitem__
2608 2609
2609 2610 try:
2610 2611 rev(node)
2611 2612 except error.RevlogError:
2612 2613 pass
2613 2614
2614 2615 def resolvenodes(nodes, count=1):
2615 2616 index = revlogio.parseindex(data, inline)[0]
2616 2617 rev = getattr(index, 'rev', None)
2617 2618 if rev is None:
2618 2619 nodemap = getattr(
2619 2620 revlogio.parseindex(data, inline)[0], 'nodemap', None
2620 2621 )
2621 2622 # This only works for the C code.
2622 2623 if nodemap is None:
2623 2624 return
2624 2625 rev = nodemap.__getitem__
2625 2626
2626 2627 for i in range(count):
2627 2628 for node in nodes:
2628 2629 try:
2629 2630 rev(node)
2630 2631 except error.RevlogError:
2631 2632 pass
2632 2633
2633 2634 benches = [
2634 2635 (constructor, b'revlog constructor'),
2635 2636 (read, b'read'),
2636 2637 (parseindex, b'create index object'),
2637 2638 (lambda: getentry(0), b'retrieve index entry for rev 0'),
2638 2639 (lambda: resolvenode(b'a' * 20), b'look up missing node'),
2639 2640 (lambda: resolvenode(node0), b'look up node at rev 0'),
2640 2641 (lambda: resolvenode(node25), b'look up node at 1/4 len'),
2641 2642 (lambda: resolvenode(node50), b'look up node at 1/2 len'),
2642 2643 (lambda: resolvenode(node75), b'look up node at 3/4 len'),
2643 2644 (lambda: resolvenode(node100), b'look up node at tip'),
2644 2645 # 2x variation is to measure caching impact.
2645 2646 (lambda: resolvenodes(allnodes), b'look up all nodes (forward)'),
2646 2647 (lambda: resolvenodes(allnodes, 2), b'look up all nodes 2x (forward)'),
2647 2648 (lambda: resolvenodes(allnodesrev), b'look up all nodes (reverse)'),
2648 2649 (
2649 2650 lambda: resolvenodes(allnodesrev, 2),
2650 2651 b'look up all nodes 2x (reverse)',
2651 2652 ),
2652 2653 (lambda: getentries(allrevs), b'retrieve all index entries (forward)'),
2653 2654 (
2654 2655 lambda: getentries(allrevs, 2),
2655 2656 b'retrieve all index entries 2x (forward)',
2656 2657 ),
2657 2658 (
2658 2659 lambda: getentries(allrevsrev),
2659 2660 b'retrieve all index entries (reverse)',
2660 2661 ),
2661 2662 (
2662 2663 lambda: getentries(allrevsrev, 2),
2663 2664 b'retrieve all index entries 2x (reverse)',
2664 2665 ),
2665 2666 ]
2666 2667
2667 2668 for fn, title in benches:
2668 2669 timer, fm = gettimer(ui, opts)
2669 2670 timer(fn, title=title)
2670 2671 fm.end()
2671 2672
2672 2673
2673 2674 @command(
2674 2675 b'perfrevlogrevisions',
2675 2676 revlogopts
2676 2677 + formatteropts
2677 2678 + [
2678 2679 (b'd', b'dist', 100, b'distance between the revisions'),
2679 2680 (b's', b'startrev', 0, b'revision to start reading at'),
2680 2681 (b'', b'reverse', False, b'read in reverse'),
2681 2682 ],
2682 2683 b'-c|-m|FILE',
2683 2684 )
2684 2685 def perfrevlogrevisions(
2685 2686 ui, repo, file_=None, startrev=0, reverse=False, **opts
2686 2687 ):
2687 2688 """Benchmark reading a series of revisions from a revlog.
2688 2689
2689 2690 By default, we read every ``-d/--dist`` revision from 0 to tip of
2690 2691 the specified revlog.
2691 2692
2692 2693 The start revision can be defined via ``-s/--startrev``.
2693 2694 """
2694 2695 opts = _byteskwargs(opts)
2695 2696
2696 2697 rl = cmdutil.openrevlog(repo, b'perfrevlogrevisions', file_, opts)
2697 2698 rllen = getlen(ui)(rl)
2698 2699
2699 2700 if startrev < 0:
2700 2701 startrev = rllen + startrev
2701 2702
2702 2703 def d():
2703 2704 rl.clearcaches()
2704 2705
2705 2706 beginrev = startrev
2706 2707 endrev = rllen
2707 2708 dist = opts[b'dist']
2708 2709
2709 2710 if reverse:
2710 2711 beginrev, endrev = endrev - 1, beginrev - 1
2711 2712 dist = -1 * dist
2712 2713
2713 2714 for x in _xrange(beginrev, endrev, dist):
2714 2715 # Old revisions don't support passing int.
2715 2716 n = rl.node(x)
2716 2717 rl.revision(n)
2717 2718
2718 2719 timer, fm = gettimer(ui, opts)
2719 2720 timer(d)
2720 2721 fm.end()
2721 2722
2722 2723
2723 2724 @command(
2724 2725 b'perfrevlogwrite',
2725 2726 revlogopts
2726 2727 + formatteropts
2727 2728 + [
2728 2729 (b's', b'startrev', 1000, b'revision to start writing at'),
2729 2730 (b'', b'stoprev', -1, b'last revision to write'),
2730 2731 (b'', b'count', 3, b'number of passes to perform'),
2731 2732 (b'', b'details', False, b'print timing for every revisions tested'),
2732 2733 (b'', b'source', b'full', b'the kind of data feed in the revlog'),
2733 2734 (b'', b'lazydeltabase', True, b'try the provided delta first'),
2734 2735 (b'', b'clear-caches', True, b'clear revlog cache between calls'),
2735 2736 ],
2736 2737 b'-c|-m|FILE',
2737 2738 )
2738 2739 def perfrevlogwrite(ui, repo, file_=None, startrev=1000, stoprev=-1, **opts):
2739 2740 """Benchmark writing a series of revisions to a revlog.
2740 2741
2741 2742 Possible source values are:
2742 2743 * `full`: add from a full text (default).
2743 2744 * `parent-1`: add from a delta to the first parent
2744 2745 * `parent-2`: add from a delta to the second parent if it exists
2745 2746 (use a delta from the first parent otherwise)
2746 2747 * `parent-smallest`: add from the smallest delta (either p1 or p2)
2747 2748 * `storage`: add from the existing precomputed deltas
2748 2749
2749 2750 Note: This performance command measures performance in a custom way. As a
2750 2751 result some of the global configuration of the 'perf' command does not
2751 2752 apply to it:
2752 2753
2753 2754 * ``pre-run``: disabled
2754 2755
2755 2756 * ``profile-benchmark``: disabled
2756 2757
2757 2758 * ``run-limits``: disabled use --count instead
2758 2759 """
2759 2760 opts = _byteskwargs(opts)
2760 2761
2761 2762 rl = cmdutil.openrevlog(repo, b'perfrevlogwrite', file_, opts)
2762 2763 rllen = getlen(ui)(rl)
2763 2764 if startrev < 0:
2764 2765 startrev = rllen + startrev
2765 2766 if stoprev < 0:
2766 2767 stoprev = rllen + stoprev
2767 2768
2768 2769 lazydeltabase = opts['lazydeltabase']
2769 2770 source = opts['source']
2770 2771 clearcaches = opts['clear_caches']
2771 2772 validsource = (
2772 2773 b'full',
2773 2774 b'parent-1',
2774 2775 b'parent-2',
2775 2776 b'parent-smallest',
2776 2777 b'storage',
2777 2778 )
2778 2779 if source not in validsource:
2779 2780 raise error.Abort('invalid source type: %s' % source)
2780 2781
2781 2782 ### actually gather results
2782 2783 count = opts['count']
2783 2784 if count <= 0:
2784 2785 raise error.Abort('invalide run count: %d' % count)
2785 2786 allresults = []
2786 2787 for c in range(count):
2787 2788 timing = _timeonewrite(
2788 2789 ui,
2789 2790 rl,
2790 2791 source,
2791 2792 startrev,
2792 2793 stoprev,
2793 2794 c + 1,
2794 2795 lazydeltabase=lazydeltabase,
2795 2796 clearcaches=clearcaches,
2796 2797 )
2797 2798 allresults.append(timing)
2798 2799
2799 2800 ### consolidate the results in a single list
2800 2801 results = []
2801 2802 for idx, (rev, t) in enumerate(allresults[0]):
2802 2803 ts = [t]
2803 2804 for other in allresults[1:]:
2804 2805 orev, ot = other[idx]
2805 2806 assert orev == rev
2806 2807 ts.append(ot)
2807 2808 results.append((rev, ts))
2808 2809 resultcount = len(results)
2809 2810
2810 2811 ### Compute and display relevant statistics
2811 2812
2812 2813 # get a formatter
2813 2814 fm = ui.formatter(b'perf', opts)
2814 2815 displayall = ui.configbool(b"perf", b"all-timing", False)
2815 2816
2816 2817 # print individual details if requested
2817 2818 if opts['details']:
2818 2819 for idx, item in enumerate(results, 1):
2819 2820 rev, data = item
2820 2821 title = 'revisions #%d of %d, rev %d' % (idx, resultcount, rev)
2821 2822 formatone(fm, data, title=title, displayall=displayall)
2822 2823
2823 2824 # sorts results by median time
2824 2825 results.sort(key=lambda x: sorted(x[1])[len(x[1]) // 2])
2825 2826 # list of (name, index) to display)
2826 2827 relevants = [
2827 2828 ("min", 0),
2828 2829 ("10%", resultcount * 10 // 100),
2829 2830 ("25%", resultcount * 25 // 100),
2830 2831 ("50%", resultcount * 70 // 100),
2831 2832 ("75%", resultcount * 75 // 100),
2832 2833 ("90%", resultcount * 90 // 100),
2833 2834 ("95%", resultcount * 95 // 100),
2834 2835 ("99%", resultcount * 99 // 100),
2835 2836 ("99.9%", resultcount * 999 // 1000),
2836 2837 ("99.99%", resultcount * 9999 // 10000),
2837 2838 ("99.999%", resultcount * 99999 // 100000),
2838 2839 ("max", -1),
2839 2840 ]
2840 2841 if not ui.quiet:
2841 2842 for name, idx in relevants:
2842 2843 data = results[idx]
2843 2844 title = '%s of %d, rev %d' % (name, resultcount, data[0])
2844 2845 formatone(fm, data[1], title=title, displayall=displayall)
2845 2846
2846 2847 # XXX summing that many float will not be very precise, we ignore this fact
2847 2848 # for now
2848 2849 totaltime = []
2849 2850 for item in allresults:
2850 2851 totaltime.append(
2851 2852 (
2852 2853 sum(x[1][0] for x in item),
2853 2854 sum(x[1][1] for x in item),
2854 2855 sum(x[1][2] for x in item),
2855 2856 )
2856 2857 )
2857 2858 formatone(
2858 2859 fm,
2859 2860 totaltime,
2860 2861 title="total time (%d revs)" % resultcount,
2861 2862 displayall=displayall,
2862 2863 )
2863 2864 fm.end()
2864 2865
2865 2866
2866 2867 class _faketr(object):
2867 2868 def add(s, x, y, z=None):
2868 2869 return None
2869 2870
2870 2871
2871 2872 def _timeonewrite(
2872 2873 ui,
2873 2874 orig,
2874 2875 source,
2875 2876 startrev,
2876 2877 stoprev,
2877 2878 runidx=None,
2878 2879 lazydeltabase=True,
2879 2880 clearcaches=True,
2880 2881 ):
2881 2882 timings = []
2882 2883 tr = _faketr()
2883 2884 with _temprevlog(ui, orig, startrev) as dest:
2884 2885 dest._lazydeltabase = lazydeltabase
2885 2886 revs = list(orig.revs(startrev, stoprev))
2886 2887 total = len(revs)
2887 2888 topic = 'adding'
2888 2889 if runidx is not None:
2889 2890 topic += ' (run #%d)' % runidx
2890 2891 # Support both old and new progress API
2891 2892 if util.safehasattr(ui, 'makeprogress'):
2892 2893 progress = ui.makeprogress(topic, unit='revs', total=total)
2893 2894
2894 2895 def updateprogress(pos):
2895 2896 progress.update(pos)
2896 2897
2897 2898 def completeprogress():
2898 2899 progress.complete()
2899 2900
2900 2901 else:
2901 2902
2902 2903 def updateprogress(pos):
2903 2904 ui.progress(topic, pos, unit='revs', total=total)
2904 2905
2905 2906 def completeprogress():
2906 2907 ui.progress(topic, None, unit='revs', total=total)
2907 2908
2908 2909 for idx, rev in enumerate(revs):
2909 2910 updateprogress(idx)
2910 2911 addargs, addkwargs = _getrevisionseed(orig, rev, tr, source)
2911 2912 if clearcaches:
2912 2913 dest.index.clearcaches()
2913 2914 dest.clearcaches()
2914 2915 with timeone() as r:
2915 2916 dest.addrawrevision(*addargs, **addkwargs)
2916 2917 timings.append((rev, r[0]))
2917 2918 updateprogress(total)
2918 2919 completeprogress()
2919 2920 return timings
2920 2921
2921 2922
2922 2923 def _getrevisionseed(orig, rev, tr, source):
2923 2924 from mercurial.node import nullid
2924 2925
2925 2926 linkrev = orig.linkrev(rev)
2926 2927 node = orig.node(rev)
2927 2928 p1, p2 = orig.parents(node)
2928 2929 flags = orig.flags(rev)
2929 2930 cachedelta = None
2930 2931 text = None
2931 2932
2932 2933 if source == b'full':
2933 2934 text = orig.revision(rev)
2934 2935 elif source == b'parent-1':
2935 2936 baserev = orig.rev(p1)
2936 2937 cachedelta = (baserev, orig.revdiff(p1, rev))
2937 2938 elif source == b'parent-2':
2938 2939 parent = p2
2939 2940 if p2 == nullid:
2940 2941 parent = p1
2941 2942 baserev = orig.rev(parent)
2942 2943 cachedelta = (baserev, orig.revdiff(parent, rev))
2943 2944 elif source == b'parent-smallest':
2944 2945 p1diff = orig.revdiff(p1, rev)
2945 2946 parent = p1
2946 2947 diff = p1diff
2947 2948 if p2 != nullid:
2948 2949 p2diff = orig.revdiff(p2, rev)
2949 2950 if len(p1diff) > len(p2diff):
2950 2951 parent = p2
2951 2952 diff = p2diff
2952 2953 baserev = orig.rev(parent)
2953 2954 cachedelta = (baserev, diff)
2954 2955 elif source == b'storage':
2955 2956 baserev = orig.deltaparent(rev)
2956 2957 cachedelta = (baserev, orig.revdiff(orig.node(baserev), rev))
2957 2958
2958 2959 return (
2959 2960 (text, tr, linkrev, p1, p2),
2960 2961 {'node': node, 'flags': flags, 'cachedelta': cachedelta},
2961 2962 )
2962 2963
2963 2964
2964 2965 @contextlib.contextmanager
2965 2966 def _temprevlog(ui, orig, truncaterev):
2966 2967 from mercurial import vfs as vfsmod
2967 2968
2968 2969 if orig._inline:
2969 2970 raise error.Abort('not supporting inline revlog (yet)')
2970 2971 revlogkwargs = {}
2971 2972 k = 'upperboundcomp'
2972 2973 if util.safehasattr(orig, k):
2973 2974 revlogkwargs[k] = getattr(orig, k)
2974 2975
2975 2976 origindexpath = orig.opener.join(orig.indexfile)
2976 2977 origdatapath = orig.opener.join(orig.datafile)
2977 2978 indexname = 'revlog.i'
2978 2979 dataname = 'revlog.d'
2979 2980
2980 2981 tmpdir = tempfile.mkdtemp(prefix='tmp-hgperf-')
2981 2982 try:
2982 2983 # copy the data file in a temporary directory
2983 2984 ui.debug('copying data in %s\n' % tmpdir)
2984 2985 destindexpath = os.path.join(tmpdir, 'revlog.i')
2985 2986 destdatapath = os.path.join(tmpdir, 'revlog.d')
2986 2987 shutil.copyfile(origindexpath, destindexpath)
2987 2988 shutil.copyfile(origdatapath, destdatapath)
2988 2989
2989 2990 # remove the data we want to add again
2990 2991 ui.debug('truncating data to be rewritten\n')
2991 2992 with open(destindexpath, 'ab') as index:
2992 2993 index.seek(0)
2993 2994 index.truncate(truncaterev * orig._io.size)
2994 2995 with open(destdatapath, 'ab') as data:
2995 2996 data.seek(0)
2996 2997 data.truncate(orig.start(truncaterev))
2997 2998
2998 2999 # instantiate a new revlog from the temporary copy
2999 3000 ui.debug('truncating adding to be rewritten\n')
3000 3001 vfs = vfsmod.vfs(tmpdir)
3001 3002 vfs.options = getattr(orig.opener, 'options', None)
3002 3003
3003 3004 dest = revlog.revlog(
3004 3005 vfs, indexfile=indexname, datafile=dataname, **revlogkwargs
3005 3006 )
3006 3007 if dest._inline:
3007 3008 raise error.Abort('not supporting inline revlog (yet)')
3008 3009 # make sure internals are initialized
3009 3010 dest.revision(len(dest) - 1)
3010 3011 yield dest
3011 3012 del dest, vfs
3012 3013 finally:
3013 3014 shutil.rmtree(tmpdir, True)
3014 3015
3015 3016
3016 3017 @command(
3017 3018 b'perfrevlogchunks',
3018 3019 revlogopts
3019 3020 + formatteropts
3020 3021 + [
3021 3022 (b'e', b'engines', b'', b'compression engines to use'),
3022 3023 (b's', b'startrev', 0, b'revision to start at'),
3023 3024 ],
3024 3025 b'-c|-m|FILE',
3025 3026 )
3026 3027 def perfrevlogchunks(ui, repo, file_=None, engines=None, startrev=0, **opts):
3027 3028 """Benchmark operations on revlog chunks.
3028 3029
3029 3030 Logically, each revlog is a collection of fulltext revisions. However,
3030 3031 stored within each revlog are "chunks" of possibly compressed data. This
3031 3032 data needs to be read and decompressed or compressed and written.
3032 3033
3033 3034 This command measures the time it takes to read+decompress and recompress
3034 3035 chunks in a revlog. It effectively isolates I/O and compression performance.
3035 3036 For measurements of higher-level operations like resolving revisions,
3036 3037 see ``perfrevlogrevisions`` and ``perfrevlogrevision``.
3037 3038 """
3038 3039 opts = _byteskwargs(opts)
3039 3040
3040 3041 rl = cmdutil.openrevlog(repo, b'perfrevlogchunks', file_, opts)
3041 3042
3042 3043 # _chunkraw was renamed to _getsegmentforrevs.
3043 3044 try:
3044 3045 segmentforrevs = rl._getsegmentforrevs
3045 3046 except AttributeError:
3046 3047 segmentforrevs = rl._chunkraw
3047 3048
3048 3049 # Verify engines argument.
3049 3050 if engines:
3050 engines = set(e.strip() for e in engines.split(b','))
3051 engines = {e.strip() for e in engines.split(b',')}
3051 3052 for engine in engines:
3052 3053 try:
3053 3054 util.compressionengines[engine]
3054 3055 except KeyError:
3055 3056 raise error.Abort(b'unknown compression engine: %s' % engine)
3056 3057 else:
3057 3058 engines = []
3058 3059 for e in util.compengines:
3059 3060 engine = util.compengines[e]
3060 3061 try:
3061 3062 if engine.available():
3062 3063 engine.revlogcompressor().compress(b'dummy')
3063 3064 engines.append(e)
3064 3065 except NotImplementedError:
3065 3066 pass
3066 3067
3067 3068 revs = list(rl.revs(startrev, len(rl) - 1))
3068 3069
3069 3070 def rlfh(rl):
3070 3071 if rl._inline:
3071 3072 return getsvfs(repo)(rl.indexfile)
3072 3073 else:
3073 3074 return getsvfs(repo)(rl.datafile)
3074 3075
3075 3076 def doread():
3076 3077 rl.clearcaches()
3077 3078 for rev in revs:
3078 3079 segmentforrevs(rev, rev)
3079 3080
3080 3081 def doreadcachedfh():
3081 3082 rl.clearcaches()
3082 3083 fh = rlfh(rl)
3083 3084 for rev in revs:
3084 3085 segmentforrevs(rev, rev, df=fh)
3085 3086
3086 3087 def doreadbatch():
3087 3088 rl.clearcaches()
3088 3089 segmentforrevs(revs[0], revs[-1])
3089 3090
3090 3091 def doreadbatchcachedfh():
3091 3092 rl.clearcaches()
3092 3093 fh = rlfh(rl)
3093 3094 segmentforrevs(revs[0], revs[-1], df=fh)
3094 3095
3095 3096 def dochunk():
3096 3097 rl.clearcaches()
3097 3098 fh = rlfh(rl)
3098 3099 for rev in revs:
3099 3100 rl._chunk(rev, df=fh)
3100 3101
3101 3102 chunks = [None]
3102 3103
3103 3104 def dochunkbatch():
3104 3105 rl.clearcaches()
3105 3106 fh = rlfh(rl)
3106 3107 # Save chunks as a side-effect.
3107 3108 chunks[0] = rl._chunks(revs, df=fh)
3108 3109
3109 3110 def docompress(compressor):
3110 3111 rl.clearcaches()
3111 3112
3112 3113 try:
3113 3114 # Swap in the requested compression engine.
3114 3115 oldcompressor = rl._compressor
3115 3116 rl._compressor = compressor
3116 3117 for chunk in chunks[0]:
3117 3118 rl.compress(chunk)
3118 3119 finally:
3119 3120 rl._compressor = oldcompressor
3120 3121
3121 3122 benches = [
3122 3123 (lambda: doread(), b'read'),
3123 3124 (lambda: doreadcachedfh(), b'read w/ reused fd'),
3124 3125 (lambda: doreadbatch(), b'read batch'),
3125 3126 (lambda: doreadbatchcachedfh(), b'read batch w/ reused fd'),
3126 3127 (lambda: dochunk(), b'chunk'),
3127 3128 (lambda: dochunkbatch(), b'chunk batch'),
3128 3129 ]
3129 3130
3130 3131 for engine in sorted(engines):
3131 3132 compressor = util.compengines[engine].revlogcompressor()
3132 3133 benches.append(
3133 3134 (
3134 3135 functools.partial(docompress, compressor),
3135 3136 b'compress w/ %s' % engine,
3136 3137 )
3137 3138 )
3138 3139
3139 3140 for fn, title in benches:
3140 3141 timer, fm = gettimer(ui, opts)
3141 3142 timer(fn, title=title)
3142 3143 fm.end()
3143 3144
3144 3145
3145 3146 @command(
3146 3147 b'perfrevlogrevision',
3147 3148 revlogopts
3148 3149 + formatteropts
3149 3150 + [(b'', b'cache', False, b'use caches instead of clearing')],
3150 3151 b'-c|-m|FILE REV',
3151 3152 )
3152 3153 def perfrevlogrevision(ui, repo, file_, rev=None, cache=None, **opts):
3153 3154 """Benchmark obtaining a revlog revision.
3154 3155
3155 3156 Obtaining a revlog revision consists of roughly the following steps:
3156 3157
3157 3158 1. Compute the delta chain
3158 3159 2. Slice the delta chain if applicable
3159 3160 3. Obtain the raw chunks for that delta chain
3160 3161 4. Decompress each raw chunk
3161 3162 5. Apply binary patches to obtain fulltext
3162 3163 6. Verify hash of fulltext
3163 3164
3164 3165 This command measures the time spent in each of these phases.
3165 3166 """
3166 3167 opts = _byteskwargs(opts)
3167 3168
3168 3169 if opts.get(b'changelog') or opts.get(b'manifest'):
3169 3170 file_, rev = None, file_
3170 3171 elif rev is None:
3171 3172 raise error.CommandError(b'perfrevlogrevision', b'invalid arguments')
3172 3173
3173 3174 r = cmdutil.openrevlog(repo, b'perfrevlogrevision', file_, opts)
3174 3175
3175 3176 # _chunkraw was renamed to _getsegmentforrevs.
3176 3177 try:
3177 3178 segmentforrevs = r._getsegmentforrevs
3178 3179 except AttributeError:
3179 3180 segmentforrevs = r._chunkraw
3180 3181
3181 3182 node = r.lookup(rev)
3182 3183 rev = r.rev(node)
3183 3184
3184 3185 def getrawchunks(data, chain):
3185 3186 start = r.start
3186 3187 length = r.length
3187 3188 inline = r._inline
3188 3189 iosize = r._io.size
3189 3190 buffer = util.buffer
3190 3191
3191 3192 chunks = []
3192 3193 ladd = chunks.append
3193 3194 for idx, item in enumerate(chain):
3194 3195 offset = start(item[0])
3195 3196 bits = data[idx]
3196 3197 for rev in item:
3197 3198 chunkstart = start(rev)
3198 3199 if inline:
3199 3200 chunkstart += (rev + 1) * iosize
3200 3201 chunklength = length(rev)
3201 3202 ladd(buffer(bits, chunkstart - offset, chunklength))
3202 3203
3203 3204 return chunks
3204 3205
3205 3206 def dodeltachain(rev):
3206 3207 if not cache:
3207 3208 r.clearcaches()
3208 3209 r._deltachain(rev)
3209 3210
3210 3211 def doread(chain):
3211 3212 if not cache:
3212 3213 r.clearcaches()
3213 3214 for item in slicedchain:
3214 3215 segmentforrevs(item[0], item[-1])
3215 3216
3216 3217 def doslice(r, chain, size):
3217 3218 for s in slicechunk(r, chain, targetsize=size):
3218 3219 pass
3219 3220
3220 3221 def dorawchunks(data, chain):
3221 3222 if not cache:
3222 3223 r.clearcaches()
3223 3224 getrawchunks(data, chain)
3224 3225
3225 3226 def dodecompress(chunks):
3226 3227 decomp = r.decompress
3227 3228 for chunk in chunks:
3228 3229 decomp(chunk)
3229 3230
3230 3231 def dopatch(text, bins):
3231 3232 if not cache:
3232 3233 r.clearcaches()
3233 3234 mdiff.patches(text, bins)
3234 3235
3235 3236 def dohash(text):
3236 3237 if not cache:
3237 3238 r.clearcaches()
3238 3239 r.checkhash(text, node, rev=rev)
3239 3240
3240 3241 def dorevision():
3241 3242 if not cache:
3242 3243 r.clearcaches()
3243 3244 r.revision(node)
3244 3245
3245 3246 try:
3246 3247 from mercurial.revlogutils.deltas import slicechunk
3247 3248 except ImportError:
3248 3249 slicechunk = getattr(revlog, '_slicechunk', None)
3249 3250
3250 3251 size = r.length(rev)
3251 3252 chain = r._deltachain(rev)[0]
3252 3253 if not getattr(r, '_withsparseread', False):
3253 3254 slicedchain = (chain,)
3254 3255 else:
3255 3256 slicedchain = tuple(slicechunk(r, chain, targetsize=size))
3256 3257 data = [segmentforrevs(seg[0], seg[-1])[1] for seg in slicedchain]
3257 3258 rawchunks = getrawchunks(data, slicedchain)
3258 3259 bins = r._chunks(chain)
3259 3260 text = bytes(bins[0])
3260 3261 bins = bins[1:]
3261 3262 text = mdiff.patches(text, bins)
3262 3263
3263 3264 benches = [
3264 3265 (lambda: dorevision(), b'full'),
3265 3266 (lambda: dodeltachain(rev), b'deltachain'),
3266 3267 (lambda: doread(chain), b'read'),
3267 3268 ]
3268 3269
3269 3270 if getattr(r, '_withsparseread', False):
3270 3271 slicing = (lambda: doslice(r, chain, size), b'slice-sparse-chain')
3271 3272 benches.append(slicing)
3272 3273
3273 3274 benches.extend(
3274 3275 [
3275 3276 (lambda: dorawchunks(data, slicedchain), b'rawchunks'),
3276 3277 (lambda: dodecompress(rawchunks), b'decompress'),
3277 3278 (lambda: dopatch(text, bins), b'patch'),
3278 3279 (lambda: dohash(text), b'hash'),
3279 3280 ]
3280 3281 )
3281 3282
3282 3283 timer, fm = gettimer(ui, opts)
3283 3284 for fn, title in benches:
3284 3285 timer(fn, title=title)
3285 3286 fm.end()
3286 3287
3287 3288
3288 3289 @command(
3289 3290 b'perfrevset',
3290 3291 [
3291 3292 (b'C', b'clear', False, b'clear volatile cache between each call.'),
3292 3293 (b'', b'contexts', False, b'obtain changectx for each revision'),
3293 3294 ]
3294 3295 + formatteropts,
3295 3296 b"REVSET",
3296 3297 )
3297 3298 def perfrevset(ui, repo, expr, clear=False, contexts=False, **opts):
3298 3299 """benchmark the execution time of a revset
3299 3300
3300 3301 Use the --clean option if need to evaluate the impact of build volatile
3301 3302 revisions set cache on the revset execution. Volatile cache hold filtered
3302 3303 and obsolete related cache."""
3303 3304 opts = _byteskwargs(opts)
3304 3305
3305 3306 timer, fm = gettimer(ui, opts)
3306 3307
3307 3308 def d():
3308 3309 if clear:
3309 3310 repo.invalidatevolatilesets()
3310 3311 if contexts:
3311 3312 for ctx in repo.set(expr):
3312 3313 pass
3313 3314 else:
3314 3315 for r in repo.revs(expr):
3315 3316 pass
3316 3317
3317 3318 timer(d)
3318 3319 fm.end()
3319 3320
3320 3321
3321 3322 @command(
3322 3323 b'perfvolatilesets',
3323 3324 [(b'', b'clear-obsstore', False, b'drop obsstore between each call.'),]
3324 3325 + formatteropts,
3325 3326 )
3326 3327 def perfvolatilesets(ui, repo, *names, **opts):
3327 3328 """benchmark the computation of various volatile set
3328 3329
3329 3330 Volatile set computes element related to filtering and obsolescence."""
3330 3331 opts = _byteskwargs(opts)
3331 3332 timer, fm = gettimer(ui, opts)
3332 3333 repo = repo.unfiltered()
3333 3334
3334 3335 def getobs(name):
3335 3336 def d():
3336 3337 repo.invalidatevolatilesets()
3337 3338 if opts[b'clear_obsstore']:
3338 3339 clearfilecache(repo, b'obsstore')
3339 3340 obsolete.getrevs(repo, name)
3340 3341
3341 3342 return d
3342 3343
3343 3344 allobs = sorted(obsolete.cachefuncs)
3344 3345 if names:
3345 3346 allobs = [n for n in allobs if n in names]
3346 3347
3347 3348 for name in allobs:
3348 3349 timer(getobs(name), title=name)
3349 3350
3350 3351 def getfiltered(name):
3351 3352 def d():
3352 3353 repo.invalidatevolatilesets()
3353 3354 if opts[b'clear_obsstore']:
3354 3355 clearfilecache(repo, b'obsstore')
3355 3356 repoview.filterrevs(repo, name)
3356 3357
3357 3358 return d
3358 3359
3359 3360 allfilter = sorted(repoview.filtertable)
3360 3361 if names:
3361 3362 allfilter = [n for n in allfilter if n in names]
3362 3363
3363 3364 for name in allfilter:
3364 3365 timer(getfiltered(name), title=name)
3365 3366 fm.end()
3366 3367
3367 3368
3368 3369 @command(
3369 3370 b'perfbranchmap',
3370 3371 [
3371 3372 (b'f', b'full', False, b'Includes build time of subset'),
3372 3373 (
3373 3374 b'',
3374 3375 b'clear-revbranch',
3375 3376 False,
3376 3377 b'purge the revbranch cache between computation',
3377 3378 ),
3378 3379 ]
3379 3380 + formatteropts,
3380 3381 )
3381 3382 def perfbranchmap(ui, repo, *filternames, **opts):
3382 3383 """benchmark the update of a branchmap
3383 3384
3384 3385 This benchmarks the full repo.branchmap() call with read and write disabled
3385 3386 """
3386 3387 opts = _byteskwargs(opts)
3387 3388 full = opts.get(b"full", False)
3388 3389 clear_revbranch = opts.get(b"clear_revbranch", False)
3389 3390 timer, fm = gettimer(ui, opts)
3390 3391
3391 3392 def getbranchmap(filtername):
3392 3393 """generate a benchmark function for the filtername"""
3393 3394 if filtername is None:
3394 3395 view = repo
3395 3396 else:
3396 3397 view = repo.filtered(filtername)
3397 3398 if util.safehasattr(view._branchcaches, '_per_filter'):
3398 3399 filtered = view._branchcaches._per_filter
3399 3400 else:
3400 3401 # older versions
3401 3402 filtered = view._branchcaches
3402 3403
3403 3404 def d():
3404 3405 if clear_revbranch:
3405 3406 repo.revbranchcache()._clear()
3406 3407 if full:
3407 3408 view._branchcaches.clear()
3408 3409 else:
3409 3410 filtered.pop(filtername, None)
3410 3411 view.branchmap()
3411 3412
3412 3413 return d
3413 3414
3414 3415 # add filter in smaller subset to bigger subset
3415 3416 possiblefilters = set(repoview.filtertable)
3416 3417 if filternames:
3417 3418 possiblefilters &= set(filternames)
3418 3419 subsettable = getbranchmapsubsettable()
3419 3420 allfilters = []
3420 3421 while possiblefilters:
3421 3422 for name in possiblefilters:
3422 3423 subset = subsettable.get(name)
3423 3424 if subset not in possiblefilters:
3424 3425 break
3425 3426 else:
3426 3427 assert False, b'subset cycle %s!' % possiblefilters
3427 3428 allfilters.append(name)
3428 3429 possiblefilters.remove(name)
3429 3430
3430 3431 # warm the cache
3431 3432 if not full:
3432 3433 for name in allfilters:
3433 3434 repo.filtered(name).branchmap()
3434 3435 if not filternames or b'unfiltered' in filternames:
3435 3436 # add unfiltered
3436 3437 allfilters.append(None)
3437 3438
3438 3439 if util.safehasattr(branchmap.branchcache, 'fromfile'):
3439 3440 branchcacheread = safeattrsetter(branchmap.branchcache, b'fromfile')
3440 3441 branchcacheread.set(classmethod(lambda *args: None))
3441 3442 else:
3442 3443 # older versions
3443 3444 branchcacheread = safeattrsetter(branchmap, b'read')
3444 3445 branchcacheread.set(lambda *args: None)
3445 3446 branchcachewrite = safeattrsetter(branchmap.branchcache, b'write')
3446 3447 branchcachewrite.set(lambda *args: None)
3447 3448 try:
3448 3449 for name in allfilters:
3449 3450 printname = name
3450 3451 if name is None:
3451 3452 printname = b'unfiltered'
3452 3453 timer(getbranchmap(name), title=str(printname))
3453 3454 finally:
3454 3455 branchcacheread.restore()
3455 3456 branchcachewrite.restore()
3456 3457 fm.end()
3457 3458
3458 3459
3459 3460 @command(
3460 3461 b'perfbranchmapupdate',
3461 3462 [
3462 3463 (b'', b'base', [], b'subset of revision to start from'),
3463 3464 (b'', b'target', [], b'subset of revision to end with'),
3464 3465 (b'', b'clear-caches', False, b'clear cache between each runs'),
3465 3466 ]
3466 3467 + formatteropts,
3467 3468 )
3468 3469 def perfbranchmapupdate(ui, repo, base=(), target=(), **opts):
3469 3470 """benchmark branchmap update from for <base> revs to <target> revs
3470 3471
3471 3472 If `--clear-caches` is passed, the following items will be reset before
3472 3473 each update:
3473 3474 * the changelog instance and associated indexes
3474 3475 * the rev-branch-cache instance
3475 3476
3476 3477 Examples:
3477 3478
3478 3479 # update for the one last revision
3479 3480 $ hg perfbranchmapupdate --base 'not tip' --target 'tip'
3480 3481
3481 3482 $ update for change coming with a new branch
3482 3483 $ hg perfbranchmapupdate --base 'stable' --target 'default'
3483 3484 """
3484 3485 from mercurial import branchmap
3485 3486 from mercurial import repoview
3486 3487
3487 3488 opts = _byteskwargs(opts)
3488 3489 timer, fm = gettimer(ui, opts)
3489 3490 clearcaches = opts[b'clear_caches']
3490 3491 unfi = repo.unfiltered()
3491 3492 x = [None] # used to pass data between closure
3492 3493
3493 3494 # we use a `list` here to avoid possible side effect from smartset
3494 3495 baserevs = list(scmutil.revrange(repo, base))
3495 3496 targetrevs = list(scmutil.revrange(repo, target))
3496 3497 if not baserevs:
3497 3498 raise error.Abort(b'no revisions selected for --base')
3498 3499 if not targetrevs:
3499 3500 raise error.Abort(b'no revisions selected for --target')
3500 3501
3501 3502 # make sure the target branchmap also contains the one in the base
3502 3503 targetrevs = list(set(baserevs) | set(targetrevs))
3503 3504 targetrevs.sort()
3504 3505
3505 3506 cl = repo.changelog
3506 3507 allbaserevs = list(cl.ancestors(baserevs, inclusive=True))
3507 3508 allbaserevs.sort()
3508 3509 alltargetrevs = frozenset(cl.ancestors(targetrevs, inclusive=True))
3509 3510
3510 3511 newrevs = list(alltargetrevs.difference(allbaserevs))
3511 3512 newrevs.sort()
3512 3513
3513 3514 allrevs = frozenset(unfi.changelog.revs())
3514 3515 basefilterrevs = frozenset(allrevs.difference(allbaserevs))
3515 3516 targetfilterrevs = frozenset(allrevs.difference(alltargetrevs))
3516 3517
3517 3518 def basefilter(repo, visibilityexceptions=None):
3518 3519 return basefilterrevs
3519 3520
3520 3521 def targetfilter(repo, visibilityexceptions=None):
3521 3522 return targetfilterrevs
3522 3523
3523 3524 msg = b'benchmark of branchmap with %d revisions with %d new ones\n'
3524 3525 ui.status(msg % (len(allbaserevs), len(newrevs)))
3525 3526 if targetfilterrevs:
3526 3527 msg = b'(%d revisions still filtered)\n'
3527 3528 ui.status(msg % len(targetfilterrevs))
3528 3529
3529 3530 try:
3530 3531 repoview.filtertable[b'__perf_branchmap_update_base'] = basefilter
3531 3532 repoview.filtertable[b'__perf_branchmap_update_target'] = targetfilter
3532 3533
3533 3534 baserepo = repo.filtered(b'__perf_branchmap_update_base')
3534 3535 targetrepo = repo.filtered(b'__perf_branchmap_update_target')
3535 3536
3536 3537 # try to find an existing branchmap to reuse
3537 3538 subsettable = getbranchmapsubsettable()
3538 3539 candidatefilter = subsettable.get(None)
3539 3540 while candidatefilter is not None:
3540 3541 candidatebm = repo.filtered(candidatefilter).branchmap()
3541 3542 if candidatebm.validfor(baserepo):
3542 3543 filtered = repoview.filterrevs(repo, candidatefilter)
3543 3544 missing = [r for r in allbaserevs if r in filtered]
3544 3545 base = candidatebm.copy()
3545 3546 base.update(baserepo, missing)
3546 3547 break
3547 3548 candidatefilter = subsettable.get(candidatefilter)
3548 3549 else:
3549 3550 # no suitable subset where found
3550 3551 base = branchmap.branchcache()
3551 3552 base.update(baserepo, allbaserevs)
3552 3553
3553 3554 def setup():
3554 3555 x[0] = base.copy()
3555 3556 if clearcaches:
3556 3557 unfi._revbranchcache = None
3557 3558 clearchangelog(repo)
3558 3559
3559 3560 def bench():
3560 3561 x[0].update(targetrepo, newrevs)
3561 3562
3562 3563 timer(bench, setup=setup)
3563 3564 fm.end()
3564 3565 finally:
3565 3566 repoview.filtertable.pop(b'__perf_branchmap_update_base', None)
3566 3567 repoview.filtertable.pop(b'__perf_branchmap_update_target', None)
3567 3568
3568 3569
3569 3570 @command(
3570 3571 b'perfbranchmapload',
3571 3572 [
3572 3573 (b'f', b'filter', b'', b'Specify repoview filter'),
3573 3574 (b'', b'list', False, b'List brachmap filter caches'),
3574 3575 (b'', b'clear-revlogs', False, b'refresh changelog and manifest'),
3575 3576 ]
3576 3577 + formatteropts,
3577 3578 )
3578 3579 def perfbranchmapload(ui, repo, filter=b'', list=False, **opts):
3579 3580 """benchmark reading the branchmap"""
3580 3581 opts = _byteskwargs(opts)
3581 3582 clearrevlogs = opts[b'clear_revlogs']
3582 3583
3583 3584 if list:
3584 3585 for name, kind, st in repo.cachevfs.readdir(stat=True):
3585 3586 if name.startswith(b'branch2'):
3586 3587 filtername = name.partition(b'-')[2] or b'unfiltered'
3587 3588 ui.status(
3588 3589 b'%s - %s\n' % (filtername, util.bytecount(st.st_size))
3589 3590 )
3590 3591 return
3591 3592 if not filter:
3592 3593 filter = None
3593 3594 subsettable = getbranchmapsubsettable()
3594 3595 if filter is None:
3595 3596 repo = repo.unfiltered()
3596 3597 else:
3597 3598 repo = repoview.repoview(repo, filter)
3598 3599
3599 3600 repo.branchmap() # make sure we have a relevant, up to date branchmap
3600 3601
3601 3602 try:
3602 3603 fromfile = branchmap.branchcache.fromfile
3603 3604 except AttributeError:
3604 3605 # older versions
3605 3606 fromfile = branchmap.read
3606 3607
3607 3608 currentfilter = filter
3608 3609 # try once without timer, the filter may not be cached
3609 3610 while fromfile(repo) is None:
3610 3611 currentfilter = subsettable.get(currentfilter)
3611 3612 if currentfilter is None:
3612 3613 raise error.Abort(
3613 3614 b'No branchmap cached for %s repo' % (filter or b'unfiltered')
3614 3615 )
3615 3616 repo = repo.filtered(currentfilter)
3616 3617 timer, fm = gettimer(ui, opts)
3617 3618
3618 3619 def setup():
3619 3620 if clearrevlogs:
3620 3621 clearchangelog(repo)
3621 3622
3622 3623 def bench():
3623 3624 fromfile(repo)
3624 3625
3625 3626 timer(bench, setup=setup)
3626 3627 fm.end()
3627 3628
3628 3629
3629 3630 @command(b'perfloadmarkers')
3630 3631 def perfloadmarkers(ui, repo):
3631 3632 """benchmark the time to parse the on-disk markers for a repo
3632 3633
3633 3634 Result is the number of markers in the repo."""
3634 3635 timer, fm = gettimer(ui)
3635 3636 svfs = getsvfs(repo)
3636 3637 timer(lambda: len(obsolete.obsstore(svfs)))
3637 3638 fm.end()
3638 3639
3639 3640
3640 3641 @command(
3641 3642 b'perflrucachedict',
3642 3643 formatteropts
3643 3644 + [
3644 3645 (b'', b'costlimit', 0, b'maximum total cost of items in cache'),
3645 3646 (b'', b'mincost', 0, b'smallest cost of items in cache'),
3646 3647 (b'', b'maxcost', 100, b'maximum cost of items in cache'),
3647 3648 (b'', b'size', 4, b'size of cache'),
3648 3649 (b'', b'gets', 10000, b'number of key lookups'),
3649 3650 (b'', b'sets', 10000, b'number of key sets'),
3650 3651 (b'', b'mixed', 10000, b'number of mixed mode operations'),
3651 3652 (
3652 3653 b'',
3653 3654 b'mixedgetfreq',
3654 3655 50,
3655 3656 b'frequency of get vs set ops in mixed mode',
3656 3657 ),
3657 3658 ],
3658 3659 norepo=True,
3659 3660 )
3660 3661 def perflrucache(
3661 3662 ui,
3662 3663 mincost=0,
3663 3664 maxcost=100,
3664 3665 costlimit=0,
3665 3666 size=4,
3666 3667 gets=10000,
3667 3668 sets=10000,
3668 3669 mixed=10000,
3669 3670 mixedgetfreq=50,
3670 3671 **opts
3671 3672 ):
3672 3673 opts = _byteskwargs(opts)
3673 3674
3674 3675 def doinit():
3675 3676 for i in _xrange(10000):
3676 3677 util.lrucachedict(size)
3677 3678
3678 3679 costrange = list(range(mincost, maxcost + 1))
3679 3680
3680 3681 values = []
3681 3682 for i in _xrange(size):
3682 3683 values.append(random.randint(0, _maxint))
3683 3684
3684 3685 # Get mode fills the cache and tests raw lookup performance with no
3685 3686 # eviction.
3686 3687 getseq = []
3687 3688 for i in _xrange(gets):
3688 3689 getseq.append(random.choice(values))
3689 3690
3690 3691 def dogets():
3691 3692 d = util.lrucachedict(size)
3692 3693 for v in values:
3693 3694 d[v] = v
3694 3695 for key in getseq:
3695 3696 value = d[key]
3696 3697 value # silence pyflakes warning
3697 3698
3698 3699 def dogetscost():
3699 3700 d = util.lrucachedict(size, maxcost=costlimit)
3700 3701 for i, v in enumerate(values):
3701 3702 d.insert(v, v, cost=costs[i])
3702 3703 for key in getseq:
3703 3704 try:
3704 3705 value = d[key]
3705 3706 value # silence pyflakes warning
3706 3707 except KeyError:
3707 3708 pass
3708 3709
3709 3710 # Set mode tests insertion speed with cache eviction.
3710 3711 setseq = []
3711 3712 costs = []
3712 3713 for i in _xrange(sets):
3713 3714 setseq.append(random.randint(0, _maxint))
3714 3715 costs.append(random.choice(costrange))
3715 3716
3716 3717 def doinserts():
3717 3718 d = util.lrucachedict(size)
3718 3719 for v in setseq:
3719 3720 d.insert(v, v)
3720 3721
3721 3722 def doinsertscost():
3722 3723 d = util.lrucachedict(size, maxcost=costlimit)
3723 3724 for i, v in enumerate(setseq):
3724 3725 d.insert(v, v, cost=costs[i])
3725 3726
3726 3727 def dosets():
3727 3728 d = util.lrucachedict(size)
3728 3729 for v in setseq:
3729 3730 d[v] = v
3730 3731
3731 3732 # Mixed mode randomly performs gets and sets with eviction.
3732 3733 mixedops = []
3733 3734 for i in _xrange(mixed):
3734 3735 r = random.randint(0, 100)
3735 3736 if r < mixedgetfreq:
3736 3737 op = 0
3737 3738 else:
3738 3739 op = 1
3739 3740
3740 3741 mixedops.append(
3741 3742 (op, random.randint(0, size * 2), random.choice(costrange))
3742 3743 )
3743 3744
3744 3745 def domixed():
3745 3746 d = util.lrucachedict(size)
3746 3747
3747 3748 for op, v, cost in mixedops:
3748 3749 if op == 0:
3749 3750 try:
3750 3751 d[v]
3751 3752 except KeyError:
3752 3753 pass
3753 3754 else:
3754 3755 d[v] = v
3755 3756
3756 3757 def domixedcost():
3757 3758 d = util.lrucachedict(size, maxcost=costlimit)
3758 3759
3759 3760 for op, v, cost in mixedops:
3760 3761 if op == 0:
3761 3762 try:
3762 3763 d[v]
3763 3764 except KeyError:
3764 3765 pass
3765 3766 else:
3766 3767 d.insert(v, v, cost=cost)
3767 3768
3768 3769 benches = [
3769 3770 (doinit, b'init'),
3770 3771 ]
3771 3772
3772 3773 if costlimit:
3773 3774 benches.extend(
3774 3775 [
3775 3776 (dogetscost, b'gets w/ cost limit'),
3776 3777 (doinsertscost, b'inserts w/ cost limit'),
3777 3778 (domixedcost, b'mixed w/ cost limit'),
3778 3779 ]
3779 3780 )
3780 3781 else:
3781 3782 benches.extend(
3782 3783 [
3783 3784 (dogets, b'gets'),
3784 3785 (doinserts, b'inserts'),
3785 3786 (dosets, b'sets'),
3786 3787 (domixed, b'mixed'),
3787 3788 ]
3788 3789 )
3789 3790
3790 3791 for fn, title in benches:
3791 3792 timer, fm = gettimer(ui, opts)
3792 3793 timer(fn, title=title)
3793 3794 fm.end()
3794 3795
3795 3796
3796 3797 @command(b'perfwrite', formatteropts)
3797 3798 def perfwrite(ui, repo, **opts):
3798 3799 """microbenchmark ui.write
3799 3800 """
3800 3801 opts = _byteskwargs(opts)
3801 3802
3802 3803 timer, fm = gettimer(ui, opts)
3803 3804
3804 3805 def write():
3805 3806 for i in range(100000):
3806 3807 ui.writenoi18n(b'Testing write performance\n')
3807 3808
3808 3809 timer(write)
3809 3810 fm.end()
3810 3811
3811 3812
3812 3813 def uisetup(ui):
3813 3814 if util.safehasattr(cmdutil, b'openrevlog') and not util.safehasattr(
3814 3815 commands, b'debugrevlogopts'
3815 3816 ):
3816 3817 # for "historical portability":
3817 3818 # In this case, Mercurial should be 1.9 (or a79fea6b3e77) -
3818 3819 # 3.7 (or 5606f7d0d063). Therefore, '--dir' option for
3819 3820 # openrevlog() should cause failure, because it has been
3820 3821 # available since 3.5 (or 49c583ca48c4).
3821 3822 def openrevlog(orig, repo, cmd, file_, opts):
3822 3823 if opts.get(b'dir') and not util.safehasattr(repo, b'dirlog'):
3823 3824 raise error.Abort(
3824 3825 b"This version doesn't support --dir option",
3825 3826 hint=b"use 3.5 or later",
3826 3827 )
3827 3828 return orig(repo, cmd, file_, opts)
3828 3829
3829 3830 extensions.wrapfunction(cmdutil, b'openrevlog', openrevlog)
3830 3831
3831 3832
3832 3833 @command(
3833 3834 b'perfprogress',
3834 3835 formatteropts
3835 3836 + [
3836 3837 (b'', b'topic', b'topic', b'topic for progress messages'),
3837 3838 (b'c', b'total', 1000000, b'total value we are progressing to'),
3838 3839 ],
3839 3840 norepo=True,
3840 3841 )
3841 3842 def perfprogress(ui, topic=None, total=None, **opts):
3842 3843 """printing of progress bars"""
3843 3844 opts = _byteskwargs(opts)
3844 3845
3845 3846 timer, fm = gettimer(ui, opts)
3846 3847
3847 3848 def doprogress():
3848 3849 with ui.makeprogress(topic, total=total) as progress:
3849 3850 for i in _xrange(total):
3850 3851 progress.increment()
3851 3852
3852 3853 timer(doprogress)
3853 3854 fm.end()
@@ -1,225 +1,228 b''
1 1 # Copyright (c) 2016-present, Gregory Szorc
2 2 # All rights reserved.
3 3 #
4 4 # This software may be modified and distributed under the terms
5 5 # of the BSD license. See the LICENSE file for details.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import cffi
10 10 import distutils.ccompiler
11 11 import os
12 12 import re
13 13 import subprocess
14 14 import tempfile
15 15
16 16
17 17 HERE = os.path.abspath(os.path.dirname(__file__))
18 18
19 19 SOURCES = [
20 20 "zstd/%s" % p
21 21 for p in (
22 22 "common/debug.c",
23 23 "common/entropy_common.c",
24 24 "common/error_private.c",
25 25 "common/fse_decompress.c",
26 26 "common/pool.c",
27 27 "common/threading.c",
28 28 "common/xxhash.c",
29 29 "common/zstd_common.c",
30 30 "compress/fse_compress.c",
31 31 "compress/hist.c",
32 32 "compress/huf_compress.c",
33 33 "compress/zstd_compress.c",
34 34 "compress/zstd_compress_literals.c",
35 35 "compress/zstd_compress_sequences.c",
36 36 "compress/zstd_double_fast.c",
37 37 "compress/zstd_fast.c",
38 38 "compress/zstd_lazy.c",
39 39 "compress/zstd_ldm.c",
40 40 "compress/zstd_opt.c",
41 41 "compress/zstdmt_compress.c",
42 42 "decompress/huf_decompress.c",
43 43 "decompress/zstd_ddict.c",
44 44 "decompress/zstd_decompress.c",
45 45 "decompress/zstd_decompress_block.c",
46 46 "dictBuilder/cover.c",
47 47 "dictBuilder/fastcover.c",
48 48 "dictBuilder/divsufsort.c",
49 49 "dictBuilder/zdict.c",
50 50 )
51 51 ]
52 52
53 53 # Headers whose preprocessed output will be fed into cdef().
54 54 HEADERS = [
55 os.path.join(HERE, "zstd", *p) for p in (("zstd.h",), ("dictBuilder", "zdict.h"),)
55 os.path.join(HERE, "zstd", *p)
56 for p in (("zstd.h",), ("dictBuilder", "zdict.h"),)
56 57 ]
57 58
58 59 INCLUDE_DIRS = [
59 60 os.path.join(HERE, d)
60 61 for d in (
61 62 "zstd",
62 63 "zstd/common",
63 64 "zstd/compress",
64 65 "zstd/decompress",
65 66 "zstd/dictBuilder",
66 67 )
67 68 ]
68 69
69 70 # cffi can't parse some of the primitives in zstd.h. So we invoke the
70 71 # preprocessor and feed its output into cffi.
71 72 compiler = distutils.ccompiler.new_compiler()
72 73
73 74 # Needed for MSVC.
74 75 if hasattr(compiler, "initialize"):
75 76 compiler.initialize()
76 77
77 78 # Distutils doesn't set compiler.preprocessor, so invoke the preprocessor
78 79 # manually.
79 80 if compiler.compiler_type == "unix":
80 81 args = list(compiler.executables["compiler"])
81 82 args.extend(
82 83 ["-E", "-DZSTD_STATIC_LINKING_ONLY", "-DZDICT_STATIC_LINKING_ONLY",]
83 84 )
84 85 elif compiler.compiler_type == "msvc":
85 86 args = [compiler.cc]
86 87 args.extend(
87 88 ["/EP", "/DZSTD_STATIC_LINKING_ONLY", "/DZDICT_STATIC_LINKING_ONLY",]
88 89 )
89 90 else:
90 91 raise Exception("unsupported compiler type: %s" % compiler.compiler_type)
91 92
92 93
93 94 def preprocess(path):
94 95 with open(path, "rb") as fh:
95 96 lines = []
96 97 it = iter(fh)
97 98
98 99 for l in it:
99 100 # zstd.h includes <stddef.h>, which is also included by cffi's
100 101 # boilerplate. This can lead to duplicate declarations. So we strip
101 102 # this include from the preprocessor invocation.
102 103 #
103 104 # The same things happens for including zstd.h, so give it the same
104 105 # treatment.
105 106 #
106 107 # We define ZSTD_STATIC_LINKING_ONLY, which is redundant with the inline
107 108 # #define in zstdmt_compress.h and results in a compiler warning. So drop
108 109 # the inline #define.
109 110 if l.startswith(
110 111 (
111 112 b"#include <stddef.h>",
112 113 b'#include "zstd.h"',
113 114 b"#define ZSTD_STATIC_LINKING_ONLY",
114 115 )
115 116 ):
116 117 continue
117 118
118 119 # The preprocessor environment on Windows doesn't define include
119 120 # paths, so the #include of limits.h fails. We work around this
120 121 # by removing that import and defining INT_MAX ourselves. This is
121 122 # a bit hacky. But it gets the job done.
122 123 # TODO make limits.h work on Windows so we ensure INT_MAX is
123 124 # correct.
124 125 if l.startswith(b"#include <limits.h>"):
125 126 l = b"#define INT_MAX 2147483647\n"
126 127
127 128 # ZSTDLIB_API may not be defined if we dropped zstd.h. It isn't
128 129 # important so just filter it out.
129 130 if l.startswith(b"ZSTDLIB_API"):
130 131 l = l[len(b"ZSTDLIB_API ") :]
131 132
132 133 lines.append(l)
133 134
134 135 fd, input_file = tempfile.mkstemp(suffix=".h")
135 136 os.write(fd, b"".join(lines))
136 137 os.close(fd)
137 138
138 139 try:
139 140 env = dict(os.environ)
140 141 if getattr(compiler, "_paths", None):
141 142 env["PATH"] = compiler._paths
142 process = subprocess.Popen(args + [input_file], stdout=subprocess.PIPE, env=env)
143 process = subprocess.Popen(
144 args + [input_file], stdout=subprocess.PIPE, env=env
145 )
143 146 output = process.communicate()[0]
144 147 ret = process.poll()
145 148 if ret:
146 149 raise Exception("preprocessor exited with error")
147 150
148 151 return output
149 152 finally:
150 153 os.unlink(input_file)
151 154
152 155
153 156 def normalize_output(output):
154 157 lines = []
155 158 for line in output.splitlines():
156 159 # CFFI's parser doesn't like __attribute__ on UNIX compilers.
157 160 if line.startswith(b'__attribute__ ((visibility ("default"))) '):
158 161 line = line[len(b'__attribute__ ((visibility ("default"))) ') :]
159 162
160 163 if line.startswith(b"__attribute__((deprecated("):
161 164 continue
162 165 elif b"__declspec(deprecated(" in line:
163 166 continue
164 167
165 168 lines.append(line)
166 169
167 170 return b"\n".join(lines)
168 171
169 172
170 173 ffi = cffi.FFI()
171 174 # zstd.h uses a possible undefined MIN(). Define it until
172 175 # https://github.com/facebook/zstd/issues/976 is fixed.
173 176 # *_DISABLE_DEPRECATE_WARNINGS prevents the compiler from emitting a warning
174 177 # when cffi uses the function. Since we statically link against zstd, even
175 178 # if we use the deprecated functions it shouldn't be a huge problem.
176 179 ffi.set_source(
177 180 "_zstd_cffi",
178 181 """
179 182 #define MIN(a,b) ((a)<(b) ? (a) : (b))
180 183 #define ZSTD_STATIC_LINKING_ONLY
181 184 #include <zstd.h>
182 185 #define ZDICT_STATIC_LINKING_ONLY
183 186 #define ZDICT_DISABLE_DEPRECATE_WARNINGS
184 187 #include <zdict.h>
185 188 """,
186 189 sources=SOURCES,
187 190 include_dirs=INCLUDE_DIRS,
188 191 extra_compile_args=["-DZSTD_MULTITHREAD"],
189 192 )
190 193
191 194 DEFINE = re.compile(b"^\\#define ([a-zA-Z0-9_]+) ")
192 195
193 196 sources = []
194 197
195 198 # Feed normalized preprocessor output for headers into the cdef parser.
196 199 for header in HEADERS:
197 200 preprocessed = preprocess(header)
198 201 sources.append(normalize_output(preprocessed))
199 202
200 203 # #define's are effectively erased as part of going through preprocessor.
201 204 # So perform a manual pass to re-add those to the cdef source.
202 205 with open(header, "rb") as fh:
203 206 for line in fh:
204 207 line = line.strip()
205 208 m = DEFINE.match(line)
206 209 if not m:
207 210 continue
208 211
209 212 if m.group(1) == b"ZSTD_STATIC_LINKING_ONLY":
210 213 continue
211 214
212 215 # The parser doesn't like some constants with complex values.
213 216 if m.group(1) in (b"ZSTD_LIB_VERSION", b"ZSTD_VERSION_STRING"):
214 217 continue
215 218
216 219 # The ... is magic syntax by the cdef parser to resolve the
217 220 # value at compile time.
218 221 sources.append(m.group(0) + b" ...")
219 222
220 223 cdeflines = b"\n".join(sources).splitlines()
221 224 cdeflines = [l for l in cdeflines if l.strip()]
222 225 ffi.cdef(b"\n".join(cdeflines).decode("latin1"))
223 226
224 227 if __name__ == "__main__":
225 228 ffi.compile()
@@ -1,118 +1,120 b''
1 1 #!/usr/bin/env python
2 2 # Copyright (c) 2016-present, Gregory Szorc
3 3 # All rights reserved.
4 4 #
5 5 # This software may be modified and distributed under the terms
6 6 # of the BSD license. See the LICENSE file for details.
7 7
8 8 from __future__ import print_function
9 9
10 10 from distutils.version import LooseVersion
11 11 import os
12 12 import sys
13 13 from setuptools import setup
14 14
15 15 # Need change in 1.10 for ffi.from_buffer() to handle all buffer types
16 16 # (like memoryview).
17 17 # Need feature in 1.11 for ffi.gc() to declare size of objects so we avoid
18 18 # garbage collection pitfalls.
19 19 MINIMUM_CFFI_VERSION = "1.11"
20 20
21 21 try:
22 22 import cffi
23 23
24 24 # PyPy (and possibly other distros) have CFFI distributed as part of
25 25 # them. The install_requires for CFFI below won't work. We need to sniff
26 26 # out the CFFI version here and reject CFFI if it is too old.
27 27 cffi_version = LooseVersion(cffi.__version__)
28 28 if cffi_version < LooseVersion(MINIMUM_CFFI_VERSION):
29 29 print(
30 30 "CFFI 1.11 or newer required (%s found); "
31 31 "not building CFFI backend" % cffi_version,
32 32 file=sys.stderr,
33 33 )
34 34 cffi = None
35 35
36 36 except ImportError:
37 37 cffi = None
38 38
39 39 import setup_zstd
40 40
41 41 SUPPORT_LEGACY = False
42 42 SYSTEM_ZSTD = False
43 43 WARNINGS_AS_ERRORS = False
44 44
45 45 if os.environ.get("ZSTD_WARNINGS_AS_ERRORS", ""):
46 46 WARNINGS_AS_ERRORS = True
47 47
48 48 if "--legacy" in sys.argv:
49 49 SUPPORT_LEGACY = True
50 50 sys.argv.remove("--legacy")
51 51
52 52 if "--system-zstd" in sys.argv:
53 53 SYSTEM_ZSTD = True
54 54 sys.argv.remove("--system-zstd")
55 55
56 56 if "--warnings-as-errors" in sys.argv:
57 57 WARNINGS_AS_ERRORS = True
58 58 sys.argv.remove("--warning-as-errors")
59 59
60 60 # Code for obtaining the Extension instance is in its own module to
61 61 # facilitate reuse in other projects.
62 62 extensions = [
63 63 setup_zstd.get_c_extension(
64 64 name="zstd",
65 65 support_legacy=SUPPORT_LEGACY,
66 66 system_zstd=SYSTEM_ZSTD,
67 67 warnings_as_errors=WARNINGS_AS_ERRORS,
68 68 ),
69 69 ]
70 70
71 71 install_requires = []
72 72
73 73 if cffi:
74 74 import make_cffi
75 75
76 76 extensions.append(make_cffi.ffi.distutils_extension())
77 77 install_requires.append("cffi>=%s" % MINIMUM_CFFI_VERSION)
78 78
79 79 version = None
80 80
81 81 with open("c-ext/python-zstandard.h", "r") as fh:
82 82 for line in fh:
83 83 if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"):
84 84 continue
85 85
86 86 version = line.split()[2][1:-1]
87 87 break
88 88
89 89 if not version:
90 raise Exception("could not resolve package version; " "this should never happen")
90 raise Exception(
91 "could not resolve package version; " "this should never happen"
92 )
91 93
92 94 setup(
93 95 name="zstandard",
94 96 version=version,
95 97 description="Zstandard bindings for Python",
96 98 long_description=open("README.rst", "r").read(),
97 99 url="https://github.com/indygreg/python-zstandard",
98 100 author="Gregory Szorc",
99 101 author_email="gregory.szorc@gmail.com",
100 102 license="BSD",
101 103 classifiers=[
102 104 "Development Status :: 4 - Beta",
103 105 "Intended Audience :: Developers",
104 106 "License :: OSI Approved :: BSD License",
105 107 "Programming Language :: C",
106 108 "Programming Language :: Python :: 2.7",
107 109 "Programming Language :: Python :: 3.5",
108 110 "Programming Language :: Python :: 3.6",
109 111 "Programming Language :: Python :: 3.7",
110 112 "Programming Language :: Python :: 3.8",
111 113 ],
112 114 keywords="zstandard zstd compression",
113 115 packages=["zstandard"],
114 116 ext_modules=extensions,
115 117 test_suite="tests",
116 118 install_requires=install_requires,
117 119 tests_require=["hypothesis"],
118 120 )
@@ -1,206 +1,210 b''
1 1 # Copyright (c) 2016-present, Gregory Szorc
2 2 # All rights reserved.
3 3 #
4 4 # This software may be modified and distributed under the terms
5 5 # of the BSD license. See the LICENSE file for details.
6 6
7 7 import distutils.ccompiler
8 8 import os
9 9
10 10 from distutils.extension import Extension
11 11
12 12
13 13 zstd_sources = [
14 14 "zstd/%s" % p
15 15 for p in (
16 16 "common/debug.c",
17 17 "common/entropy_common.c",
18 18 "common/error_private.c",
19 19 "common/fse_decompress.c",
20 20 "common/pool.c",
21 21 "common/threading.c",
22 22 "common/xxhash.c",
23 23 "common/zstd_common.c",
24 24 "compress/fse_compress.c",
25 25 "compress/hist.c",
26 26 "compress/huf_compress.c",
27 27 "compress/zstd_compress_literals.c",
28 28 "compress/zstd_compress_sequences.c",
29 29 "compress/zstd_compress.c",
30 30 "compress/zstd_double_fast.c",
31 31 "compress/zstd_fast.c",
32 32 "compress/zstd_lazy.c",
33 33 "compress/zstd_ldm.c",
34 34 "compress/zstd_opt.c",
35 35 "compress/zstdmt_compress.c",
36 36 "decompress/huf_decompress.c",
37 37 "decompress/zstd_ddict.c",
38 38 "decompress/zstd_decompress.c",
39 39 "decompress/zstd_decompress_block.c",
40 40 "dictBuilder/cover.c",
41 41 "dictBuilder/divsufsort.c",
42 42 "dictBuilder/fastcover.c",
43 43 "dictBuilder/zdict.c",
44 44 )
45 45 ]
46 46
47 47 zstd_sources_legacy = [
48 48 "zstd/%s" % p
49 49 for p in (
50 50 "deprecated/zbuff_common.c",
51 51 "deprecated/zbuff_compress.c",
52 52 "deprecated/zbuff_decompress.c",
53 53 "legacy/zstd_v01.c",
54 54 "legacy/zstd_v02.c",
55 55 "legacy/zstd_v03.c",
56 56 "legacy/zstd_v04.c",
57 57 "legacy/zstd_v05.c",
58 58 "legacy/zstd_v06.c",
59 59 "legacy/zstd_v07.c",
60 60 )
61 61 ]
62 62
63 63 zstd_includes = [
64 64 "zstd",
65 65 "zstd/common",
66 66 "zstd/compress",
67 67 "zstd/decompress",
68 68 "zstd/dictBuilder",
69 69 ]
70 70
71 71 zstd_includes_legacy = [
72 72 "zstd/deprecated",
73 73 "zstd/legacy",
74 74 ]
75 75
76 76 ext_includes = [
77 77 "c-ext",
78 78 "zstd/common",
79 79 ]
80 80
81 81 ext_sources = [
82 82 "zstd/common/error_private.c",
83 83 "zstd/common/pool.c",
84 84 "zstd/common/threading.c",
85 85 "zstd/common/zstd_common.c",
86 86 "zstd.c",
87 87 "c-ext/bufferutil.c",
88 88 "c-ext/compressiondict.c",
89 89 "c-ext/compressobj.c",
90 90 "c-ext/compressor.c",
91 91 "c-ext/compressoriterator.c",
92 92 "c-ext/compressionchunker.c",
93 93 "c-ext/compressionparams.c",
94 94 "c-ext/compressionreader.c",
95 95 "c-ext/compressionwriter.c",
96 96 "c-ext/constants.c",
97 97 "c-ext/decompressobj.c",
98 98 "c-ext/decompressor.c",
99 99 "c-ext/decompressoriterator.c",
100 100 "c-ext/decompressionreader.c",
101 101 "c-ext/decompressionwriter.c",
102 102 "c-ext/frameparams.c",
103 103 ]
104 104
105 105 zstd_depends = [
106 106 "c-ext/python-zstandard.h",
107 107 ]
108 108
109 109
110 110 def get_c_extension(
111 111 support_legacy=False,
112 112 system_zstd=False,
113 113 name="zstd",
114 114 warnings_as_errors=False,
115 115 root=None,
116 116 ):
117 117 """Obtain a distutils.extension.Extension for the C extension.
118 118
119 119 ``support_legacy`` controls whether to compile in legacy zstd format support.
120 120
121 121 ``system_zstd`` controls whether to compile against the system zstd library.
122 122 For this to work, the system zstd library and headers must match what
123 123 python-zstandard is coded against exactly.
124 124
125 125 ``name`` is the module name of the C extension to produce.
126 126
127 127 ``warnings_as_errors`` controls whether compiler warnings are turned into
128 128 compiler errors.
129 129
130 130 ``root`` defines a root path that source should be computed as relative
131 131 to. This should be the directory with the main ``setup.py`` that is
132 132 being invoked. If not defined, paths will be relative to this file.
133 133 """
134 134 actual_root = os.path.abspath(os.path.dirname(__file__))
135 135 root = root or actual_root
136 136
137 137 sources = set([os.path.join(actual_root, p) for p in ext_sources])
138 138 if not system_zstd:
139 139 sources.update([os.path.join(actual_root, p) for p in zstd_sources])
140 140 if support_legacy:
141 sources.update([os.path.join(actual_root, p) for p in zstd_sources_legacy])
141 sources.update(
142 [os.path.join(actual_root, p) for p in zstd_sources_legacy]
143 )
142 144 sources = list(sources)
143 145
144 146 include_dirs = set([os.path.join(actual_root, d) for d in ext_includes])
145 147 if not system_zstd:
146 include_dirs.update([os.path.join(actual_root, d) for d in zstd_includes])
148 include_dirs.update(
149 [os.path.join(actual_root, d) for d in zstd_includes]
150 )
147 151 if support_legacy:
148 152 include_dirs.update(
149 153 [os.path.join(actual_root, d) for d in zstd_includes_legacy]
150 154 )
151 155 include_dirs = list(include_dirs)
152 156
153 157 depends = [os.path.join(actual_root, p) for p in zstd_depends]
154 158
155 159 compiler = distutils.ccompiler.new_compiler()
156 160
157 161 # Needed for MSVC.
158 162 if hasattr(compiler, "initialize"):
159 163 compiler.initialize()
160 164
161 165 if compiler.compiler_type == "unix":
162 166 compiler_type = "unix"
163 167 elif compiler.compiler_type == "msvc":
164 168 compiler_type = "msvc"
165 169 elif compiler.compiler_type == "mingw32":
166 170 compiler_type = "mingw32"
167 171 else:
168 172 raise Exception("unhandled compiler type: %s" % compiler.compiler_type)
169 173
170 174 extra_args = ["-DZSTD_MULTITHREAD"]
171 175
172 176 if not system_zstd:
173 177 extra_args.append("-DZSTDLIB_VISIBILITY=")
174 178 extra_args.append("-DZDICTLIB_VISIBILITY=")
175 179 extra_args.append("-DZSTDERRORLIB_VISIBILITY=")
176 180
177 181 if compiler_type == "unix":
178 182 extra_args.append("-fvisibility=hidden")
179 183
180 184 if not system_zstd and support_legacy:
181 185 extra_args.append("-DZSTD_LEGACY_SUPPORT=1")
182 186
183 187 if warnings_as_errors:
184 188 if compiler_type in ("unix", "mingw32"):
185 189 extra_args.append("-Werror")
186 190 elif compiler_type == "msvc":
187 191 extra_args.append("/WX")
188 192 else:
189 193 assert False
190 194
191 195 libraries = ["zstd"] if system_zstd else []
192 196
193 197 # Python 3.7 doesn't like absolute paths. So normalize to relative.
194 198 sources = [os.path.relpath(p, root) for p in sources]
195 199 include_dirs = [os.path.relpath(p, root) for p in include_dirs]
196 200 depends = [os.path.relpath(p, root) for p in depends]
197 201
198 202 # TODO compile with optimizations.
199 203 return Extension(
200 204 name,
201 205 sources,
202 206 include_dirs=include_dirs,
203 207 depends=depends,
204 208 extra_compile_args=extra_args,
205 209 libraries=libraries,
206 210 )
@@ -1,197 +1,203 b''
1 1 import imp
2 2 import inspect
3 3 import io
4 4 import os
5 5 import types
6 6 import unittest
7 7
8 8 try:
9 9 import hypothesis
10 10 except ImportError:
11 11 hypothesis = None
12 12
13 13
14 14 class TestCase(unittest.TestCase):
15 15 if not getattr(unittest.TestCase, "assertRaisesRegex", False):
16 16 assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
17 17
18 18
19 19 def make_cffi(cls):
20 20 """Decorator to add CFFI versions of each test method."""
21 21
22 22 # The module containing this class definition should
23 23 # `import zstandard as zstd`. Otherwise things may blow up.
24 24 mod = inspect.getmodule(cls)
25 25 if not hasattr(mod, "zstd"):
26 26 raise Exception('test module does not contain "zstd" symbol')
27 27
28 28 if not hasattr(mod.zstd, "backend"):
29 29 raise Exception(
30 30 'zstd symbol does not have "backend" attribute; did '
31 31 "you `import zstandard as zstd`?"
32 32 )
33 33
34 34 # If `import zstandard` already chose the cffi backend, there is nothing
35 35 # for us to do: we only add the cffi variation if the default backend
36 36 # is the C extension.
37 37 if mod.zstd.backend == "cffi":
38 38 return cls
39 39
40 40 old_env = dict(os.environ)
41 41 os.environ["PYTHON_ZSTANDARD_IMPORT_POLICY"] = "cffi"
42 42 try:
43 43 try:
44 44 mod_info = imp.find_module("zstandard")
45 45 mod = imp.load_module("zstandard_cffi", *mod_info)
46 46 except ImportError:
47 47 return cls
48 48 finally:
49 49 os.environ.clear()
50 50 os.environ.update(old_env)
51 51
52 52 if mod.backend != "cffi":
53 raise Exception("got the zstandard %s backend instead of cffi" % mod.backend)
53 raise Exception(
54 "got the zstandard %s backend instead of cffi" % mod.backend
55 )
54 56
55 57 # If CFFI version is available, dynamically construct test methods
56 58 # that use it.
57 59
58 60 for attr in dir(cls):
59 61 fn = getattr(cls, attr)
60 62 if not inspect.ismethod(fn) and not inspect.isfunction(fn):
61 63 continue
62 64
63 65 if not fn.__name__.startswith("test_"):
64 66 continue
65 67
66 68 name = "%s_cffi" % fn.__name__
67 69
68 70 # Replace the "zstd" symbol with the CFFI module instance. Then copy
69 71 # the function object and install it in a new attribute.
70 72 if isinstance(fn, types.FunctionType):
71 73 globs = dict(fn.__globals__)
72 74 globs["zstd"] = mod
73 75 new_fn = types.FunctionType(
74 76 fn.__code__, globs, name, fn.__defaults__, fn.__closure__
75 77 )
76 78 new_method = new_fn
77 79 else:
78 80 globs = dict(fn.__func__.func_globals)
79 81 globs["zstd"] = mod
80 82 new_fn = types.FunctionType(
81 83 fn.__func__.func_code,
82 84 globs,
83 85 name,
84 86 fn.__func__.func_defaults,
85 87 fn.__func__.func_closure,
86 88 )
87 new_method = types.UnboundMethodType(new_fn, fn.im_self, fn.im_class)
89 new_method = types.UnboundMethodType(
90 new_fn, fn.im_self, fn.im_class
91 )
88 92
89 93 setattr(cls, name, new_method)
90 94
91 95 return cls
92 96
93 97
94 98 class NonClosingBytesIO(io.BytesIO):
95 99 """BytesIO that saves the underlying buffer on close().
96 100
97 101 This allows us to access written data after close().
98 102 """
99 103
100 104 def __init__(self, *args, **kwargs):
101 105 super(NonClosingBytesIO, self).__init__(*args, **kwargs)
102 106 self._saved_buffer = None
103 107
104 108 def close(self):
105 109 self._saved_buffer = self.getvalue()
106 110 return super(NonClosingBytesIO, self).close()
107 111
108 112 def getvalue(self):
109 113 if self.closed:
110 114 return self._saved_buffer
111 115 else:
112 116 return super(NonClosingBytesIO, self).getvalue()
113 117
114 118
115 119 class OpCountingBytesIO(NonClosingBytesIO):
116 120 def __init__(self, *args, **kwargs):
117 121 self._flush_count = 0
118 122 self._read_count = 0
119 123 self._write_count = 0
120 124 return super(OpCountingBytesIO, self).__init__(*args, **kwargs)
121 125
122 126 def flush(self):
123 127 self._flush_count += 1
124 128 return super(OpCountingBytesIO, self).flush()
125 129
126 130 def read(self, *args):
127 131 self._read_count += 1
128 132 return super(OpCountingBytesIO, self).read(*args)
129 133
130 134 def write(self, data):
131 135 self._write_count += 1
132 136 return super(OpCountingBytesIO, self).write(data)
133 137
134 138
135 139 _source_files = []
136 140
137 141
138 142 def random_input_data():
139 143 """Obtain the raw content of source files.
140 144
141 145 This is used for generating "random" data to feed into fuzzing, since it is
142 146 faster than random content generation.
143 147 """
144 148 if _source_files:
145 149 return _source_files
146 150
147 151 for root, dirs, files in os.walk(os.path.dirname(__file__)):
148 152 dirs[:] = list(sorted(dirs))
149 153 for f in sorted(files):
150 154 try:
151 155 with open(os.path.join(root, f), "rb") as fh:
152 156 data = fh.read()
153 157 if data:
154 158 _source_files.append(data)
155 159 except OSError:
156 160 pass
157 161
158 162 # Also add some actual random data.
159 163 _source_files.append(os.urandom(100))
160 164 _source_files.append(os.urandom(1000))
161 165 _source_files.append(os.urandom(10000))
162 166 _source_files.append(os.urandom(100000))
163 167 _source_files.append(os.urandom(1000000))
164 168
165 169 return _source_files
166 170
167 171
168 172 def generate_samples():
169 173 inputs = [
170 174 b"foo",
171 175 b"bar",
172 176 b"abcdef",
173 177 b"sometext",
174 178 b"baz",
175 179 ]
176 180
177 181 samples = []
178 182
179 183 for i in range(128):
180 184 samples.append(inputs[i % 5])
181 185 samples.append(inputs[i % 5] * (i + 3))
182 186 samples.append(inputs[-(i % 5)] * (i + 2))
183 187
184 188 return samples
185 189
186 190
187 191 if hypothesis:
188 192 default_settings = hypothesis.settings(deadline=10000)
189 193 hypothesis.settings.register_profile("default", default_settings)
190 194
191 195 ci_settings = hypothesis.settings(deadline=20000, max_examples=1000)
192 196 hypothesis.settings.register_profile("ci", ci_settings)
193 197
194 198 expensive_settings = hypothesis.settings(deadline=None, max_examples=10000)
195 199 hypothesis.settings.register_profile("expensive", expensive_settings)
196 200
197 hypothesis.settings.load_profile(os.environ.get("HYPOTHESIS_PROFILE", "default"))
201 hypothesis.settings.load_profile(
202 os.environ.get("HYPOTHESIS_PROFILE", "default")
203 )
@@ -1,146 +1,153 b''
1 1 import struct
2 2 import unittest
3 3
4 4 import zstandard as zstd
5 5
6 6 from .common import TestCase
7 7
8 8 ss = struct.Struct("=QQ")
9 9
10 10
11 11 class TestBufferWithSegments(TestCase):
12 12 def test_arguments(self):
13 13 if not hasattr(zstd, "BufferWithSegments"):
14 14 self.skipTest("BufferWithSegments not available")
15 15
16 16 with self.assertRaises(TypeError):
17 17 zstd.BufferWithSegments()
18 18
19 19 with self.assertRaises(TypeError):
20 20 zstd.BufferWithSegments(b"foo")
21 21
22 22 # Segments data should be a multiple of 16.
23 23 with self.assertRaisesRegex(
24 24 ValueError, "segments array size is not a multiple of 16"
25 25 ):
26 26 zstd.BufferWithSegments(b"foo", b"\x00\x00")
27 27
28 28 def test_invalid_offset(self):
29 29 if not hasattr(zstd, "BufferWithSegments"):
30 30 self.skipTest("BufferWithSegments not available")
31 31
32 32 with self.assertRaisesRegex(
33 33 ValueError, "offset within segments array references memory"
34 34 ):
35 35 zstd.BufferWithSegments(b"foo", ss.pack(0, 4))
36 36
37 37 def test_invalid_getitem(self):
38 38 if not hasattr(zstd, "BufferWithSegments"):
39 39 self.skipTest("BufferWithSegments not available")
40 40
41 41 b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
42 42
43 43 with self.assertRaisesRegex(IndexError, "offset must be non-negative"):
44 44 test = b[-10]
45 45
46 46 with self.assertRaisesRegex(IndexError, "offset must be less than 1"):
47 47 test = b[1]
48 48
49 49 with self.assertRaisesRegex(IndexError, "offset must be less than 1"):
50 50 test = b[2]
51 51
52 52 def test_single(self):
53 53 if not hasattr(zstd, "BufferWithSegments"):
54 54 self.skipTest("BufferWithSegments not available")
55 55
56 56 b = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
57 57 self.assertEqual(len(b), 1)
58 58 self.assertEqual(b.size, 3)
59 59 self.assertEqual(b.tobytes(), b"foo")
60 60
61 61 self.assertEqual(len(b[0]), 3)
62 62 self.assertEqual(b[0].offset, 0)
63 63 self.assertEqual(b[0].tobytes(), b"foo")
64 64
65 65 def test_multiple(self):
66 66 if not hasattr(zstd, "BufferWithSegments"):
67 67 self.skipTest("BufferWithSegments not available")
68 68
69 69 b = zstd.BufferWithSegments(
70 b"foofooxfooxy", b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)])
70 b"foofooxfooxy",
71 b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]),
71 72 )
72 73 self.assertEqual(len(b), 3)
73 74 self.assertEqual(b.size, 12)
74 75 self.assertEqual(b.tobytes(), b"foofooxfooxy")
75 76
76 77 self.assertEqual(b[0].tobytes(), b"foo")
77 78 self.assertEqual(b[1].tobytes(), b"foox")
78 79 self.assertEqual(b[2].tobytes(), b"fooxy")
79 80
80 81
81 82 class TestBufferWithSegmentsCollection(TestCase):
82 83 def test_empty_constructor(self):
83 84 if not hasattr(zstd, "BufferWithSegmentsCollection"):
84 85 self.skipTest("BufferWithSegmentsCollection not available")
85 86
86 with self.assertRaisesRegex(ValueError, "must pass at least 1 argument"):
87 with self.assertRaisesRegex(
88 ValueError, "must pass at least 1 argument"
89 ):
87 90 zstd.BufferWithSegmentsCollection()
88 91
89 92 def test_argument_validation(self):
90 93 if not hasattr(zstd, "BufferWithSegmentsCollection"):
91 94 self.skipTest("BufferWithSegmentsCollection not available")
92 95
93 with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"):
96 with self.assertRaisesRegex(
97 TypeError, "arguments must be BufferWithSegments"
98 ):
94 99 zstd.BufferWithSegmentsCollection(None)
95 100
96 with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"):
101 with self.assertRaisesRegex(
102 TypeError, "arguments must be BufferWithSegments"
103 ):
97 104 zstd.BufferWithSegmentsCollection(
98 105 zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None
99 106 )
100 107
101 108 with self.assertRaisesRegex(
102 109 ValueError, "ZstdBufferWithSegments cannot be empty"
103 110 ):
104 111 zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b""))
105 112
106 113 def test_length(self):
107 114 if not hasattr(zstd, "BufferWithSegmentsCollection"):
108 115 self.skipTest("BufferWithSegmentsCollection not available")
109 116
110 117 b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
111 118 b2 = zstd.BufferWithSegments(
112 119 b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)])
113 120 )
114 121
115 122 c = zstd.BufferWithSegmentsCollection(b1)
116 123 self.assertEqual(len(c), 1)
117 124 self.assertEqual(c.size(), 3)
118 125
119 126 c = zstd.BufferWithSegmentsCollection(b2)
120 127 self.assertEqual(len(c), 2)
121 128 self.assertEqual(c.size(), 6)
122 129
123 130 c = zstd.BufferWithSegmentsCollection(b1, b2)
124 131 self.assertEqual(len(c), 3)
125 132 self.assertEqual(c.size(), 9)
126 133
127 134 def test_getitem(self):
128 135 if not hasattr(zstd, "BufferWithSegmentsCollection"):
129 136 self.skipTest("BufferWithSegmentsCollection not available")
130 137
131 138 b1 = zstd.BufferWithSegments(b"foo", ss.pack(0, 3))
132 139 b2 = zstd.BufferWithSegments(
133 140 b"barbaz", b"".join([ss.pack(0, 3), ss.pack(3, 3)])
134 141 )
135 142
136 143 c = zstd.BufferWithSegmentsCollection(b1, b2)
137 144
138 145 with self.assertRaisesRegex(IndexError, "offset must be less than 3"):
139 146 c[3]
140 147
141 148 with self.assertRaisesRegex(IndexError, "offset must be less than 3"):
142 149 c[4]
143 150
144 151 self.assertEqual(c[0].tobytes(), b"foo")
145 152 self.assertEqual(c[1].tobytes(), b"bar")
146 153 self.assertEqual(c[2].tobytes(), b"baz")
@@ -1,1770 +1,1803 b''
1 1 import hashlib
2 2 import io
3 3 import os
4 4 import struct
5 5 import sys
6 6 import tarfile
7 7 import tempfile
8 8 import unittest
9 9
10 10 import zstandard as zstd
11 11
12 12 from .common import (
13 13 make_cffi,
14 14 NonClosingBytesIO,
15 15 OpCountingBytesIO,
16 16 TestCase,
17 17 )
18 18
19 19
20 20 if sys.version_info[0] >= 3:
21 21 next = lambda it: it.__next__()
22 22 else:
23 23 next = lambda it: it.next()
24 24
25 25
26 26 def multithreaded_chunk_size(level, source_size=0):
27 params = zstd.ZstdCompressionParameters.from_level(level, source_size=source_size)
27 params = zstd.ZstdCompressionParameters.from_level(
28 level, source_size=source_size
29 )
28 30
29 31 return 1 << (params.window_log + 2)
30 32
31 33
32 34 @make_cffi
33 35 class TestCompressor(TestCase):
34 36 def test_level_bounds(self):
35 37 with self.assertRaises(ValueError):
36 38 zstd.ZstdCompressor(level=23)
37 39
38 40 def test_memory_size(self):
39 41 cctx = zstd.ZstdCompressor(level=1)
40 42 self.assertGreater(cctx.memory_size(), 100)
41 43
42 44
43 45 @make_cffi
44 46 class TestCompressor_compress(TestCase):
45 47 def test_compress_empty(self):
46 48 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
47 49 result = cctx.compress(b"")
48 50 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
49 51 params = zstd.get_frame_parameters(result)
50 52 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
51 53 self.assertEqual(params.window_size, 524288)
52 54 self.assertEqual(params.dict_id, 0)
53 55 self.assertFalse(params.has_checksum, 0)
54 56
55 57 cctx = zstd.ZstdCompressor()
56 58 result = cctx.compress(b"")
57 59 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00")
58 60 params = zstd.get_frame_parameters(result)
59 61 self.assertEqual(params.content_size, 0)
60 62
61 63 def test_input_types(self):
62 64 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
63 65 expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f"
64 66
65 67 mutable_array = bytearray(3)
66 68 mutable_array[:] = b"foo"
67 69
68 70 sources = [
69 71 memoryview(b"foo"),
70 72 bytearray(b"foo"),
71 73 mutable_array,
72 74 ]
73 75
74 76 for source in sources:
75 77 self.assertEqual(cctx.compress(source), expected)
76 78
77 79 def test_compress_large(self):
78 80 chunks = []
79 81 for i in range(255):
80 82 chunks.append(struct.Struct(">B").pack(i) * 16384)
81 83
82 84 cctx = zstd.ZstdCompressor(level=3, write_content_size=False)
83 85 result = cctx.compress(b"".join(chunks))
84 86 self.assertEqual(len(result), 999)
85 87 self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
86 88
87 89 # This matches the test for read_to_iter() below.
88 90 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
89 result = cctx.compress(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o")
91 result = cctx.compress(
92 b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o"
93 )
90 94 self.assertEqual(
91 95 result,
92 96 b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00"
93 97 b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0"
94 98 b"\x02\x09\x00\x00\x6f",
95 99 )
96 100
97 101 def test_negative_level(self):
98 102 cctx = zstd.ZstdCompressor(level=-4)
99 103 result = cctx.compress(b"foo" * 256)
100 104
101 105 def test_no_magic(self):
102 params = zstd.ZstdCompressionParameters.from_level(1, format=zstd.FORMAT_ZSTD1)
106 params = zstd.ZstdCompressionParameters.from_level(
107 1, format=zstd.FORMAT_ZSTD1
108 )
103 109 cctx = zstd.ZstdCompressor(compression_params=params)
104 110 magic = cctx.compress(b"foobar")
105 111
106 112 params = zstd.ZstdCompressionParameters.from_level(
107 113 1, format=zstd.FORMAT_ZSTD1_MAGICLESS
108 114 )
109 115 cctx = zstd.ZstdCompressor(compression_params=params)
110 116 no_magic = cctx.compress(b"foobar")
111 117
112 118 self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd")
113 119 self.assertEqual(magic[4:], no_magic)
114 120
115 121 def test_write_checksum(self):
116 122 cctx = zstd.ZstdCompressor(level=1)
117 123 no_checksum = cctx.compress(b"foobar")
118 124 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
119 125 with_checksum = cctx.compress(b"foobar")
120 126
121 127 self.assertEqual(len(with_checksum), len(no_checksum) + 4)
122 128
123 129 no_params = zstd.get_frame_parameters(no_checksum)
124 130 with_params = zstd.get_frame_parameters(with_checksum)
125 131
126 132 self.assertFalse(no_params.has_checksum)
127 133 self.assertTrue(with_params.has_checksum)
128 134
129 135 def test_write_content_size(self):
130 136 cctx = zstd.ZstdCompressor(level=1)
131 137 with_size = cctx.compress(b"foobar" * 256)
132 138 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
133 139 no_size = cctx.compress(b"foobar" * 256)
134 140
135 141 self.assertEqual(len(with_size), len(no_size) + 1)
136 142
137 143 no_params = zstd.get_frame_parameters(no_size)
138 144 with_params = zstd.get_frame_parameters(with_size)
139 145 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
140 146 self.assertEqual(with_params.content_size, 1536)
141 147
142 148 def test_no_dict_id(self):
143 149 samples = []
144 150 for i in range(128):
145 151 samples.append(b"foo" * 64)
146 152 samples.append(b"bar" * 64)
147 153 samples.append(b"foobar" * 64)
148 154
149 155 d = zstd.train_dictionary(1024, samples)
150 156
151 157 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
152 158 with_dict_id = cctx.compress(b"foobarfoobar")
153 159
154 160 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
155 161 no_dict_id = cctx.compress(b"foobarfoobar")
156 162
157 163 self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
158 164
159 165 no_params = zstd.get_frame_parameters(no_dict_id)
160 166 with_params = zstd.get_frame_parameters(with_dict_id)
161 167 self.assertEqual(no_params.dict_id, 0)
162 168 self.assertEqual(with_params.dict_id, 1880053135)
163 169
164 170 def test_compress_dict_multiple(self):
165 171 samples = []
166 172 for i in range(128):
167 173 samples.append(b"foo" * 64)
168 174 samples.append(b"bar" * 64)
169 175 samples.append(b"foobar" * 64)
170 176
171 177 d = zstd.train_dictionary(8192, samples)
172 178
173 179 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
174 180
175 181 for i in range(32):
176 182 cctx.compress(b"foo bar foobar foo bar foobar")
177 183
178 184 def test_dict_precompute(self):
179 185 samples = []
180 186 for i in range(128):
181 187 samples.append(b"foo" * 64)
182 188 samples.append(b"bar" * 64)
183 189 samples.append(b"foobar" * 64)
184 190
185 191 d = zstd.train_dictionary(8192, samples)
186 192 d.precompute_compress(level=1)
187 193
188 194 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
189 195
190 196 for i in range(32):
191 197 cctx.compress(b"foo bar foobar foo bar foobar")
192 198
193 199 def test_multithreaded(self):
194 200 chunk_size = multithreaded_chunk_size(1)
195 201 source = b"".join([b"x" * chunk_size, b"y" * chunk_size])
196 202
197 203 cctx = zstd.ZstdCompressor(level=1, threads=2)
198 204 compressed = cctx.compress(source)
199 205
200 206 params = zstd.get_frame_parameters(compressed)
201 207 self.assertEqual(params.content_size, chunk_size * 2)
202 208 self.assertEqual(params.dict_id, 0)
203 209 self.assertFalse(params.has_checksum)
204 210
205 211 dctx = zstd.ZstdDecompressor()
206 212 self.assertEqual(dctx.decompress(compressed), source)
207 213
208 214 def test_multithreaded_dict(self):
209 215 samples = []
210 216 for i in range(128):
211 217 samples.append(b"foo" * 64)
212 218 samples.append(b"bar" * 64)
213 219 samples.append(b"foobar" * 64)
214 220
215 221 d = zstd.train_dictionary(1024, samples)
216 222
217 223 cctx = zstd.ZstdCompressor(dict_data=d, threads=2)
218 224
219 225 result = cctx.compress(b"foo")
220 226 params = zstd.get_frame_parameters(result)
221 227 self.assertEqual(params.content_size, 3)
222 228 self.assertEqual(params.dict_id, d.dict_id())
223 229
224 230 self.assertEqual(
225 231 result,
226 b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" b"\x66\x6f\x6f",
232 b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00"
233 b"\x66\x6f\x6f",
227 234 )
228 235
229 236 def test_multithreaded_compression_params(self):
230 237 params = zstd.ZstdCompressionParameters.from_level(0, threads=2)
231 238 cctx = zstd.ZstdCompressor(compression_params=params)
232 239
233 240 result = cctx.compress(b"foo")
234 241 params = zstd.get_frame_parameters(result)
235 242 self.assertEqual(params.content_size, 3)
236 243
237 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f")
244 self.assertEqual(
245 result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f"
246 )
238 247
239 248
240 249 @make_cffi
241 250 class TestCompressor_compressobj(TestCase):
242 251 def test_compressobj_empty(self):
243 252 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
244 253 cobj = cctx.compressobj()
245 254 self.assertEqual(cobj.compress(b""), b"")
246 255 self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
247 256
248 257 def test_input_types(self):
249 258 expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
250 259 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
251 260
252 261 mutable_array = bytearray(3)
253 262 mutable_array[:] = b"foo"
254 263
255 264 sources = [
256 265 memoryview(b"foo"),
257 266 bytearray(b"foo"),
258 267 mutable_array,
259 268 ]
260 269
261 270 for source in sources:
262 271 cobj = cctx.compressobj()
263 272 self.assertEqual(cobj.compress(source), b"")
264 273 self.assertEqual(cobj.flush(), expected)
265 274
266 275 def test_compressobj_large(self):
267 276 chunks = []
268 277 for i in range(255):
269 278 chunks.append(struct.Struct(">B").pack(i) * 16384)
270 279
271 280 cctx = zstd.ZstdCompressor(level=3)
272 281 cobj = cctx.compressobj()
273 282
274 283 result = cobj.compress(b"".join(chunks)) + cobj.flush()
275 284 self.assertEqual(len(result), 999)
276 285 self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
277 286
278 287 params = zstd.get_frame_parameters(result)
279 288 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
280 289 self.assertEqual(params.window_size, 2097152)
281 290 self.assertEqual(params.dict_id, 0)
282 291 self.assertFalse(params.has_checksum)
283 292
284 293 def test_write_checksum(self):
285 294 cctx = zstd.ZstdCompressor(level=1)
286 295 cobj = cctx.compressobj()
287 296 no_checksum = cobj.compress(b"foobar") + cobj.flush()
288 297 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
289 298 cobj = cctx.compressobj()
290 299 with_checksum = cobj.compress(b"foobar") + cobj.flush()
291 300
292 301 no_params = zstd.get_frame_parameters(no_checksum)
293 302 with_params = zstd.get_frame_parameters(with_checksum)
294 303 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
295 304 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
296 305 self.assertEqual(no_params.dict_id, 0)
297 306 self.assertEqual(with_params.dict_id, 0)
298 307 self.assertFalse(no_params.has_checksum)
299 308 self.assertTrue(with_params.has_checksum)
300 309
301 310 self.assertEqual(len(with_checksum), len(no_checksum) + 4)
302 311
303 312 def test_write_content_size(self):
304 313 cctx = zstd.ZstdCompressor(level=1)
305 314 cobj = cctx.compressobj(size=len(b"foobar" * 256))
306 315 with_size = cobj.compress(b"foobar" * 256) + cobj.flush()
307 316 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
308 317 cobj = cctx.compressobj(size=len(b"foobar" * 256))
309 318 no_size = cobj.compress(b"foobar" * 256) + cobj.flush()
310 319
311 320 no_params = zstd.get_frame_parameters(no_size)
312 321 with_params = zstd.get_frame_parameters(with_size)
313 322 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
314 323 self.assertEqual(with_params.content_size, 1536)
315 324 self.assertEqual(no_params.dict_id, 0)
316 325 self.assertEqual(with_params.dict_id, 0)
317 326 self.assertFalse(no_params.has_checksum)
318 327 self.assertFalse(with_params.has_checksum)
319 328
320 329 self.assertEqual(len(with_size), len(no_size) + 1)
321 330
322 331 def test_compress_after_finished(self):
323 332 cctx = zstd.ZstdCompressor()
324 333 cobj = cctx.compressobj()
325 334
326 335 cobj.compress(b"foo")
327 336 cobj.flush()
328 337
329 338 with self.assertRaisesRegex(
330 339 zstd.ZstdError, r"cannot call compress\(\) after compressor"
331 340 ):
332 341 cobj.compress(b"foo")
333 342
334 343 with self.assertRaisesRegex(
335 344 zstd.ZstdError, "compressor object already finished"
336 345 ):
337 346 cobj.flush()
338 347
339 348 def test_flush_block_repeated(self):
340 349 cctx = zstd.ZstdCompressor(level=1)
341 350 cobj = cctx.compressobj()
342 351
343 352 self.assertEqual(cobj.compress(b"foo"), b"")
344 353 self.assertEqual(
345 354 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
346 355 b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo",
347 356 )
348 357 self.assertEqual(cobj.compress(b"bar"), b"")
349 358 # 3 byte header plus content.
350 self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar")
359 self.assertEqual(
360 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar"
361 )
351 362 self.assertEqual(cobj.flush(), b"\x01\x00\x00")
352 363
353 364 def test_flush_empty_block(self):
354 365 cctx = zstd.ZstdCompressor(write_checksum=True)
355 366 cobj = cctx.compressobj()
356 367
357 368 cobj.compress(b"foobar")
358 369 cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
359 370 # No-op if no block is active (this is internal to zstd).
360 371 self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"")
361 372
362 373 trailing = cobj.flush()
363 374 # 3 bytes block header + 4 bytes frame checksum
364 375 self.assertEqual(len(trailing), 7)
365 376 header = trailing[0:3]
366 377 self.assertEqual(header, b"\x01\x00\x00")
367 378
368 379 def test_multithreaded(self):
369 380 source = io.BytesIO()
370 381 source.write(b"a" * 1048576)
371 382 source.write(b"b" * 1048576)
372 383 source.write(b"c" * 1048576)
373 384 source.seek(0)
374 385
375 386 cctx = zstd.ZstdCompressor(level=1, threads=2)
376 387 cobj = cctx.compressobj()
377 388
378 389 chunks = []
379 390 while True:
380 391 d = source.read(8192)
381 392 if not d:
382 393 break
383 394
384 395 chunks.append(cobj.compress(d))
385 396
386 397 chunks.append(cobj.flush())
387 398
388 399 compressed = b"".join(chunks)
389 400
390 401 self.assertEqual(len(compressed), 119)
391 402
392 403 def test_frame_progression(self):
393 404 cctx = zstd.ZstdCompressor()
394 405
395 406 self.assertEqual(cctx.frame_progression(), (0, 0, 0))
396 407
397 408 cobj = cctx.compressobj()
398 409
399 410 cobj.compress(b"foobar")
400 411 self.assertEqual(cctx.frame_progression(), (6, 0, 0))
401 412
402 413 cobj.flush()
403 414 self.assertEqual(cctx.frame_progression(), (6, 6, 15))
404 415
405 416 def test_bad_size(self):
406 417 cctx = zstd.ZstdCompressor()
407 418
408 419 cobj = cctx.compressobj(size=2)
409 420 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
410 421 cobj.compress(b"foo")
411 422
412 423 # Try another operation on this instance.
413 424 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
414 425 cobj.compress(b"aa")
415 426
416 427 # Try another operation on the compressor.
417 428 cctx.compressobj(size=4)
418 429 cctx.compress(b"foobar")
419 430
420 431
421 432 @make_cffi
422 433 class TestCompressor_copy_stream(TestCase):
423 434 def test_no_read(self):
424 435 source = object()
425 436 dest = io.BytesIO()
426 437
427 438 cctx = zstd.ZstdCompressor()
428 439 with self.assertRaises(ValueError):
429 440 cctx.copy_stream(source, dest)
430 441
431 442 def test_no_write(self):
432 443 source = io.BytesIO()
433 444 dest = object()
434 445
435 446 cctx = zstd.ZstdCompressor()
436 447 with self.assertRaises(ValueError):
437 448 cctx.copy_stream(source, dest)
438 449
439 450 def test_empty(self):
440 451 source = io.BytesIO()
441 452 dest = io.BytesIO()
442 453
443 454 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
444 455 r, w = cctx.copy_stream(source, dest)
445 456 self.assertEqual(int(r), 0)
446 457 self.assertEqual(w, 9)
447 458
448 self.assertEqual(dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
459 self.assertEqual(
460 dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00"
461 )
449 462
450 463 def test_large_data(self):
451 464 source = io.BytesIO()
452 465 for i in range(255):
453 466 source.write(struct.Struct(">B").pack(i) * 16384)
454 467 source.seek(0)
455 468
456 469 dest = io.BytesIO()
457 470 cctx = zstd.ZstdCompressor()
458 471 r, w = cctx.copy_stream(source, dest)
459 472
460 473 self.assertEqual(r, 255 * 16384)
461 474 self.assertEqual(w, 999)
462 475
463 476 params = zstd.get_frame_parameters(dest.getvalue())
464 477 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
465 478 self.assertEqual(params.window_size, 2097152)
466 479 self.assertEqual(params.dict_id, 0)
467 480 self.assertFalse(params.has_checksum)
468 481
469 482 def test_write_checksum(self):
470 483 source = io.BytesIO(b"foobar")
471 484 no_checksum = io.BytesIO()
472 485
473 486 cctx = zstd.ZstdCompressor(level=1)
474 487 cctx.copy_stream(source, no_checksum)
475 488
476 489 source.seek(0)
477 490 with_checksum = io.BytesIO()
478 491 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
479 492 cctx.copy_stream(source, with_checksum)
480 493
481 self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4)
494 self.assertEqual(
495 len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4
496 )
482 497
483 498 no_params = zstd.get_frame_parameters(no_checksum.getvalue())
484 499 with_params = zstd.get_frame_parameters(with_checksum.getvalue())
485 500 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
486 501 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
487 502 self.assertEqual(no_params.dict_id, 0)
488 503 self.assertEqual(with_params.dict_id, 0)
489 504 self.assertFalse(no_params.has_checksum)
490 505 self.assertTrue(with_params.has_checksum)
491 506
492 507 def test_write_content_size(self):
493 508 source = io.BytesIO(b"foobar" * 256)
494 509 no_size = io.BytesIO()
495 510
496 511 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
497 512 cctx.copy_stream(source, no_size)
498 513
499 514 source.seek(0)
500 515 with_size = io.BytesIO()
501 516 cctx = zstd.ZstdCompressor(level=1)
502 517 cctx.copy_stream(source, with_size)
503 518
504 519 # Source content size is unknown, so no content size written.
505 520 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
506 521
507 522 source.seek(0)
508 523 with_size = io.BytesIO()
509 524 cctx.copy_stream(source, with_size, size=len(source.getvalue()))
510 525
511 526 # We specified source size, so content size header is present.
512 527 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
513 528
514 529 no_params = zstd.get_frame_parameters(no_size.getvalue())
515 530 with_params = zstd.get_frame_parameters(with_size.getvalue())
516 531 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
517 532 self.assertEqual(with_params.content_size, 1536)
518 533 self.assertEqual(no_params.dict_id, 0)
519 534 self.assertEqual(with_params.dict_id, 0)
520 535 self.assertFalse(no_params.has_checksum)
521 536 self.assertFalse(with_params.has_checksum)
522 537
523 538 def test_read_write_size(self):
524 539 source = OpCountingBytesIO(b"foobarfoobar")
525 540 dest = OpCountingBytesIO()
526 541 cctx = zstd.ZstdCompressor()
527 542 r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1)
528 543
529 544 self.assertEqual(r, len(source.getvalue()))
530 545 self.assertEqual(w, 21)
531 546 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
532 547 self.assertEqual(dest._write_count, len(dest.getvalue()))
533 548
534 549 def test_multithreaded(self):
535 550 source = io.BytesIO()
536 551 source.write(b"a" * 1048576)
537 552 source.write(b"b" * 1048576)
538 553 source.write(b"c" * 1048576)
539 554 source.seek(0)
540 555
541 556 dest = io.BytesIO()
542 557 cctx = zstd.ZstdCompressor(threads=2, write_content_size=False)
543 558 r, w = cctx.copy_stream(source, dest)
544 559 self.assertEqual(r, 3145728)
545 560 self.assertEqual(w, 111)
546 561
547 562 params = zstd.get_frame_parameters(dest.getvalue())
548 563 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
549 564 self.assertEqual(params.dict_id, 0)
550 565 self.assertFalse(params.has_checksum)
551 566
552 567 # Writing content size and checksum works.
553 568 cctx = zstd.ZstdCompressor(threads=2, write_checksum=True)
554 569 dest = io.BytesIO()
555 570 source.seek(0)
556 571 cctx.copy_stream(source, dest, size=len(source.getvalue()))
557 572
558 573 params = zstd.get_frame_parameters(dest.getvalue())
559 574 self.assertEqual(params.content_size, 3145728)
560 575 self.assertEqual(params.dict_id, 0)
561 576 self.assertTrue(params.has_checksum)
562 577
563 578 def test_bad_size(self):
564 579 source = io.BytesIO()
565 580 source.write(b"a" * 32768)
566 581 source.write(b"b" * 32768)
567 582 source.seek(0)
568 583
569 584 dest = io.BytesIO()
570 585
571 586 cctx = zstd.ZstdCompressor()
572 587
573 588 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
574 589 cctx.copy_stream(source, dest, size=42)
575 590
576 591 # Try another operation on this compressor.
577 592 source.seek(0)
578 593 dest = io.BytesIO()
579 594 cctx.copy_stream(source, dest)
580 595
581 596
582 597 @make_cffi
583 598 class TestCompressor_stream_reader(TestCase):
584 599 def test_context_manager(self):
585 600 cctx = zstd.ZstdCompressor()
586 601
587 602 with cctx.stream_reader(b"foo") as reader:
588 with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"):
603 with self.assertRaisesRegex(
604 ValueError, "cannot __enter__ multiple times"
605 ):
589 606 with reader as reader2:
590 607 pass
591 608
592 609 def test_no_context_manager(self):
593 610 cctx = zstd.ZstdCompressor()
594 611
595 612 reader = cctx.stream_reader(b"foo")
596 613 reader.read(4)
597 614 self.assertFalse(reader.closed)
598 615
599 616 reader.close()
600 617 self.assertTrue(reader.closed)
601 618 with self.assertRaisesRegex(ValueError, "stream is closed"):
602 619 reader.read(1)
603 620
604 621 def test_not_implemented(self):
605 622 cctx = zstd.ZstdCompressor()
606 623
607 624 with cctx.stream_reader(b"foo" * 60) as reader:
608 625 with self.assertRaises(io.UnsupportedOperation):
609 626 reader.readline()
610 627
611 628 with self.assertRaises(io.UnsupportedOperation):
612 629 reader.readlines()
613 630
614 631 with self.assertRaises(io.UnsupportedOperation):
615 632 iter(reader)
616 633
617 634 with self.assertRaises(io.UnsupportedOperation):
618 635 next(reader)
619 636
620 637 with self.assertRaises(OSError):
621 638 reader.writelines([])
622 639
623 640 with self.assertRaises(OSError):
624 641 reader.write(b"foo")
625 642
626 643 def test_constant_methods(self):
627 644 cctx = zstd.ZstdCompressor()
628 645
629 646 with cctx.stream_reader(b"boo") as reader:
630 647 self.assertTrue(reader.readable())
631 648 self.assertFalse(reader.writable())
632 649 self.assertFalse(reader.seekable())
633 650 self.assertFalse(reader.isatty())
634 651 self.assertFalse(reader.closed)
635 652 self.assertIsNone(reader.flush())
636 653 self.assertFalse(reader.closed)
637 654
638 655 self.assertTrue(reader.closed)
639 656
640 657 def test_read_closed(self):
641 658 cctx = zstd.ZstdCompressor()
642 659
643 660 with cctx.stream_reader(b"foo" * 60) as reader:
644 661 reader.close()
645 662 self.assertTrue(reader.closed)
646 663 with self.assertRaisesRegex(ValueError, "stream is closed"):
647 664 reader.read(10)
648 665
649 666 def test_read_sizes(self):
650 667 cctx = zstd.ZstdCompressor()
651 668 foo = cctx.compress(b"foo")
652 669
653 670 with cctx.stream_reader(b"foo") as reader:
654 671 with self.assertRaisesRegex(
655 672 ValueError, "cannot read negative amounts less than -1"
656 673 ):
657 674 reader.read(-2)
658 675
659 676 self.assertEqual(reader.read(0), b"")
660 677 self.assertEqual(reader.read(), foo)
661 678
662 679 def test_read_buffer(self):
663 680 cctx = zstd.ZstdCompressor()
664 681
665 682 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
666 683 frame = cctx.compress(source)
667 684
668 685 with cctx.stream_reader(source) as reader:
669 686 self.assertEqual(reader.tell(), 0)
670 687
671 688 # We should get entire frame in one read.
672 689 result = reader.read(8192)
673 690 self.assertEqual(result, frame)
674 691 self.assertEqual(reader.tell(), len(result))
675 692 self.assertEqual(reader.read(), b"")
676 693 self.assertEqual(reader.tell(), len(result))
677 694
678 695 def test_read_buffer_small_chunks(self):
679 696 cctx = zstd.ZstdCompressor()
680 697
681 698 source = b"foo" * 60
682 699 chunks = []
683 700
684 701 with cctx.stream_reader(source) as reader:
685 702 self.assertEqual(reader.tell(), 0)
686 703
687 704 while True:
688 705 chunk = reader.read(1)
689 706 if not chunk:
690 707 break
691 708
692 709 chunks.append(chunk)
693 710 self.assertEqual(reader.tell(), sum(map(len, chunks)))
694 711
695 712 self.assertEqual(b"".join(chunks), cctx.compress(source))
696 713
697 714 def test_read_stream(self):
698 715 cctx = zstd.ZstdCompressor()
699 716
700 717 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
701 718 frame = cctx.compress(source)
702 719
703 720 with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
704 721 self.assertEqual(reader.tell(), 0)
705 722
706 723 chunk = reader.read(8192)
707 724 self.assertEqual(chunk, frame)
708 725 self.assertEqual(reader.tell(), len(chunk))
709 726 self.assertEqual(reader.read(), b"")
710 727 self.assertEqual(reader.tell(), len(chunk))
711 728
712 729 def test_read_stream_small_chunks(self):
713 730 cctx = zstd.ZstdCompressor()
714 731
715 732 source = b"foo" * 60
716 733 chunks = []
717 734
718 735 with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
719 736 self.assertEqual(reader.tell(), 0)
720 737
721 738 while True:
722 739 chunk = reader.read(1)
723 740 if not chunk:
724 741 break
725 742
726 743 chunks.append(chunk)
727 744 self.assertEqual(reader.tell(), sum(map(len, chunks)))
728 745
729 746 self.assertEqual(b"".join(chunks), cctx.compress(source))
730 747
731 748 def test_read_after_exit(self):
732 749 cctx = zstd.ZstdCompressor()
733 750
734 751 with cctx.stream_reader(b"foo" * 60) as reader:
735 752 while reader.read(8192):
736 753 pass
737 754
738 755 with self.assertRaisesRegex(ValueError, "stream is closed"):
739 756 reader.read(10)
740 757
741 758 def test_bad_size(self):
742 759 cctx = zstd.ZstdCompressor()
743 760
744 761 source = io.BytesIO(b"foobar")
745 762
746 763 with cctx.stream_reader(source, size=2) as reader:
747 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
764 with self.assertRaisesRegex(
765 zstd.ZstdError, "Src size is incorrect"
766 ):
748 767 reader.read(10)
749 768
750 769 # Try another compression operation.
751 770 with cctx.stream_reader(source, size=42):
752 771 pass
753 772
754 773 def test_readall(self):
755 774 cctx = zstd.ZstdCompressor()
756 775 frame = cctx.compress(b"foo" * 1024)
757 776
758 777 reader = cctx.stream_reader(b"foo" * 1024)
759 778 self.assertEqual(reader.readall(), frame)
760 779
761 780 def test_readinto(self):
762 781 cctx = zstd.ZstdCompressor()
763 782 foo = cctx.compress(b"foo")
764 783
765 784 reader = cctx.stream_reader(b"foo")
766 785 with self.assertRaises(Exception):
767 786 reader.readinto(b"foobar")
768 787
769 788 # readinto() with sufficiently large destination.
770 789 b = bytearray(1024)
771 790 reader = cctx.stream_reader(b"foo")
772 791 self.assertEqual(reader.readinto(b), len(foo))
773 792 self.assertEqual(b[0 : len(foo)], foo)
774 793 self.assertEqual(reader.readinto(b), 0)
775 794 self.assertEqual(b[0 : len(foo)], foo)
776 795
777 796 # readinto() with small reads.
778 797 b = bytearray(1024)
779 798 reader = cctx.stream_reader(b"foo", read_size=1)
780 799 self.assertEqual(reader.readinto(b), len(foo))
781 800 self.assertEqual(b[0 : len(foo)], foo)
782 801
783 802 # Too small destination buffer.
784 803 b = bytearray(2)
785 804 reader = cctx.stream_reader(b"foo")
786 805 self.assertEqual(reader.readinto(b), 2)
787 806 self.assertEqual(b[:], foo[0:2])
788 807 self.assertEqual(reader.readinto(b), 2)
789 808 self.assertEqual(b[:], foo[2:4])
790 809 self.assertEqual(reader.readinto(b), 2)
791 810 self.assertEqual(b[:], foo[4:6])
792 811
793 812 def test_readinto1(self):
794 813 cctx = zstd.ZstdCompressor()
795 814 foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
796 815
797 816 reader = cctx.stream_reader(b"foo")
798 817 with self.assertRaises(Exception):
799 818 reader.readinto1(b"foobar")
800 819
801 820 b = bytearray(1024)
802 821 source = OpCountingBytesIO(b"foo")
803 822 reader = cctx.stream_reader(source)
804 823 self.assertEqual(reader.readinto1(b), len(foo))
805 824 self.assertEqual(b[0 : len(foo)], foo)
806 825 self.assertEqual(source._read_count, 2)
807 826
808 827 # readinto1() with small reads.
809 828 b = bytearray(1024)
810 829 source = OpCountingBytesIO(b"foo")
811 830 reader = cctx.stream_reader(source, read_size=1)
812 831 self.assertEqual(reader.readinto1(b), len(foo))
813 832 self.assertEqual(b[0 : len(foo)], foo)
814 833 self.assertEqual(source._read_count, 4)
815 834
816 835 def test_read1(self):
817 836 cctx = zstd.ZstdCompressor()
818 837 foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
819 838
820 839 b = OpCountingBytesIO(b"foo")
821 840 reader = cctx.stream_reader(b)
822 841
823 842 self.assertEqual(reader.read1(), foo)
824 843 self.assertEqual(b._read_count, 2)
825 844
826 845 b = OpCountingBytesIO(b"foo")
827 846 reader = cctx.stream_reader(b)
828 847
829 848 self.assertEqual(reader.read1(0), b"")
830 849 self.assertEqual(reader.read1(2), foo[0:2])
831 850 self.assertEqual(b._read_count, 2)
832 851 self.assertEqual(reader.read1(2), foo[2:4])
833 852 self.assertEqual(reader.read1(1024), foo[4:])
834 853
835 854
836 855 @make_cffi
837 856 class TestCompressor_stream_writer(TestCase):
838 857 def test_io_api(self):
839 858 buffer = io.BytesIO()
840 859 cctx = zstd.ZstdCompressor()
841 860 writer = cctx.stream_writer(buffer)
842 861
843 862 self.assertFalse(writer.isatty())
844 863 self.assertFalse(writer.readable())
845 864
846 865 with self.assertRaises(io.UnsupportedOperation):
847 866 writer.readline()
848 867
849 868 with self.assertRaises(io.UnsupportedOperation):
850 869 writer.readline(42)
851 870
852 871 with self.assertRaises(io.UnsupportedOperation):
853 872 writer.readline(size=42)
854 873
855 874 with self.assertRaises(io.UnsupportedOperation):
856 875 writer.readlines()
857 876
858 877 with self.assertRaises(io.UnsupportedOperation):
859 878 writer.readlines(42)
860 879
861 880 with self.assertRaises(io.UnsupportedOperation):
862 881 writer.readlines(hint=42)
863 882
864 883 with self.assertRaises(io.UnsupportedOperation):
865 884 writer.seek(0)
866 885
867 886 with self.assertRaises(io.UnsupportedOperation):
868 887 writer.seek(10, os.SEEK_SET)
869 888
870 889 self.assertFalse(writer.seekable())
871 890
872 891 with self.assertRaises(io.UnsupportedOperation):
873 892 writer.truncate()
874 893
875 894 with self.assertRaises(io.UnsupportedOperation):
876 895 writer.truncate(42)
877 896
878 897 with self.assertRaises(io.UnsupportedOperation):
879 898 writer.truncate(size=42)
880 899
881 900 self.assertTrue(writer.writable())
882 901
883 902 with self.assertRaises(NotImplementedError):
884 903 writer.writelines([])
885 904
886 905 with self.assertRaises(io.UnsupportedOperation):
887 906 writer.read()
888 907
889 908 with self.assertRaises(io.UnsupportedOperation):
890 909 writer.read(42)
891 910
892 911 with self.assertRaises(io.UnsupportedOperation):
893 912 writer.read(size=42)
894 913
895 914 with self.assertRaises(io.UnsupportedOperation):
896 915 writer.readall()
897 916
898 917 with self.assertRaises(io.UnsupportedOperation):
899 918 writer.readinto(None)
900 919
901 920 with self.assertRaises(io.UnsupportedOperation):
902 921 writer.fileno()
903 922
904 923 self.assertFalse(writer.closed)
905 924
906 925 def test_fileno_file(self):
907 926 with tempfile.TemporaryFile("wb") as tf:
908 927 cctx = zstd.ZstdCompressor()
909 928 writer = cctx.stream_writer(tf)
910 929
911 930 self.assertEqual(writer.fileno(), tf.fileno())
912 931
913 932 def test_close(self):
914 933 buffer = NonClosingBytesIO()
915 934 cctx = zstd.ZstdCompressor(level=1)
916 935 writer = cctx.stream_writer(buffer)
917 936
918 937 writer.write(b"foo" * 1024)
919 938 self.assertFalse(writer.closed)
920 939 self.assertFalse(buffer.closed)
921 940 writer.close()
922 941 self.assertTrue(writer.closed)
923 942 self.assertTrue(buffer.closed)
924 943
925 944 with self.assertRaisesRegex(ValueError, "stream is closed"):
926 945 writer.write(b"foo")
927 946
928 947 with self.assertRaisesRegex(ValueError, "stream is closed"):
929 948 writer.flush()
930 949
931 950 with self.assertRaisesRegex(ValueError, "stream is closed"):
932 951 with writer:
933 952 pass
934 953
935 954 self.assertEqual(
936 955 buffer.getvalue(),
937 956 b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f"
938 957 b"\x6f\x01\x00\xfa\xd3\x77\x43",
939 958 )
940 959
941 960 # Context manager exit should close stream.
942 961 buffer = io.BytesIO()
943 962 writer = cctx.stream_writer(buffer)
944 963
945 964 with writer:
946 965 writer.write(b"foo")
947 966
948 967 self.assertTrue(writer.closed)
949 968
950 969 def test_empty(self):
951 970 buffer = NonClosingBytesIO()
952 971 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
953 972 with cctx.stream_writer(buffer) as compressor:
954 973 compressor.write(b"")
955 974
956 975 result = buffer.getvalue()
957 976 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
958 977
959 978 params = zstd.get_frame_parameters(result)
960 979 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
961 980 self.assertEqual(params.window_size, 524288)
962 981 self.assertEqual(params.dict_id, 0)
963 982 self.assertFalse(params.has_checksum)
964 983
965 984 # Test without context manager.
966 985 buffer = io.BytesIO()
967 986 compressor = cctx.stream_writer(buffer)
968 987 self.assertEqual(compressor.write(b""), 0)
969 988 self.assertEqual(buffer.getvalue(), b"")
970 989 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9)
971 990 result = buffer.getvalue()
972 991 self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
973 992
974 993 params = zstd.get_frame_parameters(result)
975 994 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
976 995 self.assertEqual(params.window_size, 524288)
977 996 self.assertEqual(params.dict_id, 0)
978 997 self.assertFalse(params.has_checksum)
979 998
980 999 # Test write_return_read=True
981 1000 compressor = cctx.stream_writer(buffer, write_return_read=True)
982 1001 self.assertEqual(compressor.write(b""), 0)
983 1002
984 1003 def test_input_types(self):
985 1004 expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
986 1005 cctx = zstd.ZstdCompressor(level=1)
987 1006
988 1007 mutable_array = bytearray(3)
989 1008 mutable_array[:] = b"foo"
990 1009
991 1010 sources = [
992 1011 memoryview(b"foo"),
993 1012 bytearray(b"foo"),
994 1013 mutable_array,
995 1014 ]
996 1015
997 1016 for source in sources:
998 1017 buffer = NonClosingBytesIO()
999 1018 with cctx.stream_writer(buffer) as compressor:
1000 1019 compressor.write(source)
1001 1020
1002 1021 self.assertEqual(buffer.getvalue(), expected)
1003 1022
1004 1023 compressor = cctx.stream_writer(buffer, write_return_read=True)
1005 1024 self.assertEqual(compressor.write(source), len(source))
1006 1025
1007 1026 def test_multiple_compress(self):
1008 1027 buffer = NonClosingBytesIO()
1009 1028 cctx = zstd.ZstdCompressor(level=5)
1010 1029 with cctx.stream_writer(buffer) as compressor:
1011 1030 self.assertEqual(compressor.write(b"foo"), 0)
1012 1031 self.assertEqual(compressor.write(b"bar"), 0)
1013 1032 self.assertEqual(compressor.write(b"x" * 8192), 0)
1014 1033
1015 1034 result = buffer.getvalue()
1016 1035 self.assertEqual(
1017 1036 result,
1018 1037 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
1019 1038 b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
1020 1039 )
1021 1040
1022 1041 # Test without context manager.
1023 1042 buffer = io.BytesIO()
1024 1043 compressor = cctx.stream_writer(buffer)
1025 1044 self.assertEqual(compressor.write(b"foo"), 0)
1026 1045 self.assertEqual(compressor.write(b"bar"), 0)
1027 1046 self.assertEqual(compressor.write(b"x" * 8192), 0)
1028 1047 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
1029 1048 result = buffer.getvalue()
1030 1049 self.assertEqual(
1031 1050 result,
1032 1051 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
1033 1052 b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
1034 1053 )
1035 1054
1036 1055 # Test with write_return_read=True.
1037 1056 compressor = cctx.stream_writer(buffer, write_return_read=True)
1038 1057 self.assertEqual(compressor.write(b"foo"), 3)
1039 1058 self.assertEqual(compressor.write(b"barbiz"), 6)
1040 1059 self.assertEqual(compressor.write(b"x" * 8192), 8192)
1041 1060
1042 1061 def test_dictionary(self):
1043 1062 samples = []
1044 1063 for i in range(128):
1045 1064 samples.append(b"foo" * 64)
1046 1065 samples.append(b"bar" * 64)
1047 1066 samples.append(b"foobar" * 64)
1048 1067
1049 1068 d = zstd.train_dictionary(8192, samples)
1050 1069
1051 1070 h = hashlib.sha1(d.as_bytes()).hexdigest()
1052 1071 self.assertEqual(h, "7a2e59a876db958f74257141045af8f912e00d4e")
1053 1072
1054 1073 buffer = NonClosingBytesIO()
1055 1074 cctx = zstd.ZstdCompressor(level=9, dict_data=d)
1056 1075 with cctx.stream_writer(buffer) as compressor:
1057 1076 self.assertEqual(compressor.write(b"foo"), 0)
1058 1077 self.assertEqual(compressor.write(b"bar"), 0)
1059 1078 self.assertEqual(compressor.write(b"foo" * 16384), 0)
1060 1079
1061 1080 compressed = buffer.getvalue()
1062 1081
1063 1082 params = zstd.get_frame_parameters(compressed)
1064 1083 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1065 1084 self.assertEqual(params.window_size, 2097152)
1066 1085 self.assertEqual(params.dict_id, d.dict_id())
1067 1086 self.assertFalse(params.has_checksum)
1068 1087
1069 1088 h = hashlib.sha1(compressed).hexdigest()
1070 1089 self.assertEqual(h, "0a7c05635061f58039727cdbe76388c6f4cfef06")
1071 1090
1072 1091 source = b"foo" + b"bar" + (b"foo" * 16384)
1073 1092
1074 1093 dctx = zstd.ZstdDecompressor(dict_data=d)
1075 1094
1076 1095 self.assertEqual(
1077 1096 dctx.decompress(compressed, max_output_size=len(source)), source
1078 1097 )
1079 1098
1080 1099 def test_compression_params(self):
1081 1100 params = zstd.ZstdCompressionParameters(
1082 1101 window_log=20,
1083 1102 chain_log=6,
1084 1103 hash_log=12,
1085 1104 min_match=5,
1086 1105 search_log=4,
1087 1106 target_length=10,
1088 1107 strategy=zstd.STRATEGY_FAST,
1089 1108 )
1090 1109
1091 1110 buffer = NonClosingBytesIO()
1092 1111 cctx = zstd.ZstdCompressor(compression_params=params)
1093 1112 with cctx.stream_writer(buffer) as compressor:
1094 1113 self.assertEqual(compressor.write(b"foo"), 0)
1095 1114 self.assertEqual(compressor.write(b"bar"), 0)
1096 1115 self.assertEqual(compressor.write(b"foobar" * 16384), 0)
1097 1116
1098 1117 compressed = buffer.getvalue()
1099 1118
1100 1119 params = zstd.get_frame_parameters(compressed)
1101 1120 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1102 1121 self.assertEqual(params.window_size, 1048576)
1103 1122 self.assertEqual(params.dict_id, 0)
1104 1123 self.assertFalse(params.has_checksum)
1105 1124
1106 1125 h = hashlib.sha1(compressed).hexdigest()
1107 1126 self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b")
1108 1127
1109 1128 def test_write_checksum(self):
1110 1129 no_checksum = NonClosingBytesIO()
1111 1130 cctx = zstd.ZstdCompressor(level=1)
1112 1131 with cctx.stream_writer(no_checksum) as compressor:
1113 1132 self.assertEqual(compressor.write(b"foobar"), 0)
1114 1133
1115 1134 with_checksum = NonClosingBytesIO()
1116 1135 cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
1117 1136 with cctx.stream_writer(with_checksum) as compressor:
1118 1137 self.assertEqual(compressor.write(b"foobar"), 0)
1119 1138
1120 1139 no_params = zstd.get_frame_parameters(no_checksum.getvalue())
1121 1140 with_params = zstd.get_frame_parameters(with_checksum.getvalue())
1122 1141 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1123 1142 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1124 1143 self.assertEqual(no_params.dict_id, 0)
1125 1144 self.assertEqual(with_params.dict_id, 0)
1126 1145 self.assertFalse(no_params.has_checksum)
1127 1146 self.assertTrue(with_params.has_checksum)
1128 1147
1129 self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4)
1148 self.assertEqual(
1149 len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4
1150 )
1130 1151
1131 1152 def test_write_content_size(self):
1132 1153 no_size = NonClosingBytesIO()
1133 1154 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1134 1155 with cctx.stream_writer(no_size) as compressor:
1135 1156 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1136 1157
1137 1158 with_size = NonClosingBytesIO()
1138 1159 cctx = zstd.ZstdCompressor(level=1)
1139 1160 with cctx.stream_writer(with_size) as compressor:
1140 1161 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1141 1162
1142 1163 # Source size is not known in streaming mode, so header not
1143 1164 # written.
1144 1165 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
1145 1166
1146 1167 # Declaring size will write the header.
1147 1168 with_size = NonClosingBytesIO()
1148 with cctx.stream_writer(with_size, size=len(b"foobar" * 256)) as compressor:
1169 with cctx.stream_writer(
1170 with_size, size=len(b"foobar" * 256)
1171 ) as compressor:
1149 1172 self.assertEqual(compressor.write(b"foobar" * 256), 0)
1150 1173
1151 1174 no_params = zstd.get_frame_parameters(no_size.getvalue())
1152 1175 with_params = zstd.get_frame_parameters(with_size.getvalue())
1153 1176 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1154 1177 self.assertEqual(with_params.content_size, 1536)
1155 1178 self.assertEqual(no_params.dict_id, 0)
1156 1179 self.assertEqual(with_params.dict_id, 0)
1157 1180 self.assertFalse(no_params.has_checksum)
1158 1181 self.assertFalse(with_params.has_checksum)
1159 1182
1160 1183 self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
1161 1184
1162 1185 def test_no_dict_id(self):
1163 1186 samples = []
1164 1187 for i in range(128):
1165 1188 samples.append(b"foo" * 64)
1166 1189 samples.append(b"bar" * 64)
1167 1190 samples.append(b"foobar" * 64)
1168 1191
1169 1192 d = zstd.train_dictionary(1024, samples)
1170 1193
1171 1194 with_dict_id = NonClosingBytesIO()
1172 1195 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
1173 1196 with cctx.stream_writer(with_dict_id) as compressor:
1174 1197 self.assertEqual(compressor.write(b"foobarfoobar"), 0)
1175 1198
1176 1199 self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03")
1177 1200
1178 1201 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
1179 1202 no_dict_id = NonClosingBytesIO()
1180 1203 with cctx.stream_writer(no_dict_id) as compressor:
1181 1204 self.assertEqual(compressor.write(b"foobarfoobar"), 0)
1182 1205
1183 1206 self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00")
1184 1207
1185 1208 no_params = zstd.get_frame_parameters(no_dict_id.getvalue())
1186 1209 with_params = zstd.get_frame_parameters(with_dict_id.getvalue())
1187 1210 self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1188 1211 self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1189 1212 self.assertEqual(no_params.dict_id, 0)
1190 1213 self.assertEqual(with_params.dict_id, d.dict_id())
1191 1214 self.assertFalse(no_params.has_checksum)
1192 1215 self.assertFalse(with_params.has_checksum)
1193 1216
1194 self.assertEqual(len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4)
1217 self.assertEqual(
1218 len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4
1219 )
1195 1220
1196 1221 def test_memory_size(self):
1197 1222 cctx = zstd.ZstdCompressor(level=3)
1198 1223 buffer = io.BytesIO()
1199 1224 with cctx.stream_writer(buffer) as compressor:
1200 1225 compressor.write(b"foo")
1201 1226 size = compressor.memory_size()
1202 1227
1203 1228 self.assertGreater(size, 100000)
1204 1229
1205 1230 def test_write_size(self):
1206 1231 cctx = zstd.ZstdCompressor(level=3)
1207 1232 dest = OpCountingBytesIO()
1208 1233 with cctx.stream_writer(dest, write_size=1) as compressor:
1209 1234 self.assertEqual(compressor.write(b"foo"), 0)
1210 1235 self.assertEqual(compressor.write(b"bar"), 0)
1211 1236 self.assertEqual(compressor.write(b"foobar"), 0)
1212 1237
1213 1238 self.assertEqual(len(dest.getvalue()), dest._write_count)
1214 1239
1215 1240 def test_flush_repeated(self):
1216 1241 cctx = zstd.ZstdCompressor(level=3)
1217 1242 dest = OpCountingBytesIO()
1218 1243 with cctx.stream_writer(dest) as compressor:
1219 1244 self.assertEqual(compressor.write(b"foo"), 0)
1220 1245 self.assertEqual(dest._write_count, 0)
1221 1246 self.assertEqual(compressor.flush(), 12)
1222 1247 self.assertEqual(dest._write_count, 1)
1223 1248 self.assertEqual(compressor.write(b"bar"), 0)
1224 1249 self.assertEqual(dest._write_count, 1)
1225 1250 self.assertEqual(compressor.flush(), 6)
1226 1251 self.assertEqual(dest._write_count, 2)
1227 1252 self.assertEqual(compressor.write(b"baz"), 0)
1228 1253
1229 1254 self.assertEqual(dest._write_count, 3)
1230 1255
1231 1256 def test_flush_empty_block(self):
1232 1257 cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
1233 1258 dest = OpCountingBytesIO()
1234 1259 with cctx.stream_writer(dest) as compressor:
1235 1260 self.assertEqual(compressor.write(b"foobar" * 8192), 0)
1236 1261 count = dest._write_count
1237 1262 offset = dest.tell()
1238 1263 self.assertEqual(compressor.flush(), 23)
1239 1264 self.assertGreater(dest._write_count, count)
1240 1265 self.assertGreater(dest.tell(), offset)
1241 1266 offset = dest.tell()
1242 1267 # Ending the write here should cause an empty block to be written
1243 1268 # to denote end of frame.
1244 1269
1245 1270 trailing = dest.getvalue()[offset:]
1246 1271 # 3 bytes block header + 4 bytes frame checksum
1247 1272 self.assertEqual(len(trailing), 7)
1248 1273
1249 1274 header = trailing[0:3]
1250 1275 self.assertEqual(header, b"\x01\x00\x00")
1251 1276
1252 1277 def test_flush_frame(self):
1253 1278 cctx = zstd.ZstdCompressor(level=3)
1254 1279 dest = OpCountingBytesIO()
1255 1280
1256 1281 with cctx.stream_writer(dest) as compressor:
1257 1282 self.assertEqual(compressor.write(b"foobar" * 8192), 0)
1258 1283 self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
1259 1284 compressor.write(b"biz" * 16384)
1260 1285
1261 1286 self.assertEqual(
1262 1287 dest.getvalue(),
1263 1288 # Frame 1.
1264 1289 b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f"
1265 1290 b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08"
1266 1291 # Frame 2.
1267 1292 b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a"
1268 1293 b"\x01\x00\xfa\x3f\x75\x37\x04",
1269 1294 )
1270 1295
1271 1296 def test_bad_flush_mode(self):
1272 1297 cctx = zstd.ZstdCompressor()
1273 1298 dest = io.BytesIO()
1274 1299 with cctx.stream_writer(dest) as compressor:
1275 1300 with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"):
1276 1301 compressor.flush(flush_mode=42)
1277 1302
1278 1303 def test_multithreaded(self):
1279 1304 dest = NonClosingBytesIO()
1280 1305 cctx = zstd.ZstdCompressor(threads=2)
1281 1306 with cctx.stream_writer(dest) as compressor:
1282 1307 compressor.write(b"a" * 1048576)
1283 1308 compressor.write(b"b" * 1048576)
1284 1309 compressor.write(b"c" * 1048576)
1285 1310
1286 1311 self.assertEqual(len(dest.getvalue()), 111)
1287 1312
1288 1313 def test_tell(self):
1289 1314 dest = io.BytesIO()
1290 1315 cctx = zstd.ZstdCompressor()
1291 1316 with cctx.stream_writer(dest) as compressor:
1292 1317 self.assertEqual(compressor.tell(), 0)
1293 1318
1294 1319 for i in range(256):
1295 1320 compressor.write(b"foo" * (i + 1))
1296 1321 self.assertEqual(compressor.tell(), dest.tell())
1297 1322
1298 1323 def test_bad_size(self):
1299 1324 cctx = zstd.ZstdCompressor()
1300 1325
1301 1326 dest = io.BytesIO()
1302 1327
1303 1328 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
1304 1329 with cctx.stream_writer(dest, size=2) as compressor:
1305 1330 compressor.write(b"foo")
1306 1331
1307 1332 # Test another operation.
1308 1333 with cctx.stream_writer(dest, size=42):
1309 1334 pass
1310 1335
1311 1336 def test_tarfile_compat(self):
1312 1337 dest = NonClosingBytesIO()
1313 1338 cctx = zstd.ZstdCompressor()
1314 1339 with cctx.stream_writer(dest) as compressor:
1315 1340 with tarfile.open("tf", mode="w|", fileobj=compressor) as tf:
1316 1341 tf.add(__file__, "test_compressor.py")
1317 1342
1318 1343 dest = io.BytesIO(dest.getvalue())
1319 1344
1320 1345 dctx = zstd.ZstdDecompressor()
1321 1346 with dctx.stream_reader(dest) as reader:
1322 1347 with tarfile.open(mode="r|", fileobj=reader) as tf:
1323 1348 for member in tf:
1324 1349 self.assertEqual(member.name, "test_compressor.py")
1325 1350
1326 1351
1327 1352 @make_cffi
1328 1353 class TestCompressor_read_to_iter(TestCase):
1329 1354 def test_type_validation(self):
1330 1355 cctx = zstd.ZstdCompressor()
1331 1356
1332 1357 # Object with read() works.
1333 1358 for chunk in cctx.read_to_iter(io.BytesIO()):
1334 1359 pass
1335 1360
1336 1361 # Buffer protocol works.
1337 1362 for chunk in cctx.read_to_iter(b"foobar"):
1338 1363 pass
1339 1364
1340 with self.assertRaisesRegex(ValueError, "must pass an object with a read"):
1365 with self.assertRaisesRegex(
1366 ValueError, "must pass an object with a read"
1367 ):
1341 1368 for chunk in cctx.read_to_iter(True):
1342 1369 pass
1343 1370
1344 1371 def test_read_empty(self):
1345 1372 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1346 1373
1347 1374 source = io.BytesIO()
1348 1375 it = cctx.read_to_iter(source)
1349 1376 chunks = list(it)
1350 1377 self.assertEqual(len(chunks), 1)
1351 1378 compressed = b"".join(chunks)
1352 1379 self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
1353 1380
1354 1381 # And again with the buffer protocol.
1355 1382 it = cctx.read_to_iter(b"")
1356 1383 chunks = list(it)
1357 1384 self.assertEqual(len(chunks), 1)
1358 1385 compressed2 = b"".join(chunks)
1359 1386 self.assertEqual(compressed2, compressed)
1360 1387
1361 1388 def test_read_large(self):
1362 1389 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1363 1390
1364 1391 source = io.BytesIO()
1365 1392 source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
1366 1393 source.write(b"o")
1367 1394 source.seek(0)
1368 1395
1369 1396 # Creating an iterator should not perform any compression until
1370 1397 # first read.
1371 1398 it = cctx.read_to_iter(source, size=len(source.getvalue()))
1372 1399 self.assertEqual(source.tell(), 0)
1373 1400
1374 1401 # We should have exactly 2 output chunks.
1375 1402 chunks = []
1376 1403 chunk = next(it)
1377 1404 self.assertIsNotNone(chunk)
1378 1405 self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
1379 1406 chunks.append(chunk)
1380 1407 chunk = next(it)
1381 1408 self.assertIsNotNone(chunk)
1382 1409 chunks.append(chunk)
1383 1410
1384 1411 self.assertEqual(source.tell(), len(source.getvalue()))
1385 1412
1386 1413 with self.assertRaises(StopIteration):
1387 1414 next(it)
1388 1415
1389 1416 # And again for good measure.
1390 1417 with self.assertRaises(StopIteration):
1391 1418 next(it)
1392 1419
1393 1420 # We should get the same output as the one-shot compression mechanism.
1394 1421 self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
1395 1422
1396 1423 params = zstd.get_frame_parameters(b"".join(chunks))
1397 1424 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1398 1425 self.assertEqual(params.window_size, 262144)
1399 1426 self.assertEqual(params.dict_id, 0)
1400 1427 self.assertFalse(params.has_checksum)
1401 1428
1402 1429 # Now check the buffer protocol.
1403 1430 it = cctx.read_to_iter(source.getvalue())
1404 1431 chunks = list(it)
1405 1432 self.assertEqual(len(chunks), 2)
1406 1433
1407 1434 params = zstd.get_frame_parameters(b"".join(chunks))
1408 1435 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
1409 1436 # self.assertEqual(params.window_size, 262144)
1410 1437 self.assertEqual(params.dict_id, 0)
1411 1438 self.assertFalse(params.has_checksum)
1412 1439
1413 1440 self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
1414 1441
1415 1442 def test_read_write_size(self):
1416 1443 source = OpCountingBytesIO(b"foobarfoobar")
1417 1444 cctx = zstd.ZstdCompressor(level=3)
1418 1445 for chunk in cctx.read_to_iter(source, read_size=1, write_size=1):
1419 1446 self.assertEqual(len(chunk), 1)
1420 1447
1421 1448 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
1422 1449
1423 1450 def test_multithreaded(self):
1424 1451 source = io.BytesIO()
1425 1452 source.write(b"a" * 1048576)
1426 1453 source.write(b"b" * 1048576)
1427 1454 source.write(b"c" * 1048576)
1428 1455 source.seek(0)
1429 1456
1430 1457 cctx = zstd.ZstdCompressor(threads=2)
1431 1458
1432 1459 compressed = b"".join(cctx.read_to_iter(source))
1433 1460 self.assertEqual(len(compressed), 111)
1434 1461
1435 1462 def test_bad_size(self):
1436 1463 cctx = zstd.ZstdCompressor()
1437 1464
1438 1465 source = io.BytesIO(b"a" * 42)
1439 1466
1440 1467 with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
1441 1468 b"".join(cctx.read_to_iter(source, size=2))
1442 1469
1443 1470 # Test another operation on errored compressor.
1444 1471 b"".join(cctx.read_to_iter(source))
1445 1472
1446 1473
1447 1474 @make_cffi
1448 1475 class TestCompressor_chunker(TestCase):
1449 1476 def test_empty(self):
1450 1477 cctx = zstd.ZstdCompressor(write_content_size=False)
1451 1478 chunker = cctx.chunker()
1452 1479
1453 1480 it = chunker.compress(b"")
1454 1481
1455 1482 with self.assertRaises(StopIteration):
1456 1483 next(it)
1457 1484
1458 1485 it = chunker.finish()
1459 1486
1460 1487 self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x58\x01\x00\x00")
1461 1488
1462 1489 with self.assertRaises(StopIteration):
1463 1490 next(it)
1464 1491
1465 1492 def test_simple_input(self):
1466 1493 cctx = zstd.ZstdCompressor()
1467 1494 chunker = cctx.chunker()
1468 1495
1469 1496 it = chunker.compress(b"foobar")
1470 1497
1471 1498 with self.assertRaises(StopIteration):
1472 1499 next(it)
1473 1500
1474 1501 it = chunker.compress(b"baz" * 30)
1475 1502
1476 1503 with self.assertRaises(StopIteration):
1477 1504 next(it)
1478 1505
1479 1506 it = chunker.finish()
1480 1507
1481 1508 self.assertEqual(
1482 1509 next(it),
1483 1510 b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f"
1484 1511 b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e",
1485 1512 )
1486 1513
1487 1514 with self.assertRaises(StopIteration):
1488 1515 next(it)
1489 1516
1490 1517 def test_input_size(self):
1491 1518 cctx = zstd.ZstdCompressor()
1492 1519 chunker = cctx.chunker(size=1024)
1493 1520
1494 1521 it = chunker.compress(b"x" * 1000)
1495 1522
1496 1523 with self.assertRaises(StopIteration):
1497 1524 next(it)
1498 1525
1499 1526 it = chunker.compress(b"y" * 24)
1500 1527
1501 1528 with self.assertRaises(StopIteration):
1502 1529 next(it)
1503 1530
1504 1531 chunks = list(chunker.finish())
1505 1532
1506 1533 self.assertEqual(
1507 1534 chunks,
1508 1535 [
1509 1536 b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00"
1510 1537 b"\xa0\x16\xe3\x2b\x80\x05"
1511 1538 ],
1512 1539 )
1513 1540
1514 1541 dctx = zstd.ZstdDecompressor()
1515 1542
1516 self.assertEqual(dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24))
1543 self.assertEqual(
1544 dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24)
1545 )
1517 1546
1518 1547 def test_small_chunk_size(self):
1519 1548 cctx = zstd.ZstdCompressor()
1520 1549 chunker = cctx.chunker(chunk_size=1)
1521 1550
1522 1551 chunks = list(chunker.compress(b"foo" * 1024))
1523 1552 self.assertEqual(chunks, [])
1524 1553
1525 1554 chunks = list(chunker.finish())
1526 1555 self.assertTrue(all(len(chunk) == 1 for chunk in chunks))
1527 1556
1528 1557 self.assertEqual(
1529 1558 b"".join(chunks),
1530 1559 b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00"
1531 1560 b"\xfa\xd3\x77\x43",
1532 1561 )
1533 1562
1534 1563 dctx = zstd.ZstdDecompressor()
1535 1564 self.assertEqual(
1536 dctx.decompress(b"".join(chunks), max_output_size=10000), b"foo" * 1024
1565 dctx.decompress(b"".join(chunks), max_output_size=10000),
1566 b"foo" * 1024,
1537 1567 )
1538 1568
1539 1569 def test_input_types(self):
1540 1570 cctx = zstd.ZstdCompressor()
1541 1571
1542 1572 mutable_array = bytearray(3)
1543 1573 mutable_array[:] = b"foo"
1544 1574
1545 1575 sources = [
1546 1576 memoryview(b"foo"),
1547 1577 bytearray(b"foo"),
1548 1578 mutable_array,
1549 1579 ]
1550 1580
1551 1581 for source in sources:
1552 1582 chunker = cctx.chunker()
1553 1583
1554 1584 self.assertEqual(list(chunker.compress(source)), [])
1555 1585 self.assertEqual(
1556 1586 list(chunker.finish()),
1557 1587 [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"],
1558 1588 )
1559 1589
1560 1590 def test_flush(self):
1561 1591 cctx = zstd.ZstdCompressor()
1562 1592 chunker = cctx.chunker()
1563 1593
1564 1594 self.assertEqual(list(chunker.compress(b"foo" * 1024)), [])
1565 1595 self.assertEqual(list(chunker.compress(b"bar" * 1024)), [])
1566 1596
1567 1597 chunks1 = list(chunker.flush())
1568 1598
1569 1599 self.assertEqual(
1570 1600 chunks1,
1571 1601 [
1572 1602 b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72"
1573 1603 b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02"
1574 1604 ],
1575 1605 )
1576 1606
1577 1607 self.assertEqual(list(chunker.flush()), [])
1578 1608 self.assertEqual(list(chunker.flush()), [])
1579 1609
1580 1610 self.assertEqual(list(chunker.compress(b"baz" * 1024)), [])
1581 1611
1582 1612 chunks2 = list(chunker.flush())
1583 1613 self.assertEqual(len(chunks2), 1)
1584 1614
1585 1615 chunks3 = list(chunker.finish())
1586 1616 self.assertEqual(len(chunks2), 1)
1587 1617
1588 1618 dctx = zstd.ZstdDecompressor()
1589 1619
1590 1620 self.assertEqual(
1591 1621 dctx.decompress(
1592 1622 b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000
1593 1623 ),
1594 1624 (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024),
1595 1625 )
1596 1626
1597 1627 def test_compress_after_finish(self):
1598 1628 cctx = zstd.ZstdCompressor()
1599 1629 chunker = cctx.chunker()
1600 1630
1601 1631 list(chunker.compress(b"foo"))
1602 1632 list(chunker.finish())
1603 1633
1604 1634 with self.assertRaisesRegex(
1605 zstd.ZstdError, r"cannot call compress\(\) after compression finished"
1635 zstd.ZstdError,
1636 r"cannot call compress\(\) after compression finished",
1606 1637 ):
1607 1638 list(chunker.compress(b"foo"))
1608 1639
1609 1640 def test_flush_after_finish(self):
1610 1641 cctx = zstd.ZstdCompressor()
1611 1642 chunker = cctx.chunker()
1612 1643
1613 1644 list(chunker.compress(b"foo"))
1614 1645 list(chunker.finish())
1615 1646
1616 1647 with self.assertRaisesRegex(
1617 1648 zstd.ZstdError, r"cannot call flush\(\) after compression finished"
1618 1649 ):
1619 1650 list(chunker.flush())
1620 1651
1621 1652 def test_finish_after_finish(self):
1622 1653 cctx = zstd.ZstdCompressor()
1623 1654 chunker = cctx.chunker()
1624 1655
1625 1656 list(chunker.compress(b"foo"))
1626 1657 list(chunker.finish())
1627 1658
1628 1659 with self.assertRaisesRegex(
1629 1660 zstd.ZstdError, r"cannot call finish\(\) after compression finished"
1630 1661 ):
1631 1662 list(chunker.finish())
1632 1663
1633 1664
1634 1665 class TestCompressor_multi_compress_to_buffer(TestCase):
1635 1666 def test_invalid_inputs(self):
1636 1667 cctx = zstd.ZstdCompressor()
1637 1668
1638 1669 if not hasattr(cctx, "multi_compress_to_buffer"):
1639 1670 self.skipTest("multi_compress_to_buffer not available")
1640 1671
1641 1672 with self.assertRaises(TypeError):
1642 1673 cctx.multi_compress_to_buffer(True)
1643 1674
1644 1675 with self.assertRaises(TypeError):
1645 1676 cctx.multi_compress_to_buffer((1, 2))
1646 1677
1647 with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"):
1678 with self.assertRaisesRegex(
1679 TypeError, "item 0 not a bytes like object"
1680 ):
1648 1681 cctx.multi_compress_to_buffer([u"foo"])
1649 1682
1650 1683 def test_empty_input(self):
1651 1684 cctx = zstd.ZstdCompressor()
1652 1685
1653 1686 if not hasattr(cctx, "multi_compress_to_buffer"):
1654 1687 self.skipTest("multi_compress_to_buffer not available")
1655 1688
1656 1689 with self.assertRaisesRegex(ValueError, "no source elements found"):
1657 1690 cctx.multi_compress_to_buffer([])
1658 1691
1659 1692 with self.assertRaisesRegex(ValueError, "source elements are empty"):
1660 1693 cctx.multi_compress_to_buffer([b"", b"", b""])
1661 1694
1662 1695 def test_list_input(self):
1663 1696 cctx = zstd.ZstdCompressor(write_checksum=True)
1664 1697
1665 1698 if not hasattr(cctx, "multi_compress_to_buffer"):
1666 1699 self.skipTest("multi_compress_to_buffer not available")
1667 1700
1668 1701 original = [b"foo" * 12, b"bar" * 6]
1669 1702 frames = [cctx.compress(c) for c in original]
1670 1703 b = cctx.multi_compress_to_buffer(original)
1671 1704
1672 1705 self.assertIsInstance(b, zstd.BufferWithSegmentsCollection)
1673 1706
1674 1707 self.assertEqual(len(b), 2)
1675 1708 self.assertEqual(b.size(), 44)
1676 1709
1677 1710 self.assertEqual(b[0].tobytes(), frames[0])
1678 1711 self.assertEqual(b[1].tobytes(), frames[1])
1679 1712
1680 1713 def test_buffer_with_segments_input(self):
1681 1714 cctx = zstd.ZstdCompressor(write_checksum=True)
1682 1715
1683 1716 if not hasattr(cctx, "multi_compress_to_buffer"):
1684 1717 self.skipTest("multi_compress_to_buffer not available")
1685 1718
1686 1719 original = [b"foo" * 4, b"bar" * 6]
1687 1720 frames = [cctx.compress(c) for c in original]
1688 1721
1689 1722 offsets = struct.pack(
1690 1723 "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
1691 1724 )
1692 1725 segments = zstd.BufferWithSegments(b"".join(original), offsets)
1693 1726
1694 1727 result = cctx.multi_compress_to_buffer(segments)
1695 1728
1696 1729 self.assertEqual(len(result), 2)
1697 1730 self.assertEqual(result.size(), 47)
1698 1731
1699 1732 self.assertEqual(result[0].tobytes(), frames[0])
1700 1733 self.assertEqual(result[1].tobytes(), frames[1])
1701 1734
1702 1735 def test_buffer_with_segments_collection_input(self):
1703 1736 cctx = zstd.ZstdCompressor(write_checksum=True)
1704 1737
1705 1738 if not hasattr(cctx, "multi_compress_to_buffer"):
1706 1739 self.skipTest("multi_compress_to_buffer not available")
1707 1740
1708 1741 original = [
1709 1742 b"foo1",
1710 1743 b"foo2" * 2,
1711 1744 b"foo3" * 3,
1712 1745 b"foo4" * 4,
1713 1746 b"foo5" * 5,
1714 1747 ]
1715 1748
1716 1749 frames = [cctx.compress(c) for c in original]
1717 1750
1718 1751 b = b"".join([original[0], original[1]])
1719 1752 b1 = zstd.BufferWithSegments(
1720 1753 b,
1721 1754 struct.pack(
1722 1755 "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
1723 1756 ),
1724 1757 )
1725 1758 b = b"".join([original[2], original[3], original[4]])
1726 1759 b2 = zstd.BufferWithSegments(
1727 1760 b,
1728 1761 struct.pack(
1729 1762 "=QQQQQQ",
1730 1763 0,
1731 1764 len(original[2]),
1732 1765 len(original[2]),
1733 1766 len(original[3]),
1734 1767 len(original[2]) + len(original[3]),
1735 1768 len(original[4]),
1736 1769 ),
1737 1770 )
1738 1771
1739 1772 c = zstd.BufferWithSegmentsCollection(b1, b2)
1740 1773
1741 1774 result = cctx.multi_compress_to_buffer(c)
1742 1775
1743 1776 self.assertEqual(len(result), len(frames))
1744 1777
1745 1778 for i, frame in enumerate(frames):
1746 1779 self.assertEqual(result[i].tobytes(), frame)
1747 1780
1748 1781 def test_multiple_threads(self):
1749 1782 # threads argument will cause multi-threaded ZSTD APIs to be used, which will
1750 1783 # make output different.
1751 1784 refcctx = zstd.ZstdCompressor(write_checksum=True)
1752 1785 reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)]
1753 1786
1754 1787 cctx = zstd.ZstdCompressor(write_checksum=True)
1755 1788
1756 1789 if not hasattr(cctx, "multi_compress_to_buffer"):
1757 1790 self.skipTest("multi_compress_to_buffer not available")
1758 1791
1759 1792 frames = []
1760 1793 frames.extend(b"x" * 64 for i in range(256))
1761 1794 frames.extend(b"y" * 64 for i in range(256))
1762 1795
1763 1796 result = cctx.multi_compress_to_buffer(frames, threads=-1)
1764 1797
1765 1798 self.assertEqual(len(result), 512)
1766 1799 for i in range(512):
1767 1800 if i < 256:
1768 1801 self.assertEqual(result[i].tobytes(), reference[0])
1769 1802 else:
1770 1803 self.assertEqual(result[i].tobytes(), reference[1])
@@ -1,836 +1,884 b''
1 1 import io
2 2 import os
3 3 import unittest
4 4
5 5 try:
6 6 import hypothesis
7 7 import hypothesis.strategies as strategies
8 8 except ImportError:
9 9 raise unittest.SkipTest("hypothesis not available")
10 10
11 11 import zstandard as zstd
12 12
13 13 from .common import (
14 14 make_cffi,
15 15 NonClosingBytesIO,
16 16 random_input_data,
17 17 TestCase,
18 18 )
19 19
20 20
21 21 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
22 22 @make_cffi
23 23 class TestCompressor_stream_reader_fuzzing(TestCase):
24 24 @hypothesis.settings(
25 25 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
26 26 )
27 27 @hypothesis.given(
28 28 original=strategies.sampled_from(random_input_data()),
29 29 level=strategies.integers(min_value=1, max_value=5),
30 30 source_read_size=strategies.integers(1, 16384),
31 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
31 read_size=strategies.integers(
32 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
33 ),
32 34 )
33 def test_stream_source_read(self, original, level, source_read_size, read_size):
35 def test_stream_source_read(
36 self, original, level, source_read_size, read_size
37 ):
34 38 if read_size == 0:
35 39 read_size = -1
36 40
37 41 refctx = zstd.ZstdCompressor(level=level)
38 42 ref_frame = refctx.compress(original)
39 43
40 44 cctx = zstd.ZstdCompressor(level=level)
41 45 with cctx.stream_reader(
42 46 io.BytesIO(original), size=len(original), read_size=source_read_size
43 47 ) as reader:
44 48 chunks = []
45 49 while True:
46 50 chunk = reader.read(read_size)
47 51 if not chunk:
48 52 break
49 53
50 54 chunks.append(chunk)
51 55
52 56 self.assertEqual(b"".join(chunks), ref_frame)
53 57
54 58 @hypothesis.settings(
55 59 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
56 60 )
57 61 @hypothesis.given(
58 62 original=strategies.sampled_from(random_input_data()),
59 63 level=strategies.integers(min_value=1, max_value=5),
60 64 source_read_size=strategies.integers(1, 16384),
61 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
65 read_size=strategies.integers(
66 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
67 ),
62 68 )
63 def test_buffer_source_read(self, original, level, source_read_size, read_size):
69 def test_buffer_source_read(
70 self, original, level, source_read_size, read_size
71 ):
64 72 if read_size == 0:
65 73 read_size = -1
66 74
67 75 refctx = zstd.ZstdCompressor(level=level)
68 76 ref_frame = refctx.compress(original)
69 77
70 78 cctx = zstd.ZstdCompressor(level=level)
71 79 with cctx.stream_reader(
72 80 original, size=len(original), read_size=source_read_size
73 81 ) as reader:
74 82 chunks = []
75 83 while True:
76 84 chunk = reader.read(read_size)
77 85 if not chunk:
78 86 break
79 87
80 88 chunks.append(chunk)
81 89
82 90 self.assertEqual(b"".join(chunks), ref_frame)
83 91
84 92 @hypothesis.settings(
85 93 suppress_health_check=[
86 94 hypothesis.HealthCheck.large_base_example,
87 95 hypothesis.HealthCheck.too_slow,
88 96 ]
89 97 )
90 98 @hypothesis.given(
91 99 original=strategies.sampled_from(random_input_data()),
92 100 level=strategies.integers(min_value=1, max_value=5),
93 101 source_read_size=strategies.integers(1, 16384),
94 102 read_sizes=strategies.data(),
95 103 )
96 104 def test_stream_source_read_variance(
97 105 self, original, level, source_read_size, read_sizes
98 106 ):
99 107 refctx = zstd.ZstdCompressor(level=level)
100 108 ref_frame = refctx.compress(original)
101 109
102 110 cctx = zstd.ZstdCompressor(level=level)
103 111 with cctx.stream_reader(
104 112 io.BytesIO(original), size=len(original), read_size=source_read_size
105 113 ) as reader:
106 114 chunks = []
107 115 while True:
108 116 read_size = read_sizes.draw(strategies.integers(-1, 16384))
109 117 chunk = reader.read(read_size)
110 118 if not chunk and read_size:
111 119 break
112 120
113 121 chunks.append(chunk)
114 122
115 123 self.assertEqual(b"".join(chunks), ref_frame)
116 124
117 125 @hypothesis.settings(
118 126 suppress_health_check=[
119 127 hypothesis.HealthCheck.large_base_example,
120 128 hypothesis.HealthCheck.too_slow,
121 129 ]
122 130 )
123 131 @hypothesis.given(
124 132 original=strategies.sampled_from(random_input_data()),
125 133 level=strategies.integers(min_value=1, max_value=5),
126 134 source_read_size=strategies.integers(1, 16384),
127 135 read_sizes=strategies.data(),
128 136 )
129 137 def test_buffer_source_read_variance(
130 138 self, original, level, source_read_size, read_sizes
131 139 ):
132 140
133 141 refctx = zstd.ZstdCompressor(level=level)
134 142 ref_frame = refctx.compress(original)
135 143
136 144 cctx = zstd.ZstdCompressor(level=level)
137 145 with cctx.stream_reader(
138 146 original, size=len(original), read_size=source_read_size
139 147 ) as reader:
140 148 chunks = []
141 149 while True:
142 150 read_size = read_sizes.draw(strategies.integers(-1, 16384))
143 151 chunk = reader.read(read_size)
144 152 if not chunk and read_size:
145 153 break
146 154
147 155 chunks.append(chunk)
148 156
149 157 self.assertEqual(b"".join(chunks), ref_frame)
150 158
151 159 @hypothesis.settings(
152 160 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
153 161 )
154 162 @hypothesis.given(
155 163 original=strategies.sampled_from(random_input_data()),
156 164 level=strategies.integers(min_value=1, max_value=5),
157 165 source_read_size=strategies.integers(1, 16384),
158 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
166 read_size=strategies.integers(
167 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
168 ),
159 169 )
160 def test_stream_source_readinto(self, original, level, source_read_size, read_size):
170 def test_stream_source_readinto(
171 self, original, level, source_read_size, read_size
172 ):
161 173 refctx = zstd.ZstdCompressor(level=level)
162 174 ref_frame = refctx.compress(original)
163 175
164 176 cctx = zstd.ZstdCompressor(level=level)
165 177 with cctx.stream_reader(
166 178 io.BytesIO(original), size=len(original), read_size=source_read_size
167 179 ) as reader:
168 180 chunks = []
169 181 while True:
170 182 b = bytearray(read_size)
171 183 count = reader.readinto(b)
172 184
173 185 if not count:
174 186 break
175 187
176 188 chunks.append(bytes(b[0:count]))
177 189
178 190 self.assertEqual(b"".join(chunks), ref_frame)
179 191
180 192 @hypothesis.settings(
181 193 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
182 194 )
183 195 @hypothesis.given(
184 196 original=strategies.sampled_from(random_input_data()),
185 197 level=strategies.integers(min_value=1, max_value=5),
186 198 source_read_size=strategies.integers(1, 16384),
187 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
199 read_size=strategies.integers(
200 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
201 ),
188 202 )
189 def test_buffer_source_readinto(self, original, level, source_read_size, read_size):
203 def test_buffer_source_readinto(
204 self, original, level, source_read_size, read_size
205 ):
190 206
191 207 refctx = zstd.ZstdCompressor(level=level)
192 208 ref_frame = refctx.compress(original)
193 209
194 210 cctx = zstd.ZstdCompressor(level=level)
195 211 with cctx.stream_reader(
196 212 original, size=len(original), read_size=source_read_size
197 213 ) as reader:
198 214 chunks = []
199 215 while True:
200 216 b = bytearray(read_size)
201 217 count = reader.readinto(b)
202 218
203 219 if not count:
204 220 break
205 221
206 222 chunks.append(bytes(b[0:count]))
207 223
208 224 self.assertEqual(b"".join(chunks), ref_frame)
209 225
210 226 @hypothesis.settings(
211 227 suppress_health_check=[
212 228 hypothesis.HealthCheck.large_base_example,
213 229 hypothesis.HealthCheck.too_slow,
214 230 ]
215 231 )
216 232 @hypothesis.given(
217 233 original=strategies.sampled_from(random_input_data()),
218 234 level=strategies.integers(min_value=1, max_value=5),
219 235 source_read_size=strategies.integers(1, 16384),
220 236 read_sizes=strategies.data(),
221 237 )
222 238 def test_stream_source_readinto_variance(
223 239 self, original, level, source_read_size, read_sizes
224 240 ):
225 241 refctx = zstd.ZstdCompressor(level=level)
226 242 ref_frame = refctx.compress(original)
227 243
228 244 cctx = zstd.ZstdCompressor(level=level)
229 245 with cctx.stream_reader(
230 246 io.BytesIO(original), size=len(original), read_size=source_read_size
231 247 ) as reader:
232 248 chunks = []
233 249 while True:
234 250 read_size = read_sizes.draw(strategies.integers(1, 16384))
235 251 b = bytearray(read_size)
236 252 count = reader.readinto(b)
237 253
238 254 if not count:
239 255 break
240 256
241 257 chunks.append(bytes(b[0:count]))
242 258
243 259 self.assertEqual(b"".join(chunks), ref_frame)
244 260
245 261 @hypothesis.settings(
246 262 suppress_health_check=[
247 263 hypothesis.HealthCheck.large_base_example,
248 264 hypothesis.HealthCheck.too_slow,
249 265 ]
250 266 )
251 267 @hypothesis.given(
252 268 original=strategies.sampled_from(random_input_data()),
253 269 level=strategies.integers(min_value=1, max_value=5),
254 270 source_read_size=strategies.integers(1, 16384),
255 271 read_sizes=strategies.data(),
256 272 )
257 273 def test_buffer_source_readinto_variance(
258 274 self, original, level, source_read_size, read_sizes
259 275 ):
260 276
261 277 refctx = zstd.ZstdCompressor(level=level)
262 278 ref_frame = refctx.compress(original)
263 279
264 280 cctx = zstd.ZstdCompressor(level=level)
265 281 with cctx.stream_reader(
266 282 original, size=len(original), read_size=source_read_size
267 283 ) as reader:
268 284 chunks = []
269 285 while True:
270 286 read_size = read_sizes.draw(strategies.integers(1, 16384))
271 287 b = bytearray(read_size)
272 288 count = reader.readinto(b)
273 289
274 290 if not count:
275 291 break
276 292
277 293 chunks.append(bytes(b[0:count]))
278 294
279 295 self.assertEqual(b"".join(chunks), ref_frame)
280 296
281 297 @hypothesis.settings(
282 298 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
283 299 )
284 300 @hypothesis.given(
285 301 original=strategies.sampled_from(random_input_data()),
286 302 level=strategies.integers(min_value=1, max_value=5),
287 303 source_read_size=strategies.integers(1, 16384),
288 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
304 read_size=strategies.integers(
305 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
306 ),
289 307 )
290 def test_stream_source_read1(self, original, level, source_read_size, read_size):
308 def test_stream_source_read1(
309 self, original, level, source_read_size, read_size
310 ):
291 311 if read_size == 0:
292 312 read_size = -1
293 313
294 314 refctx = zstd.ZstdCompressor(level=level)
295 315 ref_frame = refctx.compress(original)
296 316
297 317 cctx = zstd.ZstdCompressor(level=level)
298 318 with cctx.stream_reader(
299 319 io.BytesIO(original), size=len(original), read_size=source_read_size
300 320 ) as reader:
301 321 chunks = []
302 322 while True:
303 323 chunk = reader.read1(read_size)
304 324 if not chunk:
305 325 break
306 326
307 327 chunks.append(chunk)
308 328
309 329 self.assertEqual(b"".join(chunks), ref_frame)
310 330
311 331 @hypothesis.settings(
312 332 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
313 333 )
314 334 @hypothesis.given(
315 335 original=strategies.sampled_from(random_input_data()),
316 336 level=strategies.integers(min_value=1, max_value=5),
317 337 source_read_size=strategies.integers(1, 16384),
318 read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
338 read_size=strategies.integers(
339 -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
340 ),
319 341 )
320 def test_buffer_source_read1(self, original, level, source_read_size, read_size):
342 def test_buffer_source_read1(
343 self, original, level, source_read_size, read_size
344 ):
321 345 if read_size == 0:
322 346 read_size = -1
323 347
324 348 refctx = zstd.ZstdCompressor(level=level)
325 349 ref_frame = refctx.compress(original)
326 350
327 351 cctx = zstd.ZstdCompressor(level=level)
328 352 with cctx.stream_reader(
329 353 original, size=len(original), read_size=source_read_size
330 354 ) as reader:
331 355 chunks = []
332 356 while True:
333 357 chunk = reader.read1(read_size)
334 358 if not chunk:
335 359 break
336 360
337 361 chunks.append(chunk)
338 362
339 363 self.assertEqual(b"".join(chunks), ref_frame)
340 364
341 365 @hypothesis.settings(
342 366 suppress_health_check=[
343 367 hypothesis.HealthCheck.large_base_example,
344 368 hypothesis.HealthCheck.too_slow,
345 369 ]
346 370 )
347 371 @hypothesis.given(
348 372 original=strategies.sampled_from(random_input_data()),
349 373 level=strategies.integers(min_value=1, max_value=5),
350 374 source_read_size=strategies.integers(1, 16384),
351 375 read_sizes=strategies.data(),
352 376 )
353 377 def test_stream_source_read1_variance(
354 378 self, original, level, source_read_size, read_sizes
355 379 ):
356 380 refctx = zstd.ZstdCompressor(level=level)
357 381 ref_frame = refctx.compress(original)
358 382
359 383 cctx = zstd.ZstdCompressor(level=level)
360 384 with cctx.stream_reader(
361 385 io.BytesIO(original), size=len(original), read_size=source_read_size
362 386 ) as reader:
363 387 chunks = []
364 388 while True:
365 389 read_size = read_sizes.draw(strategies.integers(-1, 16384))
366 390 chunk = reader.read1(read_size)
367 391 if not chunk and read_size:
368 392 break
369 393
370 394 chunks.append(chunk)
371 395
372 396 self.assertEqual(b"".join(chunks), ref_frame)
373 397
374 398 @hypothesis.settings(
375 399 suppress_health_check=[
376 400 hypothesis.HealthCheck.large_base_example,
377 401 hypothesis.HealthCheck.too_slow,
378 402 ]
379 403 )
380 404 @hypothesis.given(
381 405 original=strategies.sampled_from(random_input_data()),
382 406 level=strategies.integers(min_value=1, max_value=5),
383 407 source_read_size=strategies.integers(1, 16384),
384 408 read_sizes=strategies.data(),
385 409 )
386 410 def test_buffer_source_read1_variance(
387 411 self, original, level, source_read_size, read_sizes
388 412 ):
389 413
390 414 refctx = zstd.ZstdCompressor(level=level)
391 415 ref_frame = refctx.compress(original)
392 416
393 417 cctx = zstd.ZstdCompressor(level=level)
394 418 with cctx.stream_reader(
395 419 original, size=len(original), read_size=source_read_size
396 420 ) as reader:
397 421 chunks = []
398 422 while True:
399 423 read_size = read_sizes.draw(strategies.integers(-1, 16384))
400 424 chunk = reader.read1(read_size)
401 425 if not chunk and read_size:
402 426 break
403 427
404 428 chunks.append(chunk)
405 429
406 430 self.assertEqual(b"".join(chunks), ref_frame)
407 431
408 432 @hypothesis.settings(
409 433 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
410 434 )
411 435 @hypothesis.given(
412 436 original=strategies.sampled_from(random_input_data()),
413 437 level=strategies.integers(min_value=1, max_value=5),
414 438 source_read_size=strategies.integers(1, 16384),
415 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
439 read_size=strategies.integers(
440 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
441 ),
416 442 )
417 443 def test_stream_source_readinto1(
418 444 self, original, level, source_read_size, read_size
419 445 ):
420 446 if read_size == 0:
421 447 read_size = -1
422 448
423 449 refctx = zstd.ZstdCompressor(level=level)
424 450 ref_frame = refctx.compress(original)
425 451
426 452 cctx = zstd.ZstdCompressor(level=level)
427 453 with cctx.stream_reader(
428 454 io.BytesIO(original), size=len(original), read_size=source_read_size
429 455 ) as reader:
430 456 chunks = []
431 457 while True:
432 458 b = bytearray(read_size)
433 459 count = reader.readinto1(b)
434 460
435 461 if not count:
436 462 break
437 463
438 464 chunks.append(bytes(b[0:count]))
439 465
440 466 self.assertEqual(b"".join(chunks), ref_frame)
441 467
442 468 @hypothesis.settings(
443 469 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
444 470 )
445 471 @hypothesis.given(
446 472 original=strategies.sampled_from(random_input_data()),
447 473 level=strategies.integers(min_value=1, max_value=5),
448 474 source_read_size=strategies.integers(1, 16384),
449 read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE),
475 read_size=strategies.integers(
476 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE
477 ),
450 478 )
451 479 def test_buffer_source_readinto1(
452 480 self, original, level, source_read_size, read_size
453 481 ):
454 482 if read_size == 0:
455 483 read_size = -1
456 484
457 485 refctx = zstd.ZstdCompressor(level=level)
458 486 ref_frame = refctx.compress(original)
459 487
460 488 cctx = zstd.ZstdCompressor(level=level)
461 489 with cctx.stream_reader(
462 490 original, size=len(original), read_size=source_read_size
463 491 ) as reader:
464 492 chunks = []
465 493 while True:
466 494 b = bytearray(read_size)
467 495 count = reader.readinto1(b)
468 496
469 497 if not count:
470 498 break
471 499
472 500 chunks.append(bytes(b[0:count]))
473 501
474 502 self.assertEqual(b"".join(chunks), ref_frame)
475 503
476 504 @hypothesis.settings(
477 505 suppress_health_check=[
478 506 hypothesis.HealthCheck.large_base_example,
479 507 hypothesis.HealthCheck.too_slow,
480 508 ]
481 509 )
482 510 @hypothesis.given(
483 511 original=strategies.sampled_from(random_input_data()),
484 512 level=strategies.integers(min_value=1, max_value=5),
485 513 source_read_size=strategies.integers(1, 16384),
486 514 read_sizes=strategies.data(),
487 515 )
488 516 def test_stream_source_readinto1_variance(
489 517 self, original, level, source_read_size, read_sizes
490 518 ):
491 519 refctx = zstd.ZstdCompressor(level=level)
492 520 ref_frame = refctx.compress(original)
493 521
494 522 cctx = zstd.ZstdCompressor(level=level)
495 523 with cctx.stream_reader(
496 524 io.BytesIO(original), size=len(original), read_size=source_read_size
497 525 ) as reader:
498 526 chunks = []
499 527 while True:
500 528 read_size = read_sizes.draw(strategies.integers(1, 16384))
501 529 b = bytearray(read_size)
502 530 count = reader.readinto1(b)
503 531
504 532 if not count:
505 533 break
506 534
507 535 chunks.append(bytes(b[0:count]))
508 536
509 537 self.assertEqual(b"".join(chunks), ref_frame)
510 538
511 539 @hypothesis.settings(
512 540 suppress_health_check=[
513 541 hypothesis.HealthCheck.large_base_example,
514 542 hypothesis.HealthCheck.too_slow,
515 543 ]
516 544 )
517 545 @hypothesis.given(
518 546 original=strategies.sampled_from(random_input_data()),
519 547 level=strategies.integers(min_value=1, max_value=5),
520 548 source_read_size=strategies.integers(1, 16384),
521 549 read_sizes=strategies.data(),
522 550 )
523 551 def test_buffer_source_readinto1_variance(
524 552 self, original, level, source_read_size, read_sizes
525 553 ):
526 554
527 555 refctx = zstd.ZstdCompressor(level=level)
528 556 ref_frame = refctx.compress(original)
529 557
530 558 cctx = zstd.ZstdCompressor(level=level)
531 559 with cctx.stream_reader(
532 560 original, size=len(original), read_size=source_read_size
533 561 ) as reader:
534 562 chunks = []
535 563 while True:
536 564 read_size = read_sizes.draw(strategies.integers(1, 16384))
537 565 b = bytearray(read_size)
538 566 count = reader.readinto1(b)
539 567
540 568 if not count:
541 569 break
542 570
543 571 chunks.append(bytes(b[0:count]))
544 572
545 573 self.assertEqual(b"".join(chunks), ref_frame)
546 574
547 575
548 576 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
549 577 @make_cffi
550 578 class TestCompressor_stream_writer_fuzzing(TestCase):
551 579 @hypothesis.given(
552 580 original=strategies.sampled_from(random_input_data()),
553 581 level=strategies.integers(min_value=1, max_value=5),
554 582 write_size=strategies.integers(min_value=1, max_value=1048576),
555 583 )
556 584 def test_write_size_variance(self, original, level, write_size):
557 585 refctx = zstd.ZstdCompressor(level=level)
558 586 ref_frame = refctx.compress(original)
559 587
560 588 cctx = zstd.ZstdCompressor(level=level)
561 589 b = NonClosingBytesIO()
562 590 with cctx.stream_writer(
563 591 b, size=len(original), write_size=write_size
564 592 ) as compressor:
565 593 compressor.write(original)
566 594
567 595 self.assertEqual(b.getvalue(), ref_frame)
568 596
569 597
570 598 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
571 599 @make_cffi
572 600 class TestCompressor_copy_stream_fuzzing(TestCase):
573 601 @hypothesis.given(
574 602 original=strategies.sampled_from(random_input_data()),
575 603 level=strategies.integers(min_value=1, max_value=5),
576 604 read_size=strategies.integers(min_value=1, max_value=1048576),
577 605 write_size=strategies.integers(min_value=1, max_value=1048576),
578 606 )
579 def test_read_write_size_variance(self, original, level, read_size, write_size):
607 def test_read_write_size_variance(
608 self, original, level, read_size, write_size
609 ):
580 610 refctx = zstd.ZstdCompressor(level=level)
581 611 ref_frame = refctx.compress(original)
582 612
583 613 cctx = zstd.ZstdCompressor(level=level)
584 614 source = io.BytesIO(original)
585 615 dest = io.BytesIO()
586 616
587 617 cctx.copy_stream(
588 source, dest, size=len(original), read_size=read_size, write_size=write_size
618 source,
619 dest,
620 size=len(original),
621 read_size=read_size,
622 write_size=write_size,
589 623 )
590 624
591 625 self.assertEqual(dest.getvalue(), ref_frame)
592 626
593 627
594 628 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
595 629 @make_cffi
596 630 class TestCompressor_compressobj_fuzzing(TestCase):
597 631 @hypothesis.settings(
598 632 suppress_health_check=[
599 633 hypothesis.HealthCheck.large_base_example,
600 634 hypothesis.HealthCheck.too_slow,
601 635 ]
602 636 )
603 637 @hypothesis.given(
604 638 original=strategies.sampled_from(random_input_data()),
605 639 level=strategies.integers(min_value=1, max_value=5),
606 640 chunk_sizes=strategies.data(),
607 641 )
608 642 def test_random_input_sizes(self, original, level, chunk_sizes):
609 643 refctx = zstd.ZstdCompressor(level=level)
610 644 ref_frame = refctx.compress(original)
611 645
612 646 cctx = zstd.ZstdCompressor(level=level)
613 647 cobj = cctx.compressobj(size=len(original))
614 648
615 649 chunks = []
616 650 i = 0
617 651 while True:
618 652 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
619 653 source = original[i : i + chunk_size]
620 654 if not source:
621 655 break
622 656
623 657 chunks.append(cobj.compress(source))
624 658 i += chunk_size
625 659
626 660 chunks.append(cobj.flush())
627 661
628 662 self.assertEqual(b"".join(chunks), ref_frame)
629 663
630 664 @hypothesis.settings(
631 665 suppress_health_check=[
632 666 hypothesis.HealthCheck.large_base_example,
633 667 hypothesis.HealthCheck.too_slow,
634 668 ]
635 669 )
636 670 @hypothesis.given(
637 671 original=strategies.sampled_from(random_input_data()),
638 672 level=strategies.integers(min_value=1, max_value=5),
639 673 chunk_sizes=strategies.data(),
640 674 flushes=strategies.data(),
641 675 )
642 676 def test_flush_block(self, original, level, chunk_sizes, flushes):
643 677 cctx = zstd.ZstdCompressor(level=level)
644 678 cobj = cctx.compressobj()
645 679
646 680 dctx = zstd.ZstdDecompressor()
647 681 dobj = dctx.decompressobj()
648 682
649 683 compressed_chunks = []
650 684 decompressed_chunks = []
651 685 i = 0
652 686 while True:
653 687 input_size = chunk_sizes.draw(strategies.integers(1, 4096))
654 688 source = original[i : i + input_size]
655 689 if not source:
656 690 break
657 691
658 692 i += input_size
659 693
660 694 chunk = cobj.compress(source)
661 695 compressed_chunks.append(chunk)
662 696 decompressed_chunks.append(dobj.decompress(chunk))
663 697
664 698 if not flushes.draw(strategies.booleans()):
665 699 continue
666 700
667 701 chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
668 702 compressed_chunks.append(chunk)
669 703 decompressed_chunks.append(dobj.decompress(chunk))
670 704
671 705 self.assertEqual(b"".join(decompressed_chunks), original[0:i])
672 706
673 707 chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH)
674 708 compressed_chunks.append(chunk)
675 709 decompressed_chunks.append(dobj.decompress(chunk))
676 710
677 711 self.assertEqual(
678 dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)),
712 dctx.decompress(
713 b"".join(compressed_chunks), max_output_size=len(original)
714 ),
679 715 original,
680 716 )
681 717 self.assertEqual(b"".join(decompressed_chunks), original)
682 718
683 719
684 720 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
685 721 @make_cffi
686 722 class TestCompressor_read_to_iter_fuzzing(TestCase):
687 723 @hypothesis.given(
688 724 original=strategies.sampled_from(random_input_data()),
689 725 level=strategies.integers(min_value=1, max_value=5),
690 726 read_size=strategies.integers(min_value=1, max_value=4096),
691 727 write_size=strategies.integers(min_value=1, max_value=4096),
692 728 )
693 def test_read_write_size_variance(self, original, level, read_size, write_size):
729 def test_read_write_size_variance(
730 self, original, level, read_size, write_size
731 ):
694 732 refcctx = zstd.ZstdCompressor(level=level)
695 733 ref_frame = refcctx.compress(original)
696 734
697 735 source = io.BytesIO(original)
698 736
699 737 cctx = zstd.ZstdCompressor(level=level)
700 738 chunks = list(
701 739 cctx.read_to_iter(
702 source, size=len(original), read_size=read_size, write_size=write_size
740 source,
741 size=len(original),
742 read_size=read_size,
743 write_size=write_size,
703 744 )
704 745 )
705 746
706 747 self.assertEqual(b"".join(chunks), ref_frame)
707 748
708 749
709 750 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
710 751 class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase):
711 752 @hypothesis.given(
712 753 original=strategies.lists(
713 strategies.sampled_from(random_input_data()), min_size=1, max_size=1024
754 strategies.sampled_from(random_input_data()),
755 min_size=1,
756 max_size=1024,
714 757 ),
715 758 threads=strategies.integers(min_value=1, max_value=8),
716 759 use_dict=strategies.booleans(),
717 760 )
718 761 def test_data_equivalence(self, original, threads, use_dict):
719 762 kwargs = {}
720 763
721 764 # Use a content dictionary because it is cheap to create.
722 765 if use_dict:
723 766 kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0])
724 767
725 768 cctx = zstd.ZstdCompressor(level=1, write_checksum=True, **kwargs)
726 769
727 770 if not hasattr(cctx, "multi_compress_to_buffer"):
728 771 self.skipTest("multi_compress_to_buffer not available")
729 772
730 773 result = cctx.multi_compress_to_buffer(original, threads=-1)
731 774
732 775 self.assertEqual(len(result), len(original))
733 776
734 777 # The frame produced via the batch APIs may not be bit identical to that
735 778 # produced by compress() because compression parameters are adjusted
736 779 # from the first input in batch mode. So the only thing we can do is
737 780 # verify the decompressed data matches the input.
738 781 dctx = zstd.ZstdDecompressor(**kwargs)
739 782
740 783 for i, frame in enumerate(result):
741 784 self.assertEqual(dctx.decompress(frame), original[i])
742 785
743 786
744 787 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
745 788 @make_cffi
746 789 class TestCompressor_chunker_fuzzing(TestCase):
747 790 @hypothesis.settings(
748 791 suppress_health_check=[
749 792 hypothesis.HealthCheck.large_base_example,
750 793 hypothesis.HealthCheck.too_slow,
751 794 ]
752 795 )
753 796 @hypothesis.given(
754 797 original=strategies.sampled_from(random_input_data()),
755 798 level=strategies.integers(min_value=1, max_value=5),
756 799 chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576),
757 800 input_sizes=strategies.data(),
758 801 )
759 802 def test_random_input_sizes(self, original, level, chunk_size, input_sizes):
760 803 cctx = zstd.ZstdCompressor(level=level)
761 804 chunker = cctx.chunker(chunk_size=chunk_size)
762 805
763 806 chunks = []
764 807 i = 0
765 808 while True:
766 809 input_size = input_sizes.draw(strategies.integers(1, 4096))
767 810 source = original[i : i + input_size]
768 811 if not source:
769 812 break
770 813
771 814 chunks.extend(chunker.compress(source))
772 815 i += input_size
773 816
774 817 chunks.extend(chunker.finish())
775 818
776 819 dctx = zstd.ZstdDecompressor()
777 820
778 821 self.assertEqual(
779 dctx.decompress(b"".join(chunks), max_output_size=len(original)), original
822 dctx.decompress(b"".join(chunks), max_output_size=len(original)),
823 original,
780 824 )
781 825
782 826 self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1]))
783 827
784 828 @hypothesis.settings(
785 829 suppress_health_check=[
786 830 hypothesis.HealthCheck.large_base_example,
787 831 hypothesis.HealthCheck.too_slow,
788 832 ]
789 833 )
790 834 @hypothesis.given(
791 835 original=strategies.sampled_from(random_input_data()),
792 836 level=strategies.integers(min_value=1, max_value=5),
793 837 chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576),
794 838 input_sizes=strategies.data(),
795 839 flushes=strategies.data(),
796 840 )
797 def test_flush_block(self, original, level, chunk_size, input_sizes, flushes):
841 def test_flush_block(
842 self, original, level, chunk_size, input_sizes, flushes
843 ):
798 844 cctx = zstd.ZstdCompressor(level=level)
799 845 chunker = cctx.chunker(chunk_size=chunk_size)
800 846
801 847 dctx = zstd.ZstdDecompressor()
802 848 dobj = dctx.decompressobj()
803 849
804 850 compressed_chunks = []
805 851 decompressed_chunks = []
806 852 i = 0
807 853 while True:
808 854 input_size = input_sizes.draw(strategies.integers(1, 4096))
809 855 source = original[i : i + input_size]
810 856 if not source:
811 857 break
812 858
813 859 i += input_size
814 860
815 861 chunks = list(chunker.compress(source))
816 862 compressed_chunks.extend(chunks)
817 863 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
818 864
819 865 if not flushes.draw(strategies.booleans()):
820 866 continue
821 867
822 868 chunks = list(chunker.flush())
823 869 compressed_chunks.extend(chunks)
824 870 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
825 871
826 872 self.assertEqual(b"".join(decompressed_chunks), original[0:i])
827 873
828 874 chunks = list(chunker.finish())
829 875 compressed_chunks.extend(chunks)
830 876 decompressed_chunks.append(dobj.decompress(b"".join(chunks)))
831 877
832 878 self.assertEqual(
833 dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)),
879 dctx.decompress(
880 b"".join(compressed_chunks), max_output_size=len(original)
881 ),
834 882 original,
835 883 )
836 884 self.assertEqual(b"".join(decompressed_chunks), original)
@@ -1,241 +1,255 b''
1 1 import sys
2 2 import unittest
3 3
4 4 import zstandard as zstd
5 5
6 6 from .common import (
7 7 make_cffi,
8 8 TestCase,
9 9 )
10 10
11 11
12 12 @make_cffi
13 13 class TestCompressionParameters(TestCase):
14 14 def test_bounds(self):
15 15 zstd.ZstdCompressionParameters(
16 16 window_log=zstd.WINDOWLOG_MIN,
17 17 chain_log=zstd.CHAINLOG_MIN,
18 18 hash_log=zstd.HASHLOG_MIN,
19 19 search_log=zstd.SEARCHLOG_MIN,
20 20 min_match=zstd.MINMATCH_MIN + 1,
21 21 target_length=zstd.TARGETLENGTH_MIN,
22 22 strategy=zstd.STRATEGY_FAST,
23 23 )
24 24
25 25 zstd.ZstdCompressionParameters(
26 26 window_log=zstd.WINDOWLOG_MAX,
27 27 chain_log=zstd.CHAINLOG_MAX,
28 28 hash_log=zstd.HASHLOG_MAX,
29 29 search_log=zstd.SEARCHLOG_MAX,
30 30 min_match=zstd.MINMATCH_MAX - 1,
31 31 target_length=zstd.TARGETLENGTH_MAX,
32 32 strategy=zstd.STRATEGY_BTULTRA2,
33 33 )
34 34
35 35 def test_from_level(self):
36 36 p = zstd.ZstdCompressionParameters.from_level(1)
37 37 self.assertIsInstance(p, zstd.CompressionParameters)
38 38
39 39 self.assertEqual(p.window_log, 19)
40 40
41 41 p = zstd.ZstdCompressionParameters.from_level(-4)
42 42 self.assertEqual(p.window_log, 19)
43 43
44 44 def test_members(self):
45 45 p = zstd.ZstdCompressionParameters(
46 46 window_log=10,
47 47 chain_log=6,
48 48 hash_log=7,
49 49 search_log=4,
50 50 min_match=5,
51 51 target_length=8,
52 52 strategy=1,
53 53 )
54 54 self.assertEqual(p.window_log, 10)
55 55 self.assertEqual(p.chain_log, 6)
56 56 self.assertEqual(p.hash_log, 7)
57 57 self.assertEqual(p.search_log, 4)
58 58 self.assertEqual(p.min_match, 5)
59 59 self.assertEqual(p.target_length, 8)
60 60 self.assertEqual(p.compression_strategy, 1)
61 61
62 62 p = zstd.ZstdCompressionParameters(compression_level=2)
63 63 self.assertEqual(p.compression_level, 2)
64 64
65 65 p = zstd.ZstdCompressionParameters(threads=4)
66 66 self.assertEqual(p.threads, 4)
67 67
68 p = zstd.ZstdCompressionParameters(threads=2, job_size=1048576, overlap_log=6)
68 p = zstd.ZstdCompressionParameters(
69 threads=2, job_size=1048576, overlap_log=6
70 )
69 71 self.assertEqual(p.threads, 2)
70 72 self.assertEqual(p.job_size, 1048576)
71 73 self.assertEqual(p.overlap_log, 6)
72 74 self.assertEqual(p.overlap_size_log, 6)
73 75
74 76 p = zstd.ZstdCompressionParameters(compression_level=-1)
75 77 self.assertEqual(p.compression_level, -1)
76 78
77 79 p = zstd.ZstdCompressionParameters(compression_level=-2)
78 80 self.assertEqual(p.compression_level, -2)
79 81
80 82 p = zstd.ZstdCompressionParameters(force_max_window=True)
81 83 self.assertEqual(p.force_max_window, 1)
82 84
83 85 p = zstd.ZstdCompressionParameters(enable_ldm=True)
84 86 self.assertEqual(p.enable_ldm, 1)
85 87
86 88 p = zstd.ZstdCompressionParameters(ldm_hash_log=7)
87 89 self.assertEqual(p.ldm_hash_log, 7)
88 90
89 91 p = zstd.ZstdCompressionParameters(ldm_min_match=6)
90 92 self.assertEqual(p.ldm_min_match, 6)
91 93
92 94 p = zstd.ZstdCompressionParameters(ldm_bucket_size_log=7)
93 95 self.assertEqual(p.ldm_bucket_size_log, 7)
94 96
95 97 p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8)
96 98 self.assertEqual(p.ldm_hash_every_log, 8)
97 99 self.assertEqual(p.ldm_hash_rate_log, 8)
98 100
99 101 def test_estimated_compression_context_size(self):
100 102 p = zstd.ZstdCompressionParameters(
101 103 window_log=20,
102 104 chain_log=16,
103 105 hash_log=17,
104 106 search_log=1,
105 107 min_match=5,
106 108 target_length=16,
107 109 strategy=zstd.STRATEGY_DFAST,
108 110 )
109 111
110 112 # 32-bit has slightly different values from 64-bit.
111 113 self.assertAlmostEqual(
112 114 p.estimated_compression_context_size(), 1294464, delta=400
113 115 )
114 116
115 117 def test_strategy(self):
116 118 with self.assertRaisesRegex(
117 119 ValueError, "cannot specify both compression_strategy"
118 120 ):
119 121 zstd.ZstdCompressionParameters(strategy=0, compression_strategy=0)
120 122
121 123 p = zstd.ZstdCompressionParameters(strategy=2)
122 124 self.assertEqual(p.compression_strategy, 2)
123 125
124 126 p = zstd.ZstdCompressionParameters(strategy=3)
125 127 self.assertEqual(p.compression_strategy, 3)
126 128
127 129 def test_ldm_hash_rate_log(self):
128 130 with self.assertRaisesRegex(
129 131 ValueError, "cannot specify both ldm_hash_rate_log"
130 132 ):
131 zstd.ZstdCompressionParameters(ldm_hash_rate_log=8, ldm_hash_every_log=4)
133 zstd.ZstdCompressionParameters(
134 ldm_hash_rate_log=8, ldm_hash_every_log=4
135 )
132 136
133 137 p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8)
134 138 self.assertEqual(p.ldm_hash_every_log, 8)
135 139
136 140 p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16)
137 141 self.assertEqual(p.ldm_hash_every_log, 16)
138 142
139 143 def test_overlap_log(self):
140 with self.assertRaisesRegex(ValueError, "cannot specify both overlap_log"):
144 with self.assertRaisesRegex(
145 ValueError, "cannot specify both overlap_log"
146 ):
141 147 zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9)
142 148
143 149 p = zstd.ZstdCompressionParameters(overlap_log=2)
144 150 self.assertEqual(p.overlap_log, 2)
145 151 self.assertEqual(p.overlap_size_log, 2)
146 152
147 153 p = zstd.ZstdCompressionParameters(overlap_size_log=4)
148 154 self.assertEqual(p.overlap_log, 4)
149 155 self.assertEqual(p.overlap_size_log, 4)
150 156
151 157
152 158 @make_cffi
153 159 class TestFrameParameters(TestCase):
154 160 def test_invalid_type(self):
155 161 with self.assertRaises(TypeError):
156 162 zstd.get_frame_parameters(None)
157 163
158 164 # Python 3 doesn't appear to convert unicode to Py_buffer.
159 165 if sys.version_info[0] >= 3:
160 166 with self.assertRaises(TypeError):
161 167 zstd.get_frame_parameters(u"foobarbaz")
162 168 else:
163 169 # CPython will convert unicode to Py_buffer. But CFFI won't.
164 170 if zstd.backend == "cffi":
165 171 with self.assertRaises(TypeError):
166 172 zstd.get_frame_parameters(u"foobarbaz")
167 173 else:
168 174 with self.assertRaises(zstd.ZstdError):
169 175 zstd.get_frame_parameters(u"foobarbaz")
170 176
171 177 def test_invalid_input_sizes(self):
172 with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"):
178 with self.assertRaisesRegex(
179 zstd.ZstdError, "not enough data for frame"
180 ):
173 181 zstd.get_frame_parameters(b"")
174 182
175 with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"):
183 with self.assertRaisesRegex(
184 zstd.ZstdError, "not enough data for frame"
185 ):
176 186 zstd.get_frame_parameters(zstd.FRAME_HEADER)
177 187
178 188 def test_invalid_frame(self):
179 189 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
180 190 zstd.get_frame_parameters(b"foobarbaz")
181 191
182 192 def test_attributes(self):
183 193 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00")
184 194 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
185 195 self.assertEqual(params.window_size, 1024)
186 196 self.assertEqual(params.dict_id, 0)
187 197 self.assertFalse(params.has_checksum)
188 198
189 199 # Lowest 2 bits indicate a dictionary and length. Here, the dict id is 1 byte.
190 200 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x01\x00\xff")
191 201 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
192 202 self.assertEqual(params.window_size, 1024)
193 203 self.assertEqual(params.dict_id, 255)
194 204 self.assertFalse(params.has_checksum)
195 205
196 206 # Lowest 3rd bit indicates if checksum is present.
197 207 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00")
198 208 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
199 209 self.assertEqual(params.window_size, 1024)
200 210 self.assertEqual(params.dict_id, 0)
201 211 self.assertTrue(params.has_checksum)
202 212
203 213 # Upper 2 bits indicate content size.
204 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x40\x00\xff\x00")
214 params = zstd.get_frame_parameters(
215 zstd.FRAME_HEADER + b"\x40\x00\xff\x00"
216 )
205 217 self.assertEqual(params.content_size, 511)
206 218 self.assertEqual(params.window_size, 1024)
207 219 self.assertEqual(params.dict_id, 0)
208 220 self.assertFalse(params.has_checksum)
209 221
210 222 # Window descriptor is 2nd byte after frame header.
211 223 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40")
212 224 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
213 225 self.assertEqual(params.window_size, 262144)
214 226 self.assertEqual(params.dict_id, 0)
215 227 self.assertFalse(params.has_checksum)
216 228
217 229 # Set multiple things.
218 params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00")
230 params = zstd.get_frame_parameters(
231 zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00"
232 )
219 233 self.assertEqual(params.content_size, 272)
220 234 self.assertEqual(params.window_size, 262144)
221 235 self.assertEqual(params.dict_id, 15)
222 236 self.assertTrue(params.has_checksum)
223 237
224 238 def test_input_types(self):
225 239 v = zstd.FRAME_HEADER + b"\x00\x00"
226 240
227 241 mutable_array = bytearray(len(v))
228 242 mutable_array[:] = v
229 243
230 244 sources = [
231 245 memoryview(v),
232 246 bytearray(v),
233 247 mutable_array,
234 248 ]
235 249
236 250 for source in sources:
237 251 params = zstd.get_frame_parameters(source)
238 252 self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
239 253 self.assertEqual(params.window_size, 1024)
240 254 self.assertEqual(params.dict_id, 0)
241 255 self.assertFalse(params.has_checksum)
@@ -1,105 +1,121 b''
1 1 import io
2 2 import os
3 3 import sys
4 4 import unittest
5 5
6 6 try:
7 7 import hypothesis
8 8 import hypothesis.strategies as strategies
9 9 except ImportError:
10 10 raise unittest.SkipTest("hypothesis not available")
11 11
12 12 import zstandard as zstd
13 13
14 14 from .common import (
15 15 make_cffi,
16 16 TestCase,
17 17 )
18 18
19 19
20 20 s_windowlog = strategies.integers(
21 21 min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX
22 22 )
23 23 s_chainlog = strategies.integers(
24 24 min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX
25 25 )
26 s_hashlog = strategies.integers(min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX)
26 s_hashlog = strategies.integers(
27 min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX
28 )
27 29 s_searchlog = strategies.integers(
28 30 min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX
29 31 )
30 32 s_minmatch = strategies.integers(
31 33 min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX
32 34 )
33 35 s_targetlength = strategies.integers(
34 36 min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX
35 37 )
36 38 s_strategy = strategies.sampled_from(
37 39 (
38 40 zstd.STRATEGY_FAST,
39 41 zstd.STRATEGY_DFAST,
40 42 zstd.STRATEGY_GREEDY,
41 43 zstd.STRATEGY_LAZY,
42 44 zstd.STRATEGY_LAZY2,
43 45 zstd.STRATEGY_BTLAZY2,
44 46 zstd.STRATEGY_BTOPT,
45 47 zstd.STRATEGY_BTULTRA,
46 48 zstd.STRATEGY_BTULTRA2,
47 49 )
48 50 )
49 51
50 52
51 53 @make_cffi
52 54 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
53 55 class TestCompressionParametersHypothesis(TestCase):
54 56 @hypothesis.given(
55 57 s_windowlog,
56 58 s_chainlog,
57 59 s_hashlog,
58 60 s_searchlog,
59 61 s_minmatch,
60 62 s_targetlength,
61 63 s_strategy,
62 64 )
63 65 def test_valid_init(
64 self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy
66 self,
67 windowlog,
68 chainlog,
69 hashlog,
70 searchlog,
71 minmatch,
72 targetlength,
73 strategy,
65 74 ):
66 75 zstd.ZstdCompressionParameters(
67 76 window_log=windowlog,
68 77 chain_log=chainlog,
69 78 hash_log=hashlog,
70 79 search_log=searchlog,
71 80 min_match=minmatch,
72 81 target_length=targetlength,
73 82 strategy=strategy,
74 83 )
75 84
76 85 @hypothesis.given(
77 86 s_windowlog,
78 87 s_chainlog,
79 88 s_hashlog,
80 89 s_searchlog,
81 90 s_minmatch,
82 91 s_targetlength,
83 92 s_strategy,
84 93 )
85 94 def test_estimated_compression_context_size(
86 self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy
95 self,
96 windowlog,
97 chainlog,
98 hashlog,
99 searchlog,
100 minmatch,
101 targetlength,
102 strategy,
87 103 ):
88 104 if minmatch == zstd.MINMATCH_MIN and strategy in (
89 105 zstd.STRATEGY_FAST,
90 106 zstd.STRATEGY_GREEDY,
91 107 ):
92 108 minmatch += 1
93 109 elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST:
94 110 minmatch -= 1
95 111
96 112 p = zstd.ZstdCompressionParameters(
97 113 window_log=windowlog,
98 114 chain_log=chainlog,
99 115 hash_log=hashlog,
100 116 search_log=searchlog,
101 117 min_match=minmatch,
102 118 target_length=targetlength,
103 119 strategy=strategy,
104 120 )
105 121 size = p.estimated_compression_context_size()
@@ -1,1670 +1,1714 b''
1 1 import io
2 2 import os
3 3 import random
4 4 import struct
5 5 import sys
6 6 import tempfile
7 7 import unittest
8 8
9 9 import zstandard as zstd
10 10
11 11 from .common import (
12 12 generate_samples,
13 13 make_cffi,
14 14 NonClosingBytesIO,
15 15 OpCountingBytesIO,
16 16 TestCase,
17 17 )
18 18
19 19
20 20 if sys.version_info[0] >= 3:
21 21 next = lambda it: it.__next__()
22 22 else:
23 23 next = lambda it: it.next()
24 24
25 25
26 26 @make_cffi
27 27 class TestFrameHeaderSize(TestCase):
28 28 def test_empty(self):
29 29 with self.assertRaisesRegex(
30 30 zstd.ZstdError,
31 31 "could not determine frame header size: Src size " "is incorrect",
32 32 ):
33 33 zstd.frame_header_size(b"")
34 34
35 35 def test_too_small(self):
36 36 with self.assertRaisesRegex(
37 37 zstd.ZstdError,
38 38 "could not determine frame header size: Src size " "is incorrect",
39 39 ):
40 40 zstd.frame_header_size(b"foob")
41 41
42 42 def test_basic(self):
43 43 # It doesn't matter that it isn't a valid frame.
44 44 self.assertEqual(zstd.frame_header_size(b"long enough but no magic"), 6)
45 45
46 46
47 47 @make_cffi
48 48 class TestFrameContentSize(TestCase):
49 49 def test_empty(self):
50 50 with self.assertRaisesRegex(
51 51 zstd.ZstdError, "error when determining content size"
52 52 ):
53 53 zstd.frame_content_size(b"")
54 54
55 55 def test_too_small(self):
56 56 with self.assertRaisesRegex(
57 57 zstd.ZstdError, "error when determining content size"
58 58 ):
59 59 zstd.frame_content_size(b"foob")
60 60
61 61 def test_bad_frame(self):
62 62 with self.assertRaisesRegex(
63 63 zstd.ZstdError, "error when determining content size"
64 64 ):
65 65 zstd.frame_content_size(b"invalid frame header")
66 66
67 67 def test_unknown(self):
68 68 cctx = zstd.ZstdCompressor(write_content_size=False)
69 69 frame = cctx.compress(b"foobar")
70 70
71 71 self.assertEqual(zstd.frame_content_size(frame), -1)
72 72
73 73 def test_empty(self):
74 74 cctx = zstd.ZstdCompressor()
75 75 frame = cctx.compress(b"")
76 76
77 77 self.assertEqual(zstd.frame_content_size(frame), 0)
78 78
79 79 def test_basic(self):
80 80 cctx = zstd.ZstdCompressor()
81 81 frame = cctx.compress(b"foobar")
82 82
83 83 self.assertEqual(zstd.frame_content_size(frame), 6)
84 84
85 85
86 86 @make_cffi
87 87 class TestDecompressor(TestCase):
88 88 def test_memory_size(self):
89 89 dctx = zstd.ZstdDecompressor()
90 90
91 91 self.assertGreater(dctx.memory_size(), 100)
92 92
93 93
94 94 @make_cffi
95 95 class TestDecompressor_decompress(TestCase):
96 96 def test_empty_input(self):
97 97 dctx = zstd.ZstdDecompressor()
98 98
99 99 with self.assertRaisesRegex(
100 100 zstd.ZstdError, "error determining content size from frame header"
101 101 ):
102 102 dctx.decompress(b"")
103 103
104 104 def test_invalid_input(self):
105 105 dctx = zstd.ZstdDecompressor()
106 106
107 107 with self.assertRaisesRegex(
108 108 zstd.ZstdError, "error determining content size from frame header"
109 109 ):
110 110 dctx.decompress(b"foobar")
111 111
112 112 def test_input_types(self):
113 113 cctx = zstd.ZstdCompressor(level=1)
114 114 compressed = cctx.compress(b"foo")
115 115
116 116 mutable_array = bytearray(len(compressed))
117 117 mutable_array[:] = compressed
118 118
119 119 sources = [
120 120 memoryview(compressed),
121 121 bytearray(compressed),
122 122 mutable_array,
123 123 ]
124 124
125 125 dctx = zstd.ZstdDecompressor()
126 126 for source in sources:
127 127 self.assertEqual(dctx.decompress(source), b"foo")
128 128
129 129 def test_no_content_size_in_frame(self):
130 130 cctx = zstd.ZstdCompressor(write_content_size=False)
131 131 compressed = cctx.compress(b"foobar")
132 132
133 133 dctx = zstd.ZstdDecompressor()
134 134 with self.assertRaisesRegex(
135 135 zstd.ZstdError, "could not determine content size in frame header"
136 136 ):
137 137 dctx.decompress(compressed)
138 138
139 139 def test_content_size_present(self):
140 140 cctx = zstd.ZstdCompressor()
141 141 compressed = cctx.compress(b"foobar")
142 142
143 143 dctx = zstd.ZstdDecompressor()
144 144 decompressed = dctx.decompress(compressed)
145 145 self.assertEqual(decompressed, b"foobar")
146 146
147 147 def test_empty_roundtrip(self):
148 148 cctx = zstd.ZstdCompressor()
149 149 compressed = cctx.compress(b"")
150 150
151 151 dctx = zstd.ZstdDecompressor()
152 152 decompressed = dctx.decompress(compressed)
153 153
154 154 self.assertEqual(decompressed, b"")
155 155
156 156 def test_max_output_size(self):
157 157 cctx = zstd.ZstdCompressor(write_content_size=False)
158 158 source = b"foobar" * 256
159 159 compressed = cctx.compress(source)
160 160
161 161 dctx = zstd.ZstdDecompressor()
162 162 # Will fit into buffer exactly the size of input.
163 163 decompressed = dctx.decompress(compressed, max_output_size=len(source))
164 164 self.assertEqual(decompressed, source)
165 165
166 166 # Input size - 1 fails
167 167 with self.assertRaisesRegex(
168 168 zstd.ZstdError, "decompression error: did not decompress full frame"
169 169 ):
170 170 dctx.decompress(compressed, max_output_size=len(source) - 1)
171 171
172 172 # Input size + 1 works
173 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1)
173 decompressed = dctx.decompress(
174 compressed, max_output_size=len(source) + 1
175 )
174 176 self.assertEqual(decompressed, source)
175 177
176 178 # A much larger buffer works.
177 decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64)
179 decompressed = dctx.decompress(
180 compressed, max_output_size=len(source) * 64
181 )
178 182 self.assertEqual(decompressed, source)
179 183
180 184 def test_stupidly_large_output_buffer(self):
181 185 cctx = zstd.ZstdCompressor(write_content_size=False)
182 186 compressed = cctx.compress(b"foobar" * 256)
183 187 dctx = zstd.ZstdDecompressor()
184 188
185 189 # Will get OverflowError on some Python distributions that can't
186 190 # handle really large integers.
187 191 with self.assertRaises((MemoryError, OverflowError)):
188 192 dctx.decompress(compressed, max_output_size=2 ** 62)
189 193
190 194 def test_dictionary(self):
191 195 samples = []
192 196 for i in range(128):
193 197 samples.append(b"foo" * 64)
194 198 samples.append(b"bar" * 64)
195 199 samples.append(b"foobar" * 64)
196 200
197 201 d = zstd.train_dictionary(8192, samples)
198 202
199 203 orig = b"foobar" * 16384
200 204 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
201 205 compressed = cctx.compress(orig)
202 206
203 207 dctx = zstd.ZstdDecompressor(dict_data=d)
204 208 decompressed = dctx.decompress(compressed)
205 209
206 210 self.assertEqual(decompressed, orig)
207 211
208 212 def test_dictionary_multiple(self):
209 213 samples = []
210 214 for i in range(128):
211 215 samples.append(b"foo" * 64)
212 216 samples.append(b"bar" * 64)
213 217 samples.append(b"foobar" * 64)
214 218
215 219 d = zstd.train_dictionary(8192, samples)
216 220
217 221 sources = (b"foobar" * 8192, b"foo" * 8192, b"bar" * 8192)
218 222 compressed = []
219 223 cctx = zstd.ZstdCompressor(level=1, dict_data=d)
220 224 for source in sources:
221 225 compressed.append(cctx.compress(source))
222 226
223 227 dctx = zstd.ZstdDecompressor(dict_data=d)
224 228 for i in range(len(sources)):
225 229 decompressed = dctx.decompress(compressed[i])
226 230 self.assertEqual(decompressed, sources[i])
227 231
228 232 def test_max_window_size(self):
229 233 with open(__file__, "rb") as fh:
230 234 source = fh.read()
231 235
232 236 # If we write a content size, the decompressor engages single pass
233 237 # mode and the window size doesn't come into play.
234 238 cctx = zstd.ZstdCompressor(write_content_size=False)
235 239 frame = cctx.compress(source)
236 240
237 241 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN)
238 242
239 243 with self.assertRaisesRegex(
240 zstd.ZstdError, "decompression error: Frame requires too much memory"
244 zstd.ZstdError,
245 "decompression error: Frame requires too much memory",
241 246 ):
242 247 dctx.decompress(frame, max_output_size=len(source))
243 248
244 249
245 250 @make_cffi
246 251 class TestDecompressor_copy_stream(TestCase):
247 252 def test_no_read(self):
248 253 source = object()
249 254 dest = io.BytesIO()
250 255
251 256 dctx = zstd.ZstdDecompressor()
252 257 with self.assertRaises(ValueError):
253 258 dctx.copy_stream(source, dest)
254 259
255 260 def test_no_write(self):
256 261 source = io.BytesIO()
257 262 dest = object()
258 263
259 264 dctx = zstd.ZstdDecompressor()
260 265 with self.assertRaises(ValueError):
261 266 dctx.copy_stream(source, dest)
262 267
263 268 def test_empty(self):
264 269 source = io.BytesIO()
265 270 dest = io.BytesIO()
266 271
267 272 dctx = zstd.ZstdDecompressor()
268 273 # TODO should this raise an error?
269 274 r, w = dctx.copy_stream(source, dest)
270 275
271 276 self.assertEqual(r, 0)
272 277 self.assertEqual(w, 0)
273 278 self.assertEqual(dest.getvalue(), b"")
274 279
275 280 def test_large_data(self):
276 281 source = io.BytesIO()
277 282 for i in range(255):
278 283 source.write(struct.Struct(">B").pack(i) * 16384)
279 284 source.seek(0)
280 285
281 286 compressed = io.BytesIO()
282 287 cctx = zstd.ZstdCompressor()
283 288 cctx.copy_stream(source, compressed)
284 289
285 290 compressed.seek(0)
286 291 dest = io.BytesIO()
287 292 dctx = zstd.ZstdDecompressor()
288 293 r, w = dctx.copy_stream(compressed, dest)
289 294
290 295 self.assertEqual(r, len(compressed.getvalue()))
291 296 self.assertEqual(w, len(source.getvalue()))
292 297
293 298 def test_read_write_size(self):
294 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar"))
299 source = OpCountingBytesIO(
300 zstd.ZstdCompressor().compress(b"foobarfoobar")
301 )
295 302
296 303 dest = OpCountingBytesIO()
297 304 dctx = zstd.ZstdDecompressor()
298 305 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
299 306
300 307 self.assertEqual(r, len(source.getvalue()))
301 308 self.assertEqual(w, len(b"foobarfoobar"))
302 309 self.assertEqual(source._read_count, len(source.getvalue()) + 1)
303 310 self.assertEqual(dest._write_count, len(dest.getvalue()))
304 311
305 312
306 313 @make_cffi
307 314 class TestDecompressor_stream_reader(TestCase):
308 315 def test_context_manager(self):
309 316 dctx = zstd.ZstdDecompressor()
310 317
311 318 with dctx.stream_reader(b"foo") as reader:
312 with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"):
319 with self.assertRaisesRegex(
320 ValueError, "cannot __enter__ multiple times"
321 ):
313 322 with reader as reader2:
314 323 pass
315 324
316 325 def test_not_implemented(self):
317 326 dctx = zstd.ZstdDecompressor()
318 327
319 328 with dctx.stream_reader(b"foo") as reader:
320 329 with self.assertRaises(io.UnsupportedOperation):
321 330 reader.readline()
322 331
323 332 with self.assertRaises(io.UnsupportedOperation):
324 333 reader.readlines()
325 334
326 335 with self.assertRaises(io.UnsupportedOperation):
327 336 iter(reader)
328 337
329 338 with self.assertRaises(io.UnsupportedOperation):
330 339 next(reader)
331 340
332 341 with self.assertRaises(io.UnsupportedOperation):
333 342 reader.write(b"foo")
334 343
335 344 with self.assertRaises(io.UnsupportedOperation):
336 345 reader.writelines([])
337 346
338 347 def test_constant_methods(self):
339 348 dctx = zstd.ZstdDecompressor()
340 349
341 350 with dctx.stream_reader(b"foo") as reader:
342 351 self.assertFalse(reader.closed)
343 352 self.assertTrue(reader.readable())
344 353 self.assertFalse(reader.writable())
345 354 self.assertTrue(reader.seekable())
346 355 self.assertFalse(reader.isatty())
347 356 self.assertFalse(reader.closed)
348 357 self.assertIsNone(reader.flush())
349 358 self.assertFalse(reader.closed)
350 359
351 360 self.assertTrue(reader.closed)
352 361
353 362 def test_read_closed(self):
354 363 dctx = zstd.ZstdDecompressor()
355 364
356 365 with dctx.stream_reader(b"foo") as reader:
357 366 reader.close()
358 367 self.assertTrue(reader.closed)
359 368 with self.assertRaisesRegex(ValueError, "stream is closed"):
360 369 reader.read(1)
361 370
362 371 def test_read_sizes(self):
363 372 cctx = zstd.ZstdCompressor()
364 373 foo = cctx.compress(b"foo")
365 374
366 375 dctx = zstd.ZstdDecompressor()
367 376
368 377 with dctx.stream_reader(foo) as reader:
369 378 with self.assertRaisesRegex(
370 379 ValueError, "cannot read negative amounts less than -1"
371 380 ):
372 381 reader.read(-2)
373 382
374 383 self.assertEqual(reader.read(0), b"")
375 384 self.assertEqual(reader.read(), b"foo")
376 385
377 386 def test_read_buffer(self):
378 387 cctx = zstd.ZstdCompressor()
379 388
380 389 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
381 390 frame = cctx.compress(source)
382 391
383 392 dctx = zstd.ZstdDecompressor()
384 393
385 394 with dctx.stream_reader(frame) as reader:
386 395 self.assertEqual(reader.tell(), 0)
387 396
388 397 # We should get entire frame in one read.
389 398 result = reader.read(8192)
390 399 self.assertEqual(result, source)
391 400 self.assertEqual(reader.tell(), len(source))
392 401
393 402 # Read after EOF should return empty bytes.
394 403 self.assertEqual(reader.read(1), b"")
395 404 self.assertEqual(reader.tell(), len(result))
396 405
397 406 self.assertTrue(reader.closed)
398 407
399 408 def test_read_buffer_small_chunks(self):
400 409 cctx = zstd.ZstdCompressor()
401 410 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
402 411 frame = cctx.compress(source)
403 412
404 413 dctx = zstd.ZstdDecompressor()
405 414 chunks = []
406 415
407 416 with dctx.stream_reader(frame, read_size=1) as reader:
408 417 while True:
409 418 chunk = reader.read(1)
410 419 if not chunk:
411 420 break
412 421
413 422 chunks.append(chunk)
414 423 self.assertEqual(reader.tell(), sum(map(len, chunks)))
415 424
416 425 self.assertEqual(b"".join(chunks), source)
417 426
418 427 def test_read_stream(self):
419 428 cctx = zstd.ZstdCompressor()
420 429 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
421 430 frame = cctx.compress(source)
422 431
423 432 dctx = zstd.ZstdDecompressor()
424 433 with dctx.stream_reader(io.BytesIO(frame)) as reader:
425 434 self.assertEqual(reader.tell(), 0)
426 435
427 436 chunk = reader.read(8192)
428 437 self.assertEqual(chunk, source)
429 438 self.assertEqual(reader.tell(), len(source))
430 439 self.assertEqual(reader.read(1), b"")
431 440 self.assertEqual(reader.tell(), len(source))
432 441 self.assertFalse(reader.closed)
433 442
434 443 self.assertTrue(reader.closed)
435 444
436 445 def test_read_stream_small_chunks(self):
437 446 cctx = zstd.ZstdCompressor()
438 447 source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
439 448 frame = cctx.compress(source)
440 449
441 450 dctx = zstd.ZstdDecompressor()
442 451 chunks = []
443 452
444 453 with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader:
445 454 while True:
446 455 chunk = reader.read(1)
447 456 if not chunk:
448 457 break
449 458
450 459 chunks.append(chunk)
451 460 self.assertEqual(reader.tell(), sum(map(len, chunks)))
452 461
453 462 self.assertEqual(b"".join(chunks), source)
454 463
455 464 def test_read_after_exit(self):
456 465 cctx = zstd.ZstdCompressor()
457 466 frame = cctx.compress(b"foo" * 60)
458 467
459 468 dctx = zstd.ZstdDecompressor()
460 469
461 470 with dctx.stream_reader(frame) as reader:
462 471 while reader.read(16):
463 472 pass
464 473
465 474 self.assertTrue(reader.closed)
466 475
467 476 with self.assertRaisesRegex(ValueError, "stream is closed"):
468 477 reader.read(10)
469 478
470 479 def test_illegal_seeks(self):
471 480 cctx = zstd.ZstdCompressor()
472 481 frame = cctx.compress(b"foo" * 60)
473 482
474 483 dctx = zstd.ZstdDecompressor()
475 484
476 485 with dctx.stream_reader(frame) as reader:
477 with self.assertRaisesRegex(ValueError, "cannot seek to negative position"):
486 with self.assertRaisesRegex(
487 ValueError, "cannot seek to negative position"
488 ):
478 489 reader.seek(-1, os.SEEK_SET)
479 490
480 491 reader.read(1)
481 492
482 493 with self.assertRaisesRegex(
483 494 ValueError, "cannot seek zstd decompression stream backwards"
484 495 ):
485 496 reader.seek(0, os.SEEK_SET)
486 497
487 498 with self.assertRaisesRegex(
488 499 ValueError, "cannot seek zstd decompression stream backwards"
489 500 ):
490 501 reader.seek(-1, os.SEEK_CUR)
491 502
492 503 with self.assertRaisesRegex(
493 ValueError, "zstd decompression streams cannot be seeked with SEEK_END"
504 ValueError,
505 "zstd decompression streams cannot be seeked with SEEK_END",
494 506 ):
495 507 reader.seek(0, os.SEEK_END)
496 508
497 509 reader.close()
498 510
499 511 with self.assertRaisesRegex(ValueError, "stream is closed"):
500 512 reader.seek(4, os.SEEK_SET)
501 513
502 514 with self.assertRaisesRegex(ValueError, "stream is closed"):
503 515 reader.seek(0)
504 516
505 517 def test_seek(self):
506 518 source = b"foobar" * 60
507 519 cctx = zstd.ZstdCompressor()
508 520 frame = cctx.compress(source)
509 521
510 522 dctx = zstd.ZstdDecompressor()
511 523
512 524 with dctx.stream_reader(frame) as reader:
513 525 reader.seek(3)
514 526 self.assertEqual(reader.read(3), b"bar")
515 527
516 528 reader.seek(4, os.SEEK_CUR)
517 529 self.assertEqual(reader.read(2), b"ar")
518 530
519 531 def test_no_context_manager(self):
520 532 source = b"foobar" * 60
521 533 cctx = zstd.ZstdCompressor()
522 534 frame = cctx.compress(source)
523 535
524 536 dctx = zstd.ZstdDecompressor()
525 537 reader = dctx.stream_reader(frame)
526 538
527 539 self.assertEqual(reader.read(6), b"foobar")
528 540 self.assertEqual(reader.read(18), b"foobar" * 3)
529 541 self.assertFalse(reader.closed)
530 542
531 543 # Calling close prevents subsequent use.
532 544 reader.close()
533 545 self.assertTrue(reader.closed)
534 546
535 547 with self.assertRaisesRegex(ValueError, "stream is closed"):
536 548 reader.read(6)
537 549
538 550 def test_read_after_error(self):
539 551 source = io.BytesIO(b"")
540 552 dctx = zstd.ZstdDecompressor()
541 553
542 554 reader = dctx.stream_reader(source)
543 555
544 556 with reader:
545 557 reader.read(0)
546 558
547 559 with reader:
548 560 with self.assertRaisesRegex(ValueError, "stream is closed"):
549 561 reader.read(100)
550 562
551 563 def test_partial_read(self):
552 564 # Inspired by https://github.com/indygreg/python-zstandard/issues/71.
553 565 buffer = io.BytesIO()
554 566 cctx = zstd.ZstdCompressor()
555 567 writer = cctx.stream_writer(buffer)
556 568 writer.write(bytearray(os.urandom(1000000)))
557 569 writer.flush(zstd.FLUSH_FRAME)
558 570 buffer.seek(0)
559 571
560 572 dctx = zstd.ZstdDecompressor()
561 573 reader = dctx.stream_reader(buffer)
562 574
563 575 while True:
564 576 chunk = reader.read(8192)
565 577 if not chunk:
566 578 break
567 579
568 580 def test_read_multiple_frames(self):
569 581 cctx = zstd.ZstdCompressor()
570 582 source = io.BytesIO()
571 583 writer = cctx.stream_writer(source)
572 584 writer.write(b"foo")
573 585 writer.flush(zstd.FLUSH_FRAME)
574 586 writer.write(b"bar")
575 587 writer.flush(zstd.FLUSH_FRAME)
576 588
577 589 dctx = zstd.ZstdDecompressor()
578 590
579 591 reader = dctx.stream_reader(source.getvalue())
580 592 self.assertEqual(reader.read(2), b"fo")
581 593 self.assertEqual(reader.read(2), b"o")
582 594 self.assertEqual(reader.read(2), b"ba")
583 595 self.assertEqual(reader.read(2), b"r")
584 596
585 597 source.seek(0)
586 598 reader = dctx.stream_reader(source)
587 599 self.assertEqual(reader.read(2), b"fo")
588 600 self.assertEqual(reader.read(2), b"o")
589 601 self.assertEqual(reader.read(2), b"ba")
590 602 self.assertEqual(reader.read(2), b"r")
591 603
592 604 reader = dctx.stream_reader(source.getvalue())
593 605 self.assertEqual(reader.read(3), b"foo")
594 606 self.assertEqual(reader.read(3), b"bar")
595 607
596 608 source.seek(0)
597 609 reader = dctx.stream_reader(source)
598 610 self.assertEqual(reader.read(3), b"foo")
599 611 self.assertEqual(reader.read(3), b"bar")
600 612
601 613 reader = dctx.stream_reader(source.getvalue())
602 614 self.assertEqual(reader.read(4), b"foo")
603 615 self.assertEqual(reader.read(4), b"bar")
604 616
605 617 source.seek(0)
606 618 reader = dctx.stream_reader(source)
607 619 self.assertEqual(reader.read(4), b"foo")
608 620 self.assertEqual(reader.read(4), b"bar")
609 621
610 622 reader = dctx.stream_reader(source.getvalue())
611 623 self.assertEqual(reader.read(128), b"foo")
612 624 self.assertEqual(reader.read(128), b"bar")
613 625
614 626 source.seek(0)
615 627 reader = dctx.stream_reader(source)
616 628 self.assertEqual(reader.read(128), b"foo")
617 629 self.assertEqual(reader.read(128), b"bar")
618 630
619 631 # Now tests for reads spanning frames.
620 632 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
621 633 self.assertEqual(reader.read(3), b"foo")
622 634 self.assertEqual(reader.read(3), b"bar")
623 635
624 636 source.seek(0)
625 637 reader = dctx.stream_reader(source, read_across_frames=True)
626 638 self.assertEqual(reader.read(3), b"foo")
627 639 self.assertEqual(reader.read(3), b"bar")
628 640
629 641 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
630 642 self.assertEqual(reader.read(6), b"foobar")
631 643
632 644 source.seek(0)
633 645 reader = dctx.stream_reader(source, read_across_frames=True)
634 646 self.assertEqual(reader.read(6), b"foobar")
635 647
636 648 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
637 649 self.assertEqual(reader.read(7), b"foobar")
638 650
639 651 source.seek(0)
640 652 reader = dctx.stream_reader(source, read_across_frames=True)
641 653 self.assertEqual(reader.read(7), b"foobar")
642 654
643 655 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
644 656 self.assertEqual(reader.read(128), b"foobar")
645 657
646 658 source.seek(0)
647 659 reader = dctx.stream_reader(source, read_across_frames=True)
648 660 self.assertEqual(reader.read(128), b"foobar")
649 661
650 662 def test_readinto(self):
651 663 cctx = zstd.ZstdCompressor()
652 664 foo = cctx.compress(b"foo")
653 665
654 666 dctx = zstd.ZstdDecompressor()
655 667
656 668 # Attempting to readinto() a non-writable buffer fails.
657 669 # The exact exception varies based on the backend.
658 670 reader = dctx.stream_reader(foo)
659 671 with self.assertRaises(Exception):
660 672 reader.readinto(b"foobar")
661 673
662 674 # readinto() with sufficiently large destination.
663 675 b = bytearray(1024)
664 676 reader = dctx.stream_reader(foo)
665 677 self.assertEqual(reader.readinto(b), 3)
666 678 self.assertEqual(b[0:3], b"foo")
667 679 self.assertEqual(reader.readinto(b), 0)
668 680 self.assertEqual(b[0:3], b"foo")
669 681
670 682 # readinto() with small reads.
671 683 b = bytearray(1024)
672 684 reader = dctx.stream_reader(foo, read_size=1)
673 685 self.assertEqual(reader.readinto(b), 3)
674 686 self.assertEqual(b[0:3], b"foo")
675 687
676 688 # Too small destination buffer.
677 689 b = bytearray(2)
678 690 reader = dctx.stream_reader(foo)
679 691 self.assertEqual(reader.readinto(b), 2)
680 692 self.assertEqual(b[:], b"fo")
681 693
682 694 def test_readinto1(self):
683 695 cctx = zstd.ZstdCompressor()
684 696 foo = cctx.compress(b"foo")
685 697
686 698 dctx = zstd.ZstdDecompressor()
687 699
688 700 reader = dctx.stream_reader(foo)
689 701 with self.assertRaises(Exception):
690 702 reader.readinto1(b"foobar")
691 703
692 704 # Sufficiently large destination.
693 705 b = bytearray(1024)
694 706 reader = dctx.stream_reader(foo)
695 707 self.assertEqual(reader.readinto1(b), 3)
696 708 self.assertEqual(b[0:3], b"foo")
697 709 self.assertEqual(reader.readinto1(b), 0)
698 710 self.assertEqual(b[0:3], b"foo")
699 711
700 712 # readinto() with small reads.
701 713 b = bytearray(1024)
702 714 reader = dctx.stream_reader(foo, read_size=1)
703 715 self.assertEqual(reader.readinto1(b), 3)
704 716 self.assertEqual(b[0:3], b"foo")
705 717
706 718 # Too small destination buffer.
707 719 b = bytearray(2)
708 720 reader = dctx.stream_reader(foo)
709 721 self.assertEqual(reader.readinto1(b), 2)
710 722 self.assertEqual(b[:], b"fo")
711 723
712 724 def test_readall(self):
713 725 cctx = zstd.ZstdCompressor()
714 726 foo = cctx.compress(b"foo")
715 727
716 728 dctx = zstd.ZstdDecompressor()
717 729 reader = dctx.stream_reader(foo)
718 730
719 731 self.assertEqual(reader.readall(), b"foo")
720 732
721 733 def test_read1(self):
722 734 cctx = zstd.ZstdCompressor()
723 735 foo = cctx.compress(b"foo")
724 736
725 737 dctx = zstd.ZstdDecompressor()
726 738
727 739 b = OpCountingBytesIO(foo)
728 740 reader = dctx.stream_reader(b)
729 741
730 742 self.assertEqual(reader.read1(), b"foo")
731 743 self.assertEqual(b._read_count, 1)
732 744
733 745 b = OpCountingBytesIO(foo)
734 746 reader = dctx.stream_reader(b)
735 747
736 748 self.assertEqual(reader.read1(0), b"")
737 749 self.assertEqual(reader.read1(2), b"fo")
738 750 self.assertEqual(b._read_count, 1)
739 751 self.assertEqual(reader.read1(1), b"o")
740 752 self.assertEqual(b._read_count, 1)
741 753 self.assertEqual(reader.read1(1), b"")
742 754 self.assertEqual(b._read_count, 2)
743 755
744 756 def test_read_lines(self):
745 757 cctx = zstd.ZstdCompressor()
746 source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024))
758 source = b"\n".join(
759 ("line %d" % i).encode("ascii") for i in range(1024)
760 )
747 761
748 762 frame = cctx.compress(source)
749 763
750 764 dctx = zstd.ZstdDecompressor()
751 765 reader = dctx.stream_reader(frame)
752 766 tr = io.TextIOWrapper(reader, encoding="utf-8")
753 767
754 768 lines = []
755 769 for line in tr:
756 770 lines.append(line.encode("utf-8"))
757 771
758 772 self.assertEqual(len(lines), 1024)
759 773 self.assertEqual(b"".join(lines), source)
760 774
761 775 reader = dctx.stream_reader(frame)
762 776 tr = io.TextIOWrapper(reader, encoding="utf-8")
763 777
764 778 lines = tr.readlines()
765 779 self.assertEqual(len(lines), 1024)
766 780 self.assertEqual("".join(lines).encode("utf-8"), source)
767 781
768 782 reader = dctx.stream_reader(frame)
769 783 tr = io.TextIOWrapper(reader, encoding="utf-8")
770 784
771 785 lines = []
772 786 while True:
773 787 line = tr.readline()
774 788 if not line:
775 789 break
776 790
777 791 lines.append(line.encode("utf-8"))
778 792
779 793 self.assertEqual(len(lines), 1024)
780 794 self.assertEqual(b"".join(lines), source)
781 795
782 796
783 797 @make_cffi
784 798 class TestDecompressor_decompressobj(TestCase):
785 799 def test_simple(self):
786 800 data = zstd.ZstdCompressor(level=1).compress(b"foobar")
787 801
788 802 dctx = zstd.ZstdDecompressor()
789 803 dobj = dctx.decompressobj()
790 804 self.assertEqual(dobj.decompress(data), b"foobar")
791 805 self.assertIsNone(dobj.flush())
792 806 self.assertIsNone(dobj.flush(10))
793 807 self.assertIsNone(dobj.flush(length=100))
794 808
795 809 def test_input_types(self):
796 810 compressed = zstd.ZstdCompressor(level=1).compress(b"foo")
797 811
798 812 dctx = zstd.ZstdDecompressor()
799 813
800 814 mutable_array = bytearray(len(compressed))
801 815 mutable_array[:] = compressed
802 816
803 817 sources = [
804 818 memoryview(compressed),
805 819 bytearray(compressed),
806 820 mutable_array,
807 821 ]
808 822
809 823 for source in sources:
810 824 dobj = dctx.decompressobj()
811 825 self.assertIsNone(dobj.flush())
812 826 self.assertIsNone(dobj.flush(10))
813 827 self.assertIsNone(dobj.flush(length=100))
814 828 self.assertEqual(dobj.decompress(source), b"foo")
815 829 self.assertIsNone(dobj.flush())
816 830
817 831 def test_reuse(self):
818 832 data = zstd.ZstdCompressor(level=1).compress(b"foobar")
819 833
820 834 dctx = zstd.ZstdDecompressor()
821 835 dobj = dctx.decompressobj()
822 836 dobj.decompress(data)
823 837
824 with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"):
838 with self.assertRaisesRegex(
839 zstd.ZstdError, "cannot use a decompressobj"
840 ):
825 841 dobj.decompress(data)
826 842 self.assertIsNone(dobj.flush())
827 843
828 844 def test_bad_write_size(self):
829 845 dctx = zstd.ZstdDecompressor()
830 846
831 847 with self.assertRaisesRegex(ValueError, "write_size must be positive"):
832 848 dctx.decompressobj(write_size=0)
833 849
834 850 def test_write_size(self):
835 851 source = b"foo" * 64 + b"bar" * 128
836 852 data = zstd.ZstdCompressor(level=1).compress(source)
837 853
838 854 dctx = zstd.ZstdDecompressor()
839 855
840 856 for i in range(128):
841 857 dobj = dctx.decompressobj(write_size=i + 1)
842 858 self.assertEqual(dobj.decompress(data), source)
843 859
844 860
845 861 def decompress_via_writer(data):
846 862 buffer = io.BytesIO()
847 863 dctx = zstd.ZstdDecompressor()
848 864 decompressor = dctx.stream_writer(buffer)
849 865 decompressor.write(data)
850 866
851 867 return buffer.getvalue()
852 868
853 869
854 870 @make_cffi
855 871 class TestDecompressor_stream_writer(TestCase):
856 872 def test_io_api(self):
857 873 buffer = io.BytesIO()
858 874 dctx = zstd.ZstdDecompressor()
859 875 writer = dctx.stream_writer(buffer)
860 876
861 877 self.assertFalse(writer.closed)
862 878 self.assertFalse(writer.isatty())
863 879 self.assertFalse(writer.readable())
864 880
865 881 with self.assertRaises(io.UnsupportedOperation):
866 882 writer.readline()
867 883
868 884 with self.assertRaises(io.UnsupportedOperation):
869 885 writer.readline(42)
870 886
871 887 with self.assertRaises(io.UnsupportedOperation):
872 888 writer.readline(size=42)
873 889
874 890 with self.assertRaises(io.UnsupportedOperation):
875 891 writer.readlines()
876 892
877 893 with self.assertRaises(io.UnsupportedOperation):
878 894 writer.readlines(42)
879 895
880 896 with self.assertRaises(io.UnsupportedOperation):
881 897 writer.readlines(hint=42)
882 898
883 899 with self.assertRaises(io.UnsupportedOperation):
884 900 writer.seek(0)
885 901
886 902 with self.assertRaises(io.UnsupportedOperation):
887 903 writer.seek(10, os.SEEK_SET)
888 904
889 905 self.assertFalse(writer.seekable())
890 906
891 907 with self.assertRaises(io.UnsupportedOperation):
892 908 writer.tell()
893 909
894 910 with self.assertRaises(io.UnsupportedOperation):
895 911 writer.truncate()
896 912
897 913 with self.assertRaises(io.UnsupportedOperation):
898 914 writer.truncate(42)
899 915
900 916 with self.assertRaises(io.UnsupportedOperation):
901 917 writer.truncate(size=42)
902 918
903 919 self.assertTrue(writer.writable())
904 920
905 921 with self.assertRaises(io.UnsupportedOperation):
906 922 writer.writelines([])
907 923
908 924 with self.assertRaises(io.UnsupportedOperation):
909 925 writer.read()
910 926
911 927 with self.assertRaises(io.UnsupportedOperation):
912 928 writer.read(42)
913 929
914 930 with self.assertRaises(io.UnsupportedOperation):
915 931 writer.read(size=42)
916 932
917 933 with self.assertRaises(io.UnsupportedOperation):
918 934 writer.readall()
919 935
920 936 with self.assertRaises(io.UnsupportedOperation):
921 937 writer.readinto(None)
922 938
923 939 with self.assertRaises(io.UnsupportedOperation):
924 940 writer.fileno()
925 941
926 942 def test_fileno_file(self):
927 943 with tempfile.TemporaryFile("wb") as tf:
928 944 dctx = zstd.ZstdDecompressor()
929 945 writer = dctx.stream_writer(tf)
930 946
931 947 self.assertEqual(writer.fileno(), tf.fileno())
932 948
933 949 def test_close(self):
934 950 foo = zstd.ZstdCompressor().compress(b"foo")
935 951
936 952 buffer = NonClosingBytesIO()
937 953 dctx = zstd.ZstdDecompressor()
938 954 writer = dctx.stream_writer(buffer)
939 955
940 956 writer.write(foo)
941 957 self.assertFalse(writer.closed)
942 958 self.assertFalse(buffer.closed)
943 959 writer.close()
944 960 self.assertTrue(writer.closed)
945 961 self.assertTrue(buffer.closed)
946 962
947 963 with self.assertRaisesRegex(ValueError, "stream is closed"):
948 964 writer.write(b"")
949 965
950 966 with self.assertRaisesRegex(ValueError, "stream is closed"):
951 967 writer.flush()
952 968
953 969 with self.assertRaisesRegex(ValueError, "stream is closed"):
954 970 with writer:
955 971 pass
956 972
957 973 self.assertEqual(buffer.getvalue(), b"foo")
958 974
959 975 # Context manager exit should close stream.
960 976 buffer = NonClosingBytesIO()
961 977 writer = dctx.stream_writer(buffer)
962 978
963 979 with writer:
964 980 writer.write(foo)
965 981
966 982 self.assertTrue(writer.closed)
967 983 self.assertEqual(buffer.getvalue(), b"foo")
968 984
969 985 def test_flush(self):
970 986 buffer = OpCountingBytesIO()
971 987 dctx = zstd.ZstdDecompressor()
972 988 writer = dctx.stream_writer(buffer)
973 989
974 990 writer.flush()
975 991 self.assertEqual(buffer._flush_count, 1)
976 992 writer.flush()
977 993 self.assertEqual(buffer._flush_count, 2)
978 994
979 995 def test_empty_roundtrip(self):
980 996 cctx = zstd.ZstdCompressor()
981 997 empty = cctx.compress(b"")
982 998 self.assertEqual(decompress_via_writer(empty), b"")
983 999
984 1000 def test_input_types(self):
985 1001 cctx = zstd.ZstdCompressor(level=1)
986 1002 compressed = cctx.compress(b"foo")
987 1003
988 1004 mutable_array = bytearray(len(compressed))
989 1005 mutable_array[:] = compressed
990 1006
991 1007 sources = [
992 1008 memoryview(compressed),
993 1009 bytearray(compressed),
994 1010 mutable_array,
995 1011 ]
996 1012
997 1013 dctx = zstd.ZstdDecompressor()
998 1014 for source in sources:
999 1015 buffer = io.BytesIO()
1000 1016
1001 1017 decompressor = dctx.stream_writer(buffer)
1002 1018 decompressor.write(source)
1003 1019 self.assertEqual(buffer.getvalue(), b"foo")
1004 1020
1005 1021 buffer = NonClosingBytesIO()
1006 1022
1007 1023 with dctx.stream_writer(buffer) as decompressor:
1008 1024 self.assertEqual(decompressor.write(source), 3)
1009 1025
1010 1026 self.assertEqual(buffer.getvalue(), b"foo")
1011 1027
1012 1028 buffer = io.BytesIO()
1013 1029 writer = dctx.stream_writer(buffer, write_return_read=True)
1014 1030 self.assertEqual(writer.write(source), len(source))
1015 1031 self.assertEqual(buffer.getvalue(), b"foo")
1016 1032
1017 1033 def test_large_roundtrip(self):
1018 1034 chunks = []
1019 1035 for i in range(255):
1020 1036 chunks.append(struct.Struct(">B").pack(i) * 16384)
1021 1037 orig = b"".join(chunks)
1022 1038 cctx = zstd.ZstdCompressor()
1023 1039 compressed = cctx.compress(orig)
1024 1040
1025 1041 self.assertEqual(decompress_via_writer(compressed), orig)
1026 1042
1027 1043 def test_multiple_calls(self):
1028 1044 chunks = []
1029 1045 for i in range(255):
1030 1046 for j in range(255):
1031 1047 chunks.append(struct.Struct(">B").pack(j) * i)
1032 1048
1033 1049 orig = b"".join(chunks)
1034 1050 cctx = zstd.ZstdCompressor()
1035 1051 compressed = cctx.compress(orig)
1036 1052
1037 1053 buffer = NonClosingBytesIO()
1038 1054 dctx = zstd.ZstdDecompressor()
1039 1055 with dctx.stream_writer(buffer) as decompressor:
1040 1056 pos = 0
1041 1057 while pos < len(compressed):
1042 1058 pos2 = pos + 8192
1043 1059 decompressor.write(compressed[pos:pos2])
1044 1060 pos += 8192
1045 1061 self.assertEqual(buffer.getvalue(), orig)
1046 1062
1047 1063 # Again with write_return_read=True
1048 1064 buffer = io.BytesIO()
1049 1065 writer = dctx.stream_writer(buffer, write_return_read=True)
1050 1066 pos = 0
1051 1067 while pos < len(compressed):
1052 1068 pos2 = pos + 8192
1053 1069 chunk = compressed[pos:pos2]
1054 1070 self.assertEqual(writer.write(chunk), len(chunk))
1055 1071 pos += 8192
1056 1072 self.assertEqual(buffer.getvalue(), orig)
1057 1073
1058 1074 def test_dictionary(self):
1059 1075 samples = []
1060 1076 for i in range(128):
1061 1077 samples.append(b"foo" * 64)
1062 1078 samples.append(b"bar" * 64)
1063 1079 samples.append(b"foobar" * 64)
1064 1080
1065 1081 d = zstd.train_dictionary(8192, samples)
1066 1082
1067 1083 orig = b"foobar" * 16384
1068 1084 buffer = NonClosingBytesIO()
1069 1085 cctx = zstd.ZstdCompressor(dict_data=d)
1070 1086 with cctx.stream_writer(buffer) as compressor:
1071 1087 self.assertEqual(compressor.write(orig), 0)
1072 1088
1073 1089 compressed = buffer.getvalue()
1074 1090 buffer = io.BytesIO()
1075 1091
1076 1092 dctx = zstd.ZstdDecompressor(dict_data=d)
1077 1093 decompressor = dctx.stream_writer(buffer)
1078 1094 self.assertEqual(decompressor.write(compressed), len(orig))
1079 1095 self.assertEqual(buffer.getvalue(), orig)
1080 1096
1081 1097 buffer = NonClosingBytesIO()
1082 1098
1083 1099 with dctx.stream_writer(buffer) as decompressor:
1084 1100 self.assertEqual(decompressor.write(compressed), len(orig))
1085 1101
1086 1102 self.assertEqual(buffer.getvalue(), orig)
1087 1103
1088 1104 def test_memory_size(self):
1089 1105 dctx = zstd.ZstdDecompressor()
1090 1106 buffer = io.BytesIO()
1091 1107
1092 1108 decompressor = dctx.stream_writer(buffer)
1093 1109 size = decompressor.memory_size()
1094 1110 self.assertGreater(size, 100000)
1095 1111
1096 1112 with dctx.stream_writer(buffer) as decompressor:
1097 1113 size = decompressor.memory_size()
1098 1114
1099 1115 self.assertGreater(size, 100000)
1100 1116
1101 1117 def test_write_size(self):
1102 1118 source = zstd.ZstdCompressor().compress(b"foobarfoobar")
1103 1119 dest = OpCountingBytesIO()
1104 1120 dctx = zstd.ZstdDecompressor()
1105 1121 with dctx.stream_writer(dest, write_size=1) as decompressor:
1106 1122 s = struct.Struct(">B")
1107 1123 for c in source:
1108 1124 if not isinstance(c, str):
1109 1125 c = s.pack(c)
1110 1126 decompressor.write(c)
1111 1127
1112 1128 self.assertEqual(dest.getvalue(), b"foobarfoobar")
1113 1129 self.assertEqual(dest._write_count, len(dest.getvalue()))
1114 1130
1115 1131
1116 1132 @make_cffi
1117 1133 class TestDecompressor_read_to_iter(TestCase):
1118 1134 def test_type_validation(self):
1119 1135 dctx = zstd.ZstdDecompressor()
1120 1136
1121 1137 # Object with read() works.
1122 1138 dctx.read_to_iter(io.BytesIO())
1123 1139
1124 1140 # Buffer protocol works.
1125 1141 dctx.read_to_iter(b"foobar")
1126 1142
1127 with self.assertRaisesRegex(ValueError, "must pass an object with a read"):
1143 with self.assertRaisesRegex(
1144 ValueError, "must pass an object with a read"
1145 ):
1128 1146 b"".join(dctx.read_to_iter(True))
1129 1147
1130 1148 def test_empty_input(self):
1131 1149 dctx = zstd.ZstdDecompressor()
1132 1150
1133 1151 source = io.BytesIO()
1134 1152 it = dctx.read_to_iter(source)
1135 1153 # TODO this is arguably wrong. Should get an error about missing frame foo.
1136 1154 with self.assertRaises(StopIteration):
1137 1155 next(it)
1138 1156
1139 1157 it = dctx.read_to_iter(b"")
1140 1158 with self.assertRaises(StopIteration):
1141 1159 next(it)
1142 1160
1143 1161 def test_invalid_input(self):
1144 1162 dctx = zstd.ZstdDecompressor()
1145 1163
1146 1164 source = io.BytesIO(b"foobar")
1147 1165 it = dctx.read_to_iter(source)
1148 1166 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
1149 1167 next(it)
1150 1168
1151 1169 it = dctx.read_to_iter(b"foobar")
1152 1170 with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"):
1153 1171 next(it)
1154 1172
1155 1173 def test_empty_roundtrip(self):
1156 1174 cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
1157 1175 empty = cctx.compress(b"")
1158 1176
1159 1177 source = io.BytesIO(empty)
1160 1178 source.seek(0)
1161 1179
1162 1180 dctx = zstd.ZstdDecompressor()
1163 1181 it = dctx.read_to_iter(source)
1164 1182
1165 1183 # No chunks should be emitted since there is no data.
1166 1184 with self.assertRaises(StopIteration):
1167 1185 next(it)
1168 1186
1169 1187 # Again for good measure.
1170 1188 with self.assertRaises(StopIteration):
1171 1189 next(it)
1172 1190
1173 1191 def test_skip_bytes_too_large(self):
1174 1192 dctx = zstd.ZstdDecompressor()
1175 1193
1176 1194 with self.assertRaisesRegex(
1177 1195 ValueError, "skip_bytes must be smaller than read_size"
1178 1196 ):
1179 1197 b"".join(dctx.read_to_iter(b"", skip_bytes=1, read_size=1))
1180 1198
1181 1199 with self.assertRaisesRegex(
1182 1200 ValueError, "skip_bytes larger than first input chunk"
1183 1201 ):
1184 1202 b"".join(dctx.read_to_iter(b"foobar", skip_bytes=10))
1185 1203
1186 1204 def test_skip_bytes(self):
1187 1205 cctx = zstd.ZstdCompressor(write_content_size=False)
1188 1206 compressed = cctx.compress(b"foobar")
1189 1207
1190 1208 dctx = zstd.ZstdDecompressor()
1191 1209 output = b"".join(dctx.read_to_iter(b"hdr" + compressed, skip_bytes=3))
1192 1210 self.assertEqual(output, b"foobar")
1193 1211
1194 1212 def test_large_output(self):
1195 1213 source = io.BytesIO()
1196 1214 source.write(b"f" * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
1197 1215 source.write(b"o")
1198 1216 source.seek(0)
1199 1217
1200 1218 cctx = zstd.ZstdCompressor(level=1)
1201 1219 compressed = io.BytesIO(cctx.compress(source.getvalue()))
1202 1220 compressed.seek(0)
1203 1221
1204 1222 dctx = zstd.ZstdDecompressor()
1205 1223 it = dctx.read_to_iter(compressed)
1206 1224
1207 1225 chunks = []
1208 1226 chunks.append(next(it))
1209 1227 chunks.append(next(it))
1210 1228
1211 1229 with self.assertRaises(StopIteration):
1212 1230 next(it)
1213 1231
1214 1232 decompressed = b"".join(chunks)
1215 1233 self.assertEqual(decompressed, source.getvalue())
1216 1234
1217 1235 # And again with buffer protocol.
1218 1236 it = dctx.read_to_iter(compressed.getvalue())
1219 1237 chunks = []
1220 1238 chunks.append(next(it))
1221 1239 chunks.append(next(it))
1222 1240
1223 1241 with self.assertRaises(StopIteration):
1224 1242 next(it)
1225 1243
1226 1244 decompressed = b"".join(chunks)
1227 1245 self.assertEqual(decompressed, source.getvalue())
1228 1246
1229 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
1247 @unittest.skipUnless(
1248 "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set"
1249 )
1230 1250 def test_large_input(self):
1231 1251 bytes = list(struct.Struct(">B").pack(i) for i in range(256))
1232 1252 compressed = NonClosingBytesIO()
1233 1253 input_size = 0
1234 1254 cctx = zstd.ZstdCompressor(level=1)
1235 1255 with cctx.stream_writer(compressed) as compressor:
1236 1256 while True:
1237 1257 compressor.write(random.choice(bytes))
1238 1258 input_size += 1
1239 1259
1240 1260 have_compressed = (
1241 1261 len(compressed.getvalue())
1242 1262 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
1243 1263 )
1244 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
1264 have_raw = (
1265 input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
1266 )
1245 1267 if have_compressed and have_raw:
1246 1268 break
1247 1269
1248 1270 compressed = io.BytesIO(compressed.getvalue())
1249 1271 self.assertGreater(
1250 len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
1272 len(compressed.getvalue()),
1273 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
1251 1274 )
1252 1275
1253 1276 dctx = zstd.ZstdDecompressor()
1254 1277 it = dctx.read_to_iter(compressed)
1255 1278
1256 1279 chunks = []
1257 1280 chunks.append(next(it))
1258 1281 chunks.append(next(it))
1259 1282 chunks.append(next(it))
1260 1283
1261 1284 with self.assertRaises(StopIteration):
1262 1285 next(it)
1263 1286
1264 1287 decompressed = b"".join(chunks)
1265 1288 self.assertEqual(len(decompressed), input_size)
1266 1289
1267 1290 # And again with buffer protocol.
1268 1291 it = dctx.read_to_iter(compressed.getvalue())
1269 1292
1270 1293 chunks = []
1271 1294 chunks.append(next(it))
1272 1295 chunks.append(next(it))
1273 1296 chunks.append(next(it))
1274 1297
1275 1298 with self.assertRaises(StopIteration):
1276 1299 next(it)
1277 1300
1278 1301 decompressed = b"".join(chunks)
1279 1302 self.assertEqual(len(decompressed), input_size)
1280 1303
1281 1304 def test_interesting(self):
1282 1305 # Found this edge case via fuzzing.
1283 1306 cctx = zstd.ZstdCompressor(level=1)
1284 1307
1285 1308 source = io.BytesIO()
1286 1309
1287 1310 compressed = NonClosingBytesIO()
1288 1311 with cctx.stream_writer(compressed) as compressor:
1289 1312 for i in range(256):
1290 1313 chunk = b"\0" * 1024
1291 1314 compressor.write(chunk)
1292 1315 source.write(chunk)
1293 1316
1294 1317 dctx = zstd.ZstdDecompressor()
1295 1318
1296 1319 simple = dctx.decompress(
1297 1320 compressed.getvalue(), max_output_size=len(source.getvalue())
1298 1321 )
1299 1322 self.assertEqual(simple, source.getvalue())
1300 1323
1301 1324 compressed = io.BytesIO(compressed.getvalue())
1302 1325 streamed = b"".join(dctx.read_to_iter(compressed))
1303 1326 self.assertEqual(streamed, source.getvalue())
1304 1327
1305 1328 def test_read_write_size(self):
1306 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar"))
1329 source = OpCountingBytesIO(
1330 zstd.ZstdCompressor().compress(b"foobarfoobar")
1331 )
1307 1332 dctx = zstd.ZstdDecompressor()
1308 1333 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1):
1309 1334 self.assertEqual(len(chunk), 1)
1310 1335
1311 1336 self.assertEqual(source._read_count, len(source.getvalue()))
1312 1337
1313 1338 def test_magic_less(self):
1314 1339 params = zstd.CompressionParameters.from_level(
1315 1340 1, format=zstd.FORMAT_ZSTD1_MAGICLESS
1316 1341 )
1317 1342 cctx = zstd.ZstdCompressor(compression_params=params)
1318 1343 frame = cctx.compress(b"foobar")
1319 1344
1320 1345 self.assertNotEqual(frame[0:4], b"\x28\xb5\x2f\xfd")
1321 1346
1322 1347 dctx = zstd.ZstdDecompressor()
1323 1348 with self.assertRaisesRegex(
1324 1349 zstd.ZstdError, "error determining content size from frame header"
1325 1350 ):
1326 1351 dctx.decompress(frame)
1327 1352
1328 1353 dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS)
1329 1354 res = b"".join(dctx.read_to_iter(frame))
1330 1355 self.assertEqual(res, b"foobar")
1331 1356
1332 1357
1333 1358 @make_cffi
1334 1359 class TestDecompressor_content_dict_chain(TestCase):
1335 1360 def test_bad_inputs_simple(self):
1336 1361 dctx = zstd.ZstdDecompressor()
1337 1362
1338 1363 with self.assertRaises(TypeError):
1339 1364 dctx.decompress_content_dict_chain(b"foo")
1340 1365
1341 1366 with self.assertRaises(TypeError):
1342 1367 dctx.decompress_content_dict_chain((b"foo", b"bar"))
1343 1368
1344 1369 with self.assertRaisesRegex(ValueError, "empty input chain"):
1345 1370 dctx.decompress_content_dict_chain([])
1346 1371
1347 1372 with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"):
1348 1373 dctx.decompress_content_dict_chain([u"foo"])
1349 1374
1350 1375 with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"):
1351 1376 dctx.decompress_content_dict_chain([True])
1352 1377
1353 1378 with self.assertRaisesRegex(
1354 1379 ValueError, "chunk 0 is too small to contain a zstd frame"
1355 1380 ):
1356 1381 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
1357 1382
1358 with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"):
1383 with self.assertRaisesRegex(
1384 ValueError, "chunk 0 is not a valid zstd frame"
1385 ):
1359 1386 dctx.decompress_content_dict_chain([b"foo" * 8])
1360 1387
1361 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64)
1388 no_size = zstd.ZstdCompressor(write_content_size=False).compress(
1389 b"foo" * 64
1390 )
1362 1391
1363 1392 with self.assertRaisesRegex(
1364 1393 ValueError, "chunk 0 missing content size in frame"
1365 1394 ):
1366 1395 dctx.decompress_content_dict_chain([no_size])
1367 1396
1368 1397 # Corrupt first frame.
1369 1398 frame = zstd.ZstdCompressor().compress(b"foo" * 64)
1370 1399 frame = frame[0:12] + frame[15:]
1371 1400 with self.assertRaisesRegex(
1372 1401 zstd.ZstdError, "chunk 0 did not decompress full frame"
1373 1402 ):
1374 1403 dctx.decompress_content_dict_chain([frame])
1375 1404
1376 1405 def test_bad_subsequent_input(self):
1377 1406 initial = zstd.ZstdCompressor().compress(b"foo" * 64)
1378 1407
1379 1408 dctx = zstd.ZstdDecompressor()
1380 1409
1381 1410 with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"):
1382 1411 dctx.decompress_content_dict_chain([initial, u"foo"])
1383 1412
1384 1413 with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"):
1385 1414 dctx.decompress_content_dict_chain([initial, None])
1386 1415
1387 1416 with self.assertRaisesRegex(
1388 1417 ValueError, "chunk 1 is too small to contain a zstd frame"
1389 1418 ):
1390 1419 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
1391 1420
1392 with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"):
1421 with self.assertRaisesRegex(
1422 ValueError, "chunk 1 is not a valid zstd frame"
1423 ):
1393 1424 dctx.decompress_content_dict_chain([initial, b"foo" * 8])
1394 1425
1395 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64)
1426 no_size = zstd.ZstdCompressor(write_content_size=False).compress(
1427 b"foo" * 64
1428 )
1396 1429
1397 1430 with self.assertRaisesRegex(
1398 1431 ValueError, "chunk 1 missing content size in frame"
1399 1432 ):
1400 1433 dctx.decompress_content_dict_chain([initial, no_size])
1401 1434
1402 1435 # Corrupt second frame.
1403 cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64))
1436 cctx = zstd.ZstdCompressor(
1437 dict_data=zstd.ZstdCompressionDict(b"foo" * 64)
1438 )
1404 1439 frame = cctx.compress(b"bar" * 64)
1405 1440 frame = frame[0:12] + frame[15:]
1406 1441
1407 1442 with self.assertRaisesRegex(
1408 1443 zstd.ZstdError, "chunk 1 did not decompress full frame"
1409 1444 ):
1410 1445 dctx.decompress_content_dict_chain([initial, frame])
1411 1446
1412 1447 def test_simple(self):
1413 1448 original = [
1414 1449 b"foo" * 64,
1415 1450 b"foobar" * 64,
1416 1451 b"baz" * 64,
1417 1452 b"foobaz" * 64,
1418 1453 b"foobarbaz" * 64,
1419 1454 ]
1420 1455
1421 1456 chunks = []
1422 1457 chunks.append(zstd.ZstdCompressor().compress(original[0]))
1423 1458 for i, chunk in enumerate(original[1:]):
1424 1459 d = zstd.ZstdCompressionDict(original[i])
1425 1460 cctx = zstd.ZstdCompressor(dict_data=d)
1426 1461 chunks.append(cctx.compress(chunk))
1427 1462
1428 1463 for i in range(1, len(original)):
1429 1464 chain = chunks[0:i]
1430 1465 expected = original[i - 1]
1431 1466 dctx = zstd.ZstdDecompressor()
1432 1467 decompressed = dctx.decompress_content_dict_chain(chain)
1433 1468 self.assertEqual(decompressed, expected)
1434 1469
1435 1470
1436 1471 # TODO enable for CFFI
1437 1472 class TestDecompressor_multi_decompress_to_buffer(TestCase):
1438 1473 def test_invalid_inputs(self):
1439 1474 dctx = zstd.ZstdDecompressor()
1440 1475
1441 1476 if not hasattr(dctx, "multi_decompress_to_buffer"):
1442 1477 self.skipTest("multi_decompress_to_buffer not available")
1443 1478
1444 1479 with self.assertRaises(TypeError):
1445 1480 dctx.multi_decompress_to_buffer(True)
1446 1481
1447 1482 with self.assertRaises(TypeError):
1448 1483 dctx.multi_decompress_to_buffer((1, 2))
1449 1484
1450 with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"):
1485 with self.assertRaisesRegex(
1486 TypeError, "item 0 not a bytes like object"
1487 ):
1451 1488 dctx.multi_decompress_to_buffer([u"foo"])
1452 1489
1453 1490 with self.assertRaisesRegex(
1454 1491 ValueError, "could not determine decompressed size of item 0"
1455 1492 ):
1456 1493 dctx.multi_decompress_to_buffer([b"foobarbaz"])
1457 1494
1458 1495 def test_list_input(self):
1459 1496 cctx = zstd.ZstdCompressor()
1460 1497
1461 1498 original = [b"foo" * 4, b"bar" * 6]
1462 1499 frames = [cctx.compress(d) for d in original]
1463 1500
1464 1501 dctx = zstd.ZstdDecompressor()
1465 1502
1466 1503 if not hasattr(dctx, "multi_decompress_to_buffer"):
1467 1504 self.skipTest("multi_decompress_to_buffer not available")
1468 1505
1469 1506 result = dctx.multi_decompress_to_buffer(frames)
1470 1507
1471 1508 self.assertEqual(len(result), len(frames))
1472 1509 self.assertEqual(result.size(), sum(map(len, original)))
1473 1510
1474 1511 for i, data in enumerate(original):
1475 1512 self.assertEqual(result[i].tobytes(), data)
1476 1513
1477 1514 self.assertEqual(result[0].offset, 0)
1478 1515 self.assertEqual(len(result[0]), 12)
1479 1516 self.assertEqual(result[1].offset, 12)
1480 1517 self.assertEqual(len(result[1]), 18)
1481 1518
1482 1519 def test_list_input_frame_sizes(self):
1483 1520 cctx = zstd.ZstdCompressor()
1484 1521
1485 1522 original = [b"foo" * 4, b"bar" * 6, b"baz" * 8]
1486 1523 frames = [cctx.compress(d) for d in original]
1487 1524 sizes = struct.pack("=" + "Q" * len(original), *map(len, original))
1488 1525
1489 1526 dctx = zstd.ZstdDecompressor()
1490 1527
1491 1528 if not hasattr(dctx, "multi_decompress_to_buffer"):
1492 1529 self.skipTest("multi_decompress_to_buffer not available")
1493 1530
1494 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
1531 result = dctx.multi_decompress_to_buffer(
1532 frames, decompressed_sizes=sizes
1533 )
1495 1534
1496 1535 self.assertEqual(len(result), len(frames))
1497 1536 self.assertEqual(result.size(), sum(map(len, original)))
1498 1537
1499 1538 for i, data in enumerate(original):
1500 1539 self.assertEqual(result[i].tobytes(), data)
1501 1540
1502 1541 def test_buffer_with_segments_input(self):
1503 1542 cctx = zstd.ZstdCompressor()
1504 1543
1505 1544 original = [b"foo" * 4, b"bar" * 6]
1506 1545 frames = [cctx.compress(d) for d in original]
1507 1546
1508 1547 dctx = zstd.ZstdDecompressor()
1509 1548
1510 1549 if not hasattr(dctx, "multi_decompress_to_buffer"):
1511 1550 self.skipTest("multi_decompress_to_buffer not available")
1512 1551
1513 1552 segments = struct.pack(
1514 1553 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])
1515 1554 )
1516 1555 b = zstd.BufferWithSegments(b"".join(frames), segments)
1517 1556
1518 1557 result = dctx.multi_decompress_to_buffer(b)
1519 1558
1520 1559 self.assertEqual(len(result), len(frames))
1521 1560 self.assertEqual(result[0].offset, 0)
1522 1561 self.assertEqual(len(result[0]), 12)
1523 1562 self.assertEqual(result[1].offset, 12)
1524 1563 self.assertEqual(len(result[1]), 18)
1525 1564
1526 1565 def test_buffer_with_segments_sizes(self):
1527 1566 cctx = zstd.ZstdCompressor(write_content_size=False)
1528 1567 original = [b"foo" * 4, b"bar" * 6, b"baz" * 8]
1529 1568 frames = [cctx.compress(d) for d in original]
1530 1569 sizes = struct.pack("=" + "Q" * len(original), *map(len, original))
1531 1570
1532 1571 dctx = zstd.ZstdDecompressor()
1533 1572
1534 1573 if not hasattr(dctx, "multi_decompress_to_buffer"):
1535 1574 self.skipTest("multi_decompress_to_buffer not available")
1536 1575
1537 1576 segments = struct.pack(
1538 1577 "=QQQQQQ",
1539 1578 0,
1540 1579 len(frames[0]),
1541 1580 len(frames[0]),
1542 1581 len(frames[1]),
1543 1582 len(frames[0]) + len(frames[1]),
1544 1583 len(frames[2]),
1545 1584 )
1546 1585 b = zstd.BufferWithSegments(b"".join(frames), segments)
1547 1586
1548 1587 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
1549 1588
1550 1589 self.assertEqual(len(result), len(frames))
1551 1590 self.assertEqual(result.size(), sum(map(len, original)))
1552 1591
1553 1592 for i, data in enumerate(original):
1554 1593 self.assertEqual(result[i].tobytes(), data)
1555 1594
1556 1595 def test_buffer_with_segments_collection_input(self):
1557 1596 cctx = zstd.ZstdCompressor()
1558 1597
1559 1598 original = [
1560 1599 b"foo0" * 2,
1561 1600 b"foo1" * 3,
1562 1601 b"foo2" * 4,
1563 1602 b"foo3" * 5,
1564 1603 b"foo4" * 6,
1565 1604 ]
1566 1605
1567 1606 if not hasattr(cctx, "multi_compress_to_buffer"):
1568 1607 self.skipTest("multi_compress_to_buffer not available")
1569 1608
1570 1609 frames = cctx.multi_compress_to_buffer(original)
1571 1610
1572 1611 # Check round trip.
1573 1612 dctx = zstd.ZstdDecompressor()
1574 1613
1575 1614 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
1576 1615
1577 1616 self.assertEqual(len(decompressed), len(original))
1578 1617
1579 1618 for i, data in enumerate(original):
1580 1619 self.assertEqual(data, decompressed[i].tobytes())
1581 1620
1582 1621 # And a manual mode.
1583 1622 b = b"".join([frames[0].tobytes(), frames[1].tobytes()])
1584 1623 b1 = zstd.BufferWithSegments(
1585 b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]))
1624 b,
1625 struct.pack(
1626 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])
1627 ),
1586 1628 )
1587 1629
1588 b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()])
1630 b = b"".join(
1631 [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]
1632 )
1589 1633 b2 = zstd.BufferWithSegments(
1590 1634 b,
1591 1635 struct.pack(
1592 1636 "=QQQQQQ",
1593 1637 0,
1594 1638 len(frames[2]),
1595 1639 len(frames[2]),
1596 1640 len(frames[3]),
1597 1641 len(frames[2]) + len(frames[3]),
1598 1642 len(frames[4]),
1599 1643 ),
1600 1644 )
1601 1645
1602 1646 c = zstd.BufferWithSegmentsCollection(b1, b2)
1603 1647
1604 1648 dctx = zstd.ZstdDecompressor()
1605 1649 decompressed = dctx.multi_decompress_to_buffer(c)
1606 1650
1607 1651 self.assertEqual(len(decompressed), 5)
1608 1652 for i in range(5):
1609 1653 self.assertEqual(decompressed[i].tobytes(), original[i])
1610 1654
1611 1655 def test_dict(self):
1612 1656 d = zstd.train_dictionary(16384, generate_samples(), k=64, d=16)
1613 1657
1614 1658 cctx = zstd.ZstdCompressor(dict_data=d, level=1)
1615 1659 frames = [cctx.compress(s) for s in generate_samples()]
1616 1660
1617 1661 dctx = zstd.ZstdDecompressor(dict_data=d)
1618 1662
1619 1663 if not hasattr(dctx, "multi_decompress_to_buffer"):
1620 1664 self.skipTest("multi_decompress_to_buffer not available")
1621 1665
1622 1666 result = dctx.multi_decompress_to_buffer(frames)
1623 1667
1624 1668 self.assertEqual([o.tobytes() for o in result], generate_samples())
1625 1669
1626 1670 def test_multiple_threads(self):
1627 1671 cctx = zstd.ZstdCompressor()
1628 1672
1629 1673 frames = []
1630 1674 frames.extend(cctx.compress(b"x" * 64) for i in range(256))
1631 1675 frames.extend(cctx.compress(b"y" * 64) for i in range(256))
1632 1676
1633 1677 dctx = zstd.ZstdDecompressor()
1634 1678
1635 1679 if not hasattr(dctx, "multi_decompress_to_buffer"):
1636 1680 self.skipTest("multi_decompress_to_buffer not available")
1637 1681
1638 1682 result = dctx.multi_decompress_to_buffer(frames, threads=-1)
1639 1683
1640 1684 self.assertEqual(len(result), len(frames))
1641 1685 self.assertEqual(result.size(), 2 * 64 * 256)
1642 1686 self.assertEqual(result[0].tobytes(), b"x" * 64)
1643 1687 self.assertEqual(result[256].tobytes(), b"y" * 64)
1644 1688
1645 1689 def test_item_failure(self):
1646 1690 cctx = zstd.ZstdCompressor()
1647 1691 frames = [cctx.compress(b"x" * 128), cctx.compress(b"y" * 128)]
1648 1692
1649 1693 frames[1] = frames[1][0:15] + b"extra" + frames[1][15:]
1650 1694
1651 1695 dctx = zstd.ZstdDecompressor()
1652 1696
1653 1697 if not hasattr(dctx, "multi_decompress_to_buffer"):
1654 1698 self.skipTest("multi_decompress_to_buffer not available")
1655 1699
1656 1700 with self.assertRaisesRegex(
1657 1701 zstd.ZstdError,
1658 1702 "error decompressing item 1: ("
1659 1703 "Corrupted block|"
1660 1704 "Destination buffer is too small)",
1661 1705 ):
1662 1706 dctx.multi_decompress_to_buffer(frames)
1663 1707
1664 1708 with self.assertRaisesRegex(
1665 1709 zstd.ZstdError,
1666 1710 "error decompressing item 1: ("
1667 1711 "Corrupted block|"
1668 1712 "Destination buffer is too small)",
1669 1713 ):
1670 1714 dctx.multi_decompress_to_buffer(frames, threads=2)
@@ -1,576 +1,593 b''
1 1 import io
2 2 import os
3 3 import unittest
4 4
5 5 try:
6 6 import hypothesis
7 7 import hypothesis.strategies as strategies
8 8 except ImportError:
9 9 raise unittest.SkipTest("hypothesis not available")
10 10
11 11 import zstandard as zstd
12 12
13 13 from .common import (
14 14 make_cffi,
15 15 NonClosingBytesIO,
16 16 random_input_data,
17 17 TestCase,
18 18 )
19 19
20 20
21 21 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
22 22 @make_cffi
23 23 class TestDecompressor_stream_reader_fuzzing(TestCase):
24 24 @hypothesis.settings(
25 25 suppress_health_check=[
26 26 hypothesis.HealthCheck.large_base_example,
27 27 hypothesis.HealthCheck.too_slow,
28 28 ]
29 29 )
30 30 @hypothesis.given(
31 31 original=strategies.sampled_from(random_input_data()),
32 32 level=strategies.integers(min_value=1, max_value=5),
33 33 streaming=strategies.booleans(),
34 34 source_read_size=strategies.integers(1, 1048576),
35 35 read_sizes=strategies.data(),
36 36 )
37 37 def test_stream_source_read_variance(
38 38 self, original, level, streaming, source_read_size, read_sizes
39 39 ):
40 40 cctx = zstd.ZstdCompressor(level=level)
41 41
42 42 if streaming:
43 43 source = io.BytesIO()
44 44 writer = cctx.stream_writer(source)
45 45 writer.write(original)
46 46 writer.flush(zstd.FLUSH_FRAME)
47 47 source.seek(0)
48 48 else:
49 49 frame = cctx.compress(original)
50 50 source = io.BytesIO(frame)
51 51
52 52 dctx = zstd.ZstdDecompressor()
53 53
54 54 chunks = []
55 55 with dctx.stream_reader(source, read_size=source_read_size) as reader:
56 56 while True:
57 57 read_size = read_sizes.draw(strategies.integers(-1, 131072))
58 58 chunk = reader.read(read_size)
59 59 if not chunk and read_size:
60 60 break
61 61
62 62 chunks.append(chunk)
63 63
64 64 self.assertEqual(b"".join(chunks), original)
65 65
66 66 # Similar to above except we have a constant read() size.
67 67 @hypothesis.settings(
68 68 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
69 69 )
70 70 @hypothesis.given(
71 71 original=strategies.sampled_from(random_input_data()),
72 72 level=strategies.integers(min_value=1, max_value=5),
73 73 streaming=strategies.booleans(),
74 74 source_read_size=strategies.integers(1, 1048576),
75 75 read_size=strategies.integers(-1, 131072),
76 76 )
77 77 def test_stream_source_read_size(
78 78 self, original, level, streaming, source_read_size, read_size
79 79 ):
80 80 if read_size == 0:
81 81 read_size = 1
82 82
83 83 cctx = zstd.ZstdCompressor(level=level)
84 84
85 85 if streaming:
86 86 source = io.BytesIO()
87 87 writer = cctx.stream_writer(source)
88 88 writer.write(original)
89 89 writer.flush(zstd.FLUSH_FRAME)
90 90 source.seek(0)
91 91 else:
92 92 frame = cctx.compress(original)
93 93 source = io.BytesIO(frame)
94 94
95 95 dctx = zstd.ZstdDecompressor()
96 96
97 97 chunks = []
98 98 reader = dctx.stream_reader(source, read_size=source_read_size)
99 99 while True:
100 100 chunk = reader.read(read_size)
101 101 if not chunk and read_size:
102 102 break
103 103
104 104 chunks.append(chunk)
105 105
106 106 self.assertEqual(b"".join(chunks), original)
107 107
108 108 @hypothesis.settings(
109 109 suppress_health_check=[
110 110 hypothesis.HealthCheck.large_base_example,
111 111 hypothesis.HealthCheck.too_slow,
112 112 ]
113 113 )
114 114 @hypothesis.given(
115 115 original=strategies.sampled_from(random_input_data()),
116 116 level=strategies.integers(min_value=1, max_value=5),
117 117 streaming=strategies.booleans(),
118 118 source_read_size=strategies.integers(1, 1048576),
119 119 read_sizes=strategies.data(),
120 120 )
121 121 def test_buffer_source_read_variance(
122 122 self, original, level, streaming, source_read_size, read_sizes
123 123 ):
124 124 cctx = zstd.ZstdCompressor(level=level)
125 125
126 126 if streaming:
127 127 source = io.BytesIO()
128 128 writer = cctx.stream_writer(source)
129 129 writer.write(original)
130 130 writer.flush(zstd.FLUSH_FRAME)
131 131 frame = source.getvalue()
132 132 else:
133 133 frame = cctx.compress(original)
134 134
135 135 dctx = zstd.ZstdDecompressor()
136 136 chunks = []
137 137
138 138 with dctx.stream_reader(frame, read_size=source_read_size) as reader:
139 139 while True:
140 140 read_size = read_sizes.draw(strategies.integers(-1, 131072))
141 141 chunk = reader.read(read_size)
142 142 if not chunk and read_size:
143 143 break
144 144
145 145 chunks.append(chunk)
146 146
147 147 self.assertEqual(b"".join(chunks), original)
148 148
149 149 # Similar to above except we have a constant read() size.
150 150 @hypothesis.settings(
151 151 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
152 152 )
153 153 @hypothesis.given(
154 154 original=strategies.sampled_from(random_input_data()),
155 155 level=strategies.integers(min_value=1, max_value=5),
156 156 streaming=strategies.booleans(),
157 157 source_read_size=strategies.integers(1, 1048576),
158 158 read_size=strategies.integers(-1, 131072),
159 159 )
160 160 def test_buffer_source_constant_read_size(
161 161 self, original, level, streaming, source_read_size, read_size
162 162 ):
163 163 if read_size == 0:
164 164 read_size = -1
165 165
166 166 cctx = zstd.ZstdCompressor(level=level)
167 167
168 168 if streaming:
169 169 source = io.BytesIO()
170 170 writer = cctx.stream_writer(source)
171 171 writer.write(original)
172 172 writer.flush(zstd.FLUSH_FRAME)
173 173 frame = source.getvalue()
174 174 else:
175 175 frame = cctx.compress(original)
176 176
177 177 dctx = zstd.ZstdDecompressor()
178 178 chunks = []
179 179
180 180 reader = dctx.stream_reader(frame, read_size=source_read_size)
181 181 while True:
182 182 chunk = reader.read(read_size)
183 183 if not chunk and read_size:
184 184 break
185 185
186 186 chunks.append(chunk)
187 187
188 188 self.assertEqual(b"".join(chunks), original)
189 189
190 190 @hypothesis.settings(
191 191 suppress_health_check=[hypothesis.HealthCheck.large_base_example]
192 192 )
193 193 @hypothesis.given(
194 194 original=strategies.sampled_from(random_input_data()),
195 195 level=strategies.integers(min_value=1, max_value=5),
196 196 streaming=strategies.booleans(),
197 197 source_read_size=strategies.integers(1, 1048576),
198 198 )
199 def test_stream_source_readall(self, original, level, streaming, source_read_size):
199 def test_stream_source_readall(
200 self, original, level, streaming, source_read_size
201 ):
200 202 cctx = zstd.ZstdCompressor(level=level)
201 203
202 204 if streaming:
203 205 source = io.BytesIO()
204 206 writer = cctx.stream_writer(source)
205 207 writer.write(original)
206 208 writer.flush(zstd.FLUSH_FRAME)
207 209 source.seek(0)
208 210 else:
209 211 frame = cctx.compress(original)
210 212 source = io.BytesIO(frame)
211 213
212 214 dctx = zstd.ZstdDecompressor()
213 215
214 216 data = dctx.stream_reader(source, read_size=source_read_size).readall()
215 217 self.assertEqual(data, original)
216 218
217 219 @hypothesis.settings(
218 220 suppress_health_check=[
219 221 hypothesis.HealthCheck.large_base_example,
220 222 hypothesis.HealthCheck.too_slow,
221 223 ]
222 224 )
223 225 @hypothesis.given(
224 226 original=strategies.sampled_from(random_input_data()),
225 227 level=strategies.integers(min_value=1, max_value=5),
226 228 streaming=strategies.booleans(),
227 229 source_read_size=strategies.integers(1, 1048576),
228 230 read_sizes=strategies.data(),
229 231 )
230 232 def test_stream_source_read1_variance(
231 233 self, original, level, streaming, source_read_size, read_sizes
232 234 ):
233 235 cctx = zstd.ZstdCompressor(level=level)
234 236
235 237 if streaming:
236 238 source = io.BytesIO()
237 239 writer = cctx.stream_writer(source)
238 240 writer.write(original)
239 241 writer.flush(zstd.FLUSH_FRAME)
240 242 source.seek(0)
241 243 else:
242 244 frame = cctx.compress(original)
243 245 source = io.BytesIO(frame)
244 246
245 247 dctx = zstd.ZstdDecompressor()
246 248
247 249 chunks = []
248 250 with dctx.stream_reader(source, read_size=source_read_size) as reader:
249 251 while True:
250 252 read_size = read_sizes.draw(strategies.integers(-1, 131072))
251 253 chunk = reader.read1(read_size)
252 254 if not chunk and read_size:
253 255 break
254 256
255 257 chunks.append(chunk)
256 258
257 259 self.assertEqual(b"".join(chunks), original)
258 260
259 261 @hypothesis.settings(
260 262 suppress_health_check=[
261 263 hypothesis.HealthCheck.large_base_example,
262 264 hypothesis.HealthCheck.too_slow,
263 265 ]
264 266 )
265 267 @hypothesis.given(
266 268 original=strategies.sampled_from(random_input_data()),
267 269 level=strategies.integers(min_value=1, max_value=5),
268 270 streaming=strategies.booleans(),
269 271 source_read_size=strategies.integers(1, 1048576),
270 272 read_sizes=strategies.data(),
271 273 )
272 274 def test_stream_source_readinto1_variance(
273 275 self, original, level, streaming, source_read_size, read_sizes
274 276 ):
275 277 cctx = zstd.ZstdCompressor(level=level)
276 278
277 279 if streaming:
278 280 source = io.BytesIO()
279 281 writer = cctx.stream_writer(source)
280 282 writer.write(original)
281 283 writer.flush(zstd.FLUSH_FRAME)
282 284 source.seek(0)
283 285 else:
284 286 frame = cctx.compress(original)
285 287 source = io.BytesIO(frame)
286 288
287 289 dctx = zstd.ZstdDecompressor()
288 290
289 291 chunks = []
290 292 with dctx.stream_reader(source, read_size=source_read_size) as reader:
291 293 while True:
292 294 read_size = read_sizes.draw(strategies.integers(1, 131072))
293 295 b = bytearray(read_size)
294 296 count = reader.readinto1(b)
295 297
296 298 if not count:
297 299 break
298 300
299 301 chunks.append(bytes(b[0:count]))
300 302
301 303 self.assertEqual(b"".join(chunks), original)
302 304
303 305 @hypothesis.settings(
304 306 suppress_health_check=[
305 307 hypothesis.HealthCheck.large_base_example,
306 308 hypothesis.HealthCheck.too_slow,
307 309 ]
308 310 )
309 311 @hypothesis.given(
310 312 original=strategies.sampled_from(random_input_data()),
311 313 level=strategies.integers(min_value=1, max_value=5),
312 314 source_read_size=strategies.integers(1, 1048576),
313 315 seek_amounts=strategies.data(),
314 316 read_sizes=strategies.data(),
315 317 )
316 318 def test_relative_seeks(
317 319 self, original, level, source_read_size, seek_amounts, read_sizes
318 320 ):
319 321 cctx = zstd.ZstdCompressor(level=level)
320 322 frame = cctx.compress(original)
321 323
322 324 dctx = zstd.ZstdDecompressor()
323 325
324 326 with dctx.stream_reader(frame, read_size=source_read_size) as reader:
325 327 while True:
326 328 amount = seek_amounts.draw(strategies.integers(0, 16384))
327 329 reader.seek(amount, os.SEEK_CUR)
328 330
329 331 offset = reader.tell()
330 332 read_amount = read_sizes.draw(strategies.integers(1, 16384))
331 333 chunk = reader.read(read_amount)
332 334
333 335 if not chunk:
334 336 break
335 337
336 338 self.assertEqual(original[offset : offset + len(chunk)], chunk)
337 339
338 340 @hypothesis.settings(
339 341 suppress_health_check=[
340 342 hypothesis.HealthCheck.large_base_example,
341 343 hypothesis.HealthCheck.too_slow,
342 344 ]
343 345 )
344 346 @hypothesis.given(
345 347 originals=strategies.data(),
346 348 frame_count=strategies.integers(min_value=2, max_value=10),
347 349 level=strategies.integers(min_value=1, max_value=5),
348 350 source_read_size=strategies.integers(1, 1048576),
349 351 read_sizes=strategies.data(),
350 352 )
351 353 def test_multiple_frames(
352 354 self, originals, frame_count, level, source_read_size, read_sizes
353 355 ):
354 356
355 357 cctx = zstd.ZstdCompressor(level=level)
356 358 source = io.BytesIO()
357 359 buffer = io.BytesIO()
358 360 writer = cctx.stream_writer(buffer)
359 361
360 362 for i in range(frame_count):
361 363 data = originals.draw(strategies.sampled_from(random_input_data()))
362 364 source.write(data)
363 365 writer.write(data)
364 366 writer.flush(zstd.FLUSH_FRAME)
365 367
366 368 dctx = zstd.ZstdDecompressor()
367 369 buffer.seek(0)
368 370 reader = dctx.stream_reader(
369 371 buffer, read_size=source_read_size, read_across_frames=True
370 372 )
371 373
372 374 chunks = []
373 375
374 376 while True:
375 377 read_amount = read_sizes.draw(strategies.integers(-1, 16384))
376 378 chunk = reader.read(read_amount)
377 379
378 380 if not chunk and read_amount:
379 381 break
380 382
381 383 chunks.append(chunk)
382 384
383 385 self.assertEqual(source.getvalue(), b"".join(chunks))
384 386
385 387
386 388 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
387 389 @make_cffi
388 390 class TestDecompressor_stream_writer_fuzzing(TestCase):
389 391 @hypothesis.settings(
390 392 suppress_health_check=[
391 393 hypothesis.HealthCheck.large_base_example,
392 394 hypothesis.HealthCheck.too_slow,
393 395 ]
394 396 )
395 397 @hypothesis.given(
396 398 original=strategies.sampled_from(random_input_data()),
397 399 level=strategies.integers(min_value=1, max_value=5),
398 400 write_size=strategies.integers(min_value=1, max_value=8192),
399 401 input_sizes=strategies.data(),
400 402 )
401 def test_write_size_variance(self, original, level, write_size, input_sizes):
403 def test_write_size_variance(
404 self, original, level, write_size, input_sizes
405 ):
402 406 cctx = zstd.ZstdCompressor(level=level)
403 407 frame = cctx.compress(original)
404 408
405 409 dctx = zstd.ZstdDecompressor()
406 410 source = io.BytesIO(frame)
407 411 dest = NonClosingBytesIO()
408 412
409 413 with dctx.stream_writer(dest, write_size=write_size) as decompressor:
410 414 while True:
411 415 input_size = input_sizes.draw(strategies.integers(1, 4096))
412 416 chunk = source.read(input_size)
413 417 if not chunk:
414 418 break
415 419
416 420 decompressor.write(chunk)
417 421
418 422 self.assertEqual(dest.getvalue(), original)
419 423
420 424
421 425 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
422 426 @make_cffi
423 427 class TestDecompressor_copy_stream_fuzzing(TestCase):
424 428 @hypothesis.settings(
425 429 suppress_health_check=[
426 430 hypothesis.HealthCheck.large_base_example,
427 431 hypothesis.HealthCheck.too_slow,
428 432 ]
429 433 )
430 434 @hypothesis.given(
431 435 original=strategies.sampled_from(random_input_data()),
432 436 level=strategies.integers(min_value=1, max_value=5),
433 437 read_size=strategies.integers(min_value=1, max_value=8192),
434 438 write_size=strategies.integers(min_value=1, max_value=8192),
435 439 )
436 def test_read_write_size_variance(self, original, level, read_size, write_size):
440 def test_read_write_size_variance(
441 self, original, level, read_size, write_size
442 ):
437 443 cctx = zstd.ZstdCompressor(level=level)
438 444 frame = cctx.compress(original)
439 445
440 446 source = io.BytesIO(frame)
441 447 dest = io.BytesIO()
442 448
443 449 dctx = zstd.ZstdDecompressor()
444 dctx.copy_stream(source, dest, read_size=read_size, write_size=write_size)
450 dctx.copy_stream(
451 source, dest, read_size=read_size, write_size=write_size
452 )
445 453
446 454 self.assertEqual(dest.getvalue(), original)
447 455
448 456
449 457 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
450 458 @make_cffi
451 459 class TestDecompressor_decompressobj_fuzzing(TestCase):
452 460 @hypothesis.settings(
453 461 suppress_health_check=[
454 462 hypothesis.HealthCheck.large_base_example,
455 463 hypothesis.HealthCheck.too_slow,
456 464 ]
457 465 )
458 466 @hypothesis.given(
459 467 original=strategies.sampled_from(random_input_data()),
460 468 level=strategies.integers(min_value=1, max_value=5),
461 469 chunk_sizes=strategies.data(),
462 470 )
463 471 def test_random_input_sizes(self, original, level, chunk_sizes):
464 472 cctx = zstd.ZstdCompressor(level=level)
465 473 frame = cctx.compress(original)
466 474
467 475 source = io.BytesIO(frame)
468 476
469 477 dctx = zstd.ZstdDecompressor()
470 478 dobj = dctx.decompressobj()
471 479
472 480 chunks = []
473 481 while True:
474 482 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
475 483 chunk = source.read(chunk_size)
476 484 if not chunk:
477 485 break
478 486
479 487 chunks.append(dobj.decompress(chunk))
480 488
481 489 self.assertEqual(b"".join(chunks), original)
482 490
483 491 @hypothesis.settings(
484 492 suppress_health_check=[
485 493 hypothesis.HealthCheck.large_base_example,
486 494 hypothesis.HealthCheck.too_slow,
487 495 ]
488 496 )
489 497 @hypothesis.given(
490 498 original=strategies.sampled_from(random_input_data()),
491 499 level=strategies.integers(min_value=1, max_value=5),
492 500 write_size=strategies.integers(
493 min_value=1, max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE
501 min_value=1,
502 max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
494 503 ),
495 504 chunk_sizes=strategies.data(),
496 505 )
497 def test_random_output_sizes(self, original, level, write_size, chunk_sizes):
506 def test_random_output_sizes(
507 self, original, level, write_size, chunk_sizes
508 ):
498 509 cctx = zstd.ZstdCompressor(level=level)
499 510 frame = cctx.compress(original)
500 511
501 512 source = io.BytesIO(frame)
502 513
503 514 dctx = zstd.ZstdDecompressor()
504 515 dobj = dctx.decompressobj(write_size=write_size)
505 516
506 517 chunks = []
507 518 while True:
508 519 chunk_size = chunk_sizes.draw(strategies.integers(1, 4096))
509 520 chunk = source.read(chunk_size)
510 521 if not chunk:
511 522 break
512 523
513 524 chunks.append(dobj.decompress(chunk))
514 525
515 526 self.assertEqual(b"".join(chunks), original)
516 527
517 528
518 529 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
519 530 @make_cffi
520 531 class TestDecompressor_read_to_iter_fuzzing(TestCase):
521 532 @hypothesis.given(
522 533 original=strategies.sampled_from(random_input_data()),
523 534 level=strategies.integers(min_value=1, max_value=5),
524 535 read_size=strategies.integers(min_value=1, max_value=4096),
525 536 write_size=strategies.integers(min_value=1, max_value=4096),
526 537 )
527 def test_read_write_size_variance(self, original, level, read_size, write_size):
538 def test_read_write_size_variance(
539 self, original, level, read_size, write_size
540 ):
528 541 cctx = zstd.ZstdCompressor(level=level)
529 542 frame = cctx.compress(original)
530 543
531 544 source = io.BytesIO(frame)
532 545
533 546 dctx = zstd.ZstdDecompressor()
534 547 chunks = list(
535 dctx.read_to_iter(source, read_size=read_size, write_size=write_size)
548 dctx.read_to_iter(
549 source, read_size=read_size, write_size=write_size
550 )
536 551 )
537 552
538 553 self.assertEqual(b"".join(chunks), original)
539 554
540 555
541 556 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
542 557 class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase):
543 558 @hypothesis.given(
544 559 original=strategies.lists(
545 strategies.sampled_from(random_input_data()), min_size=1, max_size=1024
560 strategies.sampled_from(random_input_data()),
561 min_size=1,
562 max_size=1024,
546 563 ),
547 564 threads=strategies.integers(min_value=1, max_value=8),
548 565 use_dict=strategies.booleans(),
549 566 )
550 567 def test_data_equivalence(self, original, threads, use_dict):
551 568 kwargs = {}
552 569 if use_dict:
553 570 kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0])
554 571
555 572 cctx = zstd.ZstdCompressor(
556 573 level=1, write_content_size=True, write_checksum=True, **kwargs
557 574 )
558 575
559 576 if not hasattr(cctx, "multi_compress_to_buffer"):
560 577 self.skipTest("multi_compress_to_buffer not available")
561 578
562 579 frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1)
563 580
564 581 dctx = zstd.ZstdDecompressor(**kwargs)
565 582 result = dctx.multi_decompress_to_buffer(frames_buffer)
566 583
567 584 self.assertEqual(len(result), len(original))
568 585 for i, frame in enumerate(result):
569 586 self.assertEqual(frame.tobytes(), original[i])
570 587
571 588 frames_list = [f.tobytes() for f in frames_buffer]
572 589 result = dctx.multi_decompress_to_buffer(frames_list)
573 590
574 591 self.assertEqual(len(result), len(original))
575 592 for i, frame in enumerate(result):
576 593 self.assertEqual(frame.tobytes(), original[i])
@@ -1,92 +1,102 b''
1 1 import struct
2 2 import sys
3 3 import unittest
4 4
5 5 import zstandard as zstd
6 6
7 7 from .common import (
8 8 generate_samples,
9 9 make_cffi,
10 10 random_input_data,
11 11 TestCase,
12 12 )
13 13
14 14 if sys.version_info[0] >= 3:
15 15 int_type = int
16 16 else:
17 17 int_type = long
18 18
19 19
20 20 @make_cffi
21 21 class TestTrainDictionary(TestCase):
22 22 def test_no_args(self):
23 23 with self.assertRaises(TypeError):
24 24 zstd.train_dictionary()
25 25
26 26 def test_bad_args(self):
27 27 with self.assertRaises(TypeError):
28 28 zstd.train_dictionary(8192, u"foo")
29 29
30 30 with self.assertRaises(ValueError):
31 31 zstd.train_dictionary(8192, [u"foo"])
32 32
33 33 def test_no_params(self):
34 34 d = zstd.train_dictionary(8192, random_input_data())
35 35 self.assertIsInstance(d.dict_id(), int_type)
36 36
37 37 # The dictionary ID may be different across platforms.
38 38 expected = b"\x37\xa4\x30\xec" + struct.pack("<I", d.dict_id())
39 39
40 40 data = d.as_bytes()
41 41 self.assertEqual(data[0:8], expected)
42 42
43 43 def test_basic(self):
44 44 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
45 45 self.assertIsInstance(d.dict_id(), int_type)
46 46
47 47 data = d.as_bytes()
48 48 self.assertEqual(data[0:4], b"\x37\xa4\x30\xec")
49 49
50 50 self.assertEqual(d.k, 64)
51 51 self.assertEqual(d.d, 16)
52 52
53 53 def test_set_dict_id(self):
54 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16, dict_id=42)
54 d = zstd.train_dictionary(
55 8192, generate_samples(), k=64, d=16, dict_id=42
56 )
55 57 self.assertEqual(d.dict_id(), 42)
56 58
57 59 def test_optimize(self):
58 d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1, d=16)
60 d = zstd.train_dictionary(
61 8192, generate_samples(), threads=-1, steps=1, d=16
62 )
59 63
60 64 # This varies by platform.
61 65 self.assertIn(d.k, (50, 2000))
62 66 self.assertEqual(d.d, 16)
63 67
64 68
65 69 @make_cffi
66 70 class TestCompressionDict(TestCase):
67 71 def test_bad_mode(self):
68 72 with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"):
69 73 zstd.ZstdCompressionDict(b"foo", dict_type=42)
70 74
71 75 def test_bad_precompute_compress(self):
72 76 d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
73 77
74 with self.assertRaisesRegex(ValueError, "must specify one of level or "):
78 with self.assertRaisesRegex(
79 ValueError, "must specify one of level or "
80 ):
75 81 d.precompute_compress()
76 82
77 with self.assertRaisesRegex(ValueError, "must only specify one of level or "):
83 with self.assertRaisesRegex(
84 ValueError, "must only specify one of level or "
85 ):
78 86 d.precompute_compress(
79 87 level=3, compression_params=zstd.CompressionParameters()
80 88 )
81 89
82 90 def test_precompute_compress_rawcontent(self):
83 91 d = zstd.ZstdCompressionDict(
84 92 b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT
85 93 )
86 94 d.precompute_compress(level=1)
87 95
88 96 d = zstd.ZstdCompressionDict(
89 97 b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT
90 98 )
91 with self.assertRaisesRegex(zstd.ZstdError, "unable to precompute dictionary"):
99 with self.assertRaisesRegex(
100 zstd.ZstdError, "unable to precompute dictionary"
101 ):
92 102 d.precompute_compress(level=1)
@@ -1,2615 +1,2769 b''
1 1 # Copyright (c) 2016-present, Gregory Szorc
2 2 # All rights reserved.
3 3 #
4 4 # This software may be modified and distributed under the terms
5 5 # of the BSD license. See the LICENSE file for details.
6 6
7 7 """Python interface to the Zstandard (zstd) compression library."""
8 8
9 9 from __future__ import absolute_import, unicode_literals
10 10
11 11 # This should match what the C extension exports.
12 12 __all__ = [
13 13 #'BufferSegment',
14 14 #'BufferSegments',
15 15 #'BufferWithSegments',
16 16 #'BufferWithSegmentsCollection',
17 17 "CompressionParameters",
18 18 "ZstdCompressionDict",
19 19 "ZstdCompressionParameters",
20 20 "ZstdCompressor",
21 21 "ZstdError",
22 22 "ZstdDecompressor",
23 23 "FrameParameters",
24 24 "estimate_decompression_context_size",
25 25 "frame_content_size",
26 26 "frame_header_size",
27 27 "get_frame_parameters",
28 28 "train_dictionary",
29 29 # Constants.
30 30 "FLUSH_BLOCK",
31 31 "FLUSH_FRAME",
32 32 "COMPRESSOBJ_FLUSH_FINISH",
33 33 "COMPRESSOBJ_FLUSH_BLOCK",
34 34 "ZSTD_VERSION",
35 35 "FRAME_HEADER",
36 36 "CONTENTSIZE_UNKNOWN",
37 37 "CONTENTSIZE_ERROR",
38 38 "MAX_COMPRESSION_LEVEL",
39 39 "COMPRESSION_RECOMMENDED_INPUT_SIZE",
40 40 "COMPRESSION_RECOMMENDED_OUTPUT_SIZE",
41 41 "DECOMPRESSION_RECOMMENDED_INPUT_SIZE",
42 42 "DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE",
43 43 "MAGIC_NUMBER",
44 44 "BLOCKSIZELOG_MAX",
45 45 "BLOCKSIZE_MAX",
46 46 "WINDOWLOG_MIN",
47 47 "WINDOWLOG_MAX",
48 48 "CHAINLOG_MIN",
49 49 "CHAINLOG_MAX",
50 50 "HASHLOG_MIN",
51 51 "HASHLOG_MAX",
52 52 "HASHLOG3_MAX",
53 53 "MINMATCH_MIN",
54 54 "MINMATCH_MAX",
55 55 "SEARCHLOG_MIN",
56 56 "SEARCHLOG_MAX",
57 57 "SEARCHLENGTH_MIN",
58 58 "SEARCHLENGTH_MAX",
59 59 "TARGETLENGTH_MIN",
60 60 "TARGETLENGTH_MAX",
61 61 "LDM_MINMATCH_MIN",
62 62 "LDM_MINMATCH_MAX",
63 63 "LDM_BUCKETSIZELOG_MAX",
64 64 "STRATEGY_FAST",
65 65 "STRATEGY_DFAST",
66 66 "STRATEGY_GREEDY",
67 67 "STRATEGY_LAZY",
68 68 "STRATEGY_LAZY2",
69 69 "STRATEGY_BTLAZY2",
70 70 "STRATEGY_BTOPT",
71 71 "STRATEGY_BTULTRA",
72 72 "STRATEGY_BTULTRA2",
73 73 "DICT_TYPE_AUTO",
74 74 "DICT_TYPE_RAWCONTENT",
75 75 "DICT_TYPE_FULLDICT",
76 76 "FORMAT_ZSTD1",
77 77 "FORMAT_ZSTD1_MAGICLESS",
78 78 ]
79 79
80 80 import io
81 81 import os
82 82 import sys
83 83
84 84 from _zstd_cffi import (
85 85 ffi,
86 86 lib,
87 87 )
88 88
89 89 if sys.version_info[0] == 2:
90 90 bytes_type = str
91 91 int_type = long
92 92 else:
93 93 bytes_type = bytes
94 94 int_type = int
95 95
96 96
97 97 COMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_CStreamInSize()
98 98 COMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_CStreamOutSize()
99 99 DECOMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_DStreamInSize()
100 100 DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_DStreamOutSize()
101 101
102 102 new_nonzero = ffi.new_allocator(should_clear_after_alloc=False)
103 103
104 104
105 105 MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel()
106 106 MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER
107 107 FRAME_HEADER = b"\x28\xb5\x2f\xfd"
108 108 CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN
109 109 CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR
110 110 ZSTD_VERSION = (
111 111 lib.ZSTD_VERSION_MAJOR,
112 112 lib.ZSTD_VERSION_MINOR,
113 113 lib.ZSTD_VERSION_RELEASE,
114 114 )
115 115
116 116 BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX
117 117 BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX
118 118 WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN
119 119 WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX
120 120 CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN
121 121 CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX
122 122 HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN
123 123 HASHLOG_MAX = lib.ZSTD_HASHLOG_MAX
124 124 HASHLOG3_MAX = lib.ZSTD_HASHLOG3_MAX
125 125 MINMATCH_MIN = lib.ZSTD_MINMATCH_MIN
126 126 MINMATCH_MAX = lib.ZSTD_MINMATCH_MAX
127 127 SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN
128 128 SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX
129 129 SEARCHLENGTH_MIN = lib.ZSTD_MINMATCH_MIN
130 130 SEARCHLENGTH_MAX = lib.ZSTD_MINMATCH_MAX
131 131 TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN
132 132 TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX
133 133 LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN
134 134 LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX
135 135 LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX
136 136
137 137 STRATEGY_FAST = lib.ZSTD_fast
138 138 STRATEGY_DFAST = lib.ZSTD_dfast
139 139 STRATEGY_GREEDY = lib.ZSTD_greedy
140 140 STRATEGY_LAZY = lib.ZSTD_lazy
141 141 STRATEGY_LAZY2 = lib.ZSTD_lazy2
142 142 STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2
143 143 STRATEGY_BTOPT = lib.ZSTD_btopt
144 144 STRATEGY_BTULTRA = lib.ZSTD_btultra
145 145 STRATEGY_BTULTRA2 = lib.ZSTD_btultra2
146 146
147 147 DICT_TYPE_AUTO = lib.ZSTD_dct_auto
148 148 DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent
149 149 DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict
150 150
151 151 FORMAT_ZSTD1 = lib.ZSTD_f_zstd1
152 152 FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless
153 153
154 154 FLUSH_BLOCK = 0
155 155 FLUSH_FRAME = 1
156 156
157 157 COMPRESSOBJ_FLUSH_FINISH = 0
158 158 COMPRESSOBJ_FLUSH_BLOCK = 1
159 159
160 160
161 161 def _cpu_count():
162 162 # os.cpu_count() was introducd in Python 3.4.
163 163 try:
164 164 return os.cpu_count() or 0
165 165 except AttributeError:
166 166 pass
167 167
168 168 # Linux.
169 169 try:
170 170 if sys.version_info[0] == 2:
171 171 return os.sysconf(b"SC_NPROCESSORS_ONLN")
172 172 else:
173 173 return os.sysconf("SC_NPROCESSORS_ONLN")
174 174 except (AttributeError, ValueError):
175 175 pass
176 176
177 177 # TODO implement on other platforms.
178 178 return 0
179 179
180 180
181 181 class ZstdError(Exception):
182 182 pass
183 183
184 184
185 185 def _zstd_error(zresult):
186 186 # Resolves to bytes on Python 2 and 3. We use the string for formatting
187 187 # into error messages, which will be literal unicode. So convert it to
188 188 # unicode.
189 189 return ffi.string(lib.ZSTD_getErrorName(zresult)).decode("utf-8")
190 190
191 191
192 192 def _make_cctx_params(params):
193 193 res = lib.ZSTD_createCCtxParams()
194 194 if res == ffi.NULL:
195 195 raise MemoryError()
196 196
197 197 res = ffi.gc(res, lib.ZSTD_freeCCtxParams)
198 198
199 199 attrs = [
200 200 (lib.ZSTD_c_format, params.format),
201 201 (lib.ZSTD_c_compressionLevel, params.compression_level),
202 202 (lib.ZSTD_c_windowLog, params.window_log),
203 203 (lib.ZSTD_c_hashLog, params.hash_log),
204 204 (lib.ZSTD_c_chainLog, params.chain_log),
205 205 (lib.ZSTD_c_searchLog, params.search_log),
206 206 (lib.ZSTD_c_minMatch, params.min_match),
207 207 (lib.ZSTD_c_targetLength, params.target_length),
208 208 (lib.ZSTD_c_strategy, params.compression_strategy),
209 209 (lib.ZSTD_c_contentSizeFlag, params.write_content_size),
210 210 (lib.ZSTD_c_checksumFlag, params.write_checksum),
211 211 (lib.ZSTD_c_dictIDFlag, params.write_dict_id),
212 212 (lib.ZSTD_c_nbWorkers, params.threads),
213 213 (lib.ZSTD_c_jobSize, params.job_size),
214 214 (lib.ZSTD_c_overlapLog, params.overlap_log),
215 215 (lib.ZSTD_c_forceMaxWindow, params.force_max_window),
216 216 (lib.ZSTD_c_enableLongDistanceMatching, params.enable_ldm),
217 217 (lib.ZSTD_c_ldmHashLog, params.ldm_hash_log),
218 218 (lib.ZSTD_c_ldmMinMatch, params.ldm_min_match),
219 219 (lib.ZSTD_c_ldmBucketSizeLog, params.ldm_bucket_size_log),
220 220 (lib.ZSTD_c_ldmHashRateLog, params.ldm_hash_rate_log),
221 221 ]
222 222
223 223 for param, value in attrs:
224 224 _set_compression_parameter(res, param, value)
225 225
226 226 return res
227 227
228 228
229 229 class ZstdCompressionParameters(object):
230 230 @staticmethod
231 231 def from_level(level, source_size=0, dict_size=0, **kwargs):
232 232 params = lib.ZSTD_getCParams(level, source_size, dict_size)
233 233
234 234 args = {
235 235 "window_log": "windowLog",
236 236 "chain_log": "chainLog",
237 237 "hash_log": "hashLog",
238 238 "search_log": "searchLog",
239 239 "min_match": "minMatch",
240 240 "target_length": "targetLength",
241 241 "compression_strategy": "strategy",
242 242 }
243 243
244 244 for arg, attr in args.items():
245 245 if arg not in kwargs:
246 246 kwargs[arg] = getattr(params, attr)
247 247
248 248 return ZstdCompressionParameters(**kwargs)
249 249
250 250 def __init__(
251 251 self,
252 252 format=0,
253 253 compression_level=0,
254 254 window_log=0,
255 255 hash_log=0,
256 256 chain_log=0,
257 257 search_log=0,
258 258 min_match=0,
259 259 target_length=0,
260 260 strategy=-1,
261 261 compression_strategy=-1,
262 262 write_content_size=1,
263 263 write_checksum=0,
264 264 write_dict_id=0,
265 265 job_size=0,
266 266 overlap_log=-1,
267 267 overlap_size_log=-1,
268 268 force_max_window=0,
269 269 enable_ldm=0,
270 270 ldm_hash_log=0,
271 271 ldm_min_match=0,
272 272 ldm_bucket_size_log=0,
273 273 ldm_hash_rate_log=-1,
274 274 ldm_hash_every_log=-1,
275 275 threads=0,
276 276 ):
277 277
278 278 params = lib.ZSTD_createCCtxParams()
279 279 if params == ffi.NULL:
280 280 raise MemoryError()
281 281
282 282 params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
283 283
284 284 self._params = params
285 285
286 286 if threads < 0:
287 287 threads = _cpu_count()
288 288
289 289 # We need to set ZSTD_c_nbWorkers before ZSTD_c_jobSize and ZSTD_c_overlapLog
290 290 # because setting ZSTD_c_nbWorkers resets the other parameters.
291 291 _set_compression_parameter(params, lib.ZSTD_c_nbWorkers, threads)
292 292
293 293 _set_compression_parameter(params, lib.ZSTD_c_format, format)
294 294 _set_compression_parameter(
295 295 params, lib.ZSTD_c_compressionLevel, compression_level
296 296 )
297 297 _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log)
298 298 _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log)
299 299 _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log)
300 300 _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log)
301 301 _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match)
302 _set_compression_parameter(params, lib.ZSTD_c_targetLength, target_length)
302 _set_compression_parameter(
303 params, lib.ZSTD_c_targetLength, target_length
304 )
303 305
304 306 if strategy != -1 and compression_strategy != -1:
305 raise ValueError("cannot specify both compression_strategy and strategy")
307 raise ValueError(
308 "cannot specify both compression_strategy and strategy"
309 )
306 310
307 311 if compression_strategy != -1:
308 312 strategy = compression_strategy
309 313 elif strategy == -1:
310 314 strategy = 0
311 315
312 316 _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy)
313 317 _set_compression_parameter(
314 318 params, lib.ZSTD_c_contentSizeFlag, write_content_size
315 319 )
316 _set_compression_parameter(params, lib.ZSTD_c_checksumFlag, write_checksum)
320 _set_compression_parameter(
321 params, lib.ZSTD_c_checksumFlag, write_checksum
322 )
317 323 _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id)
318 324 _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size)
319 325
320 326 if overlap_log != -1 and overlap_size_log != -1:
321 raise ValueError("cannot specify both overlap_log and overlap_size_log")
327 raise ValueError(
328 "cannot specify both overlap_log and overlap_size_log"
329 )
322 330
323 331 if overlap_size_log != -1:
324 332 overlap_log = overlap_size_log
325 333 elif overlap_log == -1:
326 334 overlap_log = 0
327 335
328 336 _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log)
329 _set_compression_parameter(params, lib.ZSTD_c_forceMaxWindow, force_max_window)
337 _set_compression_parameter(
338 params, lib.ZSTD_c_forceMaxWindow, force_max_window
339 )
330 340 _set_compression_parameter(
331 341 params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm
332 342 )
333 343 _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log)
334 _set_compression_parameter(params, lib.ZSTD_c_ldmMinMatch, ldm_min_match)
344 _set_compression_parameter(
345 params, lib.ZSTD_c_ldmMinMatch, ldm_min_match
346 )
335 347 _set_compression_parameter(
336 348 params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log
337 349 )
338 350
339 351 if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1:
340 352 raise ValueError(
341 353 "cannot specify both ldm_hash_rate_log and ldm_hash_every_log"
342 354 )
343 355
344 356 if ldm_hash_every_log != -1:
345 357 ldm_hash_rate_log = ldm_hash_every_log
346 358 elif ldm_hash_rate_log == -1:
347 359 ldm_hash_rate_log = 0
348 360
349 _set_compression_parameter(params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log)
361 _set_compression_parameter(
362 params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log
363 )
350 364
351 365 @property
352 366 def format(self):
353 367 return _get_compression_parameter(self._params, lib.ZSTD_c_format)
354 368
355 369 @property
356 370 def compression_level(self):
357 return _get_compression_parameter(self._params, lib.ZSTD_c_compressionLevel)
371 return _get_compression_parameter(
372 self._params, lib.ZSTD_c_compressionLevel
373 )
358 374
359 375 @property
360 376 def window_log(self):
361 377 return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog)
362 378
363 379 @property
364 380 def hash_log(self):
365 381 return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog)
366 382
367 383 @property
368 384 def chain_log(self):
369 385 return _get_compression_parameter(self._params, lib.ZSTD_c_chainLog)
370 386
371 387 @property
372 388 def search_log(self):
373 389 return _get_compression_parameter(self._params, lib.ZSTD_c_searchLog)
374 390
375 391 @property
376 392 def min_match(self):
377 393 return _get_compression_parameter(self._params, lib.ZSTD_c_minMatch)
378 394
379 395 @property
380 396 def target_length(self):
381 397 return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength)
382 398
383 399 @property
384 400 def compression_strategy(self):
385 401 return _get_compression_parameter(self._params, lib.ZSTD_c_strategy)
386 402
387 403 @property
388 404 def write_content_size(self):
389 return _get_compression_parameter(self._params, lib.ZSTD_c_contentSizeFlag)
405 return _get_compression_parameter(
406 self._params, lib.ZSTD_c_contentSizeFlag
407 )
390 408
391 409 @property
392 410 def write_checksum(self):
393 411 return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag)
394 412
395 413 @property
396 414 def write_dict_id(self):
397 415 return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag)
398 416
399 417 @property
400 418 def job_size(self):
401 419 return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize)
402 420
403 421 @property
404 422 def overlap_log(self):
405 423 return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog)
406 424
407 425 @property
408 426 def overlap_size_log(self):
409 427 return self.overlap_log
410 428
411 429 @property
412 430 def force_max_window(self):
413 return _get_compression_parameter(self._params, lib.ZSTD_c_forceMaxWindow)
431 return _get_compression_parameter(
432 self._params, lib.ZSTD_c_forceMaxWindow
433 )
414 434
415 435 @property
416 436 def enable_ldm(self):
417 437 return _get_compression_parameter(
418 438 self._params, lib.ZSTD_c_enableLongDistanceMatching
419 439 )
420 440
421 441 @property
422 442 def ldm_hash_log(self):
423 443 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog)
424 444
425 445 @property
426 446 def ldm_min_match(self):
427 447 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch)
428 448
429 449 @property
430 450 def ldm_bucket_size_log(self):
431 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmBucketSizeLog)
451 return _get_compression_parameter(
452 self._params, lib.ZSTD_c_ldmBucketSizeLog
453 )
432 454
433 455 @property
434 456 def ldm_hash_rate_log(self):
435 return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashRateLog)
457 return _get_compression_parameter(
458 self._params, lib.ZSTD_c_ldmHashRateLog
459 )
436 460
437 461 @property
438 462 def ldm_hash_every_log(self):
439 463 return self.ldm_hash_rate_log
440 464
441 465 @property
442 466 def threads(self):
443 467 return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers)
444 468
445 469 def estimated_compression_context_size(self):
446 470 return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params)
447 471
448 472
449 473 CompressionParameters = ZstdCompressionParameters
450 474
451 475
452 476 def estimate_decompression_context_size():
453 477 return lib.ZSTD_estimateDCtxSize()
454 478
455 479
456 480 def _set_compression_parameter(params, param, value):
457 481 zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value)
458 482 if lib.ZSTD_isError(zresult):
459 483 raise ZstdError(
460 "unable to set compression context parameter: %s" % _zstd_error(zresult)
484 "unable to set compression context parameter: %s"
485 % _zstd_error(zresult)
461 486 )
462 487
463 488
464 489 def _get_compression_parameter(params, param):
465 490 result = ffi.new("int *")
466 491
467 492 zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result)
468 493 if lib.ZSTD_isError(zresult):
469 494 raise ZstdError(
470 "unable to get compression context parameter: %s" % _zstd_error(zresult)
495 "unable to get compression context parameter: %s"
496 % _zstd_error(zresult)
471 497 )
472 498
473 499 return result[0]
474 500
475 501
476 502 class ZstdCompressionWriter(object):
477 def __init__(self, compressor, writer, source_size, write_size, write_return_read):
503 def __init__(
504 self, compressor, writer, source_size, write_size, write_return_read
505 ):
478 506 self._compressor = compressor
479 507 self._writer = writer
480 508 self._write_size = write_size
481 509 self._write_return_read = bool(write_return_read)
482 510 self._entered = False
483 511 self._closed = False
484 512 self._bytes_compressed = 0
485 513
486 514 self._dst_buffer = ffi.new("char[]", write_size)
487 515 self._out_buffer = ffi.new("ZSTD_outBuffer *")
488 516 self._out_buffer.dst = self._dst_buffer
489 517 self._out_buffer.size = len(self._dst_buffer)
490 518 self._out_buffer.pos = 0
491 519
492 520 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size)
493 521 if lib.ZSTD_isError(zresult):
494 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
522 raise ZstdError(
523 "error setting source size: %s" % _zstd_error(zresult)
524 )
495 525
496 526 def __enter__(self):
497 527 if self._closed:
498 528 raise ValueError("stream is closed")
499 529
500 530 if self._entered:
501 531 raise ZstdError("cannot __enter__ multiple times")
502 532
503 533 self._entered = True
504 534 return self
505 535
506 536 def __exit__(self, exc_type, exc_value, exc_tb):
507 537 self._entered = False
508 538
509 539 if not exc_type and not exc_value and not exc_tb:
510 540 self.close()
511 541
512 542 self._compressor = None
513 543
514 544 return False
515 545
516 546 def memory_size(self):
517 547 return lib.ZSTD_sizeof_CCtx(self._compressor._cctx)
518 548
519 549 def fileno(self):
520 550 f = getattr(self._writer, "fileno", None)
521 551 if f:
522 552 return f()
523 553 else:
524 554 raise OSError("fileno not available on underlying writer")
525 555
526 556 def close(self):
527 557 if self._closed:
528 558 return
529 559
530 560 try:
531 561 self.flush(FLUSH_FRAME)
532 562 finally:
533 563 self._closed = True
534 564
535 565 # Call close() on underlying stream as well.
536 566 f = getattr(self._writer, "close", None)
537 567 if f:
538 568 f()
539 569
540 570 @property
541 571 def closed(self):
542 572 return self._closed
543 573
544 574 def isatty(self):
545 575 return False
546 576
547 577 def readable(self):
548 578 return False
549 579
550 580 def readline(self, size=-1):
551 581 raise io.UnsupportedOperation()
552 582
553 583 def readlines(self, hint=-1):
554 584 raise io.UnsupportedOperation()
555 585
556 586 def seek(self, offset, whence=None):
557 587 raise io.UnsupportedOperation()
558 588
559 589 def seekable(self):
560 590 return False
561 591
562 592 def truncate(self, size=None):
563 593 raise io.UnsupportedOperation()
564 594
565 595 def writable(self):
566 596 return True
567 597
568 598 def writelines(self, lines):
569 599 raise NotImplementedError("writelines() is not yet implemented")
570 600
571 601 def read(self, size=-1):
572 602 raise io.UnsupportedOperation()
573 603
574 604 def readall(self):
575 605 raise io.UnsupportedOperation()
576 606
577 607 def readinto(self, b):
578 608 raise io.UnsupportedOperation()
579 609
580 610 def write(self, data):
581 611 if self._closed:
582 612 raise ValueError("stream is closed")
583 613
584 614 total_write = 0
585 615
586 616 data_buffer = ffi.from_buffer(data)
587 617
588 618 in_buffer = ffi.new("ZSTD_inBuffer *")
589 619 in_buffer.src = data_buffer
590 620 in_buffer.size = len(data_buffer)
591 621 in_buffer.pos = 0
592 622
593 623 out_buffer = self._out_buffer
594 624 out_buffer.pos = 0
595 625
596 626 while in_buffer.pos < in_buffer.size:
597 627 zresult = lib.ZSTD_compressStream2(
598 self._compressor._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
628 self._compressor._cctx,
629 out_buffer,
630 in_buffer,
631 lib.ZSTD_e_continue,
599 632 )
600 633 if lib.ZSTD_isError(zresult):
601 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
634 raise ZstdError(
635 "zstd compress error: %s" % _zstd_error(zresult)
636 )
602 637
603 638 if out_buffer.pos:
604 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
639 self._writer.write(
640 ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
641 )
605 642 total_write += out_buffer.pos
606 643 self._bytes_compressed += out_buffer.pos
607 644 out_buffer.pos = 0
608 645
609 646 if self._write_return_read:
610 647 return in_buffer.pos
611 648 else:
612 649 return total_write
613 650
614 651 def flush(self, flush_mode=FLUSH_BLOCK):
615 652 if flush_mode == FLUSH_BLOCK:
616 653 flush = lib.ZSTD_e_flush
617 654 elif flush_mode == FLUSH_FRAME:
618 655 flush = lib.ZSTD_e_end
619 656 else:
620 657 raise ValueError("unknown flush_mode: %r" % flush_mode)
621 658
622 659 if self._closed:
623 660 raise ValueError("stream is closed")
624 661
625 662 total_write = 0
626 663
627 664 out_buffer = self._out_buffer
628 665 out_buffer.pos = 0
629 666
630 667 in_buffer = ffi.new("ZSTD_inBuffer *")
631 668 in_buffer.src = ffi.NULL
632 669 in_buffer.size = 0
633 670 in_buffer.pos = 0
634 671
635 672 while True:
636 673 zresult = lib.ZSTD_compressStream2(
637 674 self._compressor._cctx, out_buffer, in_buffer, flush
638 675 )
639 676 if lib.ZSTD_isError(zresult):
640 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
677 raise ZstdError(
678 "zstd compress error: %s" % _zstd_error(zresult)
679 )
641 680
642 681 if out_buffer.pos:
643 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
682 self._writer.write(
683 ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
684 )
644 685 total_write += out_buffer.pos
645 686 self._bytes_compressed += out_buffer.pos
646 687 out_buffer.pos = 0
647 688
648 689 if not zresult:
649 690 break
650 691
651 692 return total_write
652 693
653 694 def tell(self):
654 695 return self._bytes_compressed
655 696
656 697
657 698 class ZstdCompressionObj(object):
658 699 def compress(self, data):
659 700 if self._finished:
660 701 raise ZstdError("cannot call compress() after compressor finished")
661 702
662 703 data_buffer = ffi.from_buffer(data)
663 704 source = ffi.new("ZSTD_inBuffer *")
664 705 source.src = data_buffer
665 706 source.size = len(data_buffer)
666 707 source.pos = 0
667 708
668 709 chunks = []
669 710
670 711 while source.pos < len(data):
671 712 zresult = lib.ZSTD_compressStream2(
672 713 self._compressor._cctx, self._out, source, lib.ZSTD_e_continue
673 714 )
674 715 if lib.ZSTD_isError(zresult):
675 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
716 raise ZstdError(
717 "zstd compress error: %s" % _zstd_error(zresult)
718 )
676 719
677 720 if self._out.pos:
678 721 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
679 722 self._out.pos = 0
680 723
681 724 return b"".join(chunks)
682 725
683 726 def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH):
684 if flush_mode not in (COMPRESSOBJ_FLUSH_FINISH, COMPRESSOBJ_FLUSH_BLOCK):
727 if flush_mode not in (
728 COMPRESSOBJ_FLUSH_FINISH,
729 COMPRESSOBJ_FLUSH_BLOCK,
730 ):
685 731 raise ValueError("flush mode not recognized")
686 732
687 733 if self._finished:
688 734 raise ZstdError("compressor object already finished")
689 735
690 736 if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
691 737 z_flush_mode = lib.ZSTD_e_flush
692 738 elif flush_mode == COMPRESSOBJ_FLUSH_FINISH:
693 739 z_flush_mode = lib.ZSTD_e_end
694 740 self._finished = True
695 741 else:
696 742 raise ZstdError("unhandled flush mode")
697 743
698 744 assert self._out.pos == 0
699 745
700 746 in_buffer = ffi.new("ZSTD_inBuffer *")
701 747 in_buffer.src = ffi.NULL
702 748 in_buffer.size = 0
703 749 in_buffer.pos = 0
704 750
705 751 chunks = []
706 752
707 753 while True:
708 754 zresult = lib.ZSTD_compressStream2(
709 755 self._compressor._cctx, self._out, in_buffer, z_flush_mode
710 756 )
711 757 if lib.ZSTD_isError(zresult):
712 758 raise ZstdError(
713 759 "error ending compression stream: %s" % _zstd_error(zresult)
714 760 )
715 761
716 762 if self._out.pos:
717 763 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
718 764 self._out.pos = 0
719 765
720 766 if not zresult:
721 767 break
722 768
723 769 return b"".join(chunks)
724 770
725 771
726 772 class ZstdCompressionChunker(object):
727 773 def __init__(self, compressor, chunk_size):
728 774 self._compressor = compressor
729 775 self._out = ffi.new("ZSTD_outBuffer *")
730 776 self._dst_buffer = ffi.new("char[]", chunk_size)
731 777 self._out.dst = self._dst_buffer
732 778 self._out.size = chunk_size
733 779 self._out.pos = 0
734 780
735 781 self._in = ffi.new("ZSTD_inBuffer *")
736 782 self._in.src = ffi.NULL
737 783 self._in.size = 0
738 784 self._in.pos = 0
739 785 self._finished = False
740 786
741 787 def compress(self, data):
742 788 if self._finished:
743 789 raise ZstdError("cannot call compress() after compression finished")
744 790
745 791 if self._in.src != ffi.NULL:
746 792 raise ZstdError(
747 793 "cannot perform operation before consuming output "
748 794 "from previous operation"
749 795 )
750 796
751 797 data_buffer = ffi.from_buffer(data)
752 798
753 799 if not len(data_buffer):
754 800 return
755 801
756 802 self._in.src = data_buffer
757 803 self._in.size = len(data_buffer)
758 804 self._in.pos = 0
759 805
760 806 while self._in.pos < self._in.size:
761 807 zresult = lib.ZSTD_compressStream2(
762 808 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_continue
763 809 )
764 810
765 811 if self._in.pos == self._in.size:
766 812 self._in.src = ffi.NULL
767 813 self._in.size = 0
768 814 self._in.pos = 0
769 815
770 816 if lib.ZSTD_isError(zresult):
771 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
817 raise ZstdError(
818 "zstd compress error: %s" % _zstd_error(zresult)
819 )
772 820
773 821 if self._out.pos == self._out.size:
774 822 yield ffi.buffer(self._out.dst, self._out.pos)[:]
775 823 self._out.pos = 0
776 824
777 825 def flush(self):
778 826 if self._finished:
779 827 raise ZstdError("cannot call flush() after compression finished")
780 828
781 829 if self._in.src != ffi.NULL:
782 830 raise ZstdError(
783 "cannot call flush() before consuming output from " "previous operation"
831 "cannot call flush() before consuming output from "
832 "previous operation"
784 833 )
785 834
786 835 while True:
787 836 zresult = lib.ZSTD_compressStream2(
788 837 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush
789 838 )
790 839 if lib.ZSTD_isError(zresult):
791 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
840 raise ZstdError(
841 "zstd compress error: %s" % _zstd_error(zresult)
842 )
792 843
793 844 if self._out.pos:
794 845 yield ffi.buffer(self._out.dst, self._out.pos)[:]
795 846 self._out.pos = 0
796 847
797 848 if not zresult:
798 849 return
799 850
800 851 def finish(self):
801 852 if self._finished:
802 853 raise ZstdError("cannot call finish() after compression finished")
803 854
804 855 if self._in.src != ffi.NULL:
805 856 raise ZstdError(
806 857 "cannot call finish() before consuming output from "
807 858 "previous operation"
808 859 )
809 860
810 861 while True:
811 862 zresult = lib.ZSTD_compressStream2(
812 863 self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end
813 864 )
814 865 if lib.ZSTD_isError(zresult):
815 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
866 raise ZstdError(
867 "zstd compress error: %s" % _zstd_error(zresult)
868 )
816 869
817 870 if self._out.pos:
818 871 yield ffi.buffer(self._out.dst, self._out.pos)[:]
819 872 self._out.pos = 0
820 873
821 874 if not zresult:
822 875 self._finished = True
823 876 return
824 877
825 878
826 879 class ZstdCompressionReader(object):
827 880 def __init__(self, compressor, source, read_size):
828 881 self._compressor = compressor
829 882 self._source = source
830 883 self._read_size = read_size
831 884 self._entered = False
832 885 self._closed = False
833 886 self._bytes_compressed = 0
834 887 self._finished_input = False
835 888 self._finished_output = False
836 889
837 890 self._in_buffer = ffi.new("ZSTD_inBuffer *")
838 891 # Holds a ref so backing bytes in self._in_buffer stay alive.
839 892 self._source_buffer = None
840 893
841 894 def __enter__(self):
842 895 if self._entered:
843 896 raise ValueError("cannot __enter__ multiple times")
844 897
845 898 self._entered = True
846 899 return self
847 900
848 901 def __exit__(self, exc_type, exc_value, exc_tb):
849 902 self._entered = False
850 903 self._closed = True
851 904 self._source = None
852 905 self._compressor = None
853 906
854 907 return False
855 908
856 909 def readable(self):
857 910 return True
858 911
859 912 def writable(self):
860 913 return False
861 914
862 915 def seekable(self):
863 916 return False
864 917
865 918 def readline(self):
866 919 raise io.UnsupportedOperation()
867 920
868 921 def readlines(self):
869 922 raise io.UnsupportedOperation()
870 923
871 924 def write(self, data):
872 925 raise OSError("stream is not writable")
873 926
874 927 def writelines(self, ignored):
875 928 raise OSError("stream is not writable")
876 929
877 930 def isatty(self):
878 931 return False
879 932
880 933 def flush(self):
881 934 return None
882 935
883 936 def close(self):
884 937 self._closed = True
885 938 return None
886 939
887 940 @property
888 941 def closed(self):
889 942 return self._closed
890 943
891 944 def tell(self):
892 945 return self._bytes_compressed
893 946
894 947 def readall(self):
895 948 chunks = []
896 949
897 950 while True:
898 951 chunk = self.read(1048576)
899 952 if not chunk:
900 953 break
901 954
902 955 chunks.append(chunk)
903 956
904 957 return b"".join(chunks)
905 958
906 959 def __iter__(self):
907 960 raise io.UnsupportedOperation()
908 961
909 962 def __next__(self):
910 963 raise io.UnsupportedOperation()
911 964
912 965 next = __next__
913 966
914 967 def _read_input(self):
915 968 if self._finished_input:
916 969 return
917 970
918 971 if hasattr(self._source, "read"):
919 972 data = self._source.read(self._read_size)
920 973
921 974 if not data:
922 975 self._finished_input = True
923 976 return
924 977
925 978 self._source_buffer = ffi.from_buffer(data)
926 979 self._in_buffer.src = self._source_buffer
927 980 self._in_buffer.size = len(self._source_buffer)
928 981 self._in_buffer.pos = 0
929 982 else:
930 983 self._source_buffer = ffi.from_buffer(self._source)
931 984 self._in_buffer.src = self._source_buffer
932 985 self._in_buffer.size = len(self._source_buffer)
933 986 self._in_buffer.pos = 0
934 987
935 988 def _compress_into_buffer(self, out_buffer):
936 989 if self._in_buffer.pos >= self._in_buffer.size:
937 990 return
938 991
939 992 old_pos = out_buffer.pos
940 993
941 994 zresult = lib.ZSTD_compressStream2(
942 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_continue
995 self._compressor._cctx,
996 out_buffer,
997 self._in_buffer,
998 lib.ZSTD_e_continue,
943 999 )
944 1000
945 1001 self._bytes_compressed += out_buffer.pos - old_pos
946 1002
947 1003 if self._in_buffer.pos == self._in_buffer.size:
948 1004 self._in_buffer.src = ffi.NULL
949 1005 self._in_buffer.pos = 0
950 1006 self._in_buffer.size = 0
951 1007 self._source_buffer = None
952 1008
953 1009 if not hasattr(self._source, "read"):
954 1010 self._finished_input = True
955 1011
956 1012 if lib.ZSTD_isError(zresult):
957 1013 raise ZstdError("zstd compress error: %s", _zstd_error(zresult))
958 1014
959 1015 return out_buffer.pos and out_buffer.pos == out_buffer.size
960 1016
961 1017 def read(self, size=-1):
962 1018 if self._closed:
963 1019 raise ValueError("stream is closed")
964 1020
965 1021 if size < -1:
966 1022 raise ValueError("cannot read negative amounts less than -1")
967 1023
968 1024 if size == -1:
969 1025 return self.readall()
970 1026
971 1027 if self._finished_output or size == 0:
972 1028 return b""
973 1029
974 1030 # Need a dedicated ref to dest buffer otherwise it gets collected.
975 1031 dst_buffer = ffi.new("char[]", size)
976 1032 out_buffer = ffi.new("ZSTD_outBuffer *")
977 1033 out_buffer.dst = dst_buffer
978 1034 out_buffer.size = size
979 1035 out_buffer.pos = 0
980 1036
981 1037 if self._compress_into_buffer(out_buffer):
982 1038 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
983 1039
984 1040 while not self._finished_input:
985 1041 self._read_input()
986 1042
987 1043 if self._compress_into_buffer(out_buffer):
988 1044 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
989 1045
990 1046 # EOF
991 1047 old_pos = out_buffer.pos
992 1048
993 1049 zresult = lib.ZSTD_compressStream2(
994 1050 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
995 1051 )
996 1052
997 1053 self._bytes_compressed += out_buffer.pos - old_pos
998 1054
999 1055 if lib.ZSTD_isError(zresult):
1000 raise ZstdError("error ending compression stream: %s", _zstd_error(zresult))
1056 raise ZstdError(
1057 "error ending compression stream: %s", _zstd_error(zresult)
1058 )
1001 1059
1002 1060 if zresult == 0:
1003 1061 self._finished_output = True
1004 1062
1005 1063 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1006 1064
1007 1065 def read1(self, size=-1):
1008 1066 if self._closed:
1009 1067 raise ValueError("stream is closed")
1010 1068
1011 1069 if size < -1:
1012 1070 raise ValueError("cannot read negative amounts less than -1")
1013 1071
1014 1072 if self._finished_output or size == 0:
1015 1073 return b""
1016 1074
1017 1075 # -1 returns arbitrary number of bytes.
1018 1076 if size == -1:
1019 1077 size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1020 1078
1021 1079 dst_buffer = ffi.new("char[]", size)
1022 1080 out_buffer = ffi.new("ZSTD_outBuffer *")
1023 1081 out_buffer.dst = dst_buffer
1024 1082 out_buffer.size = size
1025 1083 out_buffer.pos = 0
1026 1084
1027 1085 # read1() dictates that we can perform at most 1 call to the
1028 1086 # underlying stream to get input. However, we can't satisfy this
1029 1087 # restriction with compression because not all input generates output.
1030 1088 # It is possible to perform a block flush in order to ensure output.
1031 1089 # But this may not be desirable behavior. So we allow multiple read()
1032 1090 # to the underlying stream. But unlike read(), we stop once we have
1033 1091 # any output.
1034 1092
1035 1093 self._compress_into_buffer(out_buffer)
1036 1094 if out_buffer.pos:
1037 1095 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1038 1096
1039 1097 while not self._finished_input:
1040 1098 self._read_input()
1041 1099
1042 1100 # If we've filled the output buffer, return immediately.
1043 1101 if self._compress_into_buffer(out_buffer):
1044 1102 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1045 1103
1046 1104 # If we've populated the output buffer and we're not at EOF,
1047 1105 # also return, as we've satisfied the read1() limits.
1048 1106 if out_buffer.pos and not self._finished_input:
1049 1107 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1050 1108
1051 1109 # Else if we're at EOS and we have room left in the buffer,
1052 1110 # fall through to below and try to add more data to the output.
1053 1111
1054 1112 # EOF.
1055 1113 old_pos = out_buffer.pos
1056 1114
1057 1115 zresult = lib.ZSTD_compressStream2(
1058 1116 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1059 1117 )
1060 1118
1061 1119 self._bytes_compressed += out_buffer.pos - old_pos
1062 1120
1063 1121 if lib.ZSTD_isError(zresult):
1064 1122 raise ZstdError(
1065 1123 "error ending compression stream: %s" % _zstd_error(zresult)
1066 1124 )
1067 1125
1068 1126 if zresult == 0:
1069 1127 self._finished_output = True
1070 1128
1071 1129 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1072 1130
1073 1131 def readinto(self, b):
1074 1132 if self._closed:
1075 1133 raise ValueError("stream is closed")
1076 1134
1077 1135 if self._finished_output:
1078 1136 return 0
1079 1137
1080 1138 # TODO use writable=True once we require CFFI >= 1.12.
1081 1139 dest_buffer = ffi.from_buffer(b)
1082 1140 ffi.memmove(b, b"", 0)
1083 1141 out_buffer = ffi.new("ZSTD_outBuffer *")
1084 1142 out_buffer.dst = dest_buffer
1085 1143 out_buffer.size = len(dest_buffer)
1086 1144 out_buffer.pos = 0
1087 1145
1088 1146 if self._compress_into_buffer(out_buffer):
1089 1147 return out_buffer.pos
1090 1148
1091 1149 while not self._finished_input:
1092 1150 self._read_input()
1093 1151 if self._compress_into_buffer(out_buffer):
1094 1152 return out_buffer.pos
1095 1153
1096 1154 # EOF.
1097 1155 old_pos = out_buffer.pos
1098 1156 zresult = lib.ZSTD_compressStream2(
1099 1157 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1100 1158 )
1101 1159
1102 1160 self._bytes_compressed += out_buffer.pos - old_pos
1103 1161
1104 1162 if lib.ZSTD_isError(zresult):
1105 raise ZstdError("error ending compression stream: %s", _zstd_error(zresult))
1163 raise ZstdError(
1164 "error ending compression stream: %s", _zstd_error(zresult)
1165 )
1106 1166
1107 1167 if zresult == 0:
1108 1168 self._finished_output = True
1109 1169
1110 1170 return out_buffer.pos
1111 1171
1112 1172 def readinto1(self, b):
1113 1173 if self._closed:
1114 1174 raise ValueError("stream is closed")
1115 1175
1116 1176 if self._finished_output:
1117 1177 return 0
1118 1178
1119 1179 # TODO use writable=True once we require CFFI >= 1.12.
1120 1180 dest_buffer = ffi.from_buffer(b)
1121 1181 ffi.memmove(b, b"", 0)
1122 1182
1123 1183 out_buffer = ffi.new("ZSTD_outBuffer *")
1124 1184 out_buffer.dst = dest_buffer
1125 1185 out_buffer.size = len(dest_buffer)
1126 1186 out_buffer.pos = 0
1127 1187
1128 1188 self._compress_into_buffer(out_buffer)
1129 1189 if out_buffer.pos:
1130 1190 return out_buffer.pos
1131 1191
1132 1192 while not self._finished_input:
1133 1193 self._read_input()
1134 1194
1135 1195 if self._compress_into_buffer(out_buffer):
1136 1196 return out_buffer.pos
1137 1197
1138 1198 if out_buffer.pos and not self._finished_input:
1139 1199 return out_buffer.pos
1140 1200
1141 1201 # EOF.
1142 1202 old_pos = out_buffer.pos
1143 1203
1144 1204 zresult = lib.ZSTD_compressStream2(
1145 1205 self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end
1146 1206 )
1147 1207
1148 1208 self._bytes_compressed += out_buffer.pos - old_pos
1149 1209
1150 1210 if lib.ZSTD_isError(zresult):
1151 1211 raise ZstdError(
1152 1212 "error ending compression stream: %s" % _zstd_error(zresult)
1153 1213 )
1154 1214
1155 1215 if zresult == 0:
1156 1216 self._finished_output = True
1157 1217
1158 1218 return out_buffer.pos
1159 1219
1160 1220
1161 1221 class ZstdCompressor(object):
1162 1222 def __init__(
1163 1223 self,
1164 1224 level=3,
1165 1225 dict_data=None,
1166 1226 compression_params=None,
1167 1227 write_checksum=None,
1168 1228 write_content_size=None,
1169 1229 write_dict_id=None,
1170 1230 threads=0,
1171 1231 ):
1172 1232 if level > lib.ZSTD_maxCLevel():
1173 raise ValueError("level must be less than %d" % lib.ZSTD_maxCLevel())
1233 raise ValueError(
1234 "level must be less than %d" % lib.ZSTD_maxCLevel()
1235 )
1174 1236
1175 1237 if threads < 0:
1176 1238 threads = _cpu_count()
1177 1239
1178 1240 if compression_params and write_checksum is not None:
1179 raise ValueError("cannot define compression_params and " "write_checksum")
1241 raise ValueError(
1242 "cannot define compression_params and " "write_checksum"
1243 )
1180 1244
1181 1245 if compression_params and write_content_size is not None:
1182 1246 raise ValueError(
1183 1247 "cannot define compression_params and " "write_content_size"
1184 1248 )
1185 1249
1186 1250 if compression_params and write_dict_id is not None:
1187 raise ValueError("cannot define compression_params and " "write_dict_id")
1251 raise ValueError(
1252 "cannot define compression_params and " "write_dict_id"
1253 )
1188 1254
1189 1255 if compression_params and threads:
1190 1256 raise ValueError("cannot define compression_params and threads")
1191 1257
1192 1258 if compression_params:
1193 1259 self._params = _make_cctx_params(compression_params)
1194 1260 else:
1195 1261 if write_dict_id is None:
1196 1262 write_dict_id = True
1197 1263
1198 1264 params = lib.ZSTD_createCCtxParams()
1199 1265 if params == ffi.NULL:
1200 1266 raise MemoryError()
1201 1267
1202 1268 self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
1203 1269
1204 _set_compression_parameter(self._params, lib.ZSTD_c_compressionLevel, level)
1270 _set_compression_parameter(
1271 self._params, lib.ZSTD_c_compressionLevel, level
1272 )
1205 1273
1206 1274 _set_compression_parameter(
1207 1275 self._params,
1208 1276 lib.ZSTD_c_contentSizeFlag,
1209 1277 write_content_size if write_content_size is not None else 1,
1210 1278 )
1211 1279
1212 1280 _set_compression_parameter(
1213 self._params, lib.ZSTD_c_checksumFlag, 1 if write_checksum else 0
1281 self._params,
1282 lib.ZSTD_c_checksumFlag,
1283 1 if write_checksum else 0,
1214 1284 )
1215 1285
1216 1286 _set_compression_parameter(
1217 1287 self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0
1218 1288 )
1219 1289
1220 1290 if threads:
1221 _set_compression_parameter(self._params, lib.ZSTD_c_nbWorkers, threads)
1291 _set_compression_parameter(
1292 self._params, lib.ZSTD_c_nbWorkers, threads
1293 )
1222 1294
1223 1295 cctx = lib.ZSTD_createCCtx()
1224 1296 if cctx == ffi.NULL:
1225 1297 raise MemoryError()
1226 1298
1227 1299 self._cctx = cctx
1228 1300 self._dict_data = dict_data
1229 1301
1230 1302 # We defer setting up garbage collection until after calling
1231 1303 # _setup_cctx() to ensure the memory size estimate is more accurate.
1232 1304 try:
1233 1305 self._setup_cctx()
1234 1306 finally:
1235 1307 self._cctx = ffi.gc(
1236 1308 cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx)
1237 1309 )
1238 1310
1239 1311 def _setup_cctx(self):
1240 zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx, self._params)
1312 zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(
1313 self._cctx, self._params
1314 )
1241 1315 if lib.ZSTD_isError(zresult):
1242 1316 raise ZstdError(
1243 "could not set compression parameters: %s" % _zstd_error(zresult)
1317 "could not set compression parameters: %s"
1318 % _zstd_error(zresult)
1244 1319 )
1245 1320
1246 1321 dict_data = self._dict_data
1247 1322
1248 1323 if dict_data:
1249 1324 if dict_data._cdict:
1250 1325 zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict)
1251 1326 else:
1252 1327 zresult = lib.ZSTD_CCtx_loadDictionary_advanced(
1253 1328 self._cctx,
1254 1329 dict_data.as_bytes(),
1255 1330 len(dict_data),
1256 1331 lib.ZSTD_dlm_byRef,
1257 1332 dict_data._dict_type,
1258 1333 )
1259 1334
1260 1335 if lib.ZSTD_isError(zresult):
1261 1336 raise ZstdError(
1262 "could not load compression dictionary: %s" % _zstd_error(zresult)
1337 "could not load compression dictionary: %s"
1338 % _zstd_error(zresult)
1263 1339 )
1264 1340
1265 1341 def memory_size(self):
1266 1342 return lib.ZSTD_sizeof_CCtx(self._cctx)
1267 1343
1268 1344 def compress(self, data):
1269 1345 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1270 1346
1271 1347 data_buffer = ffi.from_buffer(data)
1272 1348
1273 1349 dest_size = lib.ZSTD_compressBound(len(data_buffer))
1274 1350 out = new_nonzero("char[]", dest_size)
1275 1351
1276 1352 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer))
1277 1353 if lib.ZSTD_isError(zresult):
1278 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1354 raise ZstdError(
1355 "error setting source size: %s" % _zstd_error(zresult)
1356 )
1279 1357
1280 1358 out_buffer = ffi.new("ZSTD_outBuffer *")
1281 1359 in_buffer = ffi.new("ZSTD_inBuffer *")
1282 1360
1283 1361 out_buffer.dst = out
1284 1362 out_buffer.size = dest_size
1285 1363 out_buffer.pos = 0
1286 1364
1287 1365 in_buffer.src = data_buffer
1288 1366 in_buffer.size = len(data_buffer)
1289 1367 in_buffer.pos = 0
1290 1368
1291 1369 zresult = lib.ZSTD_compressStream2(
1292 1370 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1293 1371 )
1294 1372
1295 1373 if lib.ZSTD_isError(zresult):
1296 1374 raise ZstdError("cannot compress: %s" % _zstd_error(zresult))
1297 1375 elif zresult:
1298 1376 raise ZstdError("unexpected partial frame flush")
1299 1377
1300 1378 return ffi.buffer(out, out_buffer.pos)[:]
1301 1379
1302 1380 def compressobj(self, size=-1):
1303 1381 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1304 1382
1305 1383 if size < 0:
1306 1384 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1307 1385
1308 1386 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1309 1387 if lib.ZSTD_isError(zresult):
1310 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1388 raise ZstdError(
1389 "error setting source size: %s" % _zstd_error(zresult)
1390 )
1311 1391
1312 1392 cobj = ZstdCompressionObj()
1313 1393 cobj._out = ffi.new("ZSTD_outBuffer *")
1314 cobj._dst_buffer = ffi.new("char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE)
1394 cobj._dst_buffer = ffi.new(
1395 "char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1396 )
1315 1397 cobj._out.dst = cobj._dst_buffer
1316 1398 cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
1317 1399 cobj._out.pos = 0
1318 1400 cobj._compressor = self
1319 1401 cobj._finished = False
1320 1402
1321 1403 return cobj
1322 1404
1323 1405 def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
1324 1406 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1325 1407
1326 1408 if size < 0:
1327 1409 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1328 1410
1329 1411 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1330 1412 if lib.ZSTD_isError(zresult):
1331 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1413 raise ZstdError(
1414 "error setting source size: %s" % _zstd_error(zresult)
1415 )
1332 1416
1333 1417 return ZstdCompressionChunker(self, chunk_size=chunk_size)
1334 1418
1335 1419 def copy_stream(
1336 1420 self,
1337 1421 ifh,
1338 1422 ofh,
1339 1423 size=-1,
1340 1424 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
1341 1425 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1342 1426 ):
1343 1427
1344 1428 if not hasattr(ifh, "read"):
1345 1429 raise ValueError("first argument must have a read() method")
1346 1430 if not hasattr(ofh, "write"):
1347 1431 raise ValueError("second argument must have a write() method")
1348 1432
1349 1433 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1350 1434
1351 1435 if size < 0:
1352 1436 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1353 1437
1354 1438 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1355 1439 if lib.ZSTD_isError(zresult):
1356 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1440 raise ZstdError(
1441 "error setting source size: %s" % _zstd_error(zresult)
1442 )
1357 1443
1358 1444 in_buffer = ffi.new("ZSTD_inBuffer *")
1359 1445 out_buffer = ffi.new("ZSTD_outBuffer *")
1360 1446
1361 1447 dst_buffer = ffi.new("char[]", write_size)
1362 1448 out_buffer.dst = dst_buffer
1363 1449 out_buffer.size = write_size
1364 1450 out_buffer.pos = 0
1365 1451
1366 1452 total_read, total_write = 0, 0
1367 1453
1368 1454 while True:
1369 1455 data = ifh.read(read_size)
1370 1456 if not data:
1371 1457 break
1372 1458
1373 1459 data_buffer = ffi.from_buffer(data)
1374 1460 total_read += len(data_buffer)
1375 1461 in_buffer.src = data_buffer
1376 1462 in_buffer.size = len(data_buffer)
1377 1463 in_buffer.pos = 0
1378 1464
1379 1465 while in_buffer.pos < in_buffer.size:
1380 1466 zresult = lib.ZSTD_compressStream2(
1381 1467 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
1382 1468 )
1383 1469 if lib.ZSTD_isError(zresult):
1384 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
1470 raise ZstdError(
1471 "zstd compress error: %s" % _zstd_error(zresult)
1472 )
1385 1473
1386 1474 if out_buffer.pos:
1387 1475 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
1388 1476 total_write += out_buffer.pos
1389 1477 out_buffer.pos = 0
1390 1478
1391 1479 # We've finished reading. Flush the compressor.
1392 1480 while True:
1393 1481 zresult = lib.ZSTD_compressStream2(
1394 1482 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1395 1483 )
1396 1484 if lib.ZSTD_isError(zresult):
1397 1485 raise ZstdError(
1398 1486 "error ending compression stream: %s" % _zstd_error(zresult)
1399 1487 )
1400 1488
1401 1489 if out_buffer.pos:
1402 1490 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
1403 1491 total_write += out_buffer.pos
1404 1492 out_buffer.pos = 0
1405 1493
1406 1494 if zresult == 0:
1407 1495 break
1408 1496
1409 1497 return total_read, total_write
1410 1498
1411 1499 def stream_reader(
1412 1500 self, source, size=-1, read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE
1413 1501 ):
1414 1502 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1415 1503
1416 1504 try:
1417 1505 size = len(source)
1418 1506 except Exception:
1419 1507 pass
1420 1508
1421 1509 if size < 0:
1422 1510 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1423 1511
1424 1512 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1425 1513 if lib.ZSTD_isError(zresult):
1426 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1514 raise ZstdError(
1515 "error setting source size: %s" % _zstd_error(zresult)
1516 )
1427 1517
1428 1518 return ZstdCompressionReader(self, source, read_size)
1429 1519
1430 1520 def stream_writer(
1431 1521 self,
1432 1522 writer,
1433 1523 size=-1,
1434 1524 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1435 1525 write_return_read=False,
1436 1526 ):
1437 1527
1438 1528 if not hasattr(writer, "write"):
1439 1529 raise ValueError("must pass an object with a write() method")
1440 1530
1441 1531 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1442 1532
1443 1533 if size < 0:
1444 1534 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1445 1535
1446 return ZstdCompressionWriter(self, writer, size, write_size, write_return_read)
1536 return ZstdCompressionWriter(
1537 self, writer, size, write_size, write_return_read
1538 )
1447 1539
1448 1540 write_to = stream_writer
1449 1541
1450 1542 def read_to_iter(
1451 1543 self,
1452 1544 reader,
1453 1545 size=-1,
1454 1546 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
1455 1547 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
1456 1548 ):
1457 1549 if hasattr(reader, "read"):
1458 1550 have_read = True
1459 1551 elif hasattr(reader, "__getitem__"):
1460 1552 have_read = False
1461 1553 buffer_offset = 0
1462 1554 size = len(reader)
1463 1555 else:
1464 1556 raise ValueError(
1465 1557 "must pass an object with a read() method or "
1466 1558 "conforms to buffer protocol"
1467 1559 )
1468 1560
1469 1561 lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
1470 1562
1471 1563 if size < 0:
1472 1564 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
1473 1565
1474 1566 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1475 1567 if lib.ZSTD_isError(zresult):
1476 raise ZstdError("error setting source size: %s" % _zstd_error(zresult))
1568 raise ZstdError(
1569 "error setting source size: %s" % _zstd_error(zresult)
1570 )
1477 1571
1478 1572 in_buffer = ffi.new("ZSTD_inBuffer *")
1479 1573 out_buffer = ffi.new("ZSTD_outBuffer *")
1480 1574
1481 1575 in_buffer.src = ffi.NULL
1482 1576 in_buffer.size = 0
1483 1577 in_buffer.pos = 0
1484 1578
1485 1579 dst_buffer = ffi.new("char[]", write_size)
1486 1580 out_buffer.dst = dst_buffer
1487 1581 out_buffer.size = write_size
1488 1582 out_buffer.pos = 0
1489 1583
1490 1584 while True:
1491 1585 # We should never have output data sitting around after a previous
1492 1586 # iteration.
1493 1587 assert out_buffer.pos == 0
1494 1588
1495 1589 # Collect input data.
1496 1590 if have_read:
1497 1591 read_result = reader.read(read_size)
1498 1592 else:
1499 1593 remaining = len(reader) - buffer_offset
1500 1594 slice_size = min(remaining, read_size)
1501 1595 read_result = reader[buffer_offset : buffer_offset + slice_size]
1502 1596 buffer_offset += slice_size
1503 1597
1504 1598 # No new input data. Break out of the read loop.
1505 1599 if not read_result:
1506 1600 break
1507 1601
1508 1602 # Feed all read data into the compressor and emit output until
1509 1603 # exhausted.
1510 1604 read_buffer = ffi.from_buffer(read_result)
1511 1605 in_buffer.src = read_buffer
1512 1606 in_buffer.size = len(read_buffer)
1513 1607 in_buffer.pos = 0
1514 1608
1515 1609 while in_buffer.pos < in_buffer.size:
1516 1610 zresult = lib.ZSTD_compressStream2(
1517 1611 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue
1518 1612 )
1519 1613 if lib.ZSTD_isError(zresult):
1520 raise ZstdError("zstd compress error: %s" % _zstd_error(zresult))
1614 raise ZstdError(
1615 "zstd compress error: %s" % _zstd_error(zresult)
1616 )
1521 1617
1522 1618 if out_buffer.pos:
1523 1619 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1524 1620 out_buffer.pos = 0
1525 1621 yield data
1526 1622
1527 1623 assert out_buffer.pos == 0
1528 1624
1529 1625 # And repeat the loop to collect more data.
1530 1626 continue
1531 1627
1532 1628 # If we get here, input is exhausted. End the stream and emit what
1533 1629 # remains.
1534 1630 while True:
1535 1631 assert out_buffer.pos == 0
1536 1632 zresult = lib.ZSTD_compressStream2(
1537 1633 self._cctx, out_buffer, in_buffer, lib.ZSTD_e_end
1538 1634 )
1539 1635 if lib.ZSTD_isError(zresult):
1540 1636 raise ZstdError(
1541 1637 "error ending compression stream: %s" % _zstd_error(zresult)
1542 1638 )
1543 1639
1544 1640 if out_buffer.pos:
1545 1641 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1546 1642 out_buffer.pos = 0
1547 1643 yield data
1548 1644
1549 1645 if zresult == 0:
1550 1646 break
1551 1647
1552 1648 read_from = read_to_iter
1553 1649
1554 1650 def frame_progression(self):
1555 1651 progression = lib.ZSTD_getFrameProgression(self._cctx)
1556 1652
1557 1653 return progression.ingested, progression.consumed, progression.produced
1558 1654
1559 1655
1560 1656 class FrameParameters(object):
1561 1657 def __init__(self, fparams):
1562 1658 self.content_size = fparams.frameContentSize
1563 1659 self.window_size = fparams.windowSize
1564 1660 self.dict_id = fparams.dictID
1565 1661 self.has_checksum = bool(fparams.checksumFlag)
1566 1662
1567 1663
1568 1664 def frame_content_size(data):
1569 1665 data_buffer = ffi.from_buffer(data)
1570 1666
1571 1667 size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
1572 1668
1573 1669 if size == lib.ZSTD_CONTENTSIZE_ERROR:
1574 1670 raise ZstdError("error when determining content size")
1575 1671 elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
1576 1672 return -1
1577 1673 else:
1578 1674 return size
1579 1675
1580 1676
1581 1677 def frame_header_size(data):
1582 1678 data_buffer = ffi.from_buffer(data)
1583 1679
1584 1680 zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer))
1585 1681 if lib.ZSTD_isError(zresult):
1586 1682 raise ZstdError(
1587 1683 "could not determine frame header size: %s" % _zstd_error(zresult)
1588 1684 )
1589 1685
1590 1686 return zresult
1591 1687
1592 1688
1593 1689 def get_frame_parameters(data):
1594 1690 params = ffi.new("ZSTD_frameHeader *")
1595 1691
1596 1692 data_buffer = ffi.from_buffer(data)
1597 1693 zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer))
1598 1694 if lib.ZSTD_isError(zresult):
1599 raise ZstdError("cannot get frame parameters: %s" % _zstd_error(zresult))
1695 raise ZstdError(
1696 "cannot get frame parameters: %s" % _zstd_error(zresult)
1697 )
1600 1698
1601 1699 if zresult:
1602 raise ZstdError("not enough data for frame parameters; need %d bytes" % zresult)
1700 raise ZstdError(
1701 "not enough data for frame parameters; need %d bytes" % zresult
1702 )
1603 1703
1604 1704 return FrameParameters(params[0])
1605 1705
1606 1706
1607 1707 class ZstdCompressionDict(object):
1608 1708 def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0):
1609 1709 assert isinstance(data, bytes_type)
1610 1710 self._data = data
1611 1711 self.k = k
1612 1712 self.d = d
1613 1713
1614 if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT, DICT_TYPE_FULLDICT):
1714 if dict_type not in (
1715 DICT_TYPE_AUTO,
1716 DICT_TYPE_RAWCONTENT,
1717 DICT_TYPE_FULLDICT,
1718 ):
1615 1719 raise ValueError(
1616 "invalid dictionary load mode: %d; must use " "DICT_TYPE_* constants"
1720 "invalid dictionary load mode: %d; must use "
1721 "DICT_TYPE_* constants"
1617 1722 )
1618 1723
1619 1724 self._dict_type = dict_type
1620 1725 self._cdict = None
1621 1726
1622 1727 def __len__(self):
1623 1728 return len(self._data)
1624 1729
1625 1730 def dict_id(self):
1626 1731 return int_type(lib.ZDICT_getDictID(self._data, len(self._data)))
1627 1732
1628 1733 def as_bytes(self):
1629 1734 return self._data
1630 1735
1631 1736 def precompute_compress(self, level=0, compression_params=None):
1632 1737 if level and compression_params:
1633 raise ValueError("must only specify one of level or " "compression_params")
1738 raise ValueError(
1739 "must only specify one of level or " "compression_params"
1740 )
1634 1741
1635 1742 if not level and not compression_params:
1636 1743 raise ValueError("must specify one of level or compression_params")
1637 1744
1638 1745 if level:
1639 1746 cparams = lib.ZSTD_getCParams(level, 0, len(self._data))
1640 1747 else:
1641 1748 cparams = ffi.new("ZSTD_compressionParameters")
1642 1749 cparams.chainLog = compression_params.chain_log
1643 1750 cparams.hashLog = compression_params.hash_log
1644 1751 cparams.minMatch = compression_params.min_match
1645 1752 cparams.searchLog = compression_params.search_log
1646 1753 cparams.strategy = compression_params.compression_strategy
1647 1754 cparams.targetLength = compression_params.target_length
1648 1755 cparams.windowLog = compression_params.window_log
1649 1756
1650 1757 cdict = lib.ZSTD_createCDict_advanced(
1651 1758 self._data,
1652 1759 len(self._data),
1653 1760 lib.ZSTD_dlm_byRef,
1654 1761 self._dict_type,
1655 1762 cparams,
1656 1763 lib.ZSTD_defaultCMem,
1657 1764 )
1658 1765 if cdict == ffi.NULL:
1659 1766 raise ZstdError("unable to precompute dictionary")
1660 1767
1661 1768 self._cdict = ffi.gc(
1662 1769 cdict, lib.ZSTD_freeCDict, size=lib.ZSTD_sizeof_CDict(cdict)
1663 1770 )
1664 1771
1665 1772 @property
1666 1773 def _ddict(self):
1667 1774 ddict = lib.ZSTD_createDDict_advanced(
1668 1775 self._data,
1669 1776 len(self._data),
1670 1777 lib.ZSTD_dlm_byRef,
1671 1778 self._dict_type,
1672 1779 lib.ZSTD_defaultCMem,
1673 1780 )
1674 1781
1675 1782 if ddict == ffi.NULL:
1676 1783 raise ZstdError("could not create decompression dict")
1677 1784
1678 ddict = ffi.gc(ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict))
1785 ddict = ffi.gc(
1786 ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict)
1787 )
1679 1788 self.__dict__["_ddict"] = ddict
1680 1789
1681 1790 return ddict
1682 1791
1683 1792
1684 1793 def train_dictionary(
1685 1794 dict_size,
1686 1795 samples,
1687 1796 k=0,
1688 1797 d=0,
1689 1798 notifications=0,
1690 1799 dict_id=0,
1691 1800 level=0,
1692 1801 steps=0,
1693 1802 threads=0,
1694 1803 ):
1695 1804 if not isinstance(samples, list):
1696 1805 raise TypeError("samples must be a list")
1697 1806
1698 1807 if threads < 0:
1699 1808 threads = _cpu_count()
1700 1809
1701 1810 total_size = sum(map(len, samples))
1702 1811
1703 1812 samples_buffer = new_nonzero("char[]", total_size)
1704 1813 sample_sizes = new_nonzero("size_t[]", len(samples))
1705 1814
1706 1815 offset = 0
1707 1816 for i, sample in enumerate(samples):
1708 1817 if not isinstance(sample, bytes_type):
1709 1818 raise ValueError("samples must be bytes")
1710 1819
1711 1820 l = len(sample)
1712 1821 ffi.memmove(samples_buffer + offset, sample, l)
1713 1822 offset += l
1714 1823 sample_sizes[i] = l
1715 1824
1716 1825 dict_data = new_nonzero("char[]", dict_size)
1717 1826
1718 1827 dparams = ffi.new("ZDICT_cover_params_t *")[0]
1719 1828 dparams.k = k
1720 1829 dparams.d = d
1721 1830 dparams.steps = steps
1722 1831 dparams.nbThreads = threads
1723 1832 dparams.zParams.notificationLevel = notifications
1724 1833 dparams.zParams.dictID = dict_id
1725 1834 dparams.zParams.compressionLevel = level
1726 1835
1727 1836 if (
1728 1837 not dparams.k
1729 1838 and not dparams.d
1730 1839 and not dparams.steps
1731 1840 and not dparams.nbThreads
1732 1841 and not dparams.zParams.notificationLevel
1733 1842 and not dparams.zParams.dictID
1734 1843 and not dparams.zParams.compressionLevel
1735 1844 ):
1736 1845 zresult = lib.ZDICT_trainFromBuffer(
1737 1846 ffi.addressof(dict_data),
1738 1847 dict_size,
1739 1848 ffi.addressof(samples_buffer),
1740 1849 ffi.addressof(sample_sizes, 0),
1741 1850 len(samples),
1742 1851 )
1743 1852 elif dparams.steps or dparams.nbThreads:
1744 1853 zresult = lib.ZDICT_optimizeTrainFromBuffer_cover(
1745 1854 ffi.addressof(dict_data),
1746 1855 dict_size,
1747 1856 ffi.addressof(samples_buffer),
1748 1857 ffi.addressof(sample_sizes, 0),
1749 1858 len(samples),
1750 1859 ffi.addressof(dparams),
1751 1860 )
1752 1861 else:
1753 1862 zresult = lib.ZDICT_trainFromBuffer_cover(
1754 1863 ffi.addressof(dict_data),
1755 1864 dict_size,
1756 1865 ffi.addressof(samples_buffer),
1757 1866 ffi.addressof(sample_sizes, 0),
1758 1867 len(samples),
1759 1868 dparams,
1760 1869 )
1761 1870
1762 1871 if lib.ZDICT_isError(zresult):
1763 1872 msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode("utf-8")
1764 1873 raise ZstdError("cannot train dict: %s" % msg)
1765 1874
1766 1875 return ZstdCompressionDict(
1767 1876 ffi.buffer(dict_data, zresult)[:],
1768 1877 dict_type=DICT_TYPE_FULLDICT,
1769 1878 k=dparams.k,
1770 1879 d=dparams.d,
1771 1880 )
1772 1881
1773 1882
1774 1883 class ZstdDecompressionObj(object):
1775 1884 def __init__(self, decompressor, write_size):
1776 1885 self._decompressor = decompressor
1777 1886 self._write_size = write_size
1778 1887 self._finished = False
1779 1888
1780 1889 def decompress(self, data):
1781 1890 if self._finished:
1782 1891 raise ZstdError("cannot use a decompressobj multiple times")
1783 1892
1784 1893 in_buffer = ffi.new("ZSTD_inBuffer *")
1785 1894 out_buffer = ffi.new("ZSTD_outBuffer *")
1786 1895
1787 1896 data_buffer = ffi.from_buffer(data)
1788 1897
1789 1898 if len(data_buffer) == 0:
1790 1899 return b""
1791 1900
1792 1901 in_buffer.src = data_buffer
1793 1902 in_buffer.size = len(data_buffer)
1794 1903 in_buffer.pos = 0
1795 1904
1796 1905 dst_buffer = ffi.new("char[]", self._write_size)
1797 1906 out_buffer.dst = dst_buffer
1798 1907 out_buffer.size = len(dst_buffer)
1799 1908 out_buffer.pos = 0
1800 1909
1801 1910 chunks = []
1802 1911
1803 1912 while True:
1804 1913 zresult = lib.ZSTD_decompressStream(
1805 1914 self._decompressor._dctx, out_buffer, in_buffer
1806 1915 )
1807 1916 if lib.ZSTD_isError(zresult):
1808 raise ZstdError("zstd decompressor error: %s" % _zstd_error(zresult))
1917 raise ZstdError(
1918 "zstd decompressor error: %s" % _zstd_error(zresult)
1919 )
1809 1920
1810 1921 if zresult == 0:
1811 1922 self._finished = True
1812 1923 self._decompressor = None
1813 1924
1814 1925 if out_buffer.pos:
1815 1926 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
1816 1927
1817 1928 if zresult == 0 or (
1818 1929 in_buffer.pos == in_buffer.size and out_buffer.pos == 0
1819 1930 ):
1820 1931 break
1821 1932
1822 1933 out_buffer.pos = 0
1823 1934
1824 1935 return b"".join(chunks)
1825 1936
1826 1937 def flush(self, length=0):
1827 1938 pass
1828 1939
1829 1940
1830 1941 class ZstdDecompressionReader(object):
1831 1942 def __init__(self, decompressor, source, read_size, read_across_frames):
1832 1943 self._decompressor = decompressor
1833 1944 self._source = source
1834 1945 self._read_size = read_size
1835 1946 self._read_across_frames = bool(read_across_frames)
1836 1947 self._entered = False
1837 1948 self._closed = False
1838 1949 self._bytes_decompressed = 0
1839 1950 self._finished_input = False
1840 1951 self._finished_output = False
1841 1952 self._in_buffer = ffi.new("ZSTD_inBuffer *")
1842 1953 # Holds a ref to self._in_buffer.src.
1843 1954 self._source_buffer = None
1844 1955
1845 1956 def __enter__(self):
1846 1957 if self._entered:
1847 1958 raise ValueError("cannot __enter__ multiple times")
1848 1959
1849 1960 self._entered = True
1850 1961 return self
1851 1962
1852 1963 def __exit__(self, exc_type, exc_value, exc_tb):
1853 1964 self._entered = False
1854 1965 self._closed = True
1855 1966 self._source = None
1856 1967 self._decompressor = None
1857 1968
1858 1969 return False
1859 1970
1860 1971 def readable(self):
1861 1972 return True
1862 1973
1863 1974 def writable(self):
1864 1975 return False
1865 1976
1866 1977 def seekable(self):
1867 1978 return True
1868 1979
1869 1980 def readline(self):
1870 1981 raise io.UnsupportedOperation()
1871 1982
1872 1983 def readlines(self):
1873 1984 raise io.UnsupportedOperation()
1874 1985
1875 1986 def write(self, data):
1876 1987 raise io.UnsupportedOperation()
1877 1988
1878 1989 def writelines(self, lines):
1879 1990 raise io.UnsupportedOperation()
1880 1991
1881 1992 def isatty(self):
1882 1993 return False
1883 1994
1884 1995 def flush(self):
1885 1996 return None
1886 1997
1887 1998 def close(self):
1888 1999 self._closed = True
1889 2000 return None
1890 2001
1891 2002 @property
1892 2003 def closed(self):
1893 2004 return self._closed
1894 2005
1895 2006 def tell(self):
1896 2007 return self._bytes_decompressed
1897 2008
1898 2009 def readall(self):
1899 2010 chunks = []
1900 2011
1901 2012 while True:
1902 2013 chunk = self.read(1048576)
1903 2014 if not chunk:
1904 2015 break
1905 2016
1906 2017 chunks.append(chunk)
1907 2018
1908 2019 return b"".join(chunks)
1909 2020
1910 2021 def __iter__(self):
1911 2022 raise io.UnsupportedOperation()
1912 2023
1913 2024 def __next__(self):
1914 2025 raise io.UnsupportedOperation()
1915 2026
1916 2027 next = __next__
1917 2028
1918 2029 def _read_input(self):
1919 2030 # We have data left over in the input buffer. Use it.
1920 2031 if self._in_buffer.pos < self._in_buffer.size:
1921 2032 return
1922 2033
1923 2034 # All input data exhausted. Nothing to do.
1924 2035 if self._finished_input:
1925 2036 return
1926 2037
1927 2038 # Else populate the input buffer from our source.
1928 2039 if hasattr(self._source, "read"):
1929 2040 data = self._source.read(self._read_size)
1930 2041
1931 2042 if not data:
1932 2043 self._finished_input = True
1933 2044 return
1934 2045
1935 2046 self._source_buffer = ffi.from_buffer(data)
1936 2047 self._in_buffer.src = self._source_buffer
1937 2048 self._in_buffer.size = len(self._source_buffer)
1938 2049 self._in_buffer.pos = 0
1939 2050 else:
1940 2051 self._source_buffer = ffi.from_buffer(self._source)
1941 2052 self._in_buffer.src = self._source_buffer
1942 2053 self._in_buffer.size = len(self._source_buffer)
1943 2054 self._in_buffer.pos = 0
1944 2055
1945 2056 def _decompress_into_buffer(self, out_buffer):
1946 2057 """Decompress available input into an output buffer.
1947 2058
1948 2059 Returns True if data in output buffer should be emitted.
1949 2060 """
1950 2061 zresult = lib.ZSTD_decompressStream(
1951 2062 self._decompressor._dctx, out_buffer, self._in_buffer
1952 2063 )
1953 2064
1954 2065 if self._in_buffer.pos == self._in_buffer.size:
1955 2066 self._in_buffer.src = ffi.NULL
1956 2067 self._in_buffer.pos = 0
1957 2068 self._in_buffer.size = 0
1958 2069 self._source_buffer = None
1959 2070
1960 2071 if not hasattr(self._source, "read"):
1961 2072 self._finished_input = True
1962 2073
1963 2074 if lib.ZSTD_isError(zresult):
1964 2075 raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult))
1965 2076
1966 2077 # Emit data if there is data AND either:
1967 2078 # a) output buffer is full (read amount is satisfied)
1968 2079 # b) we're at end of a frame and not in frame spanning mode
1969 2080 return out_buffer.pos and (
1970 2081 out_buffer.pos == out_buffer.size
1971 2082 or zresult == 0
1972 2083 and not self._read_across_frames
1973 2084 )
1974 2085
1975 2086 def read(self, size=-1):
1976 2087 if self._closed:
1977 2088 raise ValueError("stream is closed")
1978 2089
1979 2090 if size < -1:
1980 2091 raise ValueError("cannot read negative amounts less than -1")
1981 2092
1982 2093 if size == -1:
1983 2094 # This is recursive. But it gets the job done.
1984 2095 return self.readall()
1985 2096
1986 2097 if self._finished_output or size == 0:
1987 2098 return b""
1988 2099
1989 2100 # We /could/ call into readinto() here. But that introduces more
1990 2101 # overhead.
1991 2102 dst_buffer = ffi.new("char[]", size)
1992 2103 out_buffer = ffi.new("ZSTD_outBuffer *")
1993 2104 out_buffer.dst = dst_buffer
1994 2105 out_buffer.size = size
1995 2106 out_buffer.pos = 0
1996 2107
1997 2108 self._read_input()
1998 2109 if self._decompress_into_buffer(out_buffer):
1999 2110 self._bytes_decompressed += out_buffer.pos
2000 2111 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2001 2112
2002 2113 while not self._finished_input:
2003 2114 self._read_input()
2004 2115 if self._decompress_into_buffer(out_buffer):
2005 2116 self._bytes_decompressed += out_buffer.pos
2006 2117 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2007 2118
2008 2119 self._bytes_decompressed += out_buffer.pos
2009 2120 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2010 2121
2011 2122 def readinto(self, b):
2012 2123 if self._closed:
2013 2124 raise ValueError("stream is closed")
2014 2125
2015 2126 if self._finished_output:
2016 2127 return 0
2017 2128
2018 2129 # TODO use writable=True once we require CFFI >= 1.12.
2019 2130 dest_buffer = ffi.from_buffer(b)
2020 2131 ffi.memmove(b, b"", 0)
2021 2132 out_buffer = ffi.new("ZSTD_outBuffer *")
2022 2133 out_buffer.dst = dest_buffer
2023 2134 out_buffer.size = len(dest_buffer)
2024 2135 out_buffer.pos = 0
2025 2136
2026 2137 self._read_input()
2027 2138 if self._decompress_into_buffer(out_buffer):
2028 2139 self._bytes_decompressed += out_buffer.pos
2029 2140 return out_buffer.pos
2030 2141
2031 2142 while not self._finished_input:
2032 2143 self._read_input()
2033 2144 if self._decompress_into_buffer(out_buffer):
2034 2145 self._bytes_decompressed += out_buffer.pos
2035 2146 return out_buffer.pos
2036 2147
2037 2148 self._bytes_decompressed += out_buffer.pos
2038 2149 return out_buffer.pos
2039 2150
2040 2151 def read1(self, size=-1):
2041 2152 if self._closed:
2042 2153 raise ValueError("stream is closed")
2043 2154
2044 2155 if size < -1:
2045 2156 raise ValueError("cannot read negative amounts less than -1")
2046 2157
2047 2158 if self._finished_output or size == 0:
2048 2159 return b""
2049 2160
2050 2161 # -1 returns arbitrary number of bytes.
2051 2162 if size == -1:
2052 2163 size = DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE
2053 2164
2054 2165 dst_buffer = ffi.new("char[]", size)
2055 2166 out_buffer = ffi.new("ZSTD_outBuffer *")
2056 2167 out_buffer.dst = dst_buffer
2057 2168 out_buffer.size = size
2058 2169 out_buffer.pos = 0
2059 2170
2060 2171 # read1() dictates that we can perform at most 1 call to underlying
2061 2172 # stream to get input. However, we can't satisfy this restriction with
2062 2173 # decompression because not all input generates output. So we allow
2063 2174 # multiple read(). But unlike read(), we stop once we have any output.
2064 2175 while not self._finished_input:
2065 2176 self._read_input()
2066 2177 self._decompress_into_buffer(out_buffer)
2067 2178
2068 2179 if out_buffer.pos:
2069 2180 break
2070 2181
2071 2182 self._bytes_decompressed += out_buffer.pos
2072 2183 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2073 2184
2074 2185 def readinto1(self, b):
2075 2186 if self._closed:
2076 2187 raise ValueError("stream is closed")
2077 2188
2078 2189 if self._finished_output:
2079 2190 return 0
2080 2191
2081 2192 # TODO use writable=True once we require CFFI >= 1.12.
2082 2193 dest_buffer = ffi.from_buffer(b)
2083 2194 ffi.memmove(b, b"", 0)
2084 2195
2085 2196 out_buffer = ffi.new("ZSTD_outBuffer *")
2086 2197 out_buffer.dst = dest_buffer
2087 2198 out_buffer.size = len(dest_buffer)
2088 2199 out_buffer.pos = 0
2089 2200
2090 2201 while not self._finished_input and not self._finished_output:
2091 2202 self._read_input()
2092 2203 self._decompress_into_buffer(out_buffer)
2093 2204
2094 2205 if out_buffer.pos:
2095 2206 break
2096 2207
2097 2208 self._bytes_decompressed += out_buffer.pos
2098 2209 return out_buffer.pos
2099 2210
2100 2211 def seek(self, pos, whence=os.SEEK_SET):
2101 2212 if self._closed:
2102 2213 raise ValueError("stream is closed")
2103 2214
2104 2215 read_amount = 0
2105 2216
2106 2217 if whence == os.SEEK_SET:
2107 2218 if pos < 0:
2108 raise ValueError("cannot seek to negative position with SEEK_SET")
2219 raise ValueError(
2220 "cannot seek to negative position with SEEK_SET"
2221 )
2109 2222
2110 2223 if pos < self._bytes_decompressed:
2111 raise ValueError("cannot seek zstd decompression stream " "backwards")
2224 raise ValueError(
2225 "cannot seek zstd decompression stream " "backwards"
2226 )
2112 2227
2113 2228 read_amount = pos - self._bytes_decompressed
2114 2229
2115 2230 elif whence == os.SEEK_CUR:
2116 2231 if pos < 0:
2117 raise ValueError("cannot seek zstd decompression stream " "backwards")
2232 raise ValueError(
2233 "cannot seek zstd decompression stream " "backwards"
2234 )
2118 2235
2119 2236 read_amount = pos
2120 2237 elif whence == os.SEEK_END:
2121 2238 raise ValueError(
2122 2239 "zstd decompression streams cannot be seeked " "with SEEK_END"
2123 2240 )
2124 2241
2125 2242 while read_amount:
2126 result = self.read(min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE))
2243 result = self.read(
2244 min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
2245 )
2127 2246
2128 2247 if not result:
2129 2248 break
2130 2249
2131 2250 read_amount -= len(result)
2132 2251
2133 2252 return self._bytes_decompressed
2134 2253
2135 2254
2136 2255 class ZstdDecompressionWriter(object):
2137 2256 def __init__(self, decompressor, writer, write_size, write_return_read):
2138 2257 decompressor._ensure_dctx()
2139 2258
2140 2259 self._decompressor = decompressor
2141 2260 self._writer = writer
2142 2261 self._write_size = write_size
2143 2262 self._write_return_read = bool(write_return_read)
2144 2263 self._entered = False
2145 2264 self._closed = False
2146 2265
2147 2266 def __enter__(self):
2148 2267 if self._closed:
2149 2268 raise ValueError("stream is closed")
2150 2269
2151 2270 if self._entered:
2152 2271 raise ZstdError("cannot __enter__ multiple times")
2153 2272
2154 2273 self._entered = True
2155 2274
2156 2275 return self
2157 2276
2158 2277 def __exit__(self, exc_type, exc_value, exc_tb):
2159 2278 self._entered = False
2160 2279 self.close()
2161 2280
2162 2281 def memory_size(self):
2163 2282 return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx)
2164 2283
2165 2284 def close(self):
2166 2285 if self._closed:
2167 2286 return
2168 2287
2169 2288 try:
2170 2289 self.flush()
2171 2290 finally:
2172 2291 self._closed = True
2173 2292
2174 2293 f = getattr(self._writer, "close", None)
2175 2294 if f:
2176 2295 f()
2177 2296
2178 2297 @property
2179 2298 def closed(self):
2180 2299 return self._closed
2181 2300
2182 2301 def fileno(self):
2183 2302 f = getattr(self._writer, "fileno", None)
2184 2303 if f:
2185 2304 return f()
2186 2305 else:
2187 2306 raise OSError("fileno not available on underlying writer")
2188 2307
2189 2308 def flush(self):
2190 2309 if self._closed:
2191 2310 raise ValueError("stream is closed")
2192 2311
2193 2312 f = getattr(self._writer, "flush", None)
2194 2313 if f:
2195 2314 return f()
2196 2315
2197 2316 def isatty(self):
2198 2317 return False
2199 2318
2200 2319 def readable(self):
2201 2320 return False
2202 2321
2203 2322 def readline(self, size=-1):
2204 2323 raise io.UnsupportedOperation()
2205 2324
2206 2325 def readlines(self, hint=-1):
2207 2326 raise io.UnsupportedOperation()
2208 2327
2209 2328 def seek(self, offset, whence=None):
2210 2329 raise io.UnsupportedOperation()
2211 2330
2212 2331 def seekable(self):
2213 2332 return False
2214 2333
2215 2334 def tell(self):
2216 2335 raise io.UnsupportedOperation()
2217 2336
2218 2337 def truncate(self, size=None):
2219 2338 raise io.UnsupportedOperation()
2220 2339
2221 2340 def writable(self):
2222 2341 return True
2223 2342
2224 2343 def writelines(self, lines):
2225 2344 raise io.UnsupportedOperation()
2226 2345
2227 2346 def read(self, size=-1):
2228 2347 raise io.UnsupportedOperation()
2229 2348
2230 2349 def readall(self):
2231 2350 raise io.UnsupportedOperation()
2232 2351
2233 2352 def readinto(self, b):
2234 2353 raise io.UnsupportedOperation()
2235 2354
2236 2355 def write(self, data):
2237 2356 if self._closed:
2238 2357 raise ValueError("stream is closed")
2239 2358
2240 2359 total_write = 0
2241 2360
2242 2361 in_buffer = ffi.new("ZSTD_inBuffer *")
2243 2362 out_buffer = ffi.new("ZSTD_outBuffer *")
2244 2363
2245 2364 data_buffer = ffi.from_buffer(data)
2246 2365 in_buffer.src = data_buffer
2247 2366 in_buffer.size = len(data_buffer)
2248 2367 in_buffer.pos = 0
2249 2368
2250 2369 dst_buffer = ffi.new("char[]", self._write_size)
2251 2370 out_buffer.dst = dst_buffer
2252 2371 out_buffer.size = len(dst_buffer)
2253 2372 out_buffer.pos = 0
2254 2373
2255 2374 dctx = self._decompressor._dctx
2256 2375
2257 2376 while in_buffer.pos < in_buffer.size:
2258 2377 zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer)
2259 2378 if lib.ZSTD_isError(zresult):
2260 raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult))
2379 raise ZstdError(
2380 "zstd decompress error: %s" % _zstd_error(zresult)
2381 )
2261 2382
2262 2383 if out_buffer.pos:
2263 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
2384 self._writer.write(
2385 ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2386 )
2264 2387 total_write += out_buffer.pos
2265 2388 out_buffer.pos = 0
2266 2389
2267 2390 if self._write_return_read:
2268 2391 return in_buffer.pos
2269 2392 else:
2270 2393 return total_write
2271 2394
2272 2395
2273 2396 class ZstdDecompressor(object):
2274 2397 def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1):
2275 2398 self._dict_data = dict_data
2276 2399 self._max_window_size = max_window_size
2277 2400 self._format = format
2278 2401
2279 2402 dctx = lib.ZSTD_createDCtx()
2280 2403 if dctx == ffi.NULL:
2281 2404 raise MemoryError()
2282 2405
2283 2406 self._dctx = dctx
2284 2407
2285 2408 # Defer setting up garbage collection until full state is loaded so
2286 2409 # the memory size is more accurate.
2287 2410 try:
2288 2411 self._ensure_dctx()
2289 2412 finally:
2290 2413 self._dctx = ffi.gc(
2291 2414 dctx, lib.ZSTD_freeDCtx, size=lib.ZSTD_sizeof_DCtx(dctx)
2292 2415 )
2293 2416
2294 2417 def memory_size(self):
2295 2418 return lib.ZSTD_sizeof_DCtx(self._dctx)
2296 2419
2297 2420 def decompress(self, data, max_output_size=0):
2298 2421 self._ensure_dctx()
2299 2422
2300 2423 data_buffer = ffi.from_buffer(data)
2301 2424
2302 output_size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
2425 output_size = lib.ZSTD_getFrameContentSize(
2426 data_buffer, len(data_buffer)
2427 )
2303 2428
2304 2429 if output_size == lib.ZSTD_CONTENTSIZE_ERROR:
2305 2430 raise ZstdError("error determining content size from frame header")
2306 2431 elif output_size == 0:
2307 2432 return b""
2308 2433 elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2309 2434 if not max_output_size:
2310 raise ZstdError("could not determine content size in frame header")
2435 raise ZstdError(
2436 "could not determine content size in frame header"
2437 )
2311 2438
2312 2439 result_buffer = ffi.new("char[]", max_output_size)
2313 2440 result_size = max_output_size
2314 2441 output_size = 0
2315 2442 else:
2316 2443 result_buffer = ffi.new("char[]", output_size)
2317 2444 result_size = output_size
2318 2445
2319 2446 out_buffer = ffi.new("ZSTD_outBuffer *")
2320 2447 out_buffer.dst = result_buffer
2321 2448 out_buffer.size = result_size
2322 2449 out_buffer.pos = 0
2323 2450
2324 2451 in_buffer = ffi.new("ZSTD_inBuffer *")
2325 2452 in_buffer.src = data_buffer
2326 2453 in_buffer.size = len(data_buffer)
2327 2454 in_buffer.pos = 0
2328 2455
2329 2456 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2330 2457 if lib.ZSTD_isError(zresult):
2331 2458 raise ZstdError("decompression error: %s" % _zstd_error(zresult))
2332 2459 elif zresult:
2333 raise ZstdError("decompression error: did not decompress full frame")
2460 raise ZstdError(
2461 "decompression error: did not decompress full frame"
2462 )
2334 2463 elif output_size and out_buffer.pos != output_size:
2335 2464 raise ZstdError(
2336 2465 "decompression error: decompressed %d bytes; expected %d"
2337 2466 % (zresult, output_size)
2338 2467 )
2339 2468
2340 2469 return ffi.buffer(result_buffer, out_buffer.pos)[:]
2341 2470
2342 2471 def stream_reader(
2343 2472 self,
2344 2473 source,
2345 2474 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2346 2475 read_across_frames=False,
2347 2476 ):
2348 2477 self._ensure_dctx()
2349 return ZstdDecompressionReader(self, source, read_size, read_across_frames)
2478 return ZstdDecompressionReader(
2479 self, source, read_size, read_across_frames
2480 )
2350 2481
2351 2482 def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
2352 2483 if write_size < 1:
2353 2484 raise ValueError("write_size must be positive")
2354 2485
2355 2486 self._ensure_dctx()
2356 2487 return ZstdDecompressionObj(self, write_size=write_size)
2357 2488
2358 2489 def read_to_iter(
2359 2490 self,
2360 2491 reader,
2361 2492 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2362 2493 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2363 2494 skip_bytes=0,
2364 2495 ):
2365 2496 if skip_bytes >= read_size:
2366 2497 raise ValueError("skip_bytes must be smaller than read_size")
2367 2498
2368 2499 if hasattr(reader, "read"):
2369 2500 have_read = True
2370 2501 elif hasattr(reader, "__getitem__"):
2371 2502 have_read = False
2372 2503 buffer_offset = 0
2373 2504 size = len(reader)
2374 2505 else:
2375 2506 raise ValueError(
2376 2507 "must pass an object with a read() method or "
2377 2508 "conforms to buffer protocol"
2378 2509 )
2379 2510
2380 2511 if skip_bytes:
2381 2512 if have_read:
2382 2513 reader.read(skip_bytes)
2383 2514 else:
2384 2515 if skip_bytes > size:
2385 2516 raise ValueError("skip_bytes larger than first input chunk")
2386 2517
2387 2518 buffer_offset = skip_bytes
2388 2519
2389 2520 self._ensure_dctx()
2390 2521
2391 2522 in_buffer = ffi.new("ZSTD_inBuffer *")
2392 2523 out_buffer = ffi.new("ZSTD_outBuffer *")
2393 2524
2394 2525 dst_buffer = ffi.new("char[]", write_size)
2395 2526 out_buffer.dst = dst_buffer
2396 2527 out_buffer.size = len(dst_buffer)
2397 2528 out_buffer.pos = 0
2398 2529
2399 2530 while True:
2400 2531 assert out_buffer.pos == 0
2401 2532
2402 2533 if have_read:
2403 2534 read_result = reader.read(read_size)
2404 2535 else:
2405 2536 remaining = size - buffer_offset
2406 2537 slice_size = min(remaining, read_size)
2407 2538 read_result = reader[buffer_offset : buffer_offset + slice_size]
2408 2539 buffer_offset += slice_size
2409 2540
2410 2541 # No new input. Break out of read loop.
2411 2542 if not read_result:
2412 2543 break
2413 2544
2414 2545 # Feed all read data into decompressor and emit output until
2415 2546 # exhausted.
2416 2547 read_buffer = ffi.from_buffer(read_result)
2417 2548 in_buffer.src = read_buffer
2418 2549 in_buffer.size = len(read_buffer)
2419 2550 in_buffer.pos = 0
2420 2551
2421 2552 while in_buffer.pos < in_buffer.size:
2422 2553 assert out_buffer.pos == 0
2423 2554
2424 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2555 zresult = lib.ZSTD_decompressStream(
2556 self._dctx, out_buffer, in_buffer
2557 )
2425 2558 if lib.ZSTD_isError(zresult):
2426 raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult))
2559 raise ZstdError(
2560 "zstd decompress error: %s" % _zstd_error(zresult)
2561 )
2427 2562
2428 2563 if out_buffer.pos:
2429 2564 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
2430 2565 out_buffer.pos = 0
2431 2566 yield data
2432 2567
2433 2568 if zresult == 0:
2434 2569 return
2435 2570
2436 2571 # Repeat loop to collect more input data.
2437 2572 continue
2438 2573
2439 2574 # If we get here, input is exhausted.
2440 2575
2441 2576 read_from = read_to_iter
2442 2577
2443 2578 def stream_writer(
2444 2579 self,
2445 2580 writer,
2446 2581 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2447 2582 write_return_read=False,
2448 2583 ):
2449 2584 if not hasattr(writer, "write"):
2450 2585 raise ValueError("must pass an object with a write() method")
2451 2586
2452 return ZstdDecompressionWriter(self, writer, write_size, write_return_read)
2587 return ZstdDecompressionWriter(
2588 self, writer, write_size, write_return_read
2589 )
2453 2590
2454 2591 write_to = stream_writer
2455 2592
2456 2593 def copy_stream(
2457 2594 self,
2458 2595 ifh,
2459 2596 ofh,
2460 2597 read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
2461 2598 write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
2462 2599 ):
2463 2600 if not hasattr(ifh, "read"):
2464 2601 raise ValueError("first argument must have a read() method")
2465 2602 if not hasattr(ofh, "write"):
2466 2603 raise ValueError("second argument must have a write() method")
2467 2604
2468 2605 self._ensure_dctx()
2469 2606
2470 2607 in_buffer = ffi.new("ZSTD_inBuffer *")
2471 2608 out_buffer = ffi.new("ZSTD_outBuffer *")
2472 2609
2473 2610 dst_buffer = ffi.new("char[]", write_size)
2474 2611 out_buffer.dst = dst_buffer
2475 2612 out_buffer.size = write_size
2476 2613 out_buffer.pos = 0
2477 2614
2478 2615 total_read, total_write = 0, 0
2479 2616
2480 2617 # Read all available input.
2481 2618 while True:
2482 2619 data = ifh.read(read_size)
2483 2620 if not data:
2484 2621 break
2485 2622
2486 2623 data_buffer = ffi.from_buffer(data)
2487 2624 total_read += len(data_buffer)
2488 2625 in_buffer.src = data_buffer
2489 2626 in_buffer.size = len(data_buffer)
2490 2627 in_buffer.pos = 0
2491 2628
2492 2629 # Flush all read data to output.
2493 2630 while in_buffer.pos < in_buffer.size:
2494 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2631 zresult = lib.ZSTD_decompressStream(
2632 self._dctx, out_buffer, in_buffer
2633 )
2495 2634 if lib.ZSTD_isError(zresult):
2496 2635 raise ZstdError(
2497 2636 "zstd decompressor error: %s" % _zstd_error(zresult)
2498 2637 )
2499 2638
2500 2639 if out_buffer.pos:
2501 2640 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
2502 2641 total_write += out_buffer.pos
2503 2642 out_buffer.pos = 0
2504 2643
2505 2644 # Continue loop to keep reading.
2506 2645
2507 2646 return total_read, total_write
2508 2647
2509 2648 def decompress_content_dict_chain(self, frames):
2510 2649 if not isinstance(frames, list):
2511 2650 raise TypeError("argument must be a list")
2512 2651
2513 2652 if not frames:
2514 2653 raise ValueError("empty input chain")
2515 2654
2516 2655 # First chunk should not be using a dictionary. We handle it specially.
2517 2656 chunk = frames[0]
2518 2657 if not isinstance(chunk, bytes_type):
2519 2658 raise ValueError("chunk 0 must be bytes")
2520 2659
2521 2660 # All chunks should be zstd frames and should have content size set.
2522 2661 chunk_buffer = ffi.from_buffer(chunk)
2523 2662 params = ffi.new("ZSTD_frameHeader *")
2524 zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
2663 zresult = lib.ZSTD_getFrameHeader(
2664 params, chunk_buffer, len(chunk_buffer)
2665 )
2525 2666 if lib.ZSTD_isError(zresult):
2526 2667 raise ValueError("chunk 0 is not a valid zstd frame")
2527 2668 elif zresult:
2528 2669 raise ValueError("chunk 0 is too small to contain a zstd frame")
2529 2670
2530 2671 if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2531 2672 raise ValueError("chunk 0 missing content size in frame")
2532 2673
2533 2674 self._ensure_dctx(load_dict=False)
2534 2675
2535 2676 last_buffer = ffi.new("char[]", params.frameContentSize)
2536 2677
2537 2678 out_buffer = ffi.new("ZSTD_outBuffer *")
2538 2679 out_buffer.dst = last_buffer
2539 2680 out_buffer.size = len(last_buffer)
2540 2681 out_buffer.pos = 0
2541 2682
2542 2683 in_buffer = ffi.new("ZSTD_inBuffer *")
2543 2684 in_buffer.src = chunk_buffer
2544 2685 in_buffer.size = len(chunk_buffer)
2545 2686 in_buffer.pos = 0
2546 2687
2547 2688 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2548 2689 if lib.ZSTD_isError(zresult):
2549 raise ZstdError("could not decompress chunk 0: %s" % _zstd_error(zresult))
2690 raise ZstdError(
2691 "could not decompress chunk 0: %s" % _zstd_error(zresult)
2692 )
2550 2693 elif zresult:
2551 2694 raise ZstdError("chunk 0 did not decompress full frame")
2552 2695
2553 2696 # Special case of chain length of 1
2554 2697 if len(frames) == 1:
2555 2698 return ffi.buffer(last_buffer, len(last_buffer))[:]
2556 2699
2557 2700 i = 1
2558 2701 while i < len(frames):
2559 2702 chunk = frames[i]
2560 2703 if not isinstance(chunk, bytes_type):
2561 2704 raise ValueError("chunk %d must be bytes" % i)
2562 2705
2563 2706 chunk_buffer = ffi.from_buffer(chunk)
2564 zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
2707 zresult = lib.ZSTD_getFrameHeader(
2708 params, chunk_buffer, len(chunk_buffer)
2709 )
2565 2710 if lib.ZSTD_isError(zresult):
2566 2711 raise ValueError("chunk %d is not a valid zstd frame" % i)
2567 2712 elif zresult:
2568 raise ValueError("chunk %d is too small to contain a zstd frame" % i)
2713 raise ValueError(
2714 "chunk %d is too small to contain a zstd frame" % i
2715 )
2569 2716
2570 2717 if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
2571 2718 raise ValueError("chunk %d missing content size in frame" % i)
2572 2719
2573 2720 dest_buffer = ffi.new("char[]", params.frameContentSize)
2574 2721
2575 2722 out_buffer.dst = dest_buffer
2576 2723 out_buffer.size = len(dest_buffer)
2577 2724 out_buffer.pos = 0
2578 2725
2579 2726 in_buffer.src = chunk_buffer
2580 2727 in_buffer.size = len(chunk_buffer)
2581 2728 in_buffer.pos = 0
2582 2729
2583 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
2730 zresult = lib.ZSTD_decompressStream(
2731 self._dctx, out_buffer, in_buffer
2732 )
2584 2733 if lib.ZSTD_isError(zresult):
2585 2734 raise ZstdError(
2586 2735 "could not decompress chunk %d: %s" % _zstd_error(zresult)
2587 2736 )
2588 2737 elif zresult:
2589 2738 raise ZstdError("chunk %d did not decompress full frame" % i)
2590 2739
2591 2740 last_buffer = dest_buffer
2592 2741 i += 1
2593 2742
2594 2743 return ffi.buffer(last_buffer, len(last_buffer))[:]
2595 2744
2596 2745 def _ensure_dctx(self, load_dict=True):
2597 2746 lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only)
2598 2747
2599 2748 if self._max_window_size:
2600 zresult = lib.ZSTD_DCtx_setMaxWindowSize(self._dctx, self._max_window_size)
2749 zresult = lib.ZSTD_DCtx_setMaxWindowSize(
2750 self._dctx, self._max_window_size
2751 )
2601 2752 if lib.ZSTD_isError(zresult):
2602 2753 raise ZstdError(
2603 2754 "unable to set max window size: %s" % _zstd_error(zresult)
2604 2755 )
2605 2756
2606 2757 zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format)
2607 2758 if lib.ZSTD_isError(zresult):
2608 raise ZstdError("unable to set decoding format: %s" % _zstd_error(zresult))
2759 raise ZstdError(
2760 "unable to set decoding format: %s" % _zstd_error(zresult)
2761 )
2609 2762
2610 2763 if self._dict_data and load_dict:
2611 2764 zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict)
2612 2765 if lib.ZSTD_isError(zresult):
2613 2766 raise ZstdError(
2614 "unable to reference prepared dictionary: %s" % _zstd_error(zresult)
2767 "unable to reference prepared dictionary: %s"
2768 % _zstd_error(zresult)
2615 2769 )
@@ -1,48 +1,48 b''
1 1 SOURCES=$(notdir $(wildcard ../mercurial/helptext/*.[0-9].txt))
2 2 MAN=$(SOURCES:%.txt=%)
3 3 HTML=$(SOURCES:%.txt=%.html)
4 4 GENDOC=gendoc.py ../mercurial/commands.py ../mercurial/help.py \
5 5 ../mercurial/helptext/*.txt ../hgext/*.py ../hgext/*/__init__.py
6 6 PREFIX=/usr/local
7 7 MANDIR=$(PREFIX)/share/man
8 INSTALL=install -c -m 644
8 INSTALL=install -m 644
9 9 PYTHON?=python
10 10 RSTARGS=
11 11
12 12 export HGENCODING=UTF-8
13 13
14 14 all: man html
15 15
16 16 man: $(MAN)
17 17
18 18 html: $(HTML)
19 19
20 20 # This logic is duplicated in setup.py:hgbuilddoc()
21 21 common.txt $(SOURCES) $(SOURCES:%.txt=%.gendoc.txt): $(GENDOC)
22 22 ${PYTHON} gendoc.py "$(basename $@)" > $@.tmp
23 23 mv $@.tmp $@
24 24
25 25 %: %.txt %.gendoc.txt common.txt
26 26 $(PYTHON) runrst hgmanpage $(RSTARGS) --halt warning \
27 27 --strip-elements-with-class htmlonly $*.txt $*
28 28
29 29 %.html: %.txt %.gendoc.txt common.txt
30 30 $(PYTHON) runrst html $(RSTARGS) --halt warning \
31 31 --link-stylesheet --stylesheet-path style.css $*.txt $*.html
32 32
33 33 MANIFEST: man html
34 34 # tracked files are already in the main MANIFEST
35 35 $(RM) $@
36 36 for i in $(MAN) $(HTML); do \
37 37 echo "doc/$$i" >> $@ ; \
38 38 done
39 39
40 40 install: man
41 41 for i in $(MAN) ; do \
42 42 subdir=`echo $$i | sed -n 's/^.*\.\([0-9]\)$$/man\1/p'` ; \
43 43 mkdir -p "$(DESTDIR)$(MANDIR)"/$$subdir ; \
44 44 $(INSTALL) $$i "$(DESTDIR)$(MANDIR)"/$$subdir ; \
45 45 done
46 46
47 47 clean:
48 48 $(RM) $(MAN) $(HTML) common.txt $(SOURCES) $(SOURCES:%.txt=%.gendoc.txt) MANIFEST
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file copied from tests/test-rename.t to tests/test-rename-rev.t
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed
The requested commit or file is too big and content was truncated. Show full diff
General Comments 0
You need to be logged in to leave comments. Login now