##// END OF EJS Templates
cborutil: implement support for streaming encoding, bytestring decoding...
Gregory Szorc -
r37729:65a23cc8 default
parent child Browse files
Show More
@@ -0,0 +1,258 b''
1 # cborutil.py - CBOR extensions
2 #
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 #
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
7
8 from __future__ import absolute_import
9
10 import struct
11
12 from ..thirdparty.cbor.cbor2 import (
13 decoder as decodermod,
14 )
15
16 # Very short very of RFC 7049...
17 #
18 # Each item begins with a byte. The 3 high bits of that byte denote the
19 # "major type." The lower 5 bits denote the "subtype." Each major type
20 # has its own encoding mechanism.
21 #
22 # Most types have lengths. However, bytestring, string, array, and map
23 # can be indefinite length. These are denotes by a subtype with value 31.
24 # Sub-components of those types then come afterwards and are terminated
25 # by a "break" byte.
26
27 MAJOR_TYPE_UINT = 0
28 MAJOR_TYPE_NEGINT = 1
29 MAJOR_TYPE_BYTESTRING = 2
30 MAJOR_TYPE_STRING = 3
31 MAJOR_TYPE_ARRAY = 4
32 MAJOR_TYPE_MAP = 5
33 MAJOR_TYPE_SEMANTIC = 6
34 MAJOR_TYPE_SPECIAL = 7
35
36 SUBTYPE_MASK = 0b00011111
37
38 SUBTYPE_HALF_FLOAT = 25
39 SUBTYPE_SINGLE_FLOAT = 26
40 SUBTYPE_DOUBLE_FLOAT = 27
41 SUBTYPE_INDEFINITE = 31
42
43 # Indefinite types begin with their major type ORd with information value 31.
44 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
45 r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
46 BEGIN_INDEFINITE_ARRAY = struct.pack(
47 r'>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE)
48 BEGIN_INDEFINITE_MAP = struct.pack(
49 r'>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE)
50
51 ENCODED_LENGTH_1 = struct.Struct(r'>B')
52 ENCODED_LENGTH_2 = struct.Struct(r'>BB')
53 ENCODED_LENGTH_3 = struct.Struct(r'>BH')
54 ENCODED_LENGTH_4 = struct.Struct(r'>BL')
55 ENCODED_LENGTH_5 = struct.Struct(r'>BQ')
56
57 # The break ends an indefinite length item.
58 BREAK = b'\xff'
59 BREAK_INT = 255
60
61 def encodelength(majortype, length):
62 """Obtain a value encoding the major type and its length."""
63 if length < 24:
64 return ENCODED_LENGTH_1.pack(majortype << 5 | length)
65 elif length < 256:
66 return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
67 elif length < 65536:
68 return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
69 elif length < 4294967296:
70 return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
71 else:
72 return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
73
74 def streamencodebytestring(v):
75 yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
76 yield v
77
78 def streamencodebytestringfromiter(it):
79 """Convert an iterator of chunks to an indefinite bytestring.
80
81 Given an input that is iterable and each element in the iterator is
82 representable as bytes, emit an indefinite length bytestring.
83 """
84 yield BEGIN_INDEFINITE_BYTESTRING
85
86 for chunk in it:
87 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
88 yield chunk
89
90 yield BREAK
91
92 def streamencodeindefinitebytestring(source, chunksize=65536):
93 """Given a large source buffer, emit as an indefinite length bytestring.
94
95 This is a generator of chunks constituting the encoded CBOR data.
96 """
97 yield BEGIN_INDEFINITE_BYTESTRING
98
99 i = 0
100 l = len(source)
101
102 while True:
103 chunk = source[i:i + chunksize]
104 i += len(chunk)
105
106 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
107 yield chunk
108
109 if i >= l:
110 break
111
112 yield BREAK
113
114 def streamencodeint(v):
115 if v >= 18446744073709551616 or v < -18446744073709551616:
116 raise ValueError('big integers not supported')
117
118 if v >= 0:
119 yield encodelength(MAJOR_TYPE_UINT, v)
120 else:
121 yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
122
123 def streamencodearray(l):
124 """Encode a known size iterable to an array."""
125
126 yield encodelength(MAJOR_TYPE_ARRAY, len(l))
127
128 for i in l:
129 for chunk in streamencode(i):
130 yield chunk
131
132 def streamencodearrayfromiter(it):
133 """Encode an iterator of items to an indefinite length array."""
134
135 yield BEGIN_INDEFINITE_ARRAY
136
137 for i in it:
138 for chunk in streamencode(i):
139 yield chunk
140
141 yield BREAK
142
143 def streamencodeset(s):
144 # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
145 # semantic tag 258 for finite sets.
146 yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
147
148 for chunk in streamencodearray(sorted(s)):
149 yield chunk
150
151 def streamencodemap(d):
152 """Encode dictionary to a generator.
153
154 Does not supporting indefinite length dictionaries.
155 """
156 yield encodelength(MAJOR_TYPE_MAP, len(d))
157
158 for key, value in sorted(d.iteritems()):
159 for chunk in streamencode(key):
160 yield chunk
161 for chunk in streamencode(value):
162 yield chunk
163
164 def streamencodemapfromiter(it):
165 """Given an iterable of (key, value), encode to an indefinite length map."""
166 yield BEGIN_INDEFINITE_MAP
167
168 for key, value in it:
169 for chunk in streamencode(key):
170 yield chunk
171 for chunk in streamencode(value):
172 yield chunk
173
174 yield BREAK
175
176 def streamencodebool(b):
177 # major type 7, simple value 20 and 21.
178 yield b'\xf5' if b else b'\xf4'
179
180 def streamencodenone(v):
181 # major type 7, simple value 22.
182 yield b'\xf6'
183
184 STREAM_ENCODERS = {
185 bytes: streamencodebytestring,
186 int: streamencodeint,
187 list: streamencodearray,
188 tuple: streamencodearray,
189 dict: streamencodemap,
190 set: streamencodeset,
191 bool: streamencodebool,
192 type(None): streamencodenone,
193 }
194
195 def streamencode(v):
196 """Encode a value in a streaming manner.
197
198 Given an input object, encode it to CBOR recursively.
199
200 Returns a generator of CBOR encoded bytes. There is no guarantee
201 that each emitted chunk fully decodes to a value or sub-value.
202
203 Encoding is deterministic - unordered collections are sorted.
204 """
205 fn = STREAM_ENCODERS.get(v.__class__)
206
207 if not fn:
208 raise ValueError('do not know how to encode %s' % type(v))
209
210 return fn(v)
211
212 def readindefinitebytestringtoiter(fh, expectheader=True):
213 """Read an indefinite bytestring to a generator.
214
215 Receives an object with a ``read(X)`` method to read N bytes.
216
217 If ``expectheader`` is True, it is expected that the first byte read
218 will represent an indefinite length bytestring. Otherwise, we
219 expect the first byte to be part of the first bytestring chunk.
220 """
221 read = fh.read
222 decodeuint = decodermod.decode_uint
223 byteasinteger = decodermod.byte_as_integer
224
225 if expectheader:
226 initial = decodermod.byte_as_integer(read(1))
227
228 majortype = initial >> 5
229 subtype = initial & SUBTYPE_MASK
230
231 if majortype != MAJOR_TYPE_BYTESTRING:
232 raise decodermod.CBORDecodeError(
233 'expected major type %d; got %d' % (MAJOR_TYPE_BYTESTRING,
234 majortype))
235
236 if subtype != SUBTYPE_INDEFINITE:
237 raise decodermod.CBORDecodeError(
238 'expected indefinite subtype; got %d' % subtype)
239
240 # The indefinite bytestring is composed of chunks of normal bytestrings.
241 # Read chunks until we hit a BREAK byte.
242
243 while True:
244 # We need to sniff for the BREAK byte.
245 initial = byteasinteger(read(1))
246
247 if initial == BREAK_INT:
248 break
249
250 length = decodeuint(fh, initial & SUBTYPE_MASK)
251 chunk = read(length)
252
253 if len(chunk) != length:
254 raise decodermod.CBORDecodeError(
255 'failed to read bytestring chunk: got %d bytes; expected %d' % (
256 len(chunk), length))
257
258 yield chunk
@@ -0,0 +1,210 b''
1 from __future__ import absolute_import
2
3 import io
4 import unittest
5
6 from mercurial.thirdparty import (
7 cbor,
8 )
9 from mercurial.utils import (
10 cborutil,
11 )
12
13 def loadit(it):
14 return cbor.loads(b''.join(it))
15
16 class BytestringTests(unittest.TestCase):
17 def testsimple(self):
18 self.assertEqual(
19 list(cborutil.streamencode(b'foobar')),
20 [b'\x46', b'foobar'])
21
22 self.assertEqual(
23 loadit(cborutil.streamencode(b'foobar')),
24 b'foobar')
25
26 def testlong(self):
27 source = b'x' * 1048576
28
29 self.assertEqual(loadit(cborutil.streamencode(source)), source)
30
31 def testfromiter(self):
32 # This is the example from RFC 7049 Section 2.2.2.
33 source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
34
35 self.assertEqual(
36 list(cborutil.streamencodebytestringfromiter(source)),
37 [
38 b'\x5f',
39 b'\x44',
40 b'\xaa\xbb\xcc\xdd',
41 b'\x43',
42 b'\xee\xff\x99',
43 b'\xff',
44 ])
45
46 self.assertEqual(
47 loadit(cborutil.streamencodebytestringfromiter(source)),
48 b''.join(source))
49
50 def testfromiterlarge(self):
51 source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
52
53 self.assertEqual(
54 loadit(cborutil.streamencodebytestringfromiter(source)),
55 b''.join(source))
56
57 def testindefinite(self):
58 source = b'\x00\x01\x02\x03' + b'\xff' * 16384
59
60 it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
61
62 self.assertEqual(next(it), b'\x5f')
63 self.assertEqual(next(it), b'\x42')
64 self.assertEqual(next(it), b'\x00\x01')
65 self.assertEqual(next(it), b'\x42')
66 self.assertEqual(next(it), b'\x02\x03')
67 self.assertEqual(next(it), b'\x42')
68 self.assertEqual(next(it), b'\xff\xff')
69
70 dest = b''.join(cborutil.streamencodeindefinitebytestring(
71 source, chunksize=42))
72 self.assertEqual(cbor.loads(dest), b''.join(source))
73
74 def testreadtoiter(self):
75 source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
76
77 it = cborutil.readindefinitebytestringtoiter(source)
78 self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd')
79 self.assertEqual(next(it), b'\xee\xff\x99')
80
81 with self.assertRaises(StopIteration):
82 next(it)
83
84 class IntTests(unittest.TestCase):
85 def testsmall(self):
86 self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
87 self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
88 self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
89 self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
90 self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
91
92 def testnegativesmall(self):
93 self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
94 self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
95 self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
96 self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
97 self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
98
99 def testrange(self):
100 for i in range(-70000, 70000, 10):
101 self.assertEqual(
102 b''.join(cborutil.streamencode(i)),
103 cbor.dumps(i))
104
105 class ArrayTests(unittest.TestCase):
106 def testempty(self):
107 self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
108 self.assertEqual(loadit(cborutil.streamencode([])), [])
109
110 def testbasic(self):
111 source = [b'foo', b'bar', 1, -10]
112
113 self.assertEqual(list(cborutil.streamencode(source)), [
114 b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
115
116 def testemptyfromiter(self):
117 self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
118 b'\x9f\xff')
119
120 def testfromiter1(self):
121 source = [b'foo']
122
123 self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [
124 b'\x9f',
125 b'\x43', b'foo',
126 b'\xff',
127 ])
128
129 dest = b''.join(cborutil.streamencodearrayfromiter(source))
130 self.assertEqual(cbor.loads(dest), source)
131
132 def testtuple(self):
133 source = (b'foo', None, 42)
134
135 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
136 list(source))
137
138 class SetTests(unittest.TestCase):
139 def testempty(self):
140 self.assertEqual(list(cborutil.streamencode(set())), [
141 b'\xd9\x01\x02',
142 b'\x80',
143 ])
144
145 def testset(self):
146 source = {b'foo', None, 42}
147
148 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
149 source)
150
151 class BoolTests(unittest.TestCase):
152 def testbasic(self):
153 self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5'])
154 self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
155
156 self.assertIs(loadit(cborutil.streamencode(True)), True)
157 self.assertIs(loadit(cborutil.streamencode(False)), False)
158
159 class NoneTests(unittest.TestCase):
160 def testbasic(self):
161 self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
162
163 self.assertIs(loadit(cborutil.streamencode(None)), None)
164
165 class MapTests(unittest.TestCase):
166 def testempty(self):
167 self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
168 self.assertEqual(loadit(cborutil.streamencode({})), {})
169
170 def testemptyindefinite(self):
171 self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
172 b'\xbf', b'\xff'])
173
174 self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
175
176 def testone(self):
177 source = {b'foo': b'bar'}
178 self.assertEqual(list(cborutil.streamencode(source)), [
179 b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
180
181 self.assertEqual(loadit(cborutil.streamencode(source)), source)
182
183 def testmultiple(self):
184 source = {
185 b'foo': b'bar',
186 b'baz': b'value1',
187 }
188
189 self.assertEqual(loadit(cborutil.streamencode(source)), source)
190
191 self.assertEqual(
192 loadit(cborutil.streamencodemapfromiter(source.items())),
193 source)
194
195 def testcomplex(self):
196 source = {
197 b'key': 1,
198 2: -10,
199 }
200
201 self.assertEqual(loadit(cborutil.streamencode(source)),
202 source)
203
204 self.assertEqual(
205 loadit(cborutil.streamencodemapfromiter(source.items())),
206 source)
207
208 if __name__ == '__main__':
209 import silenttestrunner
210 silenttestrunner.main(__name__)
@@ -1,779 +1,781 b''
1 1 #!/usr/bin/env python
2 2
3 3 from __future__ import absolute_import, print_function
4 4
5 5 import ast
6 6 import collections
7 7 import os
8 8 import re
9 9 import sys
10 10
11 11 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
12 12 # to work when run from a virtualenv. The modules were chosen empirically
13 13 # so that the return value matches the return value without virtualenv.
14 14 if True: # disable lexical sorting checks
15 15 try:
16 16 import BaseHTTPServer as basehttpserver
17 17 except ImportError:
18 18 basehttpserver = None
19 19 import zlib
20 20
21 21 # Whitelist of modules that symbols can be directly imported from.
22 22 allowsymbolimports = (
23 23 '__future__',
24 24 'bzrlib',
25 25 'hgclient',
26 26 'mercurial',
27 27 'mercurial.hgweb.common',
28 28 'mercurial.hgweb.request',
29 29 'mercurial.i18n',
30 30 'mercurial.node',
31 31 # for cffi modules to re-export pure functions
32 32 'mercurial.pure.base85',
33 33 'mercurial.pure.bdiff',
34 34 'mercurial.pure.mpatch',
35 35 'mercurial.pure.osutil',
36 36 'mercurial.pure.parsers',
37 37 # third-party imports should be directly imported
38 38 'mercurial.thirdparty',
39 'mercurial.thirdparty.cbor',
40 'mercurial.thirdparty.cbor.cbor2',
39 41 'mercurial.thirdparty.zope',
40 42 'mercurial.thirdparty.zope.interface',
41 43 )
42 44
43 45 # Whitelist of symbols that can be directly imported.
44 46 directsymbols = (
45 47 'demandimport',
46 48 )
47 49
48 50 # Modules that must be aliased because they are commonly confused with
49 51 # common variables and can create aliasing and readability issues.
50 52 requirealias = {
51 53 'ui': 'uimod',
52 54 }
53 55
54 56 def usingabsolute(root):
55 57 """Whether absolute imports are being used."""
56 58 if sys.version_info[0] >= 3:
57 59 return True
58 60
59 61 for node in ast.walk(root):
60 62 if isinstance(node, ast.ImportFrom):
61 63 if node.module == '__future__':
62 64 for n in node.names:
63 65 if n.name == 'absolute_import':
64 66 return True
65 67
66 68 return False
67 69
68 70 def walklocal(root):
69 71 """Recursively yield all descendant nodes but not in a different scope"""
70 72 todo = collections.deque(ast.iter_child_nodes(root))
71 73 yield root, False
72 74 while todo:
73 75 node = todo.popleft()
74 76 newscope = isinstance(node, ast.FunctionDef)
75 77 if not newscope:
76 78 todo.extend(ast.iter_child_nodes(node))
77 79 yield node, newscope
78 80
79 81 def dotted_name_of_path(path):
80 82 """Given a relative path to a source file, return its dotted module name.
81 83
82 84 >>> dotted_name_of_path('mercurial/error.py')
83 85 'mercurial.error'
84 86 >>> dotted_name_of_path('zlibmodule.so')
85 87 'zlib'
86 88 """
87 89 parts = path.replace(os.sep, '/').split('/')
88 90 parts[-1] = parts[-1].split('.', 1)[0] # remove .py and .so and .ARCH.so
89 91 if parts[-1].endswith('module'):
90 92 parts[-1] = parts[-1][:-6]
91 93 return '.'.join(parts)
92 94
93 95 def fromlocalfunc(modulename, localmods):
94 96 """Get a function to examine which locally defined module the
95 97 target source imports via a specified name.
96 98
97 99 `modulename` is an `dotted_name_of_path()`-ed source file path,
98 100 which may have `.__init__` at the end of it, of the target source.
99 101
100 102 `localmods` is a set of absolute `dotted_name_of_path()`-ed source file
101 103 paths of locally defined (= Mercurial specific) modules.
102 104
103 105 This function assumes that module names not existing in
104 106 `localmods` are from the Python standard library.
105 107
106 108 This function returns the function, which takes `name` argument,
107 109 and returns `(absname, dottedpath, hassubmod)` tuple if `name`
108 110 matches against locally defined module. Otherwise, it returns
109 111 False.
110 112
111 113 It is assumed that `name` doesn't have `.__init__`.
112 114
113 115 `absname` is an absolute module name of specified `name`
114 116 (e.g. "hgext.convert"). This can be used to compose prefix for sub
115 117 modules or so.
116 118
117 119 `dottedpath` is a `dotted_name_of_path()`-ed source file path
118 120 (e.g. "hgext.convert.__init__") of `name`. This is used to look
119 121 module up in `localmods` again.
120 122
121 123 `hassubmod` is whether it may have sub modules under it (for
122 124 convenient, even though this is also equivalent to "absname !=
123 125 dottednpath")
124 126
125 127 >>> localmods = {'foo.__init__', 'foo.foo1',
126 128 ... 'foo.bar.__init__', 'foo.bar.bar1',
127 129 ... 'baz.__init__', 'baz.baz1'}
128 130 >>> fromlocal = fromlocalfunc('foo.xxx', localmods)
129 131 >>> # relative
130 132 >>> fromlocal('foo1')
131 133 ('foo.foo1', 'foo.foo1', False)
132 134 >>> fromlocal('bar')
133 135 ('foo.bar', 'foo.bar.__init__', True)
134 136 >>> fromlocal('bar.bar1')
135 137 ('foo.bar.bar1', 'foo.bar.bar1', False)
136 138 >>> # absolute
137 139 >>> fromlocal('baz')
138 140 ('baz', 'baz.__init__', True)
139 141 >>> fromlocal('baz.baz1')
140 142 ('baz.baz1', 'baz.baz1', False)
141 143 >>> # unknown = maybe standard library
142 144 >>> fromlocal('os')
143 145 False
144 146 >>> fromlocal(None, 1)
145 147 ('foo', 'foo.__init__', True)
146 148 >>> fromlocal('foo1', 1)
147 149 ('foo.foo1', 'foo.foo1', False)
148 150 >>> fromlocal2 = fromlocalfunc('foo.xxx.yyy', localmods)
149 151 >>> fromlocal2(None, 2)
150 152 ('foo', 'foo.__init__', True)
151 153 >>> fromlocal2('bar2', 1)
152 154 False
153 155 >>> fromlocal2('bar', 2)
154 156 ('foo.bar', 'foo.bar.__init__', True)
155 157 """
156 158 if not isinstance(modulename, str):
157 159 modulename = modulename.decode('ascii')
158 160 prefix = '.'.join(modulename.split('.')[:-1])
159 161 if prefix:
160 162 prefix += '.'
161 163 def fromlocal(name, level=0):
162 164 # name is false value when relative imports are used.
163 165 if not name:
164 166 # If relative imports are used, level must not be absolute.
165 167 assert level > 0
166 168 candidates = ['.'.join(modulename.split('.')[:-level])]
167 169 else:
168 170 if not level:
169 171 # Check relative name first.
170 172 candidates = [prefix + name, name]
171 173 else:
172 174 candidates = ['.'.join(modulename.split('.')[:-level]) +
173 175 '.' + name]
174 176
175 177 for n in candidates:
176 178 if n in localmods:
177 179 return (n, n, False)
178 180 dottedpath = n + '.__init__'
179 181 if dottedpath in localmods:
180 182 return (n, dottedpath, True)
181 183 return False
182 184 return fromlocal
183 185
184 186 def populateextmods(localmods):
185 187 """Populate C extension modules based on pure modules"""
186 188 newlocalmods = set(localmods)
187 189 for n in localmods:
188 190 if n.startswith('mercurial.pure.'):
189 191 m = n[len('mercurial.pure.'):]
190 192 newlocalmods.add('mercurial.cext.' + m)
191 193 newlocalmods.add('mercurial.cffi._' + m)
192 194 return newlocalmods
193 195
194 196 def list_stdlib_modules():
195 197 """List the modules present in the stdlib.
196 198
197 199 >>> py3 = sys.version_info[0] >= 3
198 200 >>> mods = set(list_stdlib_modules())
199 201 >>> 'BaseHTTPServer' in mods or py3
200 202 True
201 203
202 204 os.path isn't really a module, so it's missing:
203 205
204 206 >>> 'os.path' in mods
205 207 False
206 208
207 209 sys requires special treatment, because it's baked into the
208 210 interpreter, but it should still appear:
209 211
210 212 >>> 'sys' in mods
211 213 True
212 214
213 215 >>> 'collections' in mods
214 216 True
215 217
216 218 >>> 'cStringIO' in mods or py3
217 219 True
218 220
219 221 >>> 'cffi' in mods
220 222 True
221 223 """
222 224 for m in sys.builtin_module_names:
223 225 yield m
224 226 # These modules only exist on windows, but we should always
225 227 # consider them stdlib.
226 228 for m in ['msvcrt', '_winreg']:
227 229 yield m
228 230 yield '__builtin__'
229 231 yield 'builtins' # python3 only
230 232 yield 'importlib.abc' # python3 only
231 233 yield 'importlib.machinery' # python3 only
232 234 yield 'importlib.util' # python3 only
233 235 for m in 'fcntl', 'grp', 'pwd', 'termios': # Unix only
234 236 yield m
235 237 for m in 'cPickle', 'datetime': # in Python (not C) on PyPy
236 238 yield m
237 239 for m in ['cffi']:
238 240 yield m
239 241 stdlib_prefixes = {sys.prefix, sys.exec_prefix}
240 242 # We need to supplement the list of prefixes for the search to work
241 243 # when run from within a virtualenv.
242 244 for mod in (basehttpserver, zlib):
243 245 if mod is None:
244 246 continue
245 247 try:
246 248 # Not all module objects have a __file__ attribute.
247 249 filename = mod.__file__
248 250 except AttributeError:
249 251 continue
250 252 dirname = os.path.dirname(filename)
251 253 for prefix in stdlib_prefixes:
252 254 if dirname.startswith(prefix):
253 255 # Then this directory is redundant.
254 256 break
255 257 else:
256 258 stdlib_prefixes.add(dirname)
257 259 for libpath in sys.path:
258 260 # We want to walk everything in sys.path that starts with
259 261 # something in stdlib_prefixes.
260 262 if not any(libpath.startswith(p) for p in stdlib_prefixes):
261 263 continue
262 264 for top, dirs, files in os.walk(libpath):
263 265 for i, d in reversed(list(enumerate(dirs))):
264 266 if (not os.path.exists(os.path.join(top, d, '__init__.py'))
265 267 or top == libpath and d in ('hgdemandimport', 'hgext',
266 268 'mercurial')):
267 269 del dirs[i]
268 270 for name in files:
269 271 if not name.endswith(('.py', '.so', '.pyc', '.pyo', '.pyd')):
270 272 continue
271 273 if name.startswith('__init__.py'):
272 274 full_path = top
273 275 else:
274 276 full_path = os.path.join(top, name)
275 277 rel_path = full_path[len(libpath) + 1:]
276 278 mod = dotted_name_of_path(rel_path)
277 279 yield mod
278 280
279 281 stdlib_modules = set(list_stdlib_modules())
280 282
281 283 def imported_modules(source, modulename, f, localmods, ignore_nested=False):
282 284 """Given the source of a file as a string, yield the names
283 285 imported by that file.
284 286
285 287 Args:
286 288 source: The python source to examine as a string.
287 289 modulename: of specified python source (may have `__init__`)
288 290 localmods: set of locally defined module names (may have `__init__`)
289 291 ignore_nested: If true, import statements that do not start in
290 292 column zero will be ignored.
291 293
292 294 Returns:
293 295 A list of absolute module names imported by the given source.
294 296
295 297 >>> f = 'foo/xxx.py'
296 298 >>> modulename = 'foo.xxx'
297 299 >>> localmods = {'foo.__init__': True,
298 300 ... 'foo.foo1': True, 'foo.foo2': True,
299 301 ... 'foo.bar.__init__': True, 'foo.bar.bar1': True,
300 302 ... 'baz.__init__': True, 'baz.baz1': True }
301 303 >>> # standard library (= not locally defined ones)
302 304 >>> sorted(imported_modules(
303 305 ... 'from stdlib1 import foo, bar; import stdlib2',
304 306 ... modulename, f, localmods))
305 307 []
306 308 >>> # relative importing
307 309 >>> sorted(imported_modules(
308 310 ... 'import foo1; from bar import bar1',
309 311 ... modulename, f, localmods))
310 312 ['foo.bar.bar1', 'foo.foo1']
311 313 >>> sorted(imported_modules(
312 314 ... 'from bar.bar1 import name1, name2, name3',
313 315 ... modulename, f, localmods))
314 316 ['foo.bar.bar1']
315 317 >>> # absolute importing
316 318 >>> sorted(imported_modules(
317 319 ... 'from baz import baz1, name1',
318 320 ... modulename, f, localmods))
319 321 ['baz.__init__', 'baz.baz1']
320 322 >>> # mixed importing, even though it shouldn't be recommended
321 323 >>> sorted(imported_modules(
322 324 ... 'import stdlib, foo1, baz',
323 325 ... modulename, f, localmods))
324 326 ['baz.__init__', 'foo.foo1']
325 327 >>> # ignore_nested
326 328 >>> sorted(imported_modules(
327 329 ... '''import foo
328 330 ... def wat():
329 331 ... import bar
330 332 ... ''', modulename, f, localmods))
331 333 ['foo.__init__', 'foo.bar.__init__']
332 334 >>> sorted(imported_modules(
333 335 ... '''import foo
334 336 ... def wat():
335 337 ... import bar
336 338 ... ''', modulename, f, localmods, ignore_nested=True))
337 339 ['foo.__init__']
338 340 """
339 341 fromlocal = fromlocalfunc(modulename, localmods)
340 342 for node in ast.walk(ast.parse(source, f)):
341 343 if ignore_nested and getattr(node, 'col_offset', 0) > 0:
342 344 continue
343 345 if isinstance(node, ast.Import):
344 346 for n in node.names:
345 347 found = fromlocal(n.name)
346 348 if not found:
347 349 # this should import standard library
348 350 continue
349 351 yield found[1]
350 352 elif isinstance(node, ast.ImportFrom):
351 353 found = fromlocal(node.module, node.level)
352 354 if not found:
353 355 # this should import standard library
354 356 continue
355 357
356 358 absname, dottedpath, hassubmod = found
357 359 if not hassubmod:
358 360 # "dottedpath" is not a package; must be imported
359 361 yield dottedpath
360 362 # examination of "node.names" should be redundant
361 363 # e.g.: from mercurial.node import nullid, nullrev
362 364 continue
363 365
364 366 modnotfound = False
365 367 prefix = absname + '.'
366 368 for n in node.names:
367 369 found = fromlocal(prefix + n.name)
368 370 if not found:
369 371 # this should be a function or a property of "node.module"
370 372 modnotfound = True
371 373 continue
372 374 yield found[1]
373 375 if modnotfound:
374 376 # "dottedpath" is a package, but imported because of non-module
375 377 # lookup
376 378 yield dottedpath
377 379
378 380 def verify_import_convention(module, source, localmods):
379 381 """Verify imports match our established coding convention.
380 382
381 383 We have 2 conventions: legacy and modern. The modern convention is in
382 384 effect when using absolute imports.
383 385
384 386 The legacy convention only looks for mixed imports. The modern convention
385 387 is much more thorough.
386 388 """
387 389 root = ast.parse(source)
388 390 absolute = usingabsolute(root)
389 391
390 392 if absolute:
391 393 return verify_modern_convention(module, root, localmods)
392 394 else:
393 395 return verify_stdlib_on_own_line(root)
394 396
395 397 def verify_modern_convention(module, root, localmods, root_col_offset=0):
396 398 """Verify a file conforms to the modern import convention rules.
397 399
398 400 The rules of the modern convention are:
399 401
400 402 * Ordering is stdlib followed by local imports. Each group is lexically
401 403 sorted.
402 404 * Importing multiple modules via "import X, Y" is not allowed: use
403 405 separate import statements.
404 406 * Importing multiple modules via "from X import ..." is allowed if using
405 407 parenthesis and one entry per line.
406 408 * Only 1 relative import statement per import level ("from .", "from ..")
407 409 is allowed.
408 410 * Relative imports from higher levels must occur before lower levels. e.g.
409 411 "from .." must be before "from .".
410 412 * Imports from peer packages should use relative import (e.g. do not
411 413 "import mercurial.foo" from a "mercurial.*" module).
412 414 * Symbols can only be imported from specific modules (see
413 415 `allowsymbolimports`). For other modules, first import the module then
414 416 assign the symbol to a module-level variable. In addition, these imports
415 417 must be performed before other local imports. This rule only
416 418 applies to import statements outside of any blocks.
417 419 * Relative imports from the standard library are not allowed, unless that
418 420 library is also a local module.
419 421 * Certain modules must be aliased to alternate names to avoid aliasing
420 422 and readability problems. See `requirealias`.
421 423 """
422 424 if not isinstance(module, str):
423 425 module = module.decode('ascii')
424 426 topmodule = module.split('.')[0]
425 427 fromlocal = fromlocalfunc(module, localmods)
426 428
427 429 # Whether a local/non-stdlib import has been performed.
428 430 seenlocal = None
429 431 # Whether a local/non-stdlib, non-symbol import has been seen.
430 432 seennonsymbollocal = False
431 433 # The last name to be imported (for sorting).
432 434 lastname = None
433 435 laststdlib = None
434 436 # Relative import levels encountered so far.
435 437 seenlevels = set()
436 438
437 439 for node, newscope in walklocal(root):
438 440 def msg(fmt, *args):
439 441 return (fmt % args, node.lineno)
440 442 if newscope:
441 443 # Check for local imports in function
442 444 for r in verify_modern_convention(module, node, localmods,
443 445 node.col_offset + 4):
444 446 yield r
445 447 elif isinstance(node, ast.Import):
446 448 # Disallow "import foo, bar" and require separate imports
447 449 # for each module.
448 450 if len(node.names) > 1:
449 451 yield msg('multiple imported names: %s',
450 452 ', '.join(n.name for n in node.names))
451 453
452 454 name = node.names[0].name
453 455 asname = node.names[0].asname
454 456
455 457 stdlib = name in stdlib_modules
456 458
457 459 # Ignore sorting rules on imports inside blocks.
458 460 if node.col_offset == root_col_offset:
459 461 if lastname and name < lastname and laststdlib == stdlib:
460 462 yield msg('imports not lexically sorted: %s < %s',
461 463 name, lastname)
462 464
463 465 lastname = name
464 466 laststdlib = stdlib
465 467
466 468 # stdlib imports should be before local imports.
467 469 if stdlib and seenlocal and node.col_offset == root_col_offset:
468 470 yield msg('stdlib import "%s" follows local import: %s',
469 471 name, seenlocal)
470 472
471 473 if not stdlib:
472 474 seenlocal = name
473 475
474 476 # Import of sibling modules should use relative imports.
475 477 topname = name.split('.')[0]
476 478 if topname == topmodule:
477 479 yield msg('import should be relative: %s', name)
478 480
479 481 if name in requirealias and asname != requirealias[name]:
480 482 yield msg('%s module must be "as" aliased to %s',
481 483 name, requirealias[name])
482 484
483 485 elif isinstance(node, ast.ImportFrom):
484 486 # Resolve the full imported module name.
485 487 if node.level > 0:
486 488 fullname = '.'.join(module.split('.')[:-node.level])
487 489 if node.module:
488 490 fullname += '.%s' % node.module
489 491 else:
490 492 assert node.module
491 493 fullname = node.module
492 494
493 495 topname = fullname.split('.')[0]
494 496 if topname == topmodule:
495 497 yield msg('import should be relative: %s', fullname)
496 498
497 499 # __future__ is special since it needs to come first and use
498 500 # symbol import.
499 501 if fullname != '__future__':
500 502 if not fullname or (
501 503 fullname in stdlib_modules
502 504 and fullname not in localmods
503 505 and fullname + '.__init__' not in localmods):
504 506 yield msg('relative import of stdlib module')
505 507 else:
506 508 seenlocal = fullname
507 509
508 510 # Direct symbol import is only allowed from certain modules and
509 511 # must occur before non-symbol imports.
510 512 found = fromlocal(node.module, node.level)
511 513 if found and found[2]: # node.module is a package
512 514 prefix = found[0] + '.'
513 515 symbols = (n.name for n in node.names
514 516 if not fromlocal(prefix + n.name))
515 517 else:
516 518 symbols = (n.name for n in node.names)
517 519 symbols = [sym for sym in symbols if sym not in directsymbols]
518 520 if node.module and node.col_offset == root_col_offset:
519 521 if symbols and fullname not in allowsymbolimports:
520 522 yield msg('direct symbol import %s from %s',
521 523 ', '.join(symbols), fullname)
522 524
523 525 if symbols and seennonsymbollocal:
524 526 yield msg('symbol import follows non-symbol import: %s',
525 527 fullname)
526 528 if not symbols and fullname not in stdlib_modules:
527 529 seennonsymbollocal = True
528 530
529 531 if not node.module:
530 532 assert node.level
531 533
532 534 # Only allow 1 group per level.
533 535 if (node.level in seenlevels
534 536 and node.col_offset == root_col_offset):
535 537 yield msg('multiple "from %s import" statements',
536 538 '.' * node.level)
537 539
538 540 # Higher-level groups come before lower-level groups.
539 541 if any(node.level > l for l in seenlevels):
540 542 yield msg('higher-level import should come first: %s',
541 543 fullname)
542 544
543 545 seenlevels.add(node.level)
544 546
545 547 # Entries in "from .X import ( ... )" lists must be lexically
546 548 # sorted.
547 549 lastentryname = None
548 550
549 551 for n in node.names:
550 552 if lastentryname and n.name < lastentryname:
551 553 yield msg('imports from %s not lexically sorted: %s < %s',
552 554 fullname, n.name, lastentryname)
553 555
554 556 lastentryname = n.name
555 557
556 558 if n.name in requirealias and n.asname != requirealias[n.name]:
557 559 yield msg('%s from %s must be "as" aliased to %s',
558 560 n.name, fullname, requirealias[n.name])
559 561
560 562 def verify_stdlib_on_own_line(root):
561 563 """Given some python source, verify that stdlib imports are done
562 564 in separate statements from relative local module imports.
563 565
564 566 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, foo')))
565 567 [('mixed imports\\n stdlib: sys\\n relative: foo', 1)]
566 568 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, os')))
567 569 []
568 570 >>> list(verify_stdlib_on_own_line(ast.parse('import foo, bar')))
569 571 []
570 572 """
571 573 for node in ast.walk(root):
572 574 if isinstance(node, ast.Import):
573 575 from_stdlib = {False: [], True: []}
574 576 for n in node.names:
575 577 from_stdlib[n.name in stdlib_modules].append(n.name)
576 578 if from_stdlib[True] and from_stdlib[False]:
577 579 yield ('mixed imports\n stdlib: %s\n relative: %s' %
578 580 (', '.join(sorted(from_stdlib[True])),
579 581 ', '.join(sorted(from_stdlib[False]))), node.lineno)
580 582
581 583 class CircularImport(Exception):
582 584 pass
583 585
584 586 def checkmod(mod, imports):
585 587 shortest = {}
586 588 visit = [[mod]]
587 589 while visit:
588 590 path = visit.pop(0)
589 591 for i in sorted(imports.get(path[-1], [])):
590 592 if len(path) < shortest.get(i, 1000):
591 593 shortest[i] = len(path)
592 594 if i in path:
593 595 if i == path[0]:
594 596 raise CircularImport(path)
595 597 continue
596 598 visit.append(path + [i])
597 599
598 600 def rotatecycle(cycle):
599 601 """arrange a cycle so that the lexicographically first module listed first
600 602
601 603 >>> rotatecycle(['foo', 'bar'])
602 604 ['bar', 'foo', 'bar']
603 605 """
604 606 lowest = min(cycle)
605 607 idx = cycle.index(lowest)
606 608 return cycle[idx:] + cycle[:idx] + [lowest]
607 609
608 610 def find_cycles(imports):
609 611 """Find cycles in an already-loaded import graph.
610 612
611 613 All module names recorded in `imports` should be absolute one.
612 614
613 615 >>> from __future__ import print_function
614 616 >>> imports = {'top.foo': ['top.bar', 'os.path', 'top.qux'],
615 617 ... 'top.bar': ['top.baz', 'sys'],
616 618 ... 'top.baz': ['top.foo'],
617 619 ... 'top.qux': ['top.foo']}
618 620 >>> print('\\n'.join(sorted(find_cycles(imports))))
619 621 top.bar -> top.baz -> top.foo -> top.bar
620 622 top.foo -> top.qux -> top.foo
621 623 """
622 624 cycles = set()
623 625 for mod in sorted(imports.keys()):
624 626 try:
625 627 checkmod(mod, imports)
626 628 except CircularImport as e:
627 629 cycle = e.args[0]
628 630 cycles.add(" -> ".join(rotatecycle(cycle)))
629 631 return cycles
630 632
631 633 def _cycle_sortkey(c):
632 634 return len(c), c
633 635
634 636 def embedded(f, modname, src):
635 637 """Extract embedded python code
636 638
637 639 >>> def _forcestr(thing):
638 640 ... if not isinstance(thing, str):
639 641 ... return thing.decode('ascii')
640 642 ... return thing
641 643 >>> def test(fn, lines):
642 644 ... for s, m, f, l in embedded(fn, b"example", lines):
643 645 ... print("%s %s %d" % (_forcestr(m), _forcestr(f), l))
644 646 ... print(repr(_forcestr(s)))
645 647 >>> lines = [
646 648 ... b'comment',
647 649 ... b' >>> from __future__ import print_function',
648 650 ... b" >>> ' multiline",
649 651 ... b" ... string'",
650 652 ... b' ',
651 653 ... b'comment',
652 654 ... b' $ cat > foo.py <<EOF',
653 655 ... b' > from __future__ import print_function',
654 656 ... b' > EOF',
655 657 ... ]
656 658 >>> test(b"example.t", lines)
657 659 example[2] doctest.py 2
658 660 "from __future__ import print_function\\n' multiline\\nstring'\\n"
659 661 example[7] foo.py 7
660 662 'from __future__ import print_function\\n'
661 663 """
662 664 inlinepython = 0
663 665 shpython = 0
664 666 script = []
665 667 prefix = 6
666 668 t = ''
667 669 n = 0
668 670 for l in src:
669 671 n += 1
670 672 if not l.endswith(b'\n'):
671 673 l += b'\n'
672 674 if l.startswith(b' >>> '): # python inlines
673 675 if shpython:
674 676 print("%s:%d: Parse Error" % (f, n))
675 677 if not inlinepython:
676 678 # We've just entered a Python block.
677 679 inlinepython = n
678 680 t = b'doctest.py'
679 681 script.append(l[prefix:])
680 682 continue
681 683 if l.startswith(b' ... '): # python inlines
682 684 script.append(l[prefix:])
683 685 continue
684 686 cat = re.search(br"\$ \s*cat\s*>\s*(\S+\.py)\s*<<\s*EOF", l)
685 687 if cat:
686 688 if inlinepython:
687 689 yield b''.join(script), (b"%s[%d]" %
688 690 (modname, inlinepython)), t, inlinepython
689 691 script = []
690 692 inlinepython = 0
691 693 shpython = n
692 694 t = cat.group(1)
693 695 continue
694 696 if shpython and l.startswith(b' > '): # sh continuation
695 697 if l == b' > EOF\n':
696 698 yield b''.join(script), (b"%s[%d]" %
697 699 (modname, shpython)), t, shpython
698 700 script = []
699 701 shpython = 0
700 702 else:
701 703 script.append(l[4:])
702 704 continue
703 705 # If we have an empty line or a command for sh, we end the
704 706 # inline script.
705 707 if inlinepython and (l == b' \n'
706 708 or l.startswith(b' $ ')):
707 709 yield b''.join(script), (b"%s[%d]" %
708 710 (modname, inlinepython)), t, inlinepython
709 711 script = []
710 712 inlinepython = 0
711 713 continue
712 714
713 715 def sources(f, modname):
714 716 """Yields possibly multiple sources from a filepath
715 717
716 718 input: filepath, modulename
717 719 yields: script(string), modulename, filepath, linenumber
718 720
719 721 For embedded scripts, the modulename and filepath will be different
720 722 from the function arguments. linenumber is an offset relative to
721 723 the input file.
722 724 """
723 725 py = False
724 726 if not f.endswith('.t'):
725 727 with open(f, 'rb') as src:
726 728 yield src.read(), modname, f, 0
727 729 py = True
728 730 if py or f.endswith('.t'):
729 731 with open(f, 'rb') as src:
730 732 for script, modname, t, line in embedded(f, modname, src):
731 733 yield script, modname, t, line
732 734
733 735 def main(argv):
734 736 if len(argv) < 2 or (argv[1] == '-' and len(argv) > 2):
735 737 print('Usage: %s {-|file [file] [file] ...}')
736 738 return 1
737 739 if argv[1] == '-':
738 740 argv = argv[:1]
739 741 argv.extend(l.rstrip() for l in sys.stdin.readlines())
740 742 localmodpaths = {}
741 743 used_imports = {}
742 744 any_errors = False
743 745 for source_path in argv[1:]:
744 746 modname = dotted_name_of_path(source_path)
745 747 localmodpaths[modname] = source_path
746 748 localmods = populateextmods(localmodpaths)
747 749 for localmodname, source_path in sorted(localmodpaths.items()):
748 750 if not isinstance(localmodname, bytes):
749 751 # This is only safe because all hg's files are ascii
750 752 localmodname = localmodname.encode('ascii')
751 753 for src, modname, name, line in sources(source_path, localmodname):
752 754 try:
753 755 used_imports[modname] = sorted(
754 756 imported_modules(src, modname, name, localmods,
755 757 ignore_nested=True))
756 758 for error, lineno in verify_import_convention(modname, src,
757 759 localmods):
758 760 any_errors = True
759 761 print('%s:%d: %s' % (source_path, lineno + line, error))
760 762 except SyntaxError as e:
761 763 print('%s:%d: SyntaxError: %s' %
762 764 (source_path, e.lineno + line, e))
763 765 cycles = find_cycles(used_imports)
764 766 if cycles:
765 767 firstmods = set()
766 768 for c in sorted(cycles, key=_cycle_sortkey):
767 769 first = c.split()[0]
768 770 # As a rough cut, ignore any cycle that starts with the
769 771 # same module as some other cycle. Otherwise we see lots
770 772 # of cycles that are effectively duplicates.
771 773 if first in firstmods:
772 774 continue
773 775 print('Import cycle:', c)
774 776 firstmods.add(first)
775 777 any_errors = True
776 778 return any_errors != 0
777 779
778 780 if __name__ == '__main__':
779 781 sys.exit(int(main(sys.argv)))
General Comments 0
You need to be logged in to leave comments. Login now