##// END OF EJS Templates
contrib: add a fork of black (as "grey") that includes my changes...
Augie Fackler -
r43353:7054fd37 default
parent child Browse files
Show More
This diff has been collapsed as it changes many lines, (4094 lines changed) Show them Hide them
@@ -0,0 +1,4094 b''
1 # no-check-code because 3rd party
2 import ast
3 import asyncio
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 from functools import lru_cache, partial, wraps
9 import io
10 import itertools
11 import logging
12 from multiprocessing import Manager, freeze_support
13 import os
14 from pathlib import Path
15 import pickle
16 import re
17 import signal
18 import sys
19 import tempfile
20 import tokenize
21 import traceback
22 from typing import (
23 Any,
24 Callable,
25 Collection,
26 Dict,
27 Generator,
28 Generic,
29 Iterable,
30 Iterator,
31 List,
32 Optional,
33 Pattern,
34 Sequence,
35 Set,
36 Tuple,
37 TypeVar,
38 Union,
39 cast,
40 )
41
42 from appdirs import user_cache_dir
43 from attr import dataclass, evolve, Factory
44 import click
45 import toml
46 from typed_ast import ast3, ast27
47
48 # lib2to3 fork
49 from blib2to3.pytree import Node, Leaf, type_repr
50 from blib2to3 import pygram, pytree
51 from blib2to3.pgen2 import driver, token
52 from blib2to3.pgen2.grammar import Grammar
53 from blib2to3.pgen2.parse import ParseError
54
55 __version__ = '19.3b1.dev95+gdc1add6.d20191005'
56
57 DEFAULT_LINE_LENGTH = 88
58 DEFAULT_EXCLUDES = (
59 r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
60 )
61 DEFAULT_INCLUDES = r"\.pyi?$"
62 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
63
64
65 # types
66 FileContent = str
67 Encoding = str
68 NewLine = str
69 Depth = int
70 NodeType = int
71 LeafID = int
72 Priority = int
73 Index = int
74 LN = Union[Leaf, Node]
75 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
76 Timestamp = float
77 FileSize = int
78 CacheInfo = Tuple[Timestamp, FileSize]
79 Cache = Dict[Path, CacheInfo]
80 out = partial(click.secho, bold=True, err=True)
81 err = partial(click.secho, fg="red", err=True)
82
83 pygram.initialize(CACHE_DIR)
84 syms = pygram.python_symbols
85
86
87 class NothingChanged(UserWarning):
88 """Raised when reformatted code is the same as source."""
89
90
91 class CannotSplit(Exception):
92 """A readable split that fits the allotted line length is impossible."""
93
94
95 class InvalidInput(ValueError):
96 """Raised when input source code fails all parse attempts."""
97
98
99 class WriteBack(Enum):
100 NO = 0
101 YES = 1
102 DIFF = 2
103 CHECK = 3
104
105 @classmethod
106 def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
107 if check and not diff:
108 return cls.CHECK
109
110 return cls.DIFF if diff else cls.YES
111
112
113 class Changed(Enum):
114 NO = 0
115 CACHED = 1
116 YES = 2
117
118
119 class TargetVersion(Enum):
120 PY27 = 2
121 PY33 = 3
122 PY34 = 4
123 PY35 = 5
124 PY36 = 6
125 PY37 = 7
126 PY38 = 8
127
128 def is_python2(self) -> bool:
129 return self is TargetVersion.PY27
130
131
132 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
133
134
135 class Feature(Enum):
136 # All string literals are unicode
137 UNICODE_LITERALS = 1
138 F_STRINGS = 2
139 NUMERIC_UNDERSCORES = 3
140 TRAILING_COMMA_IN_CALL = 4
141 TRAILING_COMMA_IN_DEF = 5
142 # The following two feature-flags are mutually exclusive, and exactly one should be
143 # set for every version of python.
144 ASYNC_IDENTIFIERS = 6
145 ASYNC_KEYWORDS = 7
146 ASSIGNMENT_EXPRESSIONS = 8
147 POS_ONLY_ARGUMENTS = 9
148
149
150 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
151 TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
152 TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
153 TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
154 TargetVersion.PY35: {
155 Feature.UNICODE_LITERALS,
156 Feature.TRAILING_COMMA_IN_CALL,
157 Feature.ASYNC_IDENTIFIERS,
158 },
159 TargetVersion.PY36: {
160 Feature.UNICODE_LITERALS,
161 Feature.F_STRINGS,
162 Feature.NUMERIC_UNDERSCORES,
163 Feature.TRAILING_COMMA_IN_CALL,
164 Feature.TRAILING_COMMA_IN_DEF,
165 Feature.ASYNC_IDENTIFIERS,
166 },
167 TargetVersion.PY37: {
168 Feature.UNICODE_LITERALS,
169 Feature.F_STRINGS,
170 Feature.NUMERIC_UNDERSCORES,
171 Feature.TRAILING_COMMA_IN_CALL,
172 Feature.TRAILING_COMMA_IN_DEF,
173 Feature.ASYNC_KEYWORDS,
174 },
175 TargetVersion.PY38: {
176 Feature.UNICODE_LITERALS,
177 Feature.F_STRINGS,
178 Feature.NUMERIC_UNDERSCORES,
179 Feature.TRAILING_COMMA_IN_CALL,
180 Feature.TRAILING_COMMA_IN_DEF,
181 Feature.ASYNC_KEYWORDS,
182 Feature.ASSIGNMENT_EXPRESSIONS,
183 Feature.POS_ONLY_ARGUMENTS,
184 },
185 }
186
187
188 @dataclass
189 class FileMode:
190 target_versions: Set[TargetVersion] = Factory(set)
191 line_length: int = DEFAULT_LINE_LENGTH
192 string_normalization: bool = True
193 is_pyi: bool = False
194
195 def get_cache_key(self) -> str:
196 if self.target_versions:
197 version_str = ",".join(
198 str(version.value)
199 for version in sorted(self.target_versions, key=lambda v: v.value)
200 )
201 else:
202 version_str = "-"
203 parts = [
204 version_str,
205 str(self.line_length),
206 str(int(self.string_normalization)),
207 str(int(self.is_pyi)),
208 ]
209 return ".".join(parts)
210
211
212 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
213 return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
214
215
216 def read_pyproject_toml(
217 ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
218 ) -> Optional[str]:
219 """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
220
221 Returns the path to a successfully found and read configuration file, None
222 otherwise.
223 """
224 assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
225 if not value:
226 root = find_project_root(ctx.params.get("src", ()))
227 path = root / "pyproject.toml"
228 if path.is_file():
229 value = str(path)
230 else:
231 return None
232
233 try:
234 pyproject_toml = toml.load(value)
235 config = pyproject_toml.get("tool", {}).get("black", {})
236 except (toml.TomlDecodeError, OSError) as e:
237 raise click.FileError(
238 filename=value, hint=f"Error reading configuration file: {e}"
239 )
240
241 if not config:
242 return None
243
244 if ctx.default_map is None:
245 ctx.default_map = {}
246 ctx.default_map.update( # type: ignore # bad types in .pyi
247 {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
248 )
249 return value
250
251
252 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
253 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
254 @click.option(
255 "-l",
256 "--line-length",
257 type=int,
258 default=DEFAULT_LINE_LENGTH,
259 help="How many characters per line to allow.",
260 show_default=True,
261 )
262 @click.option(
263 "-t",
264 "--target-version",
265 type=click.Choice([v.name.lower() for v in TargetVersion]),
266 callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
267 multiple=True,
268 help=(
269 "Python versions that should be supported by Black's output. [default: "
270 "per-file auto-detection]"
271 ),
272 )
273 @click.option(
274 "--py36",
275 is_flag=True,
276 help=(
277 "Allow using Python 3.6-only syntax on all input files. This will put "
278 "trailing commas in function signatures and calls also after *args and "
279 "**kwargs. Deprecated; use --target-version instead. "
280 "[default: per-file auto-detection]"
281 ),
282 )
283 @click.option(
284 "--pyi",
285 is_flag=True,
286 help=(
287 "Format all input files like typing stubs regardless of file extension "
288 "(useful when piping source on standard input)."
289 ),
290 )
291 @click.option(
292 "-S",
293 "--skip-string-normalization",
294 is_flag=True,
295 help="Don't normalize string quotes or prefixes.",
296 )
297 @click.option(
298 "--check",
299 is_flag=True,
300 help=(
301 "Don't write the files back, just return the status. Return code 0 "
302 "means nothing would change. Return code 1 means some files would be "
303 "reformatted. Return code 123 means there was an internal error."
304 ),
305 )
306 @click.option(
307 "--diff",
308 is_flag=True,
309 help="Don't write the files back, just output a diff for each file on stdout.",
310 )
311 @click.option(
312 "--fast/--safe",
313 is_flag=True,
314 help="If --fast given, skip temporary sanity checks. [default: --safe]",
315 )
316 @click.option(
317 "--include",
318 type=str,
319 default=DEFAULT_INCLUDES,
320 help=(
321 "A regular expression that matches files and directories that should be "
322 "included on recursive searches. An empty value means all files are "
323 "included regardless of the name. Use forward slashes for directories on "
324 "all platforms (Windows, too). Exclusions are calculated first, inclusions "
325 "later."
326 ),
327 show_default=True,
328 )
329 @click.option(
330 "--exclude",
331 type=str,
332 default=DEFAULT_EXCLUDES,
333 help=(
334 "A regular expression that matches files and directories that should be "
335 "excluded on recursive searches. An empty value means no paths are excluded. "
336 "Use forward slashes for directories on all platforms (Windows, too). "
337 "Exclusions are calculated first, inclusions later."
338 ),
339 show_default=True,
340 )
341 @click.option(
342 "-q",
343 "--quiet",
344 is_flag=True,
345 help=(
346 "Don't emit non-error messages to stderr. Errors are still emitted; "
347 "silence those with 2>/dev/null."
348 ),
349 )
350 @click.option(
351 "-v",
352 "--verbose",
353 is_flag=True,
354 help=(
355 "Also emit messages to stderr about files that were not changed or were "
356 "ignored due to --exclude=."
357 ),
358 )
359 @click.version_option(version=__version__)
360 @click.argument(
361 "src",
362 nargs=-1,
363 type=click.Path(
364 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
365 ),
366 is_eager=True,
367 )
368 @click.option(
369 "--config",
370 type=click.Path(
371 exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
372 ),
373 is_eager=True,
374 callback=read_pyproject_toml,
375 help="Read configuration from PATH.",
376 )
377 @click.pass_context
378 def main(
379 ctx: click.Context,
380 code: Optional[str],
381 line_length: int,
382 target_version: List[TargetVersion],
383 check: bool,
384 diff: bool,
385 fast: bool,
386 pyi: bool,
387 py36: bool,
388 skip_string_normalization: bool,
389 quiet: bool,
390 verbose: bool,
391 include: str,
392 exclude: str,
393 src: Tuple[str],
394 config: Optional[str],
395 ) -> None:
396 """The uncompromising code formatter."""
397 write_back = WriteBack.from_configuration(check=check, diff=diff)
398 if target_version:
399 if py36:
400 err(f"Cannot use both --target-version and --py36")
401 ctx.exit(2)
402 else:
403 versions = set(target_version)
404 elif py36:
405 err(
406 "--py36 is deprecated and will be removed in a future version. "
407 "Use --target-version py36 instead."
408 )
409 versions = PY36_VERSIONS
410 else:
411 # We'll autodetect later.
412 versions = set()
413 mode = FileMode(
414 target_versions=versions,
415 line_length=line_length,
416 is_pyi=pyi,
417 string_normalization=not skip_string_normalization,
418 )
419 if config and verbose:
420 out(f"Using configuration from {config}.", bold=False, fg="blue")
421 if code is not None:
422 print(format_str(code, mode=mode))
423 ctx.exit(0)
424 try:
425 include_regex = re_compile_maybe_verbose(include)
426 except re.error:
427 err(f"Invalid regular expression for include given: {include!r}")
428 ctx.exit(2)
429 try:
430 exclude_regex = re_compile_maybe_verbose(exclude)
431 except re.error:
432 err(f"Invalid regular expression for exclude given: {exclude!r}")
433 ctx.exit(2)
434 report = Report(check=check, quiet=quiet, verbose=verbose)
435 root = find_project_root(src)
436 sources: Set[Path] = set()
437 path_empty(src, quiet, verbose, ctx)
438 for s in src:
439 p = Path(s)
440 if p.is_dir():
441 sources.update(
442 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
443 )
444 elif p.is_file() or s == "-":
445 # if a file was explicitly given, we don't care about its extension
446 sources.add(p)
447 else:
448 err(f"invalid path: {s}")
449 if len(sources) == 0:
450 if verbose or not quiet:
451 out("No Python files are present to be formatted. Nothing to do 😴")
452 ctx.exit(0)
453
454 if len(sources) == 1:
455 reformat_one(
456 src=sources.pop(),
457 fast=fast,
458 write_back=write_back,
459 mode=mode,
460 report=report,
461 )
462 else:
463 reformat_many(
464 sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
465 )
466
467 if verbose or not quiet:
468 out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
469 click.secho(str(report), err=True)
470 ctx.exit(report.return_code)
471
472
473 def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
474 """
475 Exit if there is no `src` provided for formatting
476 """
477 if not src:
478 if verbose or not quiet:
479 out("No Path provided. Nothing to do 😴")
480 ctx.exit(0)
481
482
483 def reformat_one(
484 src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
485 ) -> None:
486 """Reformat a single file under `src` without spawning child processes.
487
488 `fast`, `write_back`, and `mode` options are passed to
489 :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
490 """
491 try:
492 changed = Changed.NO
493 if not src.is_file() and str(src) == "-":
494 if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
495 changed = Changed.YES
496 else:
497 cache: Cache = {}
498 if write_back != WriteBack.DIFF:
499 cache = read_cache(mode)
500 res_src = src.resolve()
501 if res_src in cache and cache[res_src] == get_cache_info(res_src):
502 changed = Changed.CACHED
503 if changed is not Changed.CACHED and format_file_in_place(
504 src, fast=fast, write_back=write_back, mode=mode
505 ):
506 changed = Changed.YES
507 if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
508 write_back is WriteBack.CHECK and changed is Changed.NO
509 ):
510 write_cache(cache, [src], mode)
511 report.done(src, changed)
512 except Exception as exc:
513 report.failed(src, str(exc))
514
515
516 def reformat_many(
517 sources: Set[Path],
518 fast: bool,
519 write_back: WriteBack,
520 mode: FileMode,
521 report: "Report",
522 ) -> None:
523 """Reformat multiple files using a ProcessPoolExecutor."""
524 loop = asyncio.get_event_loop()
525 worker_count = os.cpu_count()
526 if sys.platform == "win32":
527 # Work around https://bugs.python.org/issue26903
528 worker_count = min(worker_count, 61)
529 executor = ProcessPoolExecutor(max_workers=worker_count)
530 try:
531 loop.run_until_complete(
532 schedule_formatting(
533 sources=sources,
534 fast=fast,
535 write_back=write_back,
536 mode=mode,
537 report=report,
538 loop=loop,
539 executor=executor,
540 )
541 )
542 finally:
543 shutdown(loop)
544 executor.shutdown()
545
546
547 async def schedule_formatting(
548 sources: Set[Path],
549 fast: bool,
550 write_back: WriteBack,
551 mode: FileMode,
552 report: "Report",
553 loop: asyncio.AbstractEventLoop,
554 executor: Executor,
555 ) -> None:
556 """Run formatting of `sources` in parallel using the provided `executor`.
557
558 (Use ProcessPoolExecutors for actual parallelism.)
559
560 `write_back`, `fast`, and `mode` options are passed to
561 :func:`format_file_in_place`.
562 """
563 cache: Cache = {}
564 if write_back != WriteBack.DIFF:
565 cache = read_cache(mode)
566 sources, cached = filter_cached(cache, sources)
567 for src in sorted(cached):
568 report.done(src, Changed.CACHED)
569 if not sources:
570 return
571
572 cancelled = []
573 sources_to_cache = []
574 lock = None
575 if write_back == WriteBack.DIFF:
576 # For diff output, we need locks to ensure we don't interleave output
577 # from different processes.
578 manager = Manager()
579 lock = manager.Lock()
580 tasks = {
581 asyncio.ensure_future(
582 loop.run_in_executor(
583 executor, format_file_in_place, src, fast, mode, write_back, lock
584 )
585 ): src
586 for src in sorted(sources)
587 }
588 pending: Iterable[asyncio.Future] = tasks.keys()
589 try:
590 loop.add_signal_handler(signal.SIGINT, cancel, pending)
591 loop.add_signal_handler(signal.SIGTERM, cancel, pending)
592 except NotImplementedError:
593 # There are no good alternatives for these on Windows.
594 pass
595 while pending:
596 done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
597 for task in done:
598 src = tasks.pop(task)
599 if task.cancelled():
600 cancelled.append(task)
601 elif task.exception():
602 report.failed(src, str(task.exception()))
603 else:
604 changed = Changed.YES if task.result() else Changed.NO
605 # If the file was written back or was successfully checked as
606 # well-formatted, store this information in the cache.
607 if write_back is WriteBack.YES or (
608 write_back is WriteBack.CHECK and changed is Changed.NO
609 ):
610 sources_to_cache.append(src)
611 report.done(src, changed)
612 if cancelled:
613 await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
614 if sources_to_cache:
615 write_cache(cache, sources_to_cache, mode)
616
617
618 def format_file_in_place(
619 src: Path,
620 fast: bool,
621 mode: FileMode,
622 write_back: WriteBack = WriteBack.NO,
623 lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
624 ) -> bool:
625 """Format file under `src` path. Return True if changed.
626
627 If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
628 code to the file.
629 `mode` and `fast` options are passed to :func:`format_file_contents`.
630 """
631 if src.suffix == ".pyi":
632 mode = evolve(mode, is_pyi=True)
633
634 then = datetime.utcfromtimestamp(src.stat().st_mtime)
635 with open(src, "rb") as buf:
636 src_contents, encoding, newline = decode_bytes(buf.read())
637 try:
638 dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
639 except NothingChanged:
640 return False
641
642 if write_back == write_back.YES:
643 with open(src, "w", encoding=encoding, newline=newline) as f:
644 f.write(dst_contents)
645 elif write_back == write_back.DIFF:
646 now = datetime.utcnow()
647 src_name = f"{src}\t{then} +0000"
648 dst_name = f"{src}\t{now} +0000"
649 diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
650
651 with lock or nullcontext():
652 f = io.TextIOWrapper(
653 sys.stdout.buffer,
654 encoding=encoding,
655 newline=newline,
656 write_through=True,
657 )
658 f.write(diff_contents)
659 f.detach()
660
661 return True
662
663
664 def format_stdin_to_stdout(
665 fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
666 ) -> bool:
667 """Format file on stdin. Return True if changed.
668
669 If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
670 write a diff to stdout. The `mode` argument is passed to
671 :func:`format_file_contents`.
672 """
673 then = datetime.utcnow()
674 src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
675 dst = src
676 try:
677 dst = format_file_contents(src, fast=fast, mode=mode)
678 return True
679
680 except NothingChanged:
681 return False
682
683 finally:
684 f = io.TextIOWrapper(
685 sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
686 )
687 if write_back == WriteBack.YES:
688 f.write(dst)
689 elif write_back == WriteBack.DIFF:
690 now = datetime.utcnow()
691 src_name = f"STDIN\t{then} +0000"
692 dst_name = f"STDOUT\t{now} +0000"
693 f.write(diff(src, dst, src_name, dst_name))
694 f.detach()
695
696
697 def format_file_contents(
698 src_contents: str, *, fast: bool, mode: FileMode
699 ) -> FileContent:
700 """Reformat contents a file and return new contents.
701
702 If `fast` is False, additionally confirm that the reformatted code is
703 valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
704 `mode` is passed to :func:`format_str`.
705 """
706 if src_contents.strip() == "":
707 raise NothingChanged
708
709 dst_contents = format_str(src_contents, mode=mode)
710 if src_contents == dst_contents:
711 raise NothingChanged
712
713 if not fast:
714 assert_equivalent(src_contents, dst_contents)
715 assert_stable(src_contents, dst_contents, mode=mode)
716 return dst_contents
717
718
719 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
720 """Reformat a string and return new contents.
721
722 `mode` determines formatting options, such as how many characters per line are
723 allowed.
724 """
725 src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
726 dst_contents = []
727 future_imports = get_future_imports(src_node)
728 if mode.target_versions:
729 versions = mode.target_versions
730 else:
731 versions = detect_target_versions(src_node)
732 normalize_fmt_off(src_node)
733 lines = LineGenerator(
734 remove_u_prefix="unicode_literals" in future_imports
735 or supports_feature(versions, Feature.UNICODE_LITERALS),
736 is_pyi=mode.is_pyi,
737 normalize_strings=mode.string_normalization,
738 )
739 elt = EmptyLineTracker(is_pyi=mode.is_pyi)
740 empty_line = Line()
741 after = 0
742 split_line_features = {
743 feature
744 for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
745 if supports_feature(versions, feature)
746 }
747 for current_line in lines.visit(src_node):
748 for _ in range(after):
749 dst_contents.append(str(empty_line))
750 before, after = elt.maybe_empty_lines(current_line)
751 for _ in range(before):
752 dst_contents.append(str(empty_line))
753 for line in split_line(
754 current_line, line_length=mode.line_length, features=split_line_features
755 ):
756 dst_contents.append(str(line))
757 return "".join(dst_contents)
758
759
760 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
761 """Return a tuple of (decoded_contents, encoding, newline).
762
763 `newline` is either CRLF or LF but `decoded_contents` is decoded with
764 universal newlines (i.e. only contains LF).
765 """
766 srcbuf = io.BytesIO(src)
767 encoding, lines = tokenize.detect_encoding(srcbuf.readline)
768 if not lines:
769 return "", encoding, "\n"
770
771 newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
772 srcbuf.seek(0)
773 with io.TextIOWrapper(srcbuf, encoding) as tiow:
774 return tiow.read(), encoding, newline
775
776
777 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
778 if not target_versions:
779 # No target_version specified, so try all grammars.
780 return [
781 # Python 3.7+
782 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
783 # Python 3.0-3.6
784 pygram.python_grammar_no_print_statement_no_exec_statement,
785 # Python 2.7 with future print_function import
786 pygram.python_grammar_no_print_statement,
787 # Python 2.7
788 pygram.python_grammar,
789 ]
790 elif all(version.is_python2() for version in target_versions):
791 # Python 2-only code, so try Python 2 grammars.
792 return [
793 # Python 2.7 with future print_function import
794 pygram.python_grammar_no_print_statement,
795 # Python 2.7
796 pygram.python_grammar,
797 ]
798 else:
799 # Python 3-compatible code, so only try Python 3 grammar.
800 grammars = []
801 # If we have to parse both, try to parse async as a keyword first
802 if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
803 # Python 3.7+
804 grammars.append(
805 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords # noqa: B950
806 )
807 if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
808 # Python 3.0-3.6
809 grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
810 # At least one of the above branches must have been taken, because every Python
811 # version has exactly one of the two 'ASYNC_*' flags
812 return grammars
813
814
815 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
816 """Given a string with source, return the lib2to3 Node."""
817 if src_txt[-1:] != "\n":
818 src_txt += "\n"
819
820 for grammar in get_grammars(set(target_versions)):
821 drv = driver.Driver(grammar, pytree.convert)
822 try:
823 result = drv.parse_string(src_txt, True)
824 break
825
826 except ParseError as pe:
827 lineno, column = pe.context[1]
828 lines = src_txt.splitlines()
829 try:
830 faulty_line = lines[lineno - 1]
831 except IndexError:
832 faulty_line = "<line number missing in source>"
833 exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
834 else:
835 raise exc from None
836
837 if isinstance(result, Leaf):
838 result = Node(syms.file_input, [result])
839 return result
840
841
842 def lib2to3_unparse(node: Node) -> str:
843 """Given a lib2to3 node, return its string representation."""
844 code = str(node)
845 return code
846
847
848 T = TypeVar("T")
849
850
851 class Visitor(Generic[T]):
852 """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
853
854 def visit(self, node: LN) -> Iterator[T]:
855 """Main method to visit `node` and its children.
856
857 It tries to find a `visit_*()` method for the given `node.type`, like
858 `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
859 If no dedicated `visit_*()` method is found, chooses `visit_default()`
860 instead.
861
862 Then yields objects of type `T` from the selected visitor.
863 """
864 if node.type < 256:
865 name = token.tok_name[node.type]
866 else:
867 name = type_repr(node.type)
868 yield from getattr(self, f"visit_{name}", self.visit_default)(node)
869
870 def visit_default(self, node: LN) -> Iterator[T]:
871 """Default `visit_*()` implementation. Recurses to children of `node`."""
872 if isinstance(node, Node):
873 for child in node.children:
874 yield from self.visit(child)
875
876
877 @dataclass
878 class DebugVisitor(Visitor[T]):
879 tree_depth: int = 0
880
881 def visit_default(self, node: LN) -> Iterator[T]:
882 indent = " " * (2 * self.tree_depth)
883 if isinstance(node, Node):
884 _type = type_repr(node.type)
885 out(f"{indent}{_type}", fg="yellow")
886 self.tree_depth += 1
887 for child in node.children:
888 yield from self.visit(child)
889
890 self.tree_depth -= 1
891 out(f"{indent}/{_type}", fg="yellow", bold=False)
892 else:
893 _type = token.tok_name.get(node.type, str(node.type))
894 out(f"{indent}{_type}", fg="blue", nl=False)
895 if node.prefix:
896 # We don't have to handle prefixes for `Node` objects since
897 # that delegates to the first child anyway.
898 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
899 out(f" {node.value!r}", fg="blue", bold=False)
900
901 @classmethod
902 def show(cls, code: Union[str, Leaf, Node]) -> None:
903 """Pretty-print the lib2to3 AST of a given string of `code`.
904
905 Convenience method for debugging.
906 """
907 v: DebugVisitor[None] = DebugVisitor()
908 if isinstance(code, str):
909 code = lib2to3_parse(code)
910 list(v.visit(code))
911
912
913 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
914 STATEMENT = {
915 syms.if_stmt,
916 syms.while_stmt,
917 syms.for_stmt,
918 syms.try_stmt,
919 syms.except_clause,
920 syms.with_stmt,
921 syms.funcdef,
922 syms.classdef,
923 }
924 STANDALONE_COMMENT = 153
925 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
926 LOGIC_OPERATORS = {"and", "or"}
927 COMPARATORS = {
928 token.LESS,
929 token.GREATER,
930 token.EQEQUAL,
931 token.NOTEQUAL,
932 token.LESSEQUAL,
933 token.GREATEREQUAL,
934 }
935 MATH_OPERATORS = {
936 token.VBAR,
937 token.CIRCUMFLEX,
938 token.AMPER,
939 token.LEFTSHIFT,
940 token.RIGHTSHIFT,
941 token.PLUS,
942 token.MINUS,
943 token.STAR,
944 token.SLASH,
945 token.DOUBLESLASH,
946 token.PERCENT,
947 token.AT,
948 token.TILDE,
949 token.DOUBLESTAR,
950 }
951 STARS = {token.STAR, token.DOUBLESTAR}
952 VARARGS_SPECIALS = STARS | {token.SLASH}
953 VARARGS_PARENTS = {
954 syms.arglist,
955 syms.argument, # double star in arglist
956 syms.trailer, # single argument to call
957 syms.typedargslist,
958 syms.varargslist, # lambdas
959 }
960 UNPACKING_PARENTS = {
961 syms.atom, # single element of a list or set literal
962 syms.dictsetmaker,
963 syms.listmaker,
964 syms.testlist_gexp,
965 syms.testlist_star_expr,
966 }
967 TEST_DESCENDANTS = {
968 syms.test,
969 syms.lambdef,
970 syms.or_test,
971 syms.and_test,
972 syms.not_test,
973 syms.comparison,
974 syms.star_expr,
975 syms.expr,
976 syms.xor_expr,
977 syms.and_expr,
978 syms.shift_expr,
979 syms.arith_expr,
980 syms.trailer,
981 syms.term,
982 syms.power,
983 }
984 ASSIGNMENTS = {
985 "=",
986 "+=",
987 "-=",
988 "*=",
989 "@=",
990 "/=",
991 "%=",
992 "&=",
993 "|=",
994 "^=",
995 "<<=",
996 ">>=",
997 "**=",
998 "//=",
999 }
1000 COMPREHENSION_PRIORITY = 20
1001 COMMA_PRIORITY = 18
1002 TERNARY_PRIORITY = 16
1003 LOGIC_PRIORITY = 14
1004 STRING_PRIORITY = 12
1005 COMPARATOR_PRIORITY = 10
1006 MATH_PRIORITIES = {
1007 token.VBAR: 9,
1008 token.CIRCUMFLEX: 8,
1009 token.AMPER: 7,
1010 token.LEFTSHIFT: 6,
1011 token.RIGHTSHIFT: 6,
1012 token.PLUS: 5,
1013 token.MINUS: 5,
1014 token.STAR: 4,
1015 token.SLASH: 4,
1016 token.DOUBLESLASH: 4,
1017 token.PERCENT: 4,
1018 token.AT: 4,
1019 token.TILDE: 3,
1020 token.DOUBLESTAR: 2,
1021 }
1022 DOT_PRIORITY = 1
1023
1024
1025 @dataclass
1026 class BracketTracker:
1027 """Keeps track of brackets on a line."""
1028
1029 depth: int = 0
1030 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1031 delimiters: Dict[LeafID, Priority] = Factory(dict)
1032 previous: Optional[Leaf] = None
1033 _for_loop_depths: List[int] = Factory(list)
1034 _lambda_argument_depths: List[int] = Factory(list)
1035
1036 def mark(self, leaf: Leaf) -> None:
1037 """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1038
1039 All leaves receive an int `bracket_depth` field that stores how deep
1040 within brackets a given leaf is. 0 means there are no enclosing brackets
1041 that started on this line.
1042
1043 If a leaf is itself a closing bracket, it receives an `opening_bracket`
1044 field that it forms a pair with. This is a one-directional link to
1045 avoid reference cycles.
1046
1047 If a leaf is a delimiter (a token on which Black can split the line if
1048 needed) and it's on depth 0, its `id()` is stored in the tracker's
1049 `delimiters` field.
1050 """
1051 if leaf.type == token.COMMENT:
1052 return
1053
1054 self.maybe_decrement_after_for_loop_variable(leaf)
1055 self.maybe_decrement_after_lambda_arguments(leaf)
1056 if leaf.type in CLOSING_BRACKETS:
1057 self.depth -= 1
1058 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1059 leaf.opening_bracket = opening_bracket
1060 leaf.bracket_depth = self.depth
1061 if self.depth == 0:
1062 delim = is_split_before_delimiter(leaf, self.previous)
1063 if delim and self.previous is not None:
1064 self.delimiters[id(self.previous)] = delim
1065 else:
1066 delim = is_split_after_delimiter(leaf, self.previous)
1067 if delim:
1068 self.delimiters[id(leaf)] = delim
1069 if leaf.type in OPENING_BRACKETS:
1070 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1071 self.depth += 1
1072 self.previous = leaf
1073 self.maybe_increment_lambda_arguments(leaf)
1074 self.maybe_increment_for_loop_variable(leaf)
1075
1076 def any_open_brackets(self) -> bool:
1077 """Return True if there is an yet unmatched open bracket on the line."""
1078 return bool(self.bracket_match)
1079
1080 def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1081 """Return the highest priority of a delimiter found on the line.
1082
1083 Values are consistent with what `is_split_*_delimiter()` return.
1084 Raises ValueError on no delimiters.
1085 """
1086 return max(v for k, v in self.delimiters.items() if k not in exclude)
1087
1088 def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1089 """Return the number of delimiters with the given `priority`.
1090
1091 If no `priority` is passed, defaults to max priority on the line.
1092 """
1093 if not self.delimiters:
1094 return 0
1095
1096 priority = priority or self.max_delimiter_priority()
1097 return sum(1 for p in self.delimiters.values() if p == priority)
1098
1099 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1100 """In a for loop, or comprehension, the variables are often unpacks.
1101
1102 To avoid splitting on the comma in this situation, increase the depth of
1103 tokens between `for` and `in`.
1104 """
1105 if leaf.type == token.NAME and leaf.value == "for":
1106 self.depth += 1
1107 self._for_loop_depths.append(self.depth)
1108 return True
1109
1110 return False
1111
1112 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1113 """See `maybe_increment_for_loop_variable` above for explanation."""
1114 if (
1115 self._for_loop_depths
1116 and self._for_loop_depths[-1] == self.depth
1117 and leaf.type == token.NAME
1118 and leaf.value == "in"
1119 ):
1120 self.depth -= 1
1121 self._for_loop_depths.pop()
1122 return True
1123
1124 return False
1125
1126 def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1127 """In a lambda expression, there might be more than one argument.
1128
1129 To avoid splitting on the comma in this situation, increase the depth of
1130 tokens between `lambda` and `:`.
1131 """
1132 if leaf.type == token.NAME and leaf.value == "lambda":
1133 self.depth += 1
1134 self._lambda_argument_depths.append(self.depth)
1135 return True
1136
1137 return False
1138
1139 def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1140 """See `maybe_increment_lambda_arguments` above for explanation."""
1141 if (
1142 self._lambda_argument_depths
1143 and self._lambda_argument_depths[-1] == self.depth
1144 and leaf.type == token.COLON
1145 ):
1146 self.depth -= 1
1147 self._lambda_argument_depths.pop()
1148 return True
1149
1150 return False
1151
1152 def get_open_lsqb(self) -> Optional[Leaf]:
1153 """Return the most recent opening square bracket (if any)."""
1154 return self.bracket_match.get((self.depth - 1, token.RSQB))
1155
1156
1157 @dataclass
1158 class Line:
1159 """Holds leaves and comments. Can be printed with `str(line)`."""
1160
1161 depth: int = 0
1162 leaves: List[Leaf] = Factory(list)
1163 comments: Dict[LeafID, List[Leaf]] = Factory(dict) # keys ordered like `leaves`
1164 bracket_tracker: BracketTracker = Factory(BracketTracker)
1165 inside_brackets: bool = False
1166 should_explode: bool = False
1167
1168 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1169 """Add a new `leaf` to the end of the line.
1170
1171 Unless `preformatted` is True, the `leaf` will receive a new consistent
1172 whitespace prefix and metadata applied by :class:`BracketTracker`.
1173 Trailing commas are maybe removed, unpacked for loop variables are
1174 demoted from being delimiters.
1175
1176 Inline comments are put aside.
1177 """
1178 has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1179 if not has_value:
1180 return
1181
1182 if token.COLON == leaf.type and self.is_class_paren_empty:
1183 del self.leaves[-2:]
1184 if self.leaves and not preformatted:
1185 # Note: at this point leaf.prefix should be empty except for
1186 # imports, for which we only preserve newlines.
1187 leaf.prefix += whitespace(
1188 leaf, complex_subscript=self.is_complex_subscript(leaf)
1189 )
1190 if self.inside_brackets or not preformatted:
1191 self.bracket_tracker.mark(leaf)
1192 self.maybe_remove_trailing_comma(leaf)
1193 if not self.append_comment(leaf):
1194 self.leaves.append(leaf)
1195
1196 def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1197 """Like :func:`append()` but disallow invalid standalone comment structure.
1198
1199 Raises ValueError when any `leaf` is appended after a standalone comment
1200 or when a standalone comment is not the first leaf on the line.
1201 """
1202 if self.bracket_tracker.depth == 0:
1203 if self.is_comment:
1204 raise ValueError("cannot append to standalone comments")
1205
1206 if self.leaves and leaf.type == STANDALONE_COMMENT:
1207 raise ValueError(
1208 "cannot append standalone comments to a populated line"
1209 )
1210
1211 self.append(leaf, preformatted=preformatted)
1212
1213 @property
1214 def is_comment(self) -> bool:
1215 """Is this line a standalone comment?"""
1216 return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1217
1218 @property
1219 def is_decorator(self) -> bool:
1220 """Is this line a decorator?"""
1221 return bool(self) and self.leaves[0].type == token.AT
1222
1223 @property
1224 def is_import(self) -> bool:
1225 """Is this an import line?"""
1226 return bool(self) and is_import(self.leaves[0])
1227
1228 @property
1229 def is_class(self) -> bool:
1230 """Is this line a class definition?"""
1231 return (
1232 bool(self)
1233 and self.leaves[0].type == token.NAME
1234 and self.leaves[0].value == "class"
1235 )
1236
1237 @property
1238 def is_stub_class(self) -> bool:
1239 """Is this line a class definition with a body consisting only of "..."?"""
1240 return self.is_class and self.leaves[-3:] == [
1241 Leaf(token.DOT, ".") for _ in range(3)
1242 ]
1243
1244 @property
1245 def is_collection_with_optional_trailing_comma(self) -> bool:
1246 """Is this line a collection literal with a trailing comma that's optional?
1247
1248 Note that the trailing comma in a 1-tuple is not optional.
1249 """
1250 if not self.leaves or len(self.leaves) < 4:
1251 return False
1252 # Look for and address a trailing colon.
1253 if self.leaves[-1].type == token.COLON:
1254 closer = self.leaves[-2]
1255 close_index = -2
1256 else:
1257 closer = self.leaves[-1]
1258 close_index = -1
1259 if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1260 return False
1261 if closer.type == token.RPAR:
1262 # Tuples require an extra check, because if there's only
1263 # one element in the tuple removing the comma unmakes the
1264 # tuple.
1265 #
1266 # We also check for parens before looking for the trailing
1267 # comma because in some cases (eg assigning a dict
1268 # literal) the literal gets wrapped in temporary parens
1269 # during parsing. This case is covered by the
1270 # collections.py test data.
1271 opener = closer.opening_bracket
1272 for _open_index, leaf in enumerate(self.leaves):
1273 if leaf is opener:
1274 break
1275 else:
1276 # Couldn't find the matching opening paren, play it safe.
1277 return False
1278 commas = 0
1279 comma_depth = self.leaves[close_index - 1].bracket_depth
1280 for leaf in self.leaves[_open_index + 1 : close_index]:
1281 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1282 commas += 1
1283 if commas > 1:
1284 # We haven't looked yet for the trailing comma because
1285 # we might also have caught noop parens.
1286 return self.leaves[close_index - 1].type == token.COMMA
1287 elif commas == 1:
1288 return False # it's either a one-tuple or didn't have a trailing comma
1289 if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1290 close_index -= 1
1291 closer = self.leaves[close_index]
1292 if closer.type == token.RPAR:
1293 # TODO: this is a gut feeling. Will we ever see this?
1294 return False
1295 if self.leaves[close_index - 1].type != token.COMMA:
1296 return False
1297 return True
1298
1299 @property
1300 def is_def(self) -> bool:
1301 """Is this a function definition? (Also returns True for async defs.)"""
1302 try:
1303 first_leaf = self.leaves[0]
1304 except IndexError:
1305 return False
1306
1307 try:
1308 second_leaf: Optional[Leaf] = self.leaves[1]
1309 except IndexError:
1310 second_leaf = None
1311 return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1312 first_leaf.type == token.ASYNC
1313 and second_leaf is not None
1314 and second_leaf.type == token.NAME
1315 and second_leaf.value == "def"
1316 )
1317
1318 @property
1319 def is_class_paren_empty(self) -> bool:
1320 """Is this a class with no base classes but using parentheses?
1321
1322 Those are unnecessary and should be removed.
1323 """
1324 return (
1325 bool(self)
1326 and len(self.leaves) == 4
1327 and self.is_class
1328 and self.leaves[2].type == token.LPAR
1329 and self.leaves[2].value == "("
1330 and self.leaves[3].type == token.RPAR
1331 and self.leaves[3].value == ")"
1332 )
1333
1334 @property
1335 def is_triple_quoted_string(self) -> bool:
1336 """Is the line a triple quoted string?"""
1337 return (
1338 bool(self)
1339 and self.leaves[0].type == token.STRING
1340 and self.leaves[0].value.startswith(('"""', "'''"))
1341 )
1342
1343 def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1344 """If so, needs to be split before emitting."""
1345 for leaf in self.leaves:
1346 if leaf.type == STANDALONE_COMMENT:
1347 if leaf.bracket_depth <= depth_limit:
1348 return True
1349 return False
1350
1351 def contains_uncollapsable_type_comments(self) -> bool:
1352 ignored_ids = set()
1353 try:
1354 last_leaf = self.leaves[-1]
1355 ignored_ids.add(id(last_leaf))
1356 if last_leaf.type == token.COMMA or (
1357 last_leaf.type == token.RPAR and not last_leaf.value
1358 ):
1359 # When trailing commas or optional parens are inserted by Black for
1360 # consistency, comments after the previous last element are not moved
1361 # (they don't have to, rendering will still be correct). So we ignore
1362 # trailing commas and invisible.
1363 last_leaf = self.leaves[-2]
1364 ignored_ids.add(id(last_leaf))
1365 except IndexError:
1366 return False
1367
1368 # A type comment is uncollapsable if it is attached to a leaf
1369 # that isn't at the end of the line (since that could cause it
1370 # to get associated to a different argument) or if there are
1371 # comments before it (since that could cause it to get hidden
1372 # behind a comment.
1373 comment_seen = False
1374 for leaf_id, comments in self.comments.items():
1375 for comment in comments:
1376 if is_type_comment(comment):
1377 if leaf_id not in ignored_ids or comment_seen:
1378 return True
1379
1380 comment_seen = True
1381
1382 return False
1383
1384 def contains_unsplittable_type_ignore(self) -> bool:
1385 if not self.leaves:
1386 return False
1387
1388 # If a 'type: ignore' is attached to the end of a line, we
1389 # can't split the line, because we can't know which of the
1390 # subexpressions the ignore was meant to apply to.
1391 #
1392 # We only want this to apply to actual physical lines from the
1393 # original source, though: we don't want the presence of a
1394 # 'type: ignore' at the end of a multiline expression to
1395 # justify pushing it all onto one line. Thus we
1396 # (unfortunately) need to check the actual source lines and
1397 # only report an unsplittable 'type: ignore' if this line was
1398 # one line in the original code.
1399 if self.leaves[0].lineno == self.leaves[-1].lineno:
1400 for comment in self.comments.get(id(self.leaves[-1]), []):
1401 if is_type_comment(comment, " ignore"):
1402 return True
1403
1404 return False
1405
1406 def contains_multiline_strings(self) -> bool:
1407 for leaf in self.leaves:
1408 if is_multiline_string(leaf):
1409 return True
1410
1411 return False
1412
1413 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1414 """Remove trailing comma if there is one and it's safe."""
1415 if not (self.leaves and self.leaves[-1].type == token.COMMA):
1416 return False
1417 # We remove trailing commas only in the case of importing a
1418 # single name from a module.
1419 if not (
1420 self.leaves
1421 and self.is_import
1422 and len(self.leaves) > 4
1423 and self.leaves[-1].type == token.COMMA
1424 and closing.type in CLOSING_BRACKETS
1425 and self.leaves[-4].type == token.NAME
1426 and (
1427 # regular `from foo import bar,`
1428 self.leaves[-4].value == "import"
1429 # `from foo import (bar as baz,)
1430 or (
1431 len(self.leaves) > 6
1432 and self.leaves[-6].value == "import"
1433 and self.leaves[-3].value == "as"
1434 )
1435 # `from foo import bar as baz,`
1436 or (
1437 len(self.leaves) > 5
1438 and self.leaves[-5].value == "import"
1439 and self.leaves[-3].value == "as"
1440 )
1441 )
1442 and closing.type == token.RPAR
1443 ):
1444 return False
1445
1446 self.remove_trailing_comma()
1447 return True
1448
1449 def append_comment(self, comment: Leaf) -> bool:
1450 """Add an inline or standalone comment to the line."""
1451 if (
1452 comment.type == STANDALONE_COMMENT
1453 and self.bracket_tracker.any_open_brackets()
1454 ):
1455 comment.prefix = ""
1456 return False
1457
1458 if comment.type != token.COMMENT:
1459 return False
1460
1461 if not self.leaves:
1462 comment.type = STANDALONE_COMMENT
1463 comment.prefix = ""
1464 return False
1465
1466 last_leaf = self.leaves[-1]
1467 if (
1468 last_leaf.type == token.RPAR
1469 and not last_leaf.value
1470 and last_leaf.parent
1471 and len(list(last_leaf.parent.leaves())) <= 3
1472 and not is_type_comment(comment)
1473 ):
1474 # Comments on an optional parens wrapping a single leaf should belong to
1475 # the wrapped node except if it's a type comment. Pinning the comment like
1476 # this avoids unstable formatting caused by comment migration.
1477 if len(self.leaves) < 2:
1478 comment.type = STANDALONE_COMMENT
1479 comment.prefix = ""
1480 return False
1481 last_leaf = self.leaves[-2]
1482 self.comments.setdefault(id(last_leaf), []).append(comment)
1483 return True
1484
1485 def comments_after(self, leaf: Leaf) -> List[Leaf]:
1486 """Generate comments that should appear directly after `leaf`."""
1487 return self.comments.get(id(leaf), [])
1488
1489 def remove_trailing_comma(self) -> None:
1490 """Remove the trailing comma and moves the comments attached to it."""
1491 trailing_comma = self.leaves.pop()
1492 trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1493 self.comments.setdefault(id(self.leaves[-1]), []).extend(
1494 trailing_comma_comments
1495 )
1496
1497 def is_complex_subscript(self, leaf: Leaf) -> bool:
1498 """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1499 open_lsqb = self.bracket_tracker.get_open_lsqb()
1500 if open_lsqb is None:
1501 return False
1502
1503 subscript_start = open_lsqb.next_sibling
1504
1505 if isinstance(subscript_start, Node):
1506 if subscript_start.type == syms.listmaker:
1507 return False
1508
1509 if subscript_start.type == syms.subscriptlist:
1510 subscript_start = child_towards(subscript_start, leaf)
1511 return subscript_start is not None and any(
1512 n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1513 )
1514
1515 def __str__(self) -> str:
1516 """Render the line."""
1517 if not self:
1518 return "\n"
1519
1520 indent = " " * self.depth
1521 leaves = iter(self.leaves)
1522 first = next(leaves)
1523 res = f"{first.prefix}{indent}{first.value}"
1524 for leaf in leaves:
1525 res += str(leaf)
1526 for comment in itertools.chain.from_iterable(self.comments.values()):
1527 res += str(comment)
1528 return res + "\n"
1529
1530 def __bool__(self) -> bool:
1531 """Return True if the line has leaves or comments."""
1532 return bool(self.leaves or self.comments)
1533
1534
1535 @dataclass
1536 class EmptyLineTracker:
1537 """Provides a stateful method that returns the number of potential extra
1538 empty lines needed before and after the currently processed line.
1539
1540 Note: this tracker works on lines that haven't been split yet. It assumes
1541 the prefix of the first leaf consists of optional newlines. Those newlines
1542 are consumed by `maybe_empty_lines()` and included in the computation.
1543 """
1544
1545 is_pyi: bool = False
1546 previous_line: Optional[Line] = None
1547 previous_after: int = 0
1548 previous_defs: List[int] = Factory(list)
1549
1550 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1551 """Return the number of extra empty lines before and after the `current_line`.
1552
1553 This is for separating `def`, `async def` and `class` with extra empty
1554 lines (two on module-level).
1555 """
1556 before, after = self._maybe_empty_lines(current_line)
1557 before = (
1558 # Black should not insert empty lines at the beginning
1559 # of the file
1560 0
1561 if self.previous_line is None
1562 else before - self.previous_after
1563 )
1564 self.previous_after = after
1565 self.previous_line = current_line
1566 return before, after
1567
1568 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1569 max_allowed = 1
1570 if current_line.depth == 0:
1571 max_allowed = 1 if self.is_pyi else 2
1572 if current_line.leaves:
1573 # Consume the first leaf's extra newlines.
1574 first_leaf = current_line.leaves[0]
1575 before = first_leaf.prefix.count("\n")
1576 before = min(before, max_allowed)
1577 first_leaf.prefix = ""
1578 else:
1579 before = 0
1580 depth = current_line.depth
1581 while self.previous_defs and self.previous_defs[-1] >= depth:
1582 self.previous_defs.pop()
1583 if self.is_pyi:
1584 before = 0 if depth else 1
1585 else:
1586 before = 1 if depth else 2
1587 if current_line.is_decorator or current_line.is_def or current_line.is_class:
1588 return self._maybe_empty_lines_for_class_or_def(current_line, before)
1589
1590 if (
1591 self.previous_line
1592 and self.previous_line.is_import
1593 and not current_line.is_import
1594 and depth == self.previous_line.depth
1595 ):
1596 return (before or 1), 0
1597
1598 if (
1599 self.previous_line
1600 and self.previous_line.is_class
1601 and current_line.is_triple_quoted_string
1602 ):
1603 return before, 1
1604
1605 return before, 0
1606
1607 def _maybe_empty_lines_for_class_or_def(
1608 self, current_line: Line, before: int
1609 ) -> Tuple[int, int]:
1610 if not current_line.is_decorator:
1611 self.previous_defs.append(current_line.depth)
1612 if self.previous_line is None:
1613 # Don't insert empty lines before the first line in the file.
1614 return 0, 0
1615
1616 if self.previous_line.is_decorator:
1617 return 0, 0
1618
1619 if self.previous_line.depth < current_line.depth and (
1620 self.previous_line.is_class or self.previous_line.is_def
1621 ):
1622 return 0, 0
1623
1624 if (
1625 self.previous_line.is_comment
1626 and self.previous_line.depth == current_line.depth
1627 and before == 0
1628 ):
1629 return 0, 0
1630
1631 if self.is_pyi:
1632 if self.previous_line.depth > current_line.depth:
1633 newlines = 1
1634 elif current_line.is_class or self.previous_line.is_class:
1635 if current_line.is_stub_class and self.previous_line.is_stub_class:
1636 # No blank line between classes with an empty body
1637 newlines = 0
1638 else:
1639 newlines = 1
1640 elif current_line.is_def and not self.previous_line.is_def:
1641 # Blank line between a block of functions and a block of non-functions
1642 newlines = 1
1643 else:
1644 newlines = 0
1645 else:
1646 newlines = 2
1647 if current_line.depth and newlines:
1648 newlines -= 1
1649 return newlines, 0
1650
1651
1652 @dataclass
1653 class LineGenerator(Visitor[Line]):
1654 """Generates reformatted Line objects. Empty lines are not emitted.
1655
1656 Note: destroys the tree it's visiting by mutating prefixes of its leaves
1657 in ways that will no longer stringify to valid Python code on the tree.
1658 """
1659
1660 is_pyi: bool = False
1661 normalize_strings: bool = True
1662 current_line: Line = Factory(Line)
1663 remove_u_prefix: bool = False
1664
1665 def line(self, indent: int = 0) -> Iterator[Line]:
1666 """Generate a line.
1667
1668 If the line is empty, only emit if it makes sense.
1669 If the line is too long, split it first and then generate.
1670
1671 If any lines were generated, set up a new current_line.
1672 """
1673 if not self.current_line:
1674 self.current_line.depth += indent
1675 return # Line is empty, don't emit. Creating a new one unnecessary.
1676
1677 complete_line = self.current_line
1678 self.current_line = Line(depth=complete_line.depth + indent)
1679 yield complete_line
1680
1681 def visit_default(self, node: LN) -> Iterator[Line]:
1682 """Default `visit_*()` implementation. Recurses to children of `node`."""
1683 if isinstance(node, Leaf):
1684 any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1685 for comment in generate_comments(node):
1686 if any_open_brackets:
1687 # any comment within brackets is subject to splitting
1688 self.current_line.append(comment)
1689 elif comment.type == token.COMMENT:
1690 # regular trailing comment
1691 self.current_line.append(comment)
1692 yield from self.line()
1693
1694 else:
1695 # regular standalone comment
1696 yield from self.line()
1697
1698 self.current_line.append(comment)
1699 yield from self.line()
1700
1701 normalize_prefix(node, inside_brackets=any_open_brackets)
1702 if self.normalize_strings and node.type == token.STRING:
1703 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1704 normalize_string_quotes(node)
1705 if node.type == token.NUMBER:
1706 normalize_numeric_literal(node)
1707 if node.type not in WHITESPACE:
1708 self.current_line.append(node)
1709 yield from super().visit_default(node)
1710
1711 def visit_atom(self, node: Node) -> Iterator[Line]:
1712 # Always make parentheses invisible around a single node, because it should
1713 # not be needed (except in the case of yield, where removing the parentheses
1714 # produces a SyntaxError).
1715 if (
1716 len(node.children) == 3
1717 and isinstance(node.children[0], Leaf)
1718 and node.children[0].type == token.LPAR
1719 and isinstance(node.children[2], Leaf)
1720 and node.children[2].type == token.RPAR
1721 and isinstance(node.children[1], Leaf)
1722 and not (
1723 node.children[1].type == token.NAME
1724 and node.children[1].value == "yield"
1725 )
1726 ):
1727 node.children[0].value = ""
1728 node.children[2].value = ""
1729 yield from super().visit_default(node)
1730
1731 def visit_factor(self, node: Node) -> Iterator[Line]:
1732 """Force parentheses between a unary op and a binary power:
1733
1734 -2 ** 8 -> -(2 ** 8)
1735 """
1736 child = node.children[1]
1737 if child.type == syms.power and len(child.children) == 3:
1738 lpar = Leaf(token.LPAR, "(")
1739 rpar = Leaf(token.RPAR, ")")
1740 index = child.remove() or 0
1741 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
1742 yield from self.visit_default(node)
1743
1744 def visit_INDENT(self, node: Node) -> Iterator[Line]:
1745 """Increase indentation level, maybe yield a line."""
1746 # In blib2to3 INDENT never holds comments.
1747 yield from self.line(+1)
1748 yield from self.visit_default(node)
1749
1750 def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1751 """Decrease indentation level, maybe yield a line."""
1752 # The current line might still wait for trailing comments. At DEDENT time
1753 # there won't be any (they would be prefixes on the preceding NEWLINE).
1754 # Emit the line then.
1755 yield from self.line()
1756
1757 # While DEDENT has no value, its prefix may contain standalone comments
1758 # that belong to the current indentation level. Get 'em.
1759 yield from self.visit_default(node)
1760
1761 # Finally, emit the dedent.
1762 yield from self.line(-1)
1763
1764 def visit_stmt(
1765 self, node: Node, keywords: Set[str], parens: Set[str]
1766 ) -> Iterator[Line]:
1767 """Visit a statement.
1768
1769 This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1770 `def`, `with`, `class`, `assert` and assignments.
1771
1772 The relevant Python language `keywords` for a given statement will be
1773 NAME leaves within it. This methods puts those on a separate line.
1774
1775 `parens` holds a set of string leaf values immediately after which
1776 invisible parens should be put.
1777 """
1778 normalize_invisible_parens(node, parens_after=parens)
1779 for child in node.children:
1780 if child.type == token.NAME and child.value in keywords: # type: ignore
1781 yield from self.line()
1782
1783 yield from self.visit(child)
1784
1785 def visit_suite(self, node: Node) -> Iterator[Line]:
1786 """Visit a suite."""
1787 if self.is_pyi and is_stub_suite(node):
1788 yield from self.visit(node.children[2])
1789 else:
1790 yield from self.visit_default(node)
1791
1792 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1793 """Visit a statement without nested statements."""
1794 is_suite_like = node.parent and node.parent.type in STATEMENT
1795 if is_suite_like:
1796 if self.is_pyi and is_stub_body(node):
1797 yield from self.visit_default(node)
1798 else:
1799 yield from self.line(+1)
1800 yield from self.visit_default(node)
1801 yield from self.line(-1)
1802
1803 else:
1804 if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1805 yield from self.line()
1806 yield from self.visit_default(node)
1807
1808 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1809 """Visit `async def`, `async for`, `async with`."""
1810 yield from self.line()
1811
1812 children = iter(node.children)
1813 for child in children:
1814 yield from self.visit(child)
1815
1816 if child.type == token.ASYNC:
1817 break
1818
1819 internal_stmt = next(children)
1820 for child in internal_stmt.children:
1821 yield from self.visit(child)
1822
1823 def visit_decorators(self, node: Node) -> Iterator[Line]:
1824 """Visit decorators."""
1825 for child in node.children:
1826 yield from self.line()
1827 yield from self.visit(child)
1828
1829 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1830 """Remove a semicolon and put the other statement on a separate line."""
1831 yield from self.line()
1832
1833 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1834 """End of file. Process outstanding comments and end with a newline."""
1835 yield from self.visit_default(leaf)
1836 yield from self.line()
1837
1838 def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1839 if not self.current_line.bracket_tracker.any_open_brackets():
1840 yield from self.line()
1841 yield from self.visit_default(leaf)
1842
1843 def __attrs_post_init__(self) -> None:
1844 """You are in a twisty little maze of passages."""
1845 v = self.visit_stmt
1846 Ø: Set[str] = set()
1847 self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1848 self.visit_if_stmt = partial(
1849 v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1850 )
1851 self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1852 self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1853 self.visit_try_stmt = partial(
1854 v, keywords={"try", "except", "else", "finally"}, parens=Ø
1855 )
1856 self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1857 self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1858 self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1859 self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1860 self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1861 self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1862 self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1863 self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1864 self.visit_async_funcdef = self.visit_async_stmt
1865 self.visit_decorated = self.visit_decorators
1866
1867
1868 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1869 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1870 OPENING_BRACKETS = set(BRACKET.keys())
1871 CLOSING_BRACKETS = set(BRACKET.values())
1872 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1873 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1874
1875
1876 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901
1877 """Return whitespace prefix if needed for the given `leaf`.
1878
1879 `complex_subscript` signals whether the given leaf is part of a subscription
1880 which has non-trivial arguments, like arithmetic expressions or function calls.
1881 """
1882 NO = ""
1883 SPACE = " "
1884 DOUBLESPACE = " "
1885 t = leaf.type
1886 p = leaf.parent
1887 v = leaf.value
1888 if t in ALWAYS_NO_SPACE:
1889 return NO
1890
1891 if t == token.COMMENT:
1892 return DOUBLESPACE
1893
1894 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1895 if t == token.COLON and p.type not in {
1896 syms.subscript,
1897 syms.subscriptlist,
1898 syms.sliceop,
1899 }:
1900 return NO
1901
1902 prev = leaf.prev_sibling
1903 if not prev:
1904 prevp = preceding_leaf(p)
1905 if not prevp or prevp.type in OPENING_BRACKETS:
1906 return NO
1907
1908 if t == token.COLON:
1909 if prevp.type == token.COLON:
1910 return NO
1911
1912 elif prevp.type != token.COMMA and not complex_subscript:
1913 return NO
1914
1915 return SPACE
1916
1917 if prevp.type == token.EQUAL:
1918 if prevp.parent:
1919 if prevp.parent.type in {
1920 syms.arglist,
1921 syms.argument,
1922 syms.parameters,
1923 syms.varargslist,
1924 }:
1925 return NO
1926
1927 elif prevp.parent.type == syms.typedargslist:
1928 # A bit hacky: if the equal sign has whitespace, it means we
1929 # previously found it's a typed argument. So, we're using
1930 # that, too.
1931 return prevp.prefix
1932
1933 elif prevp.type in VARARGS_SPECIALS:
1934 if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1935 return NO
1936
1937 elif prevp.type == token.COLON:
1938 if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1939 return SPACE if complex_subscript else NO
1940
1941 elif (
1942 prevp.parent
1943 and prevp.parent.type == syms.factor
1944 and prevp.type in MATH_OPERATORS
1945 ):
1946 return NO
1947
1948 elif (
1949 prevp.type == token.RIGHTSHIFT
1950 and prevp.parent
1951 and prevp.parent.type == syms.shift_expr
1952 and prevp.prev_sibling
1953 and prevp.prev_sibling.type == token.NAME
1954 and prevp.prev_sibling.value == "print" # type: ignore
1955 ):
1956 # Python 2 print chevron
1957 return NO
1958
1959 elif prev.type in OPENING_BRACKETS:
1960 return NO
1961
1962 if p.type in {syms.parameters, syms.arglist}:
1963 # untyped function signatures or calls
1964 if not prev or prev.type != token.COMMA:
1965 return NO
1966
1967 elif p.type == syms.varargslist:
1968 # lambdas
1969 if prev and prev.type != token.COMMA:
1970 return NO
1971
1972 elif p.type == syms.typedargslist:
1973 # typed function signatures
1974 if not prev:
1975 return NO
1976
1977 if t == token.EQUAL:
1978 if prev.type != syms.tname:
1979 return NO
1980
1981 elif prev.type == token.EQUAL:
1982 # A bit hacky: if the equal sign has whitespace, it means we
1983 # previously found it's a typed argument. So, we're using that, too.
1984 return prev.prefix
1985
1986 elif prev.type != token.COMMA:
1987 return NO
1988
1989 elif p.type == syms.tname:
1990 # type names
1991 if not prev:
1992 prevp = preceding_leaf(p)
1993 if not prevp or prevp.type != token.COMMA:
1994 return NO
1995
1996 elif p.type == syms.trailer:
1997 # attributes and calls
1998 if t == token.LPAR or t == token.RPAR:
1999 return NO
2000
2001 if not prev:
2002 if t == token.DOT:
2003 prevp = preceding_leaf(p)
2004 if not prevp or prevp.type != token.NUMBER:
2005 return NO
2006
2007 elif t == token.LSQB:
2008 return NO
2009
2010 elif prev.type != token.COMMA:
2011 return NO
2012
2013 elif p.type == syms.argument:
2014 # single argument
2015 if t == token.EQUAL:
2016 return NO
2017
2018 if not prev:
2019 prevp = preceding_leaf(p)
2020 if not prevp or prevp.type == token.LPAR:
2021 return NO
2022
2023 elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2024 return NO
2025
2026 elif p.type == syms.decorator:
2027 # decorators
2028 return NO
2029
2030 elif p.type == syms.dotted_name:
2031 if prev:
2032 return NO
2033
2034 prevp = preceding_leaf(p)
2035 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2036 return NO
2037
2038 elif p.type == syms.classdef:
2039 if t == token.LPAR:
2040 return NO
2041
2042 if prev and prev.type == token.LPAR:
2043 return NO
2044
2045 elif p.type in {syms.subscript, syms.sliceop}:
2046 # indexing
2047 if not prev:
2048 assert p.parent is not None, "subscripts are always parented"
2049 if p.parent.type == syms.subscriptlist:
2050 return SPACE
2051
2052 return NO
2053
2054 elif not complex_subscript:
2055 return NO
2056
2057 elif p.type == syms.atom:
2058 if prev and t == token.DOT:
2059 # dots, but not the first one.
2060 return NO
2061
2062 elif p.type == syms.dictsetmaker:
2063 # dict unpacking
2064 if prev and prev.type == token.DOUBLESTAR:
2065 return NO
2066
2067 elif p.type in {syms.factor, syms.star_expr}:
2068 # unary ops
2069 if not prev:
2070 prevp = preceding_leaf(p)
2071 if not prevp or prevp.type in OPENING_BRACKETS:
2072 return NO
2073
2074 prevp_parent = prevp.parent
2075 assert prevp_parent is not None
2076 if prevp.type == token.COLON and prevp_parent.type in {
2077 syms.subscript,
2078 syms.sliceop,
2079 }:
2080 return NO
2081
2082 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2083 return NO
2084
2085 elif t in {token.NAME, token.NUMBER, token.STRING}:
2086 return NO
2087
2088 elif p.type == syms.import_from:
2089 if t == token.DOT:
2090 if prev and prev.type == token.DOT:
2091 return NO
2092
2093 elif t == token.NAME:
2094 if v == "import":
2095 return SPACE
2096
2097 if prev and prev.type == token.DOT:
2098 return NO
2099
2100 elif p.type == syms.sliceop:
2101 return NO
2102
2103 return SPACE
2104
2105
2106 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2107 """Return the first leaf that precedes `node`, if any."""
2108 while node:
2109 res = node.prev_sibling
2110 if res:
2111 if isinstance(res, Leaf):
2112 return res
2113
2114 try:
2115 return list(res.leaves())[-1]
2116
2117 except IndexError:
2118 return None
2119
2120 node = node.parent
2121 return None
2122
2123
2124 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2125 """Return the child of `ancestor` that contains `descendant`."""
2126 node: Optional[LN] = descendant
2127 while node and node.parent != ancestor:
2128 node = node.parent
2129 return node
2130
2131
2132 def container_of(leaf: Leaf) -> LN:
2133 """Return `leaf` or one of its ancestors that is the topmost container of it.
2134
2135 By "container" we mean a node where `leaf` is the very first child.
2136 """
2137 same_prefix = leaf.prefix
2138 container: LN = leaf
2139 while container:
2140 parent = container.parent
2141 if parent is None:
2142 break
2143
2144 if parent.children[0].prefix != same_prefix:
2145 break
2146
2147 if parent.type == syms.file_input:
2148 break
2149
2150 if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2151 break
2152
2153 container = parent
2154 return container
2155
2156
2157 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2158 """Return the priority of the `leaf` delimiter, given a line break after it.
2159
2160 The delimiter priorities returned here are from those delimiters that would
2161 cause a line break after themselves.
2162
2163 Higher numbers are higher priority.
2164 """
2165 if leaf.type == token.COMMA:
2166 return COMMA_PRIORITY
2167
2168 return 0
2169
2170
2171 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2172 """Return the priority of the `leaf` delimiter, given a line break before it.
2173
2174 The delimiter priorities returned here are from those delimiters that would
2175 cause a line break before themselves.
2176
2177 Higher numbers are higher priority.
2178 """
2179 if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2180 # * and ** might also be MATH_OPERATORS but in this case they are not.
2181 # Don't treat them as a delimiter.
2182 return 0
2183
2184 if (
2185 leaf.type == token.DOT
2186 and leaf.parent
2187 and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2188 and (previous is None or previous.type in CLOSING_BRACKETS)
2189 ):
2190 return DOT_PRIORITY
2191
2192 if (
2193 leaf.type in MATH_OPERATORS
2194 and leaf.parent
2195 and leaf.parent.type not in {syms.factor, syms.star_expr}
2196 ):
2197 return MATH_PRIORITIES[leaf.type]
2198
2199 if leaf.type in COMPARATORS:
2200 return COMPARATOR_PRIORITY
2201
2202 if (
2203 leaf.type == token.STRING
2204 and previous is not None
2205 and previous.type == token.STRING
2206 ):
2207 return STRING_PRIORITY
2208
2209 if leaf.type not in {token.NAME, token.ASYNC}:
2210 return 0
2211
2212 if (
2213 leaf.value == "for"
2214 and leaf.parent
2215 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2216 or leaf.type == token.ASYNC
2217 ):
2218 if (
2219 not isinstance(leaf.prev_sibling, Leaf)
2220 or leaf.prev_sibling.value != "async"
2221 ):
2222 return COMPREHENSION_PRIORITY
2223
2224 if (
2225 leaf.value == "if"
2226 and leaf.parent
2227 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2228 ):
2229 return COMPREHENSION_PRIORITY
2230
2231 if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2232 return TERNARY_PRIORITY
2233
2234 if leaf.value == "is":
2235 return COMPARATOR_PRIORITY
2236
2237 if (
2238 leaf.value == "in"
2239 and leaf.parent
2240 and leaf.parent.type in {syms.comp_op, syms.comparison}
2241 and not (
2242 previous is not None
2243 and previous.type == token.NAME
2244 and previous.value == "not"
2245 )
2246 ):
2247 return COMPARATOR_PRIORITY
2248
2249 if (
2250 leaf.value == "not"
2251 and leaf.parent
2252 and leaf.parent.type == syms.comp_op
2253 and not (
2254 previous is not None
2255 and previous.type == token.NAME
2256 and previous.value == "is"
2257 )
2258 ):
2259 return COMPARATOR_PRIORITY
2260
2261 if leaf.value in LOGIC_OPERATORS and leaf.parent:
2262 return LOGIC_PRIORITY
2263
2264 return 0
2265
2266
2267 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2268 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2269
2270
2271 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2272 """Clean the prefix of the `leaf` and generate comments from it, if any.
2273
2274 Comments in lib2to3 are shoved into the whitespace prefix. This happens
2275 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
2276 move because it does away with modifying the grammar to include all the
2277 possible places in which comments can be placed.
2278
2279 The sad consequence for us though is that comments don't "belong" anywhere.
2280 This is why this function generates simple parentless Leaf objects for
2281 comments. We simply don't know what the correct parent should be.
2282
2283 No matter though, we can live without this. We really only need to
2284 differentiate between inline and standalone comments. The latter don't
2285 share the line with any code.
2286
2287 Inline comments are emitted as regular token.COMMENT leaves. Standalone
2288 are emitted with a fake STANDALONE_COMMENT token identifier.
2289 """
2290 for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2291 yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2292
2293
2294 @dataclass
2295 class ProtoComment:
2296 """Describes a piece of syntax that is a comment.
2297
2298 It's not a :class:`blib2to3.pytree.Leaf` so that:
2299
2300 * it can be cached (`Leaf` objects should not be reused more than once as
2301 they store their lineno, column, prefix, and parent information);
2302 * `newlines` and `consumed` fields are kept separate from the `value`. This
2303 simplifies handling of special marker comments like ``# fmt: off/on``.
2304 """
2305
2306 type: int # token.COMMENT or STANDALONE_COMMENT
2307 value: str # content of the comment
2308 newlines: int # how many newlines before the comment
2309 consumed: int # how many characters of the original leaf's prefix did we consume
2310
2311
2312 @lru_cache(maxsize=4096)
2313 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2314 """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2315 result: List[ProtoComment] = []
2316 if not prefix or "#" not in prefix:
2317 return result
2318
2319 consumed = 0
2320 nlines = 0
2321 ignored_lines = 0
2322 for index, line in enumerate(prefix.split("\n")):
2323 consumed += len(line) + 1 # adding the length of the split '\n'
2324 line = line.lstrip()
2325 if not line:
2326 nlines += 1
2327 if not line.startswith("#"):
2328 # Escaped newlines outside of a comment are not really newlines at
2329 # all. We treat a single-line comment following an escaped newline
2330 # as a simple trailing comment.
2331 if line.endswith("\\"):
2332 ignored_lines += 1
2333 continue
2334
2335 if index == ignored_lines and not is_endmarker:
2336 comment_type = token.COMMENT # simple trailing comment
2337 else:
2338 comment_type = STANDALONE_COMMENT
2339 comment = make_comment(line)
2340 result.append(
2341 ProtoComment(
2342 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2343 )
2344 )
2345 nlines = 0
2346 return result
2347
2348
2349 def make_comment(content: str) -> str:
2350 """Return a consistently formatted comment from the given `content` string.
2351
2352 All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2353 space between the hash sign and the content.
2354
2355 If `content` didn't start with a hash sign, one is provided.
2356 """
2357 content = content.rstrip()
2358 if not content:
2359 return "#"
2360
2361 if content[0] == "#":
2362 content = content[1:]
2363 if content and content[0] not in " !:#'%":
2364 content = " " + content
2365 return "#" + content
2366
2367
2368 def split_line(
2369 line: Line,
2370 line_length: int,
2371 inner: bool = False,
2372 features: Collection[Feature] = (),
2373 ) -> Iterator[Line]:
2374 """Split a `line` into potentially many lines.
2375
2376 They should fit in the allotted `line_length` but might not be able to.
2377 `inner` signifies that there were a pair of brackets somewhere around the
2378 current `line`, possibly transitively. This means we can fallback to splitting
2379 by delimiters if the LHS/RHS don't yield any results.
2380
2381 `features` are syntactical features that may be used in the output.
2382 """
2383 if line.is_comment:
2384 yield line
2385 return
2386
2387 line_str = str(line).strip("\n")
2388
2389 if (
2390 not line.contains_uncollapsable_type_comments()
2391 and not line.should_explode
2392 and not line.is_collection_with_optional_trailing_comma
2393 and (
2394 is_line_short_enough(line, line_length=line_length, line_str=line_str)
2395 or line.contains_unsplittable_type_ignore()
2396 )
2397 ):
2398 yield line
2399 return
2400
2401 split_funcs: List[SplitFunc]
2402 if line.is_def:
2403 split_funcs = [left_hand_split]
2404 else:
2405
2406 def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2407 for omit in generate_trailers_to_omit(line, line_length):
2408 lines = list(right_hand_split(line, line_length, features, omit=omit))
2409 if is_line_short_enough(lines[0], line_length=line_length):
2410 yield from lines
2411 return
2412
2413 # All splits failed, best effort split with no omits.
2414 # This mostly happens to multiline strings that are by definition
2415 # reported as not fitting a single line.
2416 yield from right_hand_split(line, line_length, features=features)
2417
2418 if line.inside_brackets:
2419 split_funcs = [delimiter_split, standalone_comment_split, rhs]
2420 else:
2421 split_funcs = [rhs]
2422 for split_func in split_funcs:
2423 # We are accumulating lines in `result` because we might want to abort
2424 # mission and return the original line in the end, or attempt a different
2425 # split altogether.
2426 result: List[Line] = []
2427 try:
2428 for l in split_func(line, features):
2429 if str(l).strip("\n") == line_str:
2430 raise CannotSplit("Split function returned an unchanged result")
2431
2432 result.extend(
2433 split_line(
2434 l, line_length=line_length, inner=True, features=features
2435 )
2436 )
2437 except CannotSplit:
2438 continue
2439
2440 else:
2441 yield from result
2442 break
2443
2444 else:
2445 yield line
2446
2447
2448 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2449 """Split line into many lines, starting with the first matching bracket pair.
2450
2451 Note: this usually looks weird, only use this for function definitions.
2452 Prefer RHS otherwise. This is why this function is not symmetrical with
2453 :func:`right_hand_split` which also handles optional parentheses.
2454 """
2455 tail_leaves: List[Leaf] = []
2456 body_leaves: List[Leaf] = []
2457 head_leaves: List[Leaf] = []
2458 current_leaves = head_leaves
2459 matching_bracket = None
2460 for leaf in line.leaves:
2461 if (
2462 current_leaves is body_leaves
2463 and leaf.type in CLOSING_BRACKETS
2464 and leaf.opening_bracket is matching_bracket
2465 ):
2466 current_leaves = tail_leaves if body_leaves else head_leaves
2467 current_leaves.append(leaf)
2468 if current_leaves is head_leaves:
2469 if leaf.type in OPENING_BRACKETS:
2470 matching_bracket = leaf
2471 current_leaves = body_leaves
2472 if not matching_bracket:
2473 raise CannotSplit("No brackets found")
2474
2475 head = bracket_split_build_line(head_leaves, line, matching_bracket)
2476 body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2477 tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2478 bracket_split_succeeded_or_raise(head, body, tail)
2479 for result in (head, body, tail):
2480 if result:
2481 yield result
2482
2483
2484 def right_hand_split(
2485 line: Line,
2486 line_length: int,
2487 features: Collection[Feature] = (),
2488 omit: Collection[LeafID] = (),
2489 ) -> Iterator[Line]:
2490 """Split line into many lines, starting with the last matching bracket pair.
2491
2492 If the split was by optional parentheses, attempt splitting without them, too.
2493 `omit` is a collection of closing bracket IDs that shouldn't be considered for
2494 this split.
2495
2496 Note: running this function modifies `bracket_depth` on the leaves of `line`.
2497 """
2498 tail_leaves: List[Leaf] = []
2499 body_leaves: List[Leaf] = []
2500 head_leaves: List[Leaf] = []
2501 current_leaves = tail_leaves
2502 opening_bracket = None
2503 closing_bracket = None
2504 for leaf in reversed(line.leaves):
2505 if current_leaves is body_leaves:
2506 if leaf is opening_bracket:
2507 current_leaves = head_leaves if body_leaves else tail_leaves
2508 current_leaves.append(leaf)
2509 if current_leaves is tail_leaves:
2510 if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2511 opening_bracket = leaf.opening_bracket
2512 closing_bracket = leaf
2513 current_leaves = body_leaves
2514 if not (opening_bracket and closing_bracket and head_leaves):
2515 # If there is no opening or closing_bracket that means the split failed and
2516 # all content is in the tail. Otherwise, if `head_leaves` are empty, it means
2517 # the matching `opening_bracket` wasn't available on `line` anymore.
2518 raise CannotSplit("No brackets found")
2519
2520 tail_leaves.reverse()
2521 body_leaves.reverse()
2522 head_leaves.reverse()
2523 head = bracket_split_build_line(head_leaves, line, opening_bracket)
2524 body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2525 tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2526 bracket_split_succeeded_or_raise(head, body, tail)
2527 if (
2528 # the body shouldn't be exploded
2529 not body.should_explode
2530 # the opening bracket is an optional paren
2531 and opening_bracket.type == token.LPAR
2532 and not opening_bracket.value
2533 # the closing bracket is an optional paren
2534 and closing_bracket.type == token.RPAR
2535 and not closing_bracket.value
2536 # it's not an import (optional parens are the only thing we can split on
2537 # in this case; attempting a split without them is a waste of time)
2538 and not line.is_import
2539 # there are no standalone comments in the body
2540 and not body.contains_standalone_comments(0)
2541 # and we can actually remove the parens
2542 and can_omit_invisible_parens(body, line_length)
2543 ):
2544 omit = {id(closing_bracket), *omit}
2545 try:
2546 yield from right_hand_split(line, line_length, features=features, omit=omit)
2547 return
2548
2549 except CannotSplit:
2550 if not (
2551 can_be_split(body)
2552 or is_line_short_enough(body, line_length=line_length)
2553 ):
2554 raise CannotSplit(
2555 "Splitting failed, body is still too long and can't be split."
2556 )
2557
2558 elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2559 raise CannotSplit(
2560 "The current optional pair of parentheses is bound to fail to "
2561 "satisfy the splitting algorithm because the head or the tail "
2562 "contains multiline strings which by definition never fit one "
2563 "line."
2564 )
2565
2566 ensure_visible(opening_bracket)
2567 ensure_visible(closing_bracket)
2568 for result in (head, body, tail):
2569 if result:
2570 yield result
2571
2572
2573 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2574 """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2575
2576 Do nothing otherwise.
2577
2578 A left- or right-hand split is based on a pair of brackets. Content before
2579 (and including) the opening bracket is left on one line, content inside the
2580 brackets is put on a separate line, and finally content starting with and
2581 following the closing bracket is put on a separate line.
2582
2583 Those are called `head`, `body`, and `tail`, respectively. If the split
2584 produced the same line (all content in `head`) or ended up with an empty `body`
2585 and the `tail` is just the closing bracket, then it's considered failed.
2586 """
2587 tail_len = len(str(tail).strip())
2588 if not body:
2589 if tail_len == 0:
2590 raise CannotSplit("Splitting brackets produced the same line")
2591
2592 elif tail_len < 3:
2593 raise CannotSplit(
2594 f"Splitting brackets on an empty body to save "
2595 f"{tail_len} characters is not worth it"
2596 )
2597
2598
2599 def bracket_split_build_line(
2600 leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2601 ) -> Line:
2602 """Return a new line with given `leaves` and respective comments from `original`.
2603
2604 If `is_body` is True, the result line is one-indented inside brackets and as such
2605 has its first leaf's prefix normalized and a trailing comma added when expected.
2606 """
2607 result = Line(depth=original.depth)
2608 if is_body:
2609 result.inside_brackets = True
2610 result.depth += 1
2611 if leaves:
2612 # Since body is a new indent level, remove spurious leading whitespace.
2613 normalize_prefix(leaves[0], inside_brackets=True)
2614 # Ensure a trailing comma for imports and standalone function arguments, but
2615 # be careful not to add one after any comments.
2616 no_commas = original.is_def and not any(
2617 l.type == token.COMMA for l in leaves
2618 )
2619
2620 if original.is_import or no_commas:
2621 for i in range(len(leaves) - 1, -1, -1):
2622 if leaves[i].type == STANDALONE_COMMENT:
2623 continue
2624 elif leaves[i].type == token.COMMA:
2625 break
2626 else:
2627 leaves.insert(i + 1, Leaf(token.COMMA, ","))
2628 break
2629 # Populate the line
2630 for leaf in leaves:
2631 result.append(leaf, preformatted=True)
2632 for comment_after in original.comments_after(leaf):
2633 result.append(comment_after, preformatted=True)
2634 if is_body:
2635 result.should_explode = should_explode(result, opening_bracket)
2636 return result
2637
2638
2639 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2640 """Normalize prefix of the first leaf in every line returned by `split_func`.
2641
2642 This is a decorator over relevant split functions.
2643 """
2644
2645 @wraps(split_func)
2646 def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2647 for l in split_func(line, features):
2648 normalize_prefix(l.leaves[0], inside_brackets=True)
2649 yield l
2650
2651 return split_wrapper
2652
2653
2654 @dont_increase_indentation
2655 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2656 """Split according to delimiters of the highest priority.
2657
2658 If the appropriate Features are given, the split will add trailing commas
2659 also in function signatures and calls that contain `*` and `**`.
2660 """
2661 try:
2662 last_leaf = line.leaves[-1]
2663 except IndexError:
2664 raise CannotSplit("Line empty")
2665
2666 bt = line.bracket_tracker
2667 try:
2668 delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2669 except ValueError:
2670 raise CannotSplit("No delimiters found")
2671
2672 if delimiter_priority == DOT_PRIORITY:
2673 if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2674 raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2675
2676 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2677 lowest_depth = sys.maxsize
2678 trailing_comma_safe = True
2679
2680 def append_to_line(leaf: Leaf) -> Iterator[Line]:
2681 """Append `leaf` to current line or to new line if appending impossible."""
2682 nonlocal current_line
2683 try:
2684 current_line.append_safe(leaf, preformatted=True)
2685 except ValueError:
2686 yield current_line
2687
2688 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2689 current_line.append(leaf)
2690
2691 for leaf in line.leaves:
2692 yield from append_to_line(leaf)
2693
2694 for comment_after in line.comments_after(leaf):
2695 yield from append_to_line(comment_after)
2696
2697 lowest_depth = min(lowest_depth, leaf.bracket_depth)
2698 if leaf.bracket_depth == lowest_depth:
2699 if is_vararg(leaf, within={syms.typedargslist}):
2700 trailing_comma_safe = (
2701 trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2702 )
2703 elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2704 trailing_comma_safe = (
2705 trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2706 )
2707
2708 leaf_priority = bt.delimiters.get(id(leaf))
2709 if leaf_priority == delimiter_priority:
2710 yield current_line
2711
2712 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2713 if current_line:
2714 if (
2715 trailing_comma_safe
2716 and delimiter_priority == COMMA_PRIORITY
2717 and current_line.leaves[-1].type != token.COMMA
2718 and current_line.leaves[-1].type != STANDALONE_COMMENT
2719 ):
2720 current_line.append(Leaf(token.COMMA, ","))
2721 yield current_line
2722
2723
2724 @dont_increase_indentation
2725 def standalone_comment_split(
2726 line: Line, features: Collection[Feature] = ()
2727 ) -> Iterator[Line]:
2728 """Split standalone comments from the rest of the line."""
2729 if not line.contains_standalone_comments(0):
2730 raise CannotSplit("Line does not have any standalone comments")
2731
2732 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2733
2734 def append_to_line(leaf: Leaf) -> Iterator[Line]:
2735 """Append `leaf` to current line or to new line if appending impossible."""
2736 nonlocal current_line
2737 try:
2738 current_line.append_safe(leaf, preformatted=True)
2739 except ValueError:
2740 yield current_line
2741
2742 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2743 current_line.append(leaf)
2744
2745 for leaf in line.leaves:
2746 yield from append_to_line(leaf)
2747
2748 for comment_after in line.comments_after(leaf):
2749 yield from append_to_line(comment_after)
2750
2751 if current_line:
2752 yield current_line
2753
2754
2755 def is_import(leaf: Leaf) -> bool:
2756 """Return True if the given leaf starts an import statement."""
2757 p = leaf.parent
2758 t = leaf.type
2759 v = leaf.value
2760 return bool(
2761 t == token.NAME
2762 and (
2763 (v == "import" and p and p.type == syms.import_name)
2764 or (v == "from" and p and p.type == syms.import_from)
2765 )
2766 )
2767
2768
2769 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
2770 """Return True if the given leaf is a special comment.
2771 Only returns true for type comments for now."""
2772 t = leaf.type
2773 v = leaf.value
2774 return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith(
2775 "# type:" + suffix
2776 )
2777
2778
2779 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2780 """Leave existing extra newlines if not `inside_brackets`. Remove everything
2781 else.
2782
2783 Note: don't use backslashes for formatting or you'll lose your voting rights.
2784 """
2785 if not inside_brackets:
2786 spl = leaf.prefix.split("#")
2787 if "\\" not in spl[0]:
2788 nl_count = spl[-1].count("\n")
2789 if len(spl) > 1:
2790 nl_count -= 1
2791 leaf.prefix = "\n" * nl_count
2792 return
2793
2794 leaf.prefix = ""
2795
2796
2797 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2798 """Make all string prefixes lowercase.
2799
2800 If remove_u_prefix is given, also removes any u prefix from the string.
2801
2802 Note: Mutates its argument.
2803 """
2804 match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2805 assert match is not None, f"failed to match string {leaf.value!r}"
2806 orig_prefix = match.group(1)
2807 new_prefix = orig_prefix.lower()
2808 if remove_u_prefix:
2809 new_prefix = new_prefix.replace("u", "")
2810 leaf.value = f"{new_prefix}{match.group(2)}"
2811
2812
2813 def normalize_string_quotes(leaf: Leaf) -> None:
2814 """Prefer double quotes but only if it doesn't cause more escaping.
2815
2816 Adds or removes backslashes as appropriate. Doesn't parse and fix
2817 strings nested in f-strings (yet).
2818
2819 Note: Mutates its argument.
2820 """
2821 value = leaf.value.lstrip("furbFURB")
2822 if value[:3] == '"""':
2823 return
2824
2825 elif value[:3] == "'''":
2826 orig_quote = "'''"
2827 new_quote = '"""'
2828 elif value[0] == '"':
2829 orig_quote = '"'
2830 new_quote = "'"
2831 else:
2832 orig_quote = "'"
2833 new_quote = '"'
2834 first_quote_pos = leaf.value.find(orig_quote)
2835 if first_quote_pos == -1:
2836 return # There's an internal error
2837
2838 prefix = leaf.value[:first_quote_pos]
2839 unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2840 escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2841 escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2842 body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2843 if "r" in prefix.casefold():
2844 if unescaped_new_quote.search(body):
2845 # There's at least one unescaped new_quote in this raw string
2846 # so converting is impossible
2847 return
2848
2849 # Do not introduce or remove backslashes in raw strings
2850 new_body = body
2851 else:
2852 # remove unnecessary escapes
2853 new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2854 if body != new_body:
2855 # Consider the string without unnecessary escapes as the original
2856 body = new_body
2857 leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2858 new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2859 new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2860 if "f" in prefix.casefold():
2861 matches = re.findall(
2862 r"""
2863 (?:[^{]|^)\{ # start of the string or a non-{ followed by a single {
2864 ([^{].*?) # contents of the brackets except if begins with {{
2865 \}(?:[^}]|$) # A } followed by end of the string or a non-}
2866 """,
2867 new_body,
2868 re.VERBOSE,
2869 )
2870 for m in matches:
2871 if "\\" in str(m):
2872 # Do not introduce backslashes in interpolated expressions
2873 return
2874 if new_quote == '"""' and new_body[-1:] == '"':
2875 # edge case:
2876 new_body = new_body[:-1] + '\\"'
2877 orig_escape_count = body.count("\\")
2878 new_escape_count = new_body.count("\\")
2879 if new_escape_count > orig_escape_count:
2880 return # Do not introduce more escaping
2881
2882 if new_escape_count == orig_escape_count and orig_quote == '"':
2883 return # Prefer double quotes
2884
2885 leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2886
2887
2888 def normalize_numeric_literal(leaf: Leaf) -> None:
2889 """Normalizes numeric (float, int, and complex) literals.
2890
2891 All letters used in the representation are normalized to lowercase (except
2892 in Python 2 long literals).
2893 """
2894 text = leaf.value.lower()
2895 if text.startswith(("0o", "0b")):
2896 # Leave octal and binary literals alone.
2897 pass
2898 elif text.startswith("0x"):
2899 # Change hex literals to upper case.
2900 before, after = text[:2], text[2:]
2901 text = f"{before}{after.upper()}"
2902 elif "e" in text:
2903 before, after = text.split("e")
2904 sign = ""
2905 if after.startswith("-"):
2906 after = after[1:]
2907 sign = "-"
2908 elif after.startswith("+"):
2909 after = after[1:]
2910 before = format_float_or_int_string(before)
2911 text = f"{before}e{sign}{after}"
2912 elif text.endswith(("j", "l")):
2913 number = text[:-1]
2914 suffix = text[-1]
2915 # Capitalize in "2L" because "l" looks too similar to "1".
2916 if suffix == "l":
2917 suffix = "L"
2918 text = f"{format_float_or_int_string(number)}{suffix}"
2919 else:
2920 text = format_float_or_int_string(text)
2921 leaf.value = text
2922
2923
2924 def format_float_or_int_string(text: str) -> str:
2925 """Formats a float string like "1.0"."""
2926 if "." not in text:
2927 return text
2928
2929 before, after = text.split(".")
2930 return f"{before or 0}.{after or 0}"
2931
2932
2933 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2934 """Make existing optional parentheses invisible or create new ones.
2935
2936 `parens_after` is a set of string leaf values immediately after which parens
2937 should be put.
2938
2939 Standardizes on visible parentheses for single-element tuples, and keeps
2940 existing visible parentheses for other tuples and generator expressions.
2941 """
2942 for pc in list_comments(node.prefix, is_endmarker=False):
2943 if pc.value in FMT_OFF:
2944 # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2945 return
2946
2947 check_lpar = False
2948 for index, child in enumerate(list(node.children)):
2949 # Add parentheses around long tuple unpacking in assignments.
2950 if (
2951 index == 0
2952 and isinstance(child, Node)
2953 and child.type == syms.testlist_star_expr
2954 ):
2955 check_lpar = True
2956
2957 if check_lpar:
2958 if is_walrus_assignment(child):
2959 continue
2960 if child.type == syms.atom:
2961 # Determines if the underlying atom should be surrounded with
2962 # invisible params - also makes parens invisible recursively
2963 # within the atom and removes repeated invisible parens within
2964 # the atom
2965 should_surround_with_parens = maybe_make_parens_invisible_in_atom(
2966 child, parent=node
2967 )
2968
2969 if should_surround_with_parens:
2970 lpar = Leaf(token.LPAR, "")
2971 rpar = Leaf(token.RPAR, "")
2972 index = child.remove() or 0
2973 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2974 elif is_one_tuple(child):
2975 # wrap child in visible parentheses
2976 lpar = Leaf(token.LPAR, "(")
2977 rpar = Leaf(token.RPAR, ")")
2978 child.remove()
2979 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2980 elif node.type == syms.import_from:
2981 # "import from" nodes store parentheses directly as part of
2982 # the statement
2983 if child.type == token.LPAR:
2984 # make parentheses invisible
2985 child.value = "" # type: ignore
2986 node.children[-1].value = "" # type: ignore
2987 elif child.type != token.STAR:
2988 # insert invisible parentheses
2989 node.insert_child(index, Leaf(token.LPAR, ""))
2990 node.append_child(Leaf(token.RPAR, ""))
2991 break
2992
2993 elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2994 # wrap child in invisible parentheses
2995 lpar = Leaf(token.LPAR, "")
2996 rpar = Leaf(token.RPAR, "")
2997 index = child.remove() or 0
2998 prefix = child.prefix
2999 child.prefix = ""
3000 new_child = Node(syms.atom, [lpar, child, rpar])
3001 new_child.prefix = prefix
3002 node.insert_child(index, new_child)
3003
3004 check_lpar = isinstance(child, Leaf) and child.value in parens_after
3005
3006
3007 def normalize_fmt_off(node: Node) -> None:
3008 """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
3009 try_again = True
3010 while try_again:
3011 try_again = convert_one_fmt_off_pair(node)
3012
3013
3014 def convert_one_fmt_off_pair(node: Node) -> bool:
3015 """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
3016
3017 Returns True if a pair was converted.
3018 """
3019 for leaf in node.leaves():
3020 previous_consumed = 0
3021 for comment in list_comments(leaf.prefix, is_endmarker=False):
3022 if comment.value in FMT_OFF:
3023 # We only want standalone comments. If there's no previous leaf or
3024 # the previous leaf is indentation, it's a standalone comment in
3025 # disguise.
3026 if comment.type != STANDALONE_COMMENT:
3027 prev = preceding_leaf(leaf)
3028 if prev and prev.type not in WHITESPACE:
3029 continue
3030
3031 ignored_nodes = list(generate_ignored_nodes(leaf))
3032 if not ignored_nodes:
3033 continue
3034
3035 first = ignored_nodes[0] # Can be a container node with the `leaf`.
3036 parent = first.parent
3037 prefix = first.prefix
3038 first.prefix = prefix[comment.consumed :]
3039 hidden_value = (
3040 comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
3041 )
3042 if hidden_value.endswith("\n"):
3043 # That happens when one of the `ignored_nodes` ended with a NEWLINE
3044 # leaf (possibly followed by a DEDENT).
3045 hidden_value = hidden_value[:-1]
3046 first_idx = None
3047 for ignored in ignored_nodes:
3048 index = ignored.remove()
3049 if first_idx is None:
3050 first_idx = index
3051 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
3052 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
3053 parent.insert_child(
3054 first_idx,
3055 Leaf(
3056 STANDALONE_COMMENT,
3057 hidden_value,
3058 prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3059 ),
3060 )
3061 return True
3062
3063 previous_consumed = comment.consumed
3064
3065 return False
3066
3067
3068 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3069 """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3070
3071 Stops at the end of the block.
3072 """
3073 container: Optional[LN] = container_of(leaf)
3074 while container is not None and container.type != token.ENDMARKER:
3075 for comment in list_comments(container.prefix, is_endmarker=False):
3076 if comment.value in FMT_ON:
3077 return
3078
3079 yield container
3080
3081 container = container.next_sibling
3082
3083
3084 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3085 """If it's safe, make the parens in the atom `node` invisible, recursively.
3086 Additionally, remove repeated, adjacent invisible parens from the atom `node`
3087 as they are redundant.
3088
3089 Returns whether the node should itself be wrapped in invisible parentheses.
3090
3091 """
3092 if (
3093 node.type != syms.atom
3094 or is_empty_tuple(node)
3095 or is_one_tuple(node)
3096 or (is_yield(node) and parent.type != syms.expr_stmt)
3097 or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3098 ):
3099 return False
3100
3101 first = node.children[0]
3102 last = node.children[-1]
3103 if first.type == token.LPAR and last.type == token.RPAR:
3104 middle = node.children[1]
3105 # make parentheses invisible
3106 first.value = "" # type: ignore
3107 last.value = "" # type: ignore
3108 maybe_make_parens_invisible_in_atom(middle, parent=parent)
3109
3110 if is_atom_with_invisible_parens(middle):
3111 # Strip the invisible parens from `middle` by replacing
3112 # it with the child in-between the invisible parens
3113 middle.replace(middle.children[1])
3114
3115 return False
3116
3117 return True
3118
3119
3120 def is_atom_with_invisible_parens(node: LN) -> bool:
3121 """Given a `LN`, determines whether it's an atom `node` with invisible
3122 parens. Useful in dedupe-ing and normalizing parens.
3123 """
3124 if isinstance(node, Leaf) or node.type != syms.atom:
3125 return False
3126
3127 first, last = node.children[0], node.children[-1]
3128 return (
3129 isinstance(first, Leaf)
3130 and first.type == token.LPAR
3131 and first.value == ""
3132 and isinstance(last, Leaf)
3133 and last.type == token.RPAR
3134 and last.value == ""
3135 )
3136
3137
3138 def is_empty_tuple(node: LN) -> bool:
3139 """Return True if `node` holds an empty tuple."""
3140 return (
3141 node.type == syms.atom
3142 and len(node.children) == 2
3143 and node.children[0].type == token.LPAR
3144 and node.children[1].type == token.RPAR
3145 )
3146
3147
3148 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3149 """Returns `wrapped` if `node` is of the shape ( wrapped ).
3150
3151 Parenthesis can be optional. Returns None otherwise"""
3152 if len(node.children) != 3:
3153 return None
3154 lpar, wrapped, rpar = node.children
3155 if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3156 return None
3157
3158 return wrapped
3159
3160
3161 def is_one_tuple(node: LN) -> bool:
3162 """Return True if `node` holds a tuple with one element, with or without parens."""
3163 if node.type == syms.atom:
3164 gexp = unwrap_singleton_parenthesis(node)
3165 if gexp is None or gexp.type != syms.testlist_gexp:
3166 return False
3167
3168 return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3169
3170 return (
3171 node.type in IMPLICIT_TUPLE
3172 and len(node.children) == 2
3173 and node.children[1].type == token.COMMA
3174 )
3175
3176
3177 def is_walrus_assignment(node: LN) -> bool:
3178 """Return True iff `node` is of the shape ( test := test )"""
3179 inner = unwrap_singleton_parenthesis(node)
3180 return inner is not None and inner.type == syms.namedexpr_test
3181
3182
3183 def is_yield(node: LN) -> bool:
3184 """Return True if `node` holds a `yield` or `yield from` expression."""
3185 if node.type == syms.yield_expr:
3186 return True
3187
3188 if node.type == token.NAME and node.value == "yield": # type: ignore
3189 return True
3190
3191 if node.type != syms.atom:
3192 return False
3193
3194 if len(node.children) != 3:
3195 return False
3196
3197 lpar, expr, rpar = node.children
3198 if lpar.type == token.LPAR and rpar.type == token.RPAR:
3199 return is_yield(expr)
3200
3201 return False
3202
3203
3204 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3205 """Return True if `leaf` is a star or double star in a vararg or kwarg.
3206
3207 If `within` includes VARARGS_PARENTS, this applies to function signatures.
3208 If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3209 extended iterable unpacking (PEP 3132) and additional unpacking
3210 generalizations (PEP 448).
3211 """
3212 if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3213 return False
3214
3215 p = leaf.parent
3216 if p.type == syms.star_expr:
3217 # Star expressions are also used as assignment targets in extended
3218 # iterable unpacking (PEP 3132). See what its parent is instead.
3219 if not p.parent:
3220 return False
3221
3222 p = p.parent
3223
3224 return p.type in within
3225
3226
3227 def is_multiline_string(leaf: Leaf) -> bool:
3228 """Return True if `leaf` is a multiline string that actually spans many lines."""
3229 value = leaf.value.lstrip("furbFURB")
3230 return value[:3] in {'"""', "'''"} and "\n" in value
3231
3232
3233 def is_stub_suite(node: Node) -> bool:
3234 """Return True if `node` is a suite with a stub body."""
3235 if (
3236 len(node.children) != 4
3237 or node.children[0].type != token.NEWLINE
3238 or node.children[1].type != token.INDENT
3239 or node.children[3].type != token.DEDENT
3240 ):
3241 return False
3242
3243 return is_stub_body(node.children[2])
3244
3245
3246 def is_stub_body(node: LN) -> bool:
3247 """Return True if `node` is a simple statement containing an ellipsis."""
3248 if not isinstance(node, Node) or node.type != syms.simple_stmt:
3249 return False
3250
3251 if len(node.children) != 2:
3252 return False
3253
3254 child = node.children[0]
3255 return (
3256 child.type == syms.atom
3257 and len(child.children) == 3
3258 and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3259 )
3260
3261
3262 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3263 """Return maximum delimiter priority inside `node`.
3264
3265 This is specific to atoms with contents contained in a pair of parentheses.
3266 If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3267 """
3268 if node.type != syms.atom:
3269 return 0
3270
3271 first = node.children[0]
3272 last = node.children[-1]
3273 if not (first.type == token.LPAR and last.type == token.RPAR):
3274 return 0
3275
3276 bt = BracketTracker()
3277 for c in node.children[1:-1]:
3278 if isinstance(c, Leaf):
3279 bt.mark(c)
3280 else:
3281 for leaf in c.leaves():
3282 bt.mark(leaf)
3283 try:
3284 return bt.max_delimiter_priority()
3285
3286 except ValueError:
3287 return 0
3288
3289
3290 def ensure_visible(leaf: Leaf) -> None:
3291 """Make sure parentheses are visible.
3292
3293 They could be invisible as part of some statements (see
3294 :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3295 """
3296 if leaf.type == token.LPAR:
3297 leaf.value = "("
3298 elif leaf.type == token.RPAR:
3299 leaf.value = ")"
3300
3301
3302 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3303 """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3304
3305 if not (
3306 opening_bracket.parent
3307 and opening_bracket.parent.type in {syms.atom, syms.import_from}
3308 and opening_bracket.value in "[{("
3309 ):
3310 return False
3311
3312 try:
3313 last_leaf = line.leaves[-1]
3314 exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3315 max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3316 except (IndexError, ValueError):
3317 return False
3318
3319 return max_priority == COMMA_PRIORITY
3320
3321
3322 def get_features_used(node: Node) -> Set[Feature]:
3323 """Return a set of (relatively) new Python features used in this file.
3324
3325 Currently looking for:
3326 - f-strings;
3327 - underscores in numeric literals;
3328 - trailing commas after * or ** in function signatures and calls;
3329 - positional only arguments in function signatures and lambdas;
3330 """
3331 features: Set[Feature] = set()
3332 for n in node.pre_order():
3333 if n.type == token.STRING:
3334 value_head = n.value[:2] # type: ignore
3335 if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3336 features.add(Feature.F_STRINGS)
3337
3338 elif n.type == token.NUMBER:
3339 if "_" in n.value: # type: ignore
3340 features.add(Feature.NUMERIC_UNDERSCORES)
3341
3342 elif n.type == token.SLASH:
3343 if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3344 features.add(Feature.POS_ONLY_ARGUMENTS)
3345
3346 elif n.type == token.COLONEQUAL:
3347 features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3348
3349 elif (
3350 n.type in {syms.typedargslist, syms.arglist}
3351 and n.children
3352 and n.children[-1].type == token.COMMA
3353 ):
3354 if n.type == syms.typedargslist:
3355 feature = Feature.TRAILING_COMMA_IN_DEF
3356 else:
3357 feature = Feature.TRAILING_COMMA_IN_CALL
3358
3359 for ch in n.children:
3360 if ch.type in STARS:
3361 features.add(feature)
3362
3363 if ch.type == syms.argument:
3364 for argch in ch.children:
3365 if argch.type in STARS:
3366 features.add(feature)
3367
3368 return features
3369
3370
3371 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3372 """Detect the version to target based on the nodes used."""
3373 features = get_features_used(node)
3374 return {
3375 version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3376 }
3377
3378
3379 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3380 """Generate sets of closing bracket IDs that should be omitted in a RHS.
3381
3382 Brackets can be omitted if the entire trailer up to and including
3383 a preceding closing bracket fits in one line.
3384
3385 Yielded sets are cumulative (contain results of previous yields, too). First
3386 set is empty.
3387 """
3388
3389 omit: Set[LeafID] = set()
3390 yield omit
3391
3392 length = 4 * line.depth
3393 opening_bracket = None
3394 closing_bracket = None
3395 inner_brackets: Set[LeafID] = set()
3396 for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3397 length += leaf_length
3398 if length > line_length:
3399 break
3400
3401 has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3402 if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3403 break
3404
3405 if opening_bracket:
3406 if leaf is opening_bracket:
3407 opening_bracket = None
3408 elif leaf.type in CLOSING_BRACKETS:
3409 inner_brackets.add(id(leaf))
3410 elif leaf.type in CLOSING_BRACKETS:
3411 if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3412 # Empty brackets would fail a split so treat them as "inner"
3413 # brackets (e.g. only add them to the `omit` set if another
3414 # pair of brackets was good enough.
3415 inner_brackets.add(id(leaf))
3416 continue
3417
3418 if closing_bracket:
3419 omit.add(id(closing_bracket))
3420 omit.update(inner_brackets)
3421 inner_brackets.clear()
3422 yield omit
3423
3424 if leaf.value:
3425 opening_bracket = leaf.opening_bracket
3426 closing_bracket = leaf
3427
3428
3429 def get_future_imports(node: Node) -> Set[str]:
3430 """Return a set of __future__ imports in the file."""
3431 imports: Set[str] = set()
3432
3433 def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3434 for child in children:
3435 if isinstance(child, Leaf):
3436 if child.type == token.NAME:
3437 yield child.value
3438 elif child.type == syms.import_as_name:
3439 orig_name = child.children[0]
3440 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3441 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3442 yield orig_name.value
3443 elif child.type == syms.import_as_names:
3444 yield from get_imports_from_children(child.children)
3445 else:
3446 raise AssertionError("Invalid syntax parsing imports")
3447
3448 for child in node.children:
3449 if child.type != syms.simple_stmt:
3450 break
3451 first_child = child.children[0]
3452 if isinstance(first_child, Leaf):
3453 # Continue looking if we see a docstring; otherwise stop.
3454 if (
3455 len(child.children) == 2
3456 and first_child.type == token.STRING
3457 and child.children[1].type == token.NEWLINE
3458 ):
3459 continue
3460 else:
3461 break
3462 elif first_child.type == syms.import_from:
3463 module_name = first_child.children[1]
3464 if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3465 break
3466 imports |= set(get_imports_from_children(first_child.children[3:]))
3467 else:
3468 break
3469 return imports
3470
3471
3472 def gen_python_files_in_dir(
3473 path: Path,
3474 root: Path,
3475 include: Pattern[str],
3476 exclude: Pattern[str],
3477 report: "Report",
3478 ) -> Iterator[Path]:
3479 """Generate all files under `path` whose paths are not excluded by the
3480 `exclude` regex, but are included by the `include` regex.
3481
3482 Symbolic links pointing outside of the `root` directory are ignored.
3483
3484 `report` is where output about exclusions goes.
3485 """
3486 assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3487 for child in path.iterdir():
3488 try:
3489 normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3490 except ValueError:
3491 if child.is_symlink():
3492 report.path_ignored(
3493 child, f"is a symbolic link that points outside {root}"
3494 )
3495 continue
3496
3497 raise
3498
3499 if child.is_dir():
3500 normalized_path += "/"
3501 exclude_match = exclude.search(normalized_path)
3502 if exclude_match and exclude_match.group(0):
3503 report.path_ignored(child, f"matches the --exclude regular expression")
3504 continue
3505
3506 if child.is_dir():
3507 yield from gen_python_files_in_dir(child, root, include, exclude, report)
3508
3509 elif child.is_file():
3510 include_match = include.search(normalized_path)
3511 if include_match:
3512 yield child
3513
3514
3515 @lru_cache()
3516 def find_project_root(srcs: Iterable[str]) -> Path:
3517 """Return a directory containing .git, .hg, or pyproject.toml.
3518
3519 That directory can be one of the directories passed in `srcs` or their
3520 common parent.
3521
3522 If no directory in the tree contains a marker that would specify it's the
3523 project root, the root of the file system is returned.
3524 """
3525 if not srcs:
3526 return Path("/").resolve()
3527
3528 common_base = min(Path(src).resolve() for src in srcs)
3529 if common_base.is_dir():
3530 # Append a fake file so `parents` below returns `common_base_dir`, too.
3531 common_base /= "fake-file"
3532 for directory in common_base.parents:
3533 if (directory / ".git").is_dir():
3534 return directory
3535
3536 if (directory / ".hg").is_dir():
3537 return directory
3538
3539 if (directory / "pyproject.toml").is_file():
3540 return directory
3541
3542 return directory
3543
3544
3545 @dataclass
3546 class Report:
3547 """Provides a reformatting counter. Can be rendered with `str(report)`."""
3548
3549 check: bool = False
3550 quiet: bool = False
3551 verbose: bool = False
3552 change_count: int = 0
3553 same_count: int = 0
3554 failure_count: int = 0
3555
3556 def done(self, src: Path, changed: Changed) -> None:
3557 """Increment the counter for successful reformatting. Write out a message."""
3558 if changed is Changed.YES:
3559 reformatted = "would reformat" if self.check else "reformatted"
3560 if self.verbose or not self.quiet:
3561 out(f"{reformatted} {src}")
3562 self.change_count += 1
3563 else:
3564 if self.verbose:
3565 if changed is Changed.NO:
3566 msg = f"{src} already well formatted, good job."
3567 else:
3568 msg = f"{src} wasn't modified on disk since last run."
3569 out(msg, bold=False)
3570 self.same_count += 1
3571
3572 def failed(self, src: Path, message: str) -> None:
3573 """Increment the counter for failed reformatting. Write out a message."""
3574 err(f"error: cannot format {src}: {message}")
3575 self.failure_count += 1
3576
3577 def path_ignored(self, path: Path, message: str) -> None:
3578 if self.verbose:
3579 out(f"{path} ignored: {message}", bold=False)
3580
3581 @property
3582 def return_code(self) -> int:
3583 """Return the exit code that the app should use.
3584
3585 This considers the current state of changed files and failures:
3586 - if there were any failures, return 123;
3587 - if any files were changed and --check is being used, return 1;
3588 - otherwise return 0.
3589 """
3590 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3591 # 126 we have special return codes reserved by the shell.
3592 if self.failure_count:
3593 return 123
3594
3595 elif self.change_count and self.check:
3596 return 1
3597
3598 return 0
3599
3600 def __str__(self) -> str:
3601 """Render a color report of the current state.
3602
3603 Use `click.unstyle` to remove colors.
3604 """
3605 if self.check:
3606 reformatted = "would be reformatted"
3607 unchanged = "would be left unchanged"
3608 failed = "would fail to reformat"
3609 else:
3610 reformatted = "reformatted"
3611 unchanged = "left unchanged"
3612 failed = "failed to reformat"
3613 report = []
3614 if self.change_count:
3615 s = "s" if self.change_count > 1 else ""
3616 report.append(
3617 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3618 )
3619 if self.same_count:
3620 s = "s" if self.same_count > 1 else ""
3621 report.append(f"{self.same_count} file{s} {unchanged}")
3622 if self.failure_count:
3623 s = "s" if self.failure_count > 1 else ""
3624 report.append(
3625 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3626 )
3627 return ", ".join(report) + "."
3628
3629
3630 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3631 filename = "<unknown>"
3632 if sys.version_info >= (3, 8):
3633 # TODO: support Python 4+ ;)
3634 for minor_version in range(sys.version_info[1], 4, -1):
3635 try:
3636 return ast.parse(src, filename, feature_version=(3, minor_version))
3637 except SyntaxError:
3638 continue
3639 else:
3640 for feature_version in (7, 6):
3641 try:
3642 return ast3.parse(src, filename, feature_version=feature_version)
3643 except SyntaxError:
3644 continue
3645
3646 return ast27.parse(src)
3647
3648
3649 def _fixup_ast_constants(
3650 node: Union[ast.AST, ast3.AST, ast27.AST]
3651 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3652 """Map ast nodes deprecated in 3.8 to Constant."""
3653 # casts are required until this is released:
3654 # https://github.com/python/typeshed/pull/3142
3655 if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3656 return cast(ast.AST, ast.Constant(value=node.s))
3657 elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3658 return cast(ast.AST, ast.Constant(value=node.n))
3659 elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3660 return cast(ast.AST, ast.Constant(value=node.value))
3661 return node
3662
3663
3664 def assert_equivalent(src: str, dst: str) -> None:
3665 """Raise AssertionError if `src` and `dst` aren't equivalent."""
3666
3667 def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3668 """Simple visitor generating strings to compare ASTs by content."""
3669
3670 node = _fixup_ast_constants(node)
3671
3672 yield f"{' ' * depth}{node.__class__.__name__}("
3673
3674 for field in sorted(node._fields):
3675 # TypeIgnore has only one field 'lineno' which breaks this comparison
3676 type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3677 if sys.version_info >= (3, 8):
3678 type_ignore_classes += (ast.TypeIgnore,)
3679 if isinstance(node, type_ignore_classes):
3680 break
3681
3682 try:
3683 value = getattr(node, field)
3684 except AttributeError:
3685 continue
3686
3687 yield f"{' ' * (depth+1)}{field}="
3688
3689 if isinstance(value, list):
3690 for item in value:
3691 # Ignore nested tuples within del statements, because we may insert
3692 # parentheses and they change the AST.
3693 if (
3694 field == "targets"
3695 and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3696 and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3697 ):
3698 for item in item.elts:
3699 yield from _v(item, depth + 2)
3700 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3701 yield from _v(item, depth + 2)
3702
3703 elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3704 yield from _v(value, depth + 2)
3705
3706 else:
3707 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
3708
3709 yield f"{' ' * depth}) # /{node.__class__.__name__}"
3710
3711 try:
3712 src_ast = parse_ast(src)
3713 except Exception as exc:
3714 raise AssertionError(
3715 f"cannot use --safe with this file; failed to parse source file. "
3716 f"AST error message: {exc}"
3717 )
3718
3719 try:
3720 dst_ast = parse_ast(dst)
3721 except Exception as exc:
3722 log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3723 raise AssertionError(
3724 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3725 f"Please report a bug on https://github.com/psf/black/issues. "
3726 f"This invalid output might be helpful: {log}"
3727 ) from None
3728
3729 src_ast_str = "\n".join(_v(src_ast))
3730 dst_ast_str = "\n".join(_v(dst_ast))
3731 if src_ast_str != dst_ast_str:
3732 log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3733 raise AssertionError(
3734 f"INTERNAL ERROR: Black produced code that is not equivalent to "
3735 f"the source. "
3736 f"Please report a bug on https://github.com/psf/black/issues. "
3737 f"This diff might be helpful: {log}"
3738 ) from None
3739
3740
3741 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3742 """Raise AssertionError if `dst` reformats differently the second time."""
3743 newdst = format_str(dst, mode=mode)
3744 if dst != newdst:
3745 log = dump_to_file(
3746 diff(src, dst, "source", "first pass"),
3747 diff(dst, newdst, "first pass", "second pass"),
3748 )
3749 raise AssertionError(
3750 f"INTERNAL ERROR: Black produced different code on the second pass "
3751 f"of the formatter. "
3752 f"Please report a bug on https://github.com/psf/black/issues. "
3753 f"This diff might be helpful: {log}"
3754 ) from None
3755
3756
3757 def dump_to_file(*output: str) -> str:
3758 """Dump `output` to a temporary file. Return path to the file."""
3759 with tempfile.NamedTemporaryFile(
3760 mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3761 ) as f:
3762 for lines in output:
3763 f.write(lines)
3764 if lines and lines[-1] != "\n":
3765 f.write("\n")
3766 return f.name
3767
3768
3769 @contextmanager
3770 def nullcontext() -> Iterator[None]:
3771 """Return context manager that does nothing.
3772 Similar to `nullcontext` from python 3.7"""
3773 yield
3774
3775
3776 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3777 """Return a unified diff string between strings `a` and `b`."""
3778 import difflib
3779
3780 a_lines = [line + "\n" for line in a.split("\n")]
3781 b_lines = [line + "\n" for line in b.split("\n")]
3782 return "".join(
3783 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3784 )
3785
3786
3787 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3788 """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3789 err("Aborted!")
3790 for task in tasks:
3791 task.cancel()
3792
3793
3794 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3795 """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3796 try:
3797 if sys.version_info[:2] >= (3, 7):
3798 all_tasks = asyncio.all_tasks
3799 else:
3800 all_tasks = asyncio.Task.all_tasks
3801 # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3802 to_cancel = [task for task in all_tasks(loop) if not task.done()]
3803 if not to_cancel:
3804 return
3805
3806 for task in to_cancel:
3807 task.cancel()
3808 loop.run_until_complete(
3809 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3810 )
3811 finally:
3812 # `concurrent.futures.Future` objects cannot be cancelled once they
3813 # are already running. There might be some when the `shutdown()` happened.
3814 # Silence their logger's spew about the event loop being closed.
3815 cf_logger = logging.getLogger("concurrent.futures")
3816 cf_logger.setLevel(logging.CRITICAL)
3817 loop.close()
3818
3819
3820 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3821 """Replace `regex` with `replacement` twice on `original`.
3822
3823 This is used by string normalization to perform replaces on
3824 overlapping matches.
3825 """
3826 return regex.sub(replacement, regex.sub(replacement, original))
3827
3828
3829 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3830 """Compile a regular expression string in `regex`.
3831
3832 If it contains newlines, use verbose mode.
3833 """
3834 if "\n" in regex:
3835 regex = "(?x)" + regex
3836 return re.compile(regex)
3837
3838
3839 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3840 """Like `reversed(enumerate(sequence))` if that were possible."""
3841 index = len(sequence) - 1
3842 for element in reversed(sequence):
3843 yield (index, element)
3844 index -= 1
3845
3846
3847 def enumerate_with_length(
3848 line: Line, reversed: bool = False
3849 ) -> Iterator[Tuple[Index, Leaf, int]]:
3850 """Return an enumeration of leaves with their length.
3851
3852 Stops prematurely on multiline strings and standalone comments.
3853 """
3854 op = cast(
3855 Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3856 enumerate_reversed if reversed else enumerate,
3857 )
3858 for index, leaf in op(line.leaves):
3859 length = len(leaf.prefix) + len(leaf.value)
3860 if "\n" in leaf.value:
3861 return # Multiline strings, we can't continue.
3862
3863 for comment in line.comments_after(leaf):
3864 length += len(comment.value)
3865
3866 yield index, leaf, length
3867
3868
3869 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3870 """Return True if `line` is no longer than `line_length`.
3871
3872 Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3873 """
3874 if not line_str:
3875 line_str = str(line).strip("\n")
3876 return (
3877 len(line_str) <= line_length
3878 and "\n" not in line_str # multiline strings
3879 and not line.contains_standalone_comments()
3880 )
3881
3882
3883 def can_be_split(line: Line) -> bool:
3884 """Return False if the line cannot be split *for sure*.
3885
3886 This is not an exhaustive search but a cheap heuristic that we can use to
3887 avoid some unfortunate formattings (mostly around wrapping unsplittable code
3888 in unnecessary parentheses).
3889 """
3890 leaves = line.leaves
3891 if len(leaves) < 2:
3892 return False
3893
3894 if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3895 call_count = 0
3896 dot_count = 0
3897 next = leaves[-1]
3898 for leaf in leaves[-2::-1]:
3899 if leaf.type in OPENING_BRACKETS:
3900 if next.type not in CLOSING_BRACKETS:
3901 return False
3902
3903 call_count += 1
3904 elif leaf.type == token.DOT:
3905 dot_count += 1
3906 elif leaf.type == token.NAME:
3907 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3908 return False
3909
3910 elif leaf.type not in CLOSING_BRACKETS:
3911 return False
3912
3913 if dot_count > 1 and call_count > 1:
3914 return False
3915
3916 return True
3917
3918
3919 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3920 """Does `line` have a shape safe to reformat without optional parens around it?
3921
3922 Returns True for only a subset of potentially nice looking formattings but
3923 the point is to not return false positives that end up producing lines that
3924 are too long.
3925 """
3926 bt = line.bracket_tracker
3927 if not bt.delimiters:
3928 # Without delimiters the optional parentheses are useless.
3929 return True
3930
3931 max_priority = bt.max_delimiter_priority()
3932 if bt.delimiter_count_with_priority(max_priority) > 1:
3933 # With more than one delimiter of a kind the optional parentheses read better.
3934 return False
3935
3936 if max_priority == DOT_PRIORITY:
3937 # A single stranded method call doesn't require optional parentheses.
3938 return True
3939
3940 assert len(line.leaves) >= 2, "Stranded delimiter"
3941
3942 first = line.leaves[0]
3943 second = line.leaves[1]
3944 penultimate = line.leaves[-2]
3945 last = line.leaves[-1]
3946
3947 # With a single delimiter, omit if the expression starts or ends with
3948 # a bracket.
3949 if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3950 remainder = False
3951 length = 4 * line.depth
3952 for _index, leaf, leaf_length in enumerate_with_length(line):
3953 if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3954 remainder = True
3955 if remainder:
3956 length += leaf_length
3957 if length > line_length:
3958 break
3959
3960 if leaf.type in OPENING_BRACKETS:
3961 # There are brackets we can further split on.
3962 remainder = False
3963
3964 else:
3965 # checked the entire string and line length wasn't exceeded
3966 if len(line.leaves) == _index + 1:
3967 return True
3968
3969 # Note: we are not returning False here because a line might have *both*
3970 # a leading opening bracket and a trailing closing bracket. If the
3971 # opening bracket doesn't match our rule, maybe the closing will.
3972
3973 if (
3974 last.type == token.RPAR
3975 or last.type == token.RBRACE
3976 or (
3977 # don't use indexing for omitting optional parentheses;
3978 # it looks weird
3979 last.type == token.RSQB
3980 and last.parent
3981 and last.parent.type != syms.trailer
3982 )
3983 ):
3984 if penultimate.type in OPENING_BRACKETS:
3985 # Empty brackets don't help.
3986 return False
3987
3988 if is_multiline_string(first):
3989 # Additional wrapping of a multiline string in this situation is
3990 # unnecessary.
3991 return True
3992
3993 length = 4 * line.depth
3994 seen_other_brackets = False
3995 for _index, leaf, leaf_length in enumerate_with_length(line):
3996 length += leaf_length
3997 if leaf is last.opening_bracket:
3998 if seen_other_brackets or length <= line_length:
3999 return True
4000
4001 elif leaf.type in OPENING_BRACKETS:
4002 # There are brackets we can further split on.
4003 seen_other_brackets = True
4004
4005 return False
4006
4007
4008 def get_cache_file(mode: FileMode) -> Path:
4009 return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
4010
4011
4012 def read_cache(mode: FileMode) -> Cache:
4013 """Read the cache if it exists and is well formed.
4014
4015 If it is not well formed, the call to write_cache later should resolve the issue.
4016 """
4017 cache_file = get_cache_file(mode)
4018 if not cache_file.exists():
4019 return {}
4020
4021 with cache_file.open("rb") as fobj:
4022 try:
4023 cache: Cache = pickle.load(fobj)
4024 except pickle.UnpicklingError:
4025 return {}
4026
4027 return cache
4028
4029
4030 def get_cache_info(path: Path) -> CacheInfo:
4031 """Return the information used to check if a file is already formatted or not."""
4032 stat = path.stat()
4033 return stat.st_mtime, stat.st_size
4034
4035
4036 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
4037 """Split an iterable of paths in `sources` into two sets.
4038
4039 The first contains paths of files that modified on disk or are not in the
4040 cache. The other contains paths to non-modified files.
4041 """
4042 todo, done = set(), set()
4043 for src in sources:
4044 src = src.resolve()
4045 if cache.get(src) != get_cache_info(src):
4046 todo.add(src)
4047 else:
4048 done.add(src)
4049 return todo, done
4050
4051
4052 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
4053 """Update the cache file."""
4054 cache_file = get_cache_file(mode)
4055 try:
4056 CACHE_DIR.mkdir(parents=True, exist_ok=True)
4057 new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
4058 with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4059 pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
4060 os.replace(f.name, cache_file)
4061 except OSError:
4062 pass
4063
4064
4065 def patch_click() -> None:
4066 """Make Click not crash.
4067
4068 On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4069 default which restricts paths that it can access during the lifetime of the
4070 application. Click refuses to work in this scenario by raising a RuntimeError.
4071
4072 In case of Black the likelihood that non-ASCII characters are going to be used in
4073 file paths is minimal since it's Python source code. Moreover, this crash was
4074 spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4075 """
4076 try:
4077 from click import core
4078 from click import _unicodefun # type: ignore
4079 except ModuleNotFoundError:
4080 return
4081
4082 for module in (core, _unicodefun):
4083 if hasattr(module, "_verify_python3_env"):
4084 module._verify_python3_env = lambda: None
4085
4086
4087 def patched_main() -> None:
4088 freeze_support()
4089 patch_click()
4090 main()
4091
4092
4093 if __name__ == "__main__":
4094 patched_main()
@@ -5,6 +5,11 b' clang-format:pattern = (**.c or **.cc or'
5 5 rustfmt:command = rustfmt {rootpath}
6 6 rustfmt:pattern = set:**.rs
7 7
8 # We use black, but currently with https://github.com/psf/black/pull/826 applied.
9 # black:command = black --skip-string-normalization
8 # We use black, but currently with
9 # https://github.com/psf/black/pull/826 applied. For now
10 # contrib/grey.py is our fork of black. You need to pip install
11 # git+https://github.com/python/black/@d9e71a75ccfefa3d9156a64c03313a0d4ad981e5
12 # to have the dependencies for grey.
13 #
14 # black:command = python3.7 contrib/grey.py --skip-string-normalization
10 15 # black:pattern = set:**.py - hgext/fsmonitor/pywatchman/** - mercurial/thirdparty/** - "contrib/python-zstandard/**"
@@ -21,6 +21,7 b' New errors are not allowed. Warnings are'
21 21 Skipping contrib/automation/hgautomation/try_server.py it has no-che?k-code (glob)
22 22 Skipping contrib/automation/hgautomation/windows.py it has no-che?k-code (glob)
23 23 Skipping contrib/automation/hgautomation/winrm.py it has no-che?k-code (glob)
24 Skipping contrib/grey.py it has no-che?k-code (glob)
24 25 Skipping contrib/packaging/hgpackaging/downloads.py it has no-che?k-code (glob)
25 26 Skipping contrib/packaging/hgpackaging/inno.py it has no-che?k-code (glob)
26 27 Skipping contrib/packaging/hgpackaging/py2exe.py it has no-che?k-code (glob)
@@ -20,6 +20,7 b' outputs, which should be fixed later.'
20 20 > -X setup.py \
21 21 > -X contrib/automation/ \
22 22 > -X contrib/debugshell.py \
23 > -X contrib/grey.py \
23 24 > -X contrib/hgweb.fcgi \
24 25 > -X contrib/packaging/hg-docker \
25 26 > -X contrib/packaging/hgpackaging/ \
@@ -6,6 +6,7 b''
6 6 #if no-py3
7 7 $ testrepohg files 'set:(**.py)' \
8 8 > -X contrib/automation/ \
9 > -X contrib/grey.py \
9 10 > -X contrib/packaging/hgpackaging/ \
10 11 > -X contrib/packaging/inno/ \
11 12 > -X contrib/packaging/wix/ \
General Comments 0
You need to be logged in to leave comments. Login now