##// END OF EJS Templates
typing: add type hints to most mercurial/pycompat.py functions...
Matt Harbison -
r50702:c5a06cc3 default
parent child Browse files
Show More
@@ -1,432 +1,478 b''
1 1 # pycompat.py - portability shim for python 3
2 2 #
3 3 # This software may be used and distributed according to the terms of the
4 4 # GNU General Public License version 2 or any later version.
5 5
6 6 """Mercurial portability shim for python 3.
7 7
8 8 This contains aliases to hide python version-specific details from the core.
9 9 """
10 10
11 11
12 12 import builtins
13 13 import codecs
14 14 import concurrent.futures as futures
15 15 import functools
16 16 import getopt
17 17 import http.client as httplib
18 18 import http.cookiejar as cookielib
19 19 import inspect
20 20 import io
21 21 import json
22 22 import os
23 23 import queue
24 24 import shlex
25 25 import socketserver
26 26 import struct
27 27 import sys
28 28 import tempfile
29 29 import xmlrpc.client as xmlrpclib
30 30
31 31 from typing import (
32 Any,
33 AnyStr,
34 BinaryIO,
35 Dict,
32 36 Iterable,
33 37 Iterator,
34 38 List,
39 Mapping,
40 NoReturn,
35 41 Optional,
42 Sequence,
43 Tuple,
36 44 Type,
37 45 TypeVar,
46 cast,
47 overload,
38 48 )
39 49
40 50 ispy3 = sys.version_info[0] >= 3
41 51 ispypy = '__pypy__' in sys.builtin_module_names
42 52 TYPE_CHECKING = False
43 53
44 54 if not globals(): # hide this from non-pytype users
45 55 import typing
46 56
47 57 TYPE_CHECKING = typing.TYPE_CHECKING
48 58
59 _GetOptResult = Tuple[List[Tuple[bytes, bytes]], List[bytes]]
60 _T0 = TypeVar('_T0')
49 61 _Tbytestr = TypeVar('_Tbytestr', bound='bytestr')
50 62
51 63
52 64 def future_set_exception_info(f, exc_info):
53 65 f.set_exception(exc_info[0])
54 66
55 67
56 68 FileNotFoundError = builtins.FileNotFoundError
57 69
58 70
59 def identity(a):
71 def identity(a: _T0) -> _T0:
60 72 return a
61 73
62 74
63 75 def _rapply(f, xs):
64 76 if xs is None:
65 77 # assume None means non-value of optional data
66 78 return xs
67 79 if isinstance(xs, (list, set, tuple)):
68 80 return type(xs)(_rapply(f, x) for x in xs)
69 81 if isinstance(xs, dict):
70 82 return type(xs)((_rapply(f, k), _rapply(f, v)) for k, v in xs.items())
71 83 return f(xs)
72 84
73 85
74 86 def rapply(f, xs):
75 87 """Apply function recursively to every item preserving the data structure
76 88
77 89 >>> def f(x):
78 90 ... return 'f(%s)' % x
79 91 >>> rapply(f, None) is None
80 92 True
81 93 >>> rapply(f, 'a')
82 94 'f(a)'
83 95 >>> rapply(f, {'a'}) == {'f(a)'}
84 96 True
85 97 >>> rapply(f, ['a', 'b', None, {'c': 'd'}, []])
86 98 ['f(a)', 'f(b)', None, {'f(c)': 'f(d)'}, []]
87 99
88 100 >>> xs = [object()]
89 101 >>> rapply(identity, xs) is xs
90 102 True
91 103 """
92 104 if f is identity:
93 105 # fast path mainly for py2
94 106 return xs
95 107 return _rapply(f, xs)
96 108
97 109
98 110 if os.name == r'nt':
99 111 # MBCS (or ANSI) filesystem encoding must be used as before.
100 112 # Otherwise non-ASCII filenames in existing repositories would be
101 113 # corrupted.
102 114 # This must be set once prior to any fsencode/fsdecode calls.
103 115 sys._enablelegacywindowsfsencoding() # pytype: disable=module-attr
104 116
105 117 fsencode = os.fsencode
106 118 fsdecode = os.fsdecode
107 119 oscurdir: bytes = os.curdir.encode('ascii')
108 120 oslinesep: bytes = os.linesep.encode('ascii')
109 121 osname: bytes = os.name.encode('ascii')
110 122 ospathsep: bytes = os.pathsep.encode('ascii')
111 123 ospardir: bytes = os.pardir.encode('ascii')
112 124 ossep: bytes = os.sep.encode('ascii')
113 125 osaltsep: Optional[bytes] = os.altsep.encode('ascii') if os.altsep else None
114 126 osdevnull: bytes = os.devnull.encode('ascii')
115 127
116 128 sysplatform: bytes = sys.platform.encode('ascii')
117 129 sysexecutable: bytes = os.fsencode(sys.executable) if sys.executable else b''
118 130
119 131
120 132 def maplist(*args):
121 133 return list(map(*args))
122 134
123 135
124 136 def rangelist(*args):
125 137 return list(range(*args))
126 138
127 139
128 140 def ziplist(*args):
129 141 return list(zip(*args))
130 142
131 143
132 144 rawinput = input
133 145 getargspec = inspect.getfullargspec
134 146
135 147 long = int
136 148
137 149 if builtins.getattr(sys, 'argv', None) is not None:
138 150 # On POSIX, the char** argv array is converted to Python str using
139 151 # Py_DecodeLocale(). The inverse of this is Py_EncodeLocale(), which
140 152 # isn't directly callable from Python code. In practice, os.fsencode()
141 153 # can be used instead (this is recommended by Python's documentation
142 154 # for sys.argv).
143 155 #
144 156 # On Windows, the wchar_t **argv is passed into the interpreter as-is.
145 157 # Like POSIX, we need to emulate what Py_EncodeLocale() would do. But
146 158 # there's an additional wrinkle. What we really want to access is the
147 159 # ANSI codepage representation of the arguments, as this is what
148 160 # `int main()` would receive if Python 3 didn't define `int wmain()`
149 161 # (this is how Python 2 worked). To get that, we encode with the mbcs
150 162 # encoding, which will pass CP_ACP to the underlying Windows API to
151 163 # produce bytes.
152 164 sysargv: List[bytes] = []
153 165 if os.name == r'nt':
154 166 sysargv = [a.encode("mbcs", "ignore") for a in sys.argv]
155 167 else:
156 168 sysargv = [fsencode(a) for a in sys.argv]
157 169
158 170 bytechr = struct.Struct('>B').pack
159 171 byterepr = b'%r'.__mod__
160 172
161 173
162 174 class bytestr(bytes):
163 175 """A bytes which mostly acts as a Python 2 str
164 176
165 177 >>> bytestr(), bytestr(bytearray(b'foo')), bytestr(u'ascii'), bytestr(1)
166 178 ('', 'foo', 'ascii', '1')
167 179 >>> s = bytestr(b'foo')
168 180 >>> assert s is bytestr(s)
169 181
170 182 __bytes__() should be called if provided:
171 183
172 184 >>> class bytesable:
173 185 ... def __bytes__(self):
174 186 ... return b'bytes'
175 187 >>> bytestr(bytesable())
176 188 'bytes'
177 189
178 190 There's no implicit conversion from non-ascii str as its encoding is
179 191 unknown:
180 192
181 193 >>> bytestr(chr(0x80)) # doctest: +ELLIPSIS
182 194 Traceback (most recent call last):
183 195 ...
184 196 UnicodeEncodeError: ...
185 197
186 198 Comparison between bytestr and bytes should work:
187 199
188 200 >>> assert bytestr(b'foo') == b'foo'
189 201 >>> assert b'foo' == bytestr(b'foo')
190 202 >>> assert b'f' in bytestr(b'foo')
191 203 >>> assert bytestr(b'f') in b'foo'
192 204
193 205 Sliced elements should be bytes, not integer:
194 206
195 207 >>> s[1], s[:2]
196 208 (b'o', b'fo')
197 209 >>> list(s), list(reversed(s))
198 210 ([b'f', b'o', b'o'], [b'o', b'o', b'f'])
199 211
200 212 As bytestr type isn't propagated across operations, you need to cast
201 213 bytes to bytestr explicitly:
202 214
203 215 >>> s = bytestr(b'foo').upper()
204 216 >>> t = bytestr(s)
205 217 >>> s[0], t[0]
206 218 (70, b'F')
207 219
208 220 Be careful to not pass a bytestr object to a function which expects
209 221 bytearray-like behavior.
210 222
211 223 >>> t = bytes(t) # cast to bytes
212 224 >>> assert type(t) is bytes
213 225 """
214 226
215 227 # Trick pytype into not demanding Iterable[int] be passed to __new__(),
216 228 # since the appropriate bytes format is done internally.
217 229 #
218 230 # https://github.com/google/pytype/issues/500
219 231 if TYPE_CHECKING:
220 232
221 233 def __init__(self, s: object = b'') -> None:
222 234 pass
223 235
224 236 def __new__(cls: Type[_Tbytestr], s: object = b'') -> _Tbytestr:
225 237 if isinstance(s, bytestr):
226 238 return s
227 239 if not isinstance(
228 240 s, (bytes, bytearray)
229 241 ) and not builtins.hasattr( # hasattr-py3-only
230 242 s, u'__bytes__'
231 243 ):
232 244 s = str(s).encode('ascii')
233 245 return bytes.__new__(cls, s)
234 246
235 247 def __getitem__(self, key) -> bytes:
236 248 s = bytes.__getitem__(self, key)
237 249 if not isinstance(s, bytes):
238 250 s = bytechr(s)
239 251 return s
240 252
241 253 def __iter__(self) -> Iterator[bytes]:
242 254 return iterbytestr(bytes.__iter__(self))
243 255
244 256 def __repr__(self) -> str:
245 257 return bytes.__repr__(self)[1:] # drop b''
246 258
247 259
248 260 def iterbytestr(s: Iterable[int]) -> Iterator[bytes]:
249 261 """Iterate bytes as if it were a str object of Python 2"""
250 262 return map(bytechr, s)
251 263
252 264
265 if TYPE_CHECKING:
266
267 @overload
268 def maybebytestr(s: bytes) -> bytestr:
269 ...
270
271 @overload
272 def maybebytestr(s: _T0) -> _T0:
273 ...
274
275
253 276 def maybebytestr(s):
254 277 """Promote bytes to bytestr"""
255 278 if isinstance(s, bytes):
256 279 return bytestr(s)
257 280 return s
258 281
259 282
260 def sysbytes(s):
283 def sysbytes(s: AnyStr) -> bytes:
261 284 """Convert an internal str (e.g. keyword, __doc__) back to bytes
262 285
263 286 This never raises UnicodeEncodeError, but only ASCII characters
264 287 can be round-trip by sysstr(sysbytes(s)).
265 288 """
266 289 if isinstance(s, bytes):
267 290 return s
268 291 return s.encode('utf-8')
269 292
270 293
271 def sysstr(s):
294 def sysstr(s: AnyStr) -> str:
272 295 """Return a keyword str to be passed to Python functions such as
273 296 getattr() and str.encode()
274 297
275 298 This never raises UnicodeDecodeError. Non-ascii characters are
276 299 considered invalid and mapped to arbitrary but unique code points
277 300 such that 'sysstr(a) != sysstr(b)' for all 'a != b'.
278 301 """
279 302 if isinstance(s, builtins.str):
280 303 return s
281 304 return s.decode('latin-1')
282 305
283 306
284 def strurl(url):
307 def strurl(url: AnyStr) -> str:
285 308 """Converts a bytes url back to str"""
286 309 if isinstance(url, bytes):
287 310 return url.decode('ascii')
288 311 return url
289 312
290 313
291 def bytesurl(url):
314 def bytesurl(url: AnyStr) -> bytes:
292 315 """Converts a str url to bytes by encoding in ascii"""
293 316 if isinstance(url, str):
294 317 return url.encode('ascii')
295 318 return url
296 319
297 320
298 def raisewithtb(exc, tb):
321 def raisewithtb(exc: BaseException, tb) -> NoReturn:
299 322 """Raise exception with the given traceback"""
300 323 raise exc.with_traceback(tb)
301 324
302 325
303 def getdoc(obj):
326 def getdoc(obj: object) -> Optional[bytes]:
304 327 """Get docstring as bytes; may be None so gettext() won't confuse it
305 328 with _('')"""
306 329 doc = builtins.getattr(obj, '__doc__', None)
307 330 if doc is None:
308 331 return doc
309 332 return sysbytes(doc)
310 333
311 334
312 335 def _wrapattrfunc(f):
313 336 @functools.wraps(f)
314 337 def w(object, name, *args):
315 338 return f(object, sysstr(name), *args)
316 339
317 340 return w
318 341
319 342
320 343 # these wrappers are automagically imported by hgloader
321 344 delattr = _wrapattrfunc(builtins.delattr)
322 345 getattr = _wrapattrfunc(builtins.getattr)
323 346 hasattr = _wrapattrfunc(builtins.hasattr)
324 347 setattr = _wrapattrfunc(builtins.setattr)
325 348 xrange = builtins.range
326 349 unicode = str
327 350
328 351
329 def open(name, mode=b'r', buffering=-1, encoding=None):
352 def open(
353 name,
354 mode: AnyStr = b'r',
355 buffering: int = -1,
356 encoding: Optional[str] = None,
357 ) -> Any:
358 # TODO: assert binary mode, and cast result to BinaryIO?
330 359 return builtins.open(name, sysstr(mode), buffering, encoding)
331 360
332 361
333 362 safehasattr = _wrapattrfunc(builtins.hasattr)
334 363
335 364
336 def _getoptbwrapper(orig, args, shortlist, namelist):
365 def _getoptbwrapper(
366 orig, args: Sequence[bytes], shortlist: bytes, namelist: Sequence[bytes]
367 ) -> _GetOptResult:
337 368 """
338 369 Takes bytes arguments, converts them to unicode, pass them to
339 370 getopt.getopt(), convert the returned values back to bytes and then
340 371 return them for Python 3 compatibility as getopt.getopt() don't accepts
341 372 bytes on Python 3.
342 373 """
343 374 args = [a.decode('latin-1') for a in args]
344 375 shortlist = shortlist.decode('latin-1')
345 376 namelist = [a.decode('latin-1') for a in namelist]
346 377 opts, args = orig(args, shortlist, namelist)
347 378 opts = [(a[0].encode('latin-1'), a[1].encode('latin-1')) for a in opts]
348 379 args = [a.encode('latin-1') for a in args]
349 380 return opts, args
350 381
351 382
352 def strkwargs(dic):
383 def strkwargs(dic: Mapping[bytes, _T0]) -> Dict[str, _T0]:
353 384 """
354 385 Converts the keys of a python dictonary to str i.e. unicodes so that
355 386 they can be passed as keyword arguments as dictionaries with bytes keys
356 387 can't be passed as keyword arguments to functions on Python 3.
357 388 """
358 389 dic = {k.decode('latin-1'): v for k, v in dic.items()}
359 390 return dic
360 391
361 392
362 def byteskwargs(dic):
393 def byteskwargs(dic: Mapping[str, _T0]) -> Dict[bytes, _T0]:
363 394 """
364 395 Converts keys of python dictionaries to bytes as they were converted to
365 396 str to pass that dictonary as a keyword argument on Python 3.
366 397 """
367 398 dic = {k.encode('latin-1'): v for k, v in dic.items()}
368 399 return dic
369 400
370 401
371 402 # TODO: handle shlex.shlex().
372 def shlexsplit(s, comments=False, posix=True):
403 def shlexsplit(
404 s: bytes, comments: bool = False, posix: bool = True
405 ) -> List[bytes]:
373 406 """
374 407 Takes bytes argument, convert it to str i.e. unicodes, pass that into
375 408 shlex.split(), convert the returned value to bytes and return that for
376 409 Python 3 compatibility as shelx.split() don't accept bytes on Python 3.
377 410 """
378 411 ret = shlex.split(s.decode('latin-1'), comments, posix)
379 412 return [a.encode('latin-1') for a in ret]
380 413
381 414
382 415 iteritems = lambda x: x.items()
383 416 itervalues = lambda x: x.values()
384 417
385 418 json_loads = json.loads
386 419
387 420 isjython: bool = sysplatform.startswith(b'java')
388 421
389 422 isdarwin: bool = sysplatform.startswith(b'darwin')
390 423 islinux: bool = sysplatform.startswith(b'linux')
391 424 isposix: bool = osname == b'posix'
392 425 iswindows: bool = osname == b'nt'
393 426
394 427
395 def getoptb(args, shortlist, namelist):
428 def getoptb(
429 args: Sequence[bytes], shortlist: bytes, namelist: Sequence[bytes]
430 ) -> _GetOptResult:
396 431 return _getoptbwrapper(getopt.getopt, args, shortlist, namelist)
397 432
398 433
399 def gnugetoptb(args, shortlist, namelist):
434 def gnugetoptb(
435 args: Sequence[bytes], shortlist: bytes, namelist: Sequence[bytes]
436 ) -> _GetOptResult:
400 437 return _getoptbwrapper(getopt.gnu_getopt, args, shortlist, namelist)
401 438
402 439
403 def mkdtemp(suffix=b'', prefix=b'tmp', dir=None):
440 def mkdtemp(
441 suffix: bytes = b'', prefix: bytes = b'tmp', dir: Optional[bytes] = None
442 ) -> bytes:
404 443 return tempfile.mkdtemp(suffix, prefix, dir)
405 444
406 445
407 446 # text=True is not supported; use util.from/tonativeeol() instead
408 def mkstemp(suffix=b'', prefix=b'tmp', dir=None):
447 def mkstemp(
448 suffix: bytes = b'', prefix: bytes = b'tmp', dir: Optional[bytes] = None
449 ) -> Tuple[int, bytes]:
409 450 return tempfile.mkstemp(suffix, prefix, dir)
410 451
411 452
412 453 # TemporaryFile does not support an "encoding=" argument on python2.
413 454 # This wrapper file are always open in byte mode.
414 def unnamedtempfile(mode=None, *args, **kwargs):
455 def unnamedtempfile(mode: Optional[bytes] = None, *args, **kwargs) -> BinaryIO:
415 456 if mode is None:
416 457 mode = 'w+b'
417 458 else:
418 459 mode = sysstr(mode)
419 460 assert 'b' in mode
420 return tempfile.TemporaryFile(mode, *args, **kwargs)
461 return cast(BinaryIO, tempfile.TemporaryFile(mode, *args, **kwargs))
421 462
422 463
423 464 # NamedTemporaryFile does not support an "encoding=" argument on python2.
424 465 # This wrapper file are always open in byte mode.
425 466 def namedtempfile(
426 mode=b'w+b', bufsize=-1, suffix=b'', prefix=b'tmp', dir=None, delete=True
467 mode: bytes = b'w+b',
468 bufsize: int = -1,
469 suffix: bytes = b'',
470 prefix: bytes = b'tmp',
471 dir: Optional[bytes] = None,
472 delete: bool = True,
427 473 ):
428 474 mode = sysstr(mode)
429 475 assert 'b' in mode
430 476 return tempfile.NamedTemporaryFile(
431 477 mode, bufsize, suffix=suffix, prefix=prefix, dir=dir, delete=delete
432 478 )
General Comments 0
You need to be logged in to leave comments. Login now