]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

ff373c863d45b71849cb0ae0b1eb800af54c10a7
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from contextlib import contextmanager
5 from datetime import datetime
6 from enum import Enum
7 from functools import lru_cache, partial, wraps
8 import io
9 import itertools
10 import logging
11 from multiprocessing import Manager, freeze_support
12 import os
13 from pathlib import Path
14 import pickle
15 import regex as re
16 import signal
17 import sys
18 import tempfile
19 import tokenize
20 import traceback
21 from typing import (
22     Any,
23     Callable,
24     Collection,
25     Dict,
26     Generator,
27     Generic,
28     Iterable,
29     Iterator,
30     List,
31     Optional,
32     Pattern,
33     Sequence,
34     Set,
35     Tuple,
36     TypeVar,
37     Union,
38     cast,
39 )
40
41 from appdirs import user_cache_dir
42 from attr import dataclass, evolve, Factory
43 import click
44 import toml
45 from typed_ast import ast3, ast27
46
47 # lib2to3 fork
48 from blib2to3.pytree import Node, Leaf, type_repr
49 from blib2to3 import pygram, pytree
50 from blib2to3.pgen2 import driver, token
51 from blib2to3.pgen2.grammar import Grammar
52 from blib2to3.pgen2.parse import ParseError
53
54 from _version import version as __version__
55
56 DEFAULT_LINE_LENGTH = 88
57 DEFAULT_EXCLUDES = (
58     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
59 )
60 DEFAULT_INCLUDES = r"\.pyi?$"
61 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
62
63
64 # types
65 FileContent = str
66 Encoding = str
67 NewLine = str
68 Depth = int
69 NodeType = int
70 LeafID = int
71 Priority = int
72 Index = int
73 LN = Union[Leaf, Node]
74 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
75 Timestamp = float
76 FileSize = int
77 CacheInfo = Tuple[Timestamp, FileSize]
78 Cache = Dict[Path, CacheInfo]
79 out = partial(click.secho, bold=True, err=True)
80 err = partial(click.secho, fg="red", err=True)
81
82 pygram.initialize(CACHE_DIR)
83 syms = pygram.python_symbols
84
85
86 class NothingChanged(UserWarning):
87     """Raised when reformatted code is the same as source."""
88
89
90 class CannotSplit(Exception):
91     """A readable split that fits the allotted line length is impossible."""
92
93
94 class InvalidInput(ValueError):
95     """Raised when input source code fails all parse attempts."""
96
97
98 class WriteBack(Enum):
99     NO = 0
100     YES = 1
101     DIFF = 2
102     CHECK = 3
103
104     @classmethod
105     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
106         if check and not diff:
107             return cls.CHECK
108
109         return cls.DIFF if diff else cls.YES
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 class TargetVersion(Enum):
119     PY27 = 2
120     PY33 = 3
121     PY34 = 4
122     PY35 = 5
123     PY36 = 6
124     PY37 = 7
125     PY38 = 8
126
127     def is_python2(self) -> bool:
128         return self is TargetVersion.PY27
129
130
131 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
132
133
134 class Feature(Enum):
135     # All string literals are unicode
136     UNICODE_LITERALS = 1
137     F_STRINGS = 2
138     NUMERIC_UNDERSCORES = 3
139     TRAILING_COMMA_IN_CALL = 4
140     TRAILING_COMMA_IN_DEF = 5
141     # The following two feature-flags are mutually exclusive, and exactly one should be
142     # set for every version of python.
143     ASYNC_IDENTIFIERS = 6
144     ASYNC_KEYWORDS = 7
145     ASSIGNMENT_EXPRESSIONS = 8
146     POS_ONLY_ARGUMENTS = 9
147
148
149 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
150     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
151     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
152     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
153     TargetVersion.PY35: {
154         Feature.UNICODE_LITERALS,
155         Feature.TRAILING_COMMA_IN_CALL,
156         Feature.ASYNC_IDENTIFIERS,
157     },
158     TargetVersion.PY36: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA_IN_CALL,
163         Feature.TRAILING_COMMA_IN_DEF,
164         Feature.ASYNC_IDENTIFIERS,
165     },
166     TargetVersion.PY37: {
167         Feature.UNICODE_LITERALS,
168         Feature.F_STRINGS,
169         Feature.NUMERIC_UNDERSCORES,
170         Feature.TRAILING_COMMA_IN_CALL,
171         Feature.TRAILING_COMMA_IN_DEF,
172         Feature.ASYNC_KEYWORDS,
173     },
174     TargetVersion.PY38: {
175         Feature.UNICODE_LITERALS,
176         Feature.F_STRINGS,
177         Feature.NUMERIC_UNDERSCORES,
178         Feature.TRAILING_COMMA_IN_CALL,
179         Feature.TRAILING_COMMA_IN_DEF,
180         Feature.ASYNC_KEYWORDS,
181         Feature.ASSIGNMENT_EXPRESSIONS,
182         Feature.POS_ONLY_ARGUMENTS,
183     },
184 }
185
186
187 @dataclass
188 class FileMode:
189     target_versions: Set[TargetVersion] = Factory(set)
190     line_length: int = DEFAULT_LINE_LENGTH
191     string_normalization: bool = True
192     is_pyi: bool = False
193
194     def get_cache_key(self) -> str:
195         if self.target_versions:
196             version_str = ",".join(
197                 str(version.value)
198                 for version in sorted(self.target_versions, key=lambda v: v.value)
199             )
200         else:
201             version_str = "-"
202         parts = [
203             version_str,
204             str(self.line_length),
205             str(int(self.string_normalization)),
206             str(int(self.is_pyi)),
207         ]
208         return ".".join(parts)
209
210
211 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
212     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
213
214
215 def read_pyproject_toml(
216     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
217 ) -> Optional[str]:
218     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
219
220     Returns the path to a successfully found and read configuration file, None
221     otherwise.
222     """
223     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
224     if not value:
225         root = find_project_root(ctx.params.get("src", ()))
226         path = root / "pyproject.toml"
227         if path.is_file():
228             value = str(path)
229         else:
230             return None
231
232     try:
233         pyproject_toml = toml.load(value)
234         config = pyproject_toml.get("tool", {}).get("black", {})
235     except (toml.TomlDecodeError, OSError) as e:
236         raise click.FileError(
237             filename=value, hint=f"Error reading configuration file: {e}"
238         )
239
240     if not config:
241         return None
242
243     if ctx.default_map is None:
244         ctx.default_map = {}
245     ctx.default_map.update(  # type: ignore  # bad types in .pyi
246         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
247     )
248     return value
249
250
251 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
252 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
253 @click.option(
254     "-l",
255     "--line-length",
256     type=int,
257     default=DEFAULT_LINE_LENGTH,
258     help="How many characters per line to allow.",
259     show_default=True,
260 )
261 @click.option(
262     "-t",
263     "--target-version",
264     type=click.Choice([v.name.lower() for v in TargetVersion]),
265     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
266     multiple=True,
267     help=(
268         "Python versions that should be supported by Black's output. [default: "
269         "per-file auto-detection]"
270     ),
271 )
272 @click.option(
273     "--py36",
274     is_flag=True,
275     help=(
276         "Allow using Python 3.6-only syntax on all input files.  This will put "
277         "trailing commas in function signatures and calls also after *args and "
278         "**kwargs. Deprecated; use --target-version instead. "
279         "[default: per-file auto-detection]"
280     ),
281 )
282 @click.option(
283     "--pyi",
284     is_flag=True,
285     help=(
286         "Format all input files like typing stubs regardless of file extension "
287         "(useful when piping source on standard input)."
288     ),
289 )
290 @click.option(
291     "-S",
292     "--skip-string-normalization",
293     is_flag=True,
294     help="Don't normalize string quotes or prefixes.",
295 )
296 @click.option(
297     "--check",
298     is_flag=True,
299     help=(
300         "Don't write the files back, just return the status.  Return code 0 "
301         "means nothing would change.  Return code 1 means some files would be "
302         "reformatted.  Return code 123 means there was an internal error."
303     ),
304 )
305 @click.option(
306     "--diff",
307     is_flag=True,
308     help="Don't write the files back, just output a diff for each file on stdout.",
309 )
310 @click.option(
311     "--fast/--safe",
312     is_flag=True,
313     help="If --fast given, skip temporary sanity checks. [default: --safe]",
314 )
315 @click.option(
316     "--include",
317     type=str,
318     default=DEFAULT_INCLUDES,
319     help=(
320         "A regular expression that matches files and directories that should be "
321         "included on recursive searches.  An empty value means all files are "
322         "included regardless of the name.  Use forward slashes for directories on "
323         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
324         "later."
325     ),
326     show_default=True,
327 )
328 @click.option(
329     "--exclude",
330     type=str,
331     default=DEFAULT_EXCLUDES,
332     help=(
333         "A regular expression that matches files and directories that should be "
334         "excluded on recursive searches.  An empty value means no paths are excluded. "
335         "Use forward slashes for directories on all platforms (Windows, too).  "
336         "Exclusions are calculated first, inclusions later."
337     ),
338     show_default=True,
339 )
340 @click.option(
341     "-q",
342     "--quiet",
343     is_flag=True,
344     help=(
345         "Don't emit non-error messages to stderr. Errors are still emitted; "
346         "silence those with 2>/dev/null."
347     ),
348 )
349 @click.option(
350     "-v",
351     "--verbose",
352     is_flag=True,
353     help=(
354         "Also emit messages to stderr about files that were not changed or were "
355         "ignored due to --exclude=."
356     ),
357 )
358 @click.version_option(version=__version__)
359 @click.argument(
360     "src",
361     nargs=-1,
362     type=click.Path(
363         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
364     ),
365     is_eager=True,
366 )
367 @click.option(
368     "--config",
369     type=click.Path(
370         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
371     ),
372     is_eager=True,
373     callback=read_pyproject_toml,
374     help="Read configuration from PATH.",
375 )
376 @click.pass_context
377 def main(
378     ctx: click.Context,
379     code: Optional[str],
380     line_length: int,
381     target_version: List[TargetVersion],
382     check: bool,
383     diff: bool,
384     fast: bool,
385     pyi: bool,
386     py36: bool,
387     skip_string_normalization: bool,
388     quiet: bool,
389     verbose: bool,
390     include: str,
391     exclude: str,
392     src: Tuple[str],
393     config: Optional[str],
394 ) -> None:
395     """The uncompromising code formatter."""
396     write_back = WriteBack.from_configuration(check=check, diff=diff)
397     if target_version:
398         if py36:
399             err(f"Cannot use both --target-version and --py36")
400             ctx.exit(2)
401         else:
402             versions = set(target_version)
403     elif py36:
404         err(
405             "--py36 is deprecated and will be removed in a future version. "
406             "Use --target-version py36 instead."
407         )
408         versions = PY36_VERSIONS
409     else:
410         # We'll autodetect later.
411         versions = set()
412     mode = FileMode(
413         target_versions=versions,
414         line_length=line_length,
415         is_pyi=pyi,
416         string_normalization=not skip_string_normalization,
417     )
418     if config and verbose:
419         out(f"Using configuration from {config}.", bold=False, fg="blue")
420     if code is not None:
421         print(format_str(code, mode=mode))
422         ctx.exit(0)
423     try:
424         include_regex = re_compile_maybe_verbose(include)
425     except re.error:
426         err(f"Invalid regular expression for include given: {include!r}")
427         ctx.exit(2)
428     try:
429         exclude_regex = re_compile_maybe_verbose(exclude)
430     except re.error:
431         err(f"Invalid regular expression for exclude given: {exclude!r}")
432         ctx.exit(2)
433     report = Report(check=check, quiet=quiet, verbose=verbose)
434     root = find_project_root(src)
435     sources: Set[Path] = set()
436     path_empty(src, quiet, verbose, ctx)
437     for s in src:
438         p = Path(s)
439         if p.is_dir():
440             sources.update(
441                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
442             )
443         elif p.is_file() or s == "-":
444             # if a file was explicitly given, we don't care about its extension
445             sources.add(p)
446         else:
447             err(f"invalid path: {s}")
448     if len(sources) == 0:
449         if verbose or not quiet:
450             out("No Python files are present to be formatted. Nothing to do 😴")
451         ctx.exit(0)
452
453     if len(sources) == 1:
454         reformat_one(
455             src=sources.pop(),
456             fast=fast,
457             write_back=write_back,
458             mode=mode,
459             report=report,
460         )
461     else:
462         reformat_many(
463             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
464         )
465
466     if verbose or not quiet:
467         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
468         click.secho(str(report), err=True)
469     ctx.exit(report.return_code)
470
471
472 def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
473     """
474     Exit if there is no `src` provided for formatting
475     """
476     if not src:
477         if verbose or not quiet:
478             out("No Path provided. Nothing to do 😴")
479             ctx.exit(0)
480
481
482 def reformat_one(
483     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
484 ) -> None:
485     """Reformat a single file under `src` without spawning child processes.
486
487     `fast`, `write_back`, and `mode` options are passed to
488     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
489     """
490     try:
491         changed = Changed.NO
492         if not src.is_file() and str(src) == "-":
493             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
494                 changed = Changed.YES
495         else:
496             cache: Cache = {}
497             if write_back != WriteBack.DIFF:
498                 cache = read_cache(mode)
499                 res_src = src.resolve()
500                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
501                     changed = Changed.CACHED
502             if changed is not Changed.CACHED and format_file_in_place(
503                 src, fast=fast, write_back=write_back, mode=mode
504             ):
505                 changed = Changed.YES
506             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
507                 write_back is WriteBack.CHECK and changed is Changed.NO
508             ):
509                 write_cache(cache, [src], mode)
510         report.done(src, changed)
511     except Exception as exc:
512         report.failed(src, str(exc))
513
514
515 def reformat_many(
516     sources: Set[Path],
517     fast: bool,
518     write_back: WriteBack,
519     mode: FileMode,
520     report: "Report",
521 ) -> None:
522     """Reformat multiple files using a ProcessPoolExecutor."""
523     loop = asyncio.get_event_loop()
524     worker_count = os.cpu_count()
525     if sys.platform == "win32":
526         # Work around https://bugs.python.org/issue26903
527         worker_count = min(worker_count, 61)
528     executor = ProcessPoolExecutor(max_workers=worker_count)
529     try:
530         loop.run_until_complete(
531             schedule_formatting(
532                 sources=sources,
533                 fast=fast,
534                 write_back=write_back,
535                 mode=mode,
536                 report=report,
537                 loop=loop,
538                 executor=executor,
539             )
540         )
541     finally:
542         shutdown(loop)
543         executor.shutdown()
544
545
546 async def schedule_formatting(
547     sources: Set[Path],
548     fast: bool,
549     write_back: WriteBack,
550     mode: FileMode,
551     report: "Report",
552     loop: asyncio.AbstractEventLoop,
553     executor: Executor,
554 ) -> None:
555     """Run formatting of `sources` in parallel using the provided `executor`.
556
557     (Use ProcessPoolExecutors for actual parallelism.)
558
559     `write_back`, `fast`, and `mode` options are passed to
560     :func:`format_file_in_place`.
561     """
562     cache: Cache = {}
563     if write_back != WriteBack.DIFF:
564         cache = read_cache(mode)
565         sources, cached = filter_cached(cache, sources)
566         for src in sorted(cached):
567             report.done(src, Changed.CACHED)
568     if not sources:
569         return
570
571     cancelled = []
572     sources_to_cache = []
573     lock = None
574     if write_back == WriteBack.DIFF:
575         # For diff output, we need locks to ensure we don't interleave output
576         # from different processes.
577         manager = Manager()
578         lock = manager.Lock()
579     tasks = {
580         asyncio.ensure_future(
581             loop.run_in_executor(
582                 executor, format_file_in_place, src, fast, mode, write_back, lock
583             )
584         ): src
585         for src in sorted(sources)
586     }
587     pending: Iterable[asyncio.Future] = tasks.keys()
588     try:
589         loop.add_signal_handler(signal.SIGINT, cancel, pending)
590         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
591     except NotImplementedError:
592         # There are no good alternatives for these on Windows.
593         pass
594     while pending:
595         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
596         for task in done:
597             src = tasks.pop(task)
598             if task.cancelled():
599                 cancelled.append(task)
600             elif task.exception():
601                 report.failed(src, str(task.exception()))
602             else:
603                 changed = Changed.YES if task.result() else Changed.NO
604                 # If the file was written back or was successfully checked as
605                 # well-formatted, store this information in the cache.
606                 if write_back is WriteBack.YES or (
607                     write_back is WriteBack.CHECK and changed is Changed.NO
608                 ):
609                     sources_to_cache.append(src)
610                 report.done(src, changed)
611     if cancelled:
612         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
613     if sources_to_cache:
614         write_cache(cache, sources_to_cache, mode)
615
616
617 def format_file_in_place(
618     src: Path,
619     fast: bool,
620     mode: FileMode,
621     write_back: WriteBack = WriteBack.NO,
622     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
623 ) -> bool:
624     """Format file under `src` path. Return True if changed.
625
626     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
627     code to the file.
628     `mode` and `fast` options are passed to :func:`format_file_contents`.
629     """
630     if src.suffix == ".pyi":
631         mode = evolve(mode, is_pyi=True)
632
633     then = datetime.utcfromtimestamp(src.stat().st_mtime)
634     with open(src, "rb") as buf:
635         src_contents, encoding, newline = decode_bytes(buf.read())
636     try:
637         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
638     except NothingChanged:
639         return False
640
641     if write_back == write_back.YES:
642         with open(src, "w", encoding=encoding, newline=newline) as f:
643             f.write(dst_contents)
644     elif write_back == write_back.DIFF:
645         now = datetime.utcnow()
646         src_name = f"{src}\t{then} +0000"
647         dst_name = f"{src}\t{now} +0000"
648         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
649
650         with lock or nullcontext():
651             f = io.TextIOWrapper(
652                 sys.stdout.buffer,
653                 encoding=encoding,
654                 newline=newline,
655                 write_through=True,
656             )
657             f.write(diff_contents)
658             f.detach()
659
660     return True
661
662
663 def format_stdin_to_stdout(
664     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
665 ) -> bool:
666     """Format file on stdin. Return True if changed.
667
668     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
669     write a diff to stdout. The `mode` argument is passed to
670     :func:`format_file_contents`.
671     """
672     then = datetime.utcnow()
673     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
674     dst = src
675     try:
676         dst = format_file_contents(src, fast=fast, mode=mode)
677         return True
678
679     except NothingChanged:
680         return False
681
682     finally:
683         f = io.TextIOWrapper(
684             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
685         )
686         if write_back == WriteBack.YES:
687             f.write(dst)
688         elif write_back == WriteBack.DIFF:
689             now = datetime.utcnow()
690             src_name = f"STDIN\t{then} +0000"
691             dst_name = f"STDOUT\t{now} +0000"
692             f.write(diff(src, dst, src_name, dst_name))
693         f.detach()
694
695
696 def format_file_contents(
697     src_contents: str, *, fast: bool, mode: FileMode
698 ) -> FileContent:
699     """Reformat contents a file and return new contents.
700
701     If `fast` is False, additionally confirm that the reformatted code is
702     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
703     `mode` is passed to :func:`format_str`.
704     """
705     if src_contents.strip() == "":
706         raise NothingChanged
707
708     dst_contents = format_str(src_contents, mode=mode)
709     if src_contents == dst_contents:
710         raise NothingChanged
711
712     if not fast:
713         assert_equivalent(src_contents, dst_contents)
714         assert_stable(src_contents, dst_contents, mode=mode)
715     return dst_contents
716
717
718 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
719     """Reformat a string and return new contents.
720
721     `mode` determines formatting options, such as how many characters per line are
722     allowed.
723     """
724     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
725     dst_contents = []
726     future_imports = get_future_imports(src_node)
727     if mode.target_versions:
728         versions = mode.target_versions
729     else:
730         versions = detect_target_versions(src_node)
731     normalize_fmt_off(src_node)
732     lines = LineGenerator(
733         remove_u_prefix="unicode_literals" in future_imports
734         or supports_feature(versions, Feature.UNICODE_LITERALS),
735         is_pyi=mode.is_pyi,
736         normalize_strings=mode.string_normalization,
737     )
738     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
739     empty_line = Line()
740     after = 0
741     split_line_features = {
742         feature
743         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
744         if supports_feature(versions, feature)
745     }
746     for current_line in lines.visit(src_node):
747         for _ in range(after):
748             dst_contents.append(str(empty_line))
749         before, after = elt.maybe_empty_lines(current_line)
750         for _ in range(before):
751             dst_contents.append(str(empty_line))
752         for line in split_line(
753             current_line, line_length=mode.line_length, features=split_line_features
754         ):
755             dst_contents.append(str(line))
756     return "".join(dst_contents)
757
758
759 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
760     """Return a tuple of (decoded_contents, encoding, newline).
761
762     `newline` is either CRLF or LF but `decoded_contents` is decoded with
763     universal newlines (i.e. only contains LF).
764     """
765     srcbuf = io.BytesIO(src)
766     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
767     if not lines:
768         return "", encoding, "\n"
769
770     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
771     srcbuf.seek(0)
772     with io.TextIOWrapper(srcbuf, encoding) as tiow:
773         return tiow.read(), encoding, newline
774
775
776 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
777     if not target_versions:
778         # No target_version specified, so try all grammars.
779         return [
780             # Python 3.7+
781             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
782             # Python 3.0-3.6
783             pygram.python_grammar_no_print_statement_no_exec_statement,
784             # Python 2.7 with future print_function import
785             pygram.python_grammar_no_print_statement,
786             # Python 2.7
787             pygram.python_grammar,
788         ]
789     elif all(version.is_python2() for version in target_versions):
790         # Python 2-only code, so try Python 2 grammars.
791         return [
792             # Python 2.7 with future print_function import
793             pygram.python_grammar_no_print_statement,
794             # Python 2.7
795             pygram.python_grammar,
796         ]
797     else:
798         # Python 3-compatible code, so only try Python 3 grammar.
799         grammars = []
800         # If we have to parse both, try to parse async as a keyword first
801         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
802             # Python 3.7+
803             grammars.append(
804                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
805             )
806         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
807             # Python 3.0-3.6
808             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
809         # At least one of the above branches must have been taken, because every Python
810         # version has exactly one of the two 'ASYNC_*' flags
811         return grammars
812
813
814 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
815     """Given a string with source, return the lib2to3 Node."""
816     if src_txt[-1:] != "\n":
817         src_txt += "\n"
818
819     for grammar in get_grammars(set(target_versions)):
820         drv = driver.Driver(grammar, pytree.convert)
821         try:
822             result = drv.parse_string(src_txt, True)
823             break
824
825         except ParseError as pe:
826             lineno, column = pe.context[1]
827             lines = src_txt.splitlines()
828             try:
829                 faulty_line = lines[lineno - 1]
830             except IndexError:
831                 faulty_line = "<line number missing in source>"
832             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
833     else:
834         raise exc from None
835
836     if isinstance(result, Leaf):
837         result = Node(syms.file_input, [result])
838     return result
839
840
841 def lib2to3_unparse(node: Node) -> str:
842     """Given a lib2to3 node, return its string representation."""
843     code = str(node)
844     return code
845
846
847 T = TypeVar("T")
848
849
850 class Visitor(Generic[T]):
851     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
852
853     def visit(self, node: LN) -> Iterator[T]:
854         """Main method to visit `node` and its children.
855
856         It tries to find a `visit_*()` method for the given `node.type`, like
857         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
858         If no dedicated `visit_*()` method is found, chooses `visit_default()`
859         instead.
860
861         Then yields objects of type `T` from the selected visitor.
862         """
863         if node.type < 256:
864             name = token.tok_name[node.type]
865         else:
866             name = type_repr(node.type)
867         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
868
869     def visit_default(self, node: LN) -> Iterator[T]:
870         """Default `visit_*()` implementation. Recurses to children of `node`."""
871         if isinstance(node, Node):
872             for child in node.children:
873                 yield from self.visit(child)
874
875
876 @dataclass
877 class DebugVisitor(Visitor[T]):
878     tree_depth: int = 0
879
880     def visit_default(self, node: LN) -> Iterator[T]:
881         indent = " " * (2 * self.tree_depth)
882         if isinstance(node, Node):
883             _type = type_repr(node.type)
884             out(f"{indent}{_type}", fg="yellow")
885             self.tree_depth += 1
886             for child in node.children:
887                 yield from self.visit(child)
888
889             self.tree_depth -= 1
890             out(f"{indent}/{_type}", fg="yellow", bold=False)
891         else:
892             _type = token.tok_name.get(node.type, str(node.type))
893             out(f"{indent}{_type}", fg="blue", nl=False)
894             if node.prefix:
895                 # We don't have to handle prefixes for `Node` objects since
896                 # that delegates to the first child anyway.
897                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
898             out(f" {node.value!r}", fg="blue", bold=False)
899
900     @classmethod
901     def show(cls, code: Union[str, Leaf, Node]) -> None:
902         """Pretty-print the lib2to3 AST of a given string of `code`.
903
904         Convenience method for debugging.
905         """
906         v: DebugVisitor[None] = DebugVisitor()
907         if isinstance(code, str):
908             code = lib2to3_parse(code)
909         list(v.visit(code))
910
911
912 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
913 STATEMENT = {
914     syms.if_stmt,
915     syms.while_stmt,
916     syms.for_stmt,
917     syms.try_stmt,
918     syms.except_clause,
919     syms.with_stmt,
920     syms.funcdef,
921     syms.classdef,
922 }
923 STANDALONE_COMMENT = 153
924 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
925 LOGIC_OPERATORS = {"and", "or"}
926 COMPARATORS = {
927     token.LESS,
928     token.GREATER,
929     token.EQEQUAL,
930     token.NOTEQUAL,
931     token.LESSEQUAL,
932     token.GREATEREQUAL,
933 }
934 MATH_OPERATORS = {
935     token.VBAR,
936     token.CIRCUMFLEX,
937     token.AMPER,
938     token.LEFTSHIFT,
939     token.RIGHTSHIFT,
940     token.PLUS,
941     token.MINUS,
942     token.STAR,
943     token.SLASH,
944     token.DOUBLESLASH,
945     token.PERCENT,
946     token.AT,
947     token.TILDE,
948     token.DOUBLESTAR,
949 }
950 STARS = {token.STAR, token.DOUBLESTAR}
951 VARARGS_SPECIALS = STARS | {token.SLASH}
952 VARARGS_PARENTS = {
953     syms.arglist,
954     syms.argument,  # double star in arglist
955     syms.trailer,  # single argument to call
956     syms.typedargslist,
957     syms.varargslist,  # lambdas
958 }
959 UNPACKING_PARENTS = {
960     syms.atom,  # single element of a list or set literal
961     syms.dictsetmaker,
962     syms.listmaker,
963     syms.testlist_gexp,
964     syms.testlist_star_expr,
965 }
966 TEST_DESCENDANTS = {
967     syms.test,
968     syms.lambdef,
969     syms.or_test,
970     syms.and_test,
971     syms.not_test,
972     syms.comparison,
973     syms.star_expr,
974     syms.expr,
975     syms.xor_expr,
976     syms.and_expr,
977     syms.shift_expr,
978     syms.arith_expr,
979     syms.trailer,
980     syms.term,
981     syms.power,
982 }
983 ASSIGNMENTS = {
984     "=",
985     "+=",
986     "-=",
987     "*=",
988     "@=",
989     "/=",
990     "%=",
991     "&=",
992     "|=",
993     "^=",
994     "<<=",
995     ">>=",
996     "**=",
997     "//=",
998 }
999 COMPREHENSION_PRIORITY = 20
1000 COMMA_PRIORITY = 18
1001 TERNARY_PRIORITY = 16
1002 LOGIC_PRIORITY = 14
1003 STRING_PRIORITY = 12
1004 COMPARATOR_PRIORITY = 10
1005 MATH_PRIORITIES = {
1006     token.VBAR: 9,
1007     token.CIRCUMFLEX: 8,
1008     token.AMPER: 7,
1009     token.LEFTSHIFT: 6,
1010     token.RIGHTSHIFT: 6,
1011     token.PLUS: 5,
1012     token.MINUS: 5,
1013     token.STAR: 4,
1014     token.SLASH: 4,
1015     token.DOUBLESLASH: 4,
1016     token.PERCENT: 4,
1017     token.AT: 4,
1018     token.TILDE: 3,
1019     token.DOUBLESTAR: 2,
1020 }
1021 DOT_PRIORITY = 1
1022
1023
1024 @dataclass
1025 class BracketTracker:
1026     """Keeps track of brackets on a line."""
1027
1028     depth: int = 0
1029     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1030     delimiters: Dict[LeafID, Priority] = Factory(dict)
1031     previous: Optional[Leaf] = None
1032     _for_loop_depths: List[int] = Factory(list)
1033     _lambda_argument_depths: List[int] = Factory(list)
1034
1035     def mark(self, leaf: Leaf) -> None:
1036         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1037
1038         All leaves receive an int `bracket_depth` field that stores how deep
1039         within brackets a given leaf is. 0 means there are no enclosing brackets
1040         that started on this line.
1041
1042         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1043         field that it forms a pair with. This is a one-directional link to
1044         avoid reference cycles.
1045
1046         If a leaf is a delimiter (a token on which Black can split the line if
1047         needed) and it's on depth 0, its `id()` is stored in the tracker's
1048         `delimiters` field.
1049         """
1050         if leaf.type == token.COMMENT:
1051             return
1052
1053         self.maybe_decrement_after_for_loop_variable(leaf)
1054         self.maybe_decrement_after_lambda_arguments(leaf)
1055         if leaf.type in CLOSING_BRACKETS:
1056             self.depth -= 1
1057             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1058             leaf.opening_bracket = opening_bracket
1059         leaf.bracket_depth = self.depth
1060         if self.depth == 0:
1061             delim = is_split_before_delimiter(leaf, self.previous)
1062             if delim and self.previous is not None:
1063                 self.delimiters[id(self.previous)] = delim
1064             else:
1065                 delim = is_split_after_delimiter(leaf, self.previous)
1066                 if delim:
1067                     self.delimiters[id(leaf)] = delim
1068         if leaf.type in OPENING_BRACKETS:
1069             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1070             self.depth += 1
1071         self.previous = leaf
1072         self.maybe_increment_lambda_arguments(leaf)
1073         self.maybe_increment_for_loop_variable(leaf)
1074
1075     def any_open_brackets(self) -> bool:
1076         """Return True if there is an yet unmatched open bracket on the line."""
1077         return bool(self.bracket_match)
1078
1079     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1080         """Return the highest priority of a delimiter found on the line.
1081
1082         Values are consistent with what `is_split_*_delimiter()` return.
1083         Raises ValueError on no delimiters.
1084         """
1085         return max(v for k, v in self.delimiters.items() if k not in exclude)
1086
1087     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1088         """Return the number of delimiters with the given `priority`.
1089
1090         If no `priority` is passed, defaults to max priority on the line.
1091         """
1092         if not self.delimiters:
1093             return 0
1094
1095         priority = priority or self.max_delimiter_priority()
1096         return sum(1 for p in self.delimiters.values() if p == priority)
1097
1098     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1099         """In a for loop, or comprehension, the variables are often unpacks.
1100
1101         To avoid splitting on the comma in this situation, increase the depth of
1102         tokens between `for` and `in`.
1103         """
1104         if leaf.type == token.NAME and leaf.value == "for":
1105             self.depth += 1
1106             self._for_loop_depths.append(self.depth)
1107             return True
1108
1109         return False
1110
1111     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1112         """See `maybe_increment_for_loop_variable` above for explanation."""
1113         if (
1114             self._for_loop_depths
1115             and self._for_loop_depths[-1] == self.depth
1116             and leaf.type == token.NAME
1117             and leaf.value == "in"
1118         ):
1119             self.depth -= 1
1120             self._for_loop_depths.pop()
1121             return True
1122
1123         return False
1124
1125     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1126         """In a lambda expression, there might be more than one argument.
1127
1128         To avoid splitting on the comma in this situation, increase the depth of
1129         tokens between `lambda` and `:`.
1130         """
1131         if leaf.type == token.NAME and leaf.value == "lambda":
1132             self.depth += 1
1133             self._lambda_argument_depths.append(self.depth)
1134             return True
1135
1136         return False
1137
1138     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1139         """See `maybe_increment_lambda_arguments` above for explanation."""
1140         if (
1141             self._lambda_argument_depths
1142             and self._lambda_argument_depths[-1] == self.depth
1143             and leaf.type == token.COLON
1144         ):
1145             self.depth -= 1
1146             self._lambda_argument_depths.pop()
1147             return True
1148
1149         return False
1150
1151     def get_open_lsqb(self) -> Optional[Leaf]:
1152         """Return the most recent opening square bracket (if any)."""
1153         return self.bracket_match.get((self.depth - 1, token.RSQB))
1154
1155
1156 @dataclass
1157 class Line:
1158     """Holds leaves and comments. Can be printed with `str(line)`."""
1159
1160     depth: int = 0
1161     leaves: List[Leaf] = Factory(list)
1162     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1163     bracket_tracker: BracketTracker = Factory(BracketTracker)
1164     inside_brackets: bool = False
1165     should_explode: bool = False
1166
1167     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1168         """Add a new `leaf` to the end of the line.
1169
1170         Unless `preformatted` is True, the `leaf` will receive a new consistent
1171         whitespace prefix and metadata applied by :class:`BracketTracker`.
1172         Trailing commas are maybe removed, unpacked for loop variables are
1173         demoted from being delimiters.
1174
1175         Inline comments are put aside.
1176         """
1177         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1178         if not has_value:
1179             return
1180
1181         if token.COLON == leaf.type and self.is_class_paren_empty:
1182             del self.leaves[-2:]
1183         if self.leaves and not preformatted:
1184             # Note: at this point leaf.prefix should be empty except for
1185             # imports, for which we only preserve newlines.
1186             leaf.prefix += whitespace(
1187                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1188             )
1189         if self.inside_brackets or not preformatted:
1190             self.bracket_tracker.mark(leaf)
1191             self.maybe_remove_trailing_comma(leaf)
1192         if not self.append_comment(leaf):
1193             self.leaves.append(leaf)
1194
1195     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1196         """Like :func:`append()` but disallow invalid standalone comment structure.
1197
1198         Raises ValueError when any `leaf` is appended after a standalone comment
1199         or when a standalone comment is not the first leaf on the line.
1200         """
1201         if self.bracket_tracker.depth == 0:
1202             if self.is_comment:
1203                 raise ValueError("cannot append to standalone comments")
1204
1205             if self.leaves and leaf.type == STANDALONE_COMMENT:
1206                 raise ValueError(
1207                     "cannot append standalone comments to a populated line"
1208                 )
1209
1210         self.append(leaf, preformatted=preformatted)
1211
1212     @property
1213     def is_comment(self) -> bool:
1214         """Is this line a standalone comment?"""
1215         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1216
1217     @property
1218     def is_decorator(self) -> bool:
1219         """Is this line a decorator?"""
1220         return bool(self) and self.leaves[0].type == token.AT
1221
1222     @property
1223     def is_import(self) -> bool:
1224         """Is this an import line?"""
1225         return bool(self) and is_import(self.leaves[0])
1226
1227     @property
1228     def is_class(self) -> bool:
1229         """Is this line a class definition?"""
1230         return (
1231             bool(self)
1232             and self.leaves[0].type == token.NAME
1233             and self.leaves[0].value == "class"
1234         )
1235
1236     @property
1237     def is_stub_class(self) -> bool:
1238         """Is this line a class definition with a body consisting only of "..."?"""
1239         return self.is_class and self.leaves[-3:] == [
1240             Leaf(token.DOT, ".") for _ in range(3)
1241         ]
1242
1243     @property
1244     def is_def(self) -> bool:
1245         """Is this a function definition? (Also returns True for async defs.)"""
1246         try:
1247             first_leaf = self.leaves[0]
1248         except IndexError:
1249             return False
1250
1251         try:
1252             second_leaf: Optional[Leaf] = self.leaves[1]
1253         except IndexError:
1254             second_leaf = None
1255         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1256             first_leaf.type == token.ASYNC
1257             and second_leaf is not None
1258             and second_leaf.type == token.NAME
1259             and second_leaf.value == "def"
1260         )
1261
1262     @property
1263     def is_class_paren_empty(self) -> bool:
1264         """Is this a class with no base classes but using parentheses?
1265
1266         Those are unnecessary and should be removed.
1267         """
1268         return (
1269             bool(self)
1270             and len(self.leaves) == 4
1271             and self.is_class
1272             and self.leaves[2].type == token.LPAR
1273             and self.leaves[2].value == "("
1274             and self.leaves[3].type == token.RPAR
1275             and self.leaves[3].value == ")"
1276         )
1277
1278     @property
1279     def is_triple_quoted_string(self) -> bool:
1280         """Is the line a triple quoted string?"""
1281         return (
1282             bool(self)
1283             and self.leaves[0].type == token.STRING
1284             and self.leaves[0].value.startswith(('"""', "'''"))
1285         )
1286
1287     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1288         """If so, needs to be split before emitting."""
1289         for leaf in self.leaves:
1290             if leaf.type == STANDALONE_COMMENT:
1291                 if leaf.bracket_depth <= depth_limit:
1292                     return True
1293         return False
1294
1295     def contains_uncollapsable_type_comments(self) -> bool:
1296         ignored_ids = set()
1297         try:
1298             last_leaf = self.leaves[-1]
1299             ignored_ids.add(id(last_leaf))
1300             if last_leaf.type == token.COMMA or (
1301                 last_leaf.type == token.RPAR and not last_leaf.value
1302             ):
1303                 # When trailing commas or optional parens are inserted by Black for
1304                 # consistency, comments after the previous last element are not moved
1305                 # (they don't have to, rendering will still be correct).  So we ignore
1306                 # trailing commas and invisible.
1307                 last_leaf = self.leaves[-2]
1308                 ignored_ids.add(id(last_leaf))
1309         except IndexError:
1310             return False
1311
1312         # A type comment is uncollapsable if it is attached to a leaf
1313         # that isn't at the end of the line (since that could cause it
1314         # to get associated to a different argument) or if there are
1315         # comments before it (since that could cause it to get hidden
1316         # behind a comment.
1317         comment_seen = False
1318         for leaf_id, comments in self.comments.items():
1319             for comment in comments:
1320                 if is_type_comment(comment):
1321                     if leaf_id not in ignored_ids or comment_seen:
1322                         return True
1323
1324                 comment_seen = True
1325
1326         return False
1327
1328     def contains_unsplittable_type_ignore(self) -> bool:
1329         if not self.leaves:
1330             return False
1331
1332         # If a 'type: ignore' is attached to the end of a line, we
1333         # can't split the line, because we can't know which of the
1334         # subexpressions the ignore was meant to apply to.
1335         #
1336         # We only want this to apply to actual physical lines from the
1337         # original source, though: we don't want the presence of a
1338         # 'type: ignore' at the end of a multiline expression to
1339         # justify pushing it all onto one line. Thus we
1340         # (unfortunately) need to check the actual source lines and
1341         # only report an unsplittable 'type: ignore' if this line was
1342         # one line in the original code.
1343
1344         # Like in the type comment check above, we need to skip a black added
1345         # trailing comma or invisible paren, since it will be the original leaf
1346         # before it that has the original line number.
1347         last_idx = -1
1348         last_leaf = self.leaves[-1]
1349         if len(self.leaves) > 2 and (
1350             last_leaf.type == token.COMMA
1351             or (last_leaf.type == token.RPAR and not last_leaf.value)
1352         ):
1353             last_idx = -2
1354
1355         if self.leaves[0].lineno == self.leaves[last_idx].lineno:
1356             for node in self.leaves[last_idx:]:
1357                 for comment in self.comments.get(id(node), []):
1358                     if is_type_comment(comment, " ignore"):
1359                         return True
1360
1361         return False
1362
1363     def contains_multiline_strings(self) -> bool:
1364         for leaf in self.leaves:
1365             if is_multiline_string(leaf):
1366                 return True
1367
1368         return False
1369
1370     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1371         """Remove trailing comma if there is one and it's safe."""
1372         if not (
1373             self.leaves
1374             and self.leaves[-1].type == token.COMMA
1375             and closing.type in CLOSING_BRACKETS
1376         ):
1377             return False
1378
1379         if closing.type == token.RBRACE:
1380             self.remove_trailing_comma()
1381             return True
1382
1383         if closing.type == token.RSQB:
1384             comma = self.leaves[-1]
1385             if comma.parent and comma.parent.type == syms.listmaker:
1386                 self.remove_trailing_comma()
1387                 return True
1388
1389         # For parens let's check if it's safe to remove the comma.
1390         # Imports are always safe.
1391         if self.is_import:
1392             self.remove_trailing_comma()
1393             return True
1394
1395         # Otherwise, if the trailing one is the only one, we might mistakenly
1396         # change a tuple into a different type by removing the comma.
1397         depth = closing.bracket_depth + 1
1398         commas = 0
1399         opening = closing.opening_bracket
1400         for _opening_index, leaf in enumerate(self.leaves):
1401             if leaf is opening:
1402                 break
1403
1404         else:
1405             return False
1406
1407         for leaf in self.leaves[_opening_index + 1 :]:
1408             if leaf is closing:
1409                 break
1410
1411             bracket_depth = leaf.bracket_depth
1412             if bracket_depth == depth and leaf.type == token.COMMA:
1413                 commas += 1
1414                 if leaf.parent and leaf.parent.type in {
1415                     syms.arglist,
1416                     syms.typedargslist,
1417                 }:
1418                     commas += 1
1419                     break
1420
1421         if commas > 1:
1422             self.remove_trailing_comma()
1423             return True
1424
1425         return False
1426
1427     def append_comment(self, comment: Leaf) -> bool:
1428         """Add an inline or standalone comment to the line."""
1429         if (
1430             comment.type == STANDALONE_COMMENT
1431             and self.bracket_tracker.any_open_brackets()
1432         ):
1433             comment.prefix = ""
1434             return False
1435
1436         if comment.type != token.COMMENT:
1437             return False
1438
1439         if not self.leaves:
1440             comment.type = STANDALONE_COMMENT
1441             comment.prefix = ""
1442             return False
1443
1444         last_leaf = self.leaves[-1]
1445         if (
1446             last_leaf.type == token.RPAR
1447             and not last_leaf.value
1448             and last_leaf.parent
1449             and len(list(last_leaf.parent.leaves())) <= 3
1450             and not is_type_comment(comment)
1451         ):
1452             # Comments on an optional parens wrapping a single leaf should belong to
1453             # the wrapped node except if it's a type comment. Pinning the comment like
1454             # this avoids unstable formatting caused by comment migration.
1455             if len(self.leaves) < 2:
1456                 comment.type = STANDALONE_COMMENT
1457                 comment.prefix = ""
1458                 return False
1459             last_leaf = self.leaves[-2]
1460         self.comments.setdefault(id(last_leaf), []).append(comment)
1461         return True
1462
1463     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1464         """Generate comments that should appear directly after `leaf`."""
1465         return self.comments.get(id(leaf), [])
1466
1467     def remove_trailing_comma(self) -> None:
1468         """Remove the trailing comma and moves the comments attached to it."""
1469         trailing_comma = self.leaves.pop()
1470         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1471         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1472             trailing_comma_comments
1473         )
1474
1475     def is_complex_subscript(self, leaf: Leaf) -> bool:
1476         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1477         open_lsqb = self.bracket_tracker.get_open_lsqb()
1478         if open_lsqb is None:
1479             return False
1480
1481         subscript_start = open_lsqb.next_sibling
1482
1483         if isinstance(subscript_start, Node):
1484             if subscript_start.type == syms.listmaker:
1485                 return False
1486
1487             if subscript_start.type == syms.subscriptlist:
1488                 subscript_start = child_towards(subscript_start, leaf)
1489         return subscript_start is not None and any(
1490             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1491         )
1492
1493     def __str__(self) -> str:
1494         """Render the line."""
1495         if not self:
1496             return "\n"
1497
1498         indent = "    " * self.depth
1499         leaves = iter(self.leaves)
1500         first = next(leaves)
1501         res = f"{first.prefix}{indent}{first.value}"
1502         for leaf in leaves:
1503             res += str(leaf)
1504         for comment in itertools.chain.from_iterable(self.comments.values()):
1505             res += str(comment)
1506         return res + "\n"
1507
1508     def __bool__(self) -> bool:
1509         """Return True if the line has leaves or comments."""
1510         return bool(self.leaves or self.comments)
1511
1512
1513 @dataclass
1514 class EmptyLineTracker:
1515     """Provides a stateful method that returns the number of potential extra
1516     empty lines needed before and after the currently processed line.
1517
1518     Note: this tracker works on lines that haven't been split yet.  It assumes
1519     the prefix of the first leaf consists of optional newlines.  Those newlines
1520     are consumed by `maybe_empty_lines()` and included in the computation.
1521     """
1522
1523     is_pyi: bool = False
1524     previous_line: Optional[Line] = None
1525     previous_after: int = 0
1526     previous_defs: List[int] = Factory(list)
1527
1528     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1529         """Return the number of extra empty lines before and after the `current_line`.
1530
1531         This is for separating `def`, `async def` and `class` with extra empty
1532         lines (two on module-level).
1533         """
1534         before, after = self._maybe_empty_lines(current_line)
1535         before = (
1536             # Black should not insert empty lines at the beginning
1537             # of the file
1538             0
1539             if self.previous_line is None
1540             else before - self.previous_after
1541         )
1542         self.previous_after = after
1543         self.previous_line = current_line
1544         return before, after
1545
1546     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1547         max_allowed = 1
1548         if current_line.depth == 0:
1549             max_allowed = 1 if self.is_pyi else 2
1550         if current_line.leaves:
1551             # Consume the first leaf's extra newlines.
1552             first_leaf = current_line.leaves[0]
1553             before = first_leaf.prefix.count("\n")
1554             before = min(before, max_allowed)
1555             first_leaf.prefix = ""
1556         else:
1557             before = 0
1558         depth = current_line.depth
1559         while self.previous_defs and self.previous_defs[-1] >= depth:
1560             self.previous_defs.pop()
1561             if self.is_pyi:
1562                 before = 0 if depth else 1
1563             else:
1564                 before = 1 if depth else 2
1565         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1566             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1567
1568         if (
1569             self.previous_line
1570             and self.previous_line.is_import
1571             and not current_line.is_import
1572             and depth == self.previous_line.depth
1573         ):
1574             return (before or 1), 0
1575
1576         if (
1577             self.previous_line
1578             and self.previous_line.is_class
1579             and current_line.is_triple_quoted_string
1580         ):
1581             return before, 1
1582
1583         return before, 0
1584
1585     def _maybe_empty_lines_for_class_or_def(
1586         self, current_line: Line, before: int
1587     ) -> Tuple[int, int]:
1588         if not current_line.is_decorator:
1589             self.previous_defs.append(current_line.depth)
1590         if self.previous_line is None:
1591             # Don't insert empty lines before the first line in the file.
1592             return 0, 0
1593
1594         if self.previous_line.is_decorator:
1595             return 0, 0
1596
1597         if self.previous_line.depth < current_line.depth and (
1598             self.previous_line.is_class or self.previous_line.is_def
1599         ):
1600             return 0, 0
1601
1602         if (
1603             self.previous_line.is_comment
1604             and self.previous_line.depth == current_line.depth
1605             and before == 0
1606         ):
1607             return 0, 0
1608
1609         if self.is_pyi:
1610             if self.previous_line.depth > current_line.depth:
1611                 newlines = 1
1612             elif current_line.is_class or self.previous_line.is_class:
1613                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1614                     # No blank line between classes with an empty body
1615                     newlines = 0
1616                 else:
1617                     newlines = 1
1618             elif current_line.is_def and not self.previous_line.is_def:
1619                 # Blank line between a block of functions and a block of non-functions
1620                 newlines = 1
1621             else:
1622                 newlines = 0
1623         else:
1624             newlines = 2
1625         if current_line.depth and newlines:
1626             newlines -= 1
1627         return newlines, 0
1628
1629
1630 @dataclass
1631 class LineGenerator(Visitor[Line]):
1632     """Generates reformatted Line objects.  Empty lines are not emitted.
1633
1634     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1635     in ways that will no longer stringify to valid Python code on the tree.
1636     """
1637
1638     is_pyi: bool = False
1639     normalize_strings: bool = True
1640     current_line: Line = Factory(Line)
1641     remove_u_prefix: bool = False
1642
1643     def line(self, indent: int = 0) -> Iterator[Line]:
1644         """Generate a line.
1645
1646         If the line is empty, only emit if it makes sense.
1647         If the line is too long, split it first and then generate.
1648
1649         If any lines were generated, set up a new current_line.
1650         """
1651         if not self.current_line:
1652             self.current_line.depth += indent
1653             return  # Line is empty, don't emit. Creating a new one unnecessary.
1654
1655         complete_line = self.current_line
1656         self.current_line = Line(depth=complete_line.depth + indent)
1657         yield complete_line
1658
1659     def visit_default(self, node: LN) -> Iterator[Line]:
1660         """Default `visit_*()` implementation. Recurses to children of `node`."""
1661         if isinstance(node, Leaf):
1662             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1663             for comment in generate_comments(node):
1664                 if any_open_brackets:
1665                     # any comment within brackets is subject to splitting
1666                     self.current_line.append(comment)
1667                 elif comment.type == token.COMMENT:
1668                     # regular trailing comment
1669                     self.current_line.append(comment)
1670                     yield from self.line()
1671
1672                 else:
1673                     # regular standalone comment
1674                     yield from self.line()
1675
1676                     self.current_line.append(comment)
1677                     yield from self.line()
1678
1679             normalize_prefix(node, inside_brackets=any_open_brackets)
1680             if self.normalize_strings and node.type == token.STRING:
1681                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1682                 normalize_string_quotes(node)
1683             if node.type == token.NUMBER:
1684                 normalize_numeric_literal(node)
1685             if node.type not in WHITESPACE:
1686                 self.current_line.append(node)
1687         yield from super().visit_default(node)
1688
1689     def visit_atom(self, node: Node) -> Iterator[Line]:
1690         # Always make parentheses invisible around a single node, because it should
1691         # not be needed (except in the case of yield, where removing the parentheses
1692         # produces a SyntaxError).
1693         if (
1694             len(node.children) == 3
1695             and isinstance(node.children[0], Leaf)
1696             and node.children[0].type == token.LPAR
1697             and isinstance(node.children[2], Leaf)
1698             and node.children[2].type == token.RPAR
1699             and isinstance(node.children[1], Leaf)
1700             and not (
1701                 node.children[1].type == token.NAME
1702                 and node.children[1].value == "yield"
1703             )
1704         ):
1705             node.children[0].value = ""
1706             node.children[2].value = ""
1707         yield from super().visit_default(node)
1708
1709     def visit_factor(self, node: Node) -> Iterator[Line]:
1710         """Force parentheses between a unary op and a binary power:
1711
1712         -2 ** 8 -> -(2 ** 8)
1713         """
1714         child = node.children[1]
1715         if child.type == syms.power and len(child.children) == 3:
1716             lpar = Leaf(token.LPAR, "(")
1717             rpar = Leaf(token.RPAR, ")")
1718             index = child.remove() or 0
1719             node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
1720         yield from self.visit_default(node)
1721
1722     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1723         """Increase indentation level, maybe yield a line."""
1724         # In blib2to3 INDENT never holds comments.
1725         yield from self.line(+1)
1726         yield from self.visit_default(node)
1727
1728     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1729         """Decrease indentation level, maybe yield a line."""
1730         # The current line might still wait for trailing comments.  At DEDENT time
1731         # there won't be any (they would be prefixes on the preceding NEWLINE).
1732         # Emit the line then.
1733         yield from self.line()
1734
1735         # While DEDENT has no value, its prefix may contain standalone comments
1736         # that belong to the current indentation level.  Get 'em.
1737         yield from self.visit_default(node)
1738
1739         # Finally, emit the dedent.
1740         yield from self.line(-1)
1741
1742     def visit_stmt(
1743         self, node: Node, keywords: Set[str], parens: Set[str]
1744     ) -> Iterator[Line]:
1745         """Visit a statement.
1746
1747         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1748         `def`, `with`, `class`, `assert` and assignments.
1749
1750         The relevant Python language `keywords` for a given statement will be
1751         NAME leaves within it. This methods puts those on a separate line.
1752
1753         `parens` holds a set of string leaf values immediately after which
1754         invisible parens should be put.
1755         """
1756         normalize_invisible_parens(node, parens_after=parens)
1757         for child in node.children:
1758             if child.type == token.NAME and child.value in keywords:  # type: ignore
1759                 yield from self.line()
1760
1761             yield from self.visit(child)
1762
1763     def visit_suite(self, node: Node) -> Iterator[Line]:
1764         """Visit a suite."""
1765         if self.is_pyi and is_stub_suite(node):
1766             yield from self.visit(node.children[2])
1767         else:
1768             yield from self.visit_default(node)
1769
1770     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1771         """Visit a statement without nested statements."""
1772         is_suite_like = node.parent and node.parent.type in STATEMENT
1773         if is_suite_like:
1774             if self.is_pyi and is_stub_body(node):
1775                 yield from self.visit_default(node)
1776             else:
1777                 yield from self.line(+1)
1778                 yield from self.visit_default(node)
1779                 yield from self.line(-1)
1780
1781         else:
1782             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1783                 yield from self.line()
1784             yield from self.visit_default(node)
1785
1786     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1787         """Visit `async def`, `async for`, `async with`."""
1788         yield from self.line()
1789
1790         children = iter(node.children)
1791         for child in children:
1792             yield from self.visit(child)
1793
1794             if child.type == token.ASYNC:
1795                 break
1796
1797         internal_stmt = next(children)
1798         for child in internal_stmt.children:
1799             yield from self.visit(child)
1800
1801     def visit_decorators(self, node: Node) -> Iterator[Line]:
1802         """Visit decorators."""
1803         for child in node.children:
1804             yield from self.line()
1805             yield from self.visit(child)
1806
1807     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1808         """Remove a semicolon and put the other statement on a separate line."""
1809         yield from self.line()
1810
1811     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1812         """End of file. Process outstanding comments and end with a newline."""
1813         yield from self.visit_default(leaf)
1814         yield from self.line()
1815
1816     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1817         if not self.current_line.bracket_tracker.any_open_brackets():
1818             yield from self.line()
1819         yield from self.visit_default(leaf)
1820
1821     def __attrs_post_init__(self) -> None:
1822         """You are in a twisty little maze of passages."""
1823         v = self.visit_stmt
1824         Ø: Set[str] = set()
1825         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1826         self.visit_if_stmt = partial(
1827             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1828         )
1829         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1830         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1831         self.visit_try_stmt = partial(
1832             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1833         )
1834         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1835         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1836         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1837         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1838         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1839         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1840         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1841         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1842         self.visit_async_funcdef = self.visit_async_stmt
1843         self.visit_decorated = self.visit_decorators
1844
1845
1846 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1847 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1848 OPENING_BRACKETS = set(BRACKET.keys())
1849 CLOSING_BRACKETS = set(BRACKET.values())
1850 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1851 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1852
1853
1854 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1855     """Return whitespace prefix if needed for the given `leaf`.
1856
1857     `complex_subscript` signals whether the given leaf is part of a subscription
1858     which has non-trivial arguments, like arithmetic expressions or function calls.
1859     """
1860     NO = ""
1861     SPACE = " "
1862     DOUBLESPACE = "  "
1863     t = leaf.type
1864     p = leaf.parent
1865     v = leaf.value
1866     if t in ALWAYS_NO_SPACE:
1867         return NO
1868
1869     if t == token.COMMENT:
1870         return DOUBLESPACE
1871
1872     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1873     if t == token.COLON and p.type not in {
1874         syms.subscript,
1875         syms.subscriptlist,
1876         syms.sliceop,
1877     }:
1878         return NO
1879
1880     prev = leaf.prev_sibling
1881     if not prev:
1882         prevp = preceding_leaf(p)
1883         if not prevp or prevp.type in OPENING_BRACKETS:
1884             return NO
1885
1886         if t == token.COLON:
1887             if prevp.type == token.COLON:
1888                 return NO
1889
1890             elif prevp.type != token.COMMA and not complex_subscript:
1891                 return NO
1892
1893             return SPACE
1894
1895         if prevp.type == token.EQUAL:
1896             if prevp.parent:
1897                 if prevp.parent.type in {
1898                     syms.arglist,
1899                     syms.argument,
1900                     syms.parameters,
1901                     syms.varargslist,
1902                 }:
1903                     return NO
1904
1905                 elif prevp.parent.type == syms.typedargslist:
1906                     # A bit hacky: if the equal sign has whitespace, it means we
1907                     # previously found it's a typed argument.  So, we're using
1908                     # that, too.
1909                     return prevp.prefix
1910
1911         elif prevp.type in VARARGS_SPECIALS:
1912             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1913                 return NO
1914
1915         elif prevp.type == token.COLON:
1916             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1917                 return SPACE if complex_subscript else NO
1918
1919         elif (
1920             prevp.parent
1921             and prevp.parent.type == syms.factor
1922             and prevp.type in MATH_OPERATORS
1923         ):
1924             return NO
1925
1926         elif (
1927             prevp.type == token.RIGHTSHIFT
1928             and prevp.parent
1929             and prevp.parent.type == syms.shift_expr
1930             and prevp.prev_sibling
1931             and prevp.prev_sibling.type == token.NAME
1932             and prevp.prev_sibling.value == "print"  # type: ignore
1933         ):
1934             # Python 2 print chevron
1935             return NO
1936
1937     elif prev.type in OPENING_BRACKETS:
1938         return NO
1939
1940     if p.type in {syms.parameters, syms.arglist}:
1941         # untyped function signatures or calls
1942         if not prev or prev.type != token.COMMA:
1943             return NO
1944
1945     elif p.type == syms.varargslist:
1946         # lambdas
1947         if prev and prev.type != token.COMMA:
1948             return NO
1949
1950     elif p.type == syms.typedargslist:
1951         # typed function signatures
1952         if not prev:
1953             return NO
1954
1955         if t == token.EQUAL:
1956             if prev.type != syms.tname:
1957                 return NO
1958
1959         elif prev.type == token.EQUAL:
1960             # A bit hacky: if the equal sign has whitespace, it means we
1961             # previously found it's a typed argument.  So, we're using that, too.
1962             return prev.prefix
1963
1964         elif prev.type != token.COMMA:
1965             return NO
1966
1967     elif p.type == syms.tname:
1968         # type names
1969         if not prev:
1970             prevp = preceding_leaf(p)
1971             if not prevp or prevp.type != token.COMMA:
1972                 return NO
1973
1974     elif p.type == syms.trailer:
1975         # attributes and calls
1976         if t == token.LPAR or t == token.RPAR:
1977             return NO
1978
1979         if not prev:
1980             if t == token.DOT:
1981                 prevp = preceding_leaf(p)
1982                 if not prevp or prevp.type != token.NUMBER:
1983                     return NO
1984
1985             elif t == token.LSQB:
1986                 return NO
1987
1988         elif prev.type != token.COMMA:
1989             return NO
1990
1991     elif p.type == syms.argument:
1992         # single argument
1993         if t == token.EQUAL:
1994             return NO
1995
1996         if not prev:
1997             prevp = preceding_leaf(p)
1998             if not prevp or prevp.type == token.LPAR:
1999                 return NO
2000
2001         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2002             return NO
2003
2004     elif p.type == syms.decorator:
2005         # decorators
2006         return NO
2007
2008     elif p.type == syms.dotted_name:
2009         if prev:
2010             return NO
2011
2012         prevp = preceding_leaf(p)
2013         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2014             return NO
2015
2016     elif p.type == syms.classdef:
2017         if t == token.LPAR:
2018             return NO
2019
2020         if prev and prev.type == token.LPAR:
2021             return NO
2022
2023     elif p.type in {syms.subscript, syms.sliceop}:
2024         # indexing
2025         if not prev:
2026             assert p.parent is not None, "subscripts are always parented"
2027             if p.parent.type == syms.subscriptlist:
2028                 return SPACE
2029
2030             return NO
2031
2032         elif not complex_subscript:
2033             return NO
2034
2035     elif p.type == syms.atom:
2036         if prev and t == token.DOT:
2037             # dots, but not the first one.
2038             return NO
2039
2040     elif p.type == syms.dictsetmaker:
2041         # dict unpacking
2042         if prev and prev.type == token.DOUBLESTAR:
2043             return NO
2044
2045     elif p.type in {syms.factor, syms.star_expr}:
2046         # unary ops
2047         if not prev:
2048             prevp = preceding_leaf(p)
2049             if not prevp or prevp.type in OPENING_BRACKETS:
2050                 return NO
2051
2052             prevp_parent = prevp.parent
2053             assert prevp_parent is not None
2054             if prevp.type == token.COLON and prevp_parent.type in {
2055                 syms.subscript,
2056                 syms.sliceop,
2057             }:
2058                 return NO
2059
2060             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2061                 return NO
2062
2063         elif t in {token.NAME, token.NUMBER, token.STRING}:
2064             return NO
2065
2066     elif p.type == syms.import_from:
2067         if t == token.DOT:
2068             if prev and prev.type == token.DOT:
2069                 return NO
2070
2071         elif t == token.NAME:
2072             if v == "import":
2073                 return SPACE
2074
2075             if prev and prev.type == token.DOT:
2076                 return NO
2077
2078     elif p.type == syms.sliceop:
2079         return NO
2080
2081     return SPACE
2082
2083
2084 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2085     """Return the first leaf that precedes `node`, if any."""
2086     while node:
2087         res = node.prev_sibling
2088         if res:
2089             if isinstance(res, Leaf):
2090                 return res
2091
2092             try:
2093                 return list(res.leaves())[-1]
2094
2095             except IndexError:
2096                 return None
2097
2098         node = node.parent
2099     return None
2100
2101
2102 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2103     """Return the child of `ancestor` that contains `descendant`."""
2104     node: Optional[LN] = descendant
2105     while node and node.parent != ancestor:
2106         node = node.parent
2107     return node
2108
2109
2110 def container_of(leaf: Leaf) -> LN:
2111     """Return `leaf` or one of its ancestors that is the topmost container of it.
2112
2113     By "container" we mean a node where `leaf` is the very first child.
2114     """
2115     same_prefix = leaf.prefix
2116     container: LN = leaf
2117     while container:
2118         parent = container.parent
2119         if parent is None:
2120             break
2121
2122         if parent.children[0].prefix != same_prefix:
2123             break
2124
2125         if parent.type == syms.file_input:
2126             break
2127
2128         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2129             break
2130
2131         container = parent
2132     return container
2133
2134
2135 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2136     """Return the priority of the `leaf` delimiter, given a line break after it.
2137
2138     The delimiter priorities returned here are from those delimiters that would
2139     cause a line break after themselves.
2140
2141     Higher numbers are higher priority.
2142     """
2143     if leaf.type == token.COMMA:
2144         return COMMA_PRIORITY
2145
2146     return 0
2147
2148
2149 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2150     """Return the priority of the `leaf` delimiter, given a line break before it.
2151
2152     The delimiter priorities returned here are from those delimiters that would
2153     cause a line break before themselves.
2154
2155     Higher numbers are higher priority.
2156     """
2157     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2158         # * and ** might also be MATH_OPERATORS but in this case they are not.
2159         # Don't treat them as a delimiter.
2160         return 0
2161
2162     if (
2163         leaf.type == token.DOT
2164         and leaf.parent
2165         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2166         and (previous is None or previous.type in CLOSING_BRACKETS)
2167     ):
2168         return DOT_PRIORITY
2169
2170     if (
2171         leaf.type in MATH_OPERATORS
2172         and leaf.parent
2173         and leaf.parent.type not in {syms.factor, syms.star_expr}
2174     ):
2175         return MATH_PRIORITIES[leaf.type]
2176
2177     if leaf.type in COMPARATORS:
2178         return COMPARATOR_PRIORITY
2179
2180     if (
2181         leaf.type == token.STRING
2182         and previous is not None
2183         and previous.type == token.STRING
2184     ):
2185         return STRING_PRIORITY
2186
2187     if leaf.type not in {token.NAME, token.ASYNC}:
2188         return 0
2189
2190     if (
2191         leaf.value == "for"
2192         and leaf.parent
2193         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2194         or leaf.type == token.ASYNC
2195     ):
2196         if (
2197             not isinstance(leaf.prev_sibling, Leaf)
2198             or leaf.prev_sibling.value != "async"
2199         ):
2200             return COMPREHENSION_PRIORITY
2201
2202     if (
2203         leaf.value == "if"
2204         and leaf.parent
2205         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2206     ):
2207         return COMPREHENSION_PRIORITY
2208
2209     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2210         return TERNARY_PRIORITY
2211
2212     if leaf.value == "is":
2213         return COMPARATOR_PRIORITY
2214
2215     if (
2216         leaf.value == "in"
2217         and leaf.parent
2218         and leaf.parent.type in {syms.comp_op, syms.comparison}
2219         and not (
2220             previous is not None
2221             and previous.type == token.NAME
2222             and previous.value == "not"
2223         )
2224     ):
2225         return COMPARATOR_PRIORITY
2226
2227     if (
2228         leaf.value == "not"
2229         and leaf.parent
2230         and leaf.parent.type == syms.comp_op
2231         and not (
2232             previous is not None
2233             and previous.type == token.NAME
2234             and previous.value == "is"
2235         )
2236     ):
2237         return COMPARATOR_PRIORITY
2238
2239     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2240         return LOGIC_PRIORITY
2241
2242     return 0
2243
2244
2245 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2246 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2247
2248
2249 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2250     """Clean the prefix of the `leaf` and generate comments from it, if any.
2251
2252     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2253     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2254     move because it does away with modifying the grammar to include all the
2255     possible places in which comments can be placed.
2256
2257     The sad consequence for us though is that comments don't "belong" anywhere.
2258     This is why this function generates simple parentless Leaf objects for
2259     comments.  We simply don't know what the correct parent should be.
2260
2261     No matter though, we can live without this.  We really only need to
2262     differentiate between inline and standalone comments.  The latter don't
2263     share the line with any code.
2264
2265     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2266     are emitted with a fake STANDALONE_COMMENT token identifier.
2267     """
2268     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2269         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2270
2271
2272 @dataclass
2273 class ProtoComment:
2274     """Describes a piece of syntax that is a comment.
2275
2276     It's not a :class:`blib2to3.pytree.Leaf` so that:
2277
2278     * it can be cached (`Leaf` objects should not be reused more than once as
2279       they store their lineno, column, prefix, and parent information);
2280     * `newlines` and `consumed` fields are kept separate from the `value`. This
2281       simplifies handling of special marker comments like ``# fmt: off/on``.
2282     """
2283
2284     type: int  # token.COMMENT or STANDALONE_COMMENT
2285     value: str  # content of the comment
2286     newlines: int  # how many newlines before the comment
2287     consumed: int  # how many characters of the original leaf's prefix did we consume
2288
2289
2290 @lru_cache(maxsize=4096)
2291 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2292     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2293     result: List[ProtoComment] = []
2294     if not prefix or "#" not in prefix:
2295         return result
2296
2297     consumed = 0
2298     nlines = 0
2299     ignored_lines = 0
2300     for index, line in enumerate(prefix.split("\n")):
2301         consumed += len(line) + 1  # adding the length of the split '\n'
2302         line = line.lstrip()
2303         if not line:
2304             nlines += 1
2305         if not line.startswith("#"):
2306             # Escaped newlines outside of a comment are not really newlines at
2307             # all. We treat a single-line comment following an escaped newline
2308             # as a simple trailing comment.
2309             if line.endswith("\\"):
2310                 ignored_lines += 1
2311             continue
2312
2313         if index == ignored_lines and not is_endmarker:
2314             comment_type = token.COMMENT  # simple trailing comment
2315         else:
2316             comment_type = STANDALONE_COMMENT
2317         comment = make_comment(line)
2318         result.append(
2319             ProtoComment(
2320                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2321             )
2322         )
2323         nlines = 0
2324     return result
2325
2326
2327 def make_comment(content: str) -> str:
2328     """Return a consistently formatted comment from the given `content` string.
2329
2330     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2331     space between the hash sign and the content.
2332
2333     If `content` didn't start with a hash sign, one is provided.
2334     """
2335     content = content.rstrip()
2336     if not content:
2337         return "#"
2338
2339     if content[0] == "#":
2340         content = content[1:]
2341     if content and content[0] not in " !:#'%":
2342         content = " " + content
2343     return "#" + content
2344
2345
2346 def split_line(
2347     line: Line,
2348     line_length: int,
2349     inner: bool = False,
2350     features: Collection[Feature] = (),
2351 ) -> Iterator[Line]:
2352     """Split a `line` into potentially many lines.
2353
2354     They should fit in the allotted `line_length` but might not be able to.
2355     `inner` signifies that there were a pair of brackets somewhere around the
2356     current `line`, possibly transitively. This means we can fallback to splitting
2357     by delimiters if the LHS/RHS don't yield any results.
2358
2359     `features` are syntactical features that may be used in the output.
2360     """
2361     if line.is_comment:
2362         yield line
2363         return
2364
2365     line_str = str(line).strip("\n")
2366
2367     if (
2368         not line.contains_uncollapsable_type_comments()
2369         and not line.should_explode
2370         and (
2371             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2372             or line.contains_unsplittable_type_ignore()
2373         )
2374     ):
2375         yield line
2376         return
2377
2378     split_funcs: List[SplitFunc]
2379     if line.is_def:
2380         split_funcs = [left_hand_split]
2381     else:
2382
2383         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2384             for omit in generate_trailers_to_omit(line, line_length):
2385                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2386                 if is_line_short_enough(lines[0], line_length=line_length):
2387                     yield from lines
2388                     return
2389
2390             # All splits failed, best effort split with no omits.
2391             # This mostly happens to multiline strings that are by definition
2392             # reported as not fitting a single line.
2393             yield from right_hand_split(line, line_length, features=features)
2394
2395         if line.inside_brackets:
2396             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2397         else:
2398             split_funcs = [rhs]
2399     for split_func in split_funcs:
2400         # We are accumulating lines in `result` because we might want to abort
2401         # mission and return the original line in the end, or attempt a different
2402         # split altogether.
2403         result: List[Line] = []
2404         try:
2405             for l in split_func(line, features):
2406                 if str(l).strip("\n") == line_str:
2407                     raise CannotSplit("Split function returned an unchanged result")
2408
2409                 result.extend(
2410                     split_line(
2411                         l, line_length=line_length, inner=True, features=features
2412                     )
2413                 )
2414         except CannotSplit:
2415             continue
2416
2417         else:
2418             yield from result
2419             break
2420
2421     else:
2422         yield line
2423
2424
2425 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2426     """Split line into many lines, starting with the first matching bracket pair.
2427
2428     Note: this usually looks weird, only use this for function definitions.
2429     Prefer RHS otherwise.  This is why this function is not symmetrical with
2430     :func:`right_hand_split` which also handles optional parentheses.
2431     """
2432     tail_leaves: List[Leaf] = []
2433     body_leaves: List[Leaf] = []
2434     head_leaves: List[Leaf] = []
2435     current_leaves = head_leaves
2436     matching_bracket = None
2437     for leaf in line.leaves:
2438         if (
2439             current_leaves is body_leaves
2440             and leaf.type in CLOSING_BRACKETS
2441             and leaf.opening_bracket is matching_bracket
2442         ):
2443             current_leaves = tail_leaves if body_leaves else head_leaves
2444         current_leaves.append(leaf)
2445         if current_leaves is head_leaves:
2446             if leaf.type in OPENING_BRACKETS:
2447                 matching_bracket = leaf
2448                 current_leaves = body_leaves
2449     if not matching_bracket:
2450         raise CannotSplit("No brackets found")
2451
2452     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2453     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2454     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2455     bracket_split_succeeded_or_raise(head, body, tail)
2456     for result in (head, body, tail):
2457         if result:
2458             yield result
2459
2460
2461 def right_hand_split(
2462     line: Line,
2463     line_length: int,
2464     features: Collection[Feature] = (),
2465     omit: Collection[LeafID] = (),
2466 ) -> Iterator[Line]:
2467     """Split line into many lines, starting with the last matching bracket pair.
2468
2469     If the split was by optional parentheses, attempt splitting without them, too.
2470     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2471     this split.
2472
2473     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2474     """
2475     tail_leaves: List[Leaf] = []
2476     body_leaves: List[Leaf] = []
2477     head_leaves: List[Leaf] = []
2478     current_leaves = tail_leaves
2479     opening_bracket = None
2480     closing_bracket = None
2481     for leaf in reversed(line.leaves):
2482         if current_leaves is body_leaves:
2483             if leaf is opening_bracket:
2484                 current_leaves = head_leaves if body_leaves else tail_leaves
2485         current_leaves.append(leaf)
2486         if current_leaves is tail_leaves:
2487             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2488                 opening_bracket = leaf.opening_bracket
2489                 closing_bracket = leaf
2490                 current_leaves = body_leaves
2491     if not (opening_bracket and closing_bracket and head_leaves):
2492         # If there is no opening or closing_bracket that means the split failed and
2493         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2494         # the matching `opening_bracket` wasn't available on `line` anymore.
2495         raise CannotSplit("No brackets found")
2496
2497     tail_leaves.reverse()
2498     body_leaves.reverse()
2499     head_leaves.reverse()
2500     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2501     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2502     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2503     bracket_split_succeeded_or_raise(head, body, tail)
2504     if (
2505         # the body shouldn't be exploded
2506         not body.should_explode
2507         # the opening bracket is an optional paren
2508         and opening_bracket.type == token.LPAR
2509         and not opening_bracket.value
2510         # the closing bracket is an optional paren
2511         and closing_bracket.type == token.RPAR
2512         and not closing_bracket.value
2513         # it's not an import (optional parens are the only thing we can split on
2514         # in this case; attempting a split without them is a waste of time)
2515         and not line.is_import
2516         # there are no standalone comments in the body
2517         and not body.contains_standalone_comments(0)
2518         # and we can actually remove the parens
2519         and can_omit_invisible_parens(body, line_length)
2520     ):
2521         omit = {id(closing_bracket), *omit}
2522         try:
2523             yield from right_hand_split(line, line_length, features=features, omit=omit)
2524             return
2525
2526         except CannotSplit:
2527             if not (
2528                 can_be_split(body)
2529                 or is_line_short_enough(body, line_length=line_length)
2530             ):
2531                 raise CannotSplit(
2532                     "Splitting failed, body is still too long and can't be split."
2533                 )
2534
2535             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2536                 raise CannotSplit(
2537                     "The current optional pair of parentheses is bound to fail to "
2538                     "satisfy the splitting algorithm because the head or the tail "
2539                     "contains multiline strings which by definition never fit one "
2540                     "line."
2541                 )
2542
2543     ensure_visible(opening_bracket)
2544     ensure_visible(closing_bracket)
2545     for result in (head, body, tail):
2546         if result:
2547             yield result
2548
2549
2550 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2551     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2552
2553     Do nothing otherwise.
2554
2555     A left- or right-hand split is based on a pair of brackets. Content before
2556     (and including) the opening bracket is left on one line, content inside the
2557     brackets is put on a separate line, and finally content starting with and
2558     following the closing bracket is put on a separate line.
2559
2560     Those are called `head`, `body`, and `tail`, respectively. If the split
2561     produced the same line (all content in `head`) or ended up with an empty `body`
2562     and the `tail` is just the closing bracket, then it's considered failed.
2563     """
2564     tail_len = len(str(tail).strip())
2565     if not body:
2566         if tail_len == 0:
2567             raise CannotSplit("Splitting brackets produced the same line")
2568
2569         elif tail_len < 3:
2570             raise CannotSplit(
2571                 f"Splitting brackets on an empty body to save "
2572                 f"{tail_len} characters is not worth it"
2573             )
2574
2575
2576 def bracket_split_build_line(
2577     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2578 ) -> Line:
2579     """Return a new line with given `leaves` and respective comments from `original`.
2580
2581     If `is_body` is True, the result line is one-indented inside brackets and as such
2582     has its first leaf's prefix normalized and a trailing comma added when expected.
2583     """
2584     result = Line(depth=original.depth)
2585     if is_body:
2586         result.inside_brackets = True
2587         result.depth += 1
2588         if leaves:
2589             # Since body is a new indent level, remove spurious leading whitespace.
2590             normalize_prefix(leaves[0], inside_brackets=True)
2591             # Ensure a trailing comma for imports and standalone function arguments, but
2592             # be careful not to add one after any comments.
2593             no_commas = original.is_def and not any(
2594                 l.type == token.COMMA for l in leaves
2595             )
2596
2597             if original.is_import or no_commas:
2598                 for i in range(len(leaves) - 1, -1, -1):
2599                     if leaves[i].type == STANDALONE_COMMENT:
2600                         continue
2601                     elif leaves[i].type == token.COMMA:
2602                         break
2603                     else:
2604                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2605                         break
2606     # Populate the line
2607     for leaf in leaves:
2608         result.append(leaf, preformatted=True)
2609         for comment_after in original.comments_after(leaf):
2610             result.append(comment_after, preformatted=True)
2611     if is_body:
2612         result.should_explode = should_explode(result, opening_bracket)
2613     return result
2614
2615
2616 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2617     """Normalize prefix of the first leaf in every line returned by `split_func`.
2618
2619     This is a decorator over relevant split functions.
2620     """
2621
2622     @wraps(split_func)
2623     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2624         for l in split_func(line, features):
2625             normalize_prefix(l.leaves[0], inside_brackets=True)
2626             yield l
2627
2628     return split_wrapper
2629
2630
2631 @dont_increase_indentation
2632 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2633     """Split according to delimiters of the highest priority.
2634
2635     If the appropriate Features are given, the split will add trailing commas
2636     also in function signatures and calls that contain `*` and `**`.
2637     """
2638     try:
2639         last_leaf = line.leaves[-1]
2640     except IndexError:
2641         raise CannotSplit("Line empty")
2642
2643     bt = line.bracket_tracker
2644     try:
2645         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2646     except ValueError:
2647         raise CannotSplit("No delimiters found")
2648
2649     if delimiter_priority == DOT_PRIORITY:
2650         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2651             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2652
2653     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2654     lowest_depth = sys.maxsize
2655     trailing_comma_safe = True
2656
2657     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2658         """Append `leaf` to current line or to new line if appending impossible."""
2659         nonlocal current_line
2660         try:
2661             current_line.append_safe(leaf, preformatted=True)
2662         except ValueError:
2663             yield current_line
2664
2665             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2666             current_line.append(leaf)
2667
2668     for leaf in line.leaves:
2669         yield from append_to_line(leaf)
2670
2671         for comment_after in line.comments_after(leaf):
2672             yield from append_to_line(comment_after)
2673
2674         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2675         if leaf.bracket_depth == lowest_depth:
2676             if is_vararg(leaf, within={syms.typedargslist}):
2677                 trailing_comma_safe = (
2678                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2679                 )
2680             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2681                 trailing_comma_safe = (
2682                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2683                 )
2684
2685         leaf_priority = bt.delimiters.get(id(leaf))
2686         if leaf_priority == delimiter_priority:
2687             yield current_line
2688
2689             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2690     if current_line:
2691         if (
2692             trailing_comma_safe
2693             and delimiter_priority == COMMA_PRIORITY
2694             and current_line.leaves[-1].type != token.COMMA
2695             and current_line.leaves[-1].type != STANDALONE_COMMENT
2696         ):
2697             current_line.append(Leaf(token.COMMA, ","))
2698         yield current_line
2699
2700
2701 @dont_increase_indentation
2702 def standalone_comment_split(
2703     line: Line, features: Collection[Feature] = ()
2704 ) -> Iterator[Line]:
2705     """Split standalone comments from the rest of the line."""
2706     if not line.contains_standalone_comments(0):
2707         raise CannotSplit("Line does not have any standalone comments")
2708
2709     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2710
2711     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2712         """Append `leaf` to current line or to new line if appending impossible."""
2713         nonlocal current_line
2714         try:
2715             current_line.append_safe(leaf, preformatted=True)
2716         except ValueError:
2717             yield current_line
2718
2719             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2720             current_line.append(leaf)
2721
2722     for leaf in line.leaves:
2723         yield from append_to_line(leaf)
2724
2725         for comment_after in line.comments_after(leaf):
2726             yield from append_to_line(comment_after)
2727
2728     if current_line:
2729         yield current_line
2730
2731
2732 def is_import(leaf: Leaf) -> bool:
2733     """Return True if the given leaf starts an import statement."""
2734     p = leaf.parent
2735     t = leaf.type
2736     v = leaf.value
2737     return bool(
2738         t == token.NAME
2739         and (
2740             (v == "import" and p and p.type == syms.import_name)
2741             or (v == "from" and p and p.type == syms.import_from)
2742         )
2743     )
2744
2745
2746 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
2747     """Return True if the given leaf is a special comment.
2748     Only returns true for type comments for now."""
2749     t = leaf.type
2750     v = leaf.value
2751     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith(
2752         "# type:" + suffix
2753     )
2754
2755
2756 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2757     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2758     else.
2759
2760     Note: don't use backslashes for formatting or you'll lose your voting rights.
2761     """
2762     if not inside_brackets:
2763         spl = leaf.prefix.split("#")
2764         if "\\" not in spl[0]:
2765             nl_count = spl[-1].count("\n")
2766             if len(spl) > 1:
2767                 nl_count -= 1
2768             leaf.prefix = "\n" * nl_count
2769             return
2770
2771     leaf.prefix = ""
2772
2773
2774 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2775     """Make all string prefixes lowercase.
2776
2777     If remove_u_prefix is given, also removes any u prefix from the string.
2778
2779     Note: Mutates its argument.
2780     """
2781     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2782     assert match is not None, f"failed to match string {leaf.value!r}"
2783     orig_prefix = match.group(1)
2784     new_prefix = orig_prefix.lower()
2785     if remove_u_prefix:
2786         new_prefix = new_prefix.replace("u", "")
2787     leaf.value = f"{new_prefix}{match.group(2)}"
2788
2789
2790 def normalize_string_quotes(leaf: Leaf) -> None:
2791     """Prefer double quotes but only if it doesn't cause more escaping.
2792
2793     Adds or removes backslashes as appropriate. Doesn't parse and fix
2794     strings nested in f-strings (yet).
2795
2796     Note: Mutates its argument.
2797     """
2798     value = leaf.value.lstrip("furbFURB")
2799     if value[:3] == '"""':
2800         return
2801
2802     elif value[:3] == "'''":
2803         orig_quote = "'''"
2804         new_quote = '"""'
2805     elif value[0] == '"':
2806         orig_quote = '"'
2807         new_quote = "'"
2808     else:
2809         orig_quote = "'"
2810         new_quote = '"'
2811     first_quote_pos = leaf.value.find(orig_quote)
2812     if first_quote_pos == -1:
2813         return  # There's an internal error
2814
2815     prefix = leaf.value[:first_quote_pos]
2816     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2817     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2818     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2819     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2820     if "r" in prefix.casefold():
2821         if unescaped_new_quote.search(body):
2822             # There's at least one unescaped new_quote in this raw string
2823             # so converting is impossible
2824             return
2825
2826         # Do not introduce or remove backslashes in raw strings
2827         new_body = body
2828     else:
2829         # remove unnecessary escapes
2830         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2831         if body != new_body:
2832             # Consider the string without unnecessary escapes as the original
2833             body = new_body
2834             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2835         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2836         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2837     if "f" in prefix.casefold():
2838         matches = re.findall(
2839             r"""
2840             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2841                 ([^{].*?)  # contents of the brackets except if begins with {{
2842             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2843             """,
2844             new_body,
2845             re.VERBOSE,
2846         )
2847         for m in matches:
2848             if "\\" in str(m):
2849                 # Do not introduce backslashes in interpolated expressions
2850                 return
2851     if new_quote == '"""' and new_body[-1:] == '"':
2852         # edge case:
2853         new_body = new_body[:-1] + '\\"'
2854     orig_escape_count = body.count("\\")
2855     new_escape_count = new_body.count("\\")
2856     if new_escape_count > orig_escape_count:
2857         return  # Do not introduce more escaping
2858
2859     if new_escape_count == orig_escape_count and orig_quote == '"':
2860         return  # Prefer double quotes
2861
2862     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2863
2864
2865 def normalize_numeric_literal(leaf: Leaf) -> None:
2866     """Normalizes numeric (float, int, and complex) literals.
2867
2868     All letters used in the representation are normalized to lowercase (except
2869     in Python 2 long literals).
2870     """
2871     text = leaf.value.lower()
2872     if text.startswith(("0o", "0b")):
2873         # Leave octal and binary literals alone.
2874         pass
2875     elif text.startswith("0x"):
2876         # Change hex literals to upper case.
2877         before, after = text[:2], text[2:]
2878         text = f"{before}{after.upper()}"
2879     elif "e" in text:
2880         before, after = text.split("e")
2881         sign = ""
2882         if after.startswith("-"):
2883             after = after[1:]
2884             sign = "-"
2885         elif after.startswith("+"):
2886             after = after[1:]
2887         before = format_float_or_int_string(before)
2888         text = f"{before}e{sign}{after}"
2889     elif text.endswith(("j", "l")):
2890         number = text[:-1]
2891         suffix = text[-1]
2892         # Capitalize in "2L" because "l" looks too similar to "1".
2893         if suffix == "l":
2894             suffix = "L"
2895         text = f"{format_float_or_int_string(number)}{suffix}"
2896     else:
2897         text = format_float_or_int_string(text)
2898     leaf.value = text
2899
2900
2901 def format_float_or_int_string(text: str) -> str:
2902     """Formats a float string like "1.0"."""
2903     if "." not in text:
2904         return text
2905
2906     before, after = text.split(".")
2907     return f"{before or 0}.{after or 0}"
2908
2909
2910 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2911     """Make existing optional parentheses invisible or create new ones.
2912
2913     `parens_after` is a set of string leaf values immediately after which parens
2914     should be put.
2915
2916     Standardizes on visible parentheses for single-element tuples, and keeps
2917     existing visible parentheses for other tuples and generator expressions.
2918     """
2919     for pc in list_comments(node.prefix, is_endmarker=False):
2920         if pc.value in FMT_OFF:
2921             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2922             return
2923
2924     check_lpar = False
2925     for index, child in enumerate(list(node.children)):
2926         # Add parentheses around long tuple unpacking in assignments.
2927         if (
2928             index == 0
2929             and isinstance(child, Node)
2930             and child.type == syms.testlist_star_expr
2931         ):
2932             check_lpar = True
2933
2934         if check_lpar:
2935             if is_walrus_assignment(child):
2936                 continue
2937             if child.type == syms.atom:
2938                 # Determines if the underlying atom should be surrounded with
2939                 # invisible params - also makes parens invisible recursively
2940                 # within the atom and removes repeated invisible parens within
2941                 # the atom
2942                 should_surround_with_parens = maybe_make_parens_invisible_in_atom(
2943                     child, parent=node
2944                 )
2945
2946                 if should_surround_with_parens:
2947                     lpar = Leaf(token.LPAR, "")
2948                     rpar = Leaf(token.RPAR, "")
2949                     index = child.remove() or 0
2950                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2951             elif is_one_tuple(child):
2952                 # wrap child in visible parentheses
2953                 lpar = Leaf(token.LPAR, "(")
2954                 rpar = Leaf(token.RPAR, ")")
2955                 child.remove()
2956                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2957             elif node.type == syms.import_from:
2958                 # "import from" nodes store parentheses directly as part of
2959                 # the statement
2960                 if child.type == token.LPAR:
2961                     # make parentheses invisible
2962                     child.value = ""  # type: ignore
2963                     node.children[-1].value = ""  # type: ignore
2964                 elif child.type != token.STAR:
2965                     # insert invisible parentheses
2966                     node.insert_child(index, Leaf(token.LPAR, ""))
2967                     node.append_child(Leaf(token.RPAR, ""))
2968                 break
2969
2970             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2971                 # wrap child in invisible parentheses
2972                 lpar = Leaf(token.LPAR, "")
2973                 rpar = Leaf(token.RPAR, "")
2974                 index = child.remove() or 0
2975                 prefix = child.prefix
2976                 child.prefix = ""
2977                 new_child = Node(syms.atom, [lpar, child, rpar])
2978                 new_child.prefix = prefix
2979                 node.insert_child(index, new_child)
2980
2981         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2982
2983
2984 def normalize_fmt_off(node: Node) -> None:
2985     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2986     try_again = True
2987     while try_again:
2988         try_again = convert_one_fmt_off_pair(node)
2989
2990
2991 def convert_one_fmt_off_pair(node: Node) -> bool:
2992     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2993
2994     Returns True if a pair was converted.
2995     """
2996     for leaf in node.leaves():
2997         previous_consumed = 0
2998         for comment in list_comments(leaf.prefix, is_endmarker=False):
2999             if comment.value in FMT_OFF:
3000                 # We only want standalone comments. If there's no previous leaf or
3001                 # the previous leaf is indentation, it's a standalone comment in
3002                 # disguise.
3003                 if comment.type != STANDALONE_COMMENT:
3004                     prev = preceding_leaf(leaf)
3005                     if prev and prev.type not in WHITESPACE:
3006                         continue
3007
3008                 ignored_nodes = list(generate_ignored_nodes(leaf))
3009                 if not ignored_nodes:
3010                     continue
3011
3012                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
3013                 parent = first.parent
3014                 prefix = first.prefix
3015                 first.prefix = prefix[comment.consumed :]
3016                 hidden_value = (
3017                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
3018                 )
3019                 if hidden_value.endswith("\n"):
3020                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
3021                     # leaf (possibly followed by a DEDENT).
3022                     hidden_value = hidden_value[:-1]
3023                 first_idx = None
3024                 for ignored in ignored_nodes:
3025                     index = ignored.remove()
3026                     if first_idx is None:
3027                         first_idx = index
3028                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
3029                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
3030                 parent.insert_child(
3031                     first_idx,
3032                     Leaf(
3033                         STANDALONE_COMMENT,
3034                         hidden_value,
3035                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3036                     ),
3037                 )
3038                 return True
3039
3040             previous_consumed = comment.consumed
3041
3042     return False
3043
3044
3045 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3046     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3047
3048     Stops at the end of the block.
3049     """
3050     container: Optional[LN] = container_of(leaf)
3051     while container is not None and container.type != token.ENDMARKER:
3052         for comment in list_comments(container.prefix, is_endmarker=False):
3053             if comment.value in FMT_ON:
3054                 return
3055
3056         yield container
3057
3058         container = container.next_sibling
3059
3060
3061 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3062     """If it's safe, make the parens in the atom `node` invisible, recursively.
3063     Additionally, remove repeated, adjacent invisible parens from the atom `node`
3064     as they are redundant.
3065
3066     Returns whether the node should itself be wrapped in invisible parentheses.
3067
3068     """
3069     if (
3070         node.type != syms.atom
3071         or is_empty_tuple(node)
3072         or is_one_tuple(node)
3073         or (is_yield(node) and parent.type != syms.expr_stmt)
3074         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3075     ):
3076         return False
3077
3078     first = node.children[0]
3079     last = node.children[-1]
3080     if first.type == token.LPAR and last.type == token.RPAR:
3081         middle = node.children[1]
3082         # make parentheses invisible
3083         first.value = ""  # type: ignore
3084         last.value = ""  # type: ignore
3085         maybe_make_parens_invisible_in_atom(middle, parent=parent)
3086
3087         if is_atom_with_invisible_parens(middle):
3088             # Strip the invisible parens from `middle` by replacing
3089             # it with the child in-between the invisible parens
3090             middle.replace(middle.children[1])
3091
3092         return False
3093
3094     return True
3095
3096
3097 def is_atom_with_invisible_parens(node: LN) -> bool:
3098     """Given a `LN`, determines whether it's an atom `node` with invisible
3099     parens. Useful in dedupe-ing and normalizing parens.
3100     """
3101     if isinstance(node, Leaf) or node.type != syms.atom:
3102         return False
3103
3104     first, last = node.children[0], node.children[-1]
3105     return (
3106         isinstance(first, Leaf)
3107         and first.type == token.LPAR
3108         and first.value == ""
3109         and isinstance(last, Leaf)
3110         and last.type == token.RPAR
3111         and last.value == ""
3112     )
3113
3114
3115 def is_empty_tuple(node: LN) -> bool:
3116     """Return True if `node` holds an empty tuple."""
3117     return (
3118         node.type == syms.atom
3119         and len(node.children) == 2
3120         and node.children[0].type == token.LPAR
3121         and node.children[1].type == token.RPAR
3122     )
3123
3124
3125 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3126     """Returns `wrapped` if `node` is of the shape ( wrapped ).
3127
3128     Parenthesis can be optional. Returns None otherwise"""
3129     if len(node.children) != 3:
3130         return None
3131     lpar, wrapped, rpar = node.children
3132     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3133         return None
3134
3135     return wrapped
3136
3137
3138 def is_one_tuple(node: LN) -> bool:
3139     """Return True if `node` holds a tuple with one element, with or without parens."""
3140     if node.type == syms.atom:
3141         gexp = unwrap_singleton_parenthesis(node)
3142         if gexp is None or gexp.type != syms.testlist_gexp:
3143             return False
3144
3145         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3146
3147     return (
3148         node.type in IMPLICIT_TUPLE
3149         and len(node.children) == 2
3150         and node.children[1].type == token.COMMA
3151     )
3152
3153
3154 def is_walrus_assignment(node: LN) -> bool:
3155     """Return True iff `node` is of the shape ( test := test )"""
3156     inner = unwrap_singleton_parenthesis(node)
3157     return inner is not None and inner.type == syms.namedexpr_test
3158
3159
3160 def is_yield(node: LN) -> bool:
3161     """Return True if `node` holds a `yield` or `yield from` expression."""
3162     if node.type == syms.yield_expr:
3163         return True
3164
3165     if node.type == token.NAME and node.value == "yield":  # type: ignore
3166         return True
3167
3168     if node.type != syms.atom:
3169         return False
3170
3171     if len(node.children) != 3:
3172         return False
3173
3174     lpar, expr, rpar = node.children
3175     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3176         return is_yield(expr)
3177
3178     return False
3179
3180
3181 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3182     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3183
3184     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3185     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3186     extended iterable unpacking (PEP 3132) and additional unpacking
3187     generalizations (PEP 448).
3188     """
3189     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3190         return False
3191
3192     p = leaf.parent
3193     if p.type == syms.star_expr:
3194         # Star expressions are also used as assignment targets in extended
3195         # iterable unpacking (PEP 3132).  See what its parent is instead.
3196         if not p.parent:
3197             return False
3198
3199         p = p.parent
3200
3201     return p.type in within
3202
3203
3204 def is_multiline_string(leaf: Leaf) -> bool:
3205     """Return True if `leaf` is a multiline string that actually spans many lines."""
3206     value = leaf.value.lstrip("furbFURB")
3207     return value[:3] in {'"""', "'''"} and "\n" in value
3208
3209
3210 def is_stub_suite(node: Node) -> bool:
3211     """Return True if `node` is a suite with a stub body."""
3212     if (
3213         len(node.children) != 4
3214         or node.children[0].type != token.NEWLINE
3215         or node.children[1].type != token.INDENT
3216         or node.children[3].type != token.DEDENT
3217     ):
3218         return False
3219
3220     return is_stub_body(node.children[2])
3221
3222
3223 def is_stub_body(node: LN) -> bool:
3224     """Return True if `node` is a simple statement containing an ellipsis."""
3225     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3226         return False
3227
3228     if len(node.children) != 2:
3229         return False
3230
3231     child = node.children[0]
3232     return (
3233         child.type == syms.atom
3234         and len(child.children) == 3
3235         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3236     )
3237
3238
3239 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3240     """Return maximum delimiter priority inside `node`.
3241
3242     This is specific to atoms with contents contained in a pair of parentheses.
3243     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3244     """
3245     if node.type != syms.atom:
3246         return 0
3247
3248     first = node.children[0]
3249     last = node.children[-1]
3250     if not (first.type == token.LPAR and last.type == token.RPAR):
3251         return 0
3252
3253     bt = BracketTracker()
3254     for c in node.children[1:-1]:
3255         if isinstance(c, Leaf):
3256             bt.mark(c)
3257         else:
3258             for leaf in c.leaves():
3259                 bt.mark(leaf)
3260     try:
3261         return bt.max_delimiter_priority()
3262
3263     except ValueError:
3264         return 0
3265
3266
3267 def ensure_visible(leaf: Leaf) -> None:
3268     """Make sure parentheses are visible.
3269
3270     They could be invisible as part of some statements (see
3271     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3272     """
3273     if leaf.type == token.LPAR:
3274         leaf.value = "("
3275     elif leaf.type == token.RPAR:
3276         leaf.value = ")"
3277
3278
3279 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3280     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3281
3282     if not (
3283         opening_bracket.parent
3284         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3285         and opening_bracket.value in "[{("
3286     ):
3287         return False
3288
3289     try:
3290         last_leaf = line.leaves[-1]
3291         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3292         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3293     except (IndexError, ValueError):
3294         return False
3295
3296     return max_priority == COMMA_PRIORITY
3297
3298
3299 def get_features_used(node: Node) -> Set[Feature]:
3300     """Return a set of (relatively) new Python features used in this file.
3301
3302     Currently looking for:
3303     - f-strings;
3304     - underscores in numeric literals;
3305     - trailing commas after * or ** in function signatures and calls;
3306     - positional only arguments in function signatures and lambdas;
3307     """
3308     features: Set[Feature] = set()
3309     for n in node.pre_order():
3310         if n.type == token.STRING:
3311             value_head = n.value[:2]  # type: ignore
3312             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3313                 features.add(Feature.F_STRINGS)
3314
3315         elif n.type == token.NUMBER:
3316             if "_" in n.value:  # type: ignore
3317                 features.add(Feature.NUMERIC_UNDERSCORES)
3318
3319         elif n.type == token.SLASH:
3320             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3321                 features.add(Feature.POS_ONLY_ARGUMENTS)
3322
3323         elif n.type == token.COLONEQUAL:
3324             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3325
3326         elif (
3327             n.type in {syms.typedargslist, syms.arglist}
3328             and n.children
3329             and n.children[-1].type == token.COMMA
3330         ):
3331             if n.type == syms.typedargslist:
3332                 feature = Feature.TRAILING_COMMA_IN_DEF
3333             else:
3334                 feature = Feature.TRAILING_COMMA_IN_CALL
3335
3336             for ch in n.children:
3337                 if ch.type in STARS:
3338                     features.add(feature)
3339
3340                 if ch.type == syms.argument:
3341                     for argch in ch.children:
3342                         if argch.type in STARS:
3343                             features.add(feature)
3344
3345     return features
3346
3347
3348 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3349     """Detect the version to target based on the nodes used."""
3350     features = get_features_used(node)
3351     return {
3352         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3353     }
3354
3355
3356 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3357     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3358
3359     Brackets can be omitted if the entire trailer up to and including
3360     a preceding closing bracket fits in one line.
3361
3362     Yielded sets are cumulative (contain results of previous yields, too).  First
3363     set is empty.
3364     """
3365
3366     omit: Set[LeafID] = set()
3367     yield omit
3368
3369     length = 4 * line.depth
3370     opening_bracket = None
3371     closing_bracket = None
3372     inner_brackets: Set[LeafID] = set()
3373     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3374         length += leaf_length
3375         if length > line_length:
3376             break
3377
3378         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3379         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3380             break
3381
3382         if opening_bracket:
3383             if leaf is opening_bracket:
3384                 opening_bracket = None
3385             elif leaf.type in CLOSING_BRACKETS:
3386                 inner_brackets.add(id(leaf))
3387         elif leaf.type in CLOSING_BRACKETS:
3388             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3389                 # Empty brackets would fail a split so treat them as "inner"
3390                 # brackets (e.g. only add them to the `omit` set if another
3391                 # pair of brackets was good enough.
3392                 inner_brackets.add(id(leaf))
3393                 continue
3394
3395             if closing_bracket:
3396                 omit.add(id(closing_bracket))
3397                 omit.update(inner_brackets)
3398                 inner_brackets.clear()
3399                 yield omit
3400
3401             if leaf.value:
3402                 opening_bracket = leaf.opening_bracket
3403                 closing_bracket = leaf
3404
3405
3406 def get_future_imports(node: Node) -> Set[str]:
3407     """Return a set of __future__ imports in the file."""
3408     imports: Set[str] = set()
3409
3410     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3411         for child in children:
3412             if isinstance(child, Leaf):
3413                 if child.type == token.NAME:
3414                     yield child.value
3415             elif child.type == syms.import_as_name:
3416                 orig_name = child.children[0]
3417                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3418                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3419                 yield orig_name.value
3420             elif child.type == syms.import_as_names:
3421                 yield from get_imports_from_children(child.children)
3422             else:
3423                 raise AssertionError("Invalid syntax parsing imports")
3424
3425     for child in node.children:
3426         if child.type != syms.simple_stmt:
3427             break
3428         first_child = child.children[0]
3429         if isinstance(first_child, Leaf):
3430             # Continue looking if we see a docstring; otherwise stop.
3431             if (
3432                 len(child.children) == 2
3433                 and first_child.type == token.STRING
3434                 and child.children[1].type == token.NEWLINE
3435             ):
3436                 continue
3437             else:
3438                 break
3439         elif first_child.type == syms.import_from:
3440             module_name = first_child.children[1]
3441             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3442                 break
3443             imports |= set(get_imports_from_children(first_child.children[3:]))
3444         else:
3445             break
3446     return imports
3447
3448
3449 def gen_python_files_in_dir(
3450     path: Path,
3451     root: Path,
3452     include: Pattern[str],
3453     exclude: Pattern[str],
3454     report: "Report",
3455 ) -> Iterator[Path]:
3456     """Generate all files under `path` whose paths are not excluded by the
3457     `exclude` regex, but are included by the `include` regex.
3458
3459     Symbolic links pointing outside of the `root` directory are ignored.
3460
3461     `report` is where output about exclusions goes.
3462     """
3463     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3464     for child in path.iterdir():
3465         try:
3466             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3467         except ValueError:
3468             if child.is_symlink():
3469                 report.path_ignored(
3470                     child, f"is a symbolic link that points outside {root}"
3471                 )
3472                 continue
3473
3474             raise
3475
3476         if child.is_dir():
3477             normalized_path += "/"
3478         exclude_match = exclude.search(normalized_path)
3479         if exclude_match and exclude_match.group(0):
3480             report.path_ignored(child, f"matches the --exclude regular expression")
3481             continue
3482
3483         if child.is_dir():
3484             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3485
3486         elif child.is_file():
3487             include_match = include.search(normalized_path)
3488             if include_match:
3489                 yield child
3490
3491
3492 @lru_cache()
3493 def find_project_root(srcs: Iterable[str]) -> Path:
3494     """Return a directory containing .git, .hg, or pyproject.toml.
3495
3496     That directory can be one of the directories passed in `srcs` or their
3497     common parent.
3498
3499     If no directory in the tree contains a marker that would specify it's the
3500     project root, the root of the file system is returned.
3501     """
3502     if not srcs:
3503         return Path("/").resolve()
3504
3505     common_base = min(Path(src).resolve() for src in srcs)
3506     if common_base.is_dir():
3507         # Append a fake file so `parents` below returns `common_base_dir`, too.
3508         common_base /= "fake-file"
3509     for directory in common_base.parents:
3510         if (directory / ".git").is_dir():
3511             return directory
3512
3513         if (directory / ".hg").is_dir():
3514             return directory
3515
3516         if (directory / "pyproject.toml").is_file():
3517             return directory
3518
3519     return directory
3520
3521
3522 @dataclass
3523 class Report:
3524     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3525
3526     check: bool = False
3527     quiet: bool = False
3528     verbose: bool = False
3529     change_count: int = 0
3530     same_count: int = 0
3531     failure_count: int = 0
3532
3533     def done(self, src: Path, changed: Changed) -> None:
3534         """Increment the counter for successful reformatting. Write out a message."""
3535         if changed is Changed.YES:
3536             reformatted = "would reformat" if self.check else "reformatted"
3537             if self.verbose or not self.quiet:
3538                 out(f"{reformatted} {src}")
3539             self.change_count += 1
3540         else:
3541             if self.verbose:
3542                 if changed is Changed.NO:
3543                     msg = f"{src} already well formatted, good job."
3544                 else:
3545                     msg = f"{src} wasn't modified on disk since last run."
3546                 out(msg, bold=False)
3547             self.same_count += 1
3548
3549     def failed(self, src: Path, message: str) -> None:
3550         """Increment the counter for failed reformatting. Write out a message."""
3551         err(f"error: cannot format {src}: {message}")
3552         self.failure_count += 1
3553
3554     def path_ignored(self, path: Path, message: str) -> None:
3555         if self.verbose:
3556             out(f"{path} ignored: {message}", bold=False)
3557
3558     @property
3559     def return_code(self) -> int:
3560         """Return the exit code that the app should use.
3561
3562         This considers the current state of changed files and failures:
3563         - if there were any failures, return 123;
3564         - if any files were changed and --check is being used, return 1;
3565         - otherwise return 0.
3566         """
3567         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3568         # 126 we have special return codes reserved by the shell.
3569         if self.failure_count:
3570             return 123
3571
3572         elif self.change_count and self.check:
3573             return 1
3574
3575         return 0
3576
3577     def __str__(self) -> str:
3578         """Render a color report of the current state.
3579
3580         Use `click.unstyle` to remove colors.
3581         """
3582         if self.check:
3583             reformatted = "would be reformatted"
3584             unchanged = "would be left unchanged"
3585             failed = "would fail to reformat"
3586         else:
3587             reformatted = "reformatted"
3588             unchanged = "left unchanged"
3589             failed = "failed to reformat"
3590         report = []
3591         if self.change_count:
3592             s = "s" if self.change_count > 1 else ""
3593             report.append(
3594                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3595             )
3596         if self.same_count:
3597             s = "s" if self.same_count > 1 else ""
3598             report.append(f"{self.same_count} file{s} {unchanged}")
3599         if self.failure_count:
3600             s = "s" if self.failure_count > 1 else ""
3601             report.append(
3602                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3603             )
3604         return ", ".join(report) + "."
3605
3606
3607 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3608     filename = "<unknown>"
3609     if sys.version_info >= (3, 8):
3610         # TODO: support Python 4+ ;)
3611         for minor_version in range(sys.version_info[1], 4, -1):
3612             try:
3613                 return ast.parse(src, filename, feature_version=(3, minor_version))
3614             except SyntaxError:
3615                 continue
3616     else:
3617         for feature_version in (7, 6):
3618             try:
3619                 return ast3.parse(src, filename, feature_version=feature_version)
3620             except SyntaxError:
3621                 continue
3622
3623     return ast27.parse(src)
3624
3625
3626 def _fixup_ast_constants(
3627     node: Union[ast.AST, ast3.AST, ast27.AST]
3628 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3629     """Map ast nodes deprecated in 3.8 to Constant."""
3630     # casts are required until this is released:
3631     # https://github.com/python/typeshed/pull/3142
3632     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3633         return cast(ast.AST, ast.Constant(value=node.s))
3634     elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3635         return cast(ast.AST, ast.Constant(value=node.n))
3636     elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3637         return cast(ast.AST, ast.Constant(value=node.value))
3638     return node
3639
3640
3641 def assert_equivalent(src: str, dst: str) -> None:
3642     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3643
3644     def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3645         """Simple visitor generating strings to compare ASTs by content."""
3646
3647         node = _fixup_ast_constants(node)
3648
3649         yield f"{'  ' * depth}{node.__class__.__name__}("
3650
3651         for field in sorted(node._fields):
3652             # TypeIgnore has only one field 'lineno' which breaks this comparison
3653             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3654             if sys.version_info >= (3, 8):
3655                 type_ignore_classes += (ast.TypeIgnore,)
3656             if isinstance(node, type_ignore_classes):
3657                 break
3658
3659             try:
3660                 value = getattr(node, field)
3661             except AttributeError:
3662                 continue
3663
3664             yield f"{'  ' * (depth+1)}{field}="
3665
3666             if isinstance(value, list):
3667                 for item in value:
3668                     # Ignore nested tuples within del statements, because we may insert
3669                     # parentheses and they change the AST.
3670                     if (
3671                         field == "targets"
3672                         and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3673                         and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3674                     ):
3675                         for item in item.elts:
3676                             yield from _v(item, depth + 2)
3677                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3678                         yield from _v(item, depth + 2)
3679
3680             elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3681                 yield from _v(value, depth + 2)
3682
3683             else:
3684                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3685
3686         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3687
3688     try:
3689         src_ast = parse_ast(src)
3690     except Exception as exc:
3691         raise AssertionError(
3692             f"cannot use --safe with this file; failed to parse source file.  "
3693             f"AST error message: {exc}"
3694         )
3695
3696     try:
3697         dst_ast = parse_ast(dst)
3698     except Exception as exc:
3699         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3700         raise AssertionError(
3701             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3702             f"Please report a bug on https://github.com/psf/black/issues.  "
3703             f"This invalid output might be helpful: {log}"
3704         ) from None
3705
3706     src_ast_str = "\n".join(_v(src_ast))
3707     dst_ast_str = "\n".join(_v(dst_ast))
3708     if src_ast_str != dst_ast_str:
3709         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3710         raise AssertionError(
3711             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3712             f"the source.  "
3713             f"Please report a bug on https://github.com/psf/black/issues.  "
3714             f"This diff might be helpful: {log}"
3715         ) from None
3716
3717
3718 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3719     """Raise AssertionError if `dst` reformats differently the second time."""
3720     newdst = format_str(dst, mode=mode)
3721     if dst != newdst:
3722         log = dump_to_file(
3723             diff(src, dst, "source", "first pass"),
3724             diff(dst, newdst, "first pass", "second pass"),
3725         )
3726         raise AssertionError(
3727             f"INTERNAL ERROR: Black produced different code on the second pass "
3728             f"of the formatter.  "
3729             f"Please report a bug on https://github.com/psf/black/issues.  "
3730             f"This diff might be helpful: {log}"
3731         ) from None
3732
3733
3734 def dump_to_file(*output: str) -> str:
3735     """Dump `output` to a temporary file. Return path to the file."""
3736     with tempfile.NamedTemporaryFile(
3737         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3738     ) as f:
3739         for lines in output:
3740             f.write(lines)
3741             if lines and lines[-1] != "\n":
3742                 f.write("\n")
3743     return f.name
3744
3745
3746 @contextmanager
3747 def nullcontext() -> Iterator[None]:
3748     """Return context manager that does nothing.
3749     Similar to `nullcontext` from python 3.7"""
3750     yield
3751
3752
3753 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3754     """Return a unified diff string between strings `a` and `b`."""
3755     import difflib
3756
3757     a_lines = [line + "\n" for line in a.split("\n")]
3758     b_lines = [line + "\n" for line in b.split("\n")]
3759     return "".join(
3760         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3761     )
3762
3763
3764 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3765     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3766     err("Aborted!")
3767     for task in tasks:
3768         task.cancel()
3769
3770
3771 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3772     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3773     try:
3774         if sys.version_info[:2] >= (3, 7):
3775             all_tasks = asyncio.all_tasks
3776         else:
3777             all_tasks = asyncio.Task.all_tasks
3778         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3779         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3780         if not to_cancel:
3781             return
3782
3783         for task in to_cancel:
3784             task.cancel()
3785         loop.run_until_complete(
3786             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3787         )
3788     finally:
3789         # `concurrent.futures.Future` objects cannot be cancelled once they
3790         # are already running. There might be some when the `shutdown()` happened.
3791         # Silence their logger's spew about the event loop being closed.
3792         cf_logger = logging.getLogger("concurrent.futures")
3793         cf_logger.setLevel(logging.CRITICAL)
3794         loop.close()
3795
3796
3797 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3798     """Replace `regex` with `replacement` twice on `original`.
3799
3800     This is used by string normalization to perform replaces on
3801     overlapping matches.
3802     """
3803     return regex.sub(replacement, regex.sub(replacement, original))
3804
3805
3806 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3807     """Compile a regular expression string in `regex`.
3808
3809     If it contains newlines, use verbose mode.
3810     """
3811     if "\n" in regex:
3812         regex = "(?x)" + regex
3813     compiled: Pattern[str] = re.compile(regex)
3814     return compiled
3815
3816
3817 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3818     """Like `reversed(enumerate(sequence))` if that were possible."""
3819     index = len(sequence) - 1
3820     for element in reversed(sequence):
3821         yield (index, element)
3822         index -= 1
3823
3824
3825 def enumerate_with_length(
3826     line: Line, reversed: bool = False
3827 ) -> Iterator[Tuple[Index, Leaf, int]]:
3828     """Return an enumeration of leaves with their length.
3829
3830     Stops prematurely on multiline strings and standalone comments.
3831     """
3832     op = cast(
3833         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3834         enumerate_reversed if reversed else enumerate,
3835     )
3836     for index, leaf in op(line.leaves):
3837         length = len(leaf.prefix) + len(leaf.value)
3838         if "\n" in leaf.value:
3839             return  # Multiline strings, we can't continue.
3840
3841         for comment in line.comments_after(leaf):
3842             length += len(comment.value)
3843
3844         yield index, leaf, length
3845
3846
3847 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3848     """Return True if `line` is no longer than `line_length`.
3849
3850     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3851     """
3852     if not line_str:
3853         line_str = str(line).strip("\n")
3854     return (
3855         len(line_str) <= line_length
3856         and "\n" not in line_str  # multiline strings
3857         and not line.contains_standalone_comments()
3858     )
3859
3860
3861 def can_be_split(line: Line) -> bool:
3862     """Return False if the line cannot be split *for sure*.
3863
3864     This is not an exhaustive search but a cheap heuristic that we can use to
3865     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3866     in unnecessary parentheses).
3867     """
3868     leaves = line.leaves
3869     if len(leaves) < 2:
3870         return False
3871
3872     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3873         call_count = 0
3874         dot_count = 0
3875         next = leaves[-1]
3876         for leaf in leaves[-2::-1]:
3877             if leaf.type in OPENING_BRACKETS:
3878                 if next.type not in CLOSING_BRACKETS:
3879                     return False
3880
3881                 call_count += 1
3882             elif leaf.type == token.DOT:
3883                 dot_count += 1
3884             elif leaf.type == token.NAME:
3885                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3886                     return False
3887
3888             elif leaf.type not in CLOSING_BRACKETS:
3889                 return False
3890
3891             if dot_count > 1 and call_count > 1:
3892                 return False
3893
3894     return True
3895
3896
3897 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3898     """Does `line` have a shape safe to reformat without optional parens around it?
3899
3900     Returns True for only a subset of potentially nice looking formattings but
3901     the point is to not return false positives that end up producing lines that
3902     are too long.
3903     """
3904     bt = line.bracket_tracker
3905     if not bt.delimiters:
3906         # Without delimiters the optional parentheses are useless.
3907         return True
3908
3909     max_priority = bt.max_delimiter_priority()
3910     if bt.delimiter_count_with_priority(max_priority) > 1:
3911         # With more than one delimiter of a kind the optional parentheses read better.
3912         return False
3913
3914     if max_priority == DOT_PRIORITY:
3915         # A single stranded method call doesn't require optional parentheses.
3916         return True
3917
3918     assert len(line.leaves) >= 2, "Stranded delimiter"
3919
3920     first = line.leaves[0]
3921     second = line.leaves[1]
3922     penultimate = line.leaves[-2]
3923     last = line.leaves[-1]
3924
3925     # With a single delimiter, omit if the expression starts or ends with
3926     # a bracket.
3927     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3928         remainder = False
3929         length = 4 * line.depth
3930         for _index, leaf, leaf_length in enumerate_with_length(line):
3931             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3932                 remainder = True
3933             if remainder:
3934                 length += leaf_length
3935                 if length > line_length:
3936                     break
3937
3938                 if leaf.type in OPENING_BRACKETS:
3939                     # There are brackets we can further split on.
3940                     remainder = False
3941
3942         else:
3943             # checked the entire string and line length wasn't exceeded
3944             if len(line.leaves) == _index + 1:
3945                 return True
3946
3947         # Note: we are not returning False here because a line might have *both*
3948         # a leading opening bracket and a trailing closing bracket.  If the
3949         # opening bracket doesn't match our rule, maybe the closing will.
3950
3951     if (
3952         last.type == token.RPAR
3953         or last.type == token.RBRACE
3954         or (
3955             # don't use indexing for omitting optional parentheses;
3956             # it looks weird
3957             last.type == token.RSQB
3958             and last.parent
3959             and last.parent.type != syms.trailer
3960         )
3961     ):
3962         if penultimate.type in OPENING_BRACKETS:
3963             # Empty brackets don't help.
3964             return False
3965
3966         if is_multiline_string(first):
3967             # Additional wrapping of a multiline string in this situation is
3968             # unnecessary.
3969             return True
3970
3971         length = 4 * line.depth
3972         seen_other_brackets = False
3973         for _index, leaf, leaf_length in enumerate_with_length(line):
3974             length += leaf_length
3975             if leaf is last.opening_bracket:
3976                 if seen_other_brackets or length <= line_length:
3977                     return True
3978
3979             elif leaf.type in OPENING_BRACKETS:
3980                 # There are brackets we can further split on.
3981                 seen_other_brackets = True
3982
3983     return False
3984
3985
3986 def get_cache_file(mode: FileMode) -> Path:
3987     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3988
3989
3990 def read_cache(mode: FileMode) -> Cache:
3991     """Read the cache if it exists and is well formed.
3992
3993     If it is not well formed, the call to write_cache later should resolve the issue.
3994     """
3995     cache_file = get_cache_file(mode)
3996     if not cache_file.exists():
3997         return {}
3998
3999     with cache_file.open("rb") as fobj:
4000         try:
4001             cache: Cache = pickle.load(fobj)
4002         except pickle.UnpicklingError:
4003             return {}
4004
4005     return cache
4006
4007
4008 def get_cache_info(path: Path) -> CacheInfo:
4009     """Return the information used to check if a file is already formatted or not."""
4010     stat = path.stat()
4011     return stat.st_mtime, stat.st_size
4012
4013
4014 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
4015     """Split an iterable of paths in `sources` into two sets.
4016
4017     The first contains paths of files that modified on disk or are not in the
4018     cache. The other contains paths to non-modified files.
4019     """
4020     todo, done = set(), set()
4021     for src in sources:
4022         src = src.resolve()
4023         if cache.get(src) != get_cache_info(src):
4024             todo.add(src)
4025         else:
4026             done.add(src)
4027     return todo, done
4028
4029
4030 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
4031     """Update the cache file."""
4032     cache_file = get_cache_file(mode)
4033     try:
4034         CACHE_DIR.mkdir(parents=True, exist_ok=True)
4035         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
4036         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4037             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
4038         os.replace(f.name, cache_file)
4039     except OSError:
4040         pass
4041
4042
4043 def patch_click() -> None:
4044     """Make Click not crash.
4045
4046     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4047     default which restricts paths that it can access during the lifetime of the
4048     application.  Click refuses to work in this scenario by raising a RuntimeError.
4049
4050     In case of Black the likelihood that non-ASCII characters are going to be used in
4051     file paths is minimal since it's Python source code.  Moreover, this crash was
4052     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4053     """
4054     try:
4055         from click import core
4056         from click import _unicodefun  # type: ignore
4057     except ModuleNotFoundError:
4058         return
4059
4060     for module in (core, _unicodefun):
4061         if hasattr(module, "_verify_python3_env"):
4062             module._verify_python3_env = lambda: None
4063
4064
4065 def patched_main() -> None:
4066     freeze_support()
4067     patch_click()
4068     main()
4069
4070
4071 if __name__ == "__main__":
4072     patched_main()