]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

40b4d2b12876013b0925c9a1c706110befeb9b28
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from contextlib import contextmanager
5 from datetime import datetime
6 from enum import Enum
7 from functools import lru_cache, partial, wraps
8 import io
9 import itertools
10 import logging
11 from multiprocessing import Manager, freeze_support
12 import os
13 from pathlib import Path
14 import pickle
15 import regex as re
16 import signal
17 import sys
18 import tempfile
19 import tokenize
20 import traceback
21 from typing import (
22     Any,
23     Callable,
24     Collection,
25     Dict,
26     Generator,
27     Generic,
28     Iterable,
29     Iterator,
30     List,
31     Optional,
32     Pattern,
33     Sequence,
34     Set,
35     Tuple,
36     TypeVar,
37     Union,
38     cast,
39 )
40
41 from appdirs import user_cache_dir
42 from attr import dataclass, evolve, Factory
43 import click
44 import toml
45 from typed_ast import ast3, ast27
46
47 # lib2to3 fork
48 from blib2to3.pytree import Node, Leaf, type_repr
49 from blib2to3 import pygram, pytree
50 from blib2to3.pgen2 import driver, token
51 from blib2to3.pgen2.grammar import Grammar
52 from blib2to3.pgen2.parse import ParseError
53
54 from _version import version as __version__
55
56 DEFAULT_LINE_LENGTH = 88
57 DEFAULT_EXCLUDES = (
58     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
59 )
60 DEFAULT_INCLUDES = r"\.pyi?$"
61 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
62
63
64 # types
65 FileContent = str
66 Encoding = str
67 NewLine = str
68 Depth = int
69 NodeType = int
70 LeafID = int
71 Priority = int
72 Index = int
73 LN = Union[Leaf, Node]
74 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
75 Timestamp = float
76 FileSize = int
77 CacheInfo = Tuple[Timestamp, FileSize]
78 Cache = Dict[Path, CacheInfo]
79 out = partial(click.secho, bold=True, err=True)
80 err = partial(click.secho, fg="red", err=True)
81
82 pygram.initialize(CACHE_DIR)
83 syms = pygram.python_symbols
84
85
86 class NothingChanged(UserWarning):
87     """Raised when reformatted code is the same as source."""
88
89
90 class CannotSplit(Exception):
91     """A readable split that fits the allotted line length is impossible."""
92
93
94 class InvalidInput(ValueError):
95     """Raised when input source code fails all parse attempts."""
96
97
98 class WriteBack(Enum):
99     NO = 0
100     YES = 1
101     DIFF = 2
102     CHECK = 3
103
104     @classmethod
105     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
106         if check and not diff:
107             return cls.CHECK
108
109         return cls.DIFF if diff else cls.YES
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 class TargetVersion(Enum):
119     PY27 = 2
120     PY33 = 3
121     PY34 = 4
122     PY35 = 5
123     PY36 = 6
124     PY37 = 7
125     PY38 = 8
126
127     def is_python2(self) -> bool:
128         return self is TargetVersion.PY27
129
130
131 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
132
133
134 class Feature(Enum):
135     # All string literals are unicode
136     UNICODE_LITERALS = 1
137     F_STRINGS = 2
138     NUMERIC_UNDERSCORES = 3
139     TRAILING_COMMA_IN_CALL = 4
140     TRAILING_COMMA_IN_DEF = 5
141     # The following two feature-flags are mutually exclusive, and exactly one should be
142     # set for every version of python.
143     ASYNC_IDENTIFIERS = 6
144     ASYNC_KEYWORDS = 7
145     ASSIGNMENT_EXPRESSIONS = 8
146     POS_ONLY_ARGUMENTS = 9
147
148
149 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
150     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
151     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
152     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
153     TargetVersion.PY35: {
154         Feature.UNICODE_LITERALS,
155         Feature.TRAILING_COMMA_IN_CALL,
156         Feature.ASYNC_IDENTIFIERS,
157     },
158     TargetVersion.PY36: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA_IN_CALL,
163         Feature.TRAILING_COMMA_IN_DEF,
164         Feature.ASYNC_IDENTIFIERS,
165     },
166     TargetVersion.PY37: {
167         Feature.UNICODE_LITERALS,
168         Feature.F_STRINGS,
169         Feature.NUMERIC_UNDERSCORES,
170         Feature.TRAILING_COMMA_IN_CALL,
171         Feature.TRAILING_COMMA_IN_DEF,
172         Feature.ASYNC_KEYWORDS,
173     },
174     TargetVersion.PY38: {
175         Feature.UNICODE_LITERALS,
176         Feature.F_STRINGS,
177         Feature.NUMERIC_UNDERSCORES,
178         Feature.TRAILING_COMMA_IN_CALL,
179         Feature.TRAILING_COMMA_IN_DEF,
180         Feature.ASYNC_KEYWORDS,
181         Feature.ASSIGNMENT_EXPRESSIONS,
182         Feature.POS_ONLY_ARGUMENTS,
183     },
184 }
185
186
187 @dataclass
188 class FileMode:
189     target_versions: Set[TargetVersion] = Factory(set)
190     line_length: int = DEFAULT_LINE_LENGTH
191     string_normalization: bool = True
192     is_pyi: bool = False
193
194     def get_cache_key(self) -> str:
195         if self.target_versions:
196             version_str = ",".join(
197                 str(version.value)
198                 for version in sorted(self.target_versions, key=lambda v: v.value)
199             )
200         else:
201             version_str = "-"
202         parts = [
203             version_str,
204             str(self.line_length),
205             str(int(self.string_normalization)),
206             str(int(self.is_pyi)),
207         ]
208         return ".".join(parts)
209
210
211 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
212     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
213
214
215 def read_pyproject_toml(
216     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
217 ) -> Optional[str]:
218     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
219
220     Returns the path to a successfully found and read configuration file, None
221     otherwise.
222     """
223     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
224     if not value:
225         root = find_project_root(ctx.params.get("src", ()))
226         path = root / "pyproject.toml"
227         if path.is_file():
228             value = str(path)
229         else:
230             return None
231
232     try:
233         pyproject_toml = toml.load(value)
234         config = pyproject_toml.get("tool", {}).get("black", {})
235     except (toml.TomlDecodeError, OSError) as e:
236         raise click.FileError(
237             filename=value, hint=f"Error reading configuration file: {e}"
238         )
239
240     if not config:
241         return None
242
243     if ctx.default_map is None:
244         ctx.default_map = {}
245     ctx.default_map.update(  # type: ignore  # bad types in .pyi
246         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
247     )
248     return value
249
250
251 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
252 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
253 @click.option(
254     "-l",
255     "--line-length",
256     type=int,
257     default=DEFAULT_LINE_LENGTH,
258     help="How many characters per line to allow.",
259     show_default=True,
260 )
261 @click.option(
262     "-t",
263     "--target-version",
264     type=click.Choice([v.name.lower() for v in TargetVersion]),
265     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
266     multiple=True,
267     help=(
268         "Python versions that should be supported by Black's output. [default: "
269         "per-file auto-detection]"
270     ),
271 )
272 @click.option(
273     "--py36",
274     is_flag=True,
275     help=(
276         "Allow using Python 3.6-only syntax on all input files.  This will put "
277         "trailing commas in function signatures and calls also after *args and "
278         "**kwargs. Deprecated; use --target-version instead. "
279         "[default: per-file auto-detection]"
280     ),
281 )
282 @click.option(
283     "--pyi",
284     is_flag=True,
285     help=(
286         "Format all input files like typing stubs regardless of file extension "
287         "(useful when piping source on standard input)."
288     ),
289 )
290 @click.option(
291     "-S",
292     "--skip-string-normalization",
293     is_flag=True,
294     help="Don't normalize string quotes or prefixes.",
295 )
296 @click.option(
297     "--check",
298     is_flag=True,
299     help=(
300         "Don't write the files back, just return the status.  Return code 0 "
301         "means nothing would change.  Return code 1 means some files would be "
302         "reformatted.  Return code 123 means there was an internal error."
303     ),
304 )
305 @click.option(
306     "--diff",
307     is_flag=True,
308     help="Don't write the files back, just output a diff for each file on stdout.",
309 )
310 @click.option(
311     "--fast/--safe",
312     is_flag=True,
313     help="If --fast given, skip temporary sanity checks. [default: --safe]",
314 )
315 @click.option(
316     "--include",
317     type=str,
318     default=DEFAULT_INCLUDES,
319     help=(
320         "A regular expression that matches files and directories that should be "
321         "included on recursive searches.  An empty value means all files are "
322         "included regardless of the name.  Use forward slashes for directories on "
323         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
324         "later."
325     ),
326     show_default=True,
327 )
328 @click.option(
329     "--exclude",
330     type=str,
331     default=DEFAULT_EXCLUDES,
332     help=(
333         "A regular expression that matches files and directories that should be "
334         "excluded on recursive searches.  An empty value means no paths are excluded. "
335         "Use forward slashes for directories on all platforms (Windows, too).  "
336         "Exclusions are calculated first, inclusions later."
337     ),
338     show_default=True,
339 )
340 @click.option(
341     "-q",
342     "--quiet",
343     is_flag=True,
344     help=(
345         "Don't emit non-error messages to stderr. Errors are still emitted; "
346         "silence those with 2>/dev/null."
347     ),
348 )
349 @click.option(
350     "-v",
351     "--verbose",
352     is_flag=True,
353     help=(
354         "Also emit messages to stderr about files that were not changed or were "
355         "ignored due to --exclude=."
356     ),
357 )
358 @click.version_option(version=__version__)
359 @click.argument(
360     "src",
361     nargs=-1,
362     type=click.Path(
363         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
364     ),
365     is_eager=True,
366 )
367 @click.option(
368     "--config",
369     type=click.Path(
370         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
371     ),
372     is_eager=True,
373     callback=read_pyproject_toml,
374     help="Read configuration from PATH.",
375 )
376 @click.pass_context
377 def main(
378     ctx: click.Context,
379     code: Optional[str],
380     line_length: int,
381     target_version: List[TargetVersion],
382     check: bool,
383     diff: bool,
384     fast: bool,
385     pyi: bool,
386     py36: bool,
387     skip_string_normalization: bool,
388     quiet: bool,
389     verbose: bool,
390     include: str,
391     exclude: str,
392     src: Tuple[str],
393     config: Optional[str],
394 ) -> None:
395     """The uncompromising code formatter."""
396     write_back = WriteBack.from_configuration(check=check, diff=diff)
397     if target_version:
398         if py36:
399             err(f"Cannot use both --target-version and --py36")
400             ctx.exit(2)
401         else:
402             versions = set(target_version)
403     elif py36:
404         err(
405             "--py36 is deprecated and will be removed in a future version. "
406             "Use --target-version py36 instead."
407         )
408         versions = PY36_VERSIONS
409     else:
410         # We'll autodetect later.
411         versions = set()
412     mode = FileMode(
413         target_versions=versions,
414         line_length=line_length,
415         is_pyi=pyi,
416         string_normalization=not skip_string_normalization,
417     )
418     if config and verbose:
419         out(f"Using configuration from {config}.", bold=False, fg="blue")
420     if code is not None:
421         print(format_str(code, mode=mode))
422         ctx.exit(0)
423     try:
424         include_regex = re_compile_maybe_verbose(include)
425     except re.error:
426         err(f"Invalid regular expression for include given: {include!r}")
427         ctx.exit(2)
428     try:
429         exclude_regex = re_compile_maybe_verbose(exclude)
430     except re.error:
431         err(f"Invalid regular expression for exclude given: {exclude!r}")
432         ctx.exit(2)
433     report = Report(check=check, quiet=quiet, verbose=verbose)
434     root = find_project_root(src)
435     sources: Set[Path] = set()
436     path_empty(src, quiet, verbose, ctx)
437     for s in src:
438         p = Path(s)
439         if p.is_dir():
440             sources.update(
441                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
442             )
443         elif p.is_file() or s == "-":
444             # if a file was explicitly given, we don't care about its extension
445             sources.add(p)
446         else:
447             err(f"invalid path: {s}")
448     if len(sources) == 0:
449         if verbose or not quiet:
450             out("No Python files are present to be formatted. Nothing to do 😴")
451         ctx.exit(0)
452
453     if len(sources) == 1:
454         reformat_one(
455             src=sources.pop(),
456             fast=fast,
457             write_back=write_back,
458             mode=mode,
459             report=report,
460         )
461     else:
462         reformat_many(
463             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
464         )
465
466     if verbose or not quiet:
467         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
468         click.secho(str(report), err=True)
469     ctx.exit(report.return_code)
470
471
472 def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
473     """
474     Exit if there is no `src` provided for formatting
475     """
476     if not src:
477         if verbose or not quiet:
478             out("No Path provided. Nothing to do 😴")
479             ctx.exit(0)
480
481
482 def reformat_one(
483     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
484 ) -> None:
485     """Reformat a single file under `src` without spawning child processes.
486
487     `fast`, `write_back`, and `mode` options are passed to
488     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
489     """
490     try:
491         changed = Changed.NO
492         if not src.is_file() and str(src) == "-":
493             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
494                 changed = Changed.YES
495         else:
496             cache: Cache = {}
497             if write_back != WriteBack.DIFF:
498                 cache = read_cache(mode)
499                 res_src = src.resolve()
500                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
501                     changed = Changed.CACHED
502             if changed is not Changed.CACHED and format_file_in_place(
503                 src, fast=fast, write_back=write_back, mode=mode
504             ):
505                 changed = Changed.YES
506             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
507                 write_back is WriteBack.CHECK and changed is Changed.NO
508             ):
509                 write_cache(cache, [src], mode)
510         report.done(src, changed)
511     except Exception as exc:
512         report.failed(src, str(exc))
513
514
515 def reformat_many(
516     sources: Set[Path],
517     fast: bool,
518     write_back: WriteBack,
519     mode: FileMode,
520     report: "Report",
521 ) -> None:
522     """Reformat multiple files using a ProcessPoolExecutor."""
523     loop = asyncio.get_event_loop()
524     worker_count = os.cpu_count()
525     if sys.platform == "win32":
526         # Work around https://bugs.python.org/issue26903
527         worker_count = min(worker_count, 61)
528     executor = ProcessPoolExecutor(max_workers=worker_count)
529     try:
530         loop.run_until_complete(
531             schedule_formatting(
532                 sources=sources,
533                 fast=fast,
534                 write_back=write_back,
535                 mode=mode,
536                 report=report,
537                 loop=loop,
538                 executor=executor,
539             )
540         )
541     finally:
542         shutdown(loop)
543         executor.shutdown()
544
545
546 async def schedule_formatting(
547     sources: Set[Path],
548     fast: bool,
549     write_back: WriteBack,
550     mode: FileMode,
551     report: "Report",
552     loop: asyncio.AbstractEventLoop,
553     executor: Executor,
554 ) -> None:
555     """Run formatting of `sources` in parallel using the provided `executor`.
556
557     (Use ProcessPoolExecutors for actual parallelism.)
558
559     `write_back`, `fast`, and `mode` options are passed to
560     :func:`format_file_in_place`.
561     """
562     cache: Cache = {}
563     if write_back != WriteBack.DIFF:
564         cache = read_cache(mode)
565         sources, cached = filter_cached(cache, sources)
566         for src in sorted(cached):
567             report.done(src, Changed.CACHED)
568     if not sources:
569         return
570
571     cancelled = []
572     sources_to_cache = []
573     lock = None
574     if write_back == WriteBack.DIFF:
575         # For diff output, we need locks to ensure we don't interleave output
576         # from different processes.
577         manager = Manager()
578         lock = manager.Lock()
579     tasks = {
580         asyncio.ensure_future(
581             loop.run_in_executor(
582                 executor, format_file_in_place, src, fast, mode, write_back, lock
583             )
584         ): src
585         for src in sorted(sources)
586     }
587     pending: Iterable[asyncio.Future] = tasks.keys()
588     try:
589         loop.add_signal_handler(signal.SIGINT, cancel, pending)
590         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
591     except NotImplementedError:
592         # There are no good alternatives for these on Windows.
593         pass
594     while pending:
595         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
596         for task in done:
597             src = tasks.pop(task)
598             if task.cancelled():
599                 cancelled.append(task)
600             elif task.exception():
601                 report.failed(src, str(task.exception()))
602             else:
603                 changed = Changed.YES if task.result() else Changed.NO
604                 # If the file was written back or was successfully checked as
605                 # well-formatted, store this information in the cache.
606                 if write_back is WriteBack.YES or (
607                     write_back is WriteBack.CHECK and changed is Changed.NO
608                 ):
609                     sources_to_cache.append(src)
610                 report.done(src, changed)
611     if cancelled:
612         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
613     if sources_to_cache:
614         write_cache(cache, sources_to_cache, mode)
615
616
617 def format_file_in_place(
618     src: Path,
619     fast: bool,
620     mode: FileMode,
621     write_back: WriteBack = WriteBack.NO,
622     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
623 ) -> bool:
624     """Format file under `src` path. Return True if changed.
625
626     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
627     code to the file.
628     `mode` and `fast` options are passed to :func:`format_file_contents`.
629     """
630     if src.suffix == ".pyi":
631         mode = evolve(mode, is_pyi=True)
632
633     then = datetime.utcfromtimestamp(src.stat().st_mtime)
634     with open(src, "rb") as buf:
635         src_contents, encoding, newline = decode_bytes(buf.read())
636     try:
637         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
638     except NothingChanged:
639         return False
640
641     if write_back == write_back.YES:
642         with open(src, "w", encoding=encoding, newline=newline) as f:
643             f.write(dst_contents)
644     elif write_back == write_back.DIFF:
645         now = datetime.utcnow()
646         src_name = f"{src}\t{then} +0000"
647         dst_name = f"{src}\t{now} +0000"
648         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
649
650         with lock or nullcontext():
651             f = io.TextIOWrapper(
652                 sys.stdout.buffer,
653                 encoding=encoding,
654                 newline=newline,
655                 write_through=True,
656             )
657             f.write(diff_contents)
658             f.detach()
659
660     return True
661
662
663 def format_stdin_to_stdout(
664     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
665 ) -> bool:
666     """Format file on stdin. Return True if changed.
667
668     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
669     write a diff to stdout. The `mode` argument is passed to
670     :func:`format_file_contents`.
671     """
672     then = datetime.utcnow()
673     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
674     dst = src
675     try:
676         dst = format_file_contents(src, fast=fast, mode=mode)
677         return True
678
679     except NothingChanged:
680         return False
681
682     finally:
683         f = io.TextIOWrapper(
684             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
685         )
686         if write_back == WriteBack.YES:
687             f.write(dst)
688         elif write_back == WriteBack.DIFF:
689             now = datetime.utcnow()
690             src_name = f"STDIN\t{then} +0000"
691             dst_name = f"STDOUT\t{now} +0000"
692             f.write(diff(src, dst, src_name, dst_name))
693         f.detach()
694
695
696 def format_file_contents(
697     src_contents: str, *, fast: bool, mode: FileMode
698 ) -> FileContent:
699     """Reformat contents a file and return new contents.
700
701     If `fast` is False, additionally confirm that the reformatted code is
702     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
703     `mode` is passed to :func:`format_str`.
704     """
705     if src_contents.strip() == "":
706         raise NothingChanged
707
708     dst_contents = format_str(src_contents, mode=mode)
709     if src_contents == dst_contents:
710         raise NothingChanged
711
712     if not fast:
713         assert_equivalent(src_contents, dst_contents)
714         assert_stable(src_contents, dst_contents, mode=mode)
715     return dst_contents
716
717
718 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
719     """Reformat a string and return new contents.
720
721     `mode` determines formatting options, such as how many characters per line are
722     allowed.
723     """
724     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
725     dst_contents = []
726     future_imports = get_future_imports(src_node)
727     if mode.target_versions:
728         versions = mode.target_versions
729     else:
730         versions = detect_target_versions(src_node)
731     normalize_fmt_off(src_node)
732     lines = LineGenerator(
733         remove_u_prefix="unicode_literals" in future_imports
734         or supports_feature(versions, Feature.UNICODE_LITERALS),
735         is_pyi=mode.is_pyi,
736         normalize_strings=mode.string_normalization,
737     )
738     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
739     empty_line = Line()
740     after = 0
741     split_line_features = {
742         feature
743         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
744         if supports_feature(versions, feature)
745     }
746     for current_line in lines.visit(src_node):
747         for _ in range(after):
748             dst_contents.append(str(empty_line))
749         before, after = elt.maybe_empty_lines(current_line)
750         for _ in range(before):
751             dst_contents.append(str(empty_line))
752         for line in split_line(
753             current_line, line_length=mode.line_length, features=split_line_features
754         ):
755             dst_contents.append(str(line))
756     return "".join(dst_contents)
757
758
759 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
760     """Return a tuple of (decoded_contents, encoding, newline).
761
762     `newline` is either CRLF or LF but `decoded_contents` is decoded with
763     universal newlines (i.e. only contains LF).
764     """
765     srcbuf = io.BytesIO(src)
766     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
767     if not lines:
768         return "", encoding, "\n"
769
770     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
771     srcbuf.seek(0)
772     with io.TextIOWrapper(srcbuf, encoding) as tiow:
773         return tiow.read(), encoding, newline
774
775
776 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
777     if not target_versions:
778         # No target_version specified, so try all grammars.
779         return [
780             # Python 3.7+
781             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
782             # Python 3.0-3.6
783             pygram.python_grammar_no_print_statement_no_exec_statement,
784             # Python 2.7 with future print_function import
785             pygram.python_grammar_no_print_statement,
786             # Python 2.7
787             pygram.python_grammar,
788         ]
789     elif all(version.is_python2() for version in target_versions):
790         # Python 2-only code, so try Python 2 grammars.
791         return [
792             # Python 2.7 with future print_function import
793             pygram.python_grammar_no_print_statement,
794             # Python 2.7
795             pygram.python_grammar,
796         ]
797     else:
798         # Python 3-compatible code, so only try Python 3 grammar.
799         grammars = []
800         # If we have to parse both, try to parse async as a keyword first
801         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
802             # Python 3.7+
803             grammars.append(
804                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
805             )
806         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
807             # Python 3.0-3.6
808             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
809         # At least one of the above branches must have been taken, because every Python
810         # version has exactly one of the two 'ASYNC_*' flags
811         return grammars
812
813
814 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
815     """Given a string with source, return the lib2to3 Node."""
816     if src_txt[-1:] != "\n":
817         src_txt += "\n"
818
819     for grammar in get_grammars(set(target_versions)):
820         drv = driver.Driver(grammar, pytree.convert)
821         try:
822             result = drv.parse_string(src_txt, True)
823             break
824
825         except ParseError as pe:
826             lineno, column = pe.context[1]
827             lines = src_txt.splitlines()
828             try:
829                 faulty_line = lines[lineno - 1]
830             except IndexError:
831                 faulty_line = "<line number missing in source>"
832             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
833     else:
834         raise exc from None
835
836     if isinstance(result, Leaf):
837         result = Node(syms.file_input, [result])
838     return result
839
840
841 def lib2to3_unparse(node: Node) -> str:
842     """Given a lib2to3 node, return its string representation."""
843     code = str(node)
844     return code
845
846
847 T = TypeVar("T")
848
849
850 class Visitor(Generic[T]):
851     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
852
853     def visit(self, node: LN) -> Iterator[T]:
854         """Main method to visit `node` and its children.
855
856         It tries to find a `visit_*()` method for the given `node.type`, like
857         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
858         If no dedicated `visit_*()` method is found, chooses `visit_default()`
859         instead.
860
861         Then yields objects of type `T` from the selected visitor.
862         """
863         if node.type < 256:
864             name = token.tok_name[node.type]
865         else:
866             name = type_repr(node.type)
867         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
868
869     def visit_default(self, node: LN) -> Iterator[T]:
870         """Default `visit_*()` implementation. Recurses to children of `node`."""
871         if isinstance(node, Node):
872             for child in node.children:
873                 yield from self.visit(child)
874
875
876 @dataclass
877 class DebugVisitor(Visitor[T]):
878     tree_depth: int = 0
879
880     def visit_default(self, node: LN) -> Iterator[T]:
881         indent = " " * (2 * self.tree_depth)
882         if isinstance(node, Node):
883             _type = type_repr(node.type)
884             out(f"{indent}{_type}", fg="yellow")
885             self.tree_depth += 1
886             for child in node.children:
887                 yield from self.visit(child)
888
889             self.tree_depth -= 1
890             out(f"{indent}/{_type}", fg="yellow", bold=False)
891         else:
892             _type = token.tok_name.get(node.type, str(node.type))
893             out(f"{indent}{_type}", fg="blue", nl=False)
894             if node.prefix:
895                 # We don't have to handle prefixes for `Node` objects since
896                 # that delegates to the first child anyway.
897                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
898             out(f" {node.value!r}", fg="blue", bold=False)
899
900     @classmethod
901     def show(cls, code: Union[str, Leaf, Node]) -> None:
902         """Pretty-print the lib2to3 AST of a given string of `code`.
903
904         Convenience method for debugging.
905         """
906         v: DebugVisitor[None] = DebugVisitor()
907         if isinstance(code, str):
908             code = lib2to3_parse(code)
909         list(v.visit(code))
910
911
912 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
913 STATEMENT = {
914     syms.if_stmt,
915     syms.while_stmt,
916     syms.for_stmt,
917     syms.try_stmt,
918     syms.except_clause,
919     syms.with_stmt,
920     syms.funcdef,
921     syms.classdef,
922 }
923 STANDALONE_COMMENT = 153
924 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
925 LOGIC_OPERATORS = {"and", "or"}
926 COMPARATORS = {
927     token.LESS,
928     token.GREATER,
929     token.EQEQUAL,
930     token.NOTEQUAL,
931     token.LESSEQUAL,
932     token.GREATEREQUAL,
933 }
934 MATH_OPERATORS = {
935     token.VBAR,
936     token.CIRCUMFLEX,
937     token.AMPER,
938     token.LEFTSHIFT,
939     token.RIGHTSHIFT,
940     token.PLUS,
941     token.MINUS,
942     token.STAR,
943     token.SLASH,
944     token.DOUBLESLASH,
945     token.PERCENT,
946     token.AT,
947     token.TILDE,
948     token.DOUBLESTAR,
949 }
950 STARS = {token.STAR, token.DOUBLESTAR}
951 VARARGS_SPECIALS = STARS | {token.SLASH}
952 VARARGS_PARENTS = {
953     syms.arglist,
954     syms.argument,  # double star in arglist
955     syms.trailer,  # single argument to call
956     syms.typedargslist,
957     syms.varargslist,  # lambdas
958 }
959 UNPACKING_PARENTS = {
960     syms.atom,  # single element of a list or set literal
961     syms.dictsetmaker,
962     syms.listmaker,
963     syms.testlist_gexp,
964     syms.testlist_star_expr,
965 }
966 TEST_DESCENDANTS = {
967     syms.test,
968     syms.lambdef,
969     syms.or_test,
970     syms.and_test,
971     syms.not_test,
972     syms.comparison,
973     syms.star_expr,
974     syms.expr,
975     syms.xor_expr,
976     syms.and_expr,
977     syms.shift_expr,
978     syms.arith_expr,
979     syms.trailer,
980     syms.term,
981     syms.power,
982 }
983 ASSIGNMENTS = {
984     "=",
985     "+=",
986     "-=",
987     "*=",
988     "@=",
989     "/=",
990     "%=",
991     "&=",
992     "|=",
993     "^=",
994     "<<=",
995     ">>=",
996     "**=",
997     "//=",
998 }
999 COMPREHENSION_PRIORITY = 20
1000 COMMA_PRIORITY = 18
1001 TERNARY_PRIORITY = 16
1002 LOGIC_PRIORITY = 14
1003 STRING_PRIORITY = 12
1004 COMPARATOR_PRIORITY = 10
1005 MATH_PRIORITIES = {
1006     token.VBAR: 9,
1007     token.CIRCUMFLEX: 8,
1008     token.AMPER: 7,
1009     token.LEFTSHIFT: 6,
1010     token.RIGHTSHIFT: 6,
1011     token.PLUS: 5,
1012     token.MINUS: 5,
1013     token.STAR: 4,
1014     token.SLASH: 4,
1015     token.DOUBLESLASH: 4,
1016     token.PERCENT: 4,
1017     token.AT: 4,
1018     token.TILDE: 3,
1019     token.DOUBLESTAR: 2,
1020 }
1021 DOT_PRIORITY = 1
1022
1023
1024 @dataclass
1025 class BracketTracker:
1026     """Keeps track of brackets on a line."""
1027
1028     depth: int = 0
1029     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1030     delimiters: Dict[LeafID, Priority] = Factory(dict)
1031     previous: Optional[Leaf] = None
1032     _for_loop_depths: List[int] = Factory(list)
1033     _lambda_argument_depths: List[int] = Factory(list)
1034
1035     def mark(self, leaf: Leaf) -> None:
1036         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1037
1038         All leaves receive an int `bracket_depth` field that stores how deep
1039         within brackets a given leaf is. 0 means there are no enclosing brackets
1040         that started on this line.
1041
1042         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1043         field that it forms a pair with. This is a one-directional link to
1044         avoid reference cycles.
1045
1046         If a leaf is a delimiter (a token on which Black can split the line if
1047         needed) and it's on depth 0, its `id()` is stored in the tracker's
1048         `delimiters` field.
1049         """
1050         if leaf.type == token.COMMENT:
1051             return
1052
1053         self.maybe_decrement_after_for_loop_variable(leaf)
1054         self.maybe_decrement_after_lambda_arguments(leaf)
1055         if leaf.type in CLOSING_BRACKETS:
1056             self.depth -= 1
1057             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1058             leaf.opening_bracket = opening_bracket
1059         leaf.bracket_depth = self.depth
1060         if self.depth == 0:
1061             delim = is_split_before_delimiter(leaf, self.previous)
1062             if delim and self.previous is not None:
1063                 self.delimiters[id(self.previous)] = delim
1064             else:
1065                 delim = is_split_after_delimiter(leaf, self.previous)
1066                 if delim:
1067                     self.delimiters[id(leaf)] = delim
1068         if leaf.type in OPENING_BRACKETS:
1069             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1070             self.depth += 1
1071         self.previous = leaf
1072         self.maybe_increment_lambda_arguments(leaf)
1073         self.maybe_increment_for_loop_variable(leaf)
1074
1075     def any_open_brackets(self) -> bool:
1076         """Return True if there is an yet unmatched open bracket on the line."""
1077         return bool(self.bracket_match)
1078
1079     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1080         """Return the highest priority of a delimiter found on the line.
1081
1082         Values are consistent with what `is_split_*_delimiter()` return.
1083         Raises ValueError on no delimiters.
1084         """
1085         return max(v for k, v in self.delimiters.items() if k not in exclude)
1086
1087     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1088         """Return the number of delimiters with the given `priority`.
1089
1090         If no `priority` is passed, defaults to max priority on the line.
1091         """
1092         if not self.delimiters:
1093             return 0
1094
1095         priority = priority or self.max_delimiter_priority()
1096         return sum(1 for p in self.delimiters.values() if p == priority)
1097
1098     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1099         """In a for loop, or comprehension, the variables are often unpacks.
1100
1101         To avoid splitting on the comma in this situation, increase the depth of
1102         tokens between `for` and `in`.
1103         """
1104         if leaf.type == token.NAME and leaf.value == "for":
1105             self.depth += 1
1106             self._for_loop_depths.append(self.depth)
1107             return True
1108
1109         return False
1110
1111     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1112         """See `maybe_increment_for_loop_variable` above for explanation."""
1113         if (
1114             self._for_loop_depths
1115             and self._for_loop_depths[-1] == self.depth
1116             and leaf.type == token.NAME
1117             and leaf.value == "in"
1118         ):
1119             self.depth -= 1
1120             self._for_loop_depths.pop()
1121             return True
1122
1123         return False
1124
1125     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1126         """In a lambda expression, there might be more than one argument.
1127
1128         To avoid splitting on the comma in this situation, increase the depth of
1129         tokens between `lambda` and `:`.
1130         """
1131         if leaf.type == token.NAME and leaf.value == "lambda":
1132             self.depth += 1
1133             self._lambda_argument_depths.append(self.depth)
1134             return True
1135
1136         return False
1137
1138     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1139         """See `maybe_increment_lambda_arguments` above for explanation."""
1140         if (
1141             self._lambda_argument_depths
1142             and self._lambda_argument_depths[-1] == self.depth
1143             and leaf.type == token.COLON
1144         ):
1145             self.depth -= 1
1146             self._lambda_argument_depths.pop()
1147             return True
1148
1149         return False
1150
1151     def get_open_lsqb(self) -> Optional[Leaf]:
1152         """Return the most recent opening square bracket (if any)."""
1153         return self.bracket_match.get((self.depth - 1, token.RSQB))
1154
1155
1156 @dataclass
1157 class Line:
1158     """Holds leaves and comments. Can be printed with `str(line)`."""
1159
1160     depth: int = 0
1161     leaves: List[Leaf] = Factory(list)
1162     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1163     bracket_tracker: BracketTracker = Factory(BracketTracker)
1164     inside_brackets: bool = False
1165     should_explode: bool = False
1166
1167     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1168         """Add a new `leaf` to the end of the line.
1169
1170         Unless `preformatted` is True, the `leaf` will receive a new consistent
1171         whitespace prefix and metadata applied by :class:`BracketTracker`.
1172         Trailing commas are maybe removed, unpacked for loop variables are
1173         demoted from being delimiters.
1174
1175         Inline comments are put aside.
1176         """
1177         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1178         if not has_value:
1179             return
1180
1181         if token.COLON == leaf.type and self.is_class_paren_empty:
1182             del self.leaves[-2:]
1183         if self.leaves and not preformatted:
1184             # Note: at this point leaf.prefix should be empty except for
1185             # imports, for which we only preserve newlines.
1186             leaf.prefix += whitespace(
1187                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1188             )
1189         if self.inside_brackets or not preformatted:
1190             self.bracket_tracker.mark(leaf)
1191             self.maybe_remove_trailing_comma(leaf)
1192         if not self.append_comment(leaf):
1193             self.leaves.append(leaf)
1194
1195     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1196         """Like :func:`append()` but disallow invalid standalone comment structure.
1197
1198         Raises ValueError when any `leaf` is appended after a standalone comment
1199         or when a standalone comment is not the first leaf on the line.
1200         """
1201         if self.bracket_tracker.depth == 0:
1202             if self.is_comment:
1203                 raise ValueError("cannot append to standalone comments")
1204
1205             if self.leaves and leaf.type == STANDALONE_COMMENT:
1206                 raise ValueError(
1207                     "cannot append standalone comments to a populated line"
1208                 )
1209
1210         self.append(leaf, preformatted=preformatted)
1211
1212     @property
1213     def is_comment(self) -> bool:
1214         """Is this line a standalone comment?"""
1215         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1216
1217     @property
1218     def is_decorator(self) -> bool:
1219         """Is this line a decorator?"""
1220         return bool(self) and self.leaves[0].type == token.AT
1221
1222     @property
1223     def is_import(self) -> bool:
1224         """Is this an import line?"""
1225         return bool(self) and is_import(self.leaves[0])
1226
1227     @property
1228     def is_class(self) -> bool:
1229         """Is this line a class definition?"""
1230         return (
1231             bool(self)
1232             and self.leaves[0].type == token.NAME
1233             and self.leaves[0].value == "class"
1234         )
1235
1236     @property
1237     def is_stub_class(self) -> bool:
1238         """Is this line a class definition with a body consisting only of "..."?"""
1239         return self.is_class and self.leaves[-3:] == [
1240             Leaf(token.DOT, ".") for _ in range(3)
1241         ]
1242
1243     @property
1244     def is_collection_with_optional_trailing_comma(self) -> bool:
1245         """Is this line a collection literal with a trailing comma that's optional?
1246
1247         Note that the trailing comma in a 1-tuple is not optional.
1248         """
1249         if not self.leaves or len(self.leaves) < 4:
1250             return False
1251         # Look for and address a trailing colon.
1252         if self.leaves[-1].type == token.COLON:
1253             closer = self.leaves[-2]
1254             close_index = -2
1255         else:
1256             closer = self.leaves[-1]
1257             close_index = -1
1258         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1259             return False
1260         if closer.type == token.RPAR:
1261             # Tuples require an extra check, because if there's only
1262             # one element in the tuple removing the comma unmakes the
1263             # tuple.
1264             #
1265             # We also check for parens before looking for the trailing
1266             # comma because in some cases (eg assigning a dict
1267             # literal) the literal gets wrapped in temporary parens
1268             # during parsing. This case is covered by the
1269             # collections.py test data.
1270             opener = closer.opening_bracket
1271             for _open_index, leaf in enumerate(self.leaves):
1272                 if leaf is opener:
1273                     break
1274             else:
1275                 # Couldn't find the matching opening paren, play it safe.
1276                 return False
1277             commas = 0
1278             comma_depth = self.leaves[close_index - 1].bracket_depth
1279             for leaf in self.leaves[_open_index + 1 : close_index]:
1280                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1281                     commas += 1
1282             if commas > 1:
1283                 # We haven't looked yet for the trailing comma because
1284                 # we might also have caught noop parens.
1285                 return self.leaves[close_index - 1].type == token.COMMA
1286             elif commas == 1:
1287                 return False  # it's either a one-tuple or didn't have a trailing comma
1288             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1289                 close_index -= 1
1290                 closer = self.leaves[close_index]
1291                 if closer.type == token.RPAR:
1292                     # TODO: this is a gut feeling. Will we ever see this?
1293                     return False
1294         if self.leaves[close_index - 1].type != token.COMMA:
1295             return False
1296         return True
1297
1298     @property
1299     def is_def(self) -> bool:
1300         """Is this a function definition? (Also returns True for async defs.)"""
1301         try:
1302             first_leaf = self.leaves[0]
1303         except IndexError:
1304             return False
1305
1306         try:
1307             second_leaf: Optional[Leaf] = self.leaves[1]
1308         except IndexError:
1309             second_leaf = None
1310         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1311             first_leaf.type == token.ASYNC
1312             and second_leaf is not None
1313             and second_leaf.type == token.NAME
1314             and second_leaf.value == "def"
1315         )
1316
1317     @property
1318     def is_class_paren_empty(self) -> bool:
1319         """Is this a class with no base classes but using parentheses?
1320
1321         Those are unnecessary and should be removed.
1322         """
1323         return (
1324             bool(self)
1325             and len(self.leaves) == 4
1326             and self.is_class
1327             and self.leaves[2].type == token.LPAR
1328             and self.leaves[2].value == "("
1329             and self.leaves[3].type == token.RPAR
1330             and self.leaves[3].value == ")"
1331         )
1332
1333     @property
1334     def is_triple_quoted_string(self) -> bool:
1335         """Is the line a triple quoted string?"""
1336         return (
1337             bool(self)
1338             and self.leaves[0].type == token.STRING
1339             and self.leaves[0].value.startswith(('"""', "'''"))
1340         )
1341
1342     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1343         """If so, needs to be split before emitting."""
1344         for leaf in self.leaves:
1345             if leaf.type == STANDALONE_COMMENT:
1346                 if leaf.bracket_depth <= depth_limit:
1347                     return True
1348         return False
1349
1350     def contains_uncollapsable_type_comments(self) -> bool:
1351         ignored_ids = set()
1352         try:
1353             last_leaf = self.leaves[-1]
1354             ignored_ids.add(id(last_leaf))
1355             if last_leaf.type == token.COMMA or (
1356                 last_leaf.type == token.RPAR and not last_leaf.value
1357             ):
1358                 # When trailing commas or optional parens are inserted by Black for
1359                 # consistency, comments after the previous last element are not moved
1360                 # (they don't have to, rendering will still be correct).  So we ignore
1361                 # trailing commas and invisible.
1362                 last_leaf = self.leaves[-2]
1363                 ignored_ids.add(id(last_leaf))
1364         except IndexError:
1365             return False
1366
1367         # A type comment is uncollapsable if it is attached to a leaf
1368         # that isn't at the end of the line (since that could cause it
1369         # to get associated to a different argument) or if there are
1370         # comments before it (since that could cause it to get hidden
1371         # behind a comment.
1372         comment_seen = False
1373         for leaf_id, comments in self.comments.items():
1374             for comment in comments:
1375                 if is_type_comment(comment):
1376                     if leaf_id not in ignored_ids or comment_seen:
1377                         return True
1378
1379                 comment_seen = True
1380
1381         return False
1382
1383     def contains_unsplittable_type_ignore(self) -> bool:
1384         if not self.leaves:
1385             return False
1386
1387         # If a 'type: ignore' is attached to the end of a line, we
1388         # can't split the line, because we can't know which of the
1389         # subexpressions the ignore was meant to apply to.
1390         #
1391         # We only want this to apply to actual physical lines from the
1392         # original source, though: we don't want the presence of a
1393         # 'type: ignore' at the end of a multiline expression to
1394         # justify pushing it all onto one line. Thus we
1395         # (unfortunately) need to check the actual source lines and
1396         # only report an unsplittable 'type: ignore' if this line was
1397         # one line in the original code.
1398
1399         # Grab the first and last line numbers, skipping generated leaves
1400         first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0)
1401         last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0)
1402
1403         if first_line == last_line:
1404             # We look at the last two leaves since a comma or an
1405             # invisible paren could have been added at the end of the
1406             # line.
1407             for node in self.leaves[-2:]:
1408                 for comment in self.comments.get(id(node), []):
1409                     if is_type_comment(comment, " ignore"):
1410                         return True
1411
1412         return False
1413
1414     def contains_multiline_strings(self) -> bool:
1415         for leaf in self.leaves:
1416             if is_multiline_string(leaf):
1417                 return True
1418
1419         return False
1420
1421     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1422         """Remove trailing comma if there is one and it's safe."""
1423         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1424             return False
1425         # We remove trailing commas only in the case of importing a
1426         # single name from a module.
1427         if not (
1428             self.leaves
1429             and self.is_import
1430             and len(self.leaves) > 4
1431             and self.leaves[-1].type == token.COMMA
1432             and closing.type in CLOSING_BRACKETS
1433             and self.leaves[-4].type == token.NAME
1434             and (
1435                 # regular `from foo import bar,`
1436                 self.leaves[-4].value == "import"
1437                 # `from foo import (bar as baz,)
1438                 or (
1439                     len(self.leaves) > 6
1440                     and self.leaves[-6].value == "import"
1441                     and self.leaves[-3].value == "as"
1442                 )
1443                 # `from foo import bar as baz,`
1444                 or (
1445                     len(self.leaves) > 5
1446                     and self.leaves[-5].value == "import"
1447                     and self.leaves[-3].value == "as"
1448                 )
1449             )
1450             and closing.type == token.RPAR
1451         ):
1452             return False
1453
1454         self.remove_trailing_comma()
1455         return True
1456
1457     def append_comment(self, comment: Leaf) -> bool:
1458         """Add an inline or standalone comment to the line."""
1459         if (
1460             comment.type == STANDALONE_COMMENT
1461             and self.bracket_tracker.any_open_brackets()
1462         ):
1463             comment.prefix = ""
1464             return False
1465
1466         if comment.type != token.COMMENT:
1467             return False
1468
1469         if not self.leaves:
1470             comment.type = STANDALONE_COMMENT
1471             comment.prefix = ""
1472             return False
1473
1474         last_leaf = self.leaves[-1]
1475         if (
1476             last_leaf.type == token.RPAR
1477             and not last_leaf.value
1478             and last_leaf.parent
1479             and len(list(last_leaf.parent.leaves())) <= 3
1480             and not is_type_comment(comment)
1481         ):
1482             # Comments on an optional parens wrapping a single leaf should belong to
1483             # the wrapped node except if it's a type comment. Pinning the comment like
1484             # this avoids unstable formatting caused by comment migration.
1485             if len(self.leaves) < 2:
1486                 comment.type = STANDALONE_COMMENT
1487                 comment.prefix = ""
1488                 return False
1489             last_leaf = self.leaves[-2]
1490         self.comments.setdefault(id(last_leaf), []).append(comment)
1491         return True
1492
1493     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1494         """Generate comments that should appear directly after `leaf`."""
1495         return self.comments.get(id(leaf), [])
1496
1497     def remove_trailing_comma(self) -> None:
1498         """Remove the trailing comma and moves the comments attached to it."""
1499         trailing_comma = self.leaves.pop()
1500         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1501         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1502             trailing_comma_comments
1503         )
1504
1505     def is_complex_subscript(self, leaf: Leaf) -> bool:
1506         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1507         open_lsqb = self.bracket_tracker.get_open_lsqb()
1508         if open_lsqb is None:
1509             return False
1510
1511         subscript_start = open_lsqb.next_sibling
1512
1513         if isinstance(subscript_start, Node):
1514             if subscript_start.type == syms.listmaker:
1515                 return False
1516
1517             if subscript_start.type == syms.subscriptlist:
1518                 subscript_start = child_towards(subscript_start, leaf)
1519         return subscript_start is not None and any(
1520             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1521         )
1522
1523     def __str__(self) -> str:
1524         """Render the line."""
1525         if not self:
1526             return "\n"
1527
1528         indent = "    " * self.depth
1529         leaves = iter(self.leaves)
1530         first = next(leaves)
1531         res = f"{first.prefix}{indent}{first.value}"
1532         for leaf in leaves:
1533             res += str(leaf)
1534         for comment in itertools.chain.from_iterable(self.comments.values()):
1535             res += str(comment)
1536         return res + "\n"
1537
1538     def __bool__(self) -> bool:
1539         """Return True if the line has leaves or comments."""
1540         return bool(self.leaves or self.comments)
1541
1542
1543 @dataclass
1544 class EmptyLineTracker:
1545     """Provides a stateful method that returns the number of potential extra
1546     empty lines needed before and after the currently processed line.
1547
1548     Note: this tracker works on lines that haven't been split yet.  It assumes
1549     the prefix of the first leaf consists of optional newlines.  Those newlines
1550     are consumed by `maybe_empty_lines()` and included in the computation.
1551     """
1552
1553     is_pyi: bool = False
1554     previous_line: Optional[Line] = None
1555     previous_after: int = 0
1556     previous_defs: List[int] = Factory(list)
1557
1558     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1559         """Return the number of extra empty lines before and after the `current_line`.
1560
1561         This is for separating `def`, `async def` and `class` with extra empty
1562         lines (two on module-level).
1563         """
1564         before, after = self._maybe_empty_lines(current_line)
1565         before = (
1566             # Black should not insert empty lines at the beginning
1567             # of the file
1568             0
1569             if self.previous_line is None
1570             else before - self.previous_after
1571         )
1572         self.previous_after = after
1573         self.previous_line = current_line
1574         return before, after
1575
1576     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1577         max_allowed = 1
1578         if current_line.depth == 0:
1579             max_allowed = 1 if self.is_pyi else 2
1580         if current_line.leaves:
1581             # Consume the first leaf's extra newlines.
1582             first_leaf = current_line.leaves[0]
1583             before = first_leaf.prefix.count("\n")
1584             before = min(before, max_allowed)
1585             first_leaf.prefix = ""
1586         else:
1587             before = 0
1588         depth = current_line.depth
1589         while self.previous_defs and self.previous_defs[-1] >= depth:
1590             self.previous_defs.pop()
1591             if self.is_pyi:
1592                 before = 0 if depth else 1
1593             else:
1594                 before = 1 if depth else 2
1595         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1596             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1597
1598         if (
1599             self.previous_line
1600             and self.previous_line.is_import
1601             and not current_line.is_import
1602             and depth == self.previous_line.depth
1603         ):
1604             return (before or 1), 0
1605
1606         if (
1607             self.previous_line
1608             and self.previous_line.is_class
1609             and current_line.is_triple_quoted_string
1610         ):
1611             return before, 1
1612
1613         return before, 0
1614
1615     def _maybe_empty_lines_for_class_or_def(
1616         self, current_line: Line, before: int
1617     ) -> Tuple[int, int]:
1618         if not current_line.is_decorator:
1619             self.previous_defs.append(current_line.depth)
1620         if self.previous_line is None:
1621             # Don't insert empty lines before the first line in the file.
1622             return 0, 0
1623
1624         if self.previous_line.is_decorator:
1625             return 0, 0
1626
1627         if self.previous_line.depth < current_line.depth and (
1628             self.previous_line.is_class or self.previous_line.is_def
1629         ):
1630             return 0, 0
1631
1632         if (
1633             self.previous_line.is_comment
1634             and self.previous_line.depth == current_line.depth
1635             and before == 0
1636         ):
1637             return 0, 0
1638
1639         if self.is_pyi:
1640             if self.previous_line.depth > current_line.depth:
1641                 newlines = 1
1642             elif current_line.is_class or self.previous_line.is_class:
1643                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1644                     # No blank line between classes with an empty body
1645                     newlines = 0
1646                 else:
1647                     newlines = 1
1648             elif current_line.is_def and not self.previous_line.is_def:
1649                 # Blank line between a block of functions and a block of non-functions
1650                 newlines = 1
1651             else:
1652                 newlines = 0
1653         else:
1654             newlines = 2
1655         if current_line.depth and newlines:
1656             newlines -= 1
1657         return newlines, 0
1658
1659
1660 @dataclass
1661 class LineGenerator(Visitor[Line]):
1662     """Generates reformatted Line objects.  Empty lines are not emitted.
1663
1664     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1665     in ways that will no longer stringify to valid Python code on the tree.
1666     """
1667
1668     is_pyi: bool = False
1669     normalize_strings: bool = True
1670     current_line: Line = Factory(Line)
1671     remove_u_prefix: bool = False
1672
1673     def line(self, indent: int = 0) -> Iterator[Line]:
1674         """Generate a line.
1675
1676         If the line is empty, only emit if it makes sense.
1677         If the line is too long, split it first and then generate.
1678
1679         If any lines were generated, set up a new current_line.
1680         """
1681         if not self.current_line:
1682             self.current_line.depth += indent
1683             return  # Line is empty, don't emit. Creating a new one unnecessary.
1684
1685         complete_line = self.current_line
1686         self.current_line = Line(depth=complete_line.depth + indent)
1687         yield complete_line
1688
1689     def visit_default(self, node: LN) -> Iterator[Line]:
1690         """Default `visit_*()` implementation. Recurses to children of `node`."""
1691         if isinstance(node, Leaf):
1692             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1693             for comment in generate_comments(node):
1694                 if any_open_brackets:
1695                     # any comment within brackets is subject to splitting
1696                     self.current_line.append(comment)
1697                 elif comment.type == token.COMMENT:
1698                     # regular trailing comment
1699                     self.current_line.append(comment)
1700                     yield from self.line()
1701
1702                 else:
1703                     # regular standalone comment
1704                     yield from self.line()
1705
1706                     self.current_line.append(comment)
1707                     yield from self.line()
1708
1709             normalize_prefix(node, inside_brackets=any_open_brackets)
1710             if self.normalize_strings and node.type == token.STRING:
1711                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1712                 normalize_string_quotes(node)
1713             if node.type == token.NUMBER:
1714                 normalize_numeric_literal(node)
1715             if node.type not in WHITESPACE:
1716                 self.current_line.append(node)
1717         yield from super().visit_default(node)
1718
1719     def visit_atom(self, node: Node) -> Iterator[Line]:
1720         # Always make parentheses invisible around a single node, because it should
1721         # not be needed (except in the case of yield, where removing the parentheses
1722         # produces a SyntaxError).
1723         if (
1724             len(node.children) == 3
1725             and isinstance(node.children[0], Leaf)
1726             and node.children[0].type == token.LPAR
1727             and isinstance(node.children[2], Leaf)
1728             and node.children[2].type == token.RPAR
1729             and isinstance(node.children[1], Leaf)
1730             and not (
1731                 node.children[1].type == token.NAME
1732                 and node.children[1].value == "yield"
1733             )
1734         ):
1735             node.children[0].value = ""
1736             node.children[2].value = ""
1737         yield from super().visit_default(node)
1738
1739     def visit_factor(self, node: Node) -> Iterator[Line]:
1740         """Force parentheses between a unary op and a binary power:
1741
1742         -2 ** 8 -> -(2 ** 8)
1743         """
1744         child = node.children[1]
1745         if child.type == syms.power and len(child.children) == 3:
1746             lpar = Leaf(token.LPAR, "(")
1747             rpar = Leaf(token.RPAR, ")")
1748             index = child.remove() or 0
1749             node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
1750         yield from self.visit_default(node)
1751
1752     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1753         """Increase indentation level, maybe yield a line."""
1754         # In blib2to3 INDENT never holds comments.
1755         yield from self.line(+1)
1756         yield from self.visit_default(node)
1757
1758     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1759         """Decrease indentation level, maybe yield a line."""
1760         # The current line might still wait for trailing comments.  At DEDENT time
1761         # there won't be any (they would be prefixes on the preceding NEWLINE).
1762         # Emit the line then.
1763         yield from self.line()
1764
1765         # While DEDENT has no value, its prefix may contain standalone comments
1766         # that belong to the current indentation level.  Get 'em.
1767         yield from self.visit_default(node)
1768
1769         # Finally, emit the dedent.
1770         yield from self.line(-1)
1771
1772     def visit_stmt(
1773         self, node: Node, keywords: Set[str], parens: Set[str]
1774     ) -> Iterator[Line]:
1775         """Visit a statement.
1776
1777         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1778         `def`, `with`, `class`, `assert` and assignments.
1779
1780         The relevant Python language `keywords` for a given statement will be
1781         NAME leaves within it. This methods puts those on a separate line.
1782
1783         `parens` holds a set of string leaf values immediately after which
1784         invisible parens should be put.
1785         """
1786         normalize_invisible_parens(node, parens_after=parens)
1787         for child in node.children:
1788             if child.type == token.NAME and child.value in keywords:  # type: ignore
1789                 yield from self.line()
1790
1791             yield from self.visit(child)
1792
1793     def visit_suite(self, node: Node) -> Iterator[Line]:
1794         """Visit a suite."""
1795         if self.is_pyi and is_stub_suite(node):
1796             yield from self.visit(node.children[2])
1797         else:
1798             yield from self.visit_default(node)
1799
1800     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1801         """Visit a statement without nested statements."""
1802         is_suite_like = node.parent and node.parent.type in STATEMENT
1803         if is_suite_like:
1804             if self.is_pyi and is_stub_body(node):
1805                 yield from self.visit_default(node)
1806             else:
1807                 yield from self.line(+1)
1808                 yield from self.visit_default(node)
1809                 yield from self.line(-1)
1810
1811         else:
1812             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1813                 yield from self.line()
1814             yield from self.visit_default(node)
1815
1816     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1817         """Visit `async def`, `async for`, `async with`."""
1818         yield from self.line()
1819
1820         children = iter(node.children)
1821         for child in children:
1822             yield from self.visit(child)
1823
1824             if child.type == token.ASYNC:
1825                 break
1826
1827         internal_stmt = next(children)
1828         for child in internal_stmt.children:
1829             yield from self.visit(child)
1830
1831     def visit_decorators(self, node: Node) -> Iterator[Line]:
1832         """Visit decorators."""
1833         for child in node.children:
1834             yield from self.line()
1835             yield from self.visit(child)
1836
1837     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1838         """Remove a semicolon and put the other statement on a separate line."""
1839         yield from self.line()
1840
1841     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1842         """End of file. Process outstanding comments and end with a newline."""
1843         yield from self.visit_default(leaf)
1844         yield from self.line()
1845
1846     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1847         if not self.current_line.bracket_tracker.any_open_brackets():
1848             yield from self.line()
1849         yield from self.visit_default(leaf)
1850
1851     def __attrs_post_init__(self) -> None:
1852         """You are in a twisty little maze of passages."""
1853         v = self.visit_stmt
1854         Ø: Set[str] = set()
1855         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1856         self.visit_if_stmt = partial(
1857             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1858         )
1859         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1860         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1861         self.visit_try_stmt = partial(
1862             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1863         )
1864         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1865         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1866         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1867         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1868         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1869         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1870         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1871         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1872         self.visit_async_funcdef = self.visit_async_stmt
1873         self.visit_decorated = self.visit_decorators
1874
1875
1876 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1877 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1878 OPENING_BRACKETS = set(BRACKET.keys())
1879 CLOSING_BRACKETS = set(BRACKET.values())
1880 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1881 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1882
1883
1884 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1885     """Return whitespace prefix if needed for the given `leaf`.
1886
1887     `complex_subscript` signals whether the given leaf is part of a subscription
1888     which has non-trivial arguments, like arithmetic expressions or function calls.
1889     """
1890     NO = ""
1891     SPACE = " "
1892     DOUBLESPACE = "  "
1893     t = leaf.type
1894     p = leaf.parent
1895     v = leaf.value
1896     if t in ALWAYS_NO_SPACE:
1897         return NO
1898
1899     if t == token.COMMENT:
1900         return DOUBLESPACE
1901
1902     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1903     if t == token.COLON and p.type not in {
1904         syms.subscript,
1905         syms.subscriptlist,
1906         syms.sliceop,
1907     }:
1908         return NO
1909
1910     prev = leaf.prev_sibling
1911     if not prev:
1912         prevp = preceding_leaf(p)
1913         if not prevp or prevp.type in OPENING_BRACKETS:
1914             return NO
1915
1916         if t == token.COLON:
1917             if prevp.type == token.COLON:
1918                 return NO
1919
1920             elif prevp.type != token.COMMA and not complex_subscript:
1921                 return NO
1922
1923             return SPACE
1924
1925         if prevp.type == token.EQUAL:
1926             if prevp.parent:
1927                 if prevp.parent.type in {
1928                     syms.arglist,
1929                     syms.argument,
1930                     syms.parameters,
1931                     syms.varargslist,
1932                 }:
1933                     return NO
1934
1935                 elif prevp.parent.type == syms.typedargslist:
1936                     # A bit hacky: if the equal sign has whitespace, it means we
1937                     # previously found it's a typed argument.  So, we're using
1938                     # that, too.
1939                     return prevp.prefix
1940
1941         elif prevp.type in VARARGS_SPECIALS:
1942             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1943                 return NO
1944
1945         elif prevp.type == token.COLON:
1946             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1947                 return SPACE if complex_subscript else NO
1948
1949         elif (
1950             prevp.parent
1951             and prevp.parent.type == syms.factor
1952             and prevp.type in MATH_OPERATORS
1953         ):
1954             return NO
1955
1956         elif (
1957             prevp.type == token.RIGHTSHIFT
1958             and prevp.parent
1959             and prevp.parent.type == syms.shift_expr
1960             and prevp.prev_sibling
1961             and prevp.prev_sibling.type == token.NAME
1962             and prevp.prev_sibling.value == "print"  # type: ignore
1963         ):
1964             # Python 2 print chevron
1965             return NO
1966
1967     elif prev.type in OPENING_BRACKETS:
1968         return NO
1969
1970     if p.type in {syms.parameters, syms.arglist}:
1971         # untyped function signatures or calls
1972         if not prev or prev.type != token.COMMA:
1973             return NO
1974
1975     elif p.type == syms.varargslist:
1976         # lambdas
1977         if prev and prev.type != token.COMMA:
1978             return NO
1979
1980     elif p.type == syms.typedargslist:
1981         # typed function signatures
1982         if not prev:
1983             return NO
1984
1985         if t == token.EQUAL:
1986             if prev.type != syms.tname:
1987                 return NO
1988
1989         elif prev.type == token.EQUAL:
1990             # A bit hacky: if the equal sign has whitespace, it means we
1991             # previously found it's a typed argument.  So, we're using that, too.
1992             return prev.prefix
1993
1994         elif prev.type != token.COMMA:
1995             return NO
1996
1997     elif p.type == syms.tname:
1998         # type names
1999         if not prev:
2000             prevp = preceding_leaf(p)
2001             if not prevp or prevp.type != token.COMMA:
2002                 return NO
2003
2004     elif p.type == syms.trailer:
2005         # attributes and calls
2006         if t == token.LPAR or t == token.RPAR:
2007             return NO
2008
2009         if not prev:
2010             if t == token.DOT:
2011                 prevp = preceding_leaf(p)
2012                 if not prevp or prevp.type != token.NUMBER:
2013                     return NO
2014
2015             elif t == token.LSQB:
2016                 return NO
2017
2018         elif prev.type != token.COMMA:
2019             return NO
2020
2021     elif p.type == syms.argument:
2022         # single argument
2023         if t == token.EQUAL:
2024             return NO
2025
2026         if not prev:
2027             prevp = preceding_leaf(p)
2028             if not prevp or prevp.type == token.LPAR:
2029                 return NO
2030
2031         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2032             return NO
2033
2034     elif p.type == syms.decorator:
2035         # decorators
2036         return NO
2037
2038     elif p.type == syms.dotted_name:
2039         if prev:
2040             return NO
2041
2042         prevp = preceding_leaf(p)
2043         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2044             return NO
2045
2046     elif p.type == syms.classdef:
2047         if t == token.LPAR:
2048             return NO
2049
2050         if prev and prev.type == token.LPAR:
2051             return NO
2052
2053     elif p.type in {syms.subscript, syms.sliceop}:
2054         # indexing
2055         if not prev:
2056             assert p.parent is not None, "subscripts are always parented"
2057             if p.parent.type == syms.subscriptlist:
2058                 return SPACE
2059
2060             return NO
2061
2062         elif not complex_subscript:
2063             return NO
2064
2065     elif p.type == syms.atom:
2066         if prev and t == token.DOT:
2067             # dots, but not the first one.
2068             return NO
2069
2070     elif p.type == syms.dictsetmaker:
2071         # dict unpacking
2072         if prev and prev.type == token.DOUBLESTAR:
2073             return NO
2074
2075     elif p.type in {syms.factor, syms.star_expr}:
2076         # unary ops
2077         if not prev:
2078             prevp = preceding_leaf(p)
2079             if not prevp or prevp.type in OPENING_BRACKETS:
2080                 return NO
2081
2082             prevp_parent = prevp.parent
2083             assert prevp_parent is not None
2084             if prevp.type == token.COLON and prevp_parent.type in {
2085                 syms.subscript,
2086                 syms.sliceop,
2087             }:
2088                 return NO
2089
2090             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2091                 return NO
2092
2093         elif t in {token.NAME, token.NUMBER, token.STRING}:
2094             return NO
2095
2096     elif p.type == syms.import_from:
2097         if t == token.DOT:
2098             if prev and prev.type == token.DOT:
2099                 return NO
2100
2101         elif t == token.NAME:
2102             if v == "import":
2103                 return SPACE
2104
2105             if prev and prev.type == token.DOT:
2106                 return NO
2107
2108     elif p.type == syms.sliceop:
2109         return NO
2110
2111     return SPACE
2112
2113
2114 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2115     """Return the first leaf that precedes `node`, if any."""
2116     while node:
2117         res = node.prev_sibling
2118         if res:
2119             if isinstance(res, Leaf):
2120                 return res
2121
2122             try:
2123                 return list(res.leaves())[-1]
2124
2125             except IndexError:
2126                 return None
2127
2128         node = node.parent
2129     return None
2130
2131
2132 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2133     """Return the child of `ancestor` that contains `descendant`."""
2134     node: Optional[LN] = descendant
2135     while node and node.parent != ancestor:
2136         node = node.parent
2137     return node
2138
2139
2140 def container_of(leaf: Leaf) -> LN:
2141     """Return `leaf` or one of its ancestors that is the topmost container of it.
2142
2143     By "container" we mean a node where `leaf` is the very first child.
2144     """
2145     same_prefix = leaf.prefix
2146     container: LN = leaf
2147     while container:
2148         parent = container.parent
2149         if parent is None:
2150             break
2151
2152         if parent.children[0].prefix != same_prefix:
2153             break
2154
2155         if parent.type == syms.file_input:
2156             break
2157
2158         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2159             break
2160
2161         container = parent
2162     return container
2163
2164
2165 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2166     """Return the priority of the `leaf` delimiter, given a line break after it.
2167
2168     The delimiter priorities returned here are from those delimiters that would
2169     cause a line break after themselves.
2170
2171     Higher numbers are higher priority.
2172     """
2173     if leaf.type == token.COMMA:
2174         return COMMA_PRIORITY
2175
2176     return 0
2177
2178
2179 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2180     """Return the priority of the `leaf` delimiter, given a line break before it.
2181
2182     The delimiter priorities returned here are from those delimiters that would
2183     cause a line break before themselves.
2184
2185     Higher numbers are higher priority.
2186     """
2187     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2188         # * and ** might also be MATH_OPERATORS but in this case they are not.
2189         # Don't treat them as a delimiter.
2190         return 0
2191
2192     if (
2193         leaf.type == token.DOT
2194         and leaf.parent
2195         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2196         and (previous is None or previous.type in CLOSING_BRACKETS)
2197     ):
2198         return DOT_PRIORITY
2199
2200     if (
2201         leaf.type in MATH_OPERATORS
2202         and leaf.parent
2203         and leaf.parent.type not in {syms.factor, syms.star_expr}
2204     ):
2205         return MATH_PRIORITIES[leaf.type]
2206
2207     if leaf.type in COMPARATORS:
2208         return COMPARATOR_PRIORITY
2209
2210     if (
2211         leaf.type == token.STRING
2212         and previous is not None
2213         and previous.type == token.STRING
2214     ):
2215         return STRING_PRIORITY
2216
2217     if leaf.type not in {token.NAME, token.ASYNC}:
2218         return 0
2219
2220     if (
2221         leaf.value == "for"
2222         and leaf.parent
2223         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2224         or leaf.type == token.ASYNC
2225     ):
2226         if (
2227             not isinstance(leaf.prev_sibling, Leaf)
2228             or leaf.prev_sibling.value != "async"
2229         ):
2230             return COMPREHENSION_PRIORITY
2231
2232     if (
2233         leaf.value == "if"
2234         and leaf.parent
2235         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2236     ):
2237         return COMPREHENSION_PRIORITY
2238
2239     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2240         return TERNARY_PRIORITY
2241
2242     if leaf.value == "is":
2243         return COMPARATOR_PRIORITY
2244
2245     if (
2246         leaf.value == "in"
2247         and leaf.parent
2248         and leaf.parent.type in {syms.comp_op, syms.comparison}
2249         and not (
2250             previous is not None
2251             and previous.type == token.NAME
2252             and previous.value == "not"
2253         )
2254     ):
2255         return COMPARATOR_PRIORITY
2256
2257     if (
2258         leaf.value == "not"
2259         and leaf.parent
2260         and leaf.parent.type == syms.comp_op
2261         and not (
2262             previous is not None
2263             and previous.type == token.NAME
2264             and previous.value == "is"
2265         )
2266     ):
2267         return COMPARATOR_PRIORITY
2268
2269     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2270         return LOGIC_PRIORITY
2271
2272     return 0
2273
2274
2275 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2276 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2277
2278
2279 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2280     """Clean the prefix of the `leaf` and generate comments from it, if any.
2281
2282     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2283     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2284     move because it does away with modifying the grammar to include all the
2285     possible places in which comments can be placed.
2286
2287     The sad consequence for us though is that comments don't "belong" anywhere.
2288     This is why this function generates simple parentless Leaf objects for
2289     comments.  We simply don't know what the correct parent should be.
2290
2291     No matter though, we can live without this.  We really only need to
2292     differentiate between inline and standalone comments.  The latter don't
2293     share the line with any code.
2294
2295     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2296     are emitted with a fake STANDALONE_COMMENT token identifier.
2297     """
2298     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2299         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2300
2301
2302 @dataclass
2303 class ProtoComment:
2304     """Describes a piece of syntax that is a comment.
2305
2306     It's not a :class:`blib2to3.pytree.Leaf` so that:
2307
2308     * it can be cached (`Leaf` objects should not be reused more than once as
2309       they store their lineno, column, prefix, and parent information);
2310     * `newlines` and `consumed` fields are kept separate from the `value`. This
2311       simplifies handling of special marker comments like ``# fmt: off/on``.
2312     """
2313
2314     type: int  # token.COMMENT or STANDALONE_COMMENT
2315     value: str  # content of the comment
2316     newlines: int  # how many newlines before the comment
2317     consumed: int  # how many characters of the original leaf's prefix did we consume
2318
2319
2320 @lru_cache(maxsize=4096)
2321 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2322     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2323     result: List[ProtoComment] = []
2324     if not prefix or "#" not in prefix:
2325         return result
2326
2327     consumed = 0
2328     nlines = 0
2329     ignored_lines = 0
2330     for index, line in enumerate(prefix.split("\n")):
2331         consumed += len(line) + 1  # adding the length of the split '\n'
2332         line = line.lstrip()
2333         if not line:
2334             nlines += 1
2335         if not line.startswith("#"):
2336             # Escaped newlines outside of a comment are not really newlines at
2337             # all. We treat a single-line comment following an escaped newline
2338             # as a simple trailing comment.
2339             if line.endswith("\\"):
2340                 ignored_lines += 1
2341             continue
2342
2343         if index == ignored_lines and not is_endmarker:
2344             comment_type = token.COMMENT  # simple trailing comment
2345         else:
2346             comment_type = STANDALONE_COMMENT
2347         comment = make_comment(line)
2348         result.append(
2349             ProtoComment(
2350                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2351             )
2352         )
2353         nlines = 0
2354     return result
2355
2356
2357 def make_comment(content: str) -> str:
2358     """Return a consistently formatted comment from the given `content` string.
2359
2360     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2361     space between the hash sign and the content.
2362
2363     If `content` didn't start with a hash sign, one is provided.
2364     """
2365     content = content.rstrip()
2366     if not content:
2367         return "#"
2368
2369     if content[0] == "#":
2370         content = content[1:]
2371     if content and content[0] not in " !:#'%":
2372         content = " " + content
2373     return "#" + content
2374
2375
2376 def split_line(
2377     line: Line,
2378     line_length: int,
2379     inner: bool = False,
2380     features: Collection[Feature] = (),
2381 ) -> Iterator[Line]:
2382     """Split a `line` into potentially many lines.
2383
2384     They should fit in the allotted `line_length` but might not be able to.
2385     `inner` signifies that there were a pair of brackets somewhere around the
2386     current `line`, possibly transitively. This means we can fallback to splitting
2387     by delimiters if the LHS/RHS don't yield any results.
2388
2389     `features` are syntactical features that may be used in the output.
2390     """
2391     if line.is_comment:
2392         yield line
2393         return
2394
2395     line_str = str(line).strip("\n")
2396
2397     if (
2398         not line.contains_uncollapsable_type_comments()
2399         and not line.should_explode
2400         and not line.is_collection_with_optional_trailing_comma
2401         and (
2402             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2403             or line.contains_unsplittable_type_ignore()
2404         )
2405     ):
2406         yield line
2407         return
2408
2409     split_funcs: List[SplitFunc]
2410     if line.is_def:
2411         split_funcs = [left_hand_split]
2412     else:
2413
2414         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2415             for omit in generate_trailers_to_omit(line, line_length):
2416                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2417                 if is_line_short_enough(lines[0], line_length=line_length):
2418                     yield from lines
2419                     return
2420
2421             # All splits failed, best effort split with no omits.
2422             # This mostly happens to multiline strings that are by definition
2423             # reported as not fitting a single line.
2424             yield from right_hand_split(line, line_length, features=features)
2425
2426         if line.inside_brackets:
2427             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2428         else:
2429             split_funcs = [rhs]
2430     for split_func in split_funcs:
2431         # We are accumulating lines in `result` because we might want to abort
2432         # mission and return the original line in the end, or attempt a different
2433         # split altogether.
2434         result: List[Line] = []
2435         try:
2436             for l in split_func(line, features):
2437                 if str(l).strip("\n") == line_str:
2438                     raise CannotSplit("Split function returned an unchanged result")
2439
2440                 result.extend(
2441                     split_line(
2442                         l, line_length=line_length, inner=True, features=features
2443                     )
2444                 )
2445         except CannotSplit:
2446             continue
2447
2448         else:
2449             yield from result
2450             break
2451
2452     else:
2453         yield line
2454
2455
2456 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2457     """Split line into many lines, starting with the first matching bracket pair.
2458
2459     Note: this usually looks weird, only use this for function definitions.
2460     Prefer RHS otherwise.  This is why this function is not symmetrical with
2461     :func:`right_hand_split` which also handles optional parentheses.
2462     """
2463     tail_leaves: List[Leaf] = []
2464     body_leaves: List[Leaf] = []
2465     head_leaves: List[Leaf] = []
2466     current_leaves = head_leaves
2467     matching_bracket = None
2468     for leaf in line.leaves:
2469         if (
2470             current_leaves is body_leaves
2471             and leaf.type in CLOSING_BRACKETS
2472             and leaf.opening_bracket is matching_bracket
2473         ):
2474             current_leaves = tail_leaves if body_leaves else head_leaves
2475         current_leaves.append(leaf)
2476         if current_leaves is head_leaves:
2477             if leaf.type in OPENING_BRACKETS:
2478                 matching_bracket = leaf
2479                 current_leaves = body_leaves
2480     if not matching_bracket:
2481         raise CannotSplit("No brackets found")
2482
2483     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2484     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2485     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2486     bracket_split_succeeded_or_raise(head, body, tail)
2487     for result in (head, body, tail):
2488         if result:
2489             yield result
2490
2491
2492 def right_hand_split(
2493     line: Line,
2494     line_length: int,
2495     features: Collection[Feature] = (),
2496     omit: Collection[LeafID] = (),
2497 ) -> Iterator[Line]:
2498     """Split line into many lines, starting with the last matching bracket pair.
2499
2500     If the split was by optional parentheses, attempt splitting without them, too.
2501     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2502     this split.
2503
2504     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2505     """
2506     tail_leaves: List[Leaf] = []
2507     body_leaves: List[Leaf] = []
2508     head_leaves: List[Leaf] = []
2509     current_leaves = tail_leaves
2510     opening_bracket = None
2511     closing_bracket = None
2512     for leaf in reversed(line.leaves):
2513         if current_leaves is body_leaves:
2514             if leaf is opening_bracket:
2515                 current_leaves = head_leaves if body_leaves else tail_leaves
2516         current_leaves.append(leaf)
2517         if current_leaves is tail_leaves:
2518             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2519                 opening_bracket = leaf.opening_bracket
2520                 closing_bracket = leaf
2521                 current_leaves = body_leaves
2522     if not (opening_bracket and closing_bracket and head_leaves):
2523         # If there is no opening or closing_bracket that means the split failed and
2524         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2525         # the matching `opening_bracket` wasn't available on `line` anymore.
2526         raise CannotSplit("No brackets found")
2527
2528     tail_leaves.reverse()
2529     body_leaves.reverse()
2530     head_leaves.reverse()
2531     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2532     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2533     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2534     bracket_split_succeeded_or_raise(head, body, tail)
2535     if (
2536         # the body shouldn't be exploded
2537         not body.should_explode
2538         # the opening bracket is an optional paren
2539         and opening_bracket.type == token.LPAR
2540         and not opening_bracket.value
2541         # the closing bracket is an optional paren
2542         and closing_bracket.type == token.RPAR
2543         and not closing_bracket.value
2544         # it's not an import (optional parens are the only thing we can split on
2545         # in this case; attempting a split without them is a waste of time)
2546         and not line.is_import
2547         # there are no standalone comments in the body
2548         and not body.contains_standalone_comments(0)
2549         # and we can actually remove the parens
2550         and can_omit_invisible_parens(body, line_length)
2551     ):
2552         omit = {id(closing_bracket), *omit}
2553         try:
2554             yield from right_hand_split(line, line_length, features=features, omit=omit)
2555             return
2556
2557         except CannotSplit:
2558             if not (
2559                 can_be_split(body)
2560                 or is_line_short_enough(body, line_length=line_length)
2561             ):
2562                 raise CannotSplit(
2563                     "Splitting failed, body is still too long and can't be split."
2564                 )
2565
2566             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2567                 raise CannotSplit(
2568                     "The current optional pair of parentheses is bound to fail to "
2569                     "satisfy the splitting algorithm because the head or the tail "
2570                     "contains multiline strings which by definition never fit one "
2571                     "line."
2572                 )
2573
2574     ensure_visible(opening_bracket)
2575     ensure_visible(closing_bracket)
2576     for result in (head, body, tail):
2577         if result:
2578             yield result
2579
2580
2581 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2582     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2583
2584     Do nothing otherwise.
2585
2586     A left- or right-hand split is based on a pair of brackets. Content before
2587     (and including) the opening bracket is left on one line, content inside the
2588     brackets is put on a separate line, and finally content starting with and
2589     following the closing bracket is put on a separate line.
2590
2591     Those are called `head`, `body`, and `tail`, respectively. If the split
2592     produced the same line (all content in `head`) or ended up with an empty `body`
2593     and the `tail` is just the closing bracket, then it's considered failed.
2594     """
2595     tail_len = len(str(tail).strip())
2596     if not body:
2597         if tail_len == 0:
2598             raise CannotSplit("Splitting brackets produced the same line")
2599
2600         elif tail_len < 3:
2601             raise CannotSplit(
2602                 f"Splitting brackets on an empty body to save "
2603                 f"{tail_len} characters is not worth it"
2604             )
2605
2606
2607 def bracket_split_build_line(
2608     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2609 ) -> Line:
2610     """Return a new line with given `leaves` and respective comments from `original`.
2611
2612     If `is_body` is True, the result line is one-indented inside brackets and as such
2613     has its first leaf's prefix normalized and a trailing comma added when expected.
2614     """
2615     result = Line(depth=original.depth)
2616     if is_body:
2617         result.inside_brackets = True
2618         result.depth += 1
2619         if leaves:
2620             # Since body is a new indent level, remove spurious leading whitespace.
2621             normalize_prefix(leaves[0], inside_brackets=True)
2622             # Ensure a trailing comma for imports and standalone function arguments, but
2623             # be careful not to add one after any comments.
2624             no_commas = original.is_def and not any(
2625                 l.type == token.COMMA for l in leaves
2626             )
2627
2628             if original.is_import or no_commas:
2629                 for i in range(len(leaves) - 1, -1, -1):
2630                     if leaves[i].type == STANDALONE_COMMENT:
2631                         continue
2632                     elif leaves[i].type == token.COMMA:
2633                         break
2634                     else:
2635                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2636                         break
2637     # Populate the line
2638     for leaf in leaves:
2639         result.append(leaf, preformatted=True)
2640         for comment_after in original.comments_after(leaf):
2641             result.append(comment_after, preformatted=True)
2642     if is_body:
2643         result.should_explode = should_explode(result, opening_bracket)
2644     return result
2645
2646
2647 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2648     """Normalize prefix of the first leaf in every line returned by `split_func`.
2649
2650     This is a decorator over relevant split functions.
2651     """
2652
2653     @wraps(split_func)
2654     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2655         for l in split_func(line, features):
2656             normalize_prefix(l.leaves[0], inside_brackets=True)
2657             yield l
2658
2659     return split_wrapper
2660
2661
2662 @dont_increase_indentation
2663 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2664     """Split according to delimiters of the highest priority.
2665
2666     If the appropriate Features are given, the split will add trailing commas
2667     also in function signatures and calls that contain `*` and `**`.
2668     """
2669     try:
2670         last_leaf = line.leaves[-1]
2671     except IndexError:
2672         raise CannotSplit("Line empty")
2673
2674     bt = line.bracket_tracker
2675     try:
2676         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2677     except ValueError:
2678         raise CannotSplit("No delimiters found")
2679
2680     if delimiter_priority == DOT_PRIORITY:
2681         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2682             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2683
2684     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2685     lowest_depth = sys.maxsize
2686     trailing_comma_safe = True
2687
2688     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2689         """Append `leaf` to current line or to new line if appending impossible."""
2690         nonlocal current_line
2691         try:
2692             current_line.append_safe(leaf, preformatted=True)
2693         except ValueError:
2694             yield current_line
2695
2696             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2697             current_line.append(leaf)
2698
2699     for leaf in line.leaves:
2700         yield from append_to_line(leaf)
2701
2702         for comment_after in line.comments_after(leaf):
2703             yield from append_to_line(comment_after)
2704
2705         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2706         if leaf.bracket_depth == lowest_depth:
2707             if is_vararg(leaf, within={syms.typedargslist}):
2708                 trailing_comma_safe = (
2709                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2710                 )
2711             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2712                 trailing_comma_safe = (
2713                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2714                 )
2715
2716         leaf_priority = bt.delimiters.get(id(leaf))
2717         if leaf_priority == delimiter_priority:
2718             yield current_line
2719
2720             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2721     if current_line:
2722         if (
2723             trailing_comma_safe
2724             and delimiter_priority == COMMA_PRIORITY
2725             and current_line.leaves[-1].type != token.COMMA
2726             and current_line.leaves[-1].type != STANDALONE_COMMENT
2727         ):
2728             current_line.append(Leaf(token.COMMA, ","))
2729         yield current_line
2730
2731
2732 @dont_increase_indentation
2733 def standalone_comment_split(
2734     line: Line, features: Collection[Feature] = ()
2735 ) -> Iterator[Line]:
2736     """Split standalone comments from the rest of the line."""
2737     if not line.contains_standalone_comments(0):
2738         raise CannotSplit("Line does not have any standalone comments")
2739
2740     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2741
2742     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2743         """Append `leaf` to current line or to new line if appending impossible."""
2744         nonlocal current_line
2745         try:
2746             current_line.append_safe(leaf, preformatted=True)
2747         except ValueError:
2748             yield current_line
2749
2750             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2751             current_line.append(leaf)
2752
2753     for leaf in line.leaves:
2754         yield from append_to_line(leaf)
2755
2756         for comment_after in line.comments_after(leaf):
2757             yield from append_to_line(comment_after)
2758
2759     if current_line:
2760         yield current_line
2761
2762
2763 def is_import(leaf: Leaf) -> bool:
2764     """Return True if the given leaf starts an import statement."""
2765     p = leaf.parent
2766     t = leaf.type
2767     v = leaf.value
2768     return bool(
2769         t == token.NAME
2770         and (
2771             (v == "import" and p and p.type == syms.import_name)
2772             or (v == "from" and p and p.type == syms.import_from)
2773         )
2774     )
2775
2776
2777 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
2778     """Return True if the given leaf is a special comment.
2779     Only returns true for type comments for now."""
2780     t = leaf.type
2781     v = leaf.value
2782     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith(
2783         "# type:" + suffix
2784     )
2785
2786
2787 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2788     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2789     else.
2790
2791     Note: don't use backslashes for formatting or you'll lose your voting rights.
2792     """
2793     if not inside_brackets:
2794         spl = leaf.prefix.split("#")
2795         if "\\" not in spl[0]:
2796             nl_count = spl[-1].count("\n")
2797             if len(spl) > 1:
2798                 nl_count -= 1
2799             leaf.prefix = "\n" * nl_count
2800             return
2801
2802     leaf.prefix = ""
2803
2804
2805 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2806     """Make all string prefixes lowercase.
2807
2808     If remove_u_prefix is given, also removes any u prefix from the string.
2809
2810     Note: Mutates its argument.
2811     """
2812     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2813     assert match is not None, f"failed to match string {leaf.value!r}"
2814     orig_prefix = match.group(1)
2815     new_prefix = orig_prefix.lower()
2816     if remove_u_prefix:
2817         new_prefix = new_prefix.replace("u", "")
2818     leaf.value = f"{new_prefix}{match.group(2)}"
2819
2820
2821 def normalize_string_quotes(leaf: Leaf) -> None:
2822     """Prefer double quotes but only if it doesn't cause more escaping.
2823
2824     Adds or removes backslashes as appropriate. Doesn't parse and fix
2825     strings nested in f-strings (yet).
2826
2827     Note: Mutates its argument.
2828     """
2829     value = leaf.value.lstrip("furbFURB")
2830     if value[:3] == '"""':
2831         return
2832
2833     elif value[:3] == "'''":
2834         orig_quote = "'''"
2835         new_quote = '"""'
2836     elif value[0] == '"':
2837         orig_quote = '"'
2838         new_quote = "'"
2839     else:
2840         orig_quote = "'"
2841         new_quote = '"'
2842     first_quote_pos = leaf.value.find(orig_quote)
2843     if first_quote_pos == -1:
2844         return  # There's an internal error
2845
2846     prefix = leaf.value[:first_quote_pos]
2847     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2848     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2849     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2850     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2851     if "r" in prefix.casefold():
2852         if unescaped_new_quote.search(body):
2853             # There's at least one unescaped new_quote in this raw string
2854             # so converting is impossible
2855             return
2856
2857         # Do not introduce or remove backslashes in raw strings
2858         new_body = body
2859     else:
2860         # remove unnecessary escapes
2861         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2862         if body != new_body:
2863             # Consider the string without unnecessary escapes as the original
2864             body = new_body
2865             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2866         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2867         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2868     if "f" in prefix.casefold():
2869         matches = re.findall(
2870             r"""
2871             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2872                 ([^{].*?)  # contents of the brackets except if begins with {{
2873             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2874             """,
2875             new_body,
2876             re.VERBOSE,
2877         )
2878         for m in matches:
2879             if "\\" in str(m):
2880                 # Do not introduce backslashes in interpolated expressions
2881                 return
2882     if new_quote == '"""' and new_body[-1:] == '"':
2883         # edge case:
2884         new_body = new_body[:-1] + '\\"'
2885     orig_escape_count = body.count("\\")
2886     new_escape_count = new_body.count("\\")
2887     if new_escape_count > orig_escape_count:
2888         return  # Do not introduce more escaping
2889
2890     if new_escape_count == orig_escape_count and orig_quote == '"':
2891         return  # Prefer double quotes
2892
2893     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2894
2895
2896 def normalize_numeric_literal(leaf: Leaf) -> None:
2897     """Normalizes numeric (float, int, and complex) literals.
2898
2899     All letters used in the representation are normalized to lowercase (except
2900     in Python 2 long literals).
2901     """
2902     text = leaf.value.lower()
2903     if text.startswith(("0o", "0b")):
2904         # Leave octal and binary literals alone.
2905         pass
2906     elif text.startswith("0x"):
2907         # Change hex literals to upper case.
2908         before, after = text[:2], text[2:]
2909         text = f"{before}{after.upper()}"
2910     elif "e" in text:
2911         before, after = text.split("e")
2912         sign = ""
2913         if after.startswith("-"):
2914             after = after[1:]
2915             sign = "-"
2916         elif after.startswith("+"):
2917             after = after[1:]
2918         before = format_float_or_int_string(before)
2919         text = f"{before}e{sign}{after}"
2920     elif text.endswith(("j", "l")):
2921         number = text[:-1]
2922         suffix = text[-1]
2923         # Capitalize in "2L" because "l" looks too similar to "1".
2924         if suffix == "l":
2925             suffix = "L"
2926         text = f"{format_float_or_int_string(number)}{suffix}"
2927     else:
2928         text = format_float_or_int_string(text)
2929     leaf.value = text
2930
2931
2932 def format_float_or_int_string(text: str) -> str:
2933     """Formats a float string like "1.0"."""
2934     if "." not in text:
2935         return text
2936
2937     before, after = text.split(".")
2938     return f"{before or 0}.{after or 0}"
2939
2940
2941 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2942     """Make existing optional parentheses invisible or create new ones.
2943
2944     `parens_after` is a set of string leaf values immediately after which parens
2945     should be put.
2946
2947     Standardizes on visible parentheses for single-element tuples, and keeps
2948     existing visible parentheses for other tuples and generator expressions.
2949     """
2950     for pc in list_comments(node.prefix, is_endmarker=False):
2951         if pc.value in FMT_OFF:
2952             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2953             return
2954
2955     check_lpar = False
2956     for index, child in enumerate(list(node.children)):
2957         # Add parentheses around long tuple unpacking in assignments.
2958         if (
2959             index == 0
2960             and isinstance(child, Node)
2961             and child.type == syms.testlist_star_expr
2962         ):
2963             check_lpar = True
2964
2965         if check_lpar:
2966             if is_walrus_assignment(child):
2967                 continue
2968             if child.type == syms.atom:
2969                 # Determines if the underlying atom should be surrounded with
2970                 # invisible params - also makes parens invisible recursively
2971                 # within the atom and removes repeated invisible parens within
2972                 # the atom
2973                 should_surround_with_parens = maybe_make_parens_invisible_in_atom(
2974                     child, parent=node
2975                 )
2976
2977                 if should_surround_with_parens:
2978                     lpar = Leaf(token.LPAR, "")
2979                     rpar = Leaf(token.RPAR, "")
2980                     index = child.remove() or 0
2981                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2982             elif is_one_tuple(child):
2983                 # wrap child in visible parentheses
2984                 lpar = Leaf(token.LPAR, "(")
2985                 rpar = Leaf(token.RPAR, ")")
2986                 child.remove()
2987                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2988             elif node.type == syms.import_from:
2989                 # "import from" nodes store parentheses directly as part of
2990                 # the statement
2991                 if child.type == token.LPAR:
2992                     # make parentheses invisible
2993                     child.value = ""  # type: ignore
2994                     node.children[-1].value = ""  # type: ignore
2995                 elif child.type != token.STAR:
2996                     # insert invisible parentheses
2997                     node.insert_child(index, Leaf(token.LPAR, ""))
2998                     node.append_child(Leaf(token.RPAR, ""))
2999                 break
3000
3001             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
3002                 # wrap child in invisible parentheses
3003                 lpar = Leaf(token.LPAR, "")
3004                 rpar = Leaf(token.RPAR, "")
3005                 index = child.remove() or 0
3006                 prefix = child.prefix
3007                 child.prefix = ""
3008                 new_child = Node(syms.atom, [lpar, child, rpar])
3009                 new_child.prefix = prefix
3010                 node.insert_child(index, new_child)
3011
3012         check_lpar = isinstance(child, Leaf) and child.value in parens_after
3013
3014
3015 def normalize_fmt_off(node: Node) -> None:
3016     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
3017     try_again = True
3018     while try_again:
3019         try_again = convert_one_fmt_off_pair(node)
3020
3021
3022 def convert_one_fmt_off_pair(node: Node) -> bool:
3023     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
3024
3025     Returns True if a pair was converted.
3026     """
3027     for leaf in node.leaves():
3028         previous_consumed = 0
3029         for comment in list_comments(leaf.prefix, is_endmarker=False):
3030             if comment.value in FMT_OFF:
3031                 # We only want standalone comments. If there's no previous leaf or
3032                 # the previous leaf is indentation, it's a standalone comment in
3033                 # disguise.
3034                 if comment.type != STANDALONE_COMMENT:
3035                     prev = preceding_leaf(leaf)
3036                     if prev and prev.type not in WHITESPACE:
3037                         continue
3038
3039                 ignored_nodes = list(generate_ignored_nodes(leaf))
3040                 if not ignored_nodes:
3041                     continue
3042
3043                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
3044                 parent = first.parent
3045                 prefix = first.prefix
3046                 first.prefix = prefix[comment.consumed :]
3047                 hidden_value = (
3048                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
3049                 )
3050                 if hidden_value.endswith("\n"):
3051                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
3052                     # leaf (possibly followed by a DEDENT).
3053                     hidden_value = hidden_value[:-1]
3054                 first_idx = None
3055                 for ignored in ignored_nodes:
3056                     index = ignored.remove()
3057                     if first_idx is None:
3058                         first_idx = index
3059                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
3060                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
3061                 parent.insert_child(
3062                     first_idx,
3063                     Leaf(
3064                         STANDALONE_COMMENT,
3065                         hidden_value,
3066                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3067                     ),
3068                 )
3069                 return True
3070
3071             previous_consumed = comment.consumed
3072
3073     return False
3074
3075
3076 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3077     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3078
3079     Stops at the end of the block.
3080     """
3081     container: Optional[LN] = container_of(leaf)
3082     while container is not None and container.type != token.ENDMARKER:
3083         for comment in list_comments(container.prefix, is_endmarker=False):
3084             if comment.value in FMT_ON:
3085                 return
3086
3087         yield container
3088
3089         container = container.next_sibling
3090
3091
3092 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3093     """If it's safe, make the parens in the atom `node` invisible, recursively.
3094     Additionally, remove repeated, adjacent invisible parens from the atom `node`
3095     as they are redundant.
3096
3097     Returns whether the node should itself be wrapped in invisible parentheses.
3098
3099     """
3100     if (
3101         node.type != syms.atom
3102         or is_empty_tuple(node)
3103         or is_one_tuple(node)
3104         or (is_yield(node) and parent.type != syms.expr_stmt)
3105         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3106     ):
3107         return False
3108
3109     first = node.children[0]
3110     last = node.children[-1]
3111     if first.type == token.LPAR and last.type == token.RPAR:
3112         middle = node.children[1]
3113         # make parentheses invisible
3114         first.value = ""  # type: ignore
3115         last.value = ""  # type: ignore
3116         maybe_make_parens_invisible_in_atom(middle, parent=parent)
3117
3118         if is_atom_with_invisible_parens(middle):
3119             # Strip the invisible parens from `middle` by replacing
3120             # it with the child in-between the invisible parens
3121             middle.replace(middle.children[1])
3122
3123         return False
3124
3125     return True
3126
3127
3128 def is_atom_with_invisible_parens(node: LN) -> bool:
3129     """Given a `LN`, determines whether it's an atom `node` with invisible
3130     parens. Useful in dedupe-ing and normalizing parens.
3131     """
3132     if isinstance(node, Leaf) or node.type != syms.atom:
3133         return False
3134
3135     first, last = node.children[0], node.children[-1]
3136     return (
3137         isinstance(first, Leaf)
3138         and first.type == token.LPAR
3139         and first.value == ""
3140         and isinstance(last, Leaf)
3141         and last.type == token.RPAR
3142         and last.value == ""
3143     )
3144
3145
3146 def is_empty_tuple(node: LN) -> bool:
3147     """Return True if `node` holds an empty tuple."""
3148     return (
3149         node.type == syms.atom
3150         and len(node.children) == 2
3151         and node.children[0].type == token.LPAR
3152         and node.children[1].type == token.RPAR
3153     )
3154
3155
3156 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3157     """Returns `wrapped` if `node` is of the shape ( wrapped ).
3158
3159     Parenthesis can be optional. Returns None otherwise"""
3160     if len(node.children) != 3:
3161         return None
3162     lpar, wrapped, rpar = node.children
3163     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3164         return None
3165
3166     return wrapped
3167
3168
3169 def is_one_tuple(node: LN) -> bool:
3170     """Return True if `node` holds a tuple with one element, with or without parens."""
3171     if node.type == syms.atom:
3172         gexp = unwrap_singleton_parenthesis(node)
3173         if gexp is None or gexp.type != syms.testlist_gexp:
3174             return False
3175
3176         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3177
3178     return (
3179         node.type in IMPLICIT_TUPLE
3180         and len(node.children) == 2
3181         and node.children[1].type == token.COMMA
3182     )
3183
3184
3185 def is_walrus_assignment(node: LN) -> bool:
3186     """Return True iff `node` is of the shape ( test := test )"""
3187     inner = unwrap_singleton_parenthesis(node)
3188     return inner is not None and inner.type == syms.namedexpr_test
3189
3190
3191 def is_yield(node: LN) -> bool:
3192     """Return True if `node` holds a `yield` or `yield from` expression."""
3193     if node.type == syms.yield_expr:
3194         return True
3195
3196     if node.type == token.NAME and node.value == "yield":  # type: ignore
3197         return True
3198
3199     if node.type != syms.atom:
3200         return False
3201
3202     if len(node.children) != 3:
3203         return False
3204
3205     lpar, expr, rpar = node.children
3206     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3207         return is_yield(expr)
3208
3209     return False
3210
3211
3212 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3213     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3214
3215     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3216     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3217     extended iterable unpacking (PEP 3132) and additional unpacking
3218     generalizations (PEP 448).
3219     """
3220     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3221         return False
3222
3223     p = leaf.parent
3224     if p.type == syms.star_expr:
3225         # Star expressions are also used as assignment targets in extended
3226         # iterable unpacking (PEP 3132).  See what its parent is instead.
3227         if not p.parent:
3228             return False
3229
3230         p = p.parent
3231
3232     return p.type in within
3233
3234
3235 def is_multiline_string(leaf: Leaf) -> bool:
3236     """Return True if `leaf` is a multiline string that actually spans many lines."""
3237     value = leaf.value.lstrip("furbFURB")
3238     return value[:3] in {'"""', "'''"} and "\n" in value
3239
3240
3241 def is_stub_suite(node: Node) -> bool:
3242     """Return True if `node` is a suite with a stub body."""
3243     if (
3244         len(node.children) != 4
3245         or node.children[0].type != token.NEWLINE
3246         or node.children[1].type != token.INDENT
3247         or node.children[3].type != token.DEDENT
3248     ):
3249         return False
3250
3251     return is_stub_body(node.children[2])
3252
3253
3254 def is_stub_body(node: LN) -> bool:
3255     """Return True if `node` is a simple statement containing an ellipsis."""
3256     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3257         return False
3258
3259     if len(node.children) != 2:
3260         return False
3261
3262     child = node.children[0]
3263     return (
3264         child.type == syms.atom
3265         and len(child.children) == 3
3266         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3267     )
3268
3269
3270 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3271     """Return maximum delimiter priority inside `node`.
3272
3273     This is specific to atoms with contents contained in a pair of parentheses.
3274     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3275     """
3276     if node.type != syms.atom:
3277         return 0
3278
3279     first = node.children[0]
3280     last = node.children[-1]
3281     if not (first.type == token.LPAR and last.type == token.RPAR):
3282         return 0
3283
3284     bt = BracketTracker()
3285     for c in node.children[1:-1]:
3286         if isinstance(c, Leaf):
3287             bt.mark(c)
3288         else:
3289             for leaf in c.leaves():
3290                 bt.mark(leaf)
3291     try:
3292         return bt.max_delimiter_priority()
3293
3294     except ValueError:
3295         return 0
3296
3297
3298 def ensure_visible(leaf: Leaf) -> None:
3299     """Make sure parentheses are visible.
3300
3301     They could be invisible as part of some statements (see
3302     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3303     """
3304     if leaf.type == token.LPAR:
3305         leaf.value = "("
3306     elif leaf.type == token.RPAR:
3307         leaf.value = ")"
3308
3309
3310 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3311     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3312
3313     if not (
3314         opening_bracket.parent
3315         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3316         and opening_bracket.value in "[{("
3317     ):
3318         return False
3319
3320     try:
3321         last_leaf = line.leaves[-1]
3322         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3323         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3324     except (IndexError, ValueError):
3325         return False
3326
3327     return max_priority == COMMA_PRIORITY
3328
3329
3330 def get_features_used(node: Node) -> Set[Feature]:
3331     """Return a set of (relatively) new Python features used in this file.
3332
3333     Currently looking for:
3334     - f-strings;
3335     - underscores in numeric literals;
3336     - trailing commas after * or ** in function signatures and calls;
3337     - positional only arguments in function signatures and lambdas;
3338     """
3339     features: Set[Feature] = set()
3340     for n in node.pre_order():
3341         if n.type == token.STRING:
3342             value_head = n.value[:2]  # type: ignore
3343             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3344                 features.add(Feature.F_STRINGS)
3345
3346         elif n.type == token.NUMBER:
3347             if "_" in n.value:  # type: ignore
3348                 features.add(Feature.NUMERIC_UNDERSCORES)
3349
3350         elif n.type == token.SLASH:
3351             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3352                 features.add(Feature.POS_ONLY_ARGUMENTS)
3353
3354         elif n.type == token.COLONEQUAL:
3355             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3356
3357         elif (
3358             n.type in {syms.typedargslist, syms.arglist}
3359             and n.children
3360             and n.children[-1].type == token.COMMA
3361         ):
3362             if n.type == syms.typedargslist:
3363                 feature = Feature.TRAILING_COMMA_IN_DEF
3364             else:
3365                 feature = Feature.TRAILING_COMMA_IN_CALL
3366
3367             for ch in n.children:
3368                 if ch.type in STARS:
3369                     features.add(feature)
3370
3371                 if ch.type == syms.argument:
3372                     for argch in ch.children:
3373                         if argch.type in STARS:
3374                             features.add(feature)
3375
3376     return features
3377
3378
3379 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3380     """Detect the version to target based on the nodes used."""
3381     features = get_features_used(node)
3382     return {
3383         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3384     }
3385
3386
3387 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3388     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3389
3390     Brackets can be omitted if the entire trailer up to and including
3391     a preceding closing bracket fits in one line.
3392
3393     Yielded sets are cumulative (contain results of previous yields, too).  First
3394     set is empty.
3395     """
3396
3397     omit: Set[LeafID] = set()
3398     yield omit
3399
3400     length = 4 * line.depth
3401     opening_bracket = None
3402     closing_bracket = None
3403     inner_brackets: Set[LeafID] = set()
3404     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3405         length += leaf_length
3406         if length > line_length:
3407             break
3408
3409         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3410         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3411             break
3412
3413         if opening_bracket:
3414             if leaf is opening_bracket:
3415                 opening_bracket = None
3416             elif leaf.type in CLOSING_BRACKETS:
3417                 inner_brackets.add(id(leaf))
3418         elif leaf.type in CLOSING_BRACKETS:
3419             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3420                 # Empty brackets would fail a split so treat them as "inner"
3421                 # brackets (e.g. only add them to the `omit` set if another
3422                 # pair of brackets was good enough.
3423                 inner_brackets.add(id(leaf))
3424                 continue
3425
3426             if closing_bracket:
3427                 omit.add(id(closing_bracket))
3428                 omit.update(inner_brackets)
3429                 inner_brackets.clear()
3430                 yield omit
3431
3432             if leaf.value:
3433                 opening_bracket = leaf.opening_bracket
3434                 closing_bracket = leaf
3435
3436
3437 def get_future_imports(node: Node) -> Set[str]:
3438     """Return a set of __future__ imports in the file."""
3439     imports: Set[str] = set()
3440
3441     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3442         for child in children:
3443             if isinstance(child, Leaf):
3444                 if child.type == token.NAME:
3445                     yield child.value
3446             elif child.type == syms.import_as_name:
3447                 orig_name = child.children[0]
3448                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3449                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3450                 yield orig_name.value
3451             elif child.type == syms.import_as_names:
3452                 yield from get_imports_from_children(child.children)
3453             else:
3454                 raise AssertionError("Invalid syntax parsing imports")
3455
3456     for child in node.children:
3457         if child.type != syms.simple_stmt:
3458             break
3459         first_child = child.children[0]
3460         if isinstance(first_child, Leaf):
3461             # Continue looking if we see a docstring; otherwise stop.
3462             if (
3463                 len(child.children) == 2
3464                 and first_child.type == token.STRING
3465                 and child.children[1].type == token.NEWLINE
3466             ):
3467                 continue
3468             else:
3469                 break
3470         elif first_child.type == syms.import_from:
3471             module_name = first_child.children[1]
3472             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3473                 break
3474             imports |= set(get_imports_from_children(first_child.children[3:]))
3475         else:
3476             break
3477     return imports
3478
3479
3480 def gen_python_files_in_dir(
3481     path: Path,
3482     root: Path,
3483     include: Pattern[str],
3484     exclude: Pattern[str],
3485     report: "Report",
3486 ) -> Iterator[Path]:
3487     """Generate all files under `path` whose paths are not excluded by the
3488     `exclude` regex, but are included by the `include` regex.
3489
3490     Symbolic links pointing outside of the `root` directory are ignored.
3491
3492     `report` is where output about exclusions goes.
3493     """
3494     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3495     for child in path.iterdir():
3496         try:
3497             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3498         except ValueError:
3499             if child.is_symlink():
3500                 report.path_ignored(
3501                     child, f"is a symbolic link that points outside {root}"
3502                 )
3503                 continue
3504
3505             raise
3506
3507         if child.is_dir():
3508             normalized_path += "/"
3509         exclude_match = exclude.search(normalized_path)
3510         if exclude_match and exclude_match.group(0):
3511             report.path_ignored(child, f"matches the --exclude regular expression")
3512             continue
3513
3514         if child.is_dir():
3515             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3516
3517         elif child.is_file():
3518             include_match = include.search(normalized_path)
3519             if include_match:
3520                 yield child
3521
3522
3523 @lru_cache()
3524 def find_project_root(srcs: Iterable[str]) -> Path:
3525     """Return a directory containing .git, .hg, or pyproject.toml.
3526
3527     That directory can be one of the directories passed in `srcs` or their
3528     common parent.
3529
3530     If no directory in the tree contains a marker that would specify it's the
3531     project root, the root of the file system is returned.
3532     """
3533     if not srcs:
3534         return Path("/").resolve()
3535
3536     common_base = min(Path(src).resolve() for src in srcs)
3537     if common_base.is_dir():
3538         # Append a fake file so `parents` below returns `common_base_dir`, too.
3539         common_base /= "fake-file"
3540     for directory in common_base.parents:
3541         if (directory / ".git").is_dir():
3542             return directory
3543
3544         if (directory / ".hg").is_dir():
3545             return directory
3546
3547         if (directory / "pyproject.toml").is_file():
3548             return directory
3549
3550     return directory
3551
3552
3553 @dataclass
3554 class Report:
3555     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3556
3557     check: bool = False
3558     quiet: bool = False
3559     verbose: bool = False
3560     change_count: int = 0
3561     same_count: int = 0
3562     failure_count: int = 0
3563
3564     def done(self, src: Path, changed: Changed) -> None:
3565         """Increment the counter for successful reformatting. Write out a message."""
3566         if changed is Changed.YES:
3567             reformatted = "would reformat" if self.check else "reformatted"
3568             if self.verbose or not self.quiet:
3569                 out(f"{reformatted} {src}")
3570             self.change_count += 1
3571         else:
3572             if self.verbose:
3573                 if changed is Changed.NO:
3574                     msg = f"{src} already well formatted, good job."
3575                 else:
3576                     msg = f"{src} wasn't modified on disk since last run."
3577                 out(msg, bold=False)
3578             self.same_count += 1
3579
3580     def failed(self, src: Path, message: str) -> None:
3581         """Increment the counter for failed reformatting. Write out a message."""
3582         err(f"error: cannot format {src}: {message}")
3583         self.failure_count += 1
3584
3585     def path_ignored(self, path: Path, message: str) -> None:
3586         if self.verbose:
3587             out(f"{path} ignored: {message}", bold=False)
3588
3589     @property
3590     def return_code(self) -> int:
3591         """Return the exit code that the app should use.
3592
3593         This considers the current state of changed files and failures:
3594         - if there were any failures, return 123;
3595         - if any files were changed and --check is being used, return 1;
3596         - otherwise return 0.
3597         """
3598         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3599         # 126 we have special return codes reserved by the shell.
3600         if self.failure_count:
3601             return 123
3602
3603         elif self.change_count and self.check:
3604             return 1
3605
3606         return 0
3607
3608     def __str__(self) -> str:
3609         """Render a color report of the current state.
3610
3611         Use `click.unstyle` to remove colors.
3612         """
3613         if self.check:
3614             reformatted = "would be reformatted"
3615             unchanged = "would be left unchanged"
3616             failed = "would fail to reformat"
3617         else:
3618             reformatted = "reformatted"
3619             unchanged = "left unchanged"
3620             failed = "failed to reformat"
3621         report = []
3622         if self.change_count:
3623             s = "s" if self.change_count > 1 else ""
3624             report.append(
3625                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3626             )
3627         if self.same_count:
3628             s = "s" if self.same_count > 1 else ""
3629             report.append(f"{self.same_count} file{s} {unchanged}")
3630         if self.failure_count:
3631             s = "s" if self.failure_count > 1 else ""
3632             report.append(
3633                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3634             )
3635         return ", ".join(report) + "."
3636
3637
3638 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3639     filename = "<unknown>"
3640     if sys.version_info >= (3, 8):
3641         # TODO: support Python 4+ ;)
3642         for minor_version in range(sys.version_info[1], 4, -1):
3643             try:
3644                 return ast.parse(src, filename, feature_version=(3, minor_version))
3645             except SyntaxError:
3646                 continue
3647     else:
3648         for feature_version in (7, 6):
3649             try:
3650                 return ast3.parse(src, filename, feature_version=feature_version)
3651             except SyntaxError:
3652                 continue
3653
3654     return ast27.parse(src)
3655
3656
3657 def _fixup_ast_constants(
3658     node: Union[ast.AST, ast3.AST, ast27.AST]
3659 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3660     """Map ast nodes deprecated in 3.8 to Constant."""
3661     # casts are required until this is released:
3662     # https://github.com/python/typeshed/pull/3142
3663     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3664         return cast(ast.AST, ast.Constant(value=node.s))
3665     elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3666         return cast(ast.AST, ast.Constant(value=node.n))
3667     elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3668         return cast(ast.AST, ast.Constant(value=node.value))
3669     return node
3670
3671
3672 def assert_equivalent(src: str, dst: str) -> None:
3673     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3674
3675     def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3676         """Simple visitor generating strings to compare ASTs by content."""
3677
3678         node = _fixup_ast_constants(node)
3679
3680         yield f"{'  ' * depth}{node.__class__.__name__}("
3681
3682         for field in sorted(node._fields):
3683             # TypeIgnore has only one field 'lineno' which breaks this comparison
3684             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3685             if sys.version_info >= (3, 8):
3686                 type_ignore_classes += (ast.TypeIgnore,)
3687             if isinstance(node, type_ignore_classes):
3688                 break
3689
3690             try:
3691                 value = getattr(node, field)
3692             except AttributeError:
3693                 continue
3694
3695             yield f"{'  ' * (depth+1)}{field}="
3696
3697             if isinstance(value, list):
3698                 for item in value:
3699                     # Ignore nested tuples within del statements, because we may insert
3700                     # parentheses and they change the AST.
3701                     if (
3702                         field == "targets"
3703                         and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3704                         and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3705                     ):
3706                         for item in item.elts:
3707                             yield from _v(item, depth + 2)
3708                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3709                         yield from _v(item, depth + 2)
3710
3711             elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3712                 yield from _v(value, depth + 2)
3713
3714             else:
3715                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3716
3717         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3718
3719     try:
3720         src_ast = parse_ast(src)
3721     except Exception as exc:
3722         raise AssertionError(
3723             f"cannot use --safe with this file; failed to parse source file.  "
3724             f"AST error message: {exc}"
3725         )
3726
3727     try:
3728         dst_ast = parse_ast(dst)
3729     except Exception as exc:
3730         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3731         raise AssertionError(
3732             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3733             f"Please report a bug on https://github.com/psf/black/issues.  "
3734             f"This invalid output might be helpful: {log}"
3735         ) from None
3736
3737     src_ast_str = "\n".join(_v(src_ast))
3738     dst_ast_str = "\n".join(_v(dst_ast))
3739     if src_ast_str != dst_ast_str:
3740         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3741         raise AssertionError(
3742             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3743             f"the source.  "
3744             f"Please report a bug on https://github.com/psf/black/issues.  "
3745             f"This diff might be helpful: {log}"
3746         ) from None
3747
3748
3749 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3750     """Raise AssertionError if `dst` reformats differently the second time."""
3751     newdst = format_str(dst, mode=mode)
3752     if dst != newdst:
3753         log = dump_to_file(
3754             diff(src, dst, "source", "first pass"),
3755             diff(dst, newdst, "first pass", "second pass"),
3756         )
3757         raise AssertionError(
3758             f"INTERNAL ERROR: Black produced different code on the second pass "
3759             f"of the formatter.  "
3760             f"Please report a bug on https://github.com/psf/black/issues.  "
3761             f"This diff might be helpful: {log}"
3762         ) from None
3763
3764
3765 def dump_to_file(*output: str) -> str:
3766     """Dump `output` to a temporary file. Return path to the file."""
3767     with tempfile.NamedTemporaryFile(
3768         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3769     ) as f:
3770         for lines in output:
3771             f.write(lines)
3772             if lines and lines[-1] != "\n":
3773                 f.write("\n")
3774     return f.name
3775
3776
3777 @contextmanager
3778 def nullcontext() -> Iterator[None]:
3779     """Return context manager that does nothing.
3780     Similar to `nullcontext` from python 3.7"""
3781     yield
3782
3783
3784 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3785     """Return a unified diff string between strings `a` and `b`."""
3786     import difflib
3787
3788     a_lines = [line + "\n" for line in a.split("\n")]
3789     b_lines = [line + "\n" for line in b.split("\n")]
3790     return "".join(
3791         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3792     )
3793
3794
3795 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3796     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3797     err("Aborted!")
3798     for task in tasks:
3799         task.cancel()
3800
3801
3802 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3803     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3804     try:
3805         if sys.version_info[:2] >= (3, 7):
3806             all_tasks = asyncio.all_tasks
3807         else:
3808             all_tasks = asyncio.Task.all_tasks
3809         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3810         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3811         if not to_cancel:
3812             return
3813
3814         for task in to_cancel:
3815             task.cancel()
3816         loop.run_until_complete(
3817             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3818         )
3819     finally:
3820         # `concurrent.futures.Future` objects cannot be cancelled once they
3821         # are already running. There might be some when the `shutdown()` happened.
3822         # Silence their logger's spew about the event loop being closed.
3823         cf_logger = logging.getLogger("concurrent.futures")
3824         cf_logger.setLevel(logging.CRITICAL)
3825         loop.close()
3826
3827
3828 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3829     """Replace `regex` with `replacement` twice on `original`.
3830
3831     This is used by string normalization to perform replaces on
3832     overlapping matches.
3833     """
3834     return regex.sub(replacement, regex.sub(replacement, original))
3835
3836
3837 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3838     """Compile a regular expression string in `regex`.
3839
3840     If it contains newlines, use verbose mode.
3841     """
3842     if "\n" in regex:
3843         regex = "(?x)" + regex
3844     compiled: Pattern[str] = re.compile(regex)
3845     return compiled
3846
3847
3848 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3849     """Like `reversed(enumerate(sequence))` if that were possible."""
3850     index = len(sequence) - 1
3851     for element in reversed(sequence):
3852         yield (index, element)
3853         index -= 1
3854
3855
3856 def enumerate_with_length(
3857     line: Line, reversed: bool = False
3858 ) -> Iterator[Tuple[Index, Leaf, int]]:
3859     """Return an enumeration of leaves with their length.
3860
3861     Stops prematurely on multiline strings and standalone comments.
3862     """
3863     op = cast(
3864         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3865         enumerate_reversed if reversed else enumerate,
3866     )
3867     for index, leaf in op(line.leaves):
3868         length = len(leaf.prefix) + len(leaf.value)
3869         if "\n" in leaf.value:
3870             return  # Multiline strings, we can't continue.
3871
3872         for comment in line.comments_after(leaf):
3873             length += len(comment.value)
3874
3875         yield index, leaf, length
3876
3877
3878 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3879     """Return True if `line` is no longer than `line_length`.
3880
3881     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3882     """
3883     if not line_str:
3884         line_str = str(line).strip("\n")
3885     return (
3886         len(line_str) <= line_length
3887         and "\n" not in line_str  # multiline strings
3888         and not line.contains_standalone_comments()
3889     )
3890
3891
3892 def can_be_split(line: Line) -> bool:
3893     """Return False if the line cannot be split *for sure*.
3894
3895     This is not an exhaustive search but a cheap heuristic that we can use to
3896     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3897     in unnecessary parentheses).
3898     """
3899     leaves = line.leaves
3900     if len(leaves) < 2:
3901         return False
3902
3903     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3904         call_count = 0
3905         dot_count = 0
3906         next = leaves[-1]
3907         for leaf in leaves[-2::-1]:
3908             if leaf.type in OPENING_BRACKETS:
3909                 if next.type not in CLOSING_BRACKETS:
3910                     return False
3911
3912                 call_count += 1
3913             elif leaf.type == token.DOT:
3914                 dot_count += 1
3915             elif leaf.type == token.NAME:
3916                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3917                     return False
3918
3919             elif leaf.type not in CLOSING_BRACKETS:
3920                 return False
3921
3922             if dot_count > 1 and call_count > 1:
3923                 return False
3924
3925     return True
3926
3927
3928 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3929     """Does `line` have a shape safe to reformat without optional parens around it?
3930
3931     Returns True for only a subset of potentially nice looking formattings but
3932     the point is to not return false positives that end up producing lines that
3933     are too long.
3934     """
3935     bt = line.bracket_tracker
3936     if not bt.delimiters:
3937         # Without delimiters the optional parentheses are useless.
3938         return True
3939
3940     max_priority = bt.max_delimiter_priority()
3941     if bt.delimiter_count_with_priority(max_priority) > 1:
3942         # With more than one delimiter of a kind the optional parentheses read better.
3943         return False
3944
3945     if max_priority == DOT_PRIORITY:
3946         # A single stranded method call doesn't require optional parentheses.
3947         return True
3948
3949     assert len(line.leaves) >= 2, "Stranded delimiter"
3950
3951     first = line.leaves[0]
3952     second = line.leaves[1]
3953     penultimate = line.leaves[-2]
3954     last = line.leaves[-1]
3955
3956     # With a single delimiter, omit if the expression starts or ends with
3957     # a bracket.
3958     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3959         remainder = False
3960         length = 4 * line.depth
3961         for _index, leaf, leaf_length in enumerate_with_length(line):
3962             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3963                 remainder = True
3964             if remainder:
3965                 length += leaf_length
3966                 if length > line_length:
3967                     break
3968
3969                 if leaf.type in OPENING_BRACKETS:
3970                     # There are brackets we can further split on.
3971                     remainder = False
3972
3973         else:
3974             # checked the entire string and line length wasn't exceeded
3975             if len(line.leaves) == _index + 1:
3976                 return True
3977
3978         # Note: we are not returning False here because a line might have *both*
3979         # a leading opening bracket and a trailing closing bracket.  If the
3980         # opening bracket doesn't match our rule, maybe the closing will.
3981
3982     if (
3983         last.type == token.RPAR
3984         or last.type == token.RBRACE
3985         or (
3986             # don't use indexing for omitting optional parentheses;
3987             # it looks weird
3988             last.type == token.RSQB
3989             and last.parent
3990             and last.parent.type != syms.trailer
3991         )
3992     ):
3993         if penultimate.type in OPENING_BRACKETS:
3994             # Empty brackets don't help.
3995             return False
3996
3997         if is_multiline_string(first):
3998             # Additional wrapping of a multiline string in this situation is
3999             # unnecessary.
4000             return True
4001
4002         length = 4 * line.depth
4003         seen_other_brackets = False
4004         for _index, leaf, leaf_length in enumerate_with_length(line):
4005             length += leaf_length
4006             if leaf is last.opening_bracket:
4007                 if seen_other_brackets or length <= line_length:
4008                     return True
4009
4010             elif leaf.type in OPENING_BRACKETS:
4011                 # There are brackets we can further split on.
4012                 seen_other_brackets = True
4013
4014     return False
4015
4016
4017 def get_cache_file(mode: FileMode) -> Path:
4018     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
4019
4020
4021 def read_cache(mode: FileMode) -> Cache:
4022     """Read the cache if it exists and is well formed.
4023
4024     If it is not well formed, the call to write_cache later should resolve the issue.
4025     """
4026     cache_file = get_cache_file(mode)
4027     if not cache_file.exists():
4028         return {}
4029
4030     with cache_file.open("rb") as fobj:
4031         try:
4032             cache: Cache = pickle.load(fobj)
4033         except pickle.UnpicklingError:
4034             return {}
4035
4036     return cache
4037
4038
4039 def get_cache_info(path: Path) -> CacheInfo:
4040     """Return the information used to check if a file is already formatted or not."""
4041     stat = path.stat()
4042     return stat.st_mtime, stat.st_size
4043
4044
4045 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
4046     """Split an iterable of paths in `sources` into two sets.
4047
4048     The first contains paths of files that modified on disk or are not in the
4049     cache. The other contains paths to non-modified files.
4050     """
4051     todo, done = set(), set()
4052     for src in sources:
4053         src = src.resolve()
4054         if cache.get(src) != get_cache_info(src):
4055             todo.add(src)
4056         else:
4057             done.add(src)
4058     return todo, done
4059
4060
4061 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
4062     """Update the cache file."""
4063     cache_file = get_cache_file(mode)
4064     try:
4065         CACHE_DIR.mkdir(parents=True, exist_ok=True)
4066         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
4067         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4068             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
4069         os.replace(f.name, cache_file)
4070     except OSError:
4071         pass
4072
4073
4074 def patch_click() -> None:
4075     """Make Click not crash.
4076
4077     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4078     default which restricts paths that it can access during the lifetime of the
4079     application.  Click refuses to work in this scenario by raising a RuntimeError.
4080
4081     In case of Black the likelihood that non-ASCII characters are going to be used in
4082     file paths is minimal since it's Python source code.  Moreover, this crash was
4083     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4084     """
4085     try:
4086         from click import core
4087         from click import _unicodefun  # type: ignore
4088     except ModuleNotFoundError:
4089         return
4090
4091     for module in (core, _unicodefun):
4092         if hasattr(module, "_verify_python3_env"):
4093             module._verify_python3_env = lambda: None
4094
4095
4096 def patched_main() -> None:
4097     freeze_support()
4098     patch_click()
4099     main()
4100
4101
4102 if __name__ == "__main__":
4103     patched_main()