##// END OF EJS Templates
cborutil: implement support for streaming encoding, bytestring decoding...
Gregory Szorc -
r37729:65a23cc8 default
parent child Browse files
Show More
@@ -0,0 +1,258 b''
1 # cborutil.py - CBOR extensions
2 #
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 #
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
7
8 from __future__ import absolute_import
9
10 import struct
11
12 from ..thirdparty.cbor.cbor2 import (
13 decoder as decodermod,
14 )
15
16 # Very short very of RFC 7049...
17 #
18 # Each item begins with a byte. The 3 high bits of that byte denote the
19 # "major type." The lower 5 bits denote the "subtype." Each major type
20 # has its own encoding mechanism.
21 #
22 # Most types have lengths. However, bytestring, string, array, and map
23 # can be indefinite length. These are denotes by a subtype with value 31.
24 # Sub-components of those types then come afterwards and are terminated
25 # by a "break" byte.
26
27 MAJOR_TYPE_UINT = 0
28 MAJOR_TYPE_NEGINT = 1
29 MAJOR_TYPE_BYTESTRING = 2
30 MAJOR_TYPE_STRING = 3
31 MAJOR_TYPE_ARRAY = 4
32 MAJOR_TYPE_MAP = 5
33 MAJOR_TYPE_SEMANTIC = 6
34 MAJOR_TYPE_SPECIAL = 7
35
36 SUBTYPE_MASK = 0b00011111
37
38 SUBTYPE_HALF_FLOAT = 25
39 SUBTYPE_SINGLE_FLOAT = 26
40 SUBTYPE_DOUBLE_FLOAT = 27
41 SUBTYPE_INDEFINITE = 31
42
43 # Indefinite types begin with their major type ORd with information value 31.
44 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
45 r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
46 BEGIN_INDEFINITE_ARRAY = struct.pack(
47 r'>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE)
48 BEGIN_INDEFINITE_MAP = struct.pack(
49 r'>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE)
50
51 ENCODED_LENGTH_1 = struct.Struct(r'>B')
52 ENCODED_LENGTH_2 = struct.Struct(r'>BB')
53 ENCODED_LENGTH_3 = struct.Struct(r'>BH')
54 ENCODED_LENGTH_4 = struct.Struct(r'>BL')
55 ENCODED_LENGTH_5 = struct.Struct(r'>BQ')
56
57 # The break ends an indefinite length item.
58 BREAK = b'\xff'
59 BREAK_INT = 255
60
61 def encodelength(majortype, length):
62 """Obtain a value encoding the major type and its length."""
63 if length < 24:
64 return ENCODED_LENGTH_1.pack(majortype << 5 | length)
65 elif length < 256:
66 return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
67 elif length < 65536:
68 return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
69 elif length < 4294967296:
70 return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
71 else:
72 return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
73
74 def streamencodebytestring(v):
75 yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
76 yield v
77
78 def streamencodebytestringfromiter(it):
79 """Convert an iterator of chunks to an indefinite bytestring.
80
81 Given an input that is iterable and each element in the iterator is
82 representable as bytes, emit an indefinite length bytestring.
83 """
84 yield BEGIN_INDEFINITE_BYTESTRING
85
86 for chunk in it:
87 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
88 yield chunk
89
90 yield BREAK
91
92 def streamencodeindefinitebytestring(source, chunksize=65536):
93 """Given a large source buffer, emit as an indefinite length bytestring.
94
95 This is a generator of chunks constituting the encoded CBOR data.
96 """
97 yield BEGIN_INDEFINITE_BYTESTRING
98
99 i = 0
100 l = len(source)
101
102 while True:
103 chunk = source[i:i + chunksize]
104 i += len(chunk)
105
106 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
107 yield chunk
108
109 if i >= l:
110 break
111
112 yield BREAK
113
114 def streamencodeint(v):
115 if v >= 18446744073709551616 or v < -18446744073709551616:
116 raise ValueError('big integers not supported')
117
118 if v >= 0:
119 yield encodelength(MAJOR_TYPE_UINT, v)
120 else:
121 yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
122
123 def streamencodearray(l):
124 """Encode a known size iterable to an array."""
125
126 yield encodelength(MAJOR_TYPE_ARRAY, len(l))
127
128 for i in l:
129 for chunk in streamencode(i):
130 yield chunk
131
132 def streamencodearrayfromiter(it):
133 """Encode an iterator of items to an indefinite length array."""
134
135 yield BEGIN_INDEFINITE_ARRAY
136
137 for i in it:
138 for chunk in streamencode(i):
139 yield chunk
140
141 yield BREAK
142
143 def streamencodeset(s):
144 # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
145 # semantic tag 258 for finite sets.
146 yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
147
148 for chunk in streamencodearray(sorted(s)):
149 yield chunk
150
151 def streamencodemap(d):
152 """Encode dictionary to a generator.
153
154 Does not supporting indefinite length dictionaries.
155 """
156 yield encodelength(MAJOR_TYPE_MAP, len(d))
157
158 for key, value in sorted(d.iteritems()):
159 for chunk in streamencode(key):
160 yield chunk
161 for chunk in streamencode(value):
162 yield chunk
163
164 def streamencodemapfromiter(it):
165 """Given an iterable of (key, value), encode to an indefinite length map."""
166 yield BEGIN_INDEFINITE_MAP
167
168 for key, value in it:
169 for chunk in streamencode(key):
170 yield chunk
171 for chunk in streamencode(value):
172 yield chunk
173
174 yield BREAK
175
176 def streamencodebool(b):
177 # major type 7, simple value 20 and 21.
178 yield b'\xf5' if b else b'\xf4'
179
180 def streamencodenone(v):
181 # major type 7, simple value 22.
182 yield b'\xf6'
183
184 STREAM_ENCODERS = {
185 bytes: streamencodebytestring,
186 int: streamencodeint,
187 list: streamencodearray,
188 tuple: streamencodearray,
189 dict: streamencodemap,
190 set: streamencodeset,
191 bool: streamencodebool,
192 type(None): streamencodenone,
193 }
194
195 def streamencode(v):
196 """Encode a value in a streaming manner.
197
198 Given an input object, encode it to CBOR recursively.
199
200 Returns a generator of CBOR encoded bytes. There is no guarantee
201 that each emitted chunk fully decodes to a value or sub-value.
202
203 Encoding is deterministic - unordered collections are sorted.
204 """
205 fn = STREAM_ENCODERS.get(v.__class__)
206
207 if not fn:
208 raise ValueError('do not know how to encode %s' % type(v))
209
210 return fn(v)
211
212 def readindefinitebytestringtoiter(fh, expectheader=True):
213 """Read an indefinite bytestring to a generator.
214
215 Receives an object with a ``read(X)`` method to read N bytes.
216
217 If ``expectheader`` is True, it is expected that the first byte read
218 will represent an indefinite length bytestring. Otherwise, we
219 expect the first byte to be part of the first bytestring chunk.
220 """
221 read = fh.read
222 decodeuint = decodermod.decode_uint
223 byteasinteger = decodermod.byte_as_integer
224
225 if expectheader:
226 initial = decodermod.byte_as_integer(read(1))
227
228 majortype = initial >> 5
229 subtype = initial & SUBTYPE_MASK
230
231 if majortype != MAJOR_TYPE_BYTESTRING:
232 raise decodermod.CBORDecodeError(
233 'expected major type %d; got %d' % (MAJOR_TYPE_BYTESTRING,
234 majortype))
235
236 if subtype != SUBTYPE_INDEFINITE:
237 raise decodermod.CBORDecodeError(
238 'expected indefinite subtype; got %d' % subtype)
239
240 # The indefinite bytestring is composed of chunks of normal bytestrings.
241 # Read chunks until we hit a BREAK byte.
242
243 while True:
244 # We need to sniff for the BREAK byte.
245 initial = byteasinteger(read(1))
246
247 if initial == BREAK_INT:
248 break
249
250 length = decodeuint(fh, initial & SUBTYPE_MASK)
251 chunk = read(length)
252
253 if len(chunk) != length:
254 raise decodermod.CBORDecodeError(
255 'failed to read bytestring chunk: got %d bytes; expected %d' % (
256 len(chunk), length))
257
258 yield chunk
@@ -0,0 +1,210 b''
1 from __future__ import absolute_import
2
3 import io
4 import unittest
5
6 from mercurial.thirdparty import (
7 cbor,
8 )
9 from mercurial.utils import (
10 cborutil,
11 )
12
13 def loadit(it):
14 return cbor.loads(b''.join(it))
15
16 class BytestringTests(unittest.TestCase):
17 def testsimple(self):
18 self.assertEqual(
19 list(cborutil.streamencode(b'foobar')),
20 [b'\x46', b'foobar'])
21
22 self.assertEqual(
23 loadit(cborutil.streamencode(b'foobar')),
24 b'foobar')
25
26 def testlong(self):
27 source = b'x' * 1048576
28
29 self.assertEqual(loadit(cborutil.streamencode(source)), source)
30
31 def testfromiter(self):
32 # This is the example from RFC 7049 Section 2.2.2.
33 source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
34
35 self.assertEqual(
36 list(cborutil.streamencodebytestringfromiter(source)),
37 [
38 b'\x5f',
39 b'\x44',
40 b'\xaa\xbb\xcc\xdd',
41 b'\x43',
42 b'\xee\xff\x99',
43 b'\xff',
44 ])
45
46 self.assertEqual(
47 loadit(cborutil.streamencodebytestringfromiter(source)),
48 b''.join(source))
49
50 def testfromiterlarge(self):
51 source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
52
53 self.assertEqual(
54 loadit(cborutil.streamencodebytestringfromiter(source)),
55 b''.join(source))
56
57 def testindefinite(self):
58 source = b'\x00\x01\x02\x03' + b'\xff' * 16384
59
60 it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
61
62 self.assertEqual(next(it), b'\x5f')
63 self.assertEqual(next(it), b'\x42')
64 self.assertEqual(next(it), b'\x00\x01')
65 self.assertEqual(next(it), b'\x42')
66 self.assertEqual(next(it), b'\x02\x03')
67 self.assertEqual(next(it), b'\x42')
68 self.assertEqual(next(it), b'\xff\xff')
69
70 dest = b''.join(cborutil.streamencodeindefinitebytestring(
71 source, chunksize=42))
72 self.assertEqual(cbor.loads(dest), b''.join(source))
73
74 def testreadtoiter(self):
75 source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
76
77 it = cborutil.readindefinitebytestringtoiter(source)
78 self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd')
79 self.assertEqual(next(it), b'\xee\xff\x99')
80
81 with self.assertRaises(StopIteration):
82 next(it)
83
84 class IntTests(unittest.TestCase):
85 def testsmall(self):
86 self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
87 self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
88 self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
89 self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
90 self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
91
92 def testnegativesmall(self):
93 self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
94 self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
95 self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
96 self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
97 self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
98
99 def testrange(self):
100 for i in range(-70000, 70000, 10):
101 self.assertEqual(
102 b''.join(cborutil.streamencode(i)),
103 cbor.dumps(i))
104
105 class ArrayTests(unittest.TestCase):
106 def testempty(self):
107 self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
108 self.assertEqual(loadit(cborutil.streamencode([])), [])
109
110 def testbasic(self):
111 source = [b'foo', b'bar', 1, -10]
112
113 self.assertEqual(list(cborutil.streamencode(source)), [
114 b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
115
116 def testemptyfromiter(self):
117 self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
118 b'\x9f\xff')
119
120 def testfromiter1(self):
121 source = [b'foo']
122
123 self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [
124 b'\x9f',
125 b'\x43', b'foo',
126 b'\xff',
127 ])
128
129 dest = b''.join(cborutil.streamencodearrayfromiter(source))
130 self.assertEqual(cbor.loads(dest), source)
131
132 def testtuple(self):
133 source = (b'foo', None, 42)
134
135 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
136 list(source))
137
138 class SetTests(unittest.TestCase):
139 def testempty(self):
140 self.assertEqual(list(cborutil.streamencode(set())), [
141 b'\xd9\x01\x02',
142 b'\x80',
143 ])
144
145 def testset(self):
146 source = {b'foo', None, 42}
147
148 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
149 source)
150
151 class BoolTests(unittest.TestCase):
152 def testbasic(self):
153 self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5'])
154 self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
155
156 self.assertIs(loadit(cborutil.streamencode(True)), True)
157 self.assertIs(loadit(cborutil.streamencode(False)), False)
158
159 class NoneTests(unittest.TestCase):
160 def testbasic(self):
161 self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
162
163 self.assertIs(loadit(cborutil.streamencode(None)), None)
164
165 class MapTests(unittest.TestCase):
166 def testempty(self):
167 self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
168 self.assertEqual(loadit(cborutil.streamencode({})), {})
169
170 def testemptyindefinite(self):
171 self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
172 b'\xbf', b'\xff'])
173
174 self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
175
176 def testone(self):
177 source = {b'foo': b'bar'}
178 self.assertEqual(list(cborutil.streamencode(source)), [
179 b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
180
181 self.assertEqual(loadit(cborutil.streamencode(source)), source)
182
183 def testmultiple(self):
184 source = {
185 b'foo': b'bar',
186 b'baz': b'value1',
187 }
188
189 self.assertEqual(loadit(cborutil.streamencode(source)), source)
190
191 self.assertEqual(
192 loadit(cborutil.streamencodemapfromiter(source.items())),
193 source)
194
195 def testcomplex(self):
196 source = {
197 b'key': 1,
198 2: -10,
199 }
200
201 self.assertEqual(loadit(cborutil.streamencode(source)),
202 source)
203
204 self.assertEqual(
205 loadit(cborutil.streamencodemapfromiter(source.items())),
206 source)
207
208 if __name__ == '__main__':
209 import silenttestrunner
210 silenttestrunner.main(__name__)
@@ -1,779 +1,781 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2
2
3 from __future__ import absolute_import, print_function
3 from __future__ import absolute_import, print_function
4
4
5 import ast
5 import ast
6 import collections
6 import collections
7 import os
7 import os
8 import re
8 import re
9 import sys
9 import sys
10
10
11 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
11 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
12 # to work when run from a virtualenv. The modules were chosen empirically
12 # to work when run from a virtualenv. The modules were chosen empirically
13 # so that the return value matches the return value without virtualenv.
13 # so that the return value matches the return value without virtualenv.
14 if True: # disable lexical sorting checks
14 if True: # disable lexical sorting checks
15 try:
15 try:
16 import BaseHTTPServer as basehttpserver
16 import BaseHTTPServer as basehttpserver
17 except ImportError:
17 except ImportError:
18 basehttpserver = None
18 basehttpserver = None
19 import zlib
19 import zlib
20
20
21 # Whitelist of modules that symbols can be directly imported from.
21 # Whitelist of modules that symbols can be directly imported from.
22 allowsymbolimports = (
22 allowsymbolimports = (
23 '__future__',
23 '__future__',
24 'bzrlib',
24 'bzrlib',
25 'hgclient',
25 'hgclient',
26 'mercurial',
26 'mercurial',
27 'mercurial.hgweb.common',
27 'mercurial.hgweb.common',
28 'mercurial.hgweb.request',
28 'mercurial.hgweb.request',
29 'mercurial.i18n',
29 'mercurial.i18n',
30 'mercurial.node',
30 'mercurial.node',
31 # for cffi modules to re-export pure functions
31 # for cffi modules to re-export pure functions
32 'mercurial.pure.base85',
32 'mercurial.pure.base85',
33 'mercurial.pure.bdiff',
33 'mercurial.pure.bdiff',
34 'mercurial.pure.mpatch',
34 'mercurial.pure.mpatch',
35 'mercurial.pure.osutil',
35 'mercurial.pure.osutil',
36 'mercurial.pure.parsers',
36 'mercurial.pure.parsers',
37 # third-party imports should be directly imported
37 # third-party imports should be directly imported
38 'mercurial.thirdparty',
38 'mercurial.thirdparty',
39 'mercurial.thirdparty.cbor',
40 'mercurial.thirdparty.cbor.cbor2',
39 'mercurial.thirdparty.zope',
41 'mercurial.thirdparty.zope',
40 'mercurial.thirdparty.zope.interface',
42 'mercurial.thirdparty.zope.interface',
41 )
43 )
42
44
43 # Whitelist of symbols that can be directly imported.
45 # Whitelist of symbols that can be directly imported.
44 directsymbols = (
46 directsymbols = (
45 'demandimport',
47 'demandimport',
46 )
48 )
47
49
48 # Modules that must be aliased because they are commonly confused with
50 # Modules that must be aliased because they are commonly confused with
49 # common variables and can create aliasing and readability issues.
51 # common variables and can create aliasing and readability issues.
50 requirealias = {
52 requirealias = {
51 'ui': 'uimod',
53 'ui': 'uimod',
52 }
54 }
53
55
54 def usingabsolute(root):
56 def usingabsolute(root):
55 """Whether absolute imports are being used."""
57 """Whether absolute imports are being used."""
56 if sys.version_info[0] >= 3:
58 if sys.version_info[0] >= 3:
57 return True
59 return True
58
60
59 for node in ast.walk(root):
61 for node in ast.walk(root):
60 if isinstance(node, ast.ImportFrom):
62 if isinstance(node, ast.ImportFrom):
61 if node.module == '__future__':
63 if node.module == '__future__':
62 for n in node.names:
64 for n in node.names:
63 if n.name == 'absolute_import':
65 if n.name == 'absolute_import':
64 return True
66 return True
65
67
66 return False
68 return False
67
69
68 def walklocal(root):
70 def walklocal(root):
69 """Recursively yield all descendant nodes but not in a different scope"""
71 """Recursively yield all descendant nodes but not in a different scope"""
70 todo = collections.deque(ast.iter_child_nodes(root))
72 todo = collections.deque(ast.iter_child_nodes(root))
71 yield root, False
73 yield root, False
72 while todo:
74 while todo:
73 node = todo.popleft()
75 node = todo.popleft()
74 newscope = isinstance(node, ast.FunctionDef)
76 newscope = isinstance(node, ast.FunctionDef)
75 if not newscope:
77 if not newscope:
76 todo.extend(ast.iter_child_nodes(node))
78 todo.extend(ast.iter_child_nodes(node))
77 yield node, newscope
79 yield node, newscope
78
80
79 def dotted_name_of_path(path):
81 def dotted_name_of_path(path):
80 """Given a relative path to a source file, return its dotted module name.
82 """Given a relative path to a source file, return its dotted module name.
81
83
82 >>> dotted_name_of_path('mercurial/error.py')
84 >>> dotted_name_of_path('mercurial/error.py')
83 'mercurial.error'
85 'mercurial.error'
84 >>> dotted_name_of_path('zlibmodule.so')
86 >>> dotted_name_of_path('zlibmodule.so')
85 'zlib'
87 'zlib'
86 """
88 """
87 parts = path.replace(os.sep, '/').split('/')
89 parts = path.replace(os.sep, '/').split('/')
88 parts[-1] = parts[-1].split('.', 1)[0] # remove .py and .so and .ARCH.so
90 parts[-1] = parts[-1].split('.', 1)[0] # remove .py and .so and .ARCH.so
89 if parts[-1].endswith('module'):
91 if parts[-1].endswith('module'):
90 parts[-1] = parts[-1][:-6]
92 parts[-1] = parts[-1][:-6]
91 return '.'.join(parts)
93 return '.'.join(parts)
92
94
93 def fromlocalfunc(modulename, localmods):
95 def fromlocalfunc(modulename, localmods):
94 """Get a function to examine which locally defined module the
96 """Get a function to examine which locally defined module the
95 target source imports via a specified name.
97 target source imports via a specified name.
96
98
97 `modulename` is an `dotted_name_of_path()`-ed source file path,
99 `modulename` is an `dotted_name_of_path()`-ed source file path,
98 which may have `.__init__` at the end of it, of the target source.
100 which may have `.__init__` at the end of it, of the target source.
99
101
100 `localmods` is a set of absolute `dotted_name_of_path()`-ed source file
102 `localmods` is a set of absolute `dotted_name_of_path()`-ed source file
101 paths of locally defined (= Mercurial specific) modules.
103 paths of locally defined (= Mercurial specific) modules.
102
104
103 This function assumes that module names not existing in
105 This function assumes that module names not existing in
104 `localmods` are from the Python standard library.
106 `localmods` are from the Python standard library.
105
107
106 This function returns the function, which takes `name` argument,
108 This function returns the function, which takes `name` argument,
107 and returns `(absname, dottedpath, hassubmod)` tuple if `name`
109 and returns `(absname, dottedpath, hassubmod)` tuple if `name`
108 matches against locally defined module. Otherwise, it returns
110 matches against locally defined module. Otherwise, it returns
109 False.
111 False.
110
112
111 It is assumed that `name` doesn't have `.__init__`.
113 It is assumed that `name` doesn't have `.__init__`.
112
114
113 `absname` is an absolute module name of specified `name`
115 `absname` is an absolute module name of specified `name`
114 (e.g. "hgext.convert"). This can be used to compose prefix for sub
116 (e.g. "hgext.convert"). This can be used to compose prefix for sub
115 modules or so.
117 modules or so.
116
118
117 `dottedpath` is a `dotted_name_of_path()`-ed source file path
119 `dottedpath` is a `dotted_name_of_path()`-ed source file path
118 (e.g. "hgext.convert.__init__") of `name`. This is used to look
120 (e.g. "hgext.convert.__init__") of `name`. This is used to look
119 module up in `localmods` again.
121 module up in `localmods` again.
120
122
121 `hassubmod` is whether it may have sub modules under it (for
123 `hassubmod` is whether it may have sub modules under it (for
122 convenient, even though this is also equivalent to "absname !=
124 convenient, even though this is also equivalent to "absname !=
123 dottednpath")
125 dottednpath")
124
126
125 >>> localmods = {'foo.__init__', 'foo.foo1',
127 >>> localmods = {'foo.__init__', 'foo.foo1',
126 ... 'foo.bar.__init__', 'foo.bar.bar1',
128 ... 'foo.bar.__init__', 'foo.bar.bar1',
127 ... 'baz.__init__', 'baz.baz1'}
129 ... 'baz.__init__', 'baz.baz1'}
128 >>> fromlocal = fromlocalfunc('foo.xxx', localmods)
130 >>> fromlocal = fromlocalfunc('foo.xxx', localmods)
129 >>> # relative
131 >>> # relative
130 >>> fromlocal('foo1')
132 >>> fromlocal('foo1')
131 ('foo.foo1', 'foo.foo1', False)
133 ('foo.foo1', 'foo.foo1', False)
132 >>> fromlocal('bar')
134 >>> fromlocal('bar')
133 ('foo.bar', 'foo.bar.__init__', True)
135 ('foo.bar', 'foo.bar.__init__', True)
134 >>> fromlocal('bar.bar1')
136 >>> fromlocal('bar.bar1')
135 ('foo.bar.bar1', 'foo.bar.bar1', False)
137 ('foo.bar.bar1', 'foo.bar.bar1', False)
136 >>> # absolute
138 >>> # absolute
137 >>> fromlocal('baz')
139 >>> fromlocal('baz')
138 ('baz', 'baz.__init__', True)
140 ('baz', 'baz.__init__', True)
139 >>> fromlocal('baz.baz1')
141 >>> fromlocal('baz.baz1')
140 ('baz.baz1', 'baz.baz1', False)
142 ('baz.baz1', 'baz.baz1', False)
141 >>> # unknown = maybe standard library
143 >>> # unknown = maybe standard library
142 >>> fromlocal('os')
144 >>> fromlocal('os')
143 False
145 False
144 >>> fromlocal(None, 1)
146 >>> fromlocal(None, 1)
145 ('foo', 'foo.__init__', True)
147 ('foo', 'foo.__init__', True)
146 >>> fromlocal('foo1', 1)
148 >>> fromlocal('foo1', 1)
147 ('foo.foo1', 'foo.foo1', False)
149 ('foo.foo1', 'foo.foo1', False)
148 >>> fromlocal2 = fromlocalfunc('foo.xxx.yyy', localmods)
150 >>> fromlocal2 = fromlocalfunc('foo.xxx.yyy', localmods)
149 >>> fromlocal2(None, 2)
151 >>> fromlocal2(None, 2)
150 ('foo', 'foo.__init__', True)
152 ('foo', 'foo.__init__', True)
151 >>> fromlocal2('bar2', 1)
153 >>> fromlocal2('bar2', 1)
152 False
154 False
153 >>> fromlocal2('bar', 2)
155 >>> fromlocal2('bar', 2)
154 ('foo.bar', 'foo.bar.__init__', True)
156 ('foo.bar', 'foo.bar.__init__', True)
155 """
157 """
156 if not isinstance(modulename, str):
158 if not isinstance(modulename, str):
157 modulename = modulename.decode('ascii')
159 modulename = modulename.decode('ascii')
158 prefix = '.'.join(modulename.split('.')[:-1])
160 prefix = '.'.join(modulename.split('.')[:-1])
159 if prefix:
161 if prefix:
160 prefix += '.'
162 prefix += '.'
161 def fromlocal(name, level=0):
163 def fromlocal(name, level=0):
162 # name is false value when relative imports are used.
164 # name is false value when relative imports are used.
163 if not name:
165 if not name:
164 # If relative imports are used, level must not be absolute.
166 # If relative imports are used, level must not be absolute.
165 assert level > 0
167 assert level > 0
166 candidates = ['.'.join(modulename.split('.')[:-level])]
168 candidates = ['.'.join(modulename.split('.')[:-level])]
167 else:
169 else:
168 if not level:
170 if not level:
169 # Check relative name first.
171 # Check relative name first.
170 candidates = [prefix + name, name]
172 candidates = [prefix + name, name]
171 else:
173 else:
172 candidates = ['.'.join(modulename.split('.')[:-level]) +
174 candidates = ['.'.join(modulename.split('.')[:-level]) +
173 '.' + name]
175 '.' + name]
174
176
175 for n in candidates:
177 for n in candidates:
176 if n in localmods:
178 if n in localmods:
177 return (n, n, False)
179 return (n, n, False)
178 dottedpath = n + '.__init__'
180 dottedpath = n + '.__init__'
179 if dottedpath in localmods:
181 if dottedpath in localmods:
180 return (n, dottedpath, True)
182 return (n, dottedpath, True)
181 return False
183 return False
182 return fromlocal
184 return fromlocal
183
185
184 def populateextmods(localmods):
186 def populateextmods(localmods):
185 """Populate C extension modules based on pure modules"""
187 """Populate C extension modules based on pure modules"""
186 newlocalmods = set(localmods)
188 newlocalmods = set(localmods)
187 for n in localmods:
189 for n in localmods:
188 if n.startswith('mercurial.pure.'):
190 if n.startswith('mercurial.pure.'):
189 m = n[len('mercurial.pure.'):]
191 m = n[len('mercurial.pure.'):]
190 newlocalmods.add('mercurial.cext.' + m)
192 newlocalmods.add('mercurial.cext.' + m)
191 newlocalmods.add('mercurial.cffi._' + m)
193 newlocalmods.add('mercurial.cffi._' + m)
192 return newlocalmods
194 return newlocalmods
193
195
194 def list_stdlib_modules():
196 def list_stdlib_modules():
195 """List the modules present in the stdlib.
197 """List the modules present in the stdlib.
196
198
197 >>> py3 = sys.version_info[0] >= 3
199 >>> py3 = sys.version_info[0] >= 3
198 >>> mods = set(list_stdlib_modules())
200 >>> mods = set(list_stdlib_modules())
199 >>> 'BaseHTTPServer' in mods or py3
201 >>> 'BaseHTTPServer' in mods or py3
200 True
202 True
201
203
202 os.path isn't really a module, so it's missing:
204 os.path isn't really a module, so it's missing:
203
205
204 >>> 'os.path' in mods
206 >>> 'os.path' in mods
205 False
207 False
206
208
207 sys requires special treatment, because it's baked into the
209 sys requires special treatment, because it's baked into the
208 interpreter, but it should still appear:
210 interpreter, but it should still appear:
209
211
210 >>> 'sys' in mods
212 >>> 'sys' in mods
211 True
213 True
212
214
213 >>> 'collections' in mods
215 >>> 'collections' in mods
214 True
216 True
215
217
216 >>> 'cStringIO' in mods or py3
218 >>> 'cStringIO' in mods or py3
217 True
219 True
218
220
219 >>> 'cffi' in mods
221 >>> 'cffi' in mods
220 True
222 True
221 """
223 """
222 for m in sys.builtin_module_names:
224 for m in sys.builtin_module_names:
223 yield m
225 yield m
224 # These modules only exist on windows, but we should always
226 # These modules only exist on windows, but we should always
225 # consider them stdlib.
227 # consider them stdlib.
226 for m in ['msvcrt', '_winreg']:
228 for m in ['msvcrt', '_winreg']:
227 yield m
229 yield m
228 yield '__builtin__'
230 yield '__builtin__'
229 yield 'builtins' # python3 only
231 yield 'builtins' # python3 only
230 yield 'importlib.abc' # python3 only
232 yield 'importlib.abc' # python3 only
231 yield 'importlib.machinery' # python3 only
233 yield 'importlib.machinery' # python3 only
232 yield 'importlib.util' # python3 only
234 yield 'importlib.util' # python3 only
233 for m in 'fcntl', 'grp', 'pwd', 'termios': # Unix only
235 for m in 'fcntl', 'grp', 'pwd', 'termios': # Unix only
234 yield m
236 yield m
235 for m in 'cPickle', 'datetime': # in Python (not C) on PyPy
237 for m in 'cPickle', 'datetime': # in Python (not C) on PyPy
236 yield m
238 yield m
237 for m in ['cffi']:
239 for m in ['cffi']:
238 yield m
240 yield m
239 stdlib_prefixes = {sys.prefix, sys.exec_prefix}
241 stdlib_prefixes = {sys.prefix, sys.exec_prefix}
240 # We need to supplement the list of prefixes for the search to work
242 # We need to supplement the list of prefixes for the search to work
241 # when run from within a virtualenv.
243 # when run from within a virtualenv.
242 for mod in (basehttpserver, zlib):
244 for mod in (basehttpserver, zlib):
243 if mod is None:
245 if mod is None:
244 continue
246 continue
245 try:
247 try:
246 # Not all module objects have a __file__ attribute.
248 # Not all module objects have a __file__ attribute.
247 filename = mod.__file__
249 filename = mod.__file__
248 except AttributeError:
250 except AttributeError:
249 continue
251 continue
250 dirname = os.path.dirname(filename)
252 dirname = os.path.dirname(filename)
251 for prefix in stdlib_prefixes:
253 for prefix in stdlib_prefixes:
252 if dirname.startswith(prefix):
254 if dirname.startswith(prefix):
253 # Then this directory is redundant.
255 # Then this directory is redundant.
254 break
256 break
255 else:
257 else:
256 stdlib_prefixes.add(dirname)
258 stdlib_prefixes.add(dirname)
257 for libpath in sys.path:
259 for libpath in sys.path:
258 # We want to walk everything in sys.path that starts with
260 # We want to walk everything in sys.path that starts with
259 # something in stdlib_prefixes.
261 # something in stdlib_prefixes.
260 if not any(libpath.startswith(p) for p in stdlib_prefixes):
262 if not any(libpath.startswith(p) for p in stdlib_prefixes):
261 continue
263 continue
262 for top, dirs, files in os.walk(libpath):
264 for top, dirs, files in os.walk(libpath):
263 for i, d in reversed(list(enumerate(dirs))):
265 for i, d in reversed(list(enumerate(dirs))):
264 if (not os.path.exists(os.path.join(top, d, '__init__.py'))
266 if (not os.path.exists(os.path.join(top, d, '__init__.py'))
265 or top == libpath and d in ('hgdemandimport', 'hgext',
267 or top == libpath and d in ('hgdemandimport', 'hgext',
266 'mercurial')):
268 'mercurial')):
267 del dirs[i]
269 del dirs[i]
268 for name in files:
270 for name in files:
269 if not name.endswith(('.py', '.so', '.pyc', '.pyo', '.pyd')):
271 if not name.endswith(('.py', '.so', '.pyc', '.pyo', '.pyd')):
270 continue
272 continue
271 if name.startswith('__init__.py'):
273 if name.startswith('__init__.py'):
272 full_path = top
274 full_path = top
273 else:
275 else:
274 full_path = os.path.join(top, name)
276 full_path = os.path.join(top, name)
275 rel_path = full_path[len(libpath) + 1:]
277 rel_path = full_path[len(libpath) + 1:]
276 mod = dotted_name_of_path(rel_path)
278 mod = dotted_name_of_path(rel_path)
277 yield mod
279 yield mod
278
280
279 stdlib_modules = set(list_stdlib_modules())
281 stdlib_modules = set(list_stdlib_modules())
280
282
281 def imported_modules(source, modulename, f, localmods, ignore_nested=False):
283 def imported_modules(source, modulename, f, localmods, ignore_nested=False):
282 """Given the source of a file as a string, yield the names
284 """Given the source of a file as a string, yield the names
283 imported by that file.
285 imported by that file.
284
286
285 Args:
287 Args:
286 source: The python source to examine as a string.
288 source: The python source to examine as a string.
287 modulename: of specified python source (may have `__init__`)
289 modulename: of specified python source (may have `__init__`)
288 localmods: set of locally defined module names (may have `__init__`)
290 localmods: set of locally defined module names (may have `__init__`)
289 ignore_nested: If true, import statements that do not start in
291 ignore_nested: If true, import statements that do not start in
290 column zero will be ignored.
292 column zero will be ignored.
291
293
292 Returns:
294 Returns:
293 A list of absolute module names imported by the given source.
295 A list of absolute module names imported by the given source.
294
296
295 >>> f = 'foo/xxx.py'
297 >>> f = 'foo/xxx.py'
296 >>> modulename = 'foo.xxx'
298 >>> modulename = 'foo.xxx'
297 >>> localmods = {'foo.__init__': True,
299 >>> localmods = {'foo.__init__': True,
298 ... 'foo.foo1': True, 'foo.foo2': True,
300 ... 'foo.foo1': True, 'foo.foo2': True,
299 ... 'foo.bar.__init__': True, 'foo.bar.bar1': True,
301 ... 'foo.bar.__init__': True, 'foo.bar.bar1': True,
300 ... 'baz.__init__': True, 'baz.baz1': True }
302 ... 'baz.__init__': True, 'baz.baz1': True }
301 >>> # standard library (= not locally defined ones)
303 >>> # standard library (= not locally defined ones)
302 >>> sorted(imported_modules(
304 >>> sorted(imported_modules(
303 ... 'from stdlib1 import foo, bar; import stdlib2',
305 ... 'from stdlib1 import foo, bar; import stdlib2',
304 ... modulename, f, localmods))
306 ... modulename, f, localmods))
305 []
307 []
306 >>> # relative importing
308 >>> # relative importing
307 >>> sorted(imported_modules(
309 >>> sorted(imported_modules(
308 ... 'import foo1; from bar import bar1',
310 ... 'import foo1; from bar import bar1',
309 ... modulename, f, localmods))
311 ... modulename, f, localmods))
310 ['foo.bar.bar1', 'foo.foo1']
312 ['foo.bar.bar1', 'foo.foo1']
311 >>> sorted(imported_modules(
313 >>> sorted(imported_modules(
312 ... 'from bar.bar1 import name1, name2, name3',
314 ... 'from bar.bar1 import name1, name2, name3',
313 ... modulename, f, localmods))
315 ... modulename, f, localmods))
314 ['foo.bar.bar1']
316 ['foo.bar.bar1']
315 >>> # absolute importing
317 >>> # absolute importing
316 >>> sorted(imported_modules(
318 >>> sorted(imported_modules(
317 ... 'from baz import baz1, name1',
319 ... 'from baz import baz1, name1',
318 ... modulename, f, localmods))
320 ... modulename, f, localmods))
319 ['baz.__init__', 'baz.baz1']
321 ['baz.__init__', 'baz.baz1']
320 >>> # mixed importing, even though it shouldn't be recommended
322 >>> # mixed importing, even though it shouldn't be recommended
321 >>> sorted(imported_modules(
323 >>> sorted(imported_modules(
322 ... 'import stdlib, foo1, baz',
324 ... 'import stdlib, foo1, baz',
323 ... modulename, f, localmods))
325 ... modulename, f, localmods))
324 ['baz.__init__', 'foo.foo1']
326 ['baz.__init__', 'foo.foo1']
325 >>> # ignore_nested
327 >>> # ignore_nested
326 >>> sorted(imported_modules(
328 >>> sorted(imported_modules(
327 ... '''import foo
329 ... '''import foo
328 ... def wat():
330 ... def wat():
329 ... import bar
331 ... import bar
330 ... ''', modulename, f, localmods))
332 ... ''', modulename, f, localmods))
331 ['foo.__init__', 'foo.bar.__init__']
333 ['foo.__init__', 'foo.bar.__init__']
332 >>> sorted(imported_modules(
334 >>> sorted(imported_modules(
333 ... '''import foo
335 ... '''import foo
334 ... def wat():
336 ... def wat():
335 ... import bar
337 ... import bar
336 ... ''', modulename, f, localmods, ignore_nested=True))
338 ... ''', modulename, f, localmods, ignore_nested=True))
337 ['foo.__init__']
339 ['foo.__init__']
338 """
340 """
339 fromlocal = fromlocalfunc(modulename, localmods)
341 fromlocal = fromlocalfunc(modulename, localmods)
340 for node in ast.walk(ast.parse(source, f)):
342 for node in ast.walk(ast.parse(source, f)):
341 if ignore_nested and getattr(node, 'col_offset', 0) > 0:
343 if ignore_nested and getattr(node, 'col_offset', 0) > 0:
342 continue
344 continue
343 if isinstance(node, ast.Import):
345 if isinstance(node, ast.Import):
344 for n in node.names:
346 for n in node.names:
345 found = fromlocal(n.name)
347 found = fromlocal(n.name)
346 if not found:
348 if not found:
347 # this should import standard library
349 # this should import standard library
348 continue
350 continue
349 yield found[1]
351 yield found[1]
350 elif isinstance(node, ast.ImportFrom):
352 elif isinstance(node, ast.ImportFrom):
351 found = fromlocal(node.module, node.level)
353 found = fromlocal(node.module, node.level)
352 if not found:
354 if not found:
353 # this should import standard library
355 # this should import standard library
354 continue
356 continue
355
357
356 absname, dottedpath, hassubmod = found
358 absname, dottedpath, hassubmod = found
357 if not hassubmod:
359 if not hassubmod:
358 # "dottedpath" is not a package; must be imported
360 # "dottedpath" is not a package; must be imported
359 yield dottedpath
361 yield dottedpath
360 # examination of "node.names" should be redundant
362 # examination of "node.names" should be redundant
361 # e.g.: from mercurial.node import nullid, nullrev
363 # e.g.: from mercurial.node import nullid, nullrev
362 continue
364 continue
363
365
364 modnotfound = False
366 modnotfound = False
365 prefix = absname + '.'
367 prefix = absname + '.'
366 for n in node.names:
368 for n in node.names:
367 found = fromlocal(prefix + n.name)
369 found = fromlocal(prefix + n.name)
368 if not found:
370 if not found:
369 # this should be a function or a property of "node.module"
371 # this should be a function or a property of "node.module"
370 modnotfound = True
372 modnotfound = True
371 continue
373 continue
372 yield found[1]
374 yield found[1]
373 if modnotfound:
375 if modnotfound:
374 # "dottedpath" is a package, but imported because of non-module
376 # "dottedpath" is a package, but imported because of non-module
375 # lookup
377 # lookup
376 yield dottedpath
378 yield dottedpath
377
379
378 def verify_import_convention(module, source, localmods):
380 def verify_import_convention(module, source, localmods):
379 """Verify imports match our established coding convention.
381 """Verify imports match our established coding convention.
380
382
381 We have 2 conventions: legacy and modern. The modern convention is in
383 We have 2 conventions: legacy and modern. The modern convention is in
382 effect when using absolute imports.
384 effect when using absolute imports.
383
385
384 The legacy convention only looks for mixed imports. The modern convention
386 The legacy convention only looks for mixed imports. The modern convention
385 is much more thorough.
387 is much more thorough.
386 """
388 """
387 root = ast.parse(source)
389 root = ast.parse(source)
388 absolute = usingabsolute(root)
390 absolute = usingabsolute(root)
389
391
390 if absolute:
392 if absolute:
391 return verify_modern_convention(module, root, localmods)
393 return verify_modern_convention(module, root, localmods)
392 else:
394 else:
393 return verify_stdlib_on_own_line(root)
395 return verify_stdlib_on_own_line(root)
394
396
395 def verify_modern_convention(module, root, localmods, root_col_offset=0):
397 def verify_modern_convention(module, root, localmods, root_col_offset=0):
396 """Verify a file conforms to the modern import convention rules.
398 """Verify a file conforms to the modern import convention rules.
397
399
398 The rules of the modern convention are:
400 The rules of the modern convention are:
399
401
400 * Ordering is stdlib followed by local imports. Each group is lexically
402 * Ordering is stdlib followed by local imports. Each group is lexically
401 sorted.
403 sorted.
402 * Importing multiple modules via "import X, Y" is not allowed: use
404 * Importing multiple modules via "import X, Y" is not allowed: use
403 separate import statements.
405 separate import statements.
404 * Importing multiple modules via "from X import ..." is allowed if using
406 * Importing multiple modules via "from X import ..." is allowed if using
405 parenthesis and one entry per line.
407 parenthesis and one entry per line.
406 * Only 1 relative import statement per import level ("from .", "from ..")
408 * Only 1 relative import statement per import level ("from .", "from ..")
407 is allowed.
409 is allowed.
408 * Relative imports from higher levels must occur before lower levels. e.g.
410 * Relative imports from higher levels must occur before lower levels. e.g.
409 "from .." must be before "from .".
411 "from .." must be before "from .".
410 * Imports from peer packages should use relative import (e.g. do not
412 * Imports from peer packages should use relative import (e.g. do not
411 "import mercurial.foo" from a "mercurial.*" module).
413 "import mercurial.foo" from a "mercurial.*" module).
412 * Symbols can only be imported from specific modules (see
414 * Symbols can only be imported from specific modules (see
413 `allowsymbolimports`). For other modules, first import the module then
415 `allowsymbolimports`). For other modules, first import the module then
414 assign the symbol to a module-level variable. In addition, these imports
416 assign the symbol to a module-level variable. In addition, these imports
415 must be performed before other local imports. This rule only
417 must be performed before other local imports. This rule only
416 applies to import statements outside of any blocks.
418 applies to import statements outside of any blocks.
417 * Relative imports from the standard library are not allowed, unless that
419 * Relative imports from the standard library are not allowed, unless that
418 library is also a local module.
420 library is also a local module.
419 * Certain modules must be aliased to alternate names to avoid aliasing
421 * Certain modules must be aliased to alternate names to avoid aliasing
420 and readability problems. See `requirealias`.
422 and readability problems. See `requirealias`.
421 """
423 """
422 if not isinstance(module, str):
424 if not isinstance(module, str):
423 module = module.decode('ascii')
425 module = module.decode('ascii')
424 topmodule = module.split('.')[0]
426 topmodule = module.split('.')[0]
425 fromlocal = fromlocalfunc(module, localmods)
427 fromlocal = fromlocalfunc(module, localmods)
426
428
427 # Whether a local/non-stdlib import has been performed.
429 # Whether a local/non-stdlib import has been performed.
428 seenlocal = None
430 seenlocal = None
429 # Whether a local/non-stdlib, non-symbol import has been seen.
431 # Whether a local/non-stdlib, non-symbol import has been seen.
430 seennonsymbollocal = False
432 seennonsymbollocal = False
431 # The last name to be imported (for sorting).
433 # The last name to be imported (for sorting).
432 lastname = None
434 lastname = None
433 laststdlib = None
435 laststdlib = None
434 # Relative import levels encountered so far.
436 # Relative import levels encountered so far.
435 seenlevels = set()
437 seenlevels = set()
436
438
437 for node, newscope in walklocal(root):
439 for node, newscope in walklocal(root):
438 def msg(fmt, *args):
440 def msg(fmt, *args):
439 return (fmt % args, node.lineno)
441 return (fmt % args, node.lineno)
440 if newscope:
442 if newscope:
441 # Check for local imports in function
443 # Check for local imports in function
442 for r in verify_modern_convention(module, node, localmods,
444 for r in verify_modern_convention(module, node, localmods,
443 node.col_offset + 4):
445 node.col_offset + 4):
444 yield r
446 yield r
445 elif isinstance(node, ast.Import):
447 elif isinstance(node, ast.Import):
446 # Disallow "import foo, bar" and require separate imports
448 # Disallow "import foo, bar" and require separate imports
447 # for each module.
449 # for each module.
448 if len(node.names) > 1:
450 if len(node.names) > 1:
449 yield msg('multiple imported names: %s',
451 yield msg('multiple imported names: %s',
450 ', '.join(n.name for n in node.names))
452 ', '.join(n.name for n in node.names))
451
453
452 name = node.names[0].name
454 name = node.names[0].name
453 asname = node.names[0].asname
455 asname = node.names[0].asname
454
456
455 stdlib = name in stdlib_modules
457 stdlib = name in stdlib_modules
456
458
457 # Ignore sorting rules on imports inside blocks.
459 # Ignore sorting rules on imports inside blocks.
458 if node.col_offset == root_col_offset:
460 if node.col_offset == root_col_offset:
459 if lastname and name < lastname and laststdlib == stdlib:
461 if lastname and name < lastname and laststdlib == stdlib:
460 yield msg('imports not lexically sorted: %s < %s',
462 yield msg('imports not lexically sorted: %s < %s',
461 name, lastname)
463 name, lastname)
462
464
463 lastname = name
465 lastname = name
464 laststdlib = stdlib
466 laststdlib = stdlib
465
467
466 # stdlib imports should be before local imports.
468 # stdlib imports should be before local imports.
467 if stdlib and seenlocal and node.col_offset == root_col_offset:
469 if stdlib and seenlocal and node.col_offset == root_col_offset:
468 yield msg('stdlib import "%s" follows local import: %s',
470 yield msg('stdlib import "%s" follows local import: %s',
469 name, seenlocal)
471 name, seenlocal)
470
472
471 if not stdlib:
473 if not stdlib:
472 seenlocal = name
474 seenlocal = name
473
475
474 # Import of sibling modules should use relative imports.
476 # Import of sibling modules should use relative imports.
475 topname = name.split('.')[0]
477 topname = name.split('.')[0]
476 if topname == topmodule:
478 if topname == topmodule:
477 yield msg('import should be relative: %s', name)
479 yield msg('import should be relative: %s', name)
478
480
479 if name in requirealias and asname != requirealias[name]:
481 if name in requirealias and asname != requirealias[name]:
480 yield msg('%s module must be "as" aliased to %s',
482 yield msg('%s module must be "as" aliased to %s',
481 name, requirealias[name])
483 name, requirealias[name])
482
484
483 elif isinstance(node, ast.ImportFrom):
485 elif isinstance(node, ast.ImportFrom):
484 # Resolve the full imported module name.
486 # Resolve the full imported module name.
485 if node.level > 0:
487 if node.level > 0:
486 fullname = '.'.join(module.split('.')[:-node.level])
488 fullname = '.'.join(module.split('.')[:-node.level])
487 if node.module:
489 if node.module:
488 fullname += '.%s' % node.module
490 fullname += '.%s' % node.module
489 else:
491 else:
490 assert node.module
492 assert node.module
491 fullname = node.module
493 fullname = node.module
492
494
493 topname = fullname.split('.')[0]
495 topname = fullname.split('.')[0]
494 if topname == topmodule:
496 if topname == topmodule:
495 yield msg('import should be relative: %s', fullname)
497 yield msg('import should be relative: %s', fullname)
496
498
497 # __future__ is special since it needs to come first and use
499 # __future__ is special since it needs to come first and use
498 # symbol import.
500 # symbol import.
499 if fullname != '__future__':
501 if fullname != '__future__':
500 if not fullname or (
502 if not fullname or (
501 fullname in stdlib_modules
503 fullname in stdlib_modules
502 and fullname not in localmods
504 and fullname not in localmods
503 and fullname + '.__init__' not in localmods):
505 and fullname + '.__init__' not in localmods):
504 yield msg('relative import of stdlib module')
506 yield msg('relative import of stdlib module')
505 else:
507 else:
506 seenlocal = fullname
508 seenlocal = fullname
507
509
508 # Direct symbol import is only allowed from certain modules and
510 # Direct symbol import is only allowed from certain modules and
509 # must occur before non-symbol imports.
511 # must occur before non-symbol imports.
510 found = fromlocal(node.module, node.level)
512 found = fromlocal(node.module, node.level)
511 if found and found[2]: # node.module is a package
513 if found and found[2]: # node.module is a package
512 prefix = found[0] + '.'
514 prefix = found[0] + '.'
513 symbols = (n.name for n in node.names
515 symbols = (n.name for n in node.names
514 if not fromlocal(prefix + n.name))
516 if not fromlocal(prefix + n.name))
515 else:
517 else:
516 symbols = (n.name for n in node.names)
518 symbols = (n.name for n in node.names)
517 symbols = [sym for sym in symbols if sym not in directsymbols]
519 symbols = [sym for sym in symbols if sym not in directsymbols]
518 if node.module and node.col_offset == root_col_offset:
520 if node.module and node.col_offset == root_col_offset:
519 if symbols and fullname not in allowsymbolimports:
521 if symbols and fullname not in allowsymbolimports:
520 yield msg('direct symbol import %s from %s',
522 yield msg('direct symbol import %s from %s',
521 ', '.join(symbols), fullname)
523 ', '.join(symbols), fullname)
522
524
523 if symbols and seennonsymbollocal:
525 if symbols and seennonsymbollocal:
524 yield msg('symbol import follows non-symbol import: %s',
526 yield msg('symbol import follows non-symbol import: %s',
525 fullname)
527 fullname)
526 if not symbols and fullname not in stdlib_modules:
528 if not symbols and fullname not in stdlib_modules:
527 seennonsymbollocal = True
529 seennonsymbollocal = True
528
530
529 if not node.module:
531 if not node.module:
530 assert node.level
532 assert node.level
531
533
532 # Only allow 1 group per level.
534 # Only allow 1 group per level.
533 if (node.level in seenlevels
535 if (node.level in seenlevels
534 and node.col_offset == root_col_offset):
536 and node.col_offset == root_col_offset):
535 yield msg('multiple "from %s import" statements',
537 yield msg('multiple "from %s import" statements',
536 '.' * node.level)
538 '.' * node.level)
537
539
538 # Higher-level groups come before lower-level groups.
540 # Higher-level groups come before lower-level groups.
539 if any(node.level > l for l in seenlevels):
541 if any(node.level > l for l in seenlevels):
540 yield msg('higher-level import should come first: %s',
542 yield msg('higher-level import should come first: %s',
541 fullname)
543 fullname)
542
544
543 seenlevels.add(node.level)
545 seenlevels.add(node.level)
544
546
545 # Entries in "from .X import ( ... )" lists must be lexically
547 # Entries in "from .X import ( ... )" lists must be lexically
546 # sorted.
548 # sorted.
547 lastentryname = None
549 lastentryname = None
548
550
549 for n in node.names:
551 for n in node.names:
550 if lastentryname and n.name < lastentryname:
552 if lastentryname and n.name < lastentryname:
551 yield msg('imports from %s not lexically sorted: %s < %s',
553 yield msg('imports from %s not lexically sorted: %s < %s',
552 fullname, n.name, lastentryname)
554 fullname, n.name, lastentryname)
553
555
554 lastentryname = n.name
556 lastentryname = n.name
555
557
556 if n.name in requirealias and n.asname != requirealias[n.name]:
558 if n.name in requirealias and n.asname != requirealias[n.name]:
557 yield msg('%s from %s must be "as" aliased to %s',
559 yield msg('%s from %s must be "as" aliased to %s',
558 n.name, fullname, requirealias[n.name])
560 n.name, fullname, requirealias[n.name])
559
561
560 def verify_stdlib_on_own_line(root):
562 def verify_stdlib_on_own_line(root):
561 """Given some python source, verify that stdlib imports are done
563 """Given some python source, verify that stdlib imports are done
562 in separate statements from relative local module imports.
564 in separate statements from relative local module imports.
563
565
564 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, foo')))
566 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, foo')))
565 [('mixed imports\\n stdlib: sys\\n relative: foo', 1)]
567 [('mixed imports\\n stdlib: sys\\n relative: foo', 1)]
566 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, os')))
568 >>> list(verify_stdlib_on_own_line(ast.parse('import sys, os')))
567 []
569 []
568 >>> list(verify_stdlib_on_own_line(ast.parse('import foo, bar')))
570 >>> list(verify_stdlib_on_own_line(ast.parse('import foo, bar')))
569 []
571 []
570 """
572 """
571 for node in ast.walk(root):
573 for node in ast.walk(root):
572 if isinstance(node, ast.Import):
574 if isinstance(node, ast.Import):
573 from_stdlib = {False: [], True: []}
575 from_stdlib = {False: [], True: []}
574 for n in node.names:
576 for n in node.names:
575 from_stdlib[n.name in stdlib_modules].append(n.name)
577 from_stdlib[n.name in stdlib_modules].append(n.name)
576 if from_stdlib[True] and from_stdlib[False]:
578 if from_stdlib[True] and from_stdlib[False]:
577 yield ('mixed imports\n stdlib: %s\n relative: %s' %
579 yield ('mixed imports\n stdlib: %s\n relative: %s' %
578 (', '.join(sorted(from_stdlib[True])),
580 (', '.join(sorted(from_stdlib[True])),
579 ', '.join(sorted(from_stdlib[False]))), node.lineno)
581 ', '.join(sorted(from_stdlib[False]))), node.lineno)
580
582
581 class CircularImport(Exception):
583 class CircularImport(Exception):
582 pass
584 pass
583
585
584 def checkmod(mod, imports):
586 def checkmod(mod, imports):
585 shortest = {}
587 shortest = {}
586 visit = [[mod]]
588 visit = [[mod]]
587 while visit:
589 while visit:
588 path = visit.pop(0)
590 path = visit.pop(0)
589 for i in sorted(imports.get(path[-1], [])):
591 for i in sorted(imports.get(path[-1], [])):
590 if len(path) < shortest.get(i, 1000):
592 if len(path) < shortest.get(i, 1000):
591 shortest[i] = len(path)
593 shortest[i] = len(path)
592 if i in path:
594 if i in path:
593 if i == path[0]:
595 if i == path[0]:
594 raise CircularImport(path)
596 raise CircularImport(path)
595 continue
597 continue
596 visit.append(path + [i])
598 visit.append(path + [i])
597
599
598 def rotatecycle(cycle):
600 def rotatecycle(cycle):
599 """arrange a cycle so that the lexicographically first module listed first
601 """arrange a cycle so that the lexicographically first module listed first
600
602
601 >>> rotatecycle(['foo', 'bar'])
603 >>> rotatecycle(['foo', 'bar'])
602 ['bar', 'foo', 'bar']
604 ['bar', 'foo', 'bar']
603 """
605 """
604 lowest = min(cycle)
606 lowest = min(cycle)
605 idx = cycle.index(lowest)
607 idx = cycle.index(lowest)
606 return cycle[idx:] + cycle[:idx] + [lowest]
608 return cycle[idx:] + cycle[:idx] + [lowest]
607
609
608 def find_cycles(imports):
610 def find_cycles(imports):
609 """Find cycles in an already-loaded import graph.
611 """Find cycles in an already-loaded import graph.
610
612
611 All module names recorded in `imports` should be absolute one.
613 All module names recorded in `imports` should be absolute one.
612
614
613 >>> from __future__ import print_function
615 >>> from __future__ import print_function
614 >>> imports = {'top.foo': ['top.bar', 'os.path', 'top.qux'],
616 >>> imports = {'top.foo': ['top.bar', 'os.path', 'top.qux'],
615 ... 'top.bar': ['top.baz', 'sys'],
617 ... 'top.bar': ['top.baz', 'sys'],
616 ... 'top.baz': ['top.foo'],
618 ... 'top.baz': ['top.foo'],
617 ... 'top.qux': ['top.foo']}
619 ... 'top.qux': ['top.foo']}
618 >>> print('\\n'.join(sorted(find_cycles(imports))))
620 >>> print('\\n'.join(sorted(find_cycles(imports))))
619 top.bar -> top.baz -> top.foo -> top.bar
621 top.bar -> top.baz -> top.foo -> top.bar
620 top.foo -> top.qux -> top.foo
622 top.foo -> top.qux -> top.foo
621 """
623 """
622 cycles = set()
624 cycles = set()
623 for mod in sorted(imports.keys()):
625 for mod in sorted(imports.keys()):
624 try:
626 try:
625 checkmod(mod, imports)
627 checkmod(mod, imports)
626 except CircularImport as e:
628 except CircularImport as e:
627 cycle = e.args[0]
629 cycle = e.args[0]
628 cycles.add(" -> ".join(rotatecycle(cycle)))
630 cycles.add(" -> ".join(rotatecycle(cycle)))
629 return cycles
631 return cycles
630
632
631 def _cycle_sortkey(c):
633 def _cycle_sortkey(c):
632 return len(c), c
634 return len(c), c
633
635
634 def embedded(f, modname, src):
636 def embedded(f, modname, src):
635 """Extract embedded python code
637 """Extract embedded python code
636
638
637 >>> def _forcestr(thing):
639 >>> def _forcestr(thing):
638 ... if not isinstance(thing, str):
640 ... if not isinstance(thing, str):
639 ... return thing.decode('ascii')
641 ... return thing.decode('ascii')
640 ... return thing
642 ... return thing
641 >>> def test(fn, lines):
643 >>> def test(fn, lines):
642 ... for s, m, f, l in embedded(fn, b"example", lines):
644 ... for s, m, f, l in embedded(fn, b"example", lines):
643 ... print("%s %s %d" % (_forcestr(m), _forcestr(f), l))
645 ... print("%s %s %d" % (_forcestr(m), _forcestr(f), l))
644 ... print(repr(_forcestr(s)))
646 ... print(repr(_forcestr(s)))
645 >>> lines = [
647 >>> lines = [
646 ... b'comment',
648 ... b'comment',
647 ... b' >>> from __future__ import print_function',
649 ... b' >>> from __future__ import print_function',
648 ... b" >>> ' multiline",
650 ... b" >>> ' multiline",
649 ... b" ... string'",
651 ... b" ... string'",
650 ... b' ',
652 ... b' ',
651 ... b'comment',
653 ... b'comment',
652 ... b' $ cat > foo.py <<EOF',
654 ... b' $ cat > foo.py <<EOF',
653 ... b' > from __future__ import print_function',
655 ... b' > from __future__ import print_function',
654 ... b' > EOF',
656 ... b' > EOF',
655 ... ]
657 ... ]
656 >>> test(b"example.t", lines)
658 >>> test(b"example.t", lines)
657 example[2] doctest.py 2
659 example[2] doctest.py 2
658 "from __future__ import print_function\\n' multiline\\nstring'\\n"
660 "from __future__ import print_function\\n' multiline\\nstring'\\n"
659 example[7] foo.py 7
661 example[7] foo.py 7
660 'from __future__ import print_function\\n'
662 'from __future__ import print_function\\n'
661 """
663 """
662 inlinepython = 0
664 inlinepython = 0
663 shpython = 0
665 shpython = 0
664 script = []
666 script = []
665 prefix = 6
667 prefix = 6
666 t = ''
668 t = ''
667 n = 0
669 n = 0
668 for l in src:
670 for l in src:
669 n += 1
671 n += 1
670 if not l.endswith(b'\n'):
672 if not l.endswith(b'\n'):
671 l += b'\n'
673 l += b'\n'
672 if l.startswith(b' >>> '): # python inlines
674 if l.startswith(b' >>> '): # python inlines
673 if shpython:
675 if shpython:
674 print("%s:%d: Parse Error" % (f, n))
676 print("%s:%d: Parse Error" % (f, n))
675 if not inlinepython:
677 if not inlinepython:
676 # We've just entered a Python block.
678 # We've just entered a Python block.
677 inlinepython = n
679 inlinepython = n
678 t = b'doctest.py'
680 t = b'doctest.py'
679 script.append(l[prefix:])
681 script.append(l[prefix:])
680 continue
682 continue
681 if l.startswith(b' ... '): # python inlines
683 if l.startswith(b' ... '): # python inlines
682 script.append(l[prefix:])
684 script.append(l[prefix:])
683 continue
685 continue
684 cat = re.search(br"\$ \s*cat\s*>\s*(\S+\.py)\s*<<\s*EOF", l)
686 cat = re.search(br"\$ \s*cat\s*>\s*(\S+\.py)\s*<<\s*EOF", l)
685 if cat:
687 if cat:
686 if inlinepython:
688 if inlinepython:
687 yield b''.join(script), (b"%s[%d]" %
689 yield b''.join(script), (b"%s[%d]" %
688 (modname, inlinepython)), t, inlinepython
690 (modname, inlinepython)), t, inlinepython
689 script = []
691 script = []
690 inlinepython = 0
692 inlinepython = 0
691 shpython = n
693 shpython = n
692 t = cat.group(1)
694 t = cat.group(1)
693 continue
695 continue
694 if shpython and l.startswith(b' > '): # sh continuation
696 if shpython and l.startswith(b' > '): # sh continuation
695 if l == b' > EOF\n':
697 if l == b' > EOF\n':
696 yield b''.join(script), (b"%s[%d]" %
698 yield b''.join(script), (b"%s[%d]" %
697 (modname, shpython)), t, shpython
699 (modname, shpython)), t, shpython
698 script = []
700 script = []
699 shpython = 0
701 shpython = 0
700 else:
702 else:
701 script.append(l[4:])
703 script.append(l[4:])
702 continue
704 continue
703 # If we have an empty line or a command for sh, we end the
705 # If we have an empty line or a command for sh, we end the
704 # inline script.
706 # inline script.
705 if inlinepython and (l == b' \n'
707 if inlinepython and (l == b' \n'
706 or l.startswith(b' $ ')):
708 or l.startswith(b' $ ')):
707 yield b''.join(script), (b"%s[%d]" %
709 yield b''.join(script), (b"%s[%d]" %
708 (modname, inlinepython)), t, inlinepython
710 (modname, inlinepython)), t, inlinepython
709 script = []
711 script = []
710 inlinepython = 0
712 inlinepython = 0
711 continue
713 continue
712
714
713 def sources(f, modname):
715 def sources(f, modname):
714 """Yields possibly multiple sources from a filepath
716 """Yields possibly multiple sources from a filepath
715
717
716 input: filepath, modulename
718 input: filepath, modulename
717 yields: script(string), modulename, filepath, linenumber
719 yields: script(string), modulename, filepath, linenumber
718
720
719 For embedded scripts, the modulename and filepath will be different
721 For embedded scripts, the modulename and filepath will be different
720 from the function arguments. linenumber is an offset relative to
722 from the function arguments. linenumber is an offset relative to
721 the input file.
723 the input file.
722 """
724 """
723 py = False
725 py = False
724 if not f.endswith('.t'):
726 if not f.endswith('.t'):
725 with open(f, 'rb') as src:
727 with open(f, 'rb') as src:
726 yield src.read(), modname, f, 0
728 yield src.read(), modname, f, 0
727 py = True
729 py = True
728 if py or f.endswith('.t'):
730 if py or f.endswith('.t'):
729 with open(f, 'rb') as src:
731 with open(f, 'rb') as src:
730 for script, modname, t, line in embedded(f, modname, src):
732 for script, modname, t, line in embedded(f, modname, src):
731 yield script, modname, t, line
733 yield script, modname, t, line
732
734
733 def main(argv):
735 def main(argv):
734 if len(argv) < 2 or (argv[1] == '-' and len(argv) > 2):
736 if len(argv) < 2 or (argv[1] == '-' and len(argv) > 2):
735 print('Usage: %s {-|file [file] [file] ...}')
737 print('Usage: %s {-|file [file] [file] ...}')
736 return 1
738 return 1
737 if argv[1] == '-':
739 if argv[1] == '-':
738 argv = argv[:1]
740 argv = argv[:1]
739 argv.extend(l.rstrip() for l in sys.stdin.readlines())
741 argv.extend(l.rstrip() for l in sys.stdin.readlines())
740 localmodpaths = {}
742 localmodpaths = {}
741 used_imports = {}
743 used_imports = {}
742 any_errors = False
744 any_errors = False
743 for source_path in argv[1:]:
745 for source_path in argv[1:]:
744 modname = dotted_name_of_path(source_path)
746 modname = dotted_name_of_path(source_path)
745 localmodpaths[modname] = source_path
747 localmodpaths[modname] = source_path
746 localmods = populateextmods(localmodpaths)
748 localmods = populateextmods(localmodpaths)
747 for localmodname, source_path in sorted(localmodpaths.items()):
749 for localmodname, source_path in sorted(localmodpaths.items()):
748 if not isinstance(localmodname, bytes):
750 if not isinstance(localmodname, bytes):
749 # This is only safe because all hg's files are ascii
751 # This is only safe because all hg's files are ascii
750 localmodname = localmodname.encode('ascii')
752 localmodname = localmodname.encode('ascii')
751 for src, modname, name, line in sources(source_path, localmodname):
753 for src, modname, name, line in sources(source_path, localmodname):
752 try:
754 try:
753 used_imports[modname] = sorted(
755 used_imports[modname] = sorted(
754 imported_modules(src, modname, name, localmods,
756 imported_modules(src, modname, name, localmods,
755 ignore_nested=True))
757 ignore_nested=True))
756 for error, lineno in verify_import_convention(modname, src,
758 for error, lineno in verify_import_convention(modname, src,
757 localmods):
759 localmods):
758 any_errors = True
760 any_errors = True
759 print('%s:%d: %s' % (source_path, lineno + line, error))
761 print('%s:%d: %s' % (source_path, lineno + line, error))
760 except SyntaxError as e:
762 except SyntaxError as e:
761 print('%s:%d: SyntaxError: %s' %
763 print('%s:%d: SyntaxError: %s' %
762 (source_path, e.lineno + line, e))
764 (source_path, e.lineno + line, e))
763 cycles = find_cycles(used_imports)
765 cycles = find_cycles(used_imports)
764 if cycles:
766 if cycles:
765 firstmods = set()
767 firstmods = set()
766 for c in sorted(cycles, key=_cycle_sortkey):
768 for c in sorted(cycles, key=_cycle_sortkey):
767 first = c.split()[0]
769 first = c.split()[0]
768 # As a rough cut, ignore any cycle that starts with the
770 # As a rough cut, ignore any cycle that starts with the
769 # same module as some other cycle. Otherwise we see lots
771 # same module as some other cycle. Otherwise we see lots
770 # of cycles that are effectively duplicates.
772 # of cycles that are effectively duplicates.
771 if first in firstmods:
773 if first in firstmods:
772 continue
774 continue
773 print('Import cycle:', c)
775 print('Import cycle:', c)
774 firstmods.add(first)
776 firstmods.add(first)
775 any_errors = True
777 any_errors = True
776 return any_errors != 0
778 return any_errors != 0
777
779
778 if __name__ == '__main__':
780 if __name__ == '__main__':
779 sys.exit(int(main(sys.argv)))
781 sys.exit(int(main(sys.argv)))
General Comments 0
You need to be logged in to leave comments. Login now