]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

restore cursor to same line of code, not same line of buffer (#989)
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from contextlib import contextmanager
5 from datetime import datetime
6 from enum import Enum
7 from functools import lru_cache, partial, wraps
8 import io
9 import itertools
10 import logging
11 from multiprocessing import Manager, freeze_support
12 import os
13 from pathlib import Path
14 import pickle
15 import regex as re
16 import signal
17 import sys
18 import tempfile
19 import tokenize
20 import traceback
21 from typing import (
22     Any,
23     Callable,
24     Collection,
25     Dict,
26     Generator,
27     Generic,
28     Iterable,
29     Iterator,
30     List,
31     Optional,
32     Pattern,
33     Sequence,
34     Set,
35     Tuple,
36     TypeVar,
37     Union,
38     cast,
39 )
40
41 from appdirs import user_cache_dir
42 from attr import dataclass, evolve, Factory
43 import click
44 import toml
45 from typed_ast import ast3, ast27
46
47 # lib2to3 fork
48 from blib2to3.pytree import Node, Leaf, type_repr
49 from blib2to3 import pygram, pytree
50 from blib2to3.pgen2 import driver, token
51 from blib2to3.pgen2.grammar import Grammar
52 from blib2to3.pgen2.parse import ParseError
53
54 from _version import version as __version__
55
56 DEFAULT_LINE_LENGTH = 88
57 DEFAULT_EXCLUDES = (
58     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
59 )
60 DEFAULT_INCLUDES = r"\.pyi?$"
61 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
62
63
64 # types
65 FileContent = str
66 Encoding = str
67 NewLine = str
68 Depth = int
69 NodeType = int
70 LeafID = int
71 Priority = int
72 Index = int
73 LN = Union[Leaf, Node]
74 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
75 Timestamp = float
76 FileSize = int
77 CacheInfo = Tuple[Timestamp, FileSize]
78 Cache = Dict[Path, CacheInfo]
79 out = partial(click.secho, bold=True, err=True)
80 err = partial(click.secho, fg="red", err=True)
81
82 pygram.initialize(CACHE_DIR)
83 syms = pygram.python_symbols
84
85
86 class NothingChanged(UserWarning):
87     """Raised when reformatted code is the same as source."""
88
89
90 class CannotSplit(Exception):
91     """A readable split that fits the allotted line length is impossible."""
92
93
94 class InvalidInput(ValueError):
95     """Raised when input source code fails all parse attempts."""
96
97
98 class WriteBack(Enum):
99     NO = 0
100     YES = 1
101     DIFF = 2
102     CHECK = 3
103
104     @classmethod
105     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
106         if check and not diff:
107             return cls.CHECK
108
109         return cls.DIFF if diff else cls.YES
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 class TargetVersion(Enum):
119     PY27 = 2
120     PY33 = 3
121     PY34 = 4
122     PY35 = 5
123     PY36 = 6
124     PY37 = 7
125     PY38 = 8
126
127     def is_python2(self) -> bool:
128         return self is TargetVersion.PY27
129
130
131 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
132
133
134 class Feature(Enum):
135     # All string literals are unicode
136     UNICODE_LITERALS = 1
137     F_STRINGS = 2
138     NUMERIC_UNDERSCORES = 3
139     TRAILING_COMMA_IN_CALL = 4
140     TRAILING_COMMA_IN_DEF = 5
141     # The following two feature-flags are mutually exclusive, and exactly one should be
142     # set for every version of python.
143     ASYNC_IDENTIFIERS = 6
144     ASYNC_KEYWORDS = 7
145     ASSIGNMENT_EXPRESSIONS = 8
146     POS_ONLY_ARGUMENTS = 9
147
148
149 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
150     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
151     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
152     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
153     TargetVersion.PY35: {
154         Feature.UNICODE_LITERALS,
155         Feature.TRAILING_COMMA_IN_CALL,
156         Feature.ASYNC_IDENTIFIERS,
157     },
158     TargetVersion.PY36: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA_IN_CALL,
163         Feature.TRAILING_COMMA_IN_DEF,
164         Feature.ASYNC_IDENTIFIERS,
165     },
166     TargetVersion.PY37: {
167         Feature.UNICODE_LITERALS,
168         Feature.F_STRINGS,
169         Feature.NUMERIC_UNDERSCORES,
170         Feature.TRAILING_COMMA_IN_CALL,
171         Feature.TRAILING_COMMA_IN_DEF,
172         Feature.ASYNC_KEYWORDS,
173     },
174     TargetVersion.PY38: {
175         Feature.UNICODE_LITERALS,
176         Feature.F_STRINGS,
177         Feature.NUMERIC_UNDERSCORES,
178         Feature.TRAILING_COMMA_IN_CALL,
179         Feature.TRAILING_COMMA_IN_DEF,
180         Feature.ASYNC_KEYWORDS,
181         Feature.ASSIGNMENT_EXPRESSIONS,
182         Feature.POS_ONLY_ARGUMENTS,
183     },
184 }
185
186
187 @dataclass
188 class FileMode:
189     target_versions: Set[TargetVersion] = Factory(set)
190     line_length: int = DEFAULT_LINE_LENGTH
191     string_normalization: bool = True
192     is_pyi: bool = False
193
194     def get_cache_key(self) -> str:
195         if self.target_versions:
196             version_str = ",".join(
197                 str(version.value)
198                 for version in sorted(self.target_versions, key=lambda v: v.value)
199             )
200         else:
201             version_str = "-"
202         parts = [
203             version_str,
204             str(self.line_length),
205             str(int(self.string_normalization)),
206             str(int(self.is_pyi)),
207         ]
208         return ".".join(parts)
209
210
211 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
212     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
213
214
215 def read_pyproject_toml(
216     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
217 ) -> Optional[str]:
218     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
219
220     Returns the path to a successfully found and read configuration file, None
221     otherwise.
222     """
223     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
224     if not value:
225         root = find_project_root(ctx.params.get("src", ()))
226         path = root / "pyproject.toml"
227         if path.is_file():
228             value = str(path)
229         else:
230             return None
231
232     try:
233         pyproject_toml = toml.load(value)
234         config = pyproject_toml.get("tool", {}).get("black", {})
235     except (toml.TomlDecodeError, OSError) as e:
236         raise click.FileError(
237             filename=value, hint=f"Error reading configuration file: {e}"
238         )
239
240     if not config:
241         return None
242
243     if ctx.default_map is None:
244         ctx.default_map = {}
245     ctx.default_map.update(  # type: ignore  # bad types in .pyi
246         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
247     )
248     return value
249
250
251 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
252 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
253 @click.option(
254     "-l",
255     "--line-length",
256     type=int,
257     default=DEFAULT_LINE_LENGTH,
258     help="How many characters per line to allow.",
259     show_default=True,
260 )
261 @click.option(
262     "-t",
263     "--target-version",
264     type=click.Choice([v.name.lower() for v in TargetVersion]),
265     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
266     multiple=True,
267     help=(
268         "Python versions that should be supported by Black's output. [default: "
269         "per-file auto-detection]"
270     ),
271 )
272 @click.option(
273     "--py36",
274     is_flag=True,
275     help=(
276         "Allow using Python 3.6-only syntax on all input files.  This will put "
277         "trailing commas in function signatures and calls also after *args and "
278         "**kwargs. Deprecated; use --target-version instead. "
279         "[default: per-file auto-detection]"
280     ),
281 )
282 @click.option(
283     "--pyi",
284     is_flag=True,
285     help=(
286         "Format all input files like typing stubs regardless of file extension "
287         "(useful when piping source on standard input)."
288     ),
289 )
290 @click.option(
291     "-S",
292     "--skip-string-normalization",
293     is_flag=True,
294     help="Don't normalize string quotes or prefixes.",
295 )
296 @click.option(
297     "--check",
298     is_flag=True,
299     help=(
300         "Don't write the files back, just return the status.  Return code 0 "
301         "means nothing would change.  Return code 1 means some files would be "
302         "reformatted.  Return code 123 means there was an internal error."
303     ),
304 )
305 @click.option(
306     "--diff",
307     is_flag=True,
308     help="Don't write the files back, just output a diff for each file on stdout.",
309 )
310 @click.option(
311     "--fast/--safe",
312     is_flag=True,
313     help="If --fast given, skip temporary sanity checks. [default: --safe]",
314 )
315 @click.option(
316     "--include",
317     type=str,
318     default=DEFAULT_INCLUDES,
319     help=(
320         "A regular expression that matches files and directories that should be "
321         "included on recursive searches.  An empty value means all files are "
322         "included regardless of the name.  Use forward slashes for directories on "
323         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
324         "later."
325     ),
326     show_default=True,
327 )
328 @click.option(
329     "--exclude",
330     type=str,
331     default=DEFAULT_EXCLUDES,
332     help=(
333         "A regular expression that matches files and directories that should be "
334         "excluded on recursive searches.  An empty value means no paths are excluded. "
335         "Use forward slashes for directories on all platforms (Windows, too).  "
336         "Exclusions are calculated first, inclusions later."
337     ),
338     show_default=True,
339 )
340 @click.option(
341     "-q",
342     "--quiet",
343     is_flag=True,
344     help=(
345         "Don't emit non-error messages to stderr. Errors are still emitted; "
346         "silence those with 2>/dev/null."
347     ),
348 )
349 @click.option(
350     "-v",
351     "--verbose",
352     is_flag=True,
353     help=(
354         "Also emit messages to stderr about files that were not changed or were "
355         "ignored due to --exclude=."
356     ),
357 )
358 @click.version_option(version=__version__)
359 @click.argument(
360     "src",
361     nargs=-1,
362     type=click.Path(
363         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
364     ),
365     is_eager=True,
366 )
367 @click.option(
368     "--config",
369     type=click.Path(
370         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
371     ),
372     is_eager=True,
373     callback=read_pyproject_toml,
374     help="Read configuration from PATH.",
375 )
376 @click.pass_context
377 def main(
378     ctx: click.Context,
379     code: Optional[str],
380     line_length: int,
381     target_version: List[TargetVersion],
382     check: bool,
383     diff: bool,
384     fast: bool,
385     pyi: bool,
386     py36: bool,
387     skip_string_normalization: bool,
388     quiet: bool,
389     verbose: bool,
390     include: str,
391     exclude: str,
392     src: Tuple[str],
393     config: Optional[str],
394 ) -> None:
395     """The uncompromising code formatter."""
396     write_back = WriteBack.from_configuration(check=check, diff=diff)
397     if target_version:
398         if py36:
399             err(f"Cannot use both --target-version and --py36")
400             ctx.exit(2)
401         else:
402             versions = set(target_version)
403     elif py36:
404         err(
405             "--py36 is deprecated and will be removed in a future version. "
406             "Use --target-version py36 instead."
407         )
408         versions = PY36_VERSIONS
409     else:
410         # We'll autodetect later.
411         versions = set()
412     mode = FileMode(
413         target_versions=versions,
414         line_length=line_length,
415         is_pyi=pyi,
416         string_normalization=not skip_string_normalization,
417     )
418     if config and verbose:
419         out(f"Using configuration from {config}.", bold=False, fg="blue")
420     if code is not None:
421         print(format_str(code, mode=mode))
422         ctx.exit(0)
423     try:
424         include_regex = re_compile_maybe_verbose(include)
425     except re.error:
426         err(f"Invalid regular expression for include given: {include!r}")
427         ctx.exit(2)
428     try:
429         exclude_regex = re_compile_maybe_verbose(exclude)
430     except re.error:
431         err(f"Invalid regular expression for exclude given: {exclude!r}")
432         ctx.exit(2)
433     report = Report(check=check, quiet=quiet, verbose=verbose)
434     root = find_project_root(src)
435     sources: Set[Path] = set()
436     path_empty(src, quiet, verbose, ctx)
437     for s in src:
438         p = Path(s)
439         if p.is_dir():
440             sources.update(
441                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
442             )
443         elif p.is_file() or s == "-":
444             # if a file was explicitly given, we don't care about its extension
445             sources.add(p)
446         else:
447             err(f"invalid path: {s}")
448     if len(sources) == 0:
449         if verbose or not quiet:
450             out("No Python files are present to be formatted. Nothing to do 😴")
451         ctx.exit(0)
452
453     if len(sources) == 1:
454         reformat_one(
455             src=sources.pop(),
456             fast=fast,
457             write_back=write_back,
458             mode=mode,
459             report=report,
460         )
461     else:
462         reformat_many(
463             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
464         )
465
466     if verbose or not quiet:
467         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
468         click.secho(str(report), err=True)
469     ctx.exit(report.return_code)
470
471
472 def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
473     """
474     Exit if there is no `src` provided for formatting
475     """
476     if not src:
477         if verbose or not quiet:
478             out("No Path provided. Nothing to do 😴")
479             ctx.exit(0)
480
481
482 def reformat_one(
483     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
484 ) -> None:
485     """Reformat a single file under `src` without spawning child processes.
486
487     `fast`, `write_back`, and `mode` options are passed to
488     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
489     """
490     try:
491         changed = Changed.NO
492         if not src.is_file() and str(src) == "-":
493             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
494                 changed = Changed.YES
495         else:
496             cache: Cache = {}
497             if write_back != WriteBack.DIFF:
498                 cache = read_cache(mode)
499                 res_src = src.resolve()
500                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
501                     changed = Changed.CACHED
502             if changed is not Changed.CACHED and format_file_in_place(
503                 src, fast=fast, write_back=write_back, mode=mode
504             ):
505                 changed = Changed.YES
506             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
507                 write_back is WriteBack.CHECK and changed is Changed.NO
508             ):
509                 write_cache(cache, [src], mode)
510         report.done(src, changed)
511     except Exception as exc:
512         report.failed(src, str(exc))
513
514
515 def reformat_many(
516     sources: Set[Path],
517     fast: bool,
518     write_back: WriteBack,
519     mode: FileMode,
520     report: "Report",
521 ) -> None:
522     """Reformat multiple files using a ProcessPoolExecutor."""
523     loop = asyncio.get_event_loop()
524     worker_count = os.cpu_count()
525     if sys.platform == "win32":
526         # Work around https://bugs.python.org/issue26903
527         worker_count = min(worker_count, 61)
528     executor = ProcessPoolExecutor(max_workers=worker_count)
529     try:
530         loop.run_until_complete(
531             schedule_formatting(
532                 sources=sources,
533                 fast=fast,
534                 write_back=write_back,
535                 mode=mode,
536                 report=report,
537                 loop=loop,
538                 executor=executor,
539             )
540         )
541     finally:
542         shutdown(loop)
543         executor.shutdown()
544
545
546 async def schedule_formatting(
547     sources: Set[Path],
548     fast: bool,
549     write_back: WriteBack,
550     mode: FileMode,
551     report: "Report",
552     loop: asyncio.AbstractEventLoop,
553     executor: Executor,
554 ) -> None:
555     """Run formatting of `sources` in parallel using the provided `executor`.
556
557     (Use ProcessPoolExecutors for actual parallelism.)
558
559     `write_back`, `fast`, and `mode` options are passed to
560     :func:`format_file_in_place`.
561     """
562     cache: Cache = {}
563     if write_back != WriteBack.DIFF:
564         cache = read_cache(mode)
565         sources, cached = filter_cached(cache, sources)
566         for src in sorted(cached):
567             report.done(src, Changed.CACHED)
568     if not sources:
569         return
570
571     cancelled = []
572     sources_to_cache = []
573     lock = None
574     if write_back == WriteBack.DIFF:
575         # For diff output, we need locks to ensure we don't interleave output
576         # from different processes.
577         manager = Manager()
578         lock = manager.Lock()
579     tasks = {
580         asyncio.ensure_future(
581             loop.run_in_executor(
582                 executor, format_file_in_place, src, fast, mode, write_back, lock
583             )
584         ): src
585         for src in sorted(sources)
586     }
587     pending: Iterable[asyncio.Future] = tasks.keys()
588     try:
589         loop.add_signal_handler(signal.SIGINT, cancel, pending)
590         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
591     except NotImplementedError:
592         # There are no good alternatives for these on Windows.
593         pass
594     while pending:
595         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
596         for task in done:
597             src = tasks.pop(task)
598             if task.cancelled():
599                 cancelled.append(task)
600             elif task.exception():
601                 report.failed(src, str(task.exception()))
602             else:
603                 changed = Changed.YES if task.result() else Changed.NO
604                 # If the file was written back or was successfully checked as
605                 # well-formatted, store this information in the cache.
606                 if write_back is WriteBack.YES or (
607                     write_back is WriteBack.CHECK and changed is Changed.NO
608                 ):
609                     sources_to_cache.append(src)
610                 report.done(src, changed)
611     if cancelled:
612         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
613     if sources_to_cache:
614         write_cache(cache, sources_to_cache, mode)
615
616
617 def format_file_in_place(
618     src: Path,
619     fast: bool,
620     mode: FileMode,
621     write_back: WriteBack = WriteBack.NO,
622     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
623 ) -> bool:
624     """Format file under `src` path. Return True if changed.
625
626     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
627     code to the file.
628     `mode` and `fast` options are passed to :func:`format_file_contents`.
629     """
630     if src.suffix == ".pyi":
631         mode = evolve(mode, is_pyi=True)
632
633     then = datetime.utcfromtimestamp(src.stat().st_mtime)
634     with open(src, "rb") as buf:
635         src_contents, encoding, newline = decode_bytes(buf.read())
636     try:
637         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
638     except NothingChanged:
639         return False
640
641     if write_back == write_back.YES:
642         with open(src, "w", encoding=encoding, newline=newline) as f:
643             f.write(dst_contents)
644     elif write_back == write_back.DIFF:
645         now = datetime.utcnow()
646         src_name = f"{src}\t{then} +0000"
647         dst_name = f"{src}\t{now} +0000"
648         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
649
650         with lock or nullcontext():
651             f = io.TextIOWrapper(
652                 sys.stdout.buffer,
653                 encoding=encoding,
654                 newline=newline,
655                 write_through=True,
656             )
657             f.write(diff_contents)
658             f.detach()
659
660     return True
661
662
663 def format_stdin_to_stdout(
664     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
665 ) -> bool:
666     """Format file on stdin. Return True if changed.
667
668     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
669     write a diff to stdout. The `mode` argument is passed to
670     :func:`format_file_contents`.
671     """
672     then = datetime.utcnow()
673     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
674     dst = src
675     try:
676         dst = format_file_contents(src, fast=fast, mode=mode)
677         return True
678
679     except NothingChanged:
680         return False
681
682     finally:
683         f = io.TextIOWrapper(
684             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
685         )
686         if write_back == WriteBack.YES:
687             f.write(dst)
688         elif write_back == WriteBack.DIFF:
689             now = datetime.utcnow()
690             src_name = f"STDIN\t{then} +0000"
691             dst_name = f"STDOUT\t{now} +0000"
692             f.write(diff(src, dst, src_name, dst_name))
693         f.detach()
694
695
696 def format_file_contents(
697     src_contents: str, *, fast: bool, mode: FileMode
698 ) -> FileContent:
699     """Reformat contents a file and return new contents.
700
701     If `fast` is False, additionally confirm that the reformatted code is
702     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
703     `mode` is passed to :func:`format_str`.
704     """
705     if src_contents.strip() == "":
706         raise NothingChanged
707
708     dst_contents = format_str(src_contents, mode=mode)
709     if src_contents == dst_contents:
710         raise NothingChanged
711
712     if not fast:
713         assert_equivalent(src_contents, dst_contents)
714         assert_stable(src_contents, dst_contents, mode=mode)
715     return dst_contents
716
717
718 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
719     """Reformat a string and return new contents.
720
721     `mode` determines formatting options, such as how many characters per line are
722     allowed.
723     """
724     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
725     dst_contents = []
726     future_imports = get_future_imports(src_node)
727     if mode.target_versions:
728         versions = mode.target_versions
729     else:
730         versions = detect_target_versions(src_node)
731     normalize_fmt_off(src_node)
732     lines = LineGenerator(
733         remove_u_prefix="unicode_literals" in future_imports
734         or supports_feature(versions, Feature.UNICODE_LITERALS),
735         is_pyi=mode.is_pyi,
736         normalize_strings=mode.string_normalization,
737     )
738     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
739     empty_line = Line()
740     after = 0
741     split_line_features = {
742         feature
743         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
744         if supports_feature(versions, feature)
745     }
746     for current_line in lines.visit(src_node):
747         for _ in range(after):
748             dst_contents.append(str(empty_line))
749         before, after = elt.maybe_empty_lines(current_line)
750         for _ in range(before):
751             dst_contents.append(str(empty_line))
752         for line in split_line(
753             current_line, line_length=mode.line_length, features=split_line_features
754         ):
755             dst_contents.append(str(line))
756     return "".join(dst_contents)
757
758
759 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
760     """Return a tuple of (decoded_contents, encoding, newline).
761
762     `newline` is either CRLF or LF but `decoded_contents` is decoded with
763     universal newlines (i.e. only contains LF).
764     """
765     srcbuf = io.BytesIO(src)
766     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
767     if not lines:
768         return "", encoding, "\n"
769
770     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
771     srcbuf.seek(0)
772     with io.TextIOWrapper(srcbuf, encoding) as tiow:
773         return tiow.read(), encoding, newline
774
775
776 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
777     if not target_versions:
778         # No target_version specified, so try all grammars.
779         return [
780             # Python 3.7+
781             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
782             # Python 3.0-3.6
783             pygram.python_grammar_no_print_statement_no_exec_statement,
784             # Python 2.7 with future print_function import
785             pygram.python_grammar_no_print_statement,
786             # Python 2.7
787             pygram.python_grammar,
788         ]
789     elif all(version.is_python2() for version in target_versions):
790         # Python 2-only code, so try Python 2 grammars.
791         return [
792             # Python 2.7 with future print_function import
793             pygram.python_grammar_no_print_statement,
794             # Python 2.7
795             pygram.python_grammar,
796         ]
797     else:
798         # Python 3-compatible code, so only try Python 3 grammar.
799         grammars = []
800         # If we have to parse both, try to parse async as a keyword first
801         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
802             # Python 3.7+
803             grammars.append(
804                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
805             )
806         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
807             # Python 3.0-3.6
808             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
809         # At least one of the above branches must have been taken, because every Python
810         # version has exactly one of the two 'ASYNC_*' flags
811         return grammars
812
813
814 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
815     """Given a string with source, return the lib2to3 Node."""
816     if src_txt[-1:] != "\n":
817         src_txt += "\n"
818
819     for grammar in get_grammars(set(target_versions)):
820         drv = driver.Driver(grammar, pytree.convert)
821         try:
822             result = drv.parse_string(src_txt, True)
823             break
824
825         except ParseError as pe:
826             lineno, column = pe.context[1]
827             lines = src_txt.splitlines()
828             try:
829                 faulty_line = lines[lineno - 1]
830             except IndexError:
831                 faulty_line = "<line number missing in source>"
832             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
833     else:
834         raise exc from None
835
836     if isinstance(result, Leaf):
837         result = Node(syms.file_input, [result])
838     return result
839
840
841 def lib2to3_unparse(node: Node) -> str:
842     """Given a lib2to3 node, return its string representation."""
843     code = str(node)
844     return code
845
846
847 T = TypeVar("T")
848
849
850 class Visitor(Generic[T]):
851     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
852
853     def visit(self, node: LN) -> Iterator[T]:
854         """Main method to visit `node` and its children.
855
856         It tries to find a `visit_*()` method for the given `node.type`, like
857         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
858         If no dedicated `visit_*()` method is found, chooses `visit_default()`
859         instead.
860
861         Then yields objects of type `T` from the selected visitor.
862         """
863         if node.type < 256:
864             name = token.tok_name[node.type]
865         else:
866             name = type_repr(node.type)
867         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
868
869     def visit_default(self, node: LN) -> Iterator[T]:
870         """Default `visit_*()` implementation. Recurses to children of `node`."""
871         if isinstance(node, Node):
872             for child in node.children:
873                 yield from self.visit(child)
874
875
876 @dataclass
877 class DebugVisitor(Visitor[T]):
878     tree_depth: int = 0
879
880     def visit_default(self, node: LN) -> Iterator[T]:
881         indent = " " * (2 * self.tree_depth)
882         if isinstance(node, Node):
883             _type = type_repr(node.type)
884             out(f"{indent}{_type}", fg="yellow")
885             self.tree_depth += 1
886             for child in node.children:
887                 yield from self.visit(child)
888
889             self.tree_depth -= 1
890             out(f"{indent}/{_type}", fg="yellow", bold=False)
891         else:
892             _type = token.tok_name.get(node.type, str(node.type))
893             out(f"{indent}{_type}", fg="blue", nl=False)
894             if node.prefix:
895                 # We don't have to handle prefixes for `Node` objects since
896                 # that delegates to the first child anyway.
897                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
898             out(f" {node.value!r}", fg="blue", bold=False)
899
900     @classmethod
901     def show(cls, code: Union[str, Leaf, Node]) -> None:
902         """Pretty-print the lib2to3 AST of a given string of `code`.
903
904         Convenience method for debugging.
905         """
906         v: DebugVisitor[None] = DebugVisitor()
907         if isinstance(code, str):
908             code = lib2to3_parse(code)
909         list(v.visit(code))
910
911
912 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
913 STATEMENT = {
914     syms.if_stmt,
915     syms.while_stmt,
916     syms.for_stmt,
917     syms.try_stmt,
918     syms.except_clause,
919     syms.with_stmt,
920     syms.funcdef,
921     syms.classdef,
922 }
923 STANDALONE_COMMENT = 153
924 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
925 LOGIC_OPERATORS = {"and", "or"}
926 COMPARATORS = {
927     token.LESS,
928     token.GREATER,
929     token.EQEQUAL,
930     token.NOTEQUAL,
931     token.LESSEQUAL,
932     token.GREATEREQUAL,
933 }
934 MATH_OPERATORS = {
935     token.VBAR,
936     token.CIRCUMFLEX,
937     token.AMPER,
938     token.LEFTSHIFT,
939     token.RIGHTSHIFT,
940     token.PLUS,
941     token.MINUS,
942     token.STAR,
943     token.SLASH,
944     token.DOUBLESLASH,
945     token.PERCENT,
946     token.AT,
947     token.TILDE,
948     token.DOUBLESTAR,
949 }
950 STARS = {token.STAR, token.DOUBLESTAR}
951 VARARGS_SPECIALS = STARS | {token.SLASH}
952 VARARGS_PARENTS = {
953     syms.arglist,
954     syms.argument,  # double star in arglist
955     syms.trailer,  # single argument to call
956     syms.typedargslist,
957     syms.varargslist,  # lambdas
958 }
959 UNPACKING_PARENTS = {
960     syms.atom,  # single element of a list or set literal
961     syms.dictsetmaker,
962     syms.listmaker,
963     syms.testlist_gexp,
964     syms.testlist_star_expr,
965 }
966 TEST_DESCENDANTS = {
967     syms.test,
968     syms.lambdef,
969     syms.or_test,
970     syms.and_test,
971     syms.not_test,
972     syms.comparison,
973     syms.star_expr,
974     syms.expr,
975     syms.xor_expr,
976     syms.and_expr,
977     syms.shift_expr,
978     syms.arith_expr,
979     syms.trailer,
980     syms.term,
981     syms.power,
982 }
983 ASSIGNMENTS = {
984     "=",
985     "+=",
986     "-=",
987     "*=",
988     "@=",
989     "/=",
990     "%=",
991     "&=",
992     "|=",
993     "^=",
994     "<<=",
995     ">>=",
996     "**=",
997     "//=",
998 }
999 COMPREHENSION_PRIORITY = 20
1000 COMMA_PRIORITY = 18
1001 TERNARY_PRIORITY = 16
1002 LOGIC_PRIORITY = 14
1003 STRING_PRIORITY = 12
1004 COMPARATOR_PRIORITY = 10
1005 MATH_PRIORITIES = {
1006     token.VBAR: 9,
1007     token.CIRCUMFLEX: 8,
1008     token.AMPER: 7,
1009     token.LEFTSHIFT: 6,
1010     token.RIGHTSHIFT: 6,
1011     token.PLUS: 5,
1012     token.MINUS: 5,
1013     token.STAR: 4,
1014     token.SLASH: 4,
1015     token.DOUBLESLASH: 4,
1016     token.PERCENT: 4,
1017     token.AT: 4,
1018     token.TILDE: 3,
1019     token.DOUBLESTAR: 2,
1020 }
1021 DOT_PRIORITY = 1
1022
1023
1024 @dataclass
1025 class BracketTracker:
1026     """Keeps track of brackets on a line."""
1027
1028     depth: int = 0
1029     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1030     delimiters: Dict[LeafID, Priority] = Factory(dict)
1031     previous: Optional[Leaf] = None
1032     _for_loop_depths: List[int] = Factory(list)
1033     _lambda_argument_depths: List[int] = Factory(list)
1034
1035     def mark(self, leaf: Leaf) -> None:
1036         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1037
1038         All leaves receive an int `bracket_depth` field that stores how deep
1039         within brackets a given leaf is. 0 means there are no enclosing brackets
1040         that started on this line.
1041
1042         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1043         field that it forms a pair with. This is a one-directional link to
1044         avoid reference cycles.
1045
1046         If a leaf is a delimiter (a token on which Black can split the line if
1047         needed) and it's on depth 0, its `id()` is stored in the tracker's
1048         `delimiters` field.
1049         """
1050         if leaf.type == token.COMMENT:
1051             return
1052
1053         self.maybe_decrement_after_for_loop_variable(leaf)
1054         self.maybe_decrement_after_lambda_arguments(leaf)
1055         if leaf.type in CLOSING_BRACKETS:
1056             self.depth -= 1
1057             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1058             leaf.opening_bracket = opening_bracket
1059         leaf.bracket_depth = self.depth
1060         if self.depth == 0:
1061             delim = is_split_before_delimiter(leaf, self.previous)
1062             if delim and self.previous is not None:
1063                 self.delimiters[id(self.previous)] = delim
1064             else:
1065                 delim = is_split_after_delimiter(leaf, self.previous)
1066                 if delim:
1067                     self.delimiters[id(leaf)] = delim
1068         if leaf.type in OPENING_BRACKETS:
1069             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1070             self.depth += 1
1071         self.previous = leaf
1072         self.maybe_increment_lambda_arguments(leaf)
1073         self.maybe_increment_for_loop_variable(leaf)
1074
1075     def any_open_brackets(self) -> bool:
1076         """Return True if there is an yet unmatched open bracket on the line."""
1077         return bool(self.bracket_match)
1078
1079     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1080         """Return the highest priority of a delimiter found on the line.
1081
1082         Values are consistent with what `is_split_*_delimiter()` return.
1083         Raises ValueError on no delimiters.
1084         """
1085         return max(v for k, v in self.delimiters.items() if k not in exclude)
1086
1087     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1088         """Return the number of delimiters with the given `priority`.
1089
1090         If no `priority` is passed, defaults to max priority on the line.
1091         """
1092         if not self.delimiters:
1093             return 0
1094
1095         priority = priority or self.max_delimiter_priority()
1096         return sum(1 for p in self.delimiters.values() if p == priority)
1097
1098     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1099         """In a for loop, or comprehension, the variables are often unpacks.
1100
1101         To avoid splitting on the comma in this situation, increase the depth of
1102         tokens between `for` and `in`.
1103         """
1104         if leaf.type == token.NAME and leaf.value == "for":
1105             self.depth += 1
1106             self._for_loop_depths.append(self.depth)
1107             return True
1108
1109         return False
1110
1111     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1112         """See `maybe_increment_for_loop_variable` above for explanation."""
1113         if (
1114             self._for_loop_depths
1115             and self._for_loop_depths[-1] == self.depth
1116             and leaf.type == token.NAME
1117             and leaf.value == "in"
1118         ):
1119             self.depth -= 1
1120             self._for_loop_depths.pop()
1121             return True
1122
1123         return False
1124
1125     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1126         """In a lambda expression, there might be more than one argument.
1127
1128         To avoid splitting on the comma in this situation, increase the depth of
1129         tokens between `lambda` and `:`.
1130         """
1131         if leaf.type == token.NAME and leaf.value == "lambda":
1132             self.depth += 1
1133             self._lambda_argument_depths.append(self.depth)
1134             return True
1135
1136         return False
1137
1138     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1139         """See `maybe_increment_lambda_arguments` above for explanation."""
1140         if (
1141             self._lambda_argument_depths
1142             and self._lambda_argument_depths[-1] == self.depth
1143             and leaf.type == token.COLON
1144         ):
1145             self.depth -= 1
1146             self._lambda_argument_depths.pop()
1147             return True
1148
1149         return False
1150
1151     def get_open_lsqb(self) -> Optional[Leaf]:
1152         """Return the most recent opening square bracket (if any)."""
1153         return self.bracket_match.get((self.depth - 1, token.RSQB))
1154
1155
1156 @dataclass
1157 class Line:
1158     """Holds leaves and comments. Can be printed with `str(line)`."""
1159
1160     depth: int = 0
1161     leaves: List[Leaf] = Factory(list)
1162     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1163     bracket_tracker: BracketTracker = Factory(BracketTracker)
1164     inside_brackets: bool = False
1165     should_explode: bool = False
1166
1167     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1168         """Add a new `leaf` to the end of the line.
1169
1170         Unless `preformatted` is True, the `leaf` will receive a new consistent
1171         whitespace prefix and metadata applied by :class:`BracketTracker`.
1172         Trailing commas are maybe removed, unpacked for loop variables are
1173         demoted from being delimiters.
1174
1175         Inline comments are put aside.
1176         """
1177         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1178         if not has_value:
1179             return
1180
1181         if token.COLON == leaf.type and self.is_class_paren_empty:
1182             del self.leaves[-2:]
1183         if self.leaves and not preformatted:
1184             # Note: at this point leaf.prefix should be empty except for
1185             # imports, for which we only preserve newlines.
1186             leaf.prefix += whitespace(
1187                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1188             )
1189         if self.inside_brackets or not preformatted:
1190             self.bracket_tracker.mark(leaf)
1191             self.maybe_remove_trailing_comma(leaf)
1192         if not self.append_comment(leaf):
1193             self.leaves.append(leaf)
1194
1195     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1196         """Like :func:`append()` but disallow invalid standalone comment structure.
1197
1198         Raises ValueError when any `leaf` is appended after a standalone comment
1199         or when a standalone comment is not the first leaf on the line.
1200         """
1201         if self.bracket_tracker.depth == 0:
1202             if self.is_comment:
1203                 raise ValueError("cannot append to standalone comments")
1204
1205             if self.leaves and leaf.type == STANDALONE_COMMENT:
1206                 raise ValueError(
1207                     "cannot append standalone comments to a populated line"
1208                 )
1209
1210         self.append(leaf, preformatted=preformatted)
1211
1212     @property
1213     def is_comment(self) -> bool:
1214         """Is this line a standalone comment?"""
1215         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1216
1217     @property
1218     def is_decorator(self) -> bool:
1219         """Is this line a decorator?"""
1220         return bool(self) and self.leaves[0].type == token.AT
1221
1222     @property
1223     def is_import(self) -> bool:
1224         """Is this an import line?"""
1225         return bool(self) and is_import(self.leaves[0])
1226
1227     @property
1228     def is_class(self) -> bool:
1229         """Is this line a class definition?"""
1230         return (
1231             bool(self)
1232             and self.leaves[0].type == token.NAME
1233             and self.leaves[0].value == "class"
1234         )
1235
1236     @property
1237     def is_stub_class(self) -> bool:
1238         """Is this line a class definition with a body consisting only of "..."?"""
1239         return self.is_class and self.leaves[-3:] == [
1240             Leaf(token.DOT, ".") for _ in range(3)
1241         ]
1242
1243     @property
1244     def is_def(self) -> bool:
1245         """Is this a function definition? (Also returns True for async defs.)"""
1246         try:
1247             first_leaf = self.leaves[0]
1248         except IndexError:
1249             return False
1250
1251         try:
1252             second_leaf: Optional[Leaf] = self.leaves[1]
1253         except IndexError:
1254             second_leaf = None
1255         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1256             first_leaf.type == token.ASYNC
1257             and second_leaf is not None
1258             and second_leaf.type == token.NAME
1259             and second_leaf.value == "def"
1260         )
1261
1262     @property
1263     def is_class_paren_empty(self) -> bool:
1264         """Is this a class with no base classes but using parentheses?
1265
1266         Those are unnecessary and should be removed.
1267         """
1268         return (
1269             bool(self)
1270             and len(self.leaves) == 4
1271             and self.is_class
1272             and self.leaves[2].type == token.LPAR
1273             and self.leaves[2].value == "("
1274             and self.leaves[3].type == token.RPAR
1275             and self.leaves[3].value == ")"
1276         )
1277
1278     @property
1279     def is_triple_quoted_string(self) -> bool:
1280         """Is the line a triple quoted string?"""
1281         return (
1282             bool(self)
1283             and self.leaves[0].type == token.STRING
1284             and self.leaves[0].value.startswith(('"""', "'''"))
1285         )
1286
1287     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1288         """If so, needs to be split before emitting."""
1289         for leaf in self.leaves:
1290             if leaf.type == STANDALONE_COMMENT:
1291                 if leaf.bracket_depth <= depth_limit:
1292                     return True
1293         return False
1294
1295     def contains_uncollapsable_type_comments(self) -> bool:
1296         ignored_ids = set()
1297         try:
1298             last_leaf = self.leaves[-1]
1299             ignored_ids.add(id(last_leaf))
1300             if last_leaf.type == token.COMMA or (
1301                 last_leaf.type == token.RPAR and not last_leaf.value
1302             ):
1303                 # When trailing commas or optional parens are inserted by Black for
1304                 # consistency, comments after the previous last element are not moved
1305                 # (they don't have to, rendering will still be correct).  So we ignore
1306                 # trailing commas and invisible.
1307                 last_leaf = self.leaves[-2]
1308                 ignored_ids.add(id(last_leaf))
1309         except IndexError:
1310             return False
1311
1312         # A type comment is uncollapsable if it is attached to a leaf
1313         # that isn't at the end of the line (since that could cause it
1314         # to get associated to a different argument) or if there are
1315         # comments before it (since that could cause it to get hidden
1316         # behind a comment.
1317         comment_seen = False
1318         for leaf_id, comments in self.comments.items():
1319             for comment in comments:
1320                 if is_type_comment(comment):
1321                     if leaf_id not in ignored_ids or comment_seen:
1322                         return True
1323
1324                 comment_seen = True
1325
1326         return False
1327
1328     def contains_unsplittable_type_ignore(self) -> bool:
1329         if not self.leaves:
1330             return False
1331
1332         # If a 'type: ignore' is attached to the end of a line, we
1333         # can't split the line, because we can't know which of the
1334         # subexpressions the ignore was meant to apply to.
1335         #
1336         # We only want this to apply to actual physical lines from the
1337         # original source, though: we don't want the presence of a
1338         # 'type: ignore' at the end of a multiline expression to
1339         # justify pushing it all onto one line. Thus we
1340         # (unfortunately) need to check the actual source lines and
1341         # only report an unsplittable 'type: ignore' if this line was
1342         # one line in the original code.
1343
1344         # Grab the first and last line numbers, skipping generated leaves
1345         first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0)
1346         last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0)
1347
1348         if first_line == last_line:
1349             # We look at the last two leaves since a comma or an
1350             # invisible paren could have been added at the end of the
1351             # line.
1352             for node in self.leaves[-2:]:
1353                 for comment in self.comments.get(id(node), []):
1354                     if is_type_comment(comment, " ignore"):
1355                         return True
1356
1357         return False
1358
1359     def contains_multiline_strings(self) -> bool:
1360         for leaf in self.leaves:
1361             if is_multiline_string(leaf):
1362                 return True
1363
1364         return False
1365
1366     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1367         """Remove trailing comma if there is one and it's safe."""
1368         if not (
1369             self.leaves
1370             and self.leaves[-1].type == token.COMMA
1371             and closing.type in CLOSING_BRACKETS
1372         ):
1373             return False
1374
1375         if closing.type == token.RBRACE:
1376             self.remove_trailing_comma()
1377             return True
1378
1379         if closing.type == token.RSQB:
1380             comma = self.leaves[-1]
1381             if comma.parent and comma.parent.type == syms.listmaker:
1382                 self.remove_trailing_comma()
1383                 return True
1384
1385         # For parens let's check if it's safe to remove the comma.
1386         # Imports are always safe.
1387         if self.is_import:
1388             self.remove_trailing_comma()
1389             return True
1390
1391         # Otherwise, if the trailing one is the only one, we might mistakenly
1392         # change a tuple into a different type by removing the comma.
1393         depth = closing.bracket_depth + 1
1394         commas = 0
1395         opening = closing.opening_bracket
1396         for _opening_index, leaf in enumerate(self.leaves):
1397             if leaf is opening:
1398                 break
1399
1400         else:
1401             return False
1402
1403         for leaf in self.leaves[_opening_index + 1 :]:
1404             if leaf is closing:
1405                 break
1406
1407             bracket_depth = leaf.bracket_depth
1408             if bracket_depth == depth and leaf.type == token.COMMA:
1409                 commas += 1
1410                 if leaf.parent and leaf.parent.type in {
1411                     syms.arglist,
1412                     syms.typedargslist,
1413                 }:
1414                     commas += 1
1415                     break
1416
1417         if commas > 1:
1418             self.remove_trailing_comma()
1419             return True
1420
1421         return False
1422
1423     def append_comment(self, comment: Leaf) -> bool:
1424         """Add an inline or standalone comment to the line."""
1425         if (
1426             comment.type == STANDALONE_COMMENT
1427             and self.bracket_tracker.any_open_brackets()
1428         ):
1429             comment.prefix = ""
1430             return False
1431
1432         if comment.type != token.COMMENT:
1433             return False
1434
1435         if not self.leaves:
1436             comment.type = STANDALONE_COMMENT
1437             comment.prefix = ""
1438             return False
1439
1440         last_leaf = self.leaves[-1]
1441         if (
1442             last_leaf.type == token.RPAR
1443             and not last_leaf.value
1444             and last_leaf.parent
1445             and len(list(last_leaf.parent.leaves())) <= 3
1446             and not is_type_comment(comment)
1447         ):
1448             # Comments on an optional parens wrapping a single leaf should belong to
1449             # the wrapped node except if it's a type comment. Pinning the comment like
1450             # this avoids unstable formatting caused by comment migration.
1451             if len(self.leaves) < 2:
1452                 comment.type = STANDALONE_COMMENT
1453                 comment.prefix = ""
1454                 return False
1455             last_leaf = self.leaves[-2]
1456         self.comments.setdefault(id(last_leaf), []).append(comment)
1457         return True
1458
1459     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1460         """Generate comments that should appear directly after `leaf`."""
1461         return self.comments.get(id(leaf), [])
1462
1463     def remove_trailing_comma(self) -> None:
1464         """Remove the trailing comma and moves the comments attached to it."""
1465         trailing_comma = self.leaves.pop()
1466         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1467         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1468             trailing_comma_comments
1469         )
1470
1471     def is_complex_subscript(self, leaf: Leaf) -> bool:
1472         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1473         open_lsqb = self.bracket_tracker.get_open_lsqb()
1474         if open_lsqb is None:
1475             return False
1476
1477         subscript_start = open_lsqb.next_sibling
1478
1479         if isinstance(subscript_start, Node):
1480             if subscript_start.type == syms.listmaker:
1481                 return False
1482
1483             if subscript_start.type == syms.subscriptlist:
1484                 subscript_start = child_towards(subscript_start, leaf)
1485         return subscript_start is not None and any(
1486             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1487         )
1488
1489     def __str__(self) -> str:
1490         """Render the line."""
1491         if not self:
1492             return "\n"
1493
1494         indent = "    " * self.depth
1495         leaves = iter(self.leaves)
1496         first = next(leaves)
1497         res = f"{first.prefix}{indent}{first.value}"
1498         for leaf in leaves:
1499             res += str(leaf)
1500         for comment in itertools.chain.from_iterable(self.comments.values()):
1501             res += str(comment)
1502         return res + "\n"
1503
1504     def __bool__(self) -> bool:
1505         """Return True if the line has leaves or comments."""
1506         return bool(self.leaves or self.comments)
1507
1508
1509 @dataclass
1510 class EmptyLineTracker:
1511     """Provides a stateful method that returns the number of potential extra
1512     empty lines needed before and after the currently processed line.
1513
1514     Note: this tracker works on lines that haven't been split yet.  It assumes
1515     the prefix of the first leaf consists of optional newlines.  Those newlines
1516     are consumed by `maybe_empty_lines()` and included in the computation.
1517     """
1518
1519     is_pyi: bool = False
1520     previous_line: Optional[Line] = None
1521     previous_after: int = 0
1522     previous_defs: List[int] = Factory(list)
1523
1524     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1525         """Return the number of extra empty lines before and after the `current_line`.
1526
1527         This is for separating `def`, `async def` and `class` with extra empty
1528         lines (two on module-level).
1529         """
1530         before, after = self._maybe_empty_lines(current_line)
1531         before = (
1532             # Black should not insert empty lines at the beginning
1533             # of the file
1534             0
1535             if self.previous_line is None
1536             else before - self.previous_after
1537         )
1538         self.previous_after = after
1539         self.previous_line = current_line
1540         return before, after
1541
1542     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1543         max_allowed = 1
1544         if current_line.depth == 0:
1545             max_allowed = 1 if self.is_pyi else 2
1546         if current_line.leaves:
1547             # Consume the first leaf's extra newlines.
1548             first_leaf = current_line.leaves[0]
1549             before = first_leaf.prefix.count("\n")
1550             before = min(before, max_allowed)
1551             first_leaf.prefix = ""
1552         else:
1553             before = 0
1554         depth = current_line.depth
1555         while self.previous_defs and self.previous_defs[-1] >= depth:
1556             self.previous_defs.pop()
1557             if self.is_pyi:
1558                 before = 0 if depth else 1
1559             else:
1560                 before = 1 if depth else 2
1561         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1562             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1563
1564         if (
1565             self.previous_line
1566             and self.previous_line.is_import
1567             and not current_line.is_import
1568             and depth == self.previous_line.depth
1569         ):
1570             return (before or 1), 0
1571
1572         if (
1573             self.previous_line
1574             and self.previous_line.is_class
1575             and current_line.is_triple_quoted_string
1576         ):
1577             return before, 1
1578
1579         return before, 0
1580
1581     def _maybe_empty_lines_for_class_or_def(
1582         self, current_line: Line, before: int
1583     ) -> Tuple[int, int]:
1584         if not current_line.is_decorator:
1585             self.previous_defs.append(current_line.depth)
1586         if self.previous_line is None:
1587             # Don't insert empty lines before the first line in the file.
1588             return 0, 0
1589
1590         if self.previous_line.is_decorator:
1591             return 0, 0
1592
1593         if self.previous_line.depth < current_line.depth and (
1594             self.previous_line.is_class or self.previous_line.is_def
1595         ):
1596             return 0, 0
1597
1598         if (
1599             self.previous_line.is_comment
1600             and self.previous_line.depth == current_line.depth
1601             and before == 0
1602         ):
1603             return 0, 0
1604
1605         if self.is_pyi:
1606             if self.previous_line.depth > current_line.depth:
1607                 newlines = 1
1608             elif current_line.is_class or self.previous_line.is_class:
1609                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1610                     # No blank line between classes with an empty body
1611                     newlines = 0
1612                 else:
1613                     newlines = 1
1614             elif current_line.is_def and not self.previous_line.is_def:
1615                 # Blank line between a block of functions and a block of non-functions
1616                 newlines = 1
1617             else:
1618                 newlines = 0
1619         else:
1620             newlines = 2
1621         if current_line.depth and newlines:
1622             newlines -= 1
1623         return newlines, 0
1624
1625
1626 @dataclass
1627 class LineGenerator(Visitor[Line]):
1628     """Generates reformatted Line objects.  Empty lines are not emitted.
1629
1630     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1631     in ways that will no longer stringify to valid Python code on the tree.
1632     """
1633
1634     is_pyi: bool = False
1635     normalize_strings: bool = True
1636     current_line: Line = Factory(Line)
1637     remove_u_prefix: bool = False
1638
1639     def line(self, indent: int = 0) -> Iterator[Line]:
1640         """Generate a line.
1641
1642         If the line is empty, only emit if it makes sense.
1643         If the line is too long, split it first and then generate.
1644
1645         If any lines were generated, set up a new current_line.
1646         """
1647         if not self.current_line:
1648             self.current_line.depth += indent
1649             return  # Line is empty, don't emit. Creating a new one unnecessary.
1650
1651         complete_line = self.current_line
1652         self.current_line = Line(depth=complete_line.depth + indent)
1653         yield complete_line
1654
1655     def visit_default(self, node: LN) -> Iterator[Line]:
1656         """Default `visit_*()` implementation. Recurses to children of `node`."""
1657         if isinstance(node, Leaf):
1658             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1659             for comment in generate_comments(node):
1660                 if any_open_brackets:
1661                     # any comment within brackets is subject to splitting
1662                     self.current_line.append(comment)
1663                 elif comment.type == token.COMMENT:
1664                     # regular trailing comment
1665                     self.current_line.append(comment)
1666                     yield from self.line()
1667
1668                 else:
1669                     # regular standalone comment
1670                     yield from self.line()
1671
1672                     self.current_line.append(comment)
1673                     yield from self.line()
1674
1675             normalize_prefix(node, inside_brackets=any_open_brackets)
1676             if self.normalize_strings and node.type == token.STRING:
1677                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1678                 normalize_string_quotes(node)
1679             if node.type == token.NUMBER:
1680                 normalize_numeric_literal(node)
1681             if node.type not in WHITESPACE:
1682                 self.current_line.append(node)
1683         yield from super().visit_default(node)
1684
1685     def visit_atom(self, node: Node) -> Iterator[Line]:
1686         # Always make parentheses invisible around a single node, because it should
1687         # not be needed (except in the case of yield, where removing the parentheses
1688         # produces a SyntaxError).
1689         if (
1690             len(node.children) == 3
1691             and isinstance(node.children[0], Leaf)
1692             and node.children[0].type == token.LPAR
1693             and isinstance(node.children[2], Leaf)
1694             and node.children[2].type == token.RPAR
1695             and isinstance(node.children[1], Leaf)
1696             and not (
1697                 node.children[1].type == token.NAME
1698                 and node.children[1].value == "yield"
1699             )
1700         ):
1701             node.children[0].value = ""
1702             node.children[2].value = ""
1703         yield from super().visit_default(node)
1704
1705     def visit_factor(self, node: Node) -> Iterator[Line]:
1706         """Force parentheses between a unary op and a binary power:
1707
1708         -2 ** 8 -> -(2 ** 8)
1709         """
1710         child = node.children[1]
1711         if child.type == syms.power and len(child.children) == 3:
1712             lpar = Leaf(token.LPAR, "(")
1713             rpar = Leaf(token.RPAR, ")")
1714             index = child.remove() or 0
1715             node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
1716         yield from self.visit_default(node)
1717
1718     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1719         """Increase indentation level, maybe yield a line."""
1720         # In blib2to3 INDENT never holds comments.
1721         yield from self.line(+1)
1722         yield from self.visit_default(node)
1723
1724     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1725         """Decrease indentation level, maybe yield a line."""
1726         # The current line might still wait for trailing comments.  At DEDENT time
1727         # there won't be any (they would be prefixes on the preceding NEWLINE).
1728         # Emit the line then.
1729         yield from self.line()
1730
1731         # While DEDENT has no value, its prefix may contain standalone comments
1732         # that belong to the current indentation level.  Get 'em.
1733         yield from self.visit_default(node)
1734
1735         # Finally, emit the dedent.
1736         yield from self.line(-1)
1737
1738     def visit_stmt(
1739         self, node: Node, keywords: Set[str], parens: Set[str]
1740     ) -> Iterator[Line]:
1741         """Visit a statement.
1742
1743         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1744         `def`, `with`, `class`, `assert` and assignments.
1745
1746         The relevant Python language `keywords` for a given statement will be
1747         NAME leaves within it. This methods puts those on a separate line.
1748
1749         `parens` holds a set of string leaf values immediately after which
1750         invisible parens should be put.
1751         """
1752         normalize_invisible_parens(node, parens_after=parens)
1753         for child in node.children:
1754             if child.type == token.NAME and child.value in keywords:  # type: ignore
1755                 yield from self.line()
1756
1757             yield from self.visit(child)
1758
1759     def visit_suite(self, node: Node) -> Iterator[Line]:
1760         """Visit a suite."""
1761         if self.is_pyi and is_stub_suite(node):
1762             yield from self.visit(node.children[2])
1763         else:
1764             yield from self.visit_default(node)
1765
1766     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1767         """Visit a statement without nested statements."""
1768         is_suite_like = node.parent and node.parent.type in STATEMENT
1769         if is_suite_like:
1770             if self.is_pyi and is_stub_body(node):
1771                 yield from self.visit_default(node)
1772             else:
1773                 yield from self.line(+1)
1774                 yield from self.visit_default(node)
1775                 yield from self.line(-1)
1776
1777         else:
1778             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1779                 yield from self.line()
1780             yield from self.visit_default(node)
1781
1782     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1783         """Visit `async def`, `async for`, `async with`."""
1784         yield from self.line()
1785
1786         children = iter(node.children)
1787         for child in children:
1788             yield from self.visit(child)
1789
1790             if child.type == token.ASYNC:
1791                 break
1792
1793         internal_stmt = next(children)
1794         for child in internal_stmt.children:
1795             yield from self.visit(child)
1796
1797     def visit_decorators(self, node: Node) -> Iterator[Line]:
1798         """Visit decorators."""
1799         for child in node.children:
1800             yield from self.line()
1801             yield from self.visit(child)
1802
1803     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1804         """Remove a semicolon and put the other statement on a separate line."""
1805         yield from self.line()
1806
1807     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1808         """End of file. Process outstanding comments and end with a newline."""
1809         yield from self.visit_default(leaf)
1810         yield from self.line()
1811
1812     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1813         if not self.current_line.bracket_tracker.any_open_brackets():
1814             yield from self.line()
1815         yield from self.visit_default(leaf)
1816
1817     def __attrs_post_init__(self) -> None:
1818         """You are in a twisty little maze of passages."""
1819         v = self.visit_stmt
1820         Ø: Set[str] = set()
1821         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1822         self.visit_if_stmt = partial(
1823             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1824         )
1825         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1826         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1827         self.visit_try_stmt = partial(
1828             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1829         )
1830         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1831         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1832         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1833         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1834         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1835         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1836         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1837         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1838         self.visit_async_funcdef = self.visit_async_stmt
1839         self.visit_decorated = self.visit_decorators
1840
1841
1842 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1843 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1844 OPENING_BRACKETS = set(BRACKET.keys())
1845 CLOSING_BRACKETS = set(BRACKET.values())
1846 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1847 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1848
1849
1850 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1851     """Return whitespace prefix if needed for the given `leaf`.
1852
1853     `complex_subscript` signals whether the given leaf is part of a subscription
1854     which has non-trivial arguments, like arithmetic expressions or function calls.
1855     """
1856     NO = ""
1857     SPACE = " "
1858     DOUBLESPACE = "  "
1859     t = leaf.type
1860     p = leaf.parent
1861     v = leaf.value
1862     if t in ALWAYS_NO_SPACE:
1863         return NO
1864
1865     if t == token.COMMENT:
1866         return DOUBLESPACE
1867
1868     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1869     if t == token.COLON and p.type not in {
1870         syms.subscript,
1871         syms.subscriptlist,
1872         syms.sliceop,
1873     }:
1874         return NO
1875
1876     prev = leaf.prev_sibling
1877     if not prev:
1878         prevp = preceding_leaf(p)
1879         if not prevp or prevp.type in OPENING_BRACKETS:
1880             return NO
1881
1882         if t == token.COLON:
1883             if prevp.type == token.COLON:
1884                 return NO
1885
1886             elif prevp.type != token.COMMA and not complex_subscript:
1887                 return NO
1888
1889             return SPACE
1890
1891         if prevp.type == token.EQUAL:
1892             if prevp.parent:
1893                 if prevp.parent.type in {
1894                     syms.arglist,
1895                     syms.argument,
1896                     syms.parameters,
1897                     syms.varargslist,
1898                 }:
1899                     return NO
1900
1901                 elif prevp.parent.type == syms.typedargslist:
1902                     # A bit hacky: if the equal sign has whitespace, it means we
1903                     # previously found it's a typed argument.  So, we're using
1904                     # that, too.
1905                     return prevp.prefix
1906
1907         elif prevp.type in VARARGS_SPECIALS:
1908             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1909                 return NO
1910
1911         elif prevp.type == token.COLON:
1912             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1913                 return SPACE if complex_subscript else NO
1914
1915         elif (
1916             prevp.parent
1917             and prevp.parent.type == syms.factor
1918             and prevp.type in MATH_OPERATORS
1919         ):
1920             return NO
1921
1922         elif (
1923             prevp.type == token.RIGHTSHIFT
1924             and prevp.parent
1925             and prevp.parent.type == syms.shift_expr
1926             and prevp.prev_sibling
1927             and prevp.prev_sibling.type == token.NAME
1928             and prevp.prev_sibling.value == "print"  # type: ignore
1929         ):
1930             # Python 2 print chevron
1931             return NO
1932
1933     elif prev.type in OPENING_BRACKETS:
1934         return NO
1935
1936     if p.type in {syms.parameters, syms.arglist}:
1937         # untyped function signatures or calls
1938         if not prev or prev.type != token.COMMA:
1939             return NO
1940
1941     elif p.type == syms.varargslist:
1942         # lambdas
1943         if prev and prev.type != token.COMMA:
1944             return NO
1945
1946     elif p.type == syms.typedargslist:
1947         # typed function signatures
1948         if not prev:
1949             return NO
1950
1951         if t == token.EQUAL:
1952             if prev.type != syms.tname:
1953                 return NO
1954
1955         elif prev.type == token.EQUAL:
1956             # A bit hacky: if the equal sign has whitespace, it means we
1957             # previously found it's a typed argument.  So, we're using that, too.
1958             return prev.prefix
1959
1960         elif prev.type != token.COMMA:
1961             return NO
1962
1963     elif p.type == syms.tname:
1964         # type names
1965         if not prev:
1966             prevp = preceding_leaf(p)
1967             if not prevp or prevp.type != token.COMMA:
1968                 return NO
1969
1970     elif p.type == syms.trailer:
1971         # attributes and calls
1972         if t == token.LPAR or t == token.RPAR:
1973             return NO
1974
1975         if not prev:
1976             if t == token.DOT:
1977                 prevp = preceding_leaf(p)
1978                 if not prevp or prevp.type != token.NUMBER:
1979                     return NO
1980
1981             elif t == token.LSQB:
1982                 return NO
1983
1984         elif prev.type != token.COMMA:
1985             return NO
1986
1987     elif p.type == syms.argument:
1988         # single argument
1989         if t == token.EQUAL:
1990             return NO
1991
1992         if not prev:
1993             prevp = preceding_leaf(p)
1994             if not prevp or prevp.type == token.LPAR:
1995                 return NO
1996
1997         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
1998             return NO
1999
2000     elif p.type == syms.decorator:
2001         # decorators
2002         return NO
2003
2004     elif p.type == syms.dotted_name:
2005         if prev:
2006             return NO
2007
2008         prevp = preceding_leaf(p)
2009         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2010             return NO
2011
2012     elif p.type == syms.classdef:
2013         if t == token.LPAR:
2014             return NO
2015
2016         if prev and prev.type == token.LPAR:
2017             return NO
2018
2019     elif p.type in {syms.subscript, syms.sliceop}:
2020         # indexing
2021         if not prev:
2022             assert p.parent is not None, "subscripts are always parented"
2023             if p.parent.type == syms.subscriptlist:
2024                 return SPACE
2025
2026             return NO
2027
2028         elif not complex_subscript:
2029             return NO
2030
2031     elif p.type == syms.atom:
2032         if prev and t == token.DOT:
2033             # dots, but not the first one.
2034             return NO
2035
2036     elif p.type == syms.dictsetmaker:
2037         # dict unpacking
2038         if prev and prev.type == token.DOUBLESTAR:
2039             return NO
2040
2041     elif p.type in {syms.factor, syms.star_expr}:
2042         # unary ops
2043         if not prev:
2044             prevp = preceding_leaf(p)
2045             if not prevp or prevp.type in OPENING_BRACKETS:
2046                 return NO
2047
2048             prevp_parent = prevp.parent
2049             assert prevp_parent is not None
2050             if prevp.type == token.COLON and prevp_parent.type in {
2051                 syms.subscript,
2052                 syms.sliceop,
2053             }:
2054                 return NO
2055
2056             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2057                 return NO
2058
2059         elif t in {token.NAME, token.NUMBER, token.STRING}:
2060             return NO
2061
2062     elif p.type == syms.import_from:
2063         if t == token.DOT:
2064             if prev and prev.type == token.DOT:
2065                 return NO
2066
2067         elif t == token.NAME:
2068             if v == "import":
2069                 return SPACE
2070
2071             if prev and prev.type == token.DOT:
2072                 return NO
2073
2074     elif p.type == syms.sliceop:
2075         return NO
2076
2077     return SPACE
2078
2079
2080 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2081     """Return the first leaf that precedes `node`, if any."""
2082     while node:
2083         res = node.prev_sibling
2084         if res:
2085             if isinstance(res, Leaf):
2086                 return res
2087
2088             try:
2089                 return list(res.leaves())[-1]
2090
2091             except IndexError:
2092                 return None
2093
2094         node = node.parent
2095     return None
2096
2097
2098 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2099     """Return the child of `ancestor` that contains `descendant`."""
2100     node: Optional[LN] = descendant
2101     while node and node.parent != ancestor:
2102         node = node.parent
2103     return node
2104
2105
2106 def container_of(leaf: Leaf) -> LN:
2107     """Return `leaf` or one of its ancestors that is the topmost container of it.
2108
2109     By "container" we mean a node where `leaf` is the very first child.
2110     """
2111     same_prefix = leaf.prefix
2112     container: LN = leaf
2113     while container:
2114         parent = container.parent
2115         if parent is None:
2116             break
2117
2118         if parent.children[0].prefix != same_prefix:
2119             break
2120
2121         if parent.type == syms.file_input:
2122             break
2123
2124         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2125             break
2126
2127         container = parent
2128     return container
2129
2130
2131 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2132     """Return the priority of the `leaf` delimiter, given a line break after it.
2133
2134     The delimiter priorities returned here are from those delimiters that would
2135     cause a line break after themselves.
2136
2137     Higher numbers are higher priority.
2138     """
2139     if leaf.type == token.COMMA:
2140         return COMMA_PRIORITY
2141
2142     return 0
2143
2144
2145 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2146     """Return the priority of the `leaf` delimiter, given a line break before it.
2147
2148     The delimiter priorities returned here are from those delimiters that would
2149     cause a line break before themselves.
2150
2151     Higher numbers are higher priority.
2152     """
2153     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2154         # * and ** might also be MATH_OPERATORS but in this case they are not.
2155         # Don't treat them as a delimiter.
2156         return 0
2157
2158     if (
2159         leaf.type == token.DOT
2160         and leaf.parent
2161         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2162         and (previous is None or previous.type in CLOSING_BRACKETS)
2163     ):
2164         return DOT_PRIORITY
2165
2166     if (
2167         leaf.type in MATH_OPERATORS
2168         and leaf.parent
2169         and leaf.parent.type not in {syms.factor, syms.star_expr}
2170     ):
2171         return MATH_PRIORITIES[leaf.type]
2172
2173     if leaf.type in COMPARATORS:
2174         return COMPARATOR_PRIORITY
2175
2176     if (
2177         leaf.type == token.STRING
2178         and previous is not None
2179         and previous.type == token.STRING
2180     ):
2181         return STRING_PRIORITY
2182
2183     if leaf.type not in {token.NAME, token.ASYNC}:
2184         return 0
2185
2186     if (
2187         leaf.value == "for"
2188         and leaf.parent
2189         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2190         or leaf.type == token.ASYNC
2191     ):
2192         if (
2193             not isinstance(leaf.prev_sibling, Leaf)
2194             or leaf.prev_sibling.value != "async"
2195         ):
2196             return COMPREHENSION_PRIORITY
2197
2198     if (
2199         leaf.value == "if"
2200         and leaf.parent
2201         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2202     ):
2203         return COMPREHENSION_PRIORITY
2204
2205     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2206         return TERNARY_PRIORITY
2207
2208     if leaf.value == "is":
2209         return COMPARATOR_PRIORITY
2210
2211     if (
2212         leaf.value == "in"
2213         and leaf.parent
2214         and leaf.parent.type in {syms.comp_op, syms.comparison}
2215         and not (
2216             previous is not None
2217             and previous.type == token.NAME
2218             and previous.value == "not"
2219         )
2220     ):
2221         return COMPARATOR_PRIORITY
2222
2223     if (
2224         leaf.value == "not"
2225         and leaf.parent
2226         and leaf.parent.type == syms.comp_op
2227         and not (
2228             previous is not None
2229             and previous.type == token.NAME
2230             and previous.value == "is"
2231         )
2232     ):
2233         return COMPARATOR_PRIORITY
2234
2235     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2236         return LOGIC_PRIORITY
2237
2238     return 0
2239
2240
2241 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2242 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2243
2244
2245 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2246     """Clean the prefix of the `leaf` and generate comments from it, if any.
2247
2248     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2249     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2250     move because it does away with modifying the grammar to include all the
2251     possible places in which comments can be placed.
2252
2253     The sad consequence for us though is that comments don't "belong" anywhere.
2254     This is why this function generates simple parentless Leaf objects for
2255     comments.  We simply don't know what the correct parent should be.
2256
2257     No matter though, we can live without this.  We really only need to
2258     differentiate between inline and standalone comments.  The latter don't
2259     share the line with any code.
2260
2261     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2262     are emitted with a fake STANDALONE_COMMENT token identifier.
2263     """
2264     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2265         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2266
2267
2268 @dataclass
2269 class ProtoComment:
2270     """Describes a piece of syntax that is a comment.
2271
2272     It's not a :class:`blib2to3.pytree.Leaf` so that:
2273
2274     * it can be cached (`Leaf` objects should not be reused more than once as
2275       they store their lineno, column, prefix, and parent information);
2276     * `newlines` and `consumed` fields are kept separate from the `value`. This
2277       simplifies handling of special marker comments like ``# fmt: off/on``.
2278     """
2279
2280     type: int  # token.COMMENT or STANDALONE_COMMENT
2281     value: str  # content of the comment
2282     newlines: int  # how many newlines before the comment
2283     consumed: int  # how many characters of the original leaf's prefix did we consume
2284
2285
2286 @lru_cache(maxsize=4096)
2287 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2288     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2289     result: List[ProtoComment] = []
2290     if not prefix or "#" not in prefix:
2291         return result
2292
2293     consumed = 0
2294     nlines = 0
2295     ignored_lines = 0
2296     for index, line in enumerate(prefix.split("\n")):
2297         consumed += len(line) + 1  # adding the length of the split '\n'
2298         line = line.lstrip()
2299         if not line:
2300             nlines += 1
2301         if not line.startswith("#"):
2302             # Escaped newlines outside of a comment are not really newlines at
2303             # all. We treat a single-line comment following an escaped newline
2304             # as a simple trailing comment.
2305             if line.endswith("\\"):
2306                 ignored_lines += 1
2307             continue
2308
2309         if index == ignored_lines and not is_endmarker:
2310             comment_type = token.COMMENT  # simple trailing comment
2311         else:
2312             comment_type = STANDALONE_COMMENT
2313         comment = make_comment(line)
2314         result.append(
2315             ProtoComment(
2316                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2317             )
2318         )
2319         nlines = 0
2320     return result
2321
2322
2323 def make_comment(content: str) -> str:
2324     """Return a consistently formatted comment from the given `content` string.
2325
2326     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2327     space between the hash sign and the content.
2328
2329     If `content` didn't start with a hash sign, one is provided.
2330     """
2331     content = content.rstrip()
2332     if not content:
2333         return "#"
2334
2335     if content[0] == "#":
2336         content = content[1:]
2337     if content and content[0] not in " !:#'%":
2338         content = " " + content
2339     return "#" + content
2340
2341
2342 def split_line(
2343     line: Line,
2344     line_length: int,
2345     inner: bool = False,
2346     features: Collection[Feature] = (),
2347 ) -> Iterator[Line]:
2348     """Split a `line` into potentially many lines.
2349
2350     They should fit in the allotted `line_length` but might not be able to.
2351     `inner` signifies that there were a pair of brackets somewhere around the
2352     current `line`, possibly transitively. This means we can fallback to splitting
2353     by delimiters if the LHS/RHS don't yield any results.
2354
2355     `features` are syntactical features that may be used in the output.
2356     """
2357     if line.is_comment:
2358         yield line
2359         return
2360
2361     line_str = str(line).strip("\n")
2362
2363     if (
2364         not line.contains_uncollapsable_type_comments()
2365         and not line.should_explode
2366         and (
2367             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2368             or line.contains_unsplittable_type_ignore()
2369         )
2370     ):
2371         yield line
2372         return
2373
2374     split_funcs: List[SplitFunc]
2375     if line.is_def:
2376         split_funcs = [left_hand_split]
2377     else:
2378
2379         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2380             for omit in generate_trailers_to_omit(line, line_length):
2381                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2382                 if is_line_short_enough(lines[0], line_length=line_length):
2383                     yield from lines
2384                     return
2385
2386             # All splits failed, best effort split with no omits.
2387             # This mostly happens to multiline strings that are by definition
2388             # reported as not fitting a single line.
2389             yield from right_hand_split(line, line_length, features=features)
2390
2391         if line.inside_brackets:
2392             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2393         else:
2394             split_funcs = [rhs]
2395     for split_func in split_funcs:
2396         # We are accumulating lines in `result` because we might want to abort
2397         # mission and return the original line in the end, or attempt a different
2398         # split altogether.
2399         result: List[Line] = []
2400         try:
2401             for l in split_func(line, features):
2402                 if str(l).strip("\n") == line_str:
2403                     raise CannotSplit("Split function returned an unchanged result")
2404
2405                 result.extend(
2406                     split_line(
2407                         l, line_length=line_length, inner=True, features=features
2408                     )
2409                 )
2410         except CannotSplit:
2411             continue
2412
2413         else:
2414             yield from result
2415             break
2416
2417     else:
2418         yield line
2419
2420
2421 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2422     """Split line into many lines, starting with the first matching bracket pair.
2423
2424     Note: this usually looks weird, only use this for function definitions.
2425     Prefer RHS otherwise.  This is why this function is not symmetrical with
2426     :func:`right_hand_split` which also handles optional parentheses.
2427     """
2428     tail_leaves: List[Leaf] = []
2429     body_leaves: List[Leaf] = []
2430     head_leaves: List[Leaf] = []
2431     current_leaves = head_leaves
2432     matching_bracket = None
2433     for leaf in line.leaves:
2434         if (
2435             current_leaves is body_leaves
2436             and leaf.type in CLOSING_BRACKETS
2437             and leaf.opening_bracket is matching_bracket
2438         ):
2439             current_leaves = tail_leaves if body_leaves else head_leaves
2440         current_leaves.append(leaf)
2441         if current_leaves is head_leaves:
2442             if leaf.type in OPENING_BRACKETS:
2443                 matching_bracket = leaf
2444                 current_leaves = body_leaves
2445     if not matching_bracket:
2446         raise CannotSplit("No brackets found")
2447
2448     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2449     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2450     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2451     bracket_split_succeeded_or_raise(head, body, tail)
2452     for result in (head, body, tail):
2453         if result:
2454             yield result
2455
2456
2457 def right_hand_split(
2458     line: Line,
2459     line_length: int,
2460     features: Collection[Feature] = (),
2461     omit: Collection[LeafID] = (),
2462 ) -> Iterator[Line]:
2463     """Split line into many lines, starting with the last matching bracket pair.
2464
2465     If the split was by optional parentheses, attempt splitting without them, too.
2466     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2467     this split.
2468
2469     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2470     """
2471     tail_leaves: List[Leaf] = []
2472     body_leaves: List[Leaf] = []
2473     head_leaves: List[Leaf] = []
2474     current_leaves = tail_leaves
2475     opening_bracket = None
2476     closing_bracket = None
2477     for leaf in reversed(line.leaves):
2478         if current_leaves is body_leaves:
2479             if leaf is opening_bracket:
2480                 current_leaves = head_leaves if body_leaves else tail_leaves
2481         current_leaves.append(leaf)
2482         if current_leaves is tail_leaves:
2483             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2484                 opening_bracket = leaf.opening_bracket
2485                 closing_bracket = leaf
2486                 current_leaves = body_leaves
2487     if not (opening_bracket and closing_bracket and head_leaves):
2488         # If there is no opening or closing_bracket that means the split failed and
2489         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2490         # the matching `opening_bracket` wasn't available on `line` anymore.
2491         raise CannotSplit("No brackets found")
2492
2493     tail_leaves.reverse()
2494     body_leaves.reverse()
2495     head_leaves.reverse()
2496     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2497     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2498     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2499     bracket_split_succeeded_or_raise(head, body, tail)
2500     if (
2501         # the body shouldn't be exploded
2502         not body.should_explode
2503         # the opening bracket is an optional paren
2504         and opening_bracket.type == token.LPAR
2505         and not opening_bracket.value
2506         # the closing bracket is an optional paren
2507         and closing_bracket.type == token.RPAR
2508         and not closing_bracket.value
2509         # it's not an import (optional parens are the only thing we can split on
2510         # in this case; attempting a split without them is a waste of time)
2511         and not line.is_import
2512         # there are no standalone comments in the body
2513         and not body.contains_standalone_comments(0)
2514         # and we can actually remove the parens
2515         and can_omit_invisible_parens(body, line_length)
2516     ):
2517         omit = {id(closing_bracket), *omit}
2518         try:
2519             yield from right_hand_split(line, line_length, features=features, omit=omit)
2520             return
2521
2522         except CannotSplit:
2523             if not (
2524                 can_be_split(body)
2525                 or is_line_short_enough(body, line_length=line_length)
2526             ):
2527                 raise CannotSplit(
2528                     "Splitting failed, body is still too long and can't be split."
2529                 )
2530
2531             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2532                 raise CannotSplit(
2533                     "The current optional pair of parentheses is bound to fail to "
2534                     "satisfy the splitting algorithm because the head or the tail "
2535                     "contains multiline strings which by definition never fit one "
2536                     "line."
2537                 )
2538
2539     ensure_visible(opening_bracket)
2540     ensure_visible(closing_bracket)
2541     for result in (head, body, tail):
2542         if result:
2543             yield result
2544
2545
2546 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2547     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2548
2549     Do nothing otherwise.
2550
2551     A left- or right-hand split is based on a pair of brackets. Content before
2552     (and including) the opening bracket is left on one line, content inside the
2553     brackets is put on a separate line, and finally content starting with and
2554     following the closing bracket is put on a separate line.
2555
2556     Those are called `head`, `body`, and `tail`, respectively. If the split
2557     produced the same line (all content in `head`) or ended up with an empty `body`
2558     and the `tail` is just the closing bracket, then it's considered failed.
2559     """
2560     tail_len = len(str(tail).strip())
2561     if not body:
2562         if tail_len == 0:
2563             raise CannotSplit("Splitting brackets produced the same line")
2564
2565         elif tail_len < 3:
2566             raise CannotSplit(
2567                 f"Splitting brackets on an empty body to save "
2568                 f"{tail_len} characters is not worth it"
2569             )
2570
2571
2572 def bracket_split_build_line(
2573     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2574 ) -> Line:
2575     """Return a new line with given `leaves` and respective comments from `original`.
2576
2577     If `is_body` is True, the result line is one-indented inside brackets and as such
2578     has its first leaf's prefix normalized and a trailing comma added when expected.
2579     """
2580     result = Line(depth=original.depth)
2581     if is_body:
2582         result.inside_brackets = True
2583         result.depth += 1
2584         if leaves:
2585             # Since body is a new indent level, remove spurious leading whitespace.
2586             normalize_prefix(leaves[0], inside_brackets=True)
2587             # Ensure a trailing comma for imports and standalone function arguments, but
2588             # be careful not to add one after any comments.
2589             no_commas = original.is_def and not any(
2590                 l.type == token.COMMA for l in leaves
2591             )
2592
2593             if original.is_import or no_commas:
2594                 for i in range(len(leaves) - 1, -1, -1):
2595                     if leaves[i].type == STANDALONE_COMMENT:
2596                         continue
2597                     elif leaves[i].type == token.COMMA:
2598                         break
2599                     else:
2600                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2601                         break
2602     # Populate the line
2603     for leaf in leaves:
2604         result.append(leaf, preformatted=True)
2605         for comment_after in original.comments_after(leaf):
2606             result.append(comment_after, preformatted=True)
2607     if is_body:
2608         result.should_explode = should_explode(result, opening_bracket)
2609     return result
2610
2611
2612 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2613     """Normalize prefix of the first leaf in every line returned by `split_func`.
2614
2615     This is a decorator over relevant split functions.
2616     """
2617
2618     @wraps(split_func)
2619     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2620         for l in split_func(line, features):
2621             normalize_prefix(l.leaves[0], inside_brackets=True)
2622             yield l
2623
2624     return split_wrapper
2625
2626
2627 @dont_increase_indentation
2628 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2629     """Split according to delimiters of the highest priority.
2630
2631     If the appropriate Features are given, the split will add trailing commas
2632     also in function signatures and calls that contain `*` and `**`.
2633     """
2634     try:
2635         last_leaf = line.leaves[-1]
2636     except IndexError:
2637         raise CannotSplit("Line empty")
2638
2639     bt = line.bracket_tracker
2640     try:
2641         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2642     except ValueError:
2643         raise CannotSplit("No delimiters found")
2644
2645     if delimiter_priority == DOT_PRIORITY:
2646         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2647             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2648
2649     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2650     lowest_depth = sys.maxsize
2651     trailing_comma_safe = True
2652
2653     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2654         """Append `leaf` to current line or to new line if appending impossible."""
2655         nonlocal current_line
2656         try:
2657             current_line.append_safe(leaf, preformatted=True)
2658         except ValueError:
2659             yield current_line
2660
2661             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2662             current_line.append(leaf)
2663
2664     for leaf in line.leaves:
2665         yield from append_to_line(leaf)
2666
2667         for comment_after in line.comments_after(leaf):
2668             yield from append_to_line(comment_after)
2669
2670         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2671         if leaf.bracket_depth == lowest_depth:
2672             if is_vararg(leaf, within={syms.typedargslist}):
2673                 trailing_comma_safe = (
2674                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2675                 )
2676             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2677                 trailing_comma_safe = (
2678                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2679                 )
2680
2681         leaf_priority = bt.delimiters.get(id(leaf))
2682         if leaf_priority == delimiter_priority:
2683             yield current_line
2684
2685             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2686     if current_line:
2687         if (
2688             trailing_comma_safe
2689             and delimiter_priority == COMMA_PRIORITY
2690             and current_line.leaves[-1].type != token.COMMA
2691             and current_line.leaves[-1].type != STANDALONE_COMMENT
2692         ):
2693             current_line.append(Leaf(token.COMMA, ","))
2694         yield current_line
2695
2696
2697 @dont_increase_indentation
2698 def standalone_comment_split(
2699     line: Line, features: Collection[Feature] = ()
2700 ) -> Iterator[Line]:
2701     """Split standalone comments from the rest of the line."""
2702     if not line.contains_standalone_comments(0):
2703         raise CannotSplit("Line does not have any standalone comments")
2704
2705     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2706
2707     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2708         """Append `leaf` to current line or to new line if appending impossible."""
2709         nonlocal current_line
2710         try:
2711             current_line.append_safe(leaf, preformatted=True)
2712         except ValueError:
2713             yield current_line
2714
2715             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2716             current_line.append(leaf)
2717
2718     for leaf in line.leaves:
2719         yield from append_to_line(leaf)
2720
2721         for comment_after in line.comments_after(leaf):
2722             yield from append_to_line(comment_after)
2723
2724     if current_line:
2725         yield current_line
2726
2727
2728 def is_import(leaf: Leaf) -> bool:
2729     """Return True if the given leaf starts an import statement."""
2730     p = leaf.parent
2731     t = leaf.type
2732     v = leaf.value
2733     return bool(
2734         t == token.NAME
2735         and (
2736             (v == "import" and p and p.type == syms.import_name)
2737             or (v == "from" and p and p.type == syms.import_from)
2738         )
2739     )
2740
2741
2742 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
2743     """Return True if the given leaf is a special comment.
2744     Only returns true for type comments for now."""
2745     t = leaf.type
2746     v = leaf.value
2747     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith(
2748         "# type:" + suffix
2749     )
2750
2751
2752 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2753     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2754     else.
2755
2756     Note: don't use backslashes for formatting or you'll lose your voting rights.
2757     """
2758     if not inside_brackets:
2759         spl = leaf.prefix.split("#")
2760         if "\\" not in spl[0]:
2761             nl_count = spl[-1].count("\n")
2762             if len(spl) > 1:
2763                 nl_count -= 1
2764             leaf.prefix = "\n" * nl_count
2765             return
2766
2767     leaf.prefix = ""
2768
2769
2770 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2771     """Make all string prefixes lowercase.
2772
2773     If remove_u_prefix is given, also removes any u prefix from the string.
2774
2775     Note: Mutates its argument.
2776     """
2777     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2778     assert match is not None, f"failed to match string {leaf.value!r}"
2779     orig_prefix = match.group(1)
2780     new_prefix = orig_prefix.lower()
2781     if remove_u_prefix:
2782         new_prefix = new_prefix.replace("u", "")
2783     leaf.value = f"{new_prefix}{match.group(2)}"
2784
2785
2786 def normalize_string_quotes(leaf: Leaf) -> None:
2787     """Prefer double quotes but only if it doesn't cause more escaping.
2788
2789     Adds or removes backslashes as appropriate. Doesn't parse and fix
2790     strings nested in f-strings (yet).
2791
2792     Note: Mutates its argument.
2793     """
2794     value = leaf.value.lstrip("furbFURB")
2795     if value[:3] == '"""':
2796         return
2797
2798     elif value[:3] == "'''":
2799         orig_quote = "'''"
2800         new_quote = '"""'
2801     elif value[0] == '"':
2802         orig_quote = '"'
2803         new_quote = "'"
2804     else:
2805         orig_quote = "'"
2806         new_quote = '"'
2807     first_quote_pos = leaf.value.find(orig_quote)
2808     if first_quote_pos == -1:
2809         return  # There's an internal error
2810
2811     prefix = leaf.value[:first_quote_pos]
2812     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2813     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2814     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2815     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2816     if "r" in prefix.casefold():
2817         if unescaped_new_quote.search(body):
2818             # There's at least one unescaped new_quote in this raw string
2819             # so converting is impossible
2820             return
2821
2822         # Do not introduce or remove backslashes in raw strings
2823         new_body = body
2824     else:
2825         # remove unnecessary escapes
2826         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2827         if body != new_body:
2828             # Consider the string without unnecessary escapes as the original
2829             body = new_body
2830             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2831         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2832         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2833     if "f" in prefix.casefold():
2834         matches = re.findall(
2835             r"""
2836             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2837                 ([^{].*?)  # contents of the brackets except if begins with {{
2838             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2839             """,
2840             new_body,
2841             re.VERBOSE,
2842         )
2843         for m in matches:
2844             if "\\" in str(m):
2845                 # Do not introduce backslashes in interpolated expressions
2846                 return
2847     if new_quote == '"""' and new_body[-1:] == '"':
2848         # edge case:
2849         new_body = new_body[:-1] + '\\"'
2850     orig_escape_count = body.count("\\")
2851     new_escape_count = new_body.count("\\")
2852     if new_escape_count > orig_escape_count:
2853         return  # Do not introduce more escaping
2854
2855     if new_escape_count == orig_escape_count and orig_quote == '"':
2856         return  # Prefer double quotes
2857
2858     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2859
2860
2861 def normalize_numeric_literal(leaf: Leaf) -> None:
2862     """Normalizes numeric (float, int, and complex) literals.
2863
2864     All letters used in the representation are normalized to lowercase (except
2865     in Python 2 long literals).
2866     """
2867     text = leaf.value.lower()
2868     if text.startswith(("0o", "0b")):
2869         # Leave octal and binary literals alone.
2870         pass
2871     elif text.startswith("0x"):
2872         # Change hex literals to upper case.
2873         before, after = text[:2], text[2:]
2874         text = f"{before}{after.upper()}"
2875     elif "e" in text:
2876         before, after = text.split("e")
2877         sign = ""
2878         if after.startswith("-"):
2879             after = after[1:]
2880             sign = "-"
2881         elif after.startswith("+"):
2882             after = after[1:]
2883         before = format_float_or_int_string(before)
2884         text = f"{before}e{sign}{after}"
2885     elif text.endswith(("j", "l")):
2886         number = text[:-1]
2887         suffix = text[-1]
2888         # Capitalize in "2L" because "l" looks too similar to "1".
2889         if suffix == "l":
2890             suffix = "L"
2891         text = f"{format_float_or_int_string(number)}{suffix}"
2892     else:
2893         text = format_float_or_int_string(text)
2894     leaf.value = text
2895
2896
2897 def format_float_or_int_string(text: str) -> str:
2898     """Formats a float string like "1.0"."""
2899     if "." not in text:
2900         return text
2901
2902     before, after = text.split(".")
2903     return f"{before or 0}.{after or 0}"
2904
2905
2906 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2907     """Make existing optional parentheses invisible or create new ones.
2908
2909     `parens_after` is a set of string leaf values immediately after which parens
2910     should be put.
2911
2912     Standardizes on visible parentheses for single-element tuples, and keeps
2913     existing visible parentheses for other tuples and generator expressions.
2914     """
2915     for pc in list_comments(node.prefix, is_endmarker=False):
2916         if pc.value in FMT_OFF:
2917             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2918             return
2919
2920     check_lpar = False
2921     for index, child in enumerate(list(node.children)):
2922         # Add parentheses around long tuple unpacking in assignments.
2923         if (
2924             index == 0
2925             and isinstance(child, Node)
2926             and child.type == syms.testlist_star_expr
2927         ):
2928             check_lpar = True
2929
2930         if check_lpar:
2931             if is_walrus_assignment(child):
2932                 continue
2933             if child.type == syms.atom:
2934                 # Determines if the underlying atom should be surrounded with
2935                 # invisible params - also makes parens invisible recursively
2936                 # within the atom and removes repeated invisible parens within
2937                 # the atom
2938                 should_surround_with_parens = maybe_make_parens_invisible_in_atom(
2939                     child, parent=node
2940                 )
2941
2942                 if should_surround_with_parens:
2943                     lpar = Leaf(token.LPAR, "")
2944                     rpar = Leaf(token.RPAR, "")
2945                     index = child.remove() or 0
2946                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2947             elif is_one_tuple(child):
2948                 # wrap child in visible parentheses
2949                 lpar = Leaf(token.LPAR, "(")
2950                 rpar = Leaf(token.RPAR, ")")
2951                 child.remove()
2952                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2953             elif node.type == syms.import_from:
2954                 # "import from" nodes store parentheses directly as part of
2955                 # the statement
2956                 if child.type == token.LPAR:
2957                     # make parentheses invisible
2958                     child.value = ""  # type: ignore
2959                     node.children[-1].value = ""  # type: ignore
2960                 elif child.type != token.STAR:
2961                     # insert invisible parentheses
2962                     node.insert_child(index, Leaf(token.LPAR, ""))
2963                     node.append_child(Leaf(token.RPAR, ""))
2964                 break
2965
2966             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2967                 # wrap child in invisible parentheses
2968                 lpar = Leaf(token.LPAR, "")
2969                 rpar = Leaf(token.RPAR, "")
2970                 index = child.remove() or 0
2971                 prefix = child.prefix
2972                 child.prefix = ""
2973                 new_child = Node(syms.atom, [lpar, child, rpar])
2974                 new_child.prefix = prefix
2975                 node.insert_child(index, new_child)
2976
2977         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2978
2979
2980 def normalize_fmt_off(node: Node) -> None:
2981     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2982     try_again = True
2983     while try_again:
2984         try_again = convert_one_fmt_off_pair(node)
2985
2986
2987 def convert_one_fmt_off_pair(node: Node) -> bool:
2988     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2989
2990     Returns True if a pair was converted.
2991     """
2992     for leaf in node.leaves():
2993         previous_consumed = 0
2994         for comment in list_comments(leaf.prefix, is_endmarker=False):
2995             if comment.value in FMT_OFF:
2996                 # We only want standalone comments. If there's no previous leaf or
2997                 # the previous leaf is indentation, it's a standalone comment in
2998                 # disguise.
2999                 if comment.type != STANDALONE_COMMENT:
3000                     prev = preceding_leaf(leaf)
3001                     if prev and prev.type not in WHITESPACE:
3002                         continue
3003
3004                 ignored_nodes = list(generate_ignored_nodes(leaf))
3005                 if not ignored_nodes:
3006                     continue
3007
3008                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
3009                 parent = first.parent
3010                 prefix = first.prefix
3011                 first.prefix = prefix[comment.consumed :]
3012                 hidden_value = (
3013                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
3014                 )
3015                 if hidden_value.endswith("\n"):
3016                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
3017                     # leaf (possibly followed by a DEDENT).
3018                     hidden_value = hidden_value[:-1]
3019                 first_idx = None
3020                 for ignored in ignored_nodes:
3021                     index = ignored.remove()
3022                     if first_idx is None:
3023                         first_idx = index
3024                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
3025                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
3026                 parent.insert_child(
3027                     first_idx,
3028                     Leaf(
3029                         STANDALONE_COMMENT,
3030                         hidden_value,
3031                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3032                     ),
3033                 )
3034                 return True
3035
3036             previous_consumed = comment.consumed
3037
3038     return False
3039
3040
3041 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3042     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3043
3044     Stops at the end of the block.
3045     """
3046     container: Optional[LN] = container_of(leaf)
3047     while container is not None and container.type != token.ENDMARKER:
3048         for comment in list_comments(container.prefix, is_endmarker=False):
3049             if comment.value in FMT_ON:
3050                 return
3051
3052         yield container
3053
3054         container = container.next_sibling
3055
3056
3057 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3058     """If it's safe, make the parens in the atom `node` invisible, recursively.
3059     Additionally, remove repeated, adjacent invisible parens from the atom `node`
3060     as they are redundant.
3061
3062     Returns whether the node should itself be wrapped in invisible parentheses.
3063
3064     """
3065     if (
3066         node.type != syms.atom
3067         or is_empty_tuple(node)
3068         or is_one_tuple(node)
3069         or (is_yield(node) and parent.type != syms.expr_stmt)
3070         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3071     ):
3072         return False
3073
3074     first = node.children[0]
3075     last = node.children[-1]
3076     if first.type == token.LPAR and last.type == token.RPAR:
3077         middle = node.children[1]
3078         # make parentheses invisible
3079         first.value = ""  # type: ignore
3080         last.value = ""  # type: ignore
3081         maybe_make_parens_invisible_in_atom(middle, parent=parent)
3082
3083         if is_atom_with_invisible_parens(middle):
3084             # Strip the invisible parens from `middle` by replacing
3085             # it with the child in-between the invisible parens
3086             middle.replace(middle.children[1])
3087
3088         return False
3089
3090     return True
3091
3092
3093 def is_atom_with_invisible_parens(node: LN) -> bool:
3094     """Given a `LN`, determines whether it's an atom `node` with invisible
3095     parens. Useful in dedupe-ing and normalizing parens.
3096     """
3097     if isinstance(node, Leaf) or node.type != syms.atom:
3098         return False
3099
3100     first, last = node.children[0], node.children[-1]
3101     return (
3102         isinstance(first, Leaf)
3103         and first.type == token.LPAR
3104         and first.value == ""
3105         and isinstance(last, Leaf)
3106         and last.type == token.RPAR
3107         and last.value == ""
3108     )
3109
3110
3111 def is_empty_tuple(node: LN) -> bool:
3112     """Return True if `node` holds an empty tuple."""
3113     return (
3114         node.type == syms.atom
3115         and len(node.children) == 2
3116         and node.children[0].type == token.LPAR
3117         and node.children[1].type == token.RPAR
3118     )
3119
3120
3121 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3122     """Returns `wrapped` if `node` is of the shape ( wrapped ).
3123
3124     Parenthesis can be optional. Returns None otherwise"""
3125     if len(node.children) != 3:
3126         return None
3127     lpar, wrapped, rpar = node.children
3128     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3129         return None
3130
3131     return wrapped
3132
3133
3134 def is_one_tuple(node: LN) -> bool:
3135     """Return True if `node` holds a tuple with one element, with or without parens."""
3136     if node.type == syms.atom:
3137         gexp = unwrap_singleton_parenthesis(node)
3138         if gexp is None or gexp.type != syms.testlist_gexp:
3139             return False
3140
3141         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3142
3143     return (
3144         node.type in IMPLICIT_TUPLE
3145         and len(node.children) == 2
3146         and node.children[1].type == token.COMMA
3147     )
3148
3149
3150 def is_walrus_assignment(node: LN) -> bool:
3151     """Return True iff `node` is of the shape ( test := test )"""
3152     inner = unwrap_singleton_parenthesis(node)
3153     return inner is not None and inner.type == syms.namedexpr_test
3154
3155
3156 def is_yield(node: LN) -> bool:
3157     """Return True if `node` holds a `yield` or `yield from` expression."""
3158     if node.type == syms.yield_expr:
3159         return True
3160
3161     if node.type == token.NAME and node.value == "yield":  # type: ignore
3162         return True
3163
3164     if node.type != syms.atom:
3165         return False
3166
3167     if len(node.children) != 3:
3168         return False
3169
3170     lpar, expr, rpar = node.children
3171     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3172         return is_yield(expr)
3173
3174     return False
3175
3176
3177 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3178     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3179
3180     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3181     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3182     extended iterable unpacking (PEP 3132) and additional unpacking
3183     generalizations (PEP 448).
3184     """
3185     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3186         return False
3187
3188     p = leaf.parent
3189     if p.type == syms.star_expr:
3190         # Star expressions are also used as assignment targets in extended
3191         # iterable unpacking (PEP 3132).  See what its parent is instead.
3192         if not p.parent:
3193             return False
3194
3195         p = p.parent
3196
3197     return p.type in within
3198
3199
3200 def is_multiline_string(leaf: Leaf) -> bool:
3201     """Return True if `leaf` is a multiline string that actually spans many lines."""
3202     value = leaf.value.lstrip("furbFURB")
3203     return value[:3] in {'"""', "'''"} and "\n" in value
3204
3205
3206 def is_stub_suite(node: Node) -> bool:
3207     """Return True if `node` is a suite with a stub body."""
3208     if (
3209         len(node.children) != 4
3210         or node.children[0].type != token.NEWLINE
3211         or node.children[1].type != token.INDENT
3212         or node.children[3].type != token.DEDENT
3213     ):
3214         return False
3215
3216     return is_stub_body(node.children[2])
3217
3218
3219 def is_stub_body(node: LN) -> bool:
3220     """Return True if `node` is a simple statement containing an ellipsis."""
3221     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3222         return False
3223
3224     if len(node.children) != 2:
3225         return False
3226
3227     child = node.children[0]
3228     return (
3229         child.type == syms.atom
3230         and len(child.children) == 3
3231         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3232     )
3233
3234
3235 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3236     """Return maximum delimiter priority inside `node`.
3237
3238     This is specific to atoms with contents contained in a pair of parentheses.
3239     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3240     """
3241     if node.type != syms.atom:
3242         return 0
3243
3244     first = node.children[0]
3245     last = node.children[-1]
3246     if not (first.type == token.LPAR and last.type == token.RPAR):
3247         return 0
3248
3249     bt = BracketTracker()
3250     for c in node.children[1:-1]:
3251         if isinstance(c, Leaf):
3252             bt.mark(c)
3253         else:
3254             for leaf in c.leaves():
3255                 bt.mark(leaf)
3256     try:
3257         return bt.max_delimiter_priority()
3258
3259     except ValueError:
3260         return 0
3261
3262
3263 def ensure_visible(leaf: Leaf) -> None:
3264     """Make sure parentheses are visible.
3265
3266     They could be invisible as part of some statements (see
3267     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3268     """
3269     if leaf.type == token.LPAR:
3270         leaf.value = "("
3271     elif leaf.type == token.RPAR:
3272         leaf.value = ")"
3273
3274
3275 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3276     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3277
3278     if not (
3279         opening_bracket.parent
3280         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3281         and opening_bracket.value in "[{("
3282     ):
3283         return False
3284
3285     try:
3286         last_leaf = line.leaves[-1]
3287         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3288         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3289     except (IndexError, ValueError):
3290         return False
3291
3292     return max_priority == COMMA_PRIORITY
3293
3294
3295 def get_features_used(node: Node) -> Set[Feature]:
3296     """Return a set of (relatively) new Python features used in this file.
3297
3298     Currently looking for:
3299     - f-strings;
3300     - underscores in numeric literals;
3301     - trailing commas after * or ** in function signatures and calls;
3302     - positional only arguments in function signatures and lambdas;
3303     """
3304     features: Set[Feature] = set()
3305     for n in node.pre_order():
3306         if n.type == token.STRING:
3307             value_head = n.value[:2]  # type: ignore
3308             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3309                 features.add(Feature.F_STRINGS)
3310
3311         elif n.type == token.NUMBER:
3312             if "_" in n.value:  # type: ignore
3313                 features.add(Feature.NUMERIC_UNDERSCORES)
3314
3315         elif n.type == token.SLASH:
3316             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3317                 features.add(Feature.POS_ONLY_ARGUMENTS)
3318
3319         elif n.type == token.COLONEQUAL:
3320             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3321
3322         elif (
3323             n.type in {syms.typedargslist, syms.arglist}
3324             and n.children
3325             and n.children[-1].type == token.COMMA
3326         ):
3327             if n.type == syms.typedargslist:
3328                 feature = Feature.TRAILING_COMMA_IN_DEF
3329             else:
3330                 feature = Feature.TRAILING_COMMA_IN_CALL
3331
3332             for ch in n.children:
3333                 if ch.type in STARS:
3334                     features.add(feature)
3335
3336                 if ch.type == syms.argument:
3337                     for argch in ch.children:
3338                         if argch.type in STARS:
3339                             features.add(feature)
3340
3341     return features
3342
3343
3344 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3345     """Detect the version to target based on the nodes used."""
3346     features = get_features_used(node)
3347     return {
3348         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3349     }
3350
3351
3352 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3353     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3354
3355     Brackets can be omitted if the entire trailer up to and including
3356     a preceding closing bracket fits in one line.
3357
3358     Yielded sets are cumulative (contain results of previous yields, too).  First
3359     set is empty.
3360     """
3361
3362     omit: Set[LeafID] = set()
3363     yield omit
3364
3365     length = 4 * line.depth
3366     opening_bracket = None
3367     closing_bracket = None
3368     inner_brackets: Set[LeafID] = set()
3369     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3370         length += leaf_length
3371         if length > line_length:
3372             break
3373
3374         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3375         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3376             break
3377
3378         if opening_bracket:
3379             if leaf is opening_bracket:
3380                 opening_bracket = None
3381             elif leaf.type in CLOSING_BRACKETS:
3382                 inner_brackets.add(id(leaf))
3383         elif leaf.type in CLOSING_BRACKETS:
3384             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3385                 # Empty brackets would fail a split so treat them as "inner"
3386                 # brackets (e.g. only add them to the `omit` set if another
3387                 # pair of brackets was good enough.
3388                 inner_brackets.add(id(leaf))
3389                 continue
3390
3391             if closing_bracket:
3392                 omit.add(id(closing_bracket))
3393                 omit.update(inner_brackets)
3394                 inner_brackets.clear()
3395                 yield omit
3396
3397             if leaf.value:
3398                 opening_bracket = leaf.opening_bracket
3399                 closing_bracket = leaf
3400
3401
3402 def get_future_imports(node: Node) -> Set[str]:
3403     """Return a set of __future__ imports in the file."""
3404     imports: Set[str] = set()
3405
3406     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3407         for child in children:
3408             if isinstance(child, Leaf):
3409                 if child.type == token.NAME:
3410                     yield child.value
3411             elif child.type == syms.import_as_name:
3412                 orig_name = child.children[0]
3413                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3414                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3415                 yield orig_name.value
3416             elif child.type == syms.import_as_names:
3417                 yield from get_imports_from_children(child.children)
3418             else:
3419                 raise AssertionError("Invalid syntax parsing imports")
3420
3421     for child in node.children:
3422         if child.type != syms.simple_stmt:
3423             break
3424         first_child = child.children[0]
3425         if isinstance(first_child, Leaf):
3426             # Continue looking if we see a docstring; otherwise stop.
3427             if (
3428                 len(child.children) == 2
3429                 and first_child.type == token.STRING
3430                 and child.children[1].type == token.NEWLINE
3431             ):
3432                 continue
3433             else:
3434                 break
3435         elif first_child.type == syms.import_from:
3436             module_name = first_child.children[1]
3437             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3438                 break
3439             imports |= set(get_imports_from_children(first_child.children[3:]))
3440         else:
3441             break
3442     return imports
3443
3444
3445 def gen_python_files_in_dir(
3446     path: Path,
3447     root: Path,
3448     include: Pattern[str],
3449     exclude: Pattern[str],
3450     report: "Report",
3451 ) -> Iterator[Path]:
3452     """Generate all files under `path` whose paths are not excluded by the
3453     `exclude` regex, but are included by the `include` regex.
3454
3455     Symbolic links pointing outside of the `root` directory are ignored.
3456
3457     `report` is where output about exclusions goes.
3458     """
3459     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3460     for child in path.iterdir():
3461         try:
3462             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3463         except ValueError:
3464             if child.is_symlink():
3465                 report.path_ignored(
3466                     child, f"is a symbolic link that points outside {root}"
3467                 )
3468                 continue
3469
3470             raise
3471
3472         if child.is_dir():
3473             normalized_path += "/"
3474         exclude_match = exclude.search(normalized_path)
3475         if exclude_match and exclude_match.group(0):
3476             report.path_ignored(child, f"matches the --exclude regular expression")
3477             continue
3478
3479         if child.is_dir():
3480             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3481
3482         elif child.is_file():
3483             include_match = include.search(normalized_path)
3484             if include_match:
3485                 yield child
3486
3487
3488 @lru_cache()
3489 def find_project_root(srcs: Iterable[str]) -> Path:
3490     """Return a directory containing .git, .hg, or pyproject.toml.
3491
3492     That directory can be one of the directories passed in `srcs` or their
3493     common parent.
3494
3495     If no directory in the tree contains a marker that would specify it's the
3496     project root, the root of the file system is returned.
3497     """
3498     if not srcs:
3499         return Path("/").resolve()
3500
3501     common_base = min(Path(src).resolve() for src in srcs)
3502     if common_base.is_dir():
3503         # Append a fake file so `parents` below returns `common_base_dir`, too.
3504         common_base /= "fake-file"
3505     for directory in common_base.parents:
3506         if (directory / ".git").is_dir():
3507             return directory
3508
3509         if (directory / ".hg").is_dir():
3510             return directory
3511
3512         if (directory / "pyproject.toml").is_file():
3513             return directory
3514
3515     return directory
3516
3517
3518 @dataclass
3519 class Report:
3520     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3521
3522     check: bool = False
3523     quiet: bool = False
3524     verbose: bool = False
3525     change_count: int = 0
3526     same_count: int = 0
3527     failure_count: int = 0
3528
3529     def done(self, src: Path, changed: Changed) -> None:
3530         """Increment the counter for successful reformatting. Write out a message."""
3531         if changed is Changed.YES:
3532             reformatted = "would reformat" if self.check else "reformatted"
3533             if self.verbose or not self.quiet:
3534                 out(f"{reformatted} {src}")
3535             self.change_count += 1
3536         else:
3537             if self.verbose:
3538                 if changed is Changed.NO:
3539                     msg = f"{src} already well formatted, good job."
3540                 else:
3541                     msg = f"{src} wasn't modified on disk since last run."
3542                 out(msg, bold=False)
3543             self.same_count += 1
3544
3545     def failed(self, src: Path, message: str) -> None:
3546         """Increment the counter for failed reformatting. Write out a message."""
3547         err(f"error: cannot format {src}: {message}")
3548         self.failure_count += 1
3549
3550     def path_ignored(self, path: Path, message: str) -> None:
3551         if self.verbose:
3552             out(f"{path} ignored: {message}", bold=False)
3553
3554     @property
3555     def return_code(self) -> int:
3556         """Return the exit code that the app should use.
3557
3558         This considers the current state of changed files and failures:
3559         - if there were any failures, return 123;
3560         - if any files were changed and --check is being used, return 1;
3561         - otherwise return 0.
3562         """
3563         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3564         # 126 we have special return codes reserved by the shell.
3565         if self.failure_count:
3566             return 123
3567
3568         elif self.change_count and self.check:
3569             return 1
3570
3571         return 0
3572
3573     def __str__(self) -> str:
3574         """Render a color report of the current state.
3575
3576         Use `click.unstyle` to remove colors.
3577         """
3578         if self.check:
3579             reformatted = "would be reformatted"
3580             unchanged = "would be left unchanged"
3581             failed = "would fail to reformat"
3582         else:
3583             reformatted = "reformatted"
3584             unchanged = "left unchanged"
3585             failed = "failed to reformat"
3586         report = []
3587         if self.change_count:
3588             s = "s" if self.change_count > 1 else ""
3589             report.append(
3590                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3591             )
3592         if self.same_count:
3593             s = "s" if self.same_count > 1 else ""
3594             report.append(f"{self.same_count} file{s} {unchanged}")
3595         if self.failure_count:
3596             s = "s" if self.failure_count > 1 else ""
3597             report.append(
3598                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3599             )
3600         return ", ".join(report) + "."
3601
3602
3603 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3604     filename = "<unknown>"
3605     if sys.version_info >= (3, 8):
3606         # TODO: support Python 4+ ;)
3607         for minor_version in range(sys.version_info[1], 4, -1):
3608             try:
3609                 return ast.parse(src, filename, feature_version=(3, minor_version))
3610             except SyntaxError:
3611                 continue
3612     else:
3613         for feature_version in (7, 6):
3614             try:
3615                 return ast3.parse(src, filename, feature_version=feature_version)
3616             except SyntaxError:
3617                 continue
3618
3619     return ast27.parse(src)
3620
3621
3622 def _fixup_ast_constants(
3623     node: Union[ast.AST, ast3.AST, ast27.AST]
3624 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3625     """Map ast nodes deprecated in 3.8 to Constant."""
3626     # casts are required until this is released:
3627     # https://github.com/python/typeshed/pull/3142
3628     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3629         return cast(ast.AST, ast.Constant(value=node.s))
3630     elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3631         return cast(ast.AST, ast.Constant(value=node.n))
3632     elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3633         return cast(ast.AST, ast.Constant(value=node.value))
3634     return node
3635
3636
3637 def assert_equivalent(src: str, dst: str) -> None:
3638     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3639
3640     def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3641         """Simple visitor generating strings to compare ASTs by content."""
3642
3643         node = _fixup_ast_constants(node)
3644
3645         yield f"{'  ' * depth}{node.__class__.__name__}("
3646
3647         for field in sorted(node._fields):
3648             # TypeIgnore has only one field 'lineno' which breaks this comparison
3649             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3650             if sys.version_info >= (3, 8):
3651                 type_ignore_classes += (ast.TypeIgnore,)
3652             if isinstance(node, type_ignore_classes):
3653                 break
3654
3655             try:
3656                 value = getattr(node, field)
3657             except AttributeError:
3658                 continue
3659
3660             yield f"{'  ' * (depth+1)}{field}="
3661
3662             if isinstance(value, list):
3663                 for item in value:
3664                     # Ignore nested tuples within del statements, because we may insert
3665                     # parentheses and they change the AST.
3666                     if (
3667                         field == "targets"
3668                         and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3669                         and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3670                     ):
3671                         for item in item.elts:
3672                             yield from _v(item, depth + 2)
3673                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3674                         yield from _v(item, depth + 2)
3675
3676             elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3677                 yield from _v(value, depth + 2)
3678
3679             else:
3680                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3681
3682         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3683
3684     try:
3685         src_ast = parse_ast(src)
3686     except Exception as exc:
3687         raise AssertionError(
3688             f"cannot use --safe with this file; failed to parse source file.  "
3689             f"AST error message: {exc}"
3690         )
3691
3692     try:
3693         dst_ast = parse_ast(dst)
3694     except Exception as exc:
3695         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3696         raise AssertionError(
3697             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3698             f"Please report a bug on https://github.com/psf/black/issues.  "
3699             f"This invalid output might be helpful: {log}"
3700         ) from None
3701
3702     src_ast_str = "\n".join(_v(src_ast))
3703     dst_ast_str = "\n".join(_v(dst_ast))
3704     if src_ast_str != dst_ast_str:
3705         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3706         raise AssertionError(
3707             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3708             f"the source.  "
3709             f"Please report a bug on https://github.com/psf/black/issues.  "
3710             f"This diff might be helpful: {log}"
3711         ) from None
3712
3713
3714 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3715     """Raise AssertionError if `dst` reformats differently the second time."""
3716     newdst = format_str(dst, mode=mode)
3717     if dst != newdst:
3718         log = dump_to_file(
3719             diff(src, dst, "source", "first pass"),
3720             diff(dst, newdst, "first pass", "second pass"),
3721         )
3722         raise AssertionError(
3723             f"INTERNAL ERROR: Black produced different code on the second pass "
3724             f"of the formatter.  "
3725             f"Please report a bug on https://github.com/psf/black/issues.  "
3726             f"This diff might be helpful: {log}"
3727         ) from None
3728
3729
3730 def dump_to_file(*output: str) -> str:
3731     """Dump `output` to a temporary file. Return path to the file."""
3732     with tempfile.NamedTemporaryFile(
3733         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3734     ) as f:
3735         for lines in output:
3736             f.write(lines)
3737             if lines and lines[-1] != "\n":
3738                 f.write("\n")
3739     return f.name
3740
3741
3742 @contextmanager
3743 def nullcontext() -> Iterator[None]:
3744     """Return context manager that does nothing.
3745     Similar to `nullcontext` from python 3.7"""
3746     yield
3747
3748
3749 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3750     """Return a unified diff string between strings `a` and `b`."""
3751     import difflib
3752
3753     a_lines = [line + "\n" for line in a.split("\n")]
3754     b_lines = [line + "\n" for line in b.split("\n")]
3755     return "".join(
3756         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3757     )
3758
3759
3760 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3761     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3762     err("Aborted!")
3763     for task in tasks:
3764         task.cancel()
3765
3766
3767 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3768     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3769     try:
3770         if sys.version_info[:2] >= (3, 7):
3771             all_tasks = asyncio.all_tasks
3772         else:
3773             all_tasks = asyncio.Task.all_tasks
3774         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3775         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3776         if not to_cancel:
3777             return
3778
3779         for task in to_cancel:
3780             task.cancel()
3781         loop.run_until_complete(
3782             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3783         )
3784     finally:
3785         # `concurrent.futures.Future` objects cannot be cancelled once they
3786         # are already running. There might be some when the `shutdown()` happened.
3787         # Silence their logger's spew about the event loop being closed.
3788         cf_logger = logging.getLogger("concurrent.futures")
3789         cf_logger.setLevel(logging.CRITICAL)
3790         loop.close()
3791
3792
3793 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3794     """Replace `regex` with `replacement` twice on `original`.
3795
3796     This is used by string normalization to perform replaces on
3797     overlapping matches.
3798     """
3799     return regex.sub(replacement, regex.sub(replacement, original))
3800
3801
3802 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3803     """Compile a regular expression string in `regex`.
3804
3805     If it contains newlines, use verbose mode.
3806     """
3807     if "\n" in regex:
3808         regex = "(?x)" + regex
3809     compiled: Pattern[str] = re.compile(regex)
3810     return compiled
3811
3812
3813 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3814     """Like `reversed(enumerate(sequence))` if that were possible."""
3815     index = len(sequence) - 1
3816     for element in reversed(sequence):
3817         yield (index, element)
3818         index -= 1
3819
3820
3821 def enumerate_with_length(
3822     line: Line, reversed: bool = False
3823 ) -> Iterator[Tuple[Index, Leaf, int]]:
3824     """Return an enumeration of leaves with their length.
3825
3826     Stops prematurely on multiline strings and standalone comments.
3827     """
3828     op = cast(
3829         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3830         enumerate_reversed if reversed else enumerate,
3831     )
3832     for index, leaf in op(line.leaves):
3833         length = len(leaf.prefix) + len(leaf.value)
3834         if "\n" in leaf.value:
3835             return  # Multiline strings, we can't continue.
3836
3837         for comment in line.comments_after(leaf):
3838             length += len(comment.value)
3839
3840         yield index, leaf, length
3841
3842
3843 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3844     """Return True if `line` is no longer than `line_length`.
3845
3846     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3847     """
3848     if not line_str:
3849         line_str = str(line).strip("\n")
3850     return (
3851         len(line_str) <= line_length
3852         and "\n" not in line_str  # multiline strings
3853         and not line.contains_standalone_comments()
3854     )
3855
3856
3857 def can_be_split(line: Line) -> bool:
3858     """Return False if the line cannot be split *for sure*.
3859
3860     This is not an exhaustive search but a cheap heuristic that we can use to
3861     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3862     in unnecessary parentheses).
3863     """
3864     leaves = line.leaves
3865     if len(leaves) < 2:
3866         return False
3867
3868     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3869         call_count = 0
3870         dot_count = 0
3871         next = leaves[-1]
3872         for leaf in leaves[-2::-1]:
3873             if leaf.type in OPENING_BRACKETS:
3874                 if next.type not in CLOSING_BRACKETS:
3875                     return False
3876
3877                 call_count += 1
3878             elif leaf.type == token.DOT:
3879                 dot_count += 1
3880             elif leaf.type == token.NAME:
3881                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3882                     return False
3883
3884             elif leaf.type not in CLOSING_BRACKETS:
3885                 return False
3886
3887             if dot_count > 1 and call_count > 1:
3888                 return False
3889
3890     return True
3891
3892
3893 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3894     """Does `line` have a shape safe to reformat without optional parens around it?
3895
3896     Returns True for only a subset of potentially nice looking formattings but
3897     the point is to not return false positives that end up producing lines that
3898     are too long.
3899     """
3900     bt = line.bracket_tracker
3901     if not bt.delimiters:
3902         # Without delimiters the optional parentheses are useless.
3903         return True
3904
3905     max_priority = bt.max_delimiter_priority()
3906     if bt.delimiter_count_with_priority(max_priority) > 1:
3907         # With more than one delimiter of a kind the optional parentheses read better.
3908         return False
3909
3910     if max_priority == DOT_PRIORITY:
3911         # A single stranded method call doesn't require optional parentheses.
3912         return True
3913
3914     assert len(line.leaves) >= 2, "Stranded delimiter"
3915
3916     first = line.leaves[0]
3917     second = line.leaves[1]
3918     penultimate = line.leaves[-2]
3919     last = line.leaves[-1]
3920
3921     # With a single delimiter, omit if the expression starts or ends with
3922     # a bracket.
3923     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3924         remainder = False
3925         length = 4 * line.depth
3926         for _index, leaf, leaf_length in enumerate_with_length(line):
3927             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3928                 remainder = True
3929             if remainder:
3930                 length += leaf_length
3931                 if length > line_length:
3932                     break
3933
3934                 if leaf.type in OPENING_BRACKETS:
3935                     # There are brackets we can further split on.
3936                     remainder = False
3937
3938         else:
3939             # checked the entire string and line length wasn't exceeded
3940             if len(line.leaves) == _index + 1:
3941                 return True
3942
3943         # Note: we are not returning False here because a line might have *both*
3944         # a leading opening bracket and a trailing closing bracket.  If the
3945         # opening bracket doesn't match our rule, maybe the closing will.
3946
3947     if (
3948         last.type == token.RPAR
3949         or last.type == token.RBRACE
3950         or (
3951             # don't use indexing for omitting optional parentheses;
3952             # it looks weird
3953             last.type == token.RSQB
3954             and last.parent
3955             and last.parent.type != syms.trailer
3956         )
3957     ):
3958         if penultimate.type in OPENING_BRACKETS:
3959             # Empty brackets don't help.
3960             return False
3961
3962         if is_multiline_string(first):
3963             # Additional wrapping of a multiline string in this situation is
3964             # unnecessary.
3965             return True
3966
3967         length = 4 * line.depth
3968         seen_other_brackets = False
3969         for _index, leaf, leaf_length in enumerate_with_length(line):
3970             length += leaf_length
3971             if leaf is last.opening_bracket:
3972                 if seen_other_brackets or length <= line_length:
3973                     return True
3974
3975             elif leaf.type in OPENING_BRACKETS:
3976                 # There are brackets we can further split on.
3977                 seen_other_brackets = True
3978
3979     return False
3980
3981
3982 def get_cache_file(mode: FileMode) -> Path:
3983     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3984
3985
3986 def read_cache(mode: FileMode) -> Cache:
3987     """Read the cache if it exists and is well formed.
3988
3989     If it is not well formed, the call to write_cache later should resolve the issue.
3990     """
3991     cache_file = get_cache_file(mode)
3992     if not cache_file.exists():
3993         return {}
3994
3995     with cache_file.open("rb") as fobj:
3996         try:
3997             cache: Cache = pickle.load(fobj)
3998         except pickle.UnpicklingError:
3999             return {}
4000
4001     return cache
4002
4003
4004 def get_cache_info(path: Path) -> CacheInfo:
4005     """Return the information used to check if a file is already formatted or not."""
4006     stat = path.stat()
4007     return stat.st_mtime, stat.st_size
4008
4009
4010 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
4011     """Split an iterable of paths in `sources` into two sets.
4012
4013     The first contains paths of files that modified on disk or are not in the
4014     cache. The other contains paths to non-modified files.
4015     """
4016     todo, done = set(), set()
4017     for src in sources:
4018         src = src.resolve()
4019         if cache.get(src) != get_cache_info(src):
4020             todo.add(src)
4021         else:
4022             done.add(src)
4023     return todo, done
4024
4025
4026 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
4027     """Update the cache file."""
4028     cache_file = get_cache_file(mode)
4029     try:
4030         CACHE_DIR.mkdir(parents=True, exist_ok=True)
4031         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
4032         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4033             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
4034         os.replace(f.name, cache_file)
4035     except OSError:
4036         pass
4037
4038
4039 def patch_click() -> None:
4040     """Make Click not crash.
4041
4042     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4043     default which restricts paths that it can access during the lifetime of the
4044     application.  Click refuses to work in this scenario by raising a RuntimeError.
4045
4046     In case of Black the likelihood that non-ASCII characters are going to be used in
4047     file paths is minimal since it's Python source code.  Moreover, this crash was
4048     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4049     """
4050     try:
4051         from click import core
4052         from click import _unicodefun  # type: ignore
4053     except ModuleNotFoundError:
4054         return
4055
4056     for module in (core, _unicodefun):
4057         if hasattr(module, "_verify_python3_env"):
4058             module._verify_python3_env = lambda: None
4059
4060
4061 def patched_main() -> None:
4062     freeze_support()
4063     patch_click()
4064     main()
4065
4066
4067 if __name__ == "__main__":
4068     patched_main()