]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

f283ffc99b3a50d1e6b225c0a93714c67247fd35
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from contextlib import contextmanager
5 from datetime import datetime
6 from enum import Enum
7 from functools import lru_cache, partial, wraps
8 import io
9 import itertools
10 import logging
11 from multiprocessing import Manager, freeze_support
12 import os
13 from pathlib import Path
14 import pickle
15 import re
16 import signal
17 import sys
18 import tempfile
19 import tokenize
20 import traceback
21 from typing import (
22     Any,
23     Callable,
24     Collection,
25     Dict,
26     Generator,
27     Generic,
28     Iterable,
29     Iterator,
30     List,
31     Optional,
32     Pattern,
33     Sequence,
34     Set,
35     Tuple,
36     TypeVar,
37     Union,
38     cast,
39 )
40
41 from appdirs import user_cache_dir
42 from attr import dataclass, evolve, Factory
43 import click
44 import toml
45 from typed_ast import ast3, ast27
46
47 # lib2to3 fork
48 from blib2to3.pytree import Node, Leaf, type_repr
49 from blib2to3 import pygram, pytree
50 from blib2to3.pgen2 import driver, token
51 from blib2to3.pgen2.grammar import Grammar
52 from blib2to3.pgen2.parse import ParseError
53
54 from _version import version as __version__
55
56 DEFAULT_LINE_LENGTH = 88
57 DEFAULT_EXCLUDES = (
58     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
59 )
60 DEFAULT_INCLUDES = r"\.pyi?$"
61 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
62
63
64 # types
65 FileContent = str
66 Encoding = str
67 NewLine = str
68 Depth = int
69 NodeType = int
70 LeafID = int
71 Priority = int
72 Index = int
73 LN = Union[Leaf, Node]
74 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
75 Timestamp = float
76 FileSize = int
77 CacheInfo = Tuple[Timestamp, FileSize]
78 Cache = Dict[Path, CacheInfo]
79 out = partial(click.secho, bold=True, err=True)
80 err = partial(click.secho, fg="red", err=True)
81
82 pygram.initialize(CACHE_DIR)
83 syms = pygram.python_symbols
84
85
86 class NothingChanged(UserWarning):
87     """Raised when reformatted code is the same as source."""
88
89
90 class CannotSplit(Exception):
91     """A readable split that fits the allotted line length is impossible."""
92
93
94 class InvalidInput(ValueError):
95     """Raised when input source code fails all parse attempts."""
96
97
98 class WriteBack(Enum):
99     NO = 0
100     YES = 1
101     DIFF = 2
102     CHECK = 3
103
104     @classmethod
105     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
106         if check and not diff:
107             return cls.CHECK
108
109         return cls.DIFF if diff else cls.YES
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 class TargetVersion(Enum):
119     PY27 = 2
120     PY33 = 3
121     PY34 = 4
122     PY35 = 5
123     PY36 = 6
124     PY37 = 7
125     PY38 = 8
126
127     def is_python2(self) -> bool:
128         return self is TargetVersion.PY27
129
130
131 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
132
133
134 class Feature(Enum):
135     # All string literals are unicode
136     UNICODE_LITERALS = 1
137     F_STRINGS = 2
138     NUMERIC_UNDERSCORES = 3
139     TRAILING_COMMA_IN_CALL = 4
140     TRAILING_COMMA_IN_DEF = 5
141     # The following two feature-flags are mutually exclusive, and exactly one should be
142     # set for every version of python.
143     ASYNC_IDENTIFIERS = 6
144     ASYNC_KEYWORDS = 7
145     ASSIGNMENT_EXPRESSIONS = 8
146     POS_ONLY_ARGUMENTS = 9
147
148
149 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
150     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
151     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
152     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
153     TargetVersion.PY35: {
154         Feature.UNICODE_LITERALS,
155         Feature.TRAILING_COMMA_IN_CALL,
156         Feature.ASYNC_IDENTIFIERS,
157     },
158     TargetVersion.PY36: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA_IN_CALL,
163         Feature.TRAILING_COMMA_IN_DEF,
164         Feature.ASYNC_IDENTIFIERS,
165     },
166     TargetVersion.PY37: {
167         Feature.UNICODE_LITERALS,
168         Feature.F_STRINGS,
169         Feature.NUMERIC_UNDERSCORES,
170         Feature.TRAILING_COMMA_IN_CALL,
171         Feature.TRAILING_COMMA_IN_DEF,
172         Feature.ASYNC_KEYWORDS,
173     },
174     TargetVersion.PY38: {
175         Feature.UNICODE_LITERALS,
176         Feature.F_STRINGS,
177         Feature.NUMERIC_UNDERSCORES,
178         Feature.TRAILING_COMMA_IN_CALL,
179         Feature.TRAILING_COMMA_IN_DEF,
180         Feature.ASYNC_KEYWORDS,
181         Feature.ASSIGNMENT_EXPRESSIONS,
182         Feature.POS_ONLY_ARGUMENTS,
183     },
184 }
185
186
187 @dataclass
188 class FileMode:
189     target_versions: Set[TargetVersion] = Factory(set)
190     line_length: int = DEFAULT_LINE_LENGTH
191     string_normalization: bool = True
192     is_pyi: bool = False
193
194     def get_cache_key(self) -> str:
195         if self.target_versions:
196             version_str = ",".join(
197                 str(version.value)
198                 for version in sorted(self.target_versions, key=lambda v: v.value)
199             )
200         else:
201             version_str = "-"
202         parts = [
203             version_str,
204             str(self.line_length),
205             str(int(self.string_normalization)),
206             str(int(self.is_pyi)),
207         ]
208         return ".".join(parts)
209
210
211 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
212     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
213
214
215 def read_pyproject_toml(
216     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
217 ) -> Optional[str]:
218     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
219
220     Returns the path to a successfully found and read configuration file, None
221     otherwise.
222     """
223     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
224     if not value:
225         root = find_project_root(ctx.params.get("src", ()))
226         path = root / "pyproject.toml"
227         if path.is_file():
228             value = str(path)
229         else:
230             return None
231
232     try:
233         pyproject_toml = toml.load(value)
234         config = pyproject_toml.get("tool", {}).get("black", {})
235     except (toml.TomlDecodeError, OSError) as e:
236         raise click.FileError(
237             filename=value, hint=f"Error reading configuration file: {e}"
238         )
239
240     if not config:
241         return None
242
243     if ctx.default_map is None:
244         ctx.default_map = {}
245     ctx.default_map.update(  # type: ignore  # bad types in .pyi
246         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
247     )
248     return value
249
250
251 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
252 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
253 @click.option(
254     "-l",
255     "--line-length",
256     type=int,
257     default=DEFAULT_LINE_LENGTH,
258     help="How many characters per line to allow.",
259     show_default=True,
260 )
261 @click.option(
262     "-t",
263     "--target-version",
264     type=click.Choice([v.name.lower() for v in TargetVersion]),
265     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
266     multiple=True,
267     help=(
268         "Python versions that should be supported by Black's output. [default: "
269         "per-file auto-detection]"
270     ),
271 )
272 @click.option(
273     "--py36",
274     is_flag=True,
275     help=(
276         "Allow using Python 3.6-only syntax on all input files.  This will put "
277         "trailing commas in function signatures and calls also after *args and "
278         "**kwargs. Deprecated; use --target-version instead. "
279         "[default: per-file auto-detection]"
280     ),
281 )
282 @click.option(
283     "--pyi",
284     is_flag=True,
285     help=(
286         "Format all input files like typing stubs regardless of file extension "
287         "(useful when piping source on standard input)."
288     ),
289 )
290 @click.option(
291     "-S",
292     "--skip-string-normalization",
293     is_flag=True,
294     help="Don't normalize string quotes or prefixes.",
295 )
296 @click.option(
297     "--check",
298     is_flag=True,
299     help=(
300         "Don't write the files back, just return the status.  Return code 0 "
301         "means nothing would change.  Return code 1 means some files would be "
302         "reformatted.  Return code 123 means there was an internal error."
303     ),
304 )
305 @click.option(
306     "--diff",
307     is_flag=True,
308     help="Don't write the files back, just output a diff for each file on stdout.",
309 )
310 @click.option(
311     "--fast/--safe",
312     is_flag=True,
313     help="If --fast given, skip temporary sanity checks. [default: --safe]",
314 )
315 @click.option(
316     "--include",
317     type=str,
318     default=DEFAULT_INCLUDES,
319     help=(
320         "A regular expression that matches files and directories that should be "
321         "included on recursive searches.  An empty value means all files are "
322         "included regardless of the name.  Use forward slashes for directories on "
323         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
324         "later."
325     ),
326     show_default=True,
327 )
328 @click.option(
329     "--exclude",
330     type=str,
331     default=DEFAULT_EXCLUDES,
332     help=(
333         "A regular expression that matches files and directories that should be "
334         "excluded on recursive searches.  An empty value means no paths are excluded. "
335         "Use forward slashes for directories on all platforms (Windows, too).  "
336         "Exclusions are calculated first, inclusions later."
337     ),
338     show_default=True,
339 )
340 @click.option(
341     "-q",
342     "--quiet",
343     is_flag=True,
344     help=(
345         "Don't emit non-error messages to stderr. Errors are still emitted; "
346         "silence those with 2>/dev/null."
347     ),
348 )
349 @click.option(
350     "-v",
351     "--verbose",
352     is_flag=True,
353     help=(
354         "Also emit messages to stderr about files that were not changed or were "
355         "ignored due to --exclude=."
356     ),
357 )
358 @click.version_option(version=__version__)
359 @click.argument(
360     "src",
361     nargs=-1,
362     type=click.Path(
363         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
364     ),
365     is_eager=True,
366 )
367 @click.option(
368     "--config",
369     type=click.Path(
370         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
371     ),
372     is_eager=True,
373     callback=read_pyproject_toml,
374     help="Read configuration from PATH.",
375 )
376 @click.pass_context
377 def main(
378     ctx: click.Context,
379     code: Optional[str],
380     line_length: int,
381     target_version: List[TargetVersion],
382     check: bool,
383     diff: bool,
384     fast: bool,
385     pyi: bool,
386     py36: bool,
387     skip_string_normalization: bool,
388     quiet: bool,
389     verbose: bool,
390     include: str,
391     exclude: str,
392     src: Tuple[str],
393     config: Optional[str],
394 ) -> None:
395     """The uncompromising code formatter."""
396     write_back = WriteBack.from_configuration(check=check, diff=diff)
397     if target_version:
398         if py36:
399             err(f"Cannot use both --target-version and --py36")
400             ctx.exit(2)
401         else:
402             versions = set(target_version)
403     elif py36:
404         err(
405             "--py36 is deprecated and will be removed in a future version. "
406             "Use --target-version py36 instead."
407         )
408         versions = PY36_VERSIONS
409     else:
410         # We'll autodetect later.
411         versions = set()
412     mode = FileMode(
413         target_versions=versions,
414         line_length=line_length,
415         is_pyi=pyi,
416         string_normalization=not skip_string_normalization,
417     )
418     if config and verbose:
419         out(f"Using configuration from {config}.", bold=False, fg="blue")
420     if code is not None:
421         print(format_str(code, mode=mode))
422         ctx.exit(0)
423     try:
424         include_regex = re_compile_maybe_verbose(include)
425     except re.error:
426         err(f"Invalid regular expression for include given: {include!r}")
427         ctx.exit(2)
428     try:
429         exclude_regex = re_compile_maybe_verbose(exclude)
430     except re.error:
431         err(f"Invalid regular expression for exclude given: {exclude!r}")
432         ctx.exit(2)
433     report = Report(check=check, quiet=quiet, verbose=verbose)
434     root = find_project_root(src)
435     sources: Set[Path] = set()
436     path_empty(src, quiet, verbose, ctx)
437     for s in src:
438         p = Path(s)
439         if p.is_dir():
440             sources.update(
441                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
442             )
443         elif p.is_file() or s == "-":
444             # if a file was explicitly given, we don't care about its extension
445             sources.add(p)
446         else:
447             err(f"invalid path: {s}")
448     if len(sources) == 0:
449         if verbose or not quiet:
450             out("No Python files are present to be formatted. Nothing to do 😴")
451         ctx.exit(0)
452
453     if len(sources) == 1:
454         reformat_one(
455             src=sources.pop(),
456             fast=fast,
457             write_back=write_back,
458             mode=mode,
459             report=report,
460         )
461     else:
462         reformat_many(
463             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
464         )
465
466     if verbose or not quiet:
467         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
468         click.secho(str(report), err=True)
469     ctx.exit(report.return_code)
470
471
472 def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
473     """
474     Exit if there is no `src` provided for formatting
475     """
476     if not src:
477         if verbose or not quiet:
478             out("No Path provided. Nothing to do 😴")
479             ctx.exit(0)
480
481
482 def reformat_one(
483     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
484 ) -> None:
485     """Reformat a single file under `src` without spawning child processes.
486
487     `fast`, `write_back`, and `mode` options are passed to
488     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
489     """
490     try:
491         changed = Changed.NO
492         if not src.is_file() and str(src) == "-":
493             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
494                 changed = Changed.YES
495         else:
496             cache: Cache = {}
497             if write_back != WriteBack.DIFF:
498                 cache = read_cache(mode)
499                 res_src = src.resolve()
500                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
501                     changed = Changed.CACHED
502             if changed is not Changed.CACHED and format_file_in_place(
503                 src, fast=fast, write_back=write_back, mode=mode
504             ):
505                 changed = Changed.YES
506             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
507                 write_back is WriteBack.CHECK and changed is Changed.NO
508             ):
509                 write_cache(cache, [src], mode)
510         report.done(src, changed)
511     except Exception as exc:
512         report.failed(src, str(exc))
513
514
515 def reformat_many(
516     sources: Set[Path],
517     fast: bool,
518     write_back: WriteBack,
519     mode: FileMode,
520     report: "Report",
521 ) -> None:
522     """Reformat multiple files using a ProcessPoolExecutor."""
523     loop = asyncio.get_event_loop()
524     worker_count = os.cpu_count()
525     if sys.platform == "win32":
526         # Work around https://bugs.python.org/issue26903
527         worker_count = min(worker_count, 61)
528     executor = ProcessPoolExecutor(max_workers=worker_count)
529     try:
530         loop.run_until_complete(
531             schedule_formatting(
532                 sources=sources,
533                 fast=fast,
534                 write_back=write_back,
535                 mode=mode,
536                 report=report,
537                 loop=loop,
538                 executor=executor,
539             )
540         )
541     finally:
542         shutdown(loop)
543         executor.shutdown()
544
545
546 async def schedule_formatting(
547     sources: Set[Path],
548     fast: bool,
549     write_back: WriteBack,
550     mode: FileMode,
551     report: "Report",
552     loop: asyncio.AbstractEventLoop,
553     executor: Executor,
554 ) -> None:
555     """Run formatting of `sources` in parallel using the provided `executor`.
556
557     (Use ProcessPoolExecutors for actual parallelism.)
558
559     `write_back`, `fast`, and `mode` options are passed to
560     :func:`format_file_in_place`.
561     """
562     cache: Cache = {}
563     if write_back != WriteBack.DIFF:
564         cache = read_cache(mode)
565         sources, cached = filter_cached(cache, sources)
566         for src in sorted(cached):
567             report.done(src, Changed.CACHED)
568     if not sources:
569         return
570
571     cancelled = []
572     sources_to_cache = []
573     lock = None
574     if write_back == WriteBack.DIFF:
575         # For diff output, we need locks to ensure we don't interleave output
576         # from different processes.
577         manager = Manager()
578         lock = manager.Lock()
579     tasks = {
580         asyncio.ensure_future(
581             loop.run_in_executor(
582                 executor, format_file_in_place, src, fast, mode, write_back, lock
583             )
584         ): src
585         for src in sorted(sources)
586     }
587     pending: Iterable[asyncio.Future] = tasks.keys()
588     try:
589         loop.add_signal_handler(signal.SIGINT, cancel, pending)
590         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
591     except NotImplementedError:
592         # There are no good alternatives for these on Windows.
593         pass
594     while pending:
595         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
596         for task in done:
597             src = tasks.pop(task)
598             if task.cancelled():
599                 cancelled.append(task)
600             elif task.exception():
601                 report.failed(src, str(task.exception()))
602             else:
603                 changed = Changed.YES if task.result() else Changed.NO
604                 # If the file was written back or was successfully checked as
605                 # well-formatted, store this information in the cache.
606                 if write_back is WriteBack.YES or (
607                     write_back is WriteBack.CHECK and changed is Changed.NO
608                 ):
609                     sources_to_cache.append(src)
610                 report.done(src, changed)
611     if cancelled:
612         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
613     if sources_to_cache:
614         write_cache(cache, sources_to_cache, mode)
615
616
617 def format_file_in_place(
618     src: Path,
619     fast: bool,
620     mode: FileMode,
621     write_back: WriteBack = WriteBack.NO,
622     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
623 ) -> bool:
624     """Format file under `src` path. Return True if changed.
625
626     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
627     code to the file.
628     `mode` and `fast` options are passed to :func:`format_file_contents`.
629     """
630     if src.suffix == ".pyi":
631         mode = evolve(mode, is_pyi=True)
632
633     then = datetime.utcfromtimestamp(src.stat().st_mtime)
634     with open(src, "rb") as buf:
635         src_contents, encoding, newline = decode_bytes(buf.read())
636     try:
637         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
638     except NothingChanged:
639         return False
640
641     if write_back == write_back.YES:
642         with open(src, "w", encoding=encoding, newline=newline) as f:
643             f.write(dst_contents)
644     elif write_back == write_back.DIFF:
645         now = datetime.utcnow()
646         src_name = f"{src}\t{then} +0000"
647         dst_name = f"{src}\t{now} +0000"
648         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
649
650         with lock or nullcontext():
651             f = io.TextIOWrapper(
652                 sys.stdout.buffer,
653                 encoding=encoding,
654                 newline=newline,
655                 write_through=True,
656             )
657             f.write(diff_contents)
658             f.detach()
659
660     return True
661
662
663 def format_stdin_to_stdout(
664     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
665 ) -> bool:
666     """Format file on stdin. Return True if changed.
667
668     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
669     write a diff to stdout. The `mode` argument is passed to
670     :func:`format_file_contents`.
671     """
672     then = datetime.utcnow()
673     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
674     dst = src
675     try:
676         dst = format_file_contents(src, fast=fast, mode=mode)
677         return True
678
679     except NothingChanged:
680         return False
681
682     finally:
683         f = io.TextIOWrapper(
684             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
685         )
686         if write_back == WriteBack.YES:
687             f.write(dst)
688         elif write_back == WriteBack.DIFF:
689             now = datetime.utcnow()
690             src_name = f"STDIN\t{then} +0000"
691             dst_name = f"STDOUT\t{now} +0000"
692             f.write(diff(src, dst, src_name, dst_name))
693         f.detach()
694
695
696 def format_file_contents(
697     src_contents: str, *, fast: bool, mode: FileMode
698 ) -> FileContent:
699     """Reformat contents a file and return new contents.
700
701     If `fast` is False, additionally confirm that the reformatted code is
702     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
703     `mode` is passed to :func:`format_str`.
704     """
705     if src_contents.strip() == "":
706         raise NothingChanged
707
708     dst_contents = format_str(src_contents, mode=mode)
709     if src_contents == dst_contents:
710         raise NothingChanged
711
712     if not fast:
713         assert_equivalent(src_contents, dst_contents)
714         assert_stable(src_contents, dst_contents, mode=mode)
715     return dst_contents
716
717
718 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
719     """Reformat a string and return new contents.
720
721     `mode` determines formatting options, such as how many characters per line are
722     allowed.
723     """
724     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
725     dst_contents = []
726     future_imports = get_future_imports(src_node)
727     if mode.target_versions:
728         versions = mode.target_versions
729     else:
730         versions = detect_target_versions(src_node)
731     normalize_fmt_off(src_node)
732     lines = LineGenerator(
733         remove_u_prefix="unicode_literals" in future_imports
734         or supports_feature(versions, Feature.UNICODE_LITERALS),
735         is_pyi=mode.is_pyi,
736         normalize_strings=mode.string_normalization,
737     )
738     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
739     empty_line = Line()
740     after = 0
741     split_line_features = {
742         feature
743         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
744         if supports_feature(versions, feature)
745     }
746     for current_line in lines.visit(src_node):
747         for _ in range(after):
748             dst_contents.append(str(empty_line))
749         before, after = elt.maybe_empty_lines(current_line)
750         for _ in range(before):
751             dst_contents.append(str(empty_line))
752         for line in split_line(
753             current_line, line_length=mode.line_length, features=split_line_features
754         ):
755             dst_contents.append(str(line))
756     return "".join(dst_contents)
757
758
759 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
760     """Return a tuple of (decoded_contents, encoding, newline).
761
762     `newline` is either CRLF or LF but `decoded_contents` is decoded with
763     universal newlines (i.e. only contains LF).
764     """
765     srcbuf = io.BytesIO(src)
766     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
767     if not lines:
768         return "", encoding, "\n"
769
770     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
771     srcbuf.seek(0)
772     with io.TextIOWrapper(srcbuf, encoding) as tiow:
773         return tiow.read(), encoding, newline
774
775
776 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
777     if not target_versions:
778         # No target_version specified, so try all grammars.
779         return [
780             # Python 3.7+
781             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
782             # Python 3.0-3.6
783             pygram.python_grammar_no_print_statement_no_exec_statement,
784             # Python 2.7 with future print_function import
785             pygram.python_grammar_no_print_statement,
786             # Python 2.7
787             pygram.python_grammar,
788         ]
789     elif all(version.is_python2() for version in target_versions):
790         # Python 2-only code, so try Python 2 grammars.
791         return [
792             # Python 2.7 with future print_function import
793             pygram.python_grammar_no_print_statement,
794             # Python 2.7
795             pygram.python_grammar,
796         ]
797     else:
798         # Python 3-compatible code, so only try Python 3 grammar.
799         grammars = []
800         # If we have to parse both, try to parse async as a keyword first
801         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
802             # Python 3.7+
803             grammars.append(
804                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
805             )
806         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
807             # Python 3.0-3.6
808             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
809         # At least one of the above branches must have been taken, because every Python
810         # version has exactly one of the two 'ASYNC_*' flags
811         return grammars
812
813
814 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
815     """Given a string with source, return the lib2to3 Node."""
816     if src_txt[-1:] != "\n":
817         src_txt += "\n"
818
819     for grammar in get_grammars(set(target_versions)):
820         drv = driver.Driver(grammar, pytree.convert)
821         try:
822             result = drv.parse_string(src_txt, True)
823             break
824
825         except ParseError as pe:
826             lineno, column = pe.context[1]
827             lines = src_txt.splitlines()
828             try:
829                 faulty_line = lines[lineno - 1]
830             except IndexError:
831                 faulty_line = "<line number missing in source>"
832             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
833     else:
834         raise exc from None
835
836     if isinstance(result, Leaf):
837         result = Node(syms.file_input, [result])
838     return result
839
840
841 def lib2to3_unparse(node: Node) -> str:
842     """Given a lib2to3 node, return its string representation."""
843     code = str(node)
844     return code
845
846
847 T = TypeVar("T")
848
849
850 class Visitor(Generic[T]):
851     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
852
853     def visit(self, node: LN) -> Iterator[T]:
854         """Main method to visit `node` and its children.
855
856         It tries to find a `visit_*()` method for the given `node.type`, like
857         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
858         If no dedicated `visit_*()` method is found, chooses `visit_default()`
859         instead.
860
861         Then yields objects of type `T` from the selected visitor.
862         """
863         if node.type < 256:
864             name = token.tok_name[node.type]
865         else:
866             name = type_repr(node.type)
867         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
868
869     def visit_default(self, node: LN) -> Iterator[T]:
870         """Default `visit_*()` implementation. Recurses to children of `node`."""
871         if isinstance(node, Node):
872             for child in node.children:
873                 yield from self.visit(child)
874
875
876 @dataclass
877 class DebugVisitor(Visitor[T]):
878     tree_depth: int = 0
879
880     def visit_default(self, node: LN) -> Iterator[T]:
881         indent = " " * (2 * self.tree_depth)
882         if isinstance(node, Node):
883             _type = type_repr(node.type)
884             out(f"{indent}{_type}", fg="yellow")
885             self.tree_depth += 1
886             for child in node.children:
887                 yield from self.visit(child)
888
889             self.tree_depth -= 1
890             out(f"{indent}/{_type}", fg="yellow", bold=False)
891         else:
892             _type = token.tok_name.get(node.type, str(node.type))
893             out(f"{indent}{_type}", fg="blue", nl=False)
894             if node.prefix:
895                 # We don't have to handle prefixes for `Node` objects since
896                 # that delegates to the first child anyway.
897                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
898             out(f" {node.value!r}", fg="blue", bold=False)
899
900     @classmethod
901     def show(cls, code: Union[str, Leaf, Node]) -> None:
902         """Pretty-print the lib2to3 AST of a given string of `code`.
903
904         Convenience method for debugging.
905         """
906         v: DebugVisitor[None] = DebugVisitor()
907         if isinstance(code, str):
908             code = lib2to3_parse(code)
909         list(v.visit(code))
910
911
912 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
913 STATEMENT = {
914     syms.if_stmt,
915     syms.while_stmt,
916     syms.for_stmt,
917     syms.try_stmt,
918     syms.except_clause,
919     syms.with_stmt,
920     syms.funcdef,
921     syms.classdef,
922 }
923 STANDALONE_COMMENT = 153
924 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
925 LOGIC_OPERATORS = {"and", "or"}
926 COMPARATORS = {
927     token.LESS,
928     token.GREATER,
929     token.EQEQUAL,
930     token.NOTEQUAL,
931     token.LESSEQUAL,
932     token.GREATEREQUAL,
933 }
934 MATH_OPERATORS = {
935     token.VBAR,
936     token.CIRCUMFLEX,
937     token.AMPER,
938     token.LEFTSHIFT,
939     token.RIGHTSHIFT,
940     token.PLUS,
941     token.MINUS,
942     token.STAR,
943     token.SLASH,
944     token.DOUBLESLASH,
945     token.PERCENT,
946     token.AT,
947     token.TILDE,
948     token.DOUBLESTAR,
949 }
950 STARS = {token.STAR, token.DOUBLESTAR}
951 VARARGS_SPECIALS = STARS | {token.SLASH}
952 VARARGS_PARENTS = {
953     syms.arglist,
954     syms.argument,  # double star in arglist
955     syms.trailer,  # single argument to call
956     syms.typedargslist,
957     syms.varargslist,  # lambdas
958 }
959 UNPACKING_PARENTS = {
960     syms.atom,  # single element of a list or set literal
961     syms.dictsetmaker,
962     syms.listmaker,
963     syms.testlist_gexp,
964     syms.testlist_star_expr,
965 }
966 TEST_DESCENDANTS = {
967     syms.test,
968     syms.lambdef,
969     syms.or_test,
970     syms.and_test,
971     syms.not_test,
972     syms.comparison,
973     syms.star_expr,
974     syms.expr,
975     syms.xor_expr,
976     syms.and_expr,
977     syms.shift_expr,
978     syms.arith_expr,
979     syms.trailer,
980     syms.term,
981     syms.power,
982 }
983 ASSIGNMENTS = {
984     "=",
985     "+=",
986     "-=",
987     "*=",
988     "@=",
989     "/=",
990     "%=",
991     "&=",
992     "|=",
993     "^=",
994     "<<=",
995     ">>=",
996     "**=",
997     "//=",
998 }
999 COMPREHENSION_PRIORITY = 20
1000 COMMA_PRIORITY = 18
1001 TERNARY_PRIORITY = 16
1002 LOGIC_PRIORITY = 14
1003 STRING_PRIORITY = 12
1004 COMPARATOR_PRIORITY = 10
1005 MATH_PRIORITIES = {
1006     token.VBAR: 9,
1007     token.CIRCUMFLEX: 8,
1008     token.AMPER: 7,
1009     token.LEFTSHIFT: 6,
1010     token.RIGHTSHIFT: 6,
1011     token.PLUS: 5,
1012     token.MINUS: 5,
1013     token.STAR: 4,
1014     token.SLASH: 4,
1015     token.DOUBLESLASH: 4,
1016     token.PERCENT: 4,
1017     token.AT: 4,
1018     token.TILDE: 3,
1019     token.DOUBLESTAR: 2,
1020 }
1021 DOT_PRIORITY = 1
1022
1023
1024 @dataclass
1025 class BracketTracker:
1026     """Keeps track of brackets on a line."""
1027
1028     depth: int = 0
1029     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1030     delimiters: Dict[LeafID, Priority] = Factory(dict)
1031     previous: Optional[Leaf] = None
1032     _for_loop_depths: List[int] = Factory(list)
1033     _lambda_argument_depths: List[int] = Factory(list)
1034
1035     def mark(self, leaf: Leaf) -> None:
1036         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1037
1038         All leaves receive an int `bracket_depth` field that stores how deep
1039         within brackets a given leaf is. 0 means there are no enclosing brackets
1040         that started on this line.
1041
1042         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1043         field that it forms a pair with. This is a one-directional link to
1044         avoid reference cycles.
1045
1046         If a leaf is a delimiter (a token on which Black can split the line if
1047         needed) and it's on depth 0, its `id()` is stored in the tracker's
1048         `delimiters` field.
1049         """
1050         if leaf.type == token.COMMENT:
1051             return
1052
1053         self.maybe_decrement_after_for_loop_variable(leaf)
1054         self.maybe_decrement_after_lambda_arguments(leaf)
1055         if leaf.type in CLOSING_BRACKETS:
1056             self.depth -= 1
1057             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1058             leaf.opening_bracket = opening_bracket
1059         leaf.bracket_depth = self.depth
1060         if self.depth == 0:
1061             delim = is_split_before_delimiter(leaf, self.previous)
1062             if delim and self.previous is not None:
1063                 self.delimiters[id(self.previous)] = delim
1064             else:
1065                 delim = is_split_after_delimiter(leaf, self.previous)
1066                 if delim:
1067                     self.delimiters[id(leaf)] = delim
1068         if leaf.type in OPENING_BRACKETS:
1069             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1070             self.depth += 1
1071         self.previous = leaf
1072         self.maybe_increment_lambda_arguments(leaf)
1073         self.maybe_increment_for_loop_variable(leaf)
1074
1075     def any_open_brackets(self) -> bool:
1076         """Return True if there is an yet unmatched open bracket on the line."""
1077         return bool(self.bracket_match)
1078
1079     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1080         """Return the highest priority of a delimiter found on the line.
1081
1082         Values are consistent with what `is_split_*_delimiter()` return.
1083         Raises ValueError on no delimiters.
1084         """
1085         return max(v for k, v in self.delimiters.items() if k not in exclude)
1086
1087     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1088         """Return the number of delimiters with the given `priority`.
1089
1090         If no `priority` is passed, defaults to max priority on the line.
1091         """
1092         if not self.delimiters:
1093             return 0
1094
1095         priority = priority or self.max_delimiter_priority()
1096         return sum(1 for p in self.delimiters.values() if p == priority)
1097
1098     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1099         """In a for loop, or comprehension, the variables are often unpacks.
1100
1101         To avoid splitting on the comma in this situation, increase the depth of
1102         tokens between `for` and `in`.
1103         """
1104         if leaf.type == token.NAME and leaf.value == "for":
1105             self.depth += 1
1106             self._for_loop_depths.append(self.depth)
1107             return True
1108
1109         return False
1110
1111     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1112         """See `maybe_increment_for_loop_variable` above for explanation."""
1113         if (
1114             self._for_loop_depths
1115             and self._for_loop_depths[-1] == self.depth
1116             and leaf.type == token.NAME
1117             and leaf.value == "in"
1118         ):
1119             self.depth -= 1
1120             self._for_loop_depths.pop()
1121             return True
1122
1123         return False
1124
1125     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1126         """In a lambda expression, there might be more than one argument.
1127
1128         To avoid splitting on the comma in this situation, increase the depth of
1129         tokens between `lambda` and `:`.
1130         """
1131         if leaf.type == token.NAME and leaf.value == "lambda":
1132             self.depth += 1
1133             self._lambda_argument_depths.append(self.depth)
1134             return True
1135
1136         return False
1137
1138     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1139         """See `maybe_increment_lambda_arguments` above for explanation."""
1140         if (
1141             self._lambda_argument_depths
1142             and self._lambda_argument_depths[-1] == self.depth
1143             and leaf.type == token.COLON
1144         ):
1145             self.depth -= 1
1146             self._lambda_argument_depths.pop()
1147             return True
1148
1149         return False
1150
1151     def get_open_lsqb(self) -> Optional[Leaf]:
1152         """Return the most recent opening square bracket (if any)."""
1153         return self.bracket_match.get((self.depth - 1, token.RSQB))
1154
1155
1156 @dataclass
1157 class Line:
1158     """Holds leaves and comments. Can be printed with `str(line)`."""
1159
1160     depth: int = 0
1161     leaves: List[Leaf] = Factory(list)
1162     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1163     bracket_tracker: BracketTracker = Factory(BracketTracker)
1164     inside_brackets: bool = False
1165     should_explode: bool = False
1166
1167     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1168         """Add a new `leaf` to the end of the line.
1169
1170         Unless `preformatted` is True, the `leaf` will receive a new consistent
1171         whitespace prefix and metadata applied by :class:`BracketTracker`.
1172         Trailing commas are maybe removed, unpacked for loop variables are
1173         demoted from being delimiters.
1174
1175         Inline comments are put aside.
1176         """
1177         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1178         if not has_value:
1179             return
1180
1181         if token.COLON == leaf.type and self.is_class_paren_empty:
1182             del self.leaves[-2:]
1183         if self.leaves and not preformatted:
1184             # Note: at this point leaf.prefix should be empty except for
1185             # imports, for which we only preserve newlines.
1186             leaf.prefix += whitespace(
1187                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1188             )
1189         if self.inside_brackets or not preformatted:
1190             self.bracket_tracker.mark(leaf)
1191             self.maybe_remove_trailing_comma(leaf)
1192         if not self.append_comment(leaf):
1193             self.leaves.append(leaf)
1194
1195     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1196         """Like :func:`append()` but disallow invalid standalone comment structure.
1197
1198         Raises ValueError when any `leaf` is appended after a standalone comment
1199         or when a standalone comment is not the first leaf on the line.
1200         """
1201         if self.bracket_tracker.depth == 0:
1202             if self.is_comment:
1203                 raise ValueError("cannot append to standalone comments")
1204
1205             if self.leaves and leaf.type == STANDALONE_COMMENT:
1206                 raise ValueError(
1207                     "cannot append standalone comments to a populated line"
1208                 )
1209
1210         self.append(leaf, preformatted=preformatted)
1211
1212     @property
1213     def is_comment(self) -> bool:
1214         """Is this line a standalone comment?"""
1215         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1216
1217     @property
1218     def is_decorator(self) -> bool:
1219         """Is this line a decorator?"""
1220         return bool(self) and self.leaves[0].type == token.AT
1221
1222     @property
1223     def is_import(self) -> bool:
1224         """Is this an import line?"""
1225         return bool(self) and is_import(self.leaves[0])
1226
1227     @property
1228     def is_class(self) -> bool:
1229         """Is this line a class definition?"""
1230         return (
1231             bool(self)
1232             and self.leaves[0].type == token.NAME
1233             and self.leaves[0].value == "class"
1234         )
1235
1236     @property
1237     def is_stub_class(self) -> bool:
1238         """Is this line a class definition with a body consisting only of "..."?"""
1239         return self.is_class and self.leaves[-3:] == [
1240             Leaf(token.DOT, ".") for _ in range(3)
1241         ]
1242
1243     @property
1244     def is_def(self) -> bool:
1245         """Is this a function definition? (Also returns True for async defs.)"""
1246         try:
1247             first_leaf = self.leaves[0]
1248         except IndexError:
1249             return False
1250
1251         try:
1252             second_leaf: Optional[Leaf] = self.leaves[1]
1253         except IndexError:
1254             second_leaf = None
1255         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1256             first_leaf.type == token.ASYNC
1257             and second_leaf is not None
1258             and second_leaf.type == token.NAME
1259             and second_leaf.value == "def"
1260         )
1261
1262     @property
1263     def is_class_paren_empty(self) -> bool:
1264         """Is this a class with no base classes but using parentheses?
1265
1266         Those are unnecessary and should be removed.
1267         """
1268         return (
1269             bool(self)
1270             and len(self.leaves) == 4
1271             and self.is_class
1272             and self.leaves[2].type == token.LPAR
1273             and self.leaves[2].value == "("
1274             and self.leaves[3].type == token.RPAR
1275             and self.leaves[3].value == ")"
1276         )
1277
1278     @property
1279     def is_triple_quoted_string(self) -> bool:
1280         """Is the line a triple quoted string?"""
1281         return (
1282             bool(self)
1283             and self.leaves[0].type == token.STRING
1284             and self.leaves[0].value.startswith(('"""', "'''"))
1285         )
1286
1287     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1288         """If so, needs to be split before emitting."""
1289         for leaf in self.leaves:
1290             if leaf.type == STANDALONE_COMMENT:
1291                 if leaf.bracket_depth <= depth_limit:
1292                     return True
1293         return False
1294
1295     def contains_uncollapsable_type_comments(self) -> bool:
1296         ignored_ids = set()
1297         try:
1298             last_leaf = self.leaves[-1]
1299             ignored_ids.add(id(last_leaf))
1300             if last_leaf.type == token.COMMA or (
1301                 last_leaf.type == token.RPAR and not last_leaf.value
1302             ):
1303                 # When trailing commas or optional parens are inserted by Black for
1304                 # consistency, comments after the previous last element are not moved
1305                 # (they don't have to, rendering will still be correct).  So we ignore
1306                 # trailing commas and invisible.
1307                 last_leaf = self.leaves[-2]
1308                 ignored_ids.add(id(last_leaf))
1309         except IndexError:
1310             return False
1311
1312         # A type comment is uncollapsable if it is attached to a leaf
1313         # that isn't at the end of the line (since that could cause it
1314         # to get associated to a different argument) or if there are
1315         # comments before it (since that could cause it to get hidden
1316         # behind a comment.
1317         comment_seen = False
1318         for leaf_id, comments in self.comments.items():
1319             for comment in comments:
1320                 if is_type_comment(comment):
1321                     if leaf_id not in ignored_ids or comment_seen:
1322                         return True
1323
1324             comment_seen = True
1325
1326         return False
1327
1328     def contains_unsplittable_type_ignore(self) -> bool:
1329         if not self.leaves:
1330             return False
1331
1332         # If a 'type: ignore' is attached to the end of a line, we
1333         # can't split the line, because we can't know which of the
1334         # subexpressions the ignore was meant to apply to.
1335         #
1336         # We only want this to apply to actual physical lines from the
1337         # original source, though: we don't want the presence of a
1338         # 'type: ignore' at the end of a multiline expression to
1339         # justify pushing it all onto one line. Thus we
1340         # (unfortunately) need to check the actual source lines and
1341         # only report an unsplittable 'type: ignore' if this line was
1342         # one line in the original code.
1343         if self.leaves[0].lineno == self.leaves[-1].lineno:
1344             for comment in self.comments.get(id(self.leaves[-1]), []):
1345                 if is_type_comment(comment, " ignore"):
1346                     return True
1347
1348         return False
1349
1350     def contains_multiline_strings(self) -> bool:
1351         for leaf in self.leaves:
1352             if is_multiline_string(leaf):
1353                 return True
1354
1355         return False
1356
1357     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1358         """Remove trailing comma if there is one and it's safe."""
1359         if not (
1360             self.leaves
1361             and self.leaves[-1].type == token.COMMA
1362             and closing.type in CLOSING_BRACKETS
1363         ):
1364             return False
1365
1366         if closing.type == token.RBRACE:
1367             self.remove_trailing_comma()
1368             return True
1369
1370         if closing.type == token.RSQB:
1371             comma = self.leaves[-1]
1372             if comma.parent and comma.parent.type == syms.listmaker:
1373                 self.remove_trailing_comma()
1374                 return True
1375
1376         # For parens let's check if it's safe to remove the comma.
1377         # Imports are always safe.
1378         if self.is_import:
1379             self.remove_trailing_comma()
1380             return True
1381
1382         # Otherwise, if the trailing one is the only one, we might mistakenly
1383         # change a tuple into a different type by removing the comma.
1384         depth = closing.bracket_depth + 1
1385         commas = 0
1386         opening = closing.opening_bracket
1387         for _opening_index, leaf in enumerate(self.leaves):
1388             if leaf is opening:
1389                 break
1390
1391         else:
1392             return False
1393
1394         for leaf in self.leaves[_opening_index + 1 :]:
1395             if leaf is closing:
1396                 break
1397
1398             bracket_depth = leaf.bracket_depth
1399             if bracket_depth == depth and leaf.type == token.COMMA:
1400                 commas += 1
1401                 if leaf.parent and leaf.parent.type in {
1402                     syms.arglist,
1403                     syms.typedargslist,
1404                 }:
1405                     commas += 1
1406                     break
1407
1408         if commas > 1:
1409             self.remove_trailing_comma()
1410             return True
1411
1412         return False
1413
1414     def append_comment(self, comment: Leaf) -> bool:
1415         """Add an inline or standalone comment to the line."""
1416         if (
1417             comment.type == STANDALONE_COMMENT
1418             and self.bracket_tracker.any_open_brackets()
1419         ):
1420             comment.prefix = ""
1421             return False
1422
1423         if comment.type != token.COMMENT:
1424             return False
1425
1426         if not self.leaves:
1427             comment.type = STANDALONE_COMMENT
1428             comment.prefix = ""
1429             return False
1430
1431         last_leaf = self.leaves[-1]
1432         if (
1433             last_leaf.type == token.RPAR
1434             and not last_leaf.value
1435             and last_leaf.parent
1436             and len(list(last_leaf.parent.leaves())) <= 3
1437             and not is_type_comment(comment)
1438         ):
1439             # Comments on an optional parens wrapping a single leaf should belong to
1440             # the wrapped node except if it's a type comment. Pinning the comment like
1441             # this avoids unstable formatting caused by comment migration.
1442             if len(self.leaves) < 2:
1443                 comment.type = STANDALONE_COMMENT
1444                 comment.prefix = ""
1445                 return False
1446             last_leaf = self.leaves[-2]
1447         self.comments.setdefault(id(last_leaf), []).append(comment)
1448         return True
1449
1450     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1451         """Generate comments that should appear directly after `leaf`."""
1452         return self.comments.get(id(leaf), [])
1453
1454     def remove_trailing_comma(self) -> None:
1455         """Remove the trailing comma and moves the comments attached to it."""
1456         trailing_comma = self.leaves.pop()
1457         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1458         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1459             trailing_comma_comments
1460         )
1461
1462     def is_complex_subscript(self, leaf: Leaf) -> bool:
1463         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1464         open_lsqb = self.bracket_tracker.get_open_lsqb()
1465         if open_lsqb is None:
1466             return False
1467
1468         subscript_start = open_lsqb.next_sibling
1469
1470         if isinstance(subscript_start, Node):
1471             if subscript_start.type == syms.listmaker:
1472                 return False
1473
1474             if subscript_start.type == syms.subscriptlist:
1475                 subscript_start = child_towards(subscript_start, leaf)
1476         return subscript_start is not None and any(
1477             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1478         )
1479
1480     def __str__(self) -> str:
1481         """Render the line."""
1482         if not self:
1483             return "\n"
1484
1485         indent = "    " * self.depth
1486         leaves = iter(self.leaves)
1487         first = next(leaves)
1488         res = f"{first.prefix}{indent}{first.value}"
1489         for leaf in leaves:
1490             res += str(leaf)
1491         for comment in itertools.chain.from_iterable(self.comments.values()):
1492             res += str(comment)
1493         return res + "\n"
1494
1495     def __bool__(self) -> bool:
1496         """Return True if the line has leaves or comments."""
1497         return bool(self.leaves or self.comments)
1498
1499
1500 @dataclass
1501 class EmptyLineTracker:
1502     """Provides a stateful method that returns the number of potential extra
1503     empty lines needed before and after the currently processed line.
1504
1505     Note: this tracker works on lines that haven't been split yet.  It assumes
1506     the prefix of the first leaf consists of optional newlines.  Those newlines
1507     are consumed by `maybe_empty_lines()` and included in the computation.
1508     """
1509
1510     is_pyi: bool = False
1511     previous_line: Optional[Line] = None
1512     previous_after: int = 0
1513     previous_defs: List[int] = Factory(list)
1514
1515     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1516         """Return the number of extra empty lines before and after the `current_line`.
1517
1518         This is for separating `def`, `async def` and `class` with extra empty
1519         lines (two on module-level).
1520         """
1521         before, after = self._maybe_empty_lines(current_line)
1522         before = (
1523             # Black should not insert empty lines at the beginning
1524             # of the file
1525             0
1526             if self.previous_line is None
1527             else before - self.previous_after
1528         )
1529         self.previous_after = after
1530         self.previous_line = current_line
1531         return before, after
1532
1533     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1534         max_allowed = 1
1535         if current_line.depth == 0:
1536             max_allowed = 1 if self.is_pyi else 2
1537         if current_line.leaves:
1538             # Consume the first leaf's extra newlines.
1539             first_leaf = current_line.leaves[0]
1540             before = first_leaf.prefix.count("\n")
1541             before = min(before, max_allowed)
1542             first_leaf.prefix = ""
1543         else:
1544             before = 0
1545         depth = current_line.depth
1546         while self.previous_defs and self.previous_defs[-1] >= depth:
1547             self.previous_defs.pop()
1548             if self.is_pyi:
1549                 before = 0 if depth else 1
1550             else:
1551                 before = 1 if depth else 2
1552         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1553             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1554
1555         if (
1556             self.previous_line
1557             and self.previous_line.is_import
1558             and not current_line.is_import
1559             and depth == self.previous_line.depth
1560         ):
1561             return (before or 1), 0
1562
1563         if (
1564             self.previous_line
1565             and self.previous_line.is_class
1566             and current_line.is_triple_quoted_string
1567         ):
1568             return before, 1
1569
1570         return before, 0
1571
1572     def _maybe_empty_lines_for_class_or_def(
1573         self, current_line: Line, before: int
1574     ) -> Tuple[int, int]:
1575         if not current_line.is_decorator:
1576             self.previous_defs.append(current_line.depth)
1577         if self.previous_line is None:
1578             # Don't insert empty lines before the first line in the file.
1579             return 0, 0
1580
1581         if self.previous_line.is_decorator:
1582             return 0, 0
1583
1584         if self.previous_line.depth < current_line.depth and (
1585             self.previous_line.is_class or self.previous_line.is_def
1586         ):
1587             return 0, 0
1588
1589         if (
1590             self.previous_line.is_comment
1591             and self.previous_line.depth == current_line.depth
1592             and before == 0
1593         ):
1594             return 0, 0
1595
1596         if self.is_pyi:
1597             if self.previous_line.depth > current_line.depth:
1598                 newlines = 1
1599             elif current_line.is_class or self.previous_line.is_class:
1600                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1601                     # No blank line between classes with an empty body
1602                     newlines = 0
1603                 else:
1604                     newlines = 1
1605             elif current_line.is_def and not self.previous_line.is_def:
1606                 # Blank line between a block of functions and a block of non-functions
1607                 newlines = 1
1608             else:
1609                 newlines = 0
1610         else:
1611             newlines = 2
1612         if current_line.depth and newlines:
1613             newlines -= 1
1614         return newlines, 0
1615
1616
1617 @dataclass
1618 class LineGenerator(Visitor[Line]):
1619     """Generates reformatted Line objects.  Empty lines are not emitted.
1620
1621     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1622     in ways that will no longer stringify to valid Python code on the tree.
1623     """
1624
1625     is_pyi: bool = False
1626     normalize_strings: bool = True
1627     current_line: Line = Factory(Line)
1628     remove_u_prefix: bool = False
1629
1630     def line(self, indent: int = 0) -> Iterator[Line]:
1631         """Generate a line.
1632
1633         If the line is empty, only emit if it makes sense.
1634         If the line is too long, split it first and then generate.
1635
1636         If any lines were generated, set up a new current_line.
1637         """
1638         if not self.current_line:
1639             self.current_line.depth += indent
1640             return  # Line is empty, don't emit. Creating a new one unnecessary.
1641
1642         complete_line = self.current_line
1643         self.current_line = Line(depth=complete_line.depth + indent)
1644         yield complete_line
1645
1646     def visit_default(self, node: LN) -> Iterator[Line]:
1647         """Default `visit_*()` implementation. Recurses to children of `node`."""
1648         if isinstance(node, Leaf):
1649             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1650             for comment in generate_comments(node):
1651                 if any_open_brackets:
1652                     # any comment within brackets is subject to splitting
1653                     self.current_line.append(comment)
1654                 elif comment.type == token.COMMENT:
1655                     # regular trailing comment
1656                     self.current_line.append(comment)
1657                     yield from self.line()
1658
1659                 else:
1660                     # regular standalone comment
1661                     yield from self.line()
1662
1663                     self.current_line.append(comment)
1664                     yield from self.line()
1665
1666             normalize_prefix(node, inside_brackets=any_open_brackets)
1667             if self.normalize_strings and node.type == token.STRING:
1668                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1669                 normalize_string_quotes(node)
1670             if node.type == token.NUMBER:
1671                 normalize_numeric_literal(node)
1672             if node.type not in WHITESPACE:
1673                 self.current_line.append(node)
1674         yield from super().visit_default(node)
1675
1676     def visit_atom(self, node: Node) -> Iterator[Line]:
1677         # Always make parentheses invisible around a single node, because it should
1678         # not be needed (except in the case of yield, where removing the parentheses
1679         # produces a SyntaxError).
1680         if (
1681             len(node.children) == 3
1682             and isinstance(node.children[0], Leaf)
1683             and node.children[0].type == token.LPAR
1684             and isinstance(node.children[2], Leaf)
1685             and node.children[2].type == token.RPAR
1686             and isinstance(node.children[1], Leaf)
1687             and not (
1688                 node.children[1].type == token.NAME
1689                 and node.children[1].value == "yield"
1690             )
1691         ):
1692             node.children[0].value = ""
1693             node.children[2].value = ""
1694         yield from super().visit_default(node)
1695
1696     def visit_factor(self, node: Node) -> Iterator[Line]:
1697         """Force parentheses between a unary op and a binary power:
1698
1699         -2 ** 8 -> -(2 ** 8)
1700         """
1701         child = node.children[1]
1702         if child.type == syms.power and len(child.children) == 3:
1703             lpar = Leaf(token.LPAR, "(")
1704             rpar = Leaf(token.RPAR, ")")
1705             index = child.remove() or 0
1706             node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
1707         yield from self.visit_default(node)
1708
1709     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1710         """Increase indentation level, maybe yield a line."""
1711         # In blib2to3 INDENT never holds comments.
1712         yield from self.line(+1)
1713         yield from self.visit_default(node)
1714
1715     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1716         """Decrease indentation level, maybe yield a line."""
1717         # The current line might still wait for trailing comments.  At DEDENT time
1718         # there won't be any (they would be prefixes on the preceding NEWLINE).
1719         # Emit the line then.
1720         yield from self.line()
1721
1722         # While DEDENT has no value, its prefix may contain standalone comments
1723         # that belong to the current indentation level.  Get 'em.
1724         yield from self.visit_default(node)
1725
1726         # Finally, emit the dedent.
1727         yield from self.line(-1)
1728
1729     def visit_stmt(
1730         self, node: Node, keywords: Set[str], parens: Set[str]
1731     ) -> Iterator[Line]:
1732         """Visit a statement.
1733
1734         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1735         `def`, `with`, `class`, `assert` and assignments.
1736
1737         The relevant Python language `keywords` for a given statement will be
1738         NAME leaves within it. This methods puts those on a separate line.
1739
1740         `parens` holds a set of string leaf values immediately after which
1741         invisible parens should be put.
1742         """
1743         normalize_invisible_parens(node, parens_after=parens)
1744         for child in node.children:
1745             if child.type == token.NAME and child.value in keywords:  # type: ignore
1746                 yield from self.line()
1747
1748             yield from self.visit(child)
1749
1750     def visit_suite(self, node: Node) -> Iterator[Line]:
1751         """Visit a suite."""
1752         if self.is_pyi and is_stub_suite(node):
1753             yield from self.visit(node.children[2])
1754         else:
1755             yield from self.visit_default(node)
1756
1757     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1758         """Visit a statement without nested statements."""
1759         is_suite_like = node.parent and node.parent.type in STATEMENT
1760         if is_suite_like:
1761             if self.is_pyi and is_stub_body(node):
1762                 yield from self.visit_default(node)
1763             else:
1764                 yield from self.line(+1)
1765                 yield from self.visit_default(node)
1766                 yield from self.line(-1)
1767
1768         else:
1769             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1770                 yield from self.line()
1771             yield from self.visit_default(node)
1772
1773     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1774         """Visit `async def`, `async for`, `async with`."""
1775         yield from self.line()
1776
1777         children = iter(node.children)
1778         for child in children:
1779             yield from self.visit(child)
1780
1781             if child.type == token.ASYNC:
1782                 break
1783
1784         internal_stmt = next(children)
1785         for child in internal_stmt.children:
1786             yield from self.visit(child)
1787
1788     def visit_decorators(self, node: Node) -> Iterator[Line]:
1789         """Visit decorators."""
1790         for child in node.children:
1791             yield from self.line()
1792             yield from self.visit(child)
1793
1794     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1795         """Remove a semicolon and put the other statement on a separate line."""
1796         yield from self.line()
1797
1798     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1799         """End of file. Process outstanding comments and end with a newline."""
1800         yield from self.visit_default(leaf)
1801         yield from self.line()
1802
1803     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1804         if not self.current_line.bracket_tracker.any_open_brackets():
1805             yield from self.line()
1806         yield from self.visit_default(leaf)
1807
1808     def __attrs_post_init__(self) -> None:
1809         """You are in a twisty little maze of passages."""
1810         v = self.visit_stmt
1811         Ø: Set[str] = set()
1812         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1813         self.visit_if_stmt = partial(
1814             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1815         )
1816         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1817         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1818         self.visit_try_stmt = partial(
1819             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1820         )
1821         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1822         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1823         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1824         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1825         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1826         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1827         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1828         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1829         self.visit_async_funcdef = self.visit_async_stmt
1830         self.visit_decorated = self.visit_decorators
1831
1832
1833 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1834 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1835 OPENING_BRACKETS = set(BRACKET.keys())
1836 CLOSING_BRACKETS = set(BRACKET.values())
1837 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1838 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1839
1840
1841 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1842     """Return whitespace prefix if needed for the given `leaf`.
1843
1844     `complex_subscript` signals whether the given leaf is part of a subscription
1845     which has non-trivial arguments, like arithmetic expressions or function calls.
1846     """
1847     NO = ""
1848     SPACE = " "
1849     DOUBLESPACE = "  "
1850     t = leaf.type
1851     p = leaf.parent
1852     v = leaf.value
1853     if t in ALWAYS_NO_SPACE:
1854         return NO
1855
1856     if t == token.COMMENT:
1857         return DOUBLESPACE
1858
1859     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1860     if t == token.COLON and p.type not in {
1861         syms.subscript,
1862         syms.subscriptlist,
1863         syms.sliceop,
1864     }:
1865         return NO
1866
1867     prev = leaf.prev_sibling
1868     if not prev:
1869         prevp = preceding_leaf(p)
1870         if not prevp or prevp.type in OPENING_BRACKETS:
1871             return NO
1872
1873         if t == token.COLON:
1874             if prevp.type == token.COLON:
1875                 return NO
1876
1877             elif prevp.type != token.COMMA and not complex_subscript:
1878                 return NO
1879
1880             return SPACE
1881
1882         if prevp.type == token.EQUAL:
1883             if prevp.parent:
1884                 if prevp.parent.type in {
1885                     syms.arglist,
1886                     syms.argument,
1887                     syms.parameters,
1888                     syms.varargslist,
1889                 }:
1890                     return NO
1891
1892                 elif prevp.parent.type == syms.typedargslist:
1893                     # A bit hacky: if the equal sign has whitespace, it means we
1894                     # previously found it's a typed argument.  So, we're using
1895                     # that, too.
1896                     return prevp.prefix
1897
1898         elif prevp.type in VARARGS_SPECIALS:
1899             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1900                 return NO
1901
1902         elif prevp.type == token.COLON:
1903             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1904                 return SPACE if complex_subscript else NO
1905
1906         elif (
1907             prevp.parent
1908             and prevp.parent.type == syms.factor
1909             and prevp.type in MATH_OPERATORS
1910         ):
1911             return NO
1912
1913         elif (
1914             prevp.type == token.RIGHTSHIFT
1915             and prevp.parent
1916             and prevp.parent.type == syms.shift_expr
1917             and prevp.prev_sibling
1918             and prevp.prev_sibling.type == token.NAME
1919             and prevp.prev_sibling.value == "print"  # type: ignore
1920         ):
1921             # Python 2 print chevron
1922             return NO
1923
1924     elif prev.type in OPENING_BRACKETS:
1925         return NO
1926
1927     if p.type in {syms.parameters, syms.arglist}:
1928         # untyped function signatures or calls
1929         if not prev or prev.type != token.COMMA:
1930             return NO
1931
1932     elif p.type == syms.varargslist:
1933         # lambdas
1934         if prev and prev.type != token.COMMA:
1935             return NO
1936
1937     elif p.type == syms.typedargslist:
1938         # typed function signatures
1939         if not prev:
1940             return NO
1941
1942         if t == token.EQUAL:
1943             if prev.type != syms.tname:
1944                 return NO
1945
1946         elif prev.type == token.EQUAL:
1947             # A bit hacky: if the equal sign has whitespace, it means we
1948             # previously found it's a typed argument.  So, we're using that, too.
1949             return prev.prefix
1950
1951         elif prev.type != token.COMMA:
1952             return NO
1953
1954     elif p.type == syms.tname:
1955         # type names
1956         if not prev:
1957             prevp = preceding_leaf(p)
1958             if not prevp or prevp.type != token.COMMA:
1959                 return NO
1960
1961     elif p.type == syms.trailer:
1962         # attributes and calls
1963         if t == token.LPAR or t == token.RPAR:
1964             return NO
1965
1966         if not prev:
1967             if t == token.DOT:
1968                 prevp = preceding_leaf(p)
1969                 if not prevp or prevp.type != token.NUMBER:
1970                     return NO
1971
1972             elif t == token.LSQB:
1973                 return NO
1974
1975         elif prev.type != token.COMMA:
1976             return NO
1977
1978     elif p.type == syms.argument:
1979         # single argument
1980         if t == token.EQUAL:
1981             return NO
1982
1983         if not prev:
1984             prevp = preceding_leaf(p)
1985             if not prevp or prevp.type == token.LPAR:
1986                 return NO
1987
1988         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
1989             return NO
1990
1991     elif p.type == syms.decorator:
1992         # decorators
1993         return NO
1994
1995     elif p.type == syms.dotted_name:
1996         if prev:
1997             return NO
1998
1999         prevp = preceding_leaf(p)
2000         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2001             return NO
2002
2003     elif p.type == syms.classdef:
2004         if t == token.LPAR:
2005             return NO
2006
2007         if prev and prev.type == token.LPAR:
2008             return NO
2009
2010     elif p.type in {syms.subscript, syms.sliceop}:
2011         # indexing
2012         if not prev:
2013             assert p.parent is not None, "subscripts are always parented"
2014             if p.parent.type == syms.subscriptlist:
2015                 return SPACE
2016
2017             return NO
2018
2019         elif not complex_subscript:
2020             return NO
2021
2022     elif p.type == syms.atom:
2023         if prev and t == token.DOT:
2024             # dots, but not the first one.
2025             return NO
2026
2027     elif p.type == syms.dictsetmaker:
2028         # dict unpacking
2029         if prev and prev.type == token.DOUBLESTAR:
2030             return NO
2031
2032     elif p.type in {syms.factor, syms.star_expr}:
2033         # unary ops
2034         if not prev:
2035             prevp = preceding_leaf(p)
2036             if not prevp or prevp.type in OPENING_BRACKETS:
2037                 return NO
2038
2039             prevp_parent = prevp.parent
2040             assert prevp_parent is not None
2041             if prevp.type == token.COLON and prevp_parent.type in {
2042                 syms.subscript,
2043                 syms.sliceop,
2044             }:
2045                 return NO
2046
2047             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2048                 return NO
2049
2050         elif t in {token.NAME, token.NUMBER, token.STRING}:
2051             return NO
2052
2053     elif p.type == syms.import_from:
2054         if t == token.DOT:
2055             if prev and prev.type == token.DOT:
2056                 return NO
2057
2058         elif t == token.NAME:
2059             if v == "import":
2060                 return SPACE
2061
2062             if prev and prev.type == token.DOT:
2063                 return NO
2064
2065     elif p.type == syms.sliceop:
2066         return NO
2067
2068     return SPACE
2069
2070
2071 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2072     """Return the first leaf that precedes `node`, if any."""
2073     while node:
2074         res = node.prev_sibling
2075         if res:
2076             if isinstance(res, Leaf):
2077                 return res
2078
2079             try:
2080                 return list(res.leaves())[-1]
2081
2082             except IndexError:
2083                 return None
2084
2085         node = node.parent
2086     return None
2087
2088
2089 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2090     """Return the child of `ancestor` that contains `descendant`."""
2091     node: Optional[LN] = descendant
2092     while node and node.parent != ancestor:
2093         node = node.parent
2094     return node
2095
2096
2097 def container_of(leaf: Leaf) -> LN:
2098     """Return `leaf` or one of its ancestors that is the topmost container of it.
2099
2100     By "container" we mean a node where `leaf` is the very first child.
2101     """
2102     same_prefix = leaf.prefix
2103     container: LN = leaf
2104     while container:
2105         parent = container.parent
2106         if parent is None:
2107             break
2108
2109         if parent.children[0].prefix != same_prefix:
2110             break
2111
2112         if parent.type == syms.file_input:
2113             break
2114
2115         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2116             break
2117
2118         container = parent
2119     return container
2120
2121
2122 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2123     """Return the priority of the `leaf` delimiter, given a line break after it.
2124
2125     The delimiter priorities returned here are from those delimiters that would
2126     cause a line break after themselves.
2127
2128     Higher numbers are higher priority.
2129     """
2130     if leaf.type == token.COMMA:
2131         return COMMA_PRIORITY
2132
2133     return 0
2134
2135
2136 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2137     """Return the priority of the `leaf` delimiter, given a line break before it.
2138
2139     The delimiter priorities returned here are from those delimiters that would
2140     cause a line break before themselves.
2141
2142     Higher numbers are higher priority.
2143     """
2144     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2145         # * and ** might also be MATH_OPERATORS but in this case they are not.
2146         # Don't treat them as a delimiter.
2147         return 0
2148
2149     if (
2150         leaf.type == token.DOT
2151         and leaf.parent
2152         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2153         and (previous is None or previous.type in CLOSING_BRACKETS)
2154     ):
2155         return DOT_PRIORITY
2156
2157     if (
2158         leaf.type in MATH_OPERATORS
2159         and leaf.parent
2160         and leaf.parent.type not in {syms.factor, syms.star_expr}
2161     ):
2162         return MATH_PRIORITIES[leaf.type]
2163
2164     if leaf.type in COMPARATORS:
2165         return COMPARATOR_PRIORITY
2166
2167     if (
2168         leaf.type == token.STRING
2169         and previous is not None
2170         and previous.type == token.STRING
2171     ):
2172         return STRING_PRIORITY
2173
2174     if leaf.type not in {token.NAME, token.ASYNC}:
2175         return 0
2176
2177     if (
2178         leaf.value == "for"
2179         and leaf.parent
2180         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2181         or leaf.type == token.ASYNC
2182     ):
2183         if (
2184             not isinstance(leaf.prev_sibling, Leaf)
2185             or leaf.prev_sibling.value != "async"
2186         ):
2187             return COMPREHENSION_PRIORITY
2188
2189     if (
2190         leaf.value == "if"
2191         and leaf.parent
2192         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2193     ):
2194         return COMPREHENSION_PRIORITY
2195
2196     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2197         return TERNARY_PRIORITY
2198
2199     if leaf.value == "is":
2200         return COMPARATOR_PRIORITY
2201
2202     if (
2203         leaf.value == "in"
2204         and leaf.parent
2205         and leaf.parent.type in {syms.comp_op, syms.comparison}
2206         and not (
2207             previous is not None
2208             and previous.type == token.NAME
2209             and previous.value == "not"
2210         )
2211     ):
2212         return COMPARATOR_PRIORITY
2213
2214     if (
2215         leaf.value == "not"
2216         and leaf.parent
2217         and leaf.parent.type == syms.comp_op
2218         and not (
2219             previous is not None
2220             and previous.type == token.NAME
2221             and previous.value == "is"
2222         )
2223     ):
2224         return COMPARATOR_PRIORITY
2225
2226     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2227         return LOGIC_PRIORITY
2228
2229     return 0
2230
2231
2232 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2233 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2234
2235
2236 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2237     """Clean the prefix of the `leaf` and generate comments from it, if any.
2238
2239     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2240     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2241     move because it does away with modifying the grammar to include all the
2242     possible places in which comments can be placed.
2243
2244     The sad consequence for us though is that comments don't "belong" anywhere.
2245     This is why this function generates simple parentless Leaf objects for
2246     comments.  We simply don't know what the correct parent should be.
2247
2248     No matter though, we can live without this.  We really only need to
2249     differentiate between inline and standalone comments.  The latter don't
2250     share the line with any code.
2251
2252     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2253     are emitted with a fake STANDALONE_COMMENT token identifier.
2254     """
2255     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2256         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2257
2258
2259 @dataclass
2260 class ProtoComment:
2261     """Describes a piece of syntax that is a comment.
2262
2263     It's not a :class:`blib2to3.pytree.Leaf` so that:
2264
2265     * it can be cached (`Leaf` objects should not be reused more than once as
2266       they store their lineno, column, prefix, and parent information);
2267     * `newlines` and `consumed` fields are kept separate from the `value`. This
2268       simplifies handling of special marker comments like ``# fmt: off/on``.
2269     """
2270
2271     type: int  # token.COMMENT or STANDALONE_COMMENT
2272     value: str  # content of the comment
2273     newlines: int  # how many newlines before the comment
2274     consumed: int  # how many characters of the original leaf's prefix did we consume
2275
2276
2277 @lru_cache(maxsize=4096)
2278 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2279     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2280     result: List[ProtoComment] = []
2281     if not prefix or "#" not in prefix:
2282         return result
2283
2284     consumed = 0
2285     nlines = 0
2286     ignored_lines = 0
2287     for index, line in enumerate(prefix.split("\n")):
2288         consumed += len(line) + 1  # adding the length of the split '\n'
2289         line = line.lstrip()
2290         if not line:
2291             nlines += 1
2292         if not line.startswith("#"):
2293             # Escaped newlines outside of a comment are not really newlines at
2294             # all. We treat a single-line comment following an escaped newline
2295             # as a simple trailing comment.
2296             if line.endswith("\\"):
2297                 ignored_lines += 1
2298             continue
2299
2300         if index == ignored_lines and not is_endmarker:
2301             comment_type = token.COMMENT  # simple trailing comment
2302         else:
2303             comment_type = STANDALONE_COMMENT
2304         comment = make_comment(line)
2305         result.append(
2306             ProtoComment(
2307                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2308             )
2309         )
2310         nlines = 0
2311     return result
2312
2313
2314 def make_comment(content: str) -> str:
2315     """Return a consistently formatted comment from the given `content` string.
2316
2317     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2318     space between the hash sign and the content.
2319
2320     If `content` didn't start with a hash sign, one is provided.
2321     """
2322     content = content.rstrip()
2323     if not content:
2324         return "#"
2325
2326     if content[0] == "#":
2327         content = content[1:]
2328     if content and content[0] not in " !:#'%":
2329         content = " " + content
2330     return "#" + content
2331
2332
2333 def split_line(
2334     line: Line,
2335     line_length: int,
2336     inner: bool = False,
2337     features: Collection[Feature] = (),
2338 ) -> Iterator[Line]:
2339     """Split a `line` into potentially many lines.
2340
2341     They should fit in the allotted `line_length` but might not be able to.
2342     `inner` signifies that there were a pair of brackets somewhere around the
2343     current `line`, possibly transitively. This means we can fallback to splitting
2344     by delimiters if the LHS/RHS don't yield any results.
2345
2346     `features` are syntactical features that may be used in the output.
2347     """
2348     if line.is_comment:
2349         yield line
2350         return
2351
2352     line_str = str(line).strip("\n")
2353
2354     if (
2355         not line.contains_uncollapsable_type_comments()
2356         and not line.should_explode
2357         and (
2358             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2359             or line.contains_unsplittable_type_ignore()
2360         )
2361     ):
2362         yield line
2363         return
2364
2365     split_funcs: List[SplitFunc]
2366     if line.is_def:
2367         split_funcs = [left_hand_split]
2368     else:
2369
2370         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2371             for omit in generate_trailers_to_omit(line, line_length):
2372                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2373                 if is_line_short_enough(lines[0], line_length=line_length):
2374                     yield from lines
2375                     return
2376
2377             # All splits failed, best effort split with no omits.
2378             # This mostly happens to multiline strings that are by definition
2379             # reported as not fitting a single line.
2380             yield from right_hand_split(line, line_length, features=features)
2381
2382         if line.inside_brackets:
2383             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2384         else:
2385             split_funcs = [rhs]
2386     for split_func in split_funcs:
2387         # We are accumulating lines in `result` because we might want to abort
2388         # mission and return the original line in the end, or attempt a different
2389         # split altogether.
2390         result: List[Line] = []
2391         try:
2392             for l in split_func(line, features):
2393                 if str(l).strip("\n") == line_str:
2394                     raise CannotSplit("Split function returned an unchanged result")
2395
2396                 result.extend(
2397                     split_line(
2398                         l, line_length=line_length, inner=True, features=features
2399                     )
2400                 )
2401         except CannotSplit:
2402             continue
2403
2404         else:
2405             yield from result
2406             break
2407
2408     else:
2409         yield line
2410
2411
2412 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2413     """Split line into many lines, starting with the first matching bracket pair.
2414
2415     Note: this usually looks weird, only use this for function definitions.
2416     Prefer RHS otherwise.  This is why this function is not symmetrical with
2417     :func:`right_hand_split` which also handles optional parentheses.
2418     """
2419     tail_leaves: List[Leaf] = []
2420     body_leaves: List[Leaf] = []
2421     head_leaves: List[Leaf] = []
2422     current_leaves = head_leaves
2423     matching_bracket = None
2424     for leaf in line.leaves:
2425         if (
2426             current_leaves is body_leaves
2427             and leaf.type in CLOSING_BRACKETS
2428             and leaf.opening_bracket is matching_bracket
2429         ):
2430             current_leaves = tail_leaves if body_leaves else head_leaves
2431         current_leaves.append(leaf)
2432         if current_leaves is head_leaves:
2433             if leaf.type in OPENING_BRACKETS:
2434                 matching_bracket = leaf
2435                 current_leaves = body_leaves
2436     if not matching_bracket:
2437         raise CannotSplit("No brackets found")
2438
2439     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2440     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2441     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2442     bracket_split_succeeded_or_raise(head, body, tail)
2443     for result in (head, body, tail):
2444         if result:
2445             yield result
2446
2447
2448 def right_hand_split(
2449     line: Line,
2450     line_length: int,
2451     features: Collection[Feature] = (),
2452     omit: Collection[LeafID] = (),
2453 ) -> Iterator[Line]:
2454     """Split line into many lines, starting with the last matching bracket pair.
2455
2456     If the split was by optional parentheses, attempt splitting without them, too.
2457     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2458     this split.
2459
2460     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2461     """
2462     tail_leaves: List[Leaf] = []
2463     body_leaves: List[Leaf] = []
2464     head_leaves: List[Leaf] = []
2465     current_leaves = tail_leaves
2466     opening_bracket = None
2467     closing_bracket = None
2468     for leaf in reversed(line.leaves):
2469         if current_leaves is body_leaves:
2470             if leaf is opening_bracket:
2471                 current_leaves = head_leaves if body_leaves else tail_leaves
2472         current_leaves.append(leaf)
2473         if current_leaves is tail_leaves:
2474             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2475                 opening_bracket = leaf.opening_bracket
2476                 closing_bracket = leaf
2477                 current_leaves = body_leaves
2478     if not (opening_bracket and closing_bracket and head_leaves):
2479         # If there is no opening or closing_bracket that means the split failed and
2480         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2481         # the matching `opening_bracket` wasn't available on `line` anymore.
2482         raise CannotSplit("No brackets found")
2483
2484     tail_leaves.reverse()
2485     body_leaves.reverse()
2486     head_leaves.reverse()
2487     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2488     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2489     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2490     bracket_split_succeeded_or_raise(head, body, tail)
2491     if (
2492         # the body shouldn't be exploded
2493         not body.should_explode
2494         # the opening bracket is an optional paren
2495         and opening_bracket.type == token.LPAR
2496         and not opening_bracket.value
2497         # the closing bracket is an optional paren
2498         and closing_bracket.type == token.RPAR
2499         and not closing_bracket.value
2500         # it's not an import (optional parens are the only thing we can split on
2501         # in this case; attempting a split without them is a waste of time)
2502         and not line.is_import
2503         # there are no standalone comments in the body
2504         and not body.contains_standalone_comments(0)
2505         # and we can actually remove the parens
2506         and can_omit_invisible_parens(body, line_length)
2507     ):
2508         omit = {id(closing_bracket), *omit}
2509         try:
2510             yield from right_hand_split(line, line_length, features=features, omit=omit)
2511             return
2512
2513         except CannotSplit:
2514             if not (
2515                 can_be_split(body)
2516                 or is_line_short_enough(body, line_length=line_length)
2517             ):
2518                 raise CannotSplit(
2519                     "Splitting failed, body is still too long and can't be split."
2520                 )
2521
2522             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2523                 raise CannotSplit(
2524                     "The current optional pair of parentheses is bound to fail to "
2525                     "satisfy the splitting algorithm because the head or the tail "
2526                     "contains multiline strings which by definition never fit one "
2527                     "line."
2528                 )
2529
2530     ensure_visible(opening_bracket)
2531     ensure_visible(closing_bracket)
2532     for result in (head, body, tail):
2533         if result:
2534             yield result
2535
2536
2537 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2538     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2539
2540     Do nothing otherwise.
2541
2542     A left- or right-hand split is based on a pair of brackets. Content before
2543     (and including) the opening bracket is left on one line, content inside the
2544     brackets is put on a separate line, and finally content starting with and
2545     following the closing bracket is put on a separate line.
2546
2547     Those are called `head`, `body`, and `tail`, respectively. If the split
2548     produced the same line (all content in `head`) or ended up with an empty `body`
2549     and the `tail` is just the closing bracket, then it's considered failed.
2550     """
2551     tail_len = len(str(tail).strip())
2552     if not body:
2553         if tail_len == 0:
2554             raise CannotSplit("Splitting brackets produced the same line")
2555
2556         elif tail_len < 3:
2557             raise CannotSplit(
2558                 f"Splitting brackets on an empty body to save "
2559                 f"{tail_len} characters is not worth it"
2560             )
2561
2562
2563 def bracket_split_build_line(
2564     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2565 ) -> Line:
2566     """Return a new line with given `leaves` and respective comments from `original`.
2567
2568     If `is_body` is True, the result line is one-indented inside brackets and as such
2569     has its first leaf's prefix normalized and a trailing comma added when expected.
2570     """
2571     result = Line(depth=original.depth)
2572     if is_body:
2573         result.inside_brackets = True
2574         result.depth += 1
2575         if leaves:
2576             # Since body is a new indent level, remove spurious leading whitespace.
2577             normalize_prefix(leaves[0], inside_brackets=True)
2578             # Ensure a trailing comma for imports and standalone function arguments, but
2579             # be careful not to add one after any comments.
2580             no_commas = original.is_def and not any(
2581                 l.type == token.COMMA for l in leaves
2582             )
2583
2584             if original.is_import or no_commas:
2585                 for i in range(len(leaves) - 1, -1, -1):
2586                     if leaves[i].type == STANDALONE_COMMENT:
2587                         continue
2588                     elif leaves[i].type == token.COMMA:
2589                         break
2590                     else:
2591                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2592                         break
2593     # Populate the line
2594     for leaf in leaves:
2595         result.append(leaf, preformatted=True)
2596         for comment_after in original.comments_after(leaf):
2597             result.append(comment_after, preformatted=True)
2598     if is_body:
2599         result.should_explode = should_explode(result, opening_bracket)
2600     return result
2601
2602
2603 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2604     """Normalize prefix of the first leaf in every line returned by `split_func`.
2605
2606     This is a decorator over relevant split functions.
2607     """
2608
2609     @wraps(split_func)
2610     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2611         for l in split_func(line, features):
2612             normalize_prefix(l.leaves[0], inside_brackets=True)
2613             yield l
2614
2615     return split_wrapper
2616
2617
2618 @dont_increase_indentation
2619 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2620     """Split according to delimiters of the highest priority.
2621
2622     If the appropriate Features are given, the split will add trailing commas
2623     also in function signatures and calls that contain `*` and `**`.
2624     """
2625     try:
2626         last_leaf = line.leaves[-1]
2627     except IndexError:
2628         raise CannotSplit("Line empty")
2629
2630     bt = line.bracket_tracker
2631     try:
2632         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2633     except ValueError:
2634         raise CannotSplit("No delimiters found")
2635
2636     if delimiter_priority == DOT_PRIORITY:
2637         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2638             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2639
2640     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2641     lowest_depth = sys.maxsize
2642     trailing_comma_safe = True
2643
2644     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2645         """Append `leaf` to current line or to new line if appending impossible."""
2646         nonlocal current_line
2647         try:
2648             current_line.append_safe(leaf, preformatted=True)
2649         except ValueError:
2650             yield current_line
2651
2652             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2653             current_line.append(leaf)
2654
2655     for leaf in line.leaves:
2656         yield from append_to_line(leaf)
2657
2658         for comment_after in line.comments_after(leaf):
2659             yield from append_to_line(comment_after)
2660
2661         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2662         if leaf.bracket_depth == lowest_depth:
2663             if is_vararg(leaf, within={syms.typedargslist}):
2664                 trailing_comma_safe = (
2665                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2666                 )
2667             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2668                 trailing_comma_safe = (
2669                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2670                 )
2671
2672         leaf_priority = bt.delimiters.get(id(leaf))
2673         if leaf_priority == delimiter_priority:
2674             yield current_line
2675
2676             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2677     if current_line:
2678         if (
2679             trailing_comma_safe
2680             and delimiter_priority == COMMA_PRIORITY
2681             and current_line.leaves[-1].type != token.COMMA
2682             and current_line.leaves[-1].type != STANDALONE_COMMENT
2683         ):
2684             current_line.append(Leaf(token.COMMA, ","))
2685         yield current_line
2686
2687
2688 @dont_increase_indentation
2689 def standalone_comment_split(
2690     line: Line, features: Collection[Feature] = ()
2691 ) -> Iterator[Line]:
2692     """Split standalone comments from the rest of the line."""
2693     if not line.contains_standalone_comments(0):
2694         raise CannotSplit("Line does not have any standalone comments")
2695
2696     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2697
2698     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2699         """Append `leaf` to current line or to new line if appending impossible."""
2700         nonlocal current_line
2701         try:
2702             current_line.append_safe(leaf, preformatted=True)
2703         except ValueError:
2704             yield current_line
2705
2706             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2707             current_line.append(leaf)
2708
2709     for leaf in line.leaves:
2710         yield from append_to_line(leaf)
2711
2712         for comment_after in line.comments_after(leaf):
2713             yield from append_to_line(comment_after)
2714
2715     if current_line:
2716         yield current_line
2717
2718
2719 def is_import(leaf: Leaf) -> bool:
2720     """Return True if the given leaf starts an import statement."""
2721     p = leaf.parent
2722     t = leaf.type
2723     v = leaf.value
2724     return bool(
2725         t == token.NAME
2726         and (
2727             (v == "import" and p and p.type == syms.import_name)
2728             or (v == "from" and p and p.type == syms.import_from)
2729         )
2730     )
2731
2732
2733 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
2734     """Return True if the given leaf is a special comment.
2735     Only returns true for type comments for now."""
2736     t = leaf.type
2737     v = leaf.value
2738     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith(
2739         "# type:" + suffix
2740     )
2741
2742
2743 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2744     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2745     else.
2746
2747     Note: don't use backslashes for formatting or you'll lose your voting rights.
2748     """
2749     if not inside_brackets:
2750         spl = leaf.prefix.split("#")
2751         if "\\" not in spl[0]:
2752             nl_count = spl[-1].count("\n")
2753             if len(spl) > 1:
2754                 nl_count -= 1
2755             leaf.prefix = "\n" * nl_count
2756             return
2757
2758     leaf.prefix = ""
2759
2760
2761 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2762     """Make all string prefixes lowercase.
2763
2764     If remove_u_prefix is given, also removes any u prefix from the string.
2765
2766     Note: Mutates its argument.
2767     """
2768     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2769     assert match is not None, f"failed to match string {leaf.value!r}"
2770     orig_prefix = match.group(1)
2771     new_prefix = orig_prefix.lower()
2772     if remove_u_prefix:
2773         new_prefix = new_prefix.replace("u", "")
2774     leaf.value = f"{new_prefix}{match.group(2)}"
2775
2776
2777 def normalize_string_quotes(leaf: Leaf) -> None:
2778     """Prefer double quotes but only if it doesn't cause more escaping.
2779
2780     Adds or removes backslashes as appropriate. Doesn't parse and fix
2781     strings nested in f-strings (yet).
2782
2783     Note: Mutates its argument.
2784     """
2785     value = leaf.value.lstrip("furbFURB")
2786     if value[:3] == '"""':
2787         return
2788
2789     elif value[:3] == "'''":
2790         orig_quote = "'''"
2791         new_quote = '"""'
2792     elif value[0] == '"':
2793         orig_quote = '"'
2794         new_quote = "'"
2795     else:
2796         orig_quote = "'"
2797         new_quote = '"'
2798     first_quote_pos = leaf.value.find(orig_quote)
2799     if first_quote_pos == -1:
2800         return  # There's an internal error
2801
2802     prefix = leaf.value[:first_quote_pos]
2803     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2804     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2805     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2806     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2807     if "r" in prefix.casefold():
2808         if unescaped_new_quote.search(body):
2809             # There's at least one unescaped new_quote in this raw string
2810             # so converting is impossible
2811             return
2812
2813         # Do not introduce or remove backslashes in raw strings
2814         new_body = body
2815     else:
2816         # remove unnecessary escapes
2817         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2818         if body != new_body:
2819             # Consider the string without unnecessary escapes as the original
2820             body = new_body
2821             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2822         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2823         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2824     if "f" in prefix.casefold():
2825         matches = re.findall(
2826             r"""
2827             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2828                 ([^{].*?)  # contents of the brackets except if begins with {{
2829             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2830             """,
2831             new_body,
2832             re.VERBOSE,
2833         )
2834         for m in matches:
2835             if "\\" in str(m):
2836                 # Do not introduce backslashes in interpolated expressions
2837                 return
2838     if new_quote == '"""' and new_body[-1:] == '"':
2839         # edge case:
2840         new_body = new_body[:-1] + '\\"'
2841     orig_escape_count = body.count("\\")
2842     new_escape_count = new_body.count("\\")
2843     if new_escape_count > orig_escape_count:
2844         return  # Do not introduce more escaping
2845
2846     if new_escape_count == orig_escape_count and orig_quote == '"':
2847         return  # Prefer double quotes
2848
2849     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2850
2851
2852 def normalize_numeric_literal(leaf: Leaf) -> None:
2853     """Normalizes numeric (float, int, and complex) literals.
2854
2855     All letters used in the representation are normalized to lowercase (except
2856     in Python 2 long literals).
2857     """
2858     text = leaf.value.lower()
2859     if text.startswith(("0o", "0b")):
2860         # Leave octal and binary literals alone.
2861         pass
2862     elif text.startswith("0x"):
2863         # Change hex literals to upper case.
2864         before, after = text[:2], text[2:]
2865         text = f"{before}{after.upper()}"
2866     elif "e" in text:
2867         before, after = text.split("e")
2868         sign = ""
2869         if after.startswith("-"):
2870             after = after[1:]
2871             sign = "-"
2872         elif after.startswith("+"):
2873             after = after[1:]
2874         before = format_float_or_int_string(before)
2875         text = f"{before}e{sign}{after}"
2876     elif text.endswith(("j", "l")):
2877         number = text[:-1]
2878         suffix = text[-1]
2879         # Capitalize in "2L" because "l" looks too similar to "1".
2880         if suffix == "l":
2881             suffix = "L"
2882         text = f"{format_float_or_int_string(number)}{suffix}"
2883     else:
2884         text = format_float_or_int_string(text)
2885     leaf.value = text
2886
2887
2888 def format_float_or_int_string(text: str) -> str:
2889     """Formats a float string like "1.0"."""
2890     if "." not in text:
2891         return text
2892
2893     before, after = text.split(".")
2894     return f"{before or 0}.{after or 0}"
2895
2896
2897 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2898     """Make existing optional parentheses invisible or create new ones.
2899
2900     `parens_after` is a set of string leaf values immediately after which parens
2901     should be put.
2902
2903     Standardizes on visible parentheses for single-element tuples, and keeps
2904     existing visible parentheses for other tuples and generator expressions.
2905     """
2906     for pc in list_comments(node.prefix, is_endmarker=False):
2907         if pc.value in FMT_OFF:
2908             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2909             return
2910
2911     check_lpar = False
2912     for index, child in enumerate(list(node.children)):
2913         # Add parentheses around long tuple unpacking in assignments.
2914         if (
2915             index == 0
2916             and isinstance(child, Node)
2917             and child.type == syms.testlist_star_expr
2918         ):
2919             check_lpar = True
2920
2921         if check_lpar:
2922             if is_walrus_assignment(child):
2923                 continue
2924             if child.type == syms.atom:
2925                 # Determines if the underlying atom should be surrounded with
2926                 # invisible params - also makes parens invisible recursively
2927                 # within the atom and removes repeated invisible parens within
2928                 # the atom
2929                 should_surround_with_parens = maybe_make_parens_invisible_in_atom(
2930                     child, parent=node
2931                 )
2932
2933                 if should_surround_with_parens:
2934                     lpar = Leaf(token.LPAR, "")
2935                     rpar = Leaf(token.RPAR, "")
2936                     index = child.remove() or 0
2937                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2938             elif is_one_tuple(child):
2939                 # wrap child in visible parentheses
2940                 lpar = Leaf(token.LPAR, "(")
2941                 rpar = Leaf(token.RPAR, ")")
2942                 child.remove()
2943                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2944             elif node.type == syms.import_from:
2945                 # "import from" nodes store parentheses directly as part of
2946                 # the statement
2947                 if child.type == token.LPAR:
2948                     # make parentheses invisible
2949                     child.value = ""  # type: ignore
2950                     node.children[-1].value = ""  # type: ignore
2951                 elif child.type != token.STAR:
2952                     # insert invisible parentheses
2953                     node.insert_child(index, Leaf(token.LPAR, ""))
2954                     node.append_child(Leaf(token.RPAR, ""))
2955                 break
2956
2957             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2958                 # wrap child in invisible parentheses
2959                 lpar = Leaf(token.LPAR, "")
2960                 rpar = Leaf(token.RPAR, "")
2961                 index = child.remove() or 0
2962                 prefix = child.prefix
2963                 child.prefix = ""
2964                 new_child = Node(syms.atom, [lpar, child, rpar])
2965                 new_child.prefix = prefix
2966                 node.insert_child(index, new_child)
2967
2968         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2969
2970
2971 def normalize_fmt_off(node: Node) -> None:
2972     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2973     try_again = True
2974     while try_again:
2975         try_again = convert_one_fmt_off_pair(node)
2976
2977
2978 def convert_one_fmt_off_pair(node: Node) -> bool:
2979     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2980
2981     Returns True if a pair was converted.
2982     """
2983     for leaf in node.leaves():
2984         previous_consumed = 0
2985         for comment in list_comments(leaf.prefix, is_endmarker=False):
2986             if comment.value in FMT_OFF:
2987                 # We only want standalone comments. If there's no previous leaf or
2988                 # the previous leaf is indentation, it's a standalone comment in
2989                 # disguise.
2990                 if comment.type != STANDALONE_COMMENT:
2991                     prev = preceding_leaf(leaf)
2992                     if prev and prev.type not in WHITESPACE:
2993                         continue
2994
2995                 ignored_nodes = list(generate_ignored_nodes(leaf))
2996                 if not ignored_nodes:
2997                     continue
2998
2999                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
3000                 parent = first.parent
3001                 prefix = first.prefix
3002                 first.prefix = prefix[comment.consumed :]
3003                 hidden_value = (
3004                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
3005                 )
3006                 if hidden_value.endswith("\n"):
3007                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
3008                     # leaf (possibly followed by a DEDENT).
3009                     hidden_value = hidden_value[:-1]
3010                 first_idx = None
3011                 for ignored in ignored_nodes:
3012                     index = ignored.remove()
3013                     if first_idx is None:
3014                         first_idx = index
3015                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
3016                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
3017                 parent.insert_child(
3018                     first_idx,
3019                     Leaf(
3020                         STANDALONE_COMMENT,
3021                         hidden_value,
3022                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3023                     ),
3024                 )
3025                 return True
3026
3027             previous_consumed = comment.consumed
3028
3029     return False
3030
3031
3032 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3033     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3034
3035     Stops at the end of the block.
3036     """
3037     container: Optional[LN] = container_of(leaf)
3038     while container is not None and container.type != token.ENDMARKER:
3039         for comment in list_comments(container.prefix, is_endmarker=False):
3040             if comment.value in FMT_ON:
3041                 return
3042
3043         yield container
3044
3045         container = container.next_sibling
3046
3047
3048 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3049     """If it's safe, make the parens in the atom `node` invisible, recursively.
3050     Additionally, remove repeated, adjacent invisible parens from the atom `node`
3051     as they are redundant.
3052
3053     Returns whether the node should itself be wrapped in invisible parentheses.
3054
3055     """
3056     if (
3057         node.type != syms.atom
3058         or is_empty_tuple(node)
3059         or is_one_tuple(node)
3060         or (is_yield(node) and parent.type != syms.expr_stmt)
3061         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3062     ):
3063         return False
3064
3065     first = node.children[0]
3066     last = node.children[-1]
3067     if first.type == token.LPAR and last.type == token.RPAR:
3068         middle = node.children[1]
3069         # make parentheses invisible
3070         first.value = ""  # type: ignore
3071         last.value = ""  # type: ignore
3072         maybe_make_parens_invisible_in_atom(middle, parent=parent)
3073
3074         if is_atom_with_invisible_parens(middle):
3075             # Strip the invisible parens from `middle` by replacing
3076             # it with the child in-between the invisible parens
3077             middle.replace(middle.children[1])
3078
3079         return False
3080
3081     return True
3082
3083
3084 def is_atom_with_invisible_parens(node: LN) -> bool:
3085     """Given a `LN`, determines whether it's an atom `node` with invisible
3086     parens. Useful in dedupe-ing and normalizing parens.
3087     """
3088     if isinstance(node, Leaf) or node.type != syms.atom:
3089         return False
3090
3091     first, last = node.children[0], node.children[-1]
3092     return (
3093         isinstance(first, Leaf)
3094         and first.type == token.LPAR
3095         and first.value == ""
3096         and isinstance(last, Leaf)
3097         and last.type == token.RPAR
3098         and last.value == ""
3099     )
3100
3101
3102 def is_empty_tuple(node: LN) -> bool:
3103     """Return True if `node` holds an empty tuple."""
3104     return (
3105         node.type == syms.atom
3106         and len(node.children) == 2
3107         and node.children[0].type == token.LPAR
3108         and node.children[1].type == token.RPAR
3109     )
3110
3111
3112 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3113     """Returns `wrapped` if `node` is of the shape ( wrapped ).
3114
3115     Parenthesis can be optional. Returns None otherwise"""
3116     if len(node.children) != 3:
3117         return None
3118     lpar, wrapped, rpar = node.children
3119     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3120         return None
3121
3122     return wrapped
3123
3124
3125 def is_one_tuple(node: LN) -> bool:
3126     """Return True if `node` holds a tuple with one element, with or without parens."""
3127     if node.type == syms.atom:
3128         gexp = unwrap_singleton_parenthesis(node)
3129         if gexp is None or gexp.type != syms.testlist_gexp:
3130             return False
3131
3132         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3133
3134     return (
3135         node.type in IMPLICIT_TUPLE
3136         and len(node.children) == 2
3137         and node.children[1].type == token.COMMA
3138     )
3139
3140
3141 def is_walrus_assignment(node: LN) -> bool:
3142     """Return True iff `node` is of the shape ( test := test )"""
3143     inner = unwrap_singleton_parenthesis(node)
3144     return inner is not None and inner.type == syms.namedexpr_test
3145
3146
3147 def is_yield(node: LN) -> bool:
3148     """Return True if `node` holds a `yield` or `yield from` expression."""
3149     if node.type == syms.yield_expr:
3150         return True
3151
3152     if node.type == token.NAME and node.value == "yield":  # type: ignore
3153         return True
3154
3155     if node.type != syms.atom:
3156         return False
3157
3158     if len(node.children) != 3:
3159         return False
3160
3161     lpar, expr, rpar = node.children
3162     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3163         return is_yield(expr)
3164
3165     return False
3166
3167
3168 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3169     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3170
3171     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3172     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3173     extended iterable unpacking (PEP 3132) and additional unpacking
3174     generalizations (PEP 448).
3175     """
3176     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3177         return False
3178
3179     p = leaf.parent
3180     if p.type == syms.star_expr:
3181         # Star expressions are also used as assignment targets in extended
3182         # iterable unpacking (PEP 3132).  See what its parent is instead.
3183         if not p.parent:
3184             return False
3185
3186         p = p.parent
3187
3188     return p.type in within
3189
3190
3191 def is_multiline_string(leaf: Leaf) -> bool:
3192     """Return True if `leaf` is a multiline string that actually spans many lines."""
3193     value = leaf.value.lstrip("furbFURB")
3194     return value[:3] in {'"""', "'''"} and "\n" in value
3195
3196
3197 def is_stub_suite(node: Node) -> bool:
3198     """Return True if `node` is a suite with a stub body."""
3199     if (
3200         len(node.children) != 4
3201         or node.children[0].type != token.NEWLINE
3202         or node.children[1].type != token.INDENT
3203         or node.children[3].type != token.DEDENT
3204     ):
3205         return False
3206
3207     return is_stub_body(node.children[2])
3208
3209
3210 def is_stub_body(node: LN) -> bool:
3211     """Return True if `node` is a simple statement containing an ellipsis."""
3212     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3213         return False
3214
3215     if len(node.children) != 2:
3216         return False
3217
3218     child = node.children[0]
3219     return (
3220         child.type == syms.atom
3221         and len(child.children) == 3
3222         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3223     )
3224
3225
3226 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3227     """Return maximum delimiter priority inside `node`.
3228
3229     This is specific to atoms with contents contained in a pair of parentheses.
3230     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3231     """
3232     if node.type != syms.atom:
3233         return 0
3234
3235     first = node.children[0]
3236     last = node.children[-1]
3237     if not (first.type == token.LPAR and last.type == token.RPAR):
3238         return 0
3239
3240     bt = BracketTracker()
3241     for c in node.children[1:-1]:
3242         if isinstance(c, Leaf):
3243             bt.mark(c)
3244         else:
3245             for leaf in c.leaves():
3246                 bt.mark(leaf)
3247     try:
3248         return bt.max_delimiter_priority()
3249
3250     except ValueError:
3251         return 0
3252
3253
3254 def ensure_visible(leaf: Leaf) -> None:
3255     """Make sure parentheses are visible.
3256
3257     They could be invisible as part of some statements (see
3258     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3259     """
3260     if leaf.type == token.LPAR:
3261         leaf.value = "("
3262     elif leaf.type == token.RPAR:
3263         leaf.value = ")"
3264
3265
3266 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3267     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3268
3269     if not (
3270         opening_bracket.parent
3271         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3272         and opening_bracket.value in "[{("
3273     ):
3274         return False
3275
3276     try:
3277         last_leaf = line.leaves[-1]
3278         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3279         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3280     except (IndexError, ValueError):
3281         return False
3282
3283     return max_priority == COMMA_PRIORITY
3284
3285
3286 def get_features_used(node: Node) -> Set[Feature]:
3287     """Return a set of (relatively) new Python features used in this file.
3288
3289     Currently looking for:
3290     - f-strings;
3291     - underscores in numeric literals;
3292     - trailing commas after * or ** in function signatures and calls;
3293     - positional only arguments in function signatures and lambdas;
3294     """
3295     features: Set[Feature] = set()
3296     for n in node.pre_order():
3297         if n.type == token.STRING:
3298             value_head = n.value[:2]  # type: ignore
3299             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3300                 features.add(Feature.F_STRINGS)
3301
3302         elif n.type == token.NUMBER:
3303             if "_" in n.value:  # type: ignore
3304                 features.add(Feature.NUMERIC_UNDERSCORES)
3305
3306         elif n.type == token.SLASH:
3307             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3308                 features.add(Feature.POS_ONLY_ARGUMENTS)
3309
3310         elif n.type == token.COLONEQUAL:
3311             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3312
3313         elif (
3314             n.type in {syms.typedargslist, syms.arglist}
3315             and n.children
3316             and n.children[-1].type == token.COMMA
3317         ):
3318             if n.type == syms.typedargslist:
3319                 feature = Feature.TRAILING_COMMA_IN_DEF
3320             else:
3321                 feature = Feature.TRAILING_COMMA_IN_CALL
3322
3323             for ch in n.children:
3324                 if ch.type in STARS:
3325                     features.add(feature)
3326
3327                 if ch.type == syms.argument:
3328                     for argch in ch.children:
3329                         if argch.type in STARS:
3330                             features.add(feature)
3331
3332     return features
3333
3334
3335 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3336     """Detect the version to target based on the nodes used."""
3337     features = get_features_used(node)
3338     return {
3339         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3340     }
3341
3342
3343 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3344     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3345
3346     Brackets can be omitted if the entire trailer up to and including
3347     a preceding closing bracket fits in one line.
3348
3349     Yielded sets are cumulative (contain results of previous yields, too).  First
3350     set is empty.
3351     """
3352
3353     omit: Set[LeafID] = set()
3354     yield omit
3355
3356     length = 4 * line.depth
3357     opening_bracket = None
3358     closing_bracket = None
3359     inner_brackets: Set[LeafID] = set()
3360     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3361         length += leaf_length
3362         if length > line_length:
3363             break
3364
3365         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3366         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3367             break
3368
3369         if opening_bracket:
3370             if leaf is opening_bracket:
3371                 opening_bracket = None
3372             elif leaf.type in CLOSING_BRACKETS:
3373                 inner_brackets.add(id(leaf))
3374         elif leaf.type in CLOSING_BRACKETS:
3375             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3376                 # Empty brackets would fail a split so treat them as "inner"
3377                 # brackets (e.g. only add them to the `omit` set if another
3378                 # pair of brackets was good enough.
3379                 inner_brackets.add(id(leaf))
3380                 continue
3381
3382             if closing_bracket:
3383                 omit.add(id(closing_bracket))
3384                 omit.update(inner_brackets)
3385                 inner_brackets.clear()
3386                 yield omit
3387
3388             if leaf.value:
3389                 opening_bracket = leaf.opening_bracket
3390                 closing_bracket = leaf
3391
3392
3393 def get_future_imports(node: Node) -> Set[str]:
3394     """Return a set of __future__ imports in the file."""
3395     imports: Set[str] = set()
3396
3397     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3398         for child in children:
3399             if isinstance(child, Leaf):
3400                 if child.type == token.NAME:
3401                     yield child.value
3402             elif child.type == syms.import_as_name:
3403                 orig_name = child.children[0]
3404                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3405                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3406                 yield orig_name.value
3407             elif child.type == syms.import_as_names:
3408                 yield from get_imports_from_children(child.children)
3409             else:
3410                 raise AssertionError("Invalid syntax parsing imports")
3411
3412     for child in node.children:
3413         if child.type != syms.simple_stmt:
3414             break
3415         first_child = child.children[0]
3416         if isinstance(first_child, Leaf):
3417             # Continue looking if we see a docstring; otherwise stop.
3418             if (
3419                 len(child.children) == 2
3420                 and first_child.type == token.STRING
3421                 and child.children[1].type == token.NEWLINE
3422             ):
3423                 continue
3424             else:
3425                 break
3426         elif first_child.type == syms.import_from:
3427             module_name = first_child.children[1]
3428             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3429                 break
3430             imports |= set(get_imports_from_children(first_child.children[3:]))
3431         else:
3432             break
3433     return imports
3434
3435
3436 def gen_python_files_in_dir(
3437     path: Path,
3438     root: Path,
3439     include: Pattern[str],
3440     exclude: Pattern[str],
3441     report: "Report",
3442 ) -> Iterator[Path]:
3443     """Generate all files under `path` whose paths are not excluded by the
3444     `exclude` regex, but are included by the `include` regex.
3445
3446     Symbolic links pointing outside of the `root` directory are ignored.
3447
3448     `report` is where output about exclusions goes.
3449     """
3450     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3451     for child in path.iterdir():
3452         try:
3453             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3454         except ValueError:
3455             if child.is_symlink():
3456                 report.path_ignored(
3457                     child, f"is a symbolic link that points outside {root}"
3458                 )
3459                 continue
3460
3461             raise
3462
3463         if child.is_dir():
3464             normalized_path += "/"
3465         exclude_match = exclude.search(normalized_path)
3466         if exclude_match and exclude_match.group(0):
3467             report.path_ignored(child, f"matches the --exclude regular expression")
3468             continue
3469
3470         if child.is_dir():
3471             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3472
3473         elif child.is_file():
3474             include_match = include.search(normalized_path)
3475             if include_match:
3476                 yield child
3477
3478
3479 @lru_cache()
3480 def find_project_root(srcs: Iterable[str]) -> Path:
3481     """Return a directory containing .git, .hg, or pyproject.toml.
3482
3483     That directory can be one of the directories passed in `srcs` or their
3484     common parent.
3485
3486     If no directory in the tree contains a marker that would specify it's the
3487     project root, the root of the file system is returned.
3488     """
3489     if not srcs:
3490         return Path("/").resolve()
3491
3492     common_base = min(Path(src).resolve() for src in srcs)
3493     if common_base.is_dir():
3494         # Append a fake file so `parents` below returns `common_base_dir`, too.
3495         common_base /= "fake-file"
3496     for directory in common_base.parents:
3497         if (directory / ".git").is_dir():
3498             return directory
3499
3500         if (directory / ".hg").is_dir():
3501             return directory
3502
3503         if (directory / "pyproject.toml").is_file():
3504             return directory
3505
3506     return directory
3507
3508
3509 @dataclass
3510 class Report:
3511     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3512
3513     check: bool = False
3514     quiet: bool = False
3515     verbose: bool = False
3516     change_count: int = 0
3517     same_count: int = 0
3518     failure_count: int = 0
3519
3520     def done(self, src: Path, changed: Changed) -> None:
3521         """Increment the counter for successful reformatting. Write out a message."""
3522         if changed is Changed.YES:
3523             reformatted = "would reformat" if self.check else "reformatted"
3524             if self.verbose or not self.quiet:
3525                 out(f"{reformatted} {src}")
3526             self.change_count += 1
3527         else:
3528             if self.verbose:
3529                 if changed is Changed.NO:
3530                     msg = f"{src} already well formatted, good job."
3531                 else:
3532                     msg = f"{src} wasn't modified on disk since last run."
3533                 out(msg, bold=False)
3534             self.same_count += 1
3535
3536     def failed(self, src: Path, message: str) -> None:
3537         """Increment the counter for failed reformatting. Write out a message."""
3538         err(f"error: cannot format {src}: {message}")
3539         self.failure_count += 1
3540
3541     def path_ignored(self, path: Path, message: str) -> None:
3542         if self.verbose:
3543             out(f"{path} ignored: {message}", bold=False)
3544
3545     @property
3546     def return_code(self) -> int:
3547         """Return the exit code that the app should use.
3548
3549         This considers the current state of changed files and failures:
3550         - if there were any failures, return 123;
3551         - if any files were changed and --check is being used, return 1;
3552         - otherwise return 0.
3553         """
3554         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3555         # 126 we have special return codes reserved by the shell.
3556         if self.failure_count:
3557             return 123
3558
3559         elif self.change_count and self.check:
3560             return 1
3561
3562         return 0
3563
3564     def __str__(self) -> str:
3565         """Render a color report of the current state.
3566
3567         Use `click.unstyle` to remove colors.
3568         """
3569         if self.check:
3570             reformatted = "would be reformatted"
3571             unchanged = "would be left unchanged"
3572             failed = "would fail to reformat"
3573         else:
3574             reformatted = "reformatted"
3575             unchanged = "left unchanged"
3576             failed = "failed to reformat"
3577         report = []
3578         if self.change_count:
3579             s = "s" if self.change_count > 1 else ""
3580             report.append(
3581                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3582             )
3583         if self.same_count:
3584             s = "s" if self.same_count > 1 else ""
3585             report.append(f"{self.same_count} file{s} {unchanged}")
3586         if self.failure_count:
3587             s = "s" if self.failure_count > 1 else ""
3588             report.append(
3589                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3590             )
3591         return ", ".join(report) + "."
3592
3593
3594 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3595     filename = "<unknown>"
3596     if sys.version_info >= (3, 8):
3597         # TODO: support Python 4+ ;)
3598         for minor_version in range(sys.version_info[1], 4, -1):
3599             try:
3600                 return ast.parse(src, filename, feature_version=(3, minor_version))
3601             except SyntaxError:
3602                 continue
3603     else:
3604         for feature_version in (7, 6):
3605             try:
3606                 return ast3.parse(src, filename, feature_version=feature_version)
3607             except SyntaxError:
3608                 continue
3609
3610     return ast27.parse(src)
3611
3612
3613 def _fixup_ast_constants(
3614     node: Union[ast.AST, ast3.AST, ast27.AST]
3615 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3616     """Map ast nodes deprecated in 3.8 to Constant."""
3617     # casts are required until this is released:
3618     # https://github.com/python/typeshed/pull/3142
3619     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3620         return cast(ast.AST, ast.Constant(value=node.s))
3621     elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3622         return cast(ast.AST, ast.Constant(value=node.n))
3623     elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3624         return cast(ast.AST, ast.Constant(value=node.value))
3625     return node
3626
3627
3628 def assert_equivalent(src: str, dst: str) -> None:
3629     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3630
3631     def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3632         """Simple visitor generating strings to compare ASTs by content."""
3633
3634         node = _fixup_ast_constants(node)
3635
3636         yield f"{'  ' * depth}{node.__class__.__name__}("
3637
3638         for field in sorted(node._fields):
3639             # TypeIgnore has only one field 'lineno' which breaks this comparison
3640             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3641             if sys.version_info >= (3, 8):
3642                 type_ignore_classes += (ast.TypeIgnore,)
3643             if isinstance(node, type_ignore_classes):
3644                 break
3645
3646             try:
3647                 value = getattr(node, field)
3648             except AttributeError:
3649                 continue
3650
3651             yield f"{'  ' * (depth+1)}{field}="
3652
3653             if isinstance(value, list):
3654                 for item in value:
3655                     # Ignore nested tuples within del statements, because we may insert
3656                     # parentheses and they change the AST.
3657                     if (
3658                         field == "targets"
3659                         and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3660                         and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3661                     ):
3662                         for item in item.elts:
3663                             yield from _v(item, depth + 2)
3664                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3665                         yield from _v(item, depth + 2)
3666
3667             elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3668                 yield from _v(value, depth + 2)
3669
3670             else:
3671                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3672
3673         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3674
3675     try:
3676         src_ast = parse_ast(src)
3677     except Exception as exc:
3678         raise AssertionError(
3679             f"cannot use --safe with this file; failed to parse source file.  "
3680             f"AST error message: {exc}"
3681         )
3682
3683     try:
3684         dst_ast = parse_ast(dst)
3685     except Exception as exc:
3686         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3687         raise AssertionError(
3688             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3689             f"Please report a bug on https://github.com/psf/black/issues.  "
3690             f"This invalid output might be helpful: {log}"
3691         ) from None
3692
3693     src_ast_str = "\n".join(_v(src_ast))
3694     dst_ast_str = "\n".join(_v(dst_ast))
3695     if src_ast_str != dst_ast_str:
3696         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3697         raise AssertionError(
3698             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3699             f"the source.  "
3700             f"Please report a bug on https://github.com/psf/black/issues.  "
3701             f"This diff might be helpful: {log}"
3702         ) from None
3703
3704
3705 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3706     """Raise AssertionError if `dst` reformats differently the second time."""
3707     newdst = format_str(dst, mode=mode)
3708     if dst != newdst:
3709         log = dump_to_file(
3710             diff(src, dst, "source", "first pass"),
3711             diff(dst, newdst, "first pass", "second pass"),
3712         )
3713         raise AssertionError(
3714             f"INTERNAL ERROR: Black produced different code on the second pass "
3715             f"of the formatter.  "
3716             f"Please report a bug on https://github.com/psf/black/issues.  "
3717             f"This diff might be helpful: {log}"
3718         ) from None
3719
3720
3721 def dump_to_file(*output: str) -> str:
3722     """Dump `output` to a temporary file. Return path to the file."""
3723     with tempfile.NamedTemporaryFile(
3724         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3725     ) as f:
3726         for lines in output:
3727             f.write(lines)
3728             if lines and lines[-1] != "\n":
3729                 f.write("\n")
3730     return f.name
3731
3732
3733 @contextmanager
3734 def nullcontext() -> Iterator[None]:
3735     """Return context manager that does nothing.
3736     Similar to `nullcontext` from python 3.7"""
3737     yield
3738
3739
3740 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3741     """Return a unified diff string between strings `a` and `b`."""
3742     import difflib
3743
3744     a_lines = [line + "\n" for line in a.split("\n")]
3745     b_lines = [line + "\n" for line in b.split("\n")]
3746     return "".join(
3747         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3748     )
3749
3750
3751 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3752     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3753     err("Aborted!")
3754     for task in tasks:
3755         task.cancel()
3756
3757
3758 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3759     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3760     try:
3761         if sys.version_info[:2] >= (3, 7):
3762             all_tasks = asyncio.all_tasks
3763         else:
3764             all_tasks = asyncio.Task.all_tasks
3765         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3766         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3767         if not to_cancel:
3768             return
3769
3770         for task in to_cancel:
3771             task.cancel()
3772         loop.run_until_complete(
3773             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3774         )
3775     finally:
3776         # `concurrent.futures.Future` objects cannot be cancelled once they
3777         # are already running. There might be some when the `shutdown()` happened.
3778         # Silence their logger's spew about the event loop being closed.
3779         cf_logger = logging.getLogger("concurrent.futures")
3780         cf_logger.setLevel(logging.CRITICAL)
3781         loop.close()
3782
3783
3784 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3785     """Replace `regex` with `replacement` twice on `original`.
3786
3787     This is used by string normalization to perform replaces on
3788     overlapping matches.
3789     """
3790     return regex.sub(replacement, regex.sub(replacement, original))
3791
3792
3793 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3794     """Compile a regular expression string in `regex`.
3795
3796     If it contains newlines, use verbose mode.
3797     """
3798     if "\n" in regex:
3799         regex = "(?x)" + regex
3800     return re.compile(regex)
3801
3802
3803 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3804     """Like `reversed(enumerate(sequence))` if that were possible."""
3805     index = len(sequence) - 1
3806     for element in reversed(sequence):
3807         yield (index, element)
3808         index -= 1
3809
3810
3811 def enumerate_with_length(
3812     line: Line, reversed: bool = False
3813 ) -> Iterator[Tuple[Index, Leaf, int]]:
3814     """Return an enumeration of leaves with their length.
3815
3816     Stops prematurely on multiline strings and standalone comments.
3817     """
3818     op = cast(
3819         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3820         enumerate_reversed if reversed else enumerate,
3821     )
3822     for index, leaf in op(line.leaves):
3823         length = len(leaf.prefix) + len(leaf.value)
3824         if "\n" in leaf.value:
3825             return  # Multiline strings, we can't continue.
3826
3827         for comment in line.comments_after(leaf):
3828             length += len(comment.value)
3829
3830         yield index, leaf, length
3831
3832
3833 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3834     """Return True if `line` is no longer than `line_length`.
3835
3836     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3837     """
3838     if not line_str:
3839         line_str = str(line).strip("\n")
3840     return (
3841         len(line_str) <= line_length
3842         and "\n" not in line_str  # multiline strings
3843         and not line.contains_standalone_comments()
3844     )
3845
3846
3847 def can_be_split(line: Line) -> bool:
3848     """Return False if the line cannot be split *for sure*.
3849
3850     This is not an exhaustive search but a cheap heuristic that we can use to
3851     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3852     in unnecessary parentheses).
3853     """
3854     leaves = line.leaves
3855     if len(leaves) < 2:
3856         return False
3857
3858     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3859         call_count = 0
3860         dot_count = 0
3861         next = leaves[-1]
3862         for leaf in leaves[-2::-1]:
3863             if leaf.type in OPENING_BRACKETS:
3864                 if next.type not in CLOSING_BRACKETS:
3865                     return False
3866
3867                 call_count += 1
3868             elif leaf.type == token.DOT:
3869                 dot_count += 1
3870             elif leaf.type == token.NAME:
3871                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3872                     return False
3873
3874             elif leaf.type not in CLOSING_BRACKETS:
3875                 return False
3876
3877             if dot_count > 1 and call_count > 1:
3878                 return False
3879
3880     return True
3881
3882
3883 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3884     """Does `line` have a shape safe to reformat without optional parens around it?
3885
3886     Returns True for only a subset of potentially nice looking formattings but
3887     the point is to not return false positives that end up producing lines that
3888     are too long.
3889     """
3890     bt = line.bracket_tracker
3891     if not bt.delimiters:
3892         # Without delimiters the optional parentheses are useless.
3893         return True
3894
3895     max_priority = bt.max_delimiter_priority()
3896     if bt.delimiter_count_with_priority(max_priority) > 1:
3897         # With more than one delimiter of a kind the optional parentheses read better.
3898         return False
3899
3900     if max_priority == DOT_PRIORITY:
3901         # A single stranded method call doesn't require optional parentheses.
3902         return True
3903
3904     assert len(line.leaves) >= 2, "Stranded delimiter"
3905
3906     first = line.leaves[0]
3907     second = line.leaves[1]
3908     penultimate = line.leaves[-2]
3909     last = line.leaves[-1]
3910
3911     # With a single delimiter, omit if the expression starts or ends with
3912     # a bracket.
3913     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3914         remainder = False
3915         length = 4 * line.depth
3916         for _index, leaf, leaf_length in enumerate_with_length(line):
3917             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3918                 remainder = True
3919             if remainder:
3920                 length += leaf_length
3921                 if length > line_length:
3922                     break
3923
3924                 if leaf.type in OPENING_BRACKETS:
3925                     # There are brackets we can further split on.
3926                     remainder = False
3927
3928         else:
3929             # checked the entire string and line length wasn't exceeded
3930             if len(line.leaves) == _index + 1:
3931                 return True
3932
3933         # Note: we are not returning False here because a line might have *both*
3934         # a leading opening bracket and a trailing closing bracket.  If the
3935         # opening bracket doesn't match our rule, maybe the closing will.
3936
3937     if (
3938         last.type == token.RPAR
3939         or last.type == token.RBRACE
3940         or (
3941             # don't use indexing for omitting optional parentheses;
3942             # it looks weird
3943             last.type == token.RSQB
3944             and last.parent
3945             and last.parent.type != syms.trailer
3946         )
3947     ):
3948         if penultimate.type in OPENING_BRACKETS:
3949             # Empty brackets don't help.
3950             return False
3951
3952         if is_multiline_string(first):
3953             # Additional wrapping of a multiline string in this situation is
3954             # unnecessary.
3955             return True
3956
3957         length = 4 * line.depth
3958         seen_other_brackets = False
3959         for _index, leaf, leaf_length in enumerate_with_length(line):
3960             length += leaf_length
3961             if leaf is last.opening_bracket:
3962                 if seen_other_brackets or length <= line_length:
3963                     return True
3964
3965             elif leaf.type in OPENING_BRACKETS:
3966                 # There are brackets we can further split on.
3967                 seen_other_brackets = True
3968
3969     return False
3970
3971
3972 def get_cache_file(mode: FileMode) -> Path:
3973     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3974
3975
3976 def read_cache(mode: FileMode) -> Cache:
3977     """Read the cache if it exists and is well formed.
3978
3979     If it is not well formed, the call to write_cache later should resolve the issue.
3980     """
3981     cache_file = get_cache_file(mode)
3982     if not cache_file.exists():
3983         return {}
3984
3985     with cache_file.open("rb") as fobj:
3986         try:
3987             cache: Cache = pickle.load(fobj)
3988         except pickle.UnpicklingError:
3989             return {}
3990
3991     return cache
3992
3993
3994 def get_cache_info(path: Path) -> CacheInfo:
3995     """Return the information used to check if a file is already formatted or not."""
3996     stat = path.stat()
3997     return stat.st_mtime, stat.st_size
3998
3999
4000 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
4001     """Split an iterable of paths in `sources` into two sets.
4002
4003     The first contains paths of files that modified on disk or are not in the
4004     cache. The other contains paths to non-modified files.
4005     """
4006     todo, done = set(), set()
4007     for src in sources:
4008         src = src.resolve()
4009         if cache.get(src) != get_cache_info(src):
4010             todo.add(src)
4011         else:
4012             done.add(src)
4013     return todo, done
4014
4015
4016 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
4017     """Update the cache file."""
4018     cache_file = get_cache_file(mode)
4019     try:
4020         CACHE_DIR.mkdir(parents=True, exist_ok=True)
4021         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
4022         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4023             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
4024         os.replace(f.name, cache_file)
4025     except OSError:
4026         pass
4027
4028
4029 def patch_click() -> None:
4030     """Make Click not crash.
4031
4032     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4033     default which restricts paths that it can access during the lifetime of the
4034     application.  Click refuses to work in this scenario by raising a RuntimeError.
4035
4036     In case of Black the likelihood that non-ASCII characters are going to be used in
4037     file paths is minimal since it's Python source code.  Moreover, this crash was
4038     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4039     """
4040     try:
4041         from click import core
4042         from click import _unicodefun  # type: ignore
4043     except ModuleNotFoundError:
4044         return
4045
4046     for module in (core, _unicodefun):
4047         if hasattr(module, "_verify_python3_env"):
4048             module._verify_python3_env = lambda: None
4049
4050
4051 def patched_main() -> None:
4052     freeze_support()
4053     patch_click()
4054     main()
4055
4056
4057 if __name__ == "__main__":
4058     patched_main()