]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

8a70c02704fbc0f445e4aa4cbadc594fdfbaa357
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from contextlib import contextmanager
5 from datetime import datetime
6 from enum import Enum
7 from functools import lru_cache, partial, wraps
8 import io
9 import itertools
10 import logging
11 from multiprocessing import Manager, freeze_support
12 import os
13 from pathlib import Path
14 import pickle
15 import regex as re
16 import signal
17 import sys
18 import tempfile
19 import tokenize
20 import traceback
21 from typing import (
22     Any,
23     Callable,
24     Collection,
25     Dict,
26     Generator,
27     Generic,
28     Iterable,
29     Iterator,
30     List,
31     Optional,
32     Pattern,
33     Sequence,
34     Set,
35     Tuple,
36     TypeVar,
37     Union,
38     cast,
39 )
40
41 from appdirs import user_cache_dir
42 from attr import dataclass, evolve, Factory
43 import click
44 import toml
45 from typed_ast import ast3, ast27
46
47 # lib2to3 fork
48 from blib2to3.pytree import Node, Leaf, type_repr
49 from blib2to3 import pygram, pytree
50 from blib2to3.pgen2 import driver, token
51 from blib2to3.pgen2.grammar import Grammar
52 from blib2to3.pgen2.parse import ParseError
53
54 from _version import version as __version__
55
56 DEFAULT_LINE_LENGTH = 88
57 DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
58 DEFAULT_INCLUDES = r"\.pyi?$"
59 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
60
61
62 # types
63 FileContent = str
64 Encoding = str
65 NewLine = str
66 Depth = int
67 NodeType = int
68 LeafID = int
69 Priority = int
70 Index = int
71 LN = Union[Leaf, Node]
72 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
73 Timestamp = float
74 FileSize = int
75 CacheInfo = Tuple[Timestamp, FileSize]
76 Cache = Dict[Path, CacheInfo]
77 out = partial(click.secho, bold=True, err=True)
78 err = partial(click.secho, fg="red", err=True)
79
80 pygram.initialize(CACHE_DIR)
81 syms = pygram.python_symbols
82
83
84 class NothingChanged(UserWarning):
85     """Raised when reformatted code is the same as source."""
86
87
88 class CannotSplit(Exception):
89     """A readable split that fits the allotted line length is impossible."""
90
91
92 class InvalidInput(ValueError):
93     """Raised when input source code fails all parse attempts."""
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101
102     @classmethod
103     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
104         if check and not diff:
105             return cls.CHECK
106
107         return cls.DIFF if diff else cls.YES
108
109
110 class Changed(Enum):
111     NO = 0
112     CACHED = 1
113     YES = 2
114
115
116 class TargetVersion(Enum):
117     PY27 = 2
118     PY33 = 3
119     PY34 = 4
120     PY35 = 5
121     PY36 = 6
122     PY37 = 7
123     PY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.PY27
127
128
129 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA_IN_CALL = 4
138     TRAILING_COMMA_IN_DEF = 5
139     # The following two feature-flags are mutually exclusive, and exactly one should be
140     # set for every version of python.
141     ASYNC_IDENTIFIERS = 6
142     ASYNC_KEYWORDS = 7
143     ASSIGNMENT_EXPRESSIONS = 8
144     POS_ONLY_ARGUMENTS = 9
145
146
147 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
148     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
149     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
150     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
151     TargetVersion.PY35: {
152         Feature.UNICODE_LITERALS,
153         Feature.TRAILING_COMMA_IN_CALL,
154         Feature.ASYNC_IDENTIFIERS,
155     },
156     TargetVersion.PY36: {
157         Feature.UNICODE_LITERALS,
158         Feature.F_STRINGS,
159         Feature.NUMERIC_UNDERSCORES,
160         Feature.TRAILING_COMMA_IN_CALL,
161         Feature.TRAILING_COMMA_IN_DEF,
162         Feature.ASYNC_IDENTIFIERS,
163     },
164     TargetVersion.PY37: {
165         Feature.UNICODE_LITERALS,
166         Feature.F_STRINGS,
167         Feature.NUMERIC_UNDERSCORES,
168         Feature.TRAILING_COMMA_IN_CALL,
169         Feature.TRAILING_COMMA_IN_DEF,
170         Feature.ASYNC_KEYWORDS,
171     },
172     TargetVersion.PY38: {
173         Feature.UNICODE_LITERALS,
174         Feature.F_STRINGS,
175         Feature.NUMERIC_UNDERSCORES,
176         Feature.TRAILING_COMMA_IN_CALL,
177         Feature.TRAILING_COMMA_IN_DEF,
178         Feature.ASYNC_KEYWORDS,
179         Feature.ASSIGNMENT_EXPRESSIONS,
180         Feature.POS_ONLY_ARGUMENTS,
181     },
182 }
183
184
185 @dataclass
186 class FileMode:
187     target_versions: Set[TargetVersion] = Factory(set)
188     line_length: int = DEFAULT_LINE_LENGTH
189     string_normalization: bool = True
190     is_pyi: bool = False
191
192     def get_cache_key(self) -> str:
193         if self.target_versions:
194             version_str = ",".join(
195                 str(version.value)
196                 for version in sorted(self.target_versions, key=lambda v: v.value)
197             )
198         else:
199             version_str = "-"
200         parts = [
201             version_str,
202             str(self.line_length),
203             str(int(self.string_normalization)),
204             str(int(self.is_pyi)),
205         ]
206         return ".".join(parts)
207
208
209 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
210     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
211
212
213 def read_pyproject_toml(
214     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
215 ) -> Optional[str]:
216     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
217
218     Returns the path to a successfully found and read configuration file, None
219     otherwise.
220     """
221     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
222     if not value:
223         root = find_project_root(ctx.params.get("src", ()))
224         path = root / "pyproject.toml"
225         if path.is_file():
226             value = str(path)
227         else:
228             return None
229
230     try:
231         pyproject_toml = toml.load(value)
232         config = pyproject_toml.get("tool", {}).get("black", {})
233     except (toml.TomlDecodeError, OSError) as e:
234         raise click.FileError(
235             filename=value, hint=f"Error reading configuration file: {e}"
236         )
237
238     if not config:
239         return None
240
241     if ctx.default_map is None:
242         ctx.default_map = {}
243     ctx.default_map.update(  # type: ignore  # bad types in .pyi
244         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
245     )
246     return value
247
248
249 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
250 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
251 @click.option(
252     "-l",
253     "--line-length",
254     type=int,
255     default=DEFAULT_LINE_LENGTH,
256     help="How many characters per line to allow.",
257     show_default=True,
258 )
259 @click.option(
260     "-t",
261     "--target-version",
262     type=click.Choice([v.name.lower() for v in TargetVersion]),
263     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
264     multiple=True,
265     help=(
266         "Python versions that should be supported by Black's output. [default: "
267         "per-file auto-detection]"
268     ),
269 )
270 @click.option(
271     "--py36",
272     is_flag=True,
273     help=(
274         "Allow using Python 3.6-only syntax on all input files.  This will put "
275         "trailing commas in function signatures and calls also after *args and "
276         "**kwargs. Deprecated; use --target-version instead. "
277         "[default: per-file auto-detection]"
278     ),
279 )
280 @click.option(
281     "--pyi",
282     is_flag=True,
283     help=(
284         "Format all input files like typing stubs regardless of file extension "
285         "(useful when piping source on standard input)."
286     ),
287 )
288 @click.option(
289     "-S",
290     "--skip-string-normalization",
291     is_flag=True,
292     help="Don't normalize string quotes or prefixes.",
293 )
294 @click.option(
295     "--check",
296     is_flag=True,
297     help=(
298         "Don't write the files back, just return the status.  Return code 0 "
299         "means nothing would change.  Return code 1 means some files would be "
300         "reformatted.  Return code 123 means there was an internal error."
301     ),
302 )
303 @click.option(
304     "--diff",
305     is_flag=True,
306     help="Don't write the files back, just output a diff for each file on stdout.",
307 )
308 @click.option(
309     "--fast/--safe",
310     is_flag=True,
311     help="If --fast given, skip temporary sanity checks. [default: --safe]",
312 )
313 @click.option(
314     "--include",
315     type=str,
316     default=DEFAULT_INCLUDES,
317     help=(
318         "A regular expression that matches files and directories that should be "
319         "included on recursive searches.  An empty value means all files are "
320         "included regardless of the name.  Use forward slashes for directories on "
321         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
322         "later."
323     ),
324     show_default=True,
325 )
326 @click.option(
327     "--exclude",
328     type=str,
329     default=DEFAULT_EXCLUDES,
330     help=(
331         "A regular expression that matches files and directories that should be "
332         "excluded on recursive searches.  An empty value means no paths are excluded. "
333         "Use forward slashes for directories on all platforms (Windows, too).  "
334         "Exclusions are calculated first, inclusions later."
335     ),
336     show_default=True,
337 )
338 @click.option(
339     "-q",
340     "--quiet",
341     is_flag=True,
342     help=(
343         "Don't emit non-error messages to stderr. Errors are still emitted; "
344         "silence those with 2>/dev/null."
345     ),
346 )
347 @click.option(
348     "-v",
349     "--verbose",
350     is_flag=True,
351     help=(
352         "Also emit messages to stderr about files that were not changed or were "
353         "ignored due to --exclude=."
354     ),
355 )
356 @click.version_option(version=__version__)
357 @click.argument(
358     "src",
359     nargs=-1,
360     type=click.Path(
361         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
362     ),
363     is_eager=True,
364 )
365 @click.option(
366     "--config",
367     type=click.Path(
368         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
369     ),
370     is_eager=True,
371     callback=read_pyproject_toml,
372     help="Read configuration from PATH.",
373 )
374 @click.pass_context
375 def main(
376     ctx: click.Context,
377     code: Optional[str],
378     line_length: int,
379     target_version: List[TargetVersion],
380     check: bool,
381     diff: bool,
382     fast: bool,
383     pyi: bool,
384     py36: bool,
385     skip_string_normalization: bool,
386     quiet: bool,
387     verbose: bool,
388     include: str,
389     exclude: str,
390     src: Tuple[str],
391     config: Optional[str],
392 ) -> None:
393     """The uncompromising code formatter."""
394     write_back = WriteBack.from_configuration(check=check, diff=diff)
395     if target_version:
396         if py36:
397             err(f"Cannot use both --target-version and --py36")
398             ctx.exit(2)
399         else:
400             versions = set(target_version)
401     elif py36:
402         err(
403             "--py36 is deprecated and will be removed in a future version. "
404             "Use --target-version py36 instead."
405         )
406         versions = PY36_VERSIONS
407     else:
408         # We'll autodetect later.
409         versions = set()
410     mode = FileMode(
411         target_versions=versions,
412         line_length=line_length,
413         is_pyi=pyi,
414         string_normalization=not skip_string_normalization,
415     )
416     if config and verbose:
417         out(f"Using configuration from {config}.", bold=False, fg="blue")
418     if code is not None:
419         print(format_str(code, mode=mode))
420         ctx.exit(0)
421     try:
422         include_regex = re_compile_maybe_verbose(include)
423     except re.error:
424         err(f"Invalid regular expression for include given: {include!r}")
425         ctx.exit(2)
426     try:
427         exclude_regex = re_compile_maybe_verbose(exclude)
428     except re.error:
429         err(f"Invalid regular expression for exclude given: {exclude!r}")
430         ctx.exit(2)
431     report = Report(check=check, quiet=quiet, verbose=verbose)
432     root = find_project_root(src)
433     sources: Set[Path] = set()
434     path_empty(src, quiet, verbose, ctx)
435     for s in src:
436         p = Path(s)
437         if p.is_dir():
438             sources.update(
439                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
440             )
441         elif p.is_file() or s == "-":
442             # if a file was explicitly given, we don't care about its extension
443             sources.add(p)
444         else:
445             err(f"invalid path: {s}")
446     if len(sources) == 0:
447         if verbose or not quiet:
448             out("No Python files are present to be formatted. Nothing to do 😴")
449         ctx.exit(0)
450
451     if len(sources) == 1:
452         reformat_one(
453             src=sources.pop(),
454             fast=fast,
455             write_back=write_back,
456             mode=mode,
457             report=report,
458         )
459     else:
460         reformat_many(
461             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
462         )
463
464     if verbose or not quiet:
465         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
466         click.secho(str(report), err=True)
467     ctx.exit(report.return_code)
468
469
470 def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
471     """
472     Exit if there is no `src` provided for formatting
473     """
474     if not src:
475         if verbose or not quiet:
476             out("No Path provided. Nothing to do 😴")
477             ctx.exit(0)
478
479
480 def reformat_one(
481     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
482 ) -> None:
483     """Reformat a single file under `src` without spawning child processes.
484
485     `fast`, `write_back`, and `mode` options are passed to
486     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
487     """
488     try:
489         changed = Changed.NO
490         if not src.is_file() and str(src) == "-":
491             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
492                 changed = Changed.YES
493         else:
494             cache: Cache = {}
495             if write_back != WriteBack.DIFF:
496                 cache = read_cache(mode)
497                 res_src = src.resolve()
498                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
499                     changed = Changed.CACHED
500             if changed is not Changed.CACHED and format_file_in_place(
501                 src, fast=fast, write_back=write_back, mode=mode
502             ):
503                 changed = Changed.YES
504             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
505                 write_back is WriteBack.CHECK and changed is Changed.NO
506             ):
507                 write_cache(cache, [src], mode)
508         report.done(src, changed)
509     except Exception as exc:
510         report.failed(src, str(exc))
511
512
513 def reformat_many(
514     sources: Set[Path],
515     fast: bool,
516     write_back: WriteBack,
517     mode: FileMode,
518     report: "Report",
519 ) -> None:
520     """Reformat multiple files using a ProcessPoolExecutor."""
521     loop = asyncio.get_event_loop()
522     worker_count = os.cpu_count()
523     if sys.platform == "win32":
524         # Work around https://bugs.python.org/issue26903
525         worker_count = min(worker_count, 61)
526     executor = ProcessPoolExecutor(max_workers=worker_count)
527     try:
528         loop.run_until_complete(
529             schedule_formatting(
530                 sources=sources,
531                 fast=fast,
532                 write_back=write_back,
533                 mode=mode,
534                 report=report,
535                 loop=loop,
536                 executor=executor,
537             )
538         )
539     finally:
540         shutdown(loop)
541         executor.shutdown()
542
543
544 async def schedule_formatting(
545     sources: Set[Path],
546     fast: bool,
547     write_back: WriteBack,
548     mode: FileMode,
549     report: "Report",
550     loop: asyncio.AbstractEventLoop,
551     executor: Executor,
552 ) -> None:
553     """Run formatting of `sources` in parallel using the provided `executor`.
554
555     (Use ProcessPoolExecutors for actual parallelism.)
556
557     `write_back`, `fast`, and `mode` options are passed to
558     :func:`format_file_in_place`.
559     """
560     cache: Cache = {}
561     if write_back != WriteBack.DIFF:
562         cache = read_cache(mode)
563         sources, cached = filter_cached(cache, sources)
564         for src in sorted(cached):
565             report.done(src, Changed.CACHED)
566     if not sources:
567         return
568
569     cancelled = []
570     sources_to_cache = []
571     lock = None
572     if write_back == WriteBack.DIFF:
573         # For diff output, we need locks to ensure we don't interleave output
574         # from different processes.
575         manager = Manager()
576         lock = manager.Lock()
577     tasks = {
578         asyncio.ensure_future(
579             loop.run_in_executor(
580                 executor, format_file_in_place, src, fast, mode, write_back, lock
581             )
582         ): src
583         for src in sorted(sources)
584     }
585     pending: Iterable[asyncio.Future] = tasks.keys()
586     try:
587         loop.add_signal_handler(signal.SIGINT, cancel, pending)
588         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
589     except NotImplementedError:
590         # There are no good alternatives for these on Windows.
591         pass
592     while pending:
593         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
594         for task in done:
595             src = tasks.pop(task)
596             if task.cancelled():
597                 cancelled.append(task)
598             elif task.exception():
599                 report.failed(src, str(task.exception()))
600             else:
601                 changed = Changed.YES if task.result() else Changed.NO
602                 # If the file was written back or was successfully checked as
603                 # well-formatted, store this information in the cache.
604                 if write_back is WriteBack.YES or (
605                     write_back is WriteBack.CHECK and changed is Changed.NO
606                 ):
607                     sources_to_cache.append(src)
608                 report.done(src, changed)
609     if cancelled:
610         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
611     if sources_to_cache:
612         write_cache(cache, sources_to_cache, mode)
613
614
615 def format_file_in_place(
616     src: Path,
617     fast: bool,
618     mode: FileMode,
619     write_back: WriteBack = WriteBack.NO,
620     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
621 ) -> bool:
622     """Format file under `src` path. Return True if changed.
623
624     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
625     code to the file.
626     `mode` and `fast` options are passed to :func:`format_file_contents`.
627     """
628     if src.suffix == ".pyi":
629         mode = evolve(mode, is_pyi=True)
630
631     then = datetime.utcfromtimestamp(src.stat().st_mtime)
632     with open(src, "rb") as buf:
633         src_contents, encoding, newline = decode_bytes(buf.read())
634     try:
635         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
636     except NothingChanged:
637         return False
638
639     if write_back == write_back.YES:
640         with open(src, "w", encoding=encoding, newline=newline) as f:
641             f.write(dst_contents)
642     elif write_back == write_back.DIFF:
643         now = datetime.utcnow()
644         src_name = f"{src}\t{then} +0000"
645         dst_name = f"{src}\t{now} +0000"
646         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
647
648         with lock or nullcontext():
649             f = io.TextIOWrapper(
650                 sys.stdout.buffer,
651                 encoding=encoding,
652                 newline=newline,
653                 write_through=True,
654             )
655             f.write(diff_contents)
656             f.detach()
657
658     return True
659
660
661 def format_stdin_to_stdout(
662     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
663 ) -> bool:
664     """Format file on stdin. Return True if changed.
665
666     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
667     write a diff to stdout. The `mode` argument is passed to
668     :func:`format_file_contents`.
669     """
670     then = datetime.utcnow()
671     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
672     dst = src
673     try:
674         dst = format_file_contents(src, fast=fast, mode=mode)
675         return True
676
677     except NothingChanged:
678         return False
679
680     finally:
681         f = io.TextIOWrapper(
682             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
683         )
684         if write_back == WriteBack.YES:
685             f.write(dst)
686         elif write_back == WriteBack.DIFF:
687             now = datetime.utcnow()
688             src_name = f"STDIN\t{then} +0000"
689             dst_name = f"STDOUT\t{now} +0000"
690             f.write(diff(src, dst, src_name, dst_name))
691         f.detach()
692
693
694 def format_file_contents(
695     src_contents: str, *, fast: bool, mode: FileMode
696 ) -> FileContent:
697     """Reformat contents a file and return new contents.
698
699     If `fast` is False, additionally confirm that the reformatted code is
700     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
701     `mode` is passed to :func:`format_str`.
702     """
703     if src_contents.strip() == "":
704         raise NothingChanged
705
706     dst_contents = format_str(src_contents, mode=mode)
707     if src_contents == dst_contents:
708         raise NothingChanged
709
710     if not fast:
711         assert_equivalent(src_contents, dst_contents)
712         assert_stable(src_contents, dst_contents, mode=mode)
713     return dst_contents
714
715
716 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
717     """Reformat a string and return new contents.
718
719     `mode` determines formatting options, such as how many characters per line are
720     allowed.
721     """
722     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
723     dst_contents = []
724     future_imports = get_future_imports(src_node)
725     if mode.target_versions:
726         versions = mode.target_versions
727     else:
728         versions = detect_target_versions(src_node)
729     normalize_fmt_off(src_node)
730     lines = LineGenerator(
731         remove_u_prefix="unicode_literals" in future_imports
732         or supports_feature(versions, Feature.UNICODE_LITERALS),
733         is_pyi=mode.is_pyi,
734         normalize_strings=mode.string_normalization,
735     )
736     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
737     empty_line = Line()
738     after = 0
739     split_line_features = {
740         feature
741         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
742         if supports_feature(versions, feature)
743     }
744     for current_line in lines.visit(src_node):
745         for _ in range(after):
746             dst_contents.append(str(empty_line))
747         before, after = elt.maybe_empty_lines(current_line)
748         for _ in range(before):
749             dst_contents.append(str(empty_line))
750         for line in split_line(
751             current_line, line_length=mode.line_length, features=split_line_features
752         ):
753             dst_contents.append(str(line))
754     return "".join(dst_contents)
755
756
757 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
758     """Return a tuple of (decoded_contents, encoding, newline).
759
760     `newline` is either CRLF or LF but `decoded_contents` is decoded with
761     universal newlines (i.e. only contains LF).
762     """
763     srcbuf = io.BytesIO(src)
764     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
765     if not lines:
766         return "", encoding, "\n"
767
768     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
769     srcbuf.seek(0)
770     with io.TextIOWrapper(srcbuf, encoding) as tiow:
771         return tiow.read(), encoding, newline
772
773
774 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
775     if not target_versions:
776         # No target_version specified, so try all grammars.
777         return [
778             # Python 3.7+
779             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
780             # Python 3.0-3.6
781             pygram.python_grammar_no_print_statement_no_exec_statement,
782             # Python 2.7 with future print_function import
783             pygram.python_grammar_no_print_statement,
784             # Python 2.7
785             pygram.python_grammar,
786         ]
787     elif all(version.is_python2() for version in target_versions):
788         # Python 2-only code, so try Python 2 grammars.
789         return [
790             # Python 2.7 with future print_function import
791             pygram.python_grammar_no_print_statement,
792             # Python 2.7
793             pygram.python_grammar,
794         ]
795     else:
796         # Python 3-compatible code, so only try Python 3 grammar.
797         grammars = []
798         # If we have to parse both, try to parse async as a keyword first
799         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
800             # Python 3.7+
801             grammars.append(
802                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
803             )
804         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
805             # Python 3.0-3.6
806             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
807         # At least one of the above branches must have been taken, because every Python
808         # version has exactly one of the two 'ASYNC_*' flags
809         return grammars
810
811
812 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
813     """Given a string with source, return the lib2to3 Node."""
814     if src_txt[-1:] != "\n":
815         src_txt += "\n"
816
817     for grammar in get_grammars(set(target_versions)):
818         drv = driver.Driver(grammar, pytree.convert)
819         try:
820             result = drv.parse_string(src_txt, True)
821             break
822
823         except ParseError as pe:
824             lineno, column = pe.context[1]
825             lines = src_txt.splitlines()
826             try:
827                 faulty_line = lines[lineno - 1]
828             except IndexError:
829                 faulty_line = "<line number missing in source>"
830             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
831     else:
832         raise exc from None
833
834     if isinstance(result, Leaf):
835         result = Node(syms.file_input, [result])
836     return result
837
838
839 def lib2to3_unparse(node: Node) -> str:
840     """Given a lib2to3 node, return its string representation."""
841     code = str(node)
842     return code
843
844
845 T = TypeVar("T")
846
847
848 class Visitor(Generic[T]):
849     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
850
851     def visit(self, node: LN) -> Iterator[T]:
852         """Main method to visit `node` and its children.
853
854         It tries to find a `visit_*()` method for the given `node.type`, like
855         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
856         If no dedicated `visit_*()` method is found, chooses `visit_default()`
857         instead.
858
859         Then yields objects of type `T` from the selected visitor.
860         """
861         if node.type < 256:
862             name = token.tok_name[node.type]
863         else:
864             name = type_repr(node.type)
865         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
866
867     def visit_default(self, node: LN) -> Iterator[T]:
868         """Default `visit_*()` implementation. Recurses to children of `node`."""
869         if isinstance(node, Node):
870             for child in node.children:
871                 yield from self.visit(child)
872
873
874 @dataclass
875 class DebugVisitor(Visitor[T]):
876     tree_depth: int = 0
877
878     def visit_default(self, node: LN) -> Iterator[T]:
879         indent = " " * (2 * self.tree_depth)
880         if isinstance(node, Node):
881             _type = type_repr(node.type)
882             out(f"{indent}{_type}", fg="yellow")
883             self.tree_depth += 1
884             for child in node.children:
885                 yield from self.visit(child)
886
887             self.tree_depth -= 1
888             out(f"{indent}/{_type}", fg="yellow", bold=False)
889         else:
890             _type = token.tok_name.get(node.type, str(node.type))
891             out(f"{indent}{_type}", fg="blue", nl=False)
892             if node.prefix:
893                 # We don't have to handle prefixes for `Node` objects since
894                 # that delegates to the first child anyway.
895                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
896             out(f" {node.value!r}", fg="blue", bold=False)
897
898     @classmethod
899     def show(cls, code: Union[str, Leaf, Node]) -> None:
900         """Pretty-print the lib2to3 AST of a given string of `code`.
901
902         Convenience method for debugging.
903         """
904         v: DebugVisitor[None] = DebugVisitor()
905         if isinstance(code, str):
906             code = lib2to3_parse(code)
907         list(v.visit(code))
908
909
910 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
911 STATEMENT = {
912     syms.if_stmt,
913     syms.while_stmt,
914     syms.for_stmt,
915     syms.try_stmt,
916     syms.except_clause,
917     syms.with_stmt,
918     syms.funcdef,
919     syms.classdef,
920 }
921 STANDALONE_COMMENT = 153
922 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
923 LOGIC_OPERATORS = {"and", "or"}
924 COMPARATORS = {
925     token.LESS,
926     token.GREATER,
927     token.EQEQUAL,
928     token.NOTEQUAL,
929     token.LESSEQUAL,
930     token.GREATEREQUAL,
931 }
932 MATH_OPERATORS = {
933     token.VBAR,
934     token.CIRCUMFLEX,
935     token.AMPER,
936     token.LEFTSHIFT,
937     token.RIGHTSHIFT,
938     token.PLUS,
939     token.MINUS,
940     token.STAR,
941     token.SLASH,
942     token.DOUBLESLASH,
943     token.PERCENT,
944     token.AT,
945     token.TILDE,
946     token.DOUBLESTAR,
947 }
948 STARS = {token.STAR, token.DOUBLESTAR}
949 VARARGS_SPECIALS = STARS | {token.SLASH}
950 VARARGS_PARENTS = {
951     syms.arglist,
952     syms.argument,  # double star in arglist
953     syms.trailer,  # single argument to call
954     syms.typedargslist,
955     syms.varargslist,  # lambdas
956 }
957 UNPACKING_PARENTS = {
958     syms.atom,  # single element of a list or set literal
959     syms.dictsetmaker,
960     syms.listmaker,
961     syms.testlist_gexp,
962     syms.testlist_star_expr,
963 }
964 TEST_DESCENDANTS = {
965     syms.test,
966     syms.lambdef,
967     syms.or_test,
968     syms.and_test,
969     syms.not_test,
970     syms.comparison,
971     syms.star_expr,
972     syms.expr,
973     syms.xor_expr,
974     syms.and_expr,
975     syms.shift_expr,
976     syms.arith_expr,
977     syms.trailer,
978     syms.term,
979     syms.power,
980 }
981 ASSIGNMENTS = {
982     "=",
983     "+=",
984     "-=",
985     "*=",
986     "@=",
987     "/=",
988     "%=",
989     "&=",
990     "|=",
991     "^=",
992     "<<=",
993     ">>=",
994     "**=",
995     "//=",
996 }
997 COMPREHENSION_PRIORITY = 20
998 COMMA_PRIORITY = 18
999 TERNARY_PRIORITY = 16
1000 LOGIC_PRIORITY = 14
1001 STRING_PRIORITY = 12
1002 COMPARATOR_PRIORITY = 10
1003 MATH_PRIORITIES = {
1004     token.VBAR: 9,
1005     token.CIRCUMFLEX: 8,
1006     token.AMPER: 7,
1007     token.LEFTSHIFT: 6,
1008     token.RIGHTSHIFT: 6,
1009     token.PLUS: 5,
1010     token.MINUS: 5,
1011     token.STAR: 4,
1012     token.SLASH: 4,
1013     token.DOUBLESLASH: 4,
1014     token.PERCENT: 4,
1015     token.AT: 4,
1016     token.TILDE: 3,
1017     token.DOUBLESTAR: 2,
1018 }
1019 DOT_PRIORITY = 1
1020
1021
1022 @dataclass
1023 class BracketTracker:
1024     """Keeps track of brackets on a line."""
1025
1026     depth: int = 0
1027     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1028     delimiters: Dict[LeafID, Priority] = Factory(dict)
1029     previous: Optional[Leaf] = None
1030     _for_loop_depths: List[int] = Factory(list)
1031     _lambda_argument_depths: List[int] = Factory(list)
1032
1033     def mark(self, leaf: Leaf) -> None:
1034         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1035
1036         All leaves receive an int `bracket_depth` field that stores how deep
1037         within brackets a given leaf is. 0 means there are no enclosing brackets
1038         that started on this line.
1039
1040         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1041         field that it forms a pair with. This is a one-directional link to
1042         avoid reference cycles.
1043
1044         If a leaf is a delimiter (a token on which Black can split the line if
1045         needed) and it's on depth 0, its `id()` is stored in the tracker's
1046         `delimiters` field.
1047         """
1048         if leaf.type == token.COMMENT:
1049             return
1050
1051         self.maybe_decrement_after_for_loop_variable(leaf)
1052         self.maybe_decrement_after_lambda_arguments(leaf)
1053         if leaf.type in CLOSING_BRACKETS:
1054             self.depth -= 1
1055             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1056             leaf.opening_bracket = opening_bracket
1057         leaf.bracket_depth = self.depth
1058         if self.depth == 0:
1059             delim = is_split_before_delimiter(leaf, self.previous)
1060             if delim and self.previous is not None:
1061                 self.delimiters[id(self.previous)] = delim
1062             else:
1063                 delim = is_split_after_delimiter(leaf, self.previous)
1064                 if delim:
1065                     self.delimiters[id(leaf)] = delim
1066         if leaf.type in OPENING_BRACKETS:
1067             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1068             self.depth += 1
1069         self.previous = leaf
1070         self.maybe_increment_lambda_arguments(leaf)
1071         self.maybe_increment_for_loop_variable(leaf)
1072
1073     def any_open_brackets(self) -> bool:
1074         """Return True if there is an yet unmatched open bracket on the line."""
1075         return bool(self.bracket_match)
1076
1077     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1078         """Return the highest priority of a delimiter found on the line.
1079
1080         Values are consistent with what `is_split_*_delimiter()` return.
1081         Raises ValueError on no delimiters.
1082         """
1083         return max(v for k, v in self.delimiters.items() if k not in exclude)
1084
1085     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1086         """Return the number of delimiters with the given `priority`.
1087
1088         If no `priority` is passed, defaults to max priority on the line.
1089         """
1090         if not self.delimiters:
1091             return 0
1092
1093         priority = priority or self.max_delimiter_priority()
1094         return sum(1 for p in self.delimiters.values() if p == priority)
1095
1096     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1097         """In a for loop, or comprehension, the variables are often unpacks.
1098
1099         To avoid splitting on the comma in this situation, increase the depth of
1100         tokens between `for` and `in`.
1101         """
1102         if leaf.type == token.NAME and leaf.value == "for":
1103             self.depth += 1
1104             self._for_loop_depths.append(self.depth)
1105             return True
1106
1107         return False
1108
1109     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1110         """See `maybe_increment_for_loop_variable` above for explanation."""
1111         if (
1112             self._for_loop_depths
1113             and self._for_loop_depths[-1] == self.depth
1114             and leaf.type == token.NAME
1115             and leaf.value == "in"
1116         ):
1117             self.depth -= 1
1118             self._for_loop_depths.pop()
1119             return True
1120
1121         return False
1122
1123     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1124         """In a lambda expression, there might be more than one argument.
1125
1126         To avoid splitting on the comma in this situation, increase the depth of
1127         tokens between `lambda` and `:`.
1128         """
1129         if leaf.type == token.NAME and leaf.value == "lambda":
1130             self.depth += 1
1131             self._lambda_argument_depths.append(self.depth)
1132             return True
1133
1134         return False
1135
1136     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1137         """See `maybe_increment_lambda_arguments` above for explanation."""
1138         if (
1139             self._lambda_argument_depths
1140             and self._lambda_argument_depths[-1] == self.depth
1141             and leaf.type == token.COLON
1142         ):
1143             self.depth -= 1
1144             self._lambda_argument_depths.pop()
1145             return True
1146
1147         return False
1148
1149     def get_open_lsqb(self) -> Optional[Leaf]:
1150         """Return the most recent opening square bracket (if any)."""
1151         return self.bracket_match.get((self.depth - 1, token.RSQB))
1152
1153
1154 @dataclass
1155 class Line:
1156     """Holds leaves and comments. Can be printed with `str(line)`."""
1157
1158     depth: int = 0
1159     leaves: List[Leaf] = Factory(list)
1160     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1161     bracket_tracker: BracketTracker = Factory(BracketTracker)
1162     inside_brackets: bool = False
1163     should_explode: bool = False
1164
1165     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1166         """Add a new `leaf` to the end of the line.
1167
1168         Unless `preformatted` is True, the `leaf` will receive a new consistent
1169         whitespace prefix and metadata applied by :class:`BracketTracker`.
1170         Trailing commas are maybe removed, unpacked for loop variables are
1171         demoted from being delimiters.
1172
1173         Inline comments are put aside.
1174         """
1175         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1176         if not has_value:
1177             return
1178
1179         if token.COLON == leaf.type and self.is_class_paren_empty:
1180             del self.leaves[-2:]
1181         if self.leaves and not preformatted:
1182             # Note: at this point leaf.prefix should be empty except for
1183             # imports, for which we only preserve newlines.
1184             leaf.prefix += whitespace(
1185                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1186             )
1187         if self.inside_brackets or not preformatted:
1188             self.bracket_tracker.mark(leaf)
1189             self.maybe_remove_trailing_comma(leaf)
1190         if not self.append_comment(leaf):
1191             self.leaves.append(leaf)
1192
1193     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1194         """Like :func:`append()` but disallow invalid standalone comment structure.
1195
1196         Raises ValueError when any `leaf` is appended after a standalone comment
1197         or when a standalone comment is not the first leaf on the line.
1198         """
1199         if self.bracket_tracker.depth == 0:
1200             if self.is_comment:
1201                 raise ValueError("cannot append to standalone comments")
1202
1203             if self.leaves and leaf.type == STANDALONE_COMMENT:
1204                 raise ValueError(
1205                     "cannot append standalone comments to a populated line"
1206                 )
1207
1208         self.append(leaf, preformatted=preformatted)
1209
1210     @property
1211     def is_comment(self) -> bool:
1212         """Is this line a standalone comment?"""
1213         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1214
1215     @property
1216     def is_decorator(self) -> bool:
1217         """Is this line a decorator?"""
1218         return bool(self) and self.leaves[0].type == token.AT
1219
1220     @property
1221     def is_import(self) -> bool:
1222         """Is this an import line?"""
1223         return bool(self) and is_import(self.leaves[0])
1224
1225     @property
1226     def is_class(self) -> bool:
1227         """Is this line a class definition?"""
1228         return (
1229             bool(self)
1230             and self.leaves[0].type == token.NAME
1231             and self.leaves[0].value == "class"
1232         )
1233
1234     @property
1235     def is_stub_class(self) -> bool:
1236         """Is this line a class definition with a body consisting only of "..."?"""
1237         return self.is_class and self.leaves[-3:] == [
1238             Leaf(token.DOT, ".") for _ in range(3)
1239         ]
1240
1241     @property
1242     def is_collection_with_optional_trailing_comma(self) -> bool:
1243         """Is this line a collection literal with a trailing comma that's optional?
1244
1245         Note that the trailing comma in a 1-tuple is not optional.
1246         """
1247         if not self.leaves or len(self.leaves) < 4:
1248             return False
1249         # Look for and address a trailing colon.
1250         if self.leaves[-1].type == token.COLON:
1251             closer = self.leaves[-2]
1252             close_index = -2
1253         else:
1254             closer = self.leaves[-1]
1255             close_index = -1
1256         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1257             return False
1258         if closer.type == token.RPAR:
1259             # Tuples require an extra check, because if there's only
1260             # one element in the tuple removing the comma unmakes the
1261             # tuple.
1262             #
1263             # We also check for parens before looking for the trailing
1264             # comma because in some cases (eg assigning a dict
1265             # literal) the literal gets wrapped in temporary parens
1266             # during parsing. This case is covered by the
1267             # collections.py test data.
1268             opener = closer.opening_bracket
1269             for _open_index, leaf in enumerate(self.leaves):
1270                 if leaf is opener:
1271                     break
1272             else:
1273                 # Couldn't find the matching opening paren, play it safe.
1274                 return False
1275             commas = 0
1276             comma_depth = self.leaves[close_index - 1].bracket_depth
1277             for leaf in self.leaves[_open_index + 1 : close_index]:
1278                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1279                     commas += 1
1280             if commas > 1:
1281                 # We haven't looked yet for the trailing comma because
1282                 # we might also have caught noop parens.
1283                 return self.leaves[close_index - 1].type == token.COMMA
1284             elif commas == 1:
1285                 return False  # it's either a one-tuple or didn't have a trailing comma
1286             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1287                 close_index -= 1
1288                 closer = self.leaves[close_index]
1289                 if closer.type == token.RPAR:
1290                     # TODO: this is a gut feeling. Will we ever see this?
1291                     return False
1292         if self.leaves[close_index - 1].type != token.COMMA:
1293             return False
1294         return True
1295
1296     @property
1297     def is_def(self) -> bool:
1298         """Is this a function definition? (Also returns True for async defs.)"""
1299         try:
1300             first_leaf = self.leaves[0]
1301         except IndexError:
1302             return False
1303
1304         try:
1305             second_leaf: Optional[Leaf] = self.leaves[1]
1306         except IndexError:
1307             second_leaf = None
1308         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1309             first_leaf.type == token.ASYNC
1310             and second_leaf is not None
1311             and second_leaf.type == token.NAME
1312             and second_leaf.value == "def"
1313         )
1314
1315     @property
1316     def is_class_paren_empty(self) -> bool:
1317         """Is this a class with no base classes but using parentheses?
1318
1319         Those are unnecessary and should be removed.
1320         """
1321         return (
1322             bool(self)
1323             and len(self.leaves) == 4
1324             and self.is_class
1325             and self.leaves[2].type == token.LPAR
1326             and self.leaves[2].value == "("
1327             and self.leaves[3].type == token.RPAR
1328             and self.leaves[3].value == ")"
1329         )
1330
1331     @property
1332     def is_triple_quoted_string(self) -> bool:
1333         """Is the line a triple quoted string?"""
1334         return (
1335             bool(self)
1336             and self.leaves[0].type == token.STRING
1337             and self.leaves[0].value.startswith(('"""', "'''"))
1338         )
1339
1340     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1341         """If so, needs to be split before emitting."""
1342         for leaf in self.leaves:
1343             if leaf.type == STANDALONE_COMMENT:
1344                 if leaf.bracket_depth <= depth_limit:
1345                     return True
1346         return False
1347
1348     def contains_uncollapsable_type_comments(self) -> bool:
1349         ignored_ids = set()
1350         try:
1351             last_leaf = self.leaves[-1]
1352             ignored_ids.add(id(last_leaf))
1353             if last_leaf.type == token.COMMA or (
1354                 last_leaf.type == token.RPAR and not last_leaf.value
1355             ):
1356                 # When trailing commas or optional parens are inserted by Black for
1357                 # consistency, comments after the previous last element are not moved
1358                 # (they don't have to, rendering will still be correct).  So we ignore
1359                 # trailing commas and invisible.
1360                 last_leaf = self.leaves[-2]
1361                 ignored_ids.add(id(last_leaf))
1362         except IndexError:
1363             return False
1364
1365         # A type comment is uncollapsable if it is attached to a leaf
1366         # that isn't at the end of the line (since that could cause it
1367         # to get associated to a different argument) or if there are
1368         # comments before it (since that could cause it to get hidden
1369         # behind a comment.
1370         comment_seen = False
1371         for leaf_id, comments in self.comments.items():
1372             for comment in comments:
1373                 if is_type_comment(comment):
1374                     if leaf_id not in ignored_ids or comment_seen:
1375                         return True
1376
1377                 comment_seen = True
1378
1379         return False
1380
1381     def contains_unsplittable_type_ignore(self) -> bool:
1382         if not self.leaves:
1383             return False
1384
1385         # If a 'type: ignore' is attached to the end of a line, we
1386         # can't split the line, because we can't know which of the
1387         # subexpressions the ignore was meant to apply to.
1388         #
1389         # We only want this to apply to actual physical lines from the
1390         # original source, though: we don't want the presence of a
1391         # 'type: ignore' at the end of a multiline expression to
1392         # justify pushing it all onto one line. Thus we
1393         # (unfortunately) need to check the actual source lines and
1394         # only report an unsplittable 'type: ignore' if this line was
1395         # one line in the original code.
1396
1397         # Grab the first and last line numbers, skipping generated leaves
1398         first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0)
1399         last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0)
1400
1401         if first_line == last_line:
1402             # We look at the last two leaves since a comma or an
1403             # invisible paren could have been added at the end of the
1404             # line.
1405             for node in self.leaves[-2:]:
1406                 for comment in self.comments.get(id(node), []):
1407                     if is_type_comment(comment, " ignore"):
1408                         return True
1409
1410         return False
1411
1412     def contains_multiline_strings(self) -> bool:
1413         for leaf in self.leaves:
1414             if is_multiline_string(leaf):
1415                 return True
1416
1417         return False
1418
1419     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1420         """Remove trailing comma if there is one and it's safe."""
1421         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1422             return False
1423         # We remove trailing commas only in the case of importing a
1424         # single name from a module.
1425         if not (
1426             self.leaves
1427             and self.is_import
1428             and len(self.leaves) > 4
1429             and self.leaves[-1].type == token.COMMA
1430             and closing.type in CLOSING_BRACKETS
1431             and self.leaves[-4].type == token.NAME
1432             and (
1433                 # regular `from foo import bar,`
1434                 self.leaves[-4].value == "import"
1435                 # `from foo import (bar as baz,)
1436                 or (
1437                     len(self.leaves) > 6
1438                     and self.leaves[-6].value == "import"
1439                     and self.leaves[-3].value == "as"
1440                 )
1441                 # `from foo import bar as baz,`
1442                 or (
1443                     len(self.leaves) > 5
1444                     and self.leaves[-5].value == "import"
1445                     and self.leaves[-3].value == "as"
1446                 )
1447             )
1448             and closing.type == token.RPAR
1449         ):
1450             return False
1451
1452         self.remove_trailing_comma()
1453         return True
1454
1455     def append_comment(self, comment: Leaf) -> bool:
1456         """Add an inline or standalone comment to the line."""
1457         if (
1458             comment.type == STANDALONE_COMMENT
1459             and self.bracket_tracker.any_open_brackets()
1460         ):
1461             comment.prefix = ""
1462             return False
1463
1464         if comment.type != token.COMMENT:
1465             return False
1466
1467         if not self.leaves:
1468             comment.type = STANDALONE_COMMENT
1469             comment.prefix = ""
1470             return False
1471
1472         last_leaf = self.leaves[-1]
1473         if (
1474             last_leaf.type == token.RPAR
1475             and not last_leaf.value
1476             and last_leaf.parent
1477             and len(list(last_leaf.parent.leaves())) <= 3
1478             and not is_type_comment(comment)
1479         ):
1480             # Comments on an optional parens wrapping a single leaf should belong to
1481             # the wrapped node except if it's a type comment. Pinning the comment like
1482             # this avoids unstable formatting caused by comment migration.
1483             if len(self.leaves) < 2:
1484                 comment.type = STANDALONE_COMMENT
1485                 comment.prefix = ""
1486                 return False
1487             last_leaf = self.leaves[-2]
1488         self.comments.setdefault(id(last_leaf), []).append(comment)
1489         return True
1490
1491     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1492         """Generate comments that should appear directly after `leaf`."""
1493         return self.comments.get(id(leaf), [])
1494
1495     def remove_trailing_comma(self) -> None:
1496         """Remove the trailing comma and moves the comments attached to it."""
1497         trailing_comma = self.leaves.pop()
1498         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1499         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1500             trailing_comma_comments
1501         )
1502
1503     def is_complex_subscript(self, leaf: Leaf) -> bool:
1504         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1505         open_lsqb = self.bracket_tracker.get_open_lsqb()
1506         if open_lsqb is None:
1507             return False
1508
1509         subscript_start = open_lsqb.next_sibling
1510
1511         if isinstance(subscript_start, Node):
1512             if subscript_start.type == syms.listmaker:
1513                 return False
1514
1515             if subscript_start.type == syms.subscriptlist:
1516                 subscript_start = child_towards(subscript_start, leaf)
1517         return subscript_start is not None and any(
1518             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1519         )
1520
1521     def __str__(self) -> str:
1522         """Render the line."""
1523         if not self:
1524             return "\n"
1525
1526         indent = "    " * self.depth
1527         leaves = iter(self.leaves)
1528         first = next(leaves)
1529         res = f"{first.prefix}{indent}{first.value}"
1530         for leaf in leaves:
1531             res += str(leaf)
1532         for comment in itertools.chain.from_iterable(self.comments.values()):
1533             res += str(comment)
1534         return res + "\n"
1535
1536     def __bool__(self) -> bool:
1537         """Return True if the line has leaves or comments."""
1538         return bool(self.leaves or self.comments)
1539
1540
1541 @dataclass
1542 class EmptyLineTracker:
1543     """Provides a stateful method that returns the number of potential extra
1544     empty lines needed before and after the currently processed line.
1545
1546     Note: this tracker works on lines that haven't been split yet.  It assumes
1547     the prefix of the first leaf consists of optional newlines.  Those newlines
1548     are consumed by `maybe_empty_lines()` and included in the computation.
1549     """
1550
1551     is_pyi: bool = False
1552     previous_line: Optional[Line] = None
1553     previous_after: int = 0
1554     previous_defs: List[int] = Factory(list)
1555
1556     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1557         """Return the number of extra empty lines before and after the `current_line`.
1558
1559         This is for separating `def`, `async def` and `class` with extra empty
1560         lines (two on module-level).
1561         """
1562         before, after = self._maybe_empty_lines(current_line)
1563         before = (
1564             # Black should not insert empty lines at the beginning
1565             # of the file
1566             0
1567             if self.previous_line is None
1568             else before - self.previous_after
1569         )
1570         self.previous_after = after
1571         self.previous_line = current_line
1572         return before, after
1573
1574     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1575         max_allowed = 1
1576         if current_line.depth == 0:
1577             max_allowed = 1 if self.is_pyi else 2
1578         if current_line.leaves:
1579             # Consume the first leaf's extra newlines.
1580             first_leaf = current_line.leaves[0]
1581             before = first_leaf.prefix.count("\n")
1582             before = min(before, max_allowed)
1583             first_leaf.prefix = ""
1584         else:
1585             before = 0
1586         depth = current_line.depth
1587         while self.previous_defs and self.previous_defs[-1] >= depth:
1588             self.previous_defs.pop()
1589             if self.is_pyi:
1590                 before = 0 if depth else 1
1591             else:
1592                 before = 1 if depth else 2
1593         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1594             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1595
1596         if (
1597             self.previous_line
1598             and self.previous_line.is_import
1599             and not current_line.is_import
1600             and depth == self.previous_line.depth
1601         ):
1602             return (before or 1), 0
1603
1604         if (
1605             self.previous_line
1606             and self.previous_line.is_class
1607             and current_line.is_triple_quoted_string
1608         ):
1609             return before, 1
1610
1611         return before, 0
1612
1613     def _maybe_empty_lines_for_class_or_def(
1614         self, current_line: Line, before: int
1615     ) -> Tuple[int, int]:
1616         if not current_line.is_decorator:
1617             self.previous_defs.append(current_line.depth)
1618         if self.previous_line is None:
1619             # Don't insert empty lines before the first line in the file.
1620             return 0, 0
1621
1622         if self.previous_line.is_decorator:
1623             return 0, 0
1624
1625         if self.previous_line.depth < current_line.depth and (
1626             self.previous_line.is_class or self.previous_line.is_def
1627         ):
1628             return 0, 0
1629
1630         if (
1631             self.previous_line.is_comment
1632             and self.previous_line.depth == current_line.depth
1633             and before == 0
1634         ):
1635             return 0, 0
1636
1637         if self.is_pyi:
1638             if self.previous_line.depth > current_line.depth:
1639                 newlines = 1
1640             elif current_line.is_class or self.previous_line.is_class:
1641                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1642                     # No blank line between classes with an empty body
1643                     newlines = 0
1644                 else:
1645                     newlines = 1
1646             elif current_line.is_def and not self.previous_line.is_def:
1647                 # Blank line between a block of functions and a block of non-functions
1648                 newlines = 1
1649             else:
1650                 newlines = 0
1651         else:
1652             newlines = 2
1653         if current_line.depth and newlines:
1654             newlines -= 1
1655         return newlines, 0
1656
1657
1658 @dataclass
1659 class LineGenerator(Visitor[Line]):
1660     """Generates reformatted Line objects.  Empty lines are not emitted.
1661
1662     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1663     in ways that will no longer stringify to valid Python code on the tree.
1664     """
1665
1666     is_pyi: bool = False
1667     normalize_strings: bool = True
1668     current_line: Line = Factory(Line)
1669     remove_u_prefix: bool = False
1670
1671     def line(self, indent: int = 0) -> Iterator[Line]:
1672         """Generate a line.
1673
1674         If the line is empty, only emit if it makes sense.
1675         If the line is too long, split it first and then generate.
1676
1677         If any lines were generated, set up a new current_line.
1678         """
1679         if not self.current_line:
1680             self.current_line.depth += indent
1681             return  # Line is empty, don't emit. Creating a new one unnecessary.
1682
1683         complete_line = self.current_line
1684         self.current_line = Line(depth=complete_line.depth + indent)
1685         yield complete_line
1686
1687     def visit_default(self, node: LN) -> Iterator[Line]:
1688         """Default `visit_*()` implementation. Recurses to children of `node`."""
1689         if isinstance(node, Leaf):
1690             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1691             for comment in generate_comments(node):
1692                 if any_open_brackets:
1693                     # any comment within brackets is subject to splitting
1694                     self.current_line.append(comment)
1695                 elif comment.type == token.COMMENT:
1696                     # regular trailing comment
1697                     self.current_line.append(comment)
1698                     yield from self.line()
1699
1700                 else:
1701                     # regular standalone comment
1702                     yield from self.line()
1703
1704                     self.current_line.append(comment)
1705                     yield from self.line()
1706
1707             normalize_prefix(node, inside_brackets=any_open_brackets)
1708             if self.normalize_strings and node.type == token.STRING:
1709                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1710                 normalize_string_quotes(node)
1711             if node.type == token.NUMBER:
1712                 normalize_numeric_literal(node)
1713             if node.type not in WHITESPACE:
1714                 self.current_line.append(node)
1715         yield from super().visit_default(node)
1716
1717     def visit_factor(self, node: Node) -> Iterator[Line]:
1718         """Force parentheses between a unary op and a binary power:
1719
1720         -2 ** 8 -> -(2 ** 8)
1721         """
1722         child = node.children[1]
1723         if child.type == syms.power and len(child.children) == 3:
1724             lpar = Leaf(token.LPAR, "(")
1725             rpar = Leaf(token.RPAR, ")")
1726             index = child.remove() or 0
1727             node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
1728         yield from self.visit_default(node)
1729
1730     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1731         """Increase indentation level, maybe yield a line."""
1732         # In blib2to3 INDENT never holds comments.
1733         yield from self.line(+1)
1734         yield from self.visit_default(node)
1735
1736     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1737         """Decrease indentation level, maybe yield a line."""
1738         # The current line might still wait for trailing comments.  At DEDENT time
1739         # there won't be any (they would be prefixes on the preceding NEWLINE).
1740         # Emit the line then.
1741         yield from self.line()
1742
1743         # While DEDENT has no value, its prefix may contain standalone comments
1744         # that belong to the current indentation level.  Get 'em.
1745         yield from self.visit_default(node)
1746
1747         # Finally, emit the dedent.
1748         yield from self.line(-1)
1749
1750     def visit_stmt(
1751         self, node: Node, keywords: Set[str], parens: Set[str]
1752     ) -> Iterator[Line]:
1753         """Visit a statement.
1754
1755         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1756         `def`, `with`, `class`, `assert` and assignments.
1757
1758         The relevant Python language `keywords` for a given statement will be
1759         NAME leaves within it. This methods puts those on a separate line.
1760
1761         `parens` holds a set of string leaf values immediately after which
1762         invisible parens should be put.
1763         """
1764         normalize_invisible_parens(node, parens_after=parens)
1765         for child in node.children:
1766             if child.type == token.NAME and child.value in keywords:  # type: ignore
1767                 yield from self.line()
1768
1769             yield from self.visit(child)
1770
1771     def visit_suite(self, node: Node) -> Iterator[Line]:
1772         """Visit a suite."""
1773         if self.is_pyi and is_stub_suite(node):
1774             yield from self.visit(node.children[2])
1775         else:
1776             yield from self.visit_default(node)
1777
1778     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1779         """Visit a statement without nested statements."""
1780         is_suite_like = node.parent and node.parent.type in STATEMENT
1781         if is_suite_like:
1782             if self.is_pyi and is_stub_body(node):
1783                 yield from self.visit_default(node)
1784             else:
1785                 yield from self.line(+1)
1786                 yield from self.visit_default(node)
1787                 yield from self.line(-1)
1788
1789         else:
1790             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1791                 yield from self.line()
1792             yield from self.visit_default(node)
1793
1794     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1795         """Visit `async def`, `async for`, `async with`."""
1796         yield from self.line()
1797
1798         children = iter(node.children)
1799         for child in children:
1800             yield from self.visit(child)
1801
1802             if child.type == token.ASYNC:
1803                 break
1804
1805         internal_stmt = next(children)
1806         for child in internal_stmt.children:
1807             yield from self.visit(child)
1808
1809     def visit_decorators(self, node: Node) -> Iterator[Line]:
1810         """Visit decorators."""
1811         for child in node.children:
1812             yield from self.line()
1813             yield from self.visit(child)
1814
1815     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1816         """Remove a semicolon and put the other statement on a separate line."""
1817         yield from self.line()
1818
1819     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1820         """End of file. Process outstanding comments and end with a newline."""
1821         yield from self.visit_default(leaf)
1822         yield from self.line()
1823
1824     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1825         if not self.current_line.bracket_tracker.any_open_brackets():
1826             yield from self.line()
1827         yield from self.visit_default(leaf)
1828
1829     def __attrs_post_init__(self) -> None:
1830         """You are in a twisty little maze of passages."""
1831         v = self.visit_stmt
1832         Ø: Set[str] = set()
1833         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1834         self.visit_if_stmt = partial(
1835             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1836         )
1837         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1838         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1839         self.visit_try_stmt = partial(
1840             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1841         )
1842         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1843         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1844         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1845         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1846         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1847         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1848         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1849         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1850         self.visit_async_funcdef = self.visit_async_stmt
1851         self.visit_decorated = self.visit_decorators
1852
1853
1854 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1855 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1856 OPENING_BRACKETS = set(BRACKET.keys())
1857 CLOSING_BRACKETS = set(BRACKET.values())
1858 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1859 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1860
1861
1862 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1863     """Return whitespace prefix if needed for the given `leaf`.
1864
1865     `complex_subscript` signals whether the given leaf is part of a subscription
1866     which has non-trivial arguments, like arithmetic expressions or function calls.
1867     """
1868     NO = ""
1869     SPACE = " "
1870     DOUBLESPACE = "  "
1871     t = leaf.type
1872     p = leaf.parent
1873     v = leaf.value
1874     if t in ALWAYS_NO_SPACE:
1875         return NO
1876
1877     if t == token.COMMENT:
1878         return DOUBLESPACE
1879
1880     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1881     if t == token.COLON and p.type not in {
1882         syms.subscript,
1883         syms.subscriptlist,
1884         syms.sliceop,
1885     }:
1886         return NO
1887
1888     prev = leaf.prev_sibling
1889     if not prev:
1890         prevp = preceding_leaf(p)
1891         if not prevp or prevp.type in OPENING_BRACKETS:
1892             return NO
1893
1894         if t == token.COLON:
1895             if prevp.type == token.COLON:
1896                 return NO
1897
1898             elif prevp.type != token.COMMA and not complex_subscript:
1899                 return NO
1900
1901             return SPACE
1902
1903         if prevp.type == token.EQUAL:
1904             if prevp.parent:
1905                 if prevp.parent.type in {
1906                     syms.arglist,
1907                     syms.argument,
1908                     syms.parameters,
1909                     syms.varargslist,
1910                 }:
1911                     return NO
1912
1913                 elif prevp.parent.type == syms.typedargslist:
1914                     # A bit hacky: if the equal sign has whitespace, it means we
1915                     # previously found it's a typed argument.  So, we're using
1916                     # that, too.
1917                     return prevp.prefix
1918
1919         elif prevp.type in VARARGS_SPECIALS:
1920             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1921                 return NO
1922
1923         elif prevp.type == token.COLON:
1924             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1925                 return SPACE if complex_subscript else NO
1926
1927         elif (
1928             prevp.parent
1929             and prevp.parent.type == syms.factor
1930             and prevp.type in MATH_OPERATORS
1931         ):
1932             return NO
1933
1934         elif (
1935             prevp.type == token.RIGHTSHIFT
1936             and prevp.parent
1937             and prevp.parent.type == syms.shift_expr
1938             and prevp.prev_sibling
1939             and prevp.prev_sibling.type == token.NAME
1940             and prevp.prev_sibling.value == "print"  # type: ignore
1941         ):
1942             # Python 2 print chevron
1943             return NO
1944
1945     elif prev.type in OPENING_BRACKETS:
1946         return NO
1947
1948     if p.type in {syms.parameters, syms.arglist}:
1949         # untyped function signatures or calls
1950         if not prev or prev.type != token.COMMA:
1951             return NO
1952
1953     elif p.type == syms.varargslist:
1954         # lambdas
1955         if prev and prev.type != token.COMMA:
1956             return NO
1957
1958     elif p.type == syms.typedargslist:
1959         # typed function signatures
1960         if not prev:
1961             return NO
1962
1963         if t == token.EQUAL:
1964             if prev.type != syms.tname:
1965                 return NO
1966
1967         elif prev.type == token.EQUAL:
1968             # A bit hacky: if the equal sign has whitespace, it means we
1969             # previously found it's a typed argument.  So, we're using that, too.
1970             return prev.prefix
1971
1972         elif prev.type != token.COMMA:
1973             return NO
1974
1975     elif p.type == syms.tname:
1976         # type names
1977         if not prev:
1978             prevp = preceding_leaf(p)
1979             if not prevp or prevp.type != token.COMMA:
1980                 return NO
1981
1982     elif p.type == syms.trailer:
1983         # attributes and calls
1984         if t == token.LPAR or t == token.RPAR:
1985             return NO
1986
1987         if not prev:
1988             if t == token.DOT:
1989                 prevp = preceding_leaf(p)
1990                 if not prevp or prevp.type != token.NUMBER:
1991                     return NO
1992
1993             elif t == token.LSQB:
1994                 return NO
1995
1996         elif prev.type != token.COMMA:
1997             return NO
1998
1999     elif p.type == syms.argument:
2000         # single argument
2001         if t == token.EQUAL:
2002             return NO
2003
2004         if not prev:
2005             prevp = preceding_leaf(p)
2006             if not prevp or prevp.type == token.LPAR:
2007                 return NO
2008
2009         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2010             return NO
2011
2012     elif p.type == syms.decorator:
2013         # decorators
2014         return NO
2015
2016     elif p.type == syms.dotted_name:
2017         if prev:
2018             return NO
2019
2020         prevp = preceding_leaf(p)
2021         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2022             return NO
2023
2024     elif p.type == syms.classdef:
2025         if t == token.LPAR:
2026             return NO
2027
2028         if prev and prev.type == token.LPAR:
2029             return NO
2030
2031     elif p.type in {syms.subscript, syms.sliceop}:
2032         # indexing
2033         if not prev:
2034             assert p.parent is not None, "subscripts are always parented"
2035             if p.parent.type == syms.subscriptlist:
2036                 return SPACE
2037
2038             return NO
2039
2040         elif not complex_subscript:
2041             return NO
2042
2043     elif p.type == syms.atom:
2044         if prev and t == token.DOT:
2045             # dots, but not the first one.
2046             return NO
2047
2048     elif p.type == syms.dictsetmaker:
2049         # dict unpacking
2050         if prev and prev.type == token.DOUBLESTAR:
2051             return NO
2052
2053     elif p.type in {syms.factor, syms.star_expr}:
2054         # unary ops
2055         if not prev:
2056             prevp = preceding_leaf(p)
2057             if not prevp or prevp.type in OPENING_BRACKETS:
2058                 return NO
2059
2060             prevp_parent = prevp.parent
2061             assert prevp_parent is not None
2062             if prevp.type == token.COLON and prevp_parent.type in {
2063                 syms.subscript,
2064                 syms.sliceop,
2065             }:
2066                 return NO
2067
2068             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2069                 return NO
2070
2071         elif t in {token.NAME, token.NUMBER, token.STRING}:
2072             return NO
2073
2074     elif p.type == syms.import_from:
2075         if t == token.DOT:
2076             if prev and prev.type == token.DOT:
2077                 return NO
2078
2079         elif t == token.NAME:
2080             if v == "import":
2081                 return SPACE
2082
2083             if prev and prev.type == token.DOT:
2084                 return NO
2085
2086     elif p.type == syms.sliceop:
2087         return NO
2088
2089     return SPACE
2090
2091
2092 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2093     """Return the first leaf that precedes `node`, if any."""
2094     while node:
2095         res = node.prev_sibling
2096         if res:
2097             if isinstance(res, Leaf):
2098                 return res
2099
2100             try:
2101                 return list(res.leaves())[-1]
2102
2103             except IndexError:
2104                 return None
2105
2106         node = node.parent
2107     return None
2108
2109
2110 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2111     """Return the child of `ancestor` that contains `descendant`."""
2112     node: Optional[LN] = descendant
2113     while node and node.parent != ancestor:
2114         node = node.parent
2115     return node
2116
2117
2118 def container_of(leaf: Leaf) -> LN:
2119     """Return `leaf` or one of its ancestors that is the topmost container of it.
2120
2121     By "container" we mean a node where `leaf` is the very first child.
2122     """
2123     same_prefix = leaf.prefix
2124     container: LN = leaf
2125     while container:
2126         parent = container.parent
2127         if parent is None:
2128             break
2129
2130         if parent.children[0].prefix != same_prefix:
2131             break
2132
2133         if parent.type == syms.file_input:
2134             break
2135
2136         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2137             break
2138
2139         container = parent
2140     return container
2141
2142
2143 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2144     """Return the priority of the `leaf` delimiter, given a line break after it.
2145
2146     The delimiter priorities returned here are from those delimiters that would
2147     cause a line break after themselves.
2148
2149     Higher numbers are higher priority.
2150     """
2151     if leaf.type == token.COMMA:
2152         return COMMA_PRIORITY
2153
2154     return 0
2155
2156
2157 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2158     """Return the priority of the `leaf` delimiter, given a line break before it.
2159
2160     The delimiter priorities returned here are from those delimiters that would
2161     cause a line break before themselves.
2162
2163     Higher numbers are higher priority.
2164     """
2165     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2166         # * and ** might also be MATH_OPERATORS but in this case they are not.
2167         # Don't treat them as a delimiter.
2168         return 0
2169
2170     if (
2171         leaf.type == token.DOT
2172         and leaf.parent
2173         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2174         and (previous is None or previous.type in CLOSING_BRACKETS)
2175     ):
2176         return DOT_PRIORITY
2177
2178     if (
2179         leaf.type in MATH_OPERATORS
2180         and leaf.parent
2181         and leaf.parent.type not in {syms.factor, syms.star_expr}
2182     ):
2183         return MATH_PRIORITIES[leaf.type]
2184
2185     if leaf.type in COMPARATORS:
2186         return COMPARATOR_PRIORITY
2187
2188     if (
2189         leaf.type == token.STRING
2190         and previous is not None
2191         and previous.type == token.STRING
2192     ):
2193         return STRING_PRIORITY
2194
2195     if leaf.type not in {token.NAME, token.ASYNC}:
2196         return 0
2197
2198     if (
2199         leaf.value == "for"
2200         and leaf.parent
2201         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2202         or leaf.type == token.ASYNC
2203     ):
2204         if (
2205             not isinstance(leaf.prev_sibling, Leaf)
2206             or leaf.prev_sibling.value != "async"
2207         ):
2208             return COMPREHENSION_PRIORITY
2209
2210     if (
2211         leaf.value == "if"
2212         and leaf.parent
2213         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2214     ):
2215         return COMPREHENSION_PRIORITY
2216
2217     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2218         return TERNARY_PRIORITY
2219
2220     if leaf.value == "is":
2221         return COMPARATOR_PRIORITY
2222
2223     if (
2224         leaf.value == "in"
2225         and leaf.parent
2226         and leaf.parent.type in {syms.comp_op, syms.comparison}
2227         and not (
2228             previous is not None
2229             and previous.type == token.NAME
2230             and previous.value == "not"
2231         )
2232     ):
2233         return COMPARATOR_PRIORITY
2234
2235     if (
2236         leaf.value == "not"
2237         and leaf.parent
2238         and leaf.parent.type == syms.comp_op
2239         and not (
2240             previous is not None
2241             and previous.type == token.NAME
2242             and previous.value == "is"
2243         )
2244     ):
2245         return COMPARATOR_PRIORITY
2246
2247     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2248         return LOGIC_PRIORITY
2249
2250     return 0
2251
2252
2253 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2254 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2255
2256
2257 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2258     """Clean the prefix of the `leaf` and generate comments from it, if any.
2259
2260     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2261     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2262     move because it does away with modifying the grammar to include all the
2263     possible places in which comments can be placed.
2264
2265     The sad consequence for us though is that comments don't "belong" anywhere.
2266     This is why this function generates simple parentless Leaf objects for
2267     comments.  We simply don't know what the correct parent should be.
2268
2269     No matter though, we can live without this.  We really only need to
2270     differentiate between inline and standalone comments.  The latter don't
2271     share the line with any code.
2272
2273     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2274     are emitted with a fake STANDALONE_COMMENT token identifier.
2275     """
2276     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2277         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2278
2279
2280 @dataclass
2281 class ProtoComment:
2282     """Describes a piece of syntax that is a comment.
2283
2284     It's not a :class:`blib2to3.pytree.Leaf` so that:
2285
2286     * it can be cached (`Leaf` objects should not be reused more than once as
2287       they store their lineno, column, prefix, and parent information);
2288     * `newlines` and `consumed` fields are kept separate from the `value`. This
2289       simplifies handling of special marker comments like ``# fmt: off/on``.
2290     """
2291
2292     type: int  # token.COMMENT or STANDALONE_COMMENT
2293     value: str  # content of the comment
2294     newlines: int  # how many newlines before the comment
2295     consumed: int  # how many characters of the original leaf's prefix did we consume
2296
2297
2298 @lru_cache(maxsize=4096)
2299 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2300     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2301     result: List[ProtoComment] = []
2302     if not prefix or "#" not in prefix:
2303         return result
2304
2305     consumed = 0
2306     nlines = 0
2307     ignored_lines = 0
2308     for index, line in enumerate(prefix.split("\n")):
2309         consumed += len(line) + 1  # adding the length of the split '\n'
2310         line = line.lstrip()
2311         if not line:
2312             nlines += 1
2313         if not line.startswith("#"):
2314             # Escaped newlines outside of a comment are not really newlines at
2315             # all. We treat a single-line comment following an escaped newline
2316             # as a simple trailing comment.
2317             if line.endswith("\\"):
2318                 ignored_lines += 1
2319             continue
2320
2321         if index == ignored_lines and not is_endmarker:
2322             comment_type = token.COMMENT  # simple trailing comment
2323         else:
2324             comment_type = STANDALONE_COMMENT
2325         comment = make_comment(line)
2326         result.append(
2327             ProtoComment(
2328                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2329             )
2330         )
2331         nlines = 0
2332     return result
2333
2334
2335 def make_comment(content: str) -> str:
2336     """Return a consistently formatted comment from the given `content` string.
2337
2338     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2339     space between the hash sign and the content.
2340
2341     If `content` didn't start with a hash sign, one is provided.
2342     """
2343     content = content.rstrip()
2344     if not content:
2345         return "#"
2346
2347     if content[0] == "#":
2348         content = content[1:]
2349     if content and content[0] not in " !:#'%":
2350         content = " " + content
2351     return "#" + content
2352
2353
2354 def split_line(
2355     line: Line,
2356     line_length: int,
2357     inner: bool = False,
2358     features: Collection[Feature] = (),
2359 ) -> Iterator[Line]:
2360     """Split a `line` into potentially many lines.
2361
2362     They should fit in the allotted `line_length` but might not be able to.
2363     `inner` signifies that there were a pair of brackets somewhere around the
2364     current `line`, possibly transitively. This means we can fallback to splitting
2365     by delimiters if the LHS/RHS don't yield any results.
2366
2367     `features` are syntactical features that may be used in the output.
2368     """
2369     if line.is_comment:
2370         yield line
2371         return
2372
2373     line_str = str(line).strip("\n")
2374
2375     if (
2376         not line.contains_uncollapsable_type_comments()
2377         and not line.should_explode
2378         and not line.is_collection_with_optional_trailing_comma
2379         and (
2380             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2381             or line.contains_unsplittable_type_ignore()
2382         )
2383     ):
2384         yield line
2385         return
2386
2387     split_funcs: List[SplitFunc]
2388     if line.is_def:
2389         split_funcs = [left_hand_split]
2390     else:
2391
2392         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2393             for omit in generate_trailers_to_omit(line, line_length):
2394                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2395                 if is_line_short_enough(lines[0], line_length=line_length):
2396                     yield from lines
2397                     return
2398
2399             # All splits failed, best effort split with no omits.
2400             # This mostly happens to multiline strings that are by definition
2401             # reported as not fitting a single line.
2402             yield from right_hand_split(line, line_length, features=features)
2403
2404         if line.inside_brackets:
2405             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2406         else:
2407             split_funcs = [rhs]
2408     for split_func in split_funcs:
2409         # We are accumulating lines in `result` because we might want to abort
2410         # mission and return the original line in the end, or attempt a different
2411         # split altogether.
2412         result: List[Line] = []
2413         try:
2414             for l in split_func(line, features):
2415                 if str(l).strip("\n") == line_str:
2416                     raise CannotSplit("Split function returned an unchanged result")
2417
2418                 result.extend(
2419                     split_line(
2420                         l, line_length=line_length, inner=True, features=features
2421                     )
2422                 )
2423         except CannotSplit:
2424             continue
2425
2426         else:
2427             yield from result
2428             break
2429
2430     else:
2431         yield line
2432
2433
2434 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2435     """Split line into many lines, starting with the first matching bracket pair.
2436
2437     Note: this usually looks weird, only use this for function definitions.
2438     Prefer RHS otherwise.  This is why this function is not symmetrical with
2439     :func:`right_hand_split` which also handles optional parentheses.
2440     """
2441     tail_leaves: List[Leaf] = []
2442     body_leaves: List[Leaf] = []
2443     head_leaves: List[Leaf] = []
2444     current_leaves = head_leaves
2445     matching_bracket = None
2446     for leaf in line.leaves:
2447         if (
2448             current_leaves is body_leaves
2449             and leaf.type in CLOSING_BRACKETS
2450             and leaf.opening_bracket is matching_bracket
2451         ):
2452             current_leaves = tail_leaves if body_leaves else head_leaves
2453         current_leaves.append(leaf)
2454         if current_leaves is head_leaves:
2455             if leaf.type in OPENING_BRACKETS:
2456                 matching_bracket = leaf
2457                 current_leaves = body_leaves
2458     if not matching_bracket:
2459         raise CannotSplit("No brackets found")
2460
2461     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2462     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2463     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2464     bracket_split_succeeded_or_raise(head, body, tail)
2465     for result in (head, body, tail):
2466         if result:
2467             yield result
2468
2469
2470 def right_hand_split(
2471     line: Line,
2472     line_length: int,
2473     features: Collection[Feature] = (),
2474     omit: Collection[LeafID] = (),
2475 ) -> Iterator[Line]:
2476     """Split line into many lines, starting with the last matching bracket pair.
2477
2478     If the split was by optional parentheses, attempt splitting without them, too.
2479     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2480     this split.
2481
2482     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2483     """
2484     tail_leaves: List[Leaf] = []
2485     body_leaves: List[Leaf] = []
2486     head_leaves: List[Leaf] = []
2487     current_leaves = tail_leaves
2488     opening_bracket = None
2489     closing_bracket = None
2490     for leaf in reversed(line.leaves):
2491         if current_leaves is body_leaves:
2492             if leaf is opening_bracket:
2493                 current_leaves = head_leaves if body_leaves else tail_leaves
2494         current_leaves.append(leaf)
2495         if current_leaves is tail_leaves:
2496             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2497                 opening_bracket = leaf.opening_bracket
2498                 closing_bracket = leaf
2499                 current_leaves = body_leaves
2500     if not (opening_bracket and closing_bracket and head_leaves):
2501         # If there is no opening or closing_bracket that means the split failed and
2502         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2503         # the matching `opening_bracket` wasn't available on `line` anymore.
2504         raise CannotSplit("No brackets found")
2505
2506     tail_leaves.reverse()
2507     body_leaves.reverse()
2508     head_leaves.reverse()
2509     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2510     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2511     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2512     bracket_split_succeeded_or_raise(head, body, tail)
2513     if (
2514         # the body shouldn't be exploded
2515         not body.should_explode
2516         # the opening bracket is an optional paren
2517         and opening_bracket.type == token.LPAR
2518         and not opening_bracket.value
2519         # the closing bracket is an optional paren
2520         and closing_bracket.type == token.RPAR
2521         and not closing_bracket.value
2522         # it's not an import (optional parens are the only thing we can split on
2523         # in this case; attempting a split without them is a waste of time)
2524         and not line.is_import
2525         # there are no standalone comments in the body
2526         and not body.contains_standalone_comments(0)
2527         # and we can actually remove the parens
2528         and can_omit_invisible_parens(body, line_length)
2529     ):
2530         omit = {id(closing_bracket), *omit}
2531         try:
2532             yield from right_hand_split(line, line_length, features=features, omit=omit)
2533             return
2534
2535         except CannotSplit:
2536             if not (
2537                 can_be_split(body)
2538                 or is_line_short_enough(body, line_length=line_length)
2539             ):
2540                 raise CannotSplit(
2541                     "Splitting failed, body is still too long and can't be split."
2542                 )
2543
2544             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2545                 raise CannotSplit(
2546                     "The current optional pair of parentheses is bound to fail to "
2547                     "satisfy the splitting algorithm because the head or the tail "
2548                     "contains multiline strings which by definition never fit one "
2549                     "line."
2550                 )
2551
2552     ensure_visible(opening_bracket)
2553     ensure_visible(closing_bracket)
2554     for result in (head, body, tail):
2555         if result:
2556             yield result
2557
2558
2559 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2560     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2561
2562     Do nothing otherwise.
2563
2564     A left- or right-hand split is based on a pair of brackets. Content before
2565     (and including) the opening bracket is left on one line, content inside the
2566     brackets is put on a separate line, and finally content starting with and
2567     following the closing bracket is put on a separate line.
2568
2569     Those are called `head`, `body`, and `tail`, respectively. If the split
2570     produced the same line (all content in `head`) or ended up with an empty `body`
2571     and the `tail` is just the closing bracket, then it's considered failed.
2572     """
2573     tail_len = len(str(tail).strip())
2574     if not body:
2575         if tail_len == 0:
2576             raise CannotSplit("Splitting brackets produced the same line")
2577
2578         elif tail_len < 3:
2579             raise CannotSplit(
2580                 f"Splitting brackets on an empty body to save "
2581                 f"{tail_len} characters is not worth it"
2582             )
2583
2584
2585 def bracket_split_build_line(
2586     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2587 ) -> Line:
2588     """Return a new line with given `leaves` and respective comments from `original`.
2589
2590     If `is_body` is True, the result line is one-indented inside brackets and as such
2591     has its first leaf's prefix normalized and a trailing comma added when expected.
2592     """
2593     result = Line(depth=original.depth)
2594     if is_body:
2595         result.inside_brackets = True
2596         result.depth += 1
2597         if leaves:
2598             # Since body is a new indent level, remove spurious leading whitespace.
2599             normalize_prefix(leaves[0], inside_brackets=True)
2600             # Ensure a trailing comma for imports and standalone function arguments, but
2601             # be careful not to add one after any comments.
2602             no_commas = original.is_def and not any(
2603                 l.type == token.COMMA for l in leaves
2604             )
2605
2606             if original.is_import or no_commas:
2607                 for i in range(len(leaves) - 1, -1, -1):
2608                     if leaves[i].type == STANDALONE_COMMENT:
2609                         continue
2610                     elif leaves[i].type == token.COMMA:
2611                         break
2612                     else:
2613                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2614                         break
2615     # Populate the line
2616     for leaf in leaves:
2617         result.append(leaf, preformatted=True)
2618         for comment_after in original.comments_after(leaf):
2619             result.append(comment_after, preformatted=True)
2620     if is_body:
2621         result.should_explode = should_explode(result, opening_bracket)
2622     return result
2623
2624
2625 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2626     """Normalize prefix of the first leaf in every line returned by `split_func`.
2627
2628     This is a decorator over relevant split functions.
2629     """
2630
2631     @wraps(split_func)
2632     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2633         for l in split_func(line, features):
2634             normalize_prefix(l.leaves[0], inside_brackets=True)
2635             yield l
2636
2637     return split_wrapper
2638
2639
2640 @dont_increase_indentation
2641 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2642     """Split according to delimiters of the highest priority.
2643
2644     If the appropriate Features are given, the split will add trailing commas
2645     also in function signatures and calls that contain `*` and `**`.
2646     """
2647     try:
2648         last_leaf = line.leaves[-1]
2649     except IndexError:
2650         raise CannotSplit("Line empty")
2651
2652     bt = line.bracket_tracker
2653     try:
2654         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2655     except ValueError:
2656         raise CannotSplit("No delimiters found")
2657
2658     if delimiter_priority == DOT_PRIORITY:
2659         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2660             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2661
2662     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2663     lowest_depth = sys.maxsize
2664     trailing_comma_safe = True
2665
2666     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2667         """Append `leaf` to current line or to new line if appending impossible."""
2668         nonlocal current_line
2669         try:
2670             current_line.append_safe(leaf, preformatted=True)
2671         except ValueError:
2672             yield current_line
2673
2674             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2675             current_line.append(leaf)
2676
2677     for leaf in line.leaves:
2678         yield from append_to_line(leaf)
2679
2680         for comment_after in line.comments_after(leaf):
2681             yield from append_to_line(comment_after)
2682
2683         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2684         if leaf.bracket_depth == lowest_depth:
2685             if is_vararg(leaf, within={syms.typedargslist}):
2686                 trailing_comma_safe = (
2687                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2688                 )
2689             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2690                 trailing_comma_safe = (
2691                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2692                 )
2693
2694         leaf_priority = bt.delimiters.get(id(leaf))
2695         if leaf_priority == delimiter_priority:
2696             yield current_line
2697
2698             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2699     if current_line:
2700         if (
2701             trailing_comma_safe
2702             and delimiter_priority == COMMA_PRIORITY
2703             and current_line.leaves[-1].type != token.COMMA
2704             and current_line.leaves[-1].type != STANDALONE_COMMENT
2705         ):
2706             current_line.append(Leaf(token.COMMA, ","))
2707         yield current_line
2708
2709
2710 @dont_increase_indentation
2711 def standalone_comment_split(
2712     line: Line, features: Collection[Feature] = ()
2713 ) -> Iterator[Line]:
2714     """Split standalone comments from the rest of the line."""
2715     if not line.contains_standalone_comments(0):
2716         raise CannotSplit("Line does not have any standalone comments")
2717
2718     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2719
2720     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2721         """Append `leaf` to current line or to new line if appending impossible."""
2722         nonlocal current_line
2723         try:
2724             current_line.append_safe(leaf, preformatted=True)
2725         except ValueError:
2726             yield current_line
2727
2728             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2729             current_line.append(leaf)
2730
2731     for leaf in line.leaves:
2732         yield from append_to_line(leaf)
2733
2734         for comment_after in line.comments_after(leaf):
2735             yield from append_to_line(comment_after)
2736
2737     if current_line:
2738         yield current_line
2739
2740
2741 def is_import(leaf: Leaf) -> bool:
2742     """Return True if the given leaf starts an import statement."""
2743     p = leaf.parent
2744     t = leaf.type
2745     v = leaf.value
2746     return bool(
2747         t == token.NAME
2748         and (
2749             (v == "import" and p and p.type == syms.import_name)
2750             or (v == "from" and p and p.type == syms.import_from)
2751         )
2752     )
2753
2754
2755 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
2756     """Return True if the given leaf is a special comment.
2757     Only returns true for type comments for now."""
2758     t = leaf.type
2759     v = leaf.value
2760     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
2761
2762
2763 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2764     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2765     else.
2766
2767     Note: don't use backslashes for formatting or you'll lose your voting rights.
2768     """
2769     if not inside_brackets:
2770         spl = leaf.prefix.split("#")
2771         if "\\" not in spl[0]:
2772             nl_count = spl[-1].count("\n")
2773             if len(spl) > 1:
2774                 nl_count -= 1
2775             leaf.prefix = "\n" * nl_count
2776             return
2777
2778     leaf.prefix = ""
2779
2780
2781 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2782     """Make all string prefixes lowercase.
2783
2784     If remove_u_prefix is given, also removes any u prefix from the string.
2785
2786     Note: Mutates its argument.
2787     """
2788     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2789     assert match is not None, f"failed to match string {leaf.value!r}"
2790     orig_prefix = match.group(1)
2791     new_prefix = orig_prefix.lower()
2792     if remove_u_prefix:
2793         new_prefix = new_prefix.replace("u", "")
2794     leaf.value = f"{new_prefix}{match.group(2)}"
2795
2796
2797 def normalize_string_quotes(leaf: Leaf) -> None:
2798     """Prefer double quotes but only if it doesn't cause more escaping.
2799
2800     Adds or removes backslashes as appropriate. Doesn't parse and fix
2801     strings nested in f-strings (yet).
2802
2803     Note: Mutates its argument.
2804     """
2805     value = leaf.value.lstrip("furbFURB")
2806     if value[:3] == '"""':
2807         return
2808
2809     elif value[:3] == "'''":
2810         orig_quote = "'''"
2811         new_quote = '"""'
2812     elif value[0] == '"':
2813         orig_quote = '"'
2814         new_quote = "'"
2815     else:
2816         orig_quote = "'"
2817         new_quote = '"'
2818     first_quote_pos = leaf.value.find(orig_quote)
2819     if first_quote_pos == -1:
2820         return  # There's an internal error
2821
2822     prefix = leaf.value[:first_quote_pos]
2823     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2824     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2825     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2826     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2827     if "r" in prefix.casefold():
2828         if unescaped_new_quote.search(body):
2829             # There's at least one unescaped new_quote in this raw string
2830             # so converting is impossible
2831             return
2832
2833         # Do not introduce or remove backslashes in raw strings
2834         new_body = body
2835     else:
2836         # remove unnecessary escapes
2837         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2838         if body != new_body:
2839             # Consider the string without unnecessary escapes as the original
2840             body = new_body
2841             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2842         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2843         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2844     if "f" in prefix.casefold():
2845         matches = re.findall(
2846             r"""
2847             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2848                 ([^{].*?)  # contents of the brackets except if begins with {{
2849             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2850             """,
2851             new_body,
2852             re.VERBOSE,
2853         )
2854         for m in matches:
2855             if "\\" in str(m):
2856                 # Do not introduce backslashes in interpolated expressions
2857                 return
2858     if new_quote == '"""' and new_body[-1:] == '"':
2859         # edge case:
2860         new_body = new_body[:-1] + '\\"'
2861     orig_escape_count = body.count("\\")
2862     new_escape_count = new_body.count("\\")
2863     if new_escape_count > orig_escape_count:
2864         return  # Do not introduce more escaping
2865
2866     if new_escape_count == orig_escape_count and orig_quote == '"':
2867         return  # Prefer double quotes
2868
2869     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2870
2871
2872 def normalize_numeric_literal(leaf: Leaf) -> None:
2873     """Normalizes numeric (float, int, and complex) literals.
2874
2875     All letters used in the representation are normalized to lowercase (except
2876     in Python 2 long literals).
2877     """
2878     text = leaf.value.lower()
2879     if text.startswith(("0o", "0b")):
2880         # Leave octal and binary literals alone.
2881         pass
2882     elif text.startswith("0x"):
2883         # Change hex literals to upper case.
2884         before, after = text[:2], text[2:]
2885         text = f"{before}{after.upper()}"
2886     elif "e" in text:
2887         before, after = text.split("e")
2888         sign = ""
2889         if after.startswith("-"):
2890             after = after[1:]
2891             sign = "-"
2892         elif after.startswith("+"):
2893             after = after[1:]
2894         before = format_float_or_int_string(before)
2895         text = f"{before}e{sign}{after}"
2896     elif text.endswith(("j", "l")):
2897         number = text[:-1]
2898         suffix = text[-1]
2899         # Capitalize in "2L" because "l" looks too similar to "1".
2900         if suffix == "l":
2901             suffix = "L"
2902         text = f"{format_float_or_int_string(number)}{suffix}"
2903     else:
2904         text = format_float_or_int_string(text)
2905     leaf.value = text
2906
2907
2908 def format_float_or_int_string(text: str) -> str:
2909     """Formats a float string like "1.0"."""
2910     if "." not in text:
2911         return text
2912
2913     before, after = text.split(".")
2914     return f"{before or 0}.{after or 0}"
2915
2916
2917 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2918     """Make existing optional parentheses invisible or create new ones.
2919
2920     `parens_after` is a set of string leaf values immediately after which parens
2921     should be put.
2922
2923     Standardizes on visible parentheses for single-element tuples, and keeps
2924     existing visible parentheses for other tuples and generator expressions.
2925     """
2926     for pc in list_comments(node.prefix, is_endmarker=False):
2927         if pc.value in FMT_OFF:
2928             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2929             return
2930
2931     check_lpar = False
2932     for index, child in enumerate(list(node.children)):
2933         # Add parentheses around long tuple unpacking in assignments.
2934         if (
2935             index == 0
2936             and isinstance(child, Node)
2937             and child.type == syms.testlist_star_expr
2938         ):
2939             check_lpar = True
2940
2941         if check_lpar:
2942             if is_walrus_assignment(child):
2943                 continue
2944             if child.type == syms.atom:
2945                 # Determines if the underlying atom should be surrounded with
2946                 # invisible params - also makes parens invisible recursively
2947                 # within the atom and removes repeated invisible parens within
2948                 # the atom
2949                 should_surround_with_parens = maybe_make_parens_invisible_in_atom(
2950                     child, parent=node
2951                 )
2952
2953                 if should_surround_with_parens:
2954                     lpar = Leaf(token.LPAR, "")
2955                     rpar = Leaf(token.RPAR, "")
2956                     index = child.remove() or 0
2957                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2958             elif is_one_tuple(child):
2959                 # wrap child in visible parentheses
2960                 lpar = Leaf(token.LPAR, "(")
2961                 rpar = Leaf(token.RPAR, ")")
2962                 child.remove()
2963                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2964             elif node.type == syms.import_from:
2965                 # "import from" nodes store parentheses directly as part of
2966                 # the statement
2967                 if child.type == token.LPAR:
2968                     # make parentheses invisible
2969                     child.value = ""  # type: ignore
2970                     node.children[-1].value = ""  # type: ignore
2971                 elif child.type != token.STAR:
2972                     # insert invisible parentheses
2973                     node.insert_child(index, Leaf(token.LPAR, ""))
2974                     node.append_child(Leaf(token.RPAR, ""))
2975                 break
2976
2977             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2978                 # wrap child in invisible parentheses
2979                 lpar = Leaf(token.LPAR, "")
2980                 rpar = Leaf(token.RPAR, "")
2981                 index = child.remove() or 0
2982                 prefix = child.prefix
2983                 child.prefix = ""
2984                 new_child = Node(syms.atom, [lpar, child, rpar])
2985                 new_child.prefix = prefix
2986                 node.insert_child(index, new_child)
2987
2988         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2989
2990
2991 def normalize_fmt_off(node: Node) -> None:
2992     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2993     try_again = True
2994     while try_again:
2995         try_again = convert_one_fmt_off_pair(node)
2996
2997
2998 def convert_one_fmt_off_pair(node: Node) -> bool:
2999     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
3000
3001     Returns True if a pair was converted.
3002     """
3003     for leaf in node.leaves():
3004         previous_consumed = 0
3005         for comment in list_comments(leaf.prefix, is_endmarker=False):
3006             if comment.value in FMT_OFF:
3007                 # We only want standalone comments. If there's no previous leaf or
3008                 # the previous leaf is indentation, it's a standalone comment in
3009                 # disguise.
3010                 if comment.type != STANDALONE_COMMENT:
3011                     prev = preceding_leaf(leaf)
3012                     if prev and prev.type not in WHITESPACE:
3013                         continue
3014
3015                 ignored_nodes = list(generate_ignored_nodes(leaf))
3016                 if not ignored_nodes:
3017                     continue
3018
3019                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
3020                 parent = first.parent
3021                 prefix = first.prefix
3022                 first.prefix = prefix[comment.consumed :]
3023                 hidden_value = (
3024                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
3025                 )
3026                 if hidden_value.endswith("\n"):
3027                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
3028                     # leaf (possibly followed by a DEDENT).
3029                     hidden_value = hidden_value[:-1]
3030                 first_idx = None
3031                 for ignored in ignored_nodes:
3032                     index = ignored.remove()
3033                     if first_idx is None:
3034                         first_idx = index
3035                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
3036                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
3037                 parent.insert_child(
3038                     first_idx,
3039                     Leaf(
3040                         STANDALONE_COMMENT,
3041                         hidden_value,
3042                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3043                     ),
3044                 )
3045                 return True
3046
3047             previous_consumed = comment.consumed
3048
3049     return False
3050
3051
3052 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3053     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3054
3055     Stops at the end of the block.
3056     """
3057     container: Optional[LN] = container_of(leaf)
3058     while container is not None and container.type != token.ENDMARKER:
3059         for comment in list_comments(container.prefix, is_endmarker=False):
3060             if comment.value in FMT_ON:
3061                 return
3062
3063         yield container
3064
3065         container = container.next_sibling
3066
3067
3068 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3069     """If it's safe, make the parens in the atom `node` invisible, recursively.
3070     Additionally, remove repeated, adjacent invisible parens from the atom `node`
3071     as they are redundant.
3072
3073     Returns whether the node should itself be wrapped in invisible parentheses.
3074
3075     """
3076     if (
3077         node.type != syms.atom
3078         or is_empty_tuple(node)
3079         or is_one_tuple(node)
3080         or (is_yield(node) and parent.type != syms.expr_stmt)
3081         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3082     ):
3083         return False
3084
3085     first = node.children[0]
3086     last = node.children[-1]
3087     if first.type == token.LPAR and last.type == token.RPAR:
3088         middle = node.children[1]
3089         # make parentheses invisible
3090         first.value = ""  # type: ignore
3091         last.value = ""  # type: ignore
3092         maybe_make_parens_invisible_in_atom(middle, parent=parent)
3093
3094         if is_atom_with_invisible_parens(middle):
3095             # Strip the invisible parens from `middle` by replacing
3096             # it with the child in-between the invisible parens
3097             middle.replace(middle.children[1])
3098
3099         return False
3100
3101     return True
3102
3103
3104 def is_atom_with_invisible_parens(node: LN) -> bool:
3105     """Given a `LN`, determines whether it's an atom `node` with invisible
3106     parens. Useful in dedupe-ing and normalizing parens.
3107     """
3108     if isinstance(node, Leaf) or node.type != syms.atom:
3109         return False
3110
3111     first, last = node.children[0], node.children[-1]
3112     return (
3113         isinstance(first, Leaf)
3114         and first.type == token.LPAR
3115         and first.value == ""
3116         and isinstance(last, Leaf)
3117         and last.type == token.RPAR
3118         and last.value == ""
3119     )
3120
3121
3122 def is_empty_tuple(node: LN) -> bool:
3123     """Return True if `node` holds an empty tuple."""
3124     return (
3125         node.type == syms.atom
3126         and len(node.children) == 2
3127         and node.children[0].type == token.LPAR
3128         and node.children[1].type == token.RPAR
3129     )
3130
3131
3132 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3133     """Returns `wrapped` if `node` is of the shape ( wrapped ).
3134
3135     Parenthesis can be optional. Returns None otherwise"""
3136     if len(node.children) != 3:
3137         return None
3138     lpar, wrapped, rpar = node.children
3139     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3140         return None
3141
3142     return wrapped
3143
3144
3145 def is_one_tuple(node: LN) -> bool:
3146     """Return True if `node` holds a tuple with one element, with or without parens."""
3147     if node.type == syms.atom:
3148         gexp = unwrap_singleton_parenthesis(node)
3149         if gexp is None or gexp.type != syms.testlist_gexp:
3150             return False
3151
3152         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3153
3154     return (
3155         node.type in IMPLICIT_TUPLE
3156         and len(node.children) == 2
3157         and node.children[1].type == token.COMMA
3158     )
3159
3160
3161 def is_walrus_assignment(node: LN) -> bool:
3162     """Return True iff `node` is of the shape ( test := test )"""
3163     inner = unwrap_singleton_parenthesis(node)
3164     return inner is not None and inner.type == syms.namedexpr_test
3165
3166
3167 def is_yield(node: LN) -> bool:
3168     """Return True if `node` holds a `yield` or `yield from` expression."""
3169     if node.type == syms.yield_expr:
3170         return True
3171
3172     if node.type == token.NAME and node.value == "yield":  # type: ignore
3173         return True
3174
3175     if node.type != syms.atom:
3176         return False
3177
3178     if len(node.children) != 3:
3179         return False
3180
3181     lpar, expr, rpar = node.children
3182     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3183         return is_yield(expr)
3184
3185     return False
3186
3187
3188 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3189     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3190
3191     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3192     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3193     extended iterable unpacking (PEP 3132) and additional unpacking
3194     generalizations (PEP 448).
3195     """
3196     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3197         return False
3198
3199     p = leaf.parent
3200     if p.type == syms.star_expr:
3201         # Star expressions are also used as assignment targets in extended
3202         # iterable unpacking (PEP 3132).  See what its parent is instead.
3203         if not p.parent:
3204             return False
3205
3206         p = p.parent
3207
3208     return p.type in within
3209
3210
3211 def is_multiline_string(leaf: Leaf) -> bool:
3212     """Return True if `leaf` is a multiline string that actually spans many lines."""
3213     value = leaf.value.lstrip("furbFURB")
3214     return value[:3] in {'"""', "'''"} and "\n" in value
3215
3216
3217 def is_stub_suite(node: Node) -> bool:
3218     """Return True if `node` is a suite with a stub body."""
3219     if (
3220         len(node.children) != 4
3221         or node.children[0].type != token.NEWLINE
3222         or node.children[1].type != token.INDENT
3223         or node.children[3].type != token.DEDENT
3224     ):
3225         return False
3226
3227     return is_stub_body(node.children[2])
3228
3229
3230 def is_stub_body(node: LN) -> bool:
3231     """Return True if `node` is a simple statement containing an ellipsis."""
3232     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3233         return False
3234
3235     if len(node.children) != 2:
3236         return False
3237
3238     child = node.children[0]
3239     return (
3240         child.type == syms.atom
3241         and len(child.children) == 3
3242         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3243     )
3244
3245
3246 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3247     """Return maximum delimiter priority inside `node`.
3248
3249     This is specific to atoms with contents contained in a pair of parentheses.
3250     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3251     """
3252     if node.type != syms.atom:
3253         return 0
3254
3255     first = node.children[0]
3256     last = node.children[-1]
3257     if not (first.type == token.LPAR and last.type == token.RPAR):
3258         return 0
3259
3260     bt = BracketTracker()
3261     for c in node.children[1:-1]:
3262         if isinstance(c, Leaf):
3263             bt.mark(c)
3264         else:
3265             for leaf in c.leaves():
3266                 bt.mark(leaf)
3267     try:
3268         return bt.max_delimiter_priority()
3269
3270     except ValueError:
3271         return 0
3272
3273
3274 def ensure_visible(leaf: Leaf) -> None:
3275     """Make sure parentheses are visible.
3276
3277     They could be invisible as part of some statements (see
3278     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3279     """
3280     if leaf.type == token.LPAR:
3281         leaf.value = "("
3282     elif leaf.type == token.RPAR:
3283         leaf.value = ")"
3284
3285
3286 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3287     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3288
3289     if not (
3290         opening_bracket.parent
3291         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3292         and opening_bracket.value in "[{("
3293     ):
3294         return False
3295
3296     try:
3297         last_leaf = line.leaves[-1]
3298         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3299         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3300     except (IndexError, ValueError):
3301         return False
3302
3303     return max_priority == COMMA_PRIORITY
3304
3305
3306 def get_features_used(node: Node) -> Set[Feature]:
3307     """Return a set of (relatively) new Python features used in this file.
3308
3309     Currently looking for:
3310     - f-strings;
3311     - underscores in numeric literals;
3312     - trailing commas after * or ** in function signatures and calls;
3313     - positional only arguments in function signatures and lambdas;
3314     """
3315     features: Set[Feature] = set()
3316     for n in node.pre_order():
3317         if n.type == token.STRING:
3318             value_head = n.value[:2]  # type: ignore
3319             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3320                 features.add(Feature.F_STRINGS)
3321
3322         elif n.type == token.NUMBER:
3323             if "_" in n.value:  # type: ignore
3324                 features.add(Feature.NUMERIC_UNDERSCORES)
3325
3326         elif n.type == token.SLASH:
3327             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3328                 features.add(Feature.POS_ONLY_ARGUMENTS)
3329
3330         elif n.type == token.COLONEQUAL:
3331             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3332
3333         elif (
3334             n.type in {syms.typedargslist, syms.arglist}
3335             and n.children
3336             and n.children[-1].type == token.COMMA
3337         ):
3338             if n.type == syms.typedargslist:
3339                 feature = Feature.TRAILING_COMMA_IN_DEF
3340             else:
3341                 feature = Feature.TRAILING_COMMA_IN_CALL
3342
3343             for ch in n.children:
3344                 if ch.type in STARS:
3345                     features.add(feature)
3346
3347                 if ch.type == syms.argument:
3348                     for argch in ch.children:
3349                         if argch.type in STARS:
3350                             features.add(feature)
3351
3352     return features
3353
3354
3355 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3356     """Detect the version to target based on the nodes used."""
3357     features = get_features_used(node)
3358     return {
3359         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3360     }
3361
3362
3363 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3364     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3365
3366     Brackets can be omitted if the entire trailer up to and including
3367     a preceding closing bracket fits in one line.
3368
3369     Yielded sets are cumulative (contain results of previous yields, too).  First
3370     set is empty.
3371     """
3372
3373     omit: Set[LeafID] = set()
3374     yield omit
3375
3376     length = 4 * line.depth
3377     opening_bracket = None
3378     closing_bracket = None
3379     inner_brackets: Set[LeafID] = set()
3380     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3381         length += leaf_length
3382         if length > line_length:
3383             break
3384
3385         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3386         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3387             break
3388
3389         if opening_bracket:
3390             if leaf is opening_bracket:
3391                 opening_bracket = None
3392             elif leaf.type in CLOSING_BRACKETS:
3393                 inner_brackets.add(id(leaf))
3394         elif leaf.type in CLOSING_BRACKETS:
3395             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3396                 # Empty brackets would fail a split so treat them as "inner"
3397                 # brackets (e.g. only add them to the `omit` set if another
3398                 # pair of brackets was good enough.
3399                 inner_brackets.add(id(leaf))
3400                 continue
3401
3402             if closing_bracket:
3403                 omit.add(id(closing_bracket))
3404                 omit.update(inner_brackets)
3405                 inner_brackets.clear()
3406                 yield omit
3407
3408             if leaf.value:
3409                 opening_bracket = leaf.opening_bracket
3410                 closing_bracket = leaf
3411
3412
3413 def get_future_imports(node: Node) -> Set[str]:
3414     """Return a set of __future__ imports in the file."""
3415     imports: Set[str] = set()
3416
3417     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3418         for child in children:
3419             if isinstance(child, Leaf):
3420                 if child.type == token.NAME:
3421                     yield child.value
3422             elif child.type == syms.import_as_name:
3423                 orig_name = child.children[0]
3424                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3425                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3426                 yield orig_name.value
3427             elif child.type == syms.import_as_names:
3428                 yield from get_imports_from_children(child.children)
3429             else:
3430                 raise AssertionError("Invalid syntax parsing imports")
3431
3432     for child in node.children:
3433         if child.type != syms.simple_stmt:
3434             break
3435         first_child = child.children[0]
3436         if isinstance(first_child, Leaf):
3437             # Continue looking if we see a docstring; otherwise stop.
3438             if (
3439                 len(child.children) == 2
3440                 and first_child.type == token.STRING
3441                 and child.children[1].type == token.NEWLINE
3442             ):
3443                 continue
3444             else:
3445                 break
3446         elif first_child.type == syms.import_from:
3447             module_name = first_child.children[1]
3448             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3449                 break
3450             imports |= set(get_imports_from_children(first_child.children[3:]))
3451         else:
3452             break
3453     return imports
3454
3455
3456 def gen_python_files_in_dir(
3457     path: Path,
3458     root: Path,
3459     include: Pattern[str],
3460     exclude: Pattern[str],
3461     report: "Report",
3462 ) -> Iterator[Path]:
3463     """Generate all files under `path` whose paths are not excluded by the
3464     `exclude` regex, but are included by the `include` regex.
3465
3466     Symbolic links pointing outside of the `root` directory are ignored.
3467
3468     `report` is where output about exclusions goes.
3469     """
3470     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3471     for child in path.iterdir():
3472         try:
3473             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3474         except ValueError:
3475             if child.is_symlink():
3476                 report.path_ignored(
3477                     child, f"is a symbolic link that points outside {root}"
3478                 )
3479                 continue
3480
3481             raise
3482
3483         if child.is_dir():
3484             normalized_path += "/"
3485         exclude_match = exclude.search(normalized_path)
3486         if exclude_match and exclude_match.group(0):
3487             report.path_ignored(child, f"matches the --exclude regular expression")
3488             continue
3489
3490         if child.is_dir():
3491             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3492
3493         elif child.is_file():
3494             include_match = include.search(normalized_path)
3495             if include_match:
3496                 yield child
3497
3498
3499 @lru_cache()
3500 def find_project_root(srcs: Iterable[str]) -> Path:
3501     """Return a directory containing .git, .hg, or pyproject.toml.
3502
3503     That directory can be one of the directories passed in `srcs` or their
3504     common parent.
3505
3506     If no directory in the tree contains a marker that would specify it's the
3507     project root, the root of the file system is returned.
3508     """
3509     if not srcs:
3510         return Path("/").resolve()
3511
3512     common_base = min(Path(src).resolve() for src in srcs)
3513     if common_base.is_dir():
3514         # Append a fake file so `parents` below returns `common_base_dir`, too.
3515         common_base /= "fake-file"
3516     for directory in common_base.parents:
3517         if (directory / ".git").is_dir():
3518             return directory
3519
3520         if (directory / ".hg").is_dir():
3521             return directory
3522
3523         if (directory / "pyproject.toml").is_file():
3524             return directory
3525
3526     return directory
3527
3528
3529 @dataclass
3530 class Report:
3531     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3532
3533     check: bool = False
3534     quiet: bool = False
3535     verbose: bool = False
3536     change_count: int = 0
3537     same_count: int = 0
3538     failure_count: int = 0
3539
3540     def done(self, src: Path, changed: Changed) -> None:
3541         """Increment the counter for successful reformatting. Write out a message."""
3542         if changed is Changed.YES:
3543             reformatted = "would reformat" if self.check else "reformatted"
3544             if self.verbose or not self.quiet:
3545                 out(f"{reformatted} {src}")
3546             self.change_count += 1
3547         else:
3548             if self.verbose:
3549                 if changed is Changed.NO:
3550                     msg = f"{src} already well formatted, good job."
3551                 else:
3552                     msg = f"{src} wasn't modified on disk since last run."
3553                 out(msg, bold=False)
3554             self.same_count += 1
3555
3556     def failed(self, src: Path, message: str) -> None:
3557         """Increment the counter for failed reformatting. Write out a message."""
3558         err(f"error: cannot format {src}: {message}")
3559         self.failure_count += 1
3560
3561     def path_ignored(self, path: Path, message: str) -> None:
3562         if self.verbose:
3563             out(f"{path} ignored: {message}", bold=False)
3564
3565     @property
3566     def return_code(self) -> int:
3567         """Return the exit code that the app should use.
3568
3569         This considers the current state of changed files and failures:
3570         - if there were any failures, return 123;
3571         - if any files were changed and --check is being used, return 1;
3572         - otherwise return 0.
3573         """
3574         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3575         # 126 we have special return codes reserved by the shell.
3576         if self.failure_count:
3577             return 123
3578
3579         elif self.change_count and self.check:
3580             return 1
3581
3582         return 0
3583
3584     def __str__(self) -> str:
3585         """Render a color report of the current state.
3586
3587         Use `click.unstyle` to remove colors.
3588         """
3589         if self.check:
3590             reformatted = "would be reformatted"
3591             unchanged = "would be left unchanged"
3592             failed = "would fail to reformat"
3593         else:
3594             reformatted = "reformatted"
3595             unchanged = "left unchanged"
3596             failed = "failed to reformat"
3597         report = []
3598         if self.change_count:
3599             s = "s" if self.change_count > 1 else ""
3600             report.append(
3601                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3602             )
3603         if self.same_count:
3604             s = "s" if self.same_count > 1 else ""
3605             report.append(f"{self.same_count} file{s} {unchanged}")
3606         if self.failure_count:
3607             s = "s" if self.failure_count > 1 else ""
3608             report.append(
3609                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3610             )
3611         return ", ".join(report) + "."
3612
3613
3614 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3615     filename = "<unknown>"
3616     if sys.version_info >= (3, 8):
3617         # TODO: support Python 4+ ;)
3618         for minor_version in range(sys.version_info[1], 4, -1):
3619             try:
3620                 return ast.parse(src, filename, feature_version=(3, minor_version))
3621             except SyntaxError:
3622                 continue
3623     else:
3624         for feature_version in (7, 6):
3625             try:
3626                 return ast3.parse(src, filename, feature_version=feature_version)
3627             except SyntaxError:
3628                 continue
3629
3630     return ast27.parse(src)
3631
3632
3633 def _fixup_ast_constants(
3634     node: Union[ast.AST, ast3.AST, ast27.AST]
3635 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3636     """Map ast nodes deprecated in 3.8 to Constant."""
3637     # casts are required until this is released:
3638     # https://github.com/python/typeshed/pull/3142
3639     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3640         return cast(ast.AST, ast.Constant(value=node.s))
3641     elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3642         return cast(ast.AST, ast.Constant(value=node.n))
3643     elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3644         return cast(ast.AST, ast.Constant(value=node.value))
3645     return node
3646
3647
3648 def assert_equivalent(src: str, dst: str) -> None:
3649     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3650
3651     def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3652         """Simple visitor generating strings to compare ASTs by content."""
3653
3654         node = _fixup_ast_constants(node)
3655
3656         yield f"{'  ' * depth}{node.__class__.__name__}("
3657
3658         for field in sorted(node._fields):
3659             # TypeIgnore has only one field 'lineno' which breaks this comparison
3660             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3661             if sys.version_info >= (3, 8):
3662                 type_ignore_classes += (ast.TypeIgnore,)
3663             if isinstance(node, type_ignore_classes):
3664                 break
3665
3666             try:
3667                 value = getattr(node, field)
3668             except AttributeError:
3669                 continue
3670
3671             yield f"{'  ' * (depth+1)}{field}="
3672
3673             if isinstance(value, list):
3674                 for item in value:
3675                     # Ignore nested tuples within del statements, because we may insert
3676                     # parentheses and they change the AST.
3677                     if (
3678                         field == "targets"
3679                         and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3680                         and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3681                     ):
3682                         for item in item.elts:
3683                             yield from _v(item, depth + 2)
3684                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3685                         yield from _v(item, depth + 2)
3686
3687             elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3688                 yield from _v(value, depth + 2)
3689
3690             else:
3691                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3692
3693         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3694
3695     try:
3696         src_ast = parse_ast(src)
3697     except Exception as exc:
3698         raise AssertionError(
3699             f"cannot use --safe with this file; failed to parse source file.  "
3700             f"AST error message: {exc}"
3701         )
3702
3703     try:
3704         dst_ast = parse_ast(dst)
3705     except Exception as exc:
3706         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3707         raise AssertionError(
3708             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3709             f"Please report a bug on https://github.com/psf/black/issues.  "
3710             f"This invalid output might be helpful: {log}"
3711         ) from None
3712
3713     src_ast_str = "\n".join(_v(src_ast))
3714     dst_ast_str = "\n".join(_v(dst_ast))
3715     if src_ast_str != dst_ast_str:
3716         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3717         raise AssertionError(
3718             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3719             f"the source.  "
3720             f"Please report a bug on https://github.com/psf/black/issues.  "
3721             f"This diff might be helpful: {log}"
3722         ) from None
3723
3724
3725 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3726     """Raise AssertionError if `dst` reformats differently the second time."""
3727     newdst = format_str(dst, mode=mode)
3728     if dst != newdst:
3729         log = dump_to_file(
3730             diff(src, dst, "source", "first pass"),
3731             diff(dst, newdst, "first pass", "second pass"),
3732         )
3733         raise AssertionError(
3734             f"INTERNAL ERROR: Black produced different code on the second pass "
3735             f"of the formatter.  "
3736             f"Please report a bug on https://github.com/psf/black/issues.  "
3737             f"This diff might be helpful: {log}"
3738         ) from None
3739
3740
3741 def dump_to_file(*output: str) -> str:
3742     """Dump `output` to a temporary file. Return path to the file."""
3743     with tempfile.NamedTemporaryFile(
3744         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3745     ) as f:
3746         for lines in output:
3747             f.write(lines)
3748             if lines and lines[-1] != "\n":
3749                 f.write("\n")
3750     return f.name
3751
3752
3753 @contextmanager
3754 def nullcontext() -> Iterator[None]:
3755     """Return context manager that does nothing.
3756     Similar to `nullcontext` from python 3.7"""
3757     yield
3758
3759
3760 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3761     """Return a unified diff string between strings `a` and `b`."""
3762     import difflib
3763
3764     a_lines = [line + "\n" for line in a.split("\n")]
3765     b_lines = [line + "\n" for line in b.split("\n")]
3766     return "".join(
3767         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3768     )
3769
3770
3771 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3772     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3773     err("Aborted!")
3774     for task in tasks:
3775         task.cancel()
3776
3777
3778 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3779     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3780     try:
3781         if sys.version_info[:2] >= (3, 7):
3782             all_tasks = asyncio.all_tasks
3783         else:
3784             all_tasks = asyncio.Task.all_tasks
3785         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3786         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3787         if not to_cancel:
3788             return
3789
3790         for task in to_cancel:
3791             task.cancel()
3792         loop.run_until_complete(
3793             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3794         )
3795     finally:
3796         # `concurrent.futures.Future` objects cannot be cancelled once they
3797         # are already running. There might be some when the `shutdown()` happened.
3798         # Silence their logger's spew about the event loop being closed.
3799         cf_logger = logging.getLogger("concurrent.futures")
3800         cf_logger.setLevel(logging.CRITICAL)
3801         loop.close()
3802
3803
3804 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3805     """Replace `regex` with `replacement` twice on `original`.
3806
3807     This is used by string normalization to perform replaces on
3808     overlapping matches.
3809     """
3810     return regex.sub(replacement, regex.sub(replacement, original))
3811
3812
3813 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3814     """Compile a regular expression string in `regex`.
3815
3816     If it contains newlines, use verbose mode.
3817     """
3818     if "\n" in regex:
3819         regex = "(?x)" + regex
3820     compiled: Pattern[str] = re.compile(regex)
3821     return compiled
3822
3823
3824 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3825     """Like `reversed(enumerate(sequence))` if that were possible."""
3826     index = len(sequence) - 1
3827     for element in reversed(sequence):
3828         yield (index, element)
3829         index -= 1
3830
3831
3832 def enumerate_with_length(
3833     line: Line, reversed: bool = False
3834 ) -> Iterator[Tuple[Index, Leaf, int]]:
3835     """Return an enumeration of leaves with their length.
3836
3837     Stops prematurely on multiline strings and standalone comments.
3838     """
3839     op = cast(
3840         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3841         enumerate_reversed if reversed else enumerate,
3842     )
3843     for index, leaf in op(line.leaves):
3844         length = len(leaf.prefix) + len(leaf.value)
3845         if "\n" in leaf.value:
3846             return  # Multiline strings, we can't continue.
3847
3848         for comment in line.comments_after(leaf):
3849             length += len(comment.value)
3850
3851         yield index, leaf, length
3852
3853
3854 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3855     """Return True if `line` is no longer than `line_length`.
3856
3857     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3858     """
3859     if not line_str:
3860         line_str = str(line).strip("\n")
3861     return (
3862         len(line_str) <= line_length
3863         and "\n" not in line_str  # multiline strings
3864         and not line.contains_standalone_comments()
3865     )
3866
3867
3868 def can_be_split(line: Line) -> bool:
3869     """Return False if the line cannot be split *for sure*.
3870
3871     This is not an exhaustive search but a cheap heuristic that we can use to
3872     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3873     in unnecessary parentheses).
3874     """
3875     leaves = line.leaves
3876     if len(leaves) < 2:
3877         return False
3878
3879     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3880         call_count = 0
3881         dot_count = 0
3882         next = leaves[-1]
3883         for leaf in leaves[-2::-1]:
3884             if leaf.type in OPENING_BRACKETS:
3885                 if next.type not in CLOSING_BRACKETS:
3886                     return False
3887
3888                 call_count += 1
3889             elif leaf.type == token.DOT:
3890                 dot_count += 1
3891             elif leaf.type == token.NAME:
3892                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3893                     return False
3894
3895             elif leaf.type not in CLOSING_BRACKETS:
3896                 return False
3897
3898             if dot_count > 1 and call_count > 1:
3899                 return False
3900
3901     return True
3902
3903
3904 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3905     """Does `line` have a shape safe to reformat without optional parens around it?
3906
3907     Returns True for only a subset of potentially nice looking formattings but
3908     the point is to not return false positives that end up producing lines that
3909     are too long.
3910     """
3911     bt = line.bracket_tracker
3912     if not bt.delimiters:
3913         # Without delimiters the optional parentheses are useless.
3914         return True
3915
3916     max_priority = bt.max_delimiter_priority()
3917     if bt.delimiter_count_with_priority(max_priority) > 1:
3918         # With more than one delimiter of a kind the optional parentheses read better.
3919         return False
3920
3921     if max_priority == DOT_PRIORITY:
3922         # A single stranded method call doesn't require optional parentheses.
3923         return True
3924
3925     assert len(line.leaves) >= 2, "Stranded delimiter"
3926
3927     first = line.leaves[0]
3928     second = line.leaves[1]
3929     penultimate = line.leaves[-2]
3930     last = line.leaves[-1]
3931
3932     # With a single delimiter, omit if the expression starts or ends with
3933     # a bracket.
3934     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3935         remainder = False
3936         length = 4 * line.depth
3937         for _index, leaf, leaf_length in enumerate_with_length(line):
3938             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3939                 remainder = True
3940             if remainder:
3941                 length += leaf_length
3942                 if length > line_length:
3943                     break
3944
3945                 if leaf.type in OPENING_BRACKETS:
3946                     # There are brackets we can further split on.
3947                     remainder = False
3948
3949         else:
3950             # checked the entire string and line length wasn't exceeded
3951             if len(line.leaves) == _index + 1:
3952                 return True
3953
3954         # Note: we are not returning False here because a line might have *both*
3955         # a leading opening bracket and a trailing closing bracket.  If the
3956         # opening bracket doesn't match our rule, maybe the closing will.
3957
3958     if (
3959         last.type == token.RPAR
3960         or last.type == token.RBRACE
3961         or (
3962             # don't use indexing for omitting optional parentheses;
3963             # it looks weird
3964             last.type == token.RSQB
3965             and last.parent
3966             and last.parent.type != syms.trailer
3967         )
3968     ):
3969         if penultimate.type in OPENING_BRACKETS:
3970             # Empty brackets don't help.
3971             return False
3972
3973         if is_multiline_string(first):
3974             # Additional wrapping of a multiline string in this situation is
3975             # unnecessary.
3976             return True
3977
3978         length = 4 * line.depth
3979         seen_other_brackets = False
3980         for _index, leaf, leaf_length in enumerate_with_length(line):
3981             length += leaf_length
3982             if leaf is last.opening_bracket:
3983                 if seen_other_brackets or length <= line_length:
3984                     return True
3985
3986             elif leaf.type in OPENING_BRACKETS:
3987                 # There are brackets we can further split on.
3988                 seen_other_brackets = True
3989
3990     return False
3991
3992
3993 def get_cache_file(mode: FileMode) -> Path:
3994     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3995
3996
3997 def read_cache(mode: FileMode) -> Cache:
3998     """Read the cache if it exists and is well formed.
3999
4000     If it is not well formed, the call to write_cache later should resolve the issue.
4001     """
4002     cache_file = get_cache_file(mode)
4003     if not cache_file.exists():
4004         return {}
4005
4006     with cache_file.open("rb") as fobj:
4007         try:
4008             cache: Cache = pickle.load(fobj)
4009         except (pickle.UnpicklingError, ValueError):
4010             return {}
4011
4012     return cache
4013
4014
4015 def get_cache_info(path: Path) -> CacheInfo:
4016     """Return the information used to check if a file is already formatted or not."""
4017     stat = path.stat()
4018     return stat.st_mtime, stat.st_size
4019
4020
4021 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
4022     """Split an iterable of paths in `sources` into two sets.
4023
4024     The first contains paths of files that modified on disk or are not in the
4025     cache. The other contains paths to non-modified files.
4026     """
4027     todo, done = set(), set()
4028     for src in sources:
4029         src = src.resolve()
4030         if cache.get(src) != get_cache_info(src):
4031             todo.add(src)
4032         else:
4033             done.add(src)
4034     return todo, done
4035
4036
4037 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
4038     """Update the cache file."""
4039     cache_file = get_cache_file(mode)
4040     try:
4041         CACHE_DIR.mkdir(parents=True, exist_ok=True)
4042         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
4043         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4044             pickle.dump(new_cache, f, protocol=4)
4045         os.replace(f.name, cache_file)
4046     except OSError:
4047         pass
4048
4049
4050 def patch_click() -> None:
4051     """Make Click not crash.
4052
4053     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4054     default which restricts paths that it can access during the lifetime of the
4055     application.  Click refuses to work in this scenario by raising a RuntimeError.
4056
4057     In case of Black the likelihood that non-ASCII characters are going to be used in
4058     file paths is minimal since it's Python source code.  Moreover, this crash was
4059     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4060     """
4061     try:
4062         from click import core
4063         from click import _unicodefun  # type: ignore
4064     except ModuleNotFoundError:
4065         return
4066
4067     for module in (core, _unicodefun):
4068         if hasattr(module, "_verify_python3_env"):
4069             module._verify_python3_env = lambda: None
4070
4071
4072 def patched_main() -> None:
4073     freeze_support()
4074     patch_click()
4075     main()
4076
4077
4078 if __name__ == "__main__":
4079     patched_main()