]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

a4f36887aecc3630136cb4ed040b3768b814564a
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from contextlib import contextmanager
5 from datetime import datetime
6 from enum import Enum
7 from functools import lru_cache, partial, wraps
8 import io
9 import itertools
10 import logging
11 from multiprocessing import Manager, freeze_support
12 import os
13 from pathlib import Path
14 import pickle
15 import re
16 import signal
17 import sys
18 import tempfile
19 import tokenize
20 import traceback
21 from typing import (
22     Any,
23     Callable,
24     Collection,
25     Dict,
26     Generator,
27     Generic,
28     Iterable,
29     Iterator,
30     List,
31     Optional,
32     Pattern,
33     Sequence,
34     Set,
35     Tuple,
36     TypeVar,
37     Union,
38     cast,
39 )
40
41 from appdirs import user_cache_dir
42 from attr import dataclass, evolve, Factory
43 import click
44 import toml
45 from typed_ast import ast3, ast27
46
47 # lib2to3 fork
48 from blib2to3.pytree import Node, Leaf, type_repr
49 from blib2to3 import pygram, pytree
50 from blib2to3.pgen2 import driver, token
51 from blib2to3.pgen2.grammar import Grammar
52 from blib2to3.pgen2.parse import ParseError
53
54 from _version import get_versions
55
56 v = get_versions()
57 __version__ = v.get("closest-tag", v["version"])
58 __git_version__ = v.get("full-revisionid")
59
60 DEFAULT_LINE_LENGTH = 88
61 DEFAULT_EXCLUDES = (
62     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
63 )
64 DEFAULT_INCLUDES = r"\.pyi?$"
65 CACHE_DIR = Path(user_cache_dir("black", version=__git_version__))
66
67
68 # types
69 FileContent = str
70 Encoding = str
71 NewLine = str
72 Depth = int
73 NodeType = int
74 LeafID = int
75 Priority = int
76 Index = int
77 LN = Union[Leaf, Node]
78 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
79 Timestamp = float
80 FileSize = int
81 CacheInfo = Tuple[Timestamp, FileSize]
82 Cache = Dict[Path, CacheInfo]
83 out = partial(click.secho, bold=True, err=True)
84 err = partial(click.secho, fg="red", err=True)
85
86 pygram.initialize(CACHE_DIR)
87 syms = pygram.python_symbols
88
89
90 class NothingChanged(UserWarning):
91     """Raised when reformatted code is the same as source."""
92
93
94 class CannotSplit(Exception):
95     """A readable split that fits the allotted line length is impossible."""
96
97
98 class InvalidInput(ValueError):
99     """Raised when input source code fails all parse attempts."""
100
101
102 class WriteBack(Enum):
103     NO = 0
104     YES = 1
105     DIFF = 2
106     CHECK = 3
107
108     @classmethod
109     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
110         if check and not diff:
111             return cls.CHECK
112
113         return cls.DIFF if diff else cls.YES
114
115
116 class Changed(Enum):
117     NO = 0
118     CACHED = 1
119     YES = 2
120
121
122 class TargetVersion(Enum):
123     PY27 = 2
124     PY33 = 3
125     PY34 = 4
126     PY35 = 5
127     PY36 = 6
128     PY37 = 7
129     PY38 = 8
130
131     def is_python2(self) -> bool:
132         return self is TargetVersion.PY27
133
134
135 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
136
137
138 class Feature(Enum):
139     # All string literals are unicode
140     UNICODE_LITERALS = 1
141     F_STRINGS = 2
142     NUMERIC_UNDERSCORES = 3
143     TRAILING_COMMA_IN_CALL = 4
144     TRAILING_COMMA_IN_DEF = 5
145     # The following two feature-flags are mutually exclusive, and exactly one should be
146     # set for every version of python.
147     ASYNC_IDENTIFIERS = 6
148     ASYNC_KEYWORDS = 7
149     ASSIGNMENT_EXPRESSIONS = 8
150     POS_ONLY_ARGUMENTS = 9
151
152
153 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
154     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
155     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
156     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
157     TargetVersion.PY35: {
158         Feature.UNICODE_LITERALS,
159         Feature.TRAILING_COMMA_IN_CALL,
160         Feature.ASYNC_IDENTIFIERS,
161     },
162     TargetVersion.PY36: {
163         Feature.UNICODE_LITERALS,
164         Feature.F_STRINGS,
165         Feature.NUMERIC_UNDERSCORES,
166         Feature.TRAILING_COMMA_IN_CALL,
167         Feature.TRAILING_COMMA_IN_DEF,
168         Feature.ASYNC_IDENTIFIERS,
169     },
170     TargetVersion.PY37: {
171         Feature.UNICODE_LITERALS,
172         Feature.F_STRINGS,
173         Feature.NUMERIC_UNDERSCORES,
174         Feature.TRAILING_COMMA_IN_CALL,
175         Feature.TRAILING_COMMA_IN_DEF,
176         Feature.ASYNC_KEYWORDS,
177     },
178     TargetVersion.PY38: {
179         Feature.UNICODE_LITERALS,
180         Feature.F_STRINGS,
181         Feature.NUMERIC_UNDERSCORES,
182         Feature.TRAILING_COMMA_IN_CALL,
183         Feature.TRAILING_COMMA_IN_DEF,
184         Feature.ASYNC_KEYWORDS,
185         Feature.ASSIGNMENT_EXPRESSIONS,
186         Feature.POS_ONLY_ARGUMENTS,
187     },
188 }
189
190
191 @dataclass
192 class FileMode:
193     target_versions: Set[TargetVersion] = Factory(set)
194     line_length: int = DEFAULT_LINE_LENGTH
195     string_normalization: bool = True
196     is_pyi: bool = False
197
198     def get_cache_key(self) -> str:
199         if self.target_versions:
200             version_str = ",".join(
201                 str(version.value)
202                 for version in sorted(self.target_versions, key=lambda v: v.value)
203             )
204         else:
205             version_str = "-"
206         parts = [
207             version_str,
208             str(self.line_length),
209             str(int(self.string_normalization)),
210             str(int(self.is_pyi)),
211         ]
212         return ".".join(parts)
213
214
215 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
216     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
217
218
219 def read_pyproject_toml(
220     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
221 ) -> Optional[str]:
222     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
223
224     Returns the path to a successfully found and read configuration file, None
225     otherwise.
226     """
227     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
228     if not value:
229         root = find_project_root(ctx.params.get("src", ()))
230         path = root / "pyproject.toml"
231         if path.is_file():
232             value = str(path)
233         else:
234             return None
235
236     try:
237         pyproject_toml = toml.load(value)
238         config = pyproject_toml.get("tool", {}).get("black", {})
239     except (toml.TomlDecodeError, OSError) as e:
240         raise click.FileError(
241             filename=value, hint=f"Error reading configuration file: {e}"
242         )
243
244     if not config:
245         return None
246
247     if ctx.default_map is None:
248         ctx.default_map = {}
249     ctx.default_map.update(  # type: ignore  # bad types in .pyi
250         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
251     )
252     return value
253
254
255 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
256 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
257 @click.option(
258     "-l",
259     "--line-length",
260     type=int,
261     default=DEFAULT_LINE_LENGTH,
262     help="How many characters per line to allow.",
263     show_default=True,
264 )
265 @click.option(
266     "-t",
267     "--target-version",
268     type=click.Choice([v.name.lower() for v in TargetVersion]),
269     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
270     multiple=True,
271     help=(
272         "Python versions that should be supported by Black's output. [default: "
273         "per-file auto-detection]"
274     ),
275 )
276 @click.option(
277     "--py36",
278     is_flag=True,
279     help=(
280         "Allow using Python 3.6-only syntax on all input files.  This will put "
281         "trailing commas in function signatures and calls also after *args and "
282         "**kwargs. Deprecated; use --target-version instead. "
283         "[default: per-file auto-detection]"
284     ),
285 )
286 @click.option(
287     "--pyi",
288     is_flag=True,
289     help=(
290         "Format all input files like typing stubs regardless of file extension "
291         "(useful when piping source on standard input)."
292     ),
293 )
294 @click.option(
295     "-S",
296     "--skip-string-normalization",
297     is_flag=True,
298     help="Don't normalize string quotes or prefixes.",
299 )
300 @click.option(
301     "--check",
302     is_flag=True,
303     help=(
304         "Don't write the files back, just return the status.  Return code 0 "
305         "means nothing would change.  Return code 1 means some files would be "
306         "reformatted.  Return code 123 means there was an internal error."
307     ),
308 )
309 @click.option(
310     "--diff",
311     is_flag=True,
312     help="Don't write the files back, just output a diff for each file on stdout.",
313 )
314 @click.option(
315     "--fast/--safe",
316     is_flag=True,
317     help="If --fast given, skip temporary sanity checks. [default: --safe]",
318 )
319 @click.option(
320     "--include",
321     type=str,
322     default=DEFAULT_INCLUDES,
323     help=(
324         "A regular expression that matches files and directories that should be "
325         "included on recursive searches.  An empty value means all files are "
326         "included regardless of the name.  Use forward slashes for directories on "
327         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
328         "later."
329     ),
330     show_default=True,
331 )
332 @click.option(
333     "--exclude",
334     type=str,
335     default=DEFAULT_EXCLUDES,
336     help=(
337         "A regular expression that matches files and directories that should be "
338         "excluded on recursive searches.  An empty value means no paths are excluded. "
339         "Use forward slashes for directories on all platforms (Windows, too).  "
340         "Exclusions are calculated first, inclusions later."
341     ),
342     show_default=True,
343 )
344 @click.option(
345     "-q",
346     "--quiet",
347     is_flag=True,
348     help=(
349         "Don't emit non-error messages to stderr. Errors are still emitted; "
350         "silence those with 2>/dev/null."
351     ),
352 )
353 @click.option(
354     "-v",
355     "--verbose",
356     is_flag=True,
357     help=(
358         "Also emit messages to stderr about files that were not changed or were "
359         "ignored due to --exclude=."
360     ),
361 )
362 @click.version_option(version=__version__)
363 @click.argument(
364     "src",
365     nargs=-1,
366     type=click.Path(
367         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
368     ),
369     is_eager=True,
370 )
371 @click.option(
372     "--config",
373     type=click.Path(
374         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
375     ),
376     is_eager=True,
377     callback=read_pyproject_toml,
378     help="Read configuration from PATH.",
379 )
380 @click.pass_context
381 def main(
382     ctx: click.Context,
383     code: Optional[str],
384     line_length: int,
385     target_version: List[TargetVersion],
386     check: bool,
387     diff: bool,
388     fast: bool,
389     pyi: bool,
390     py36: bool,
391     skip_string_normalization: bool,
392     quiet: bool,
393     verbose: bool,
394     include: str,
395     exclude: str,
396     src: Tuple[str],
397     config: Optional[str],
398 ) -> None:
399     """The uncompromising code formatter."""
400     write_back = WriteBack.from_configuration(check=check, diff=diff)
401     if target_version:
402         if py36:
403             err(f"Cannot use both --target-version and --py36")
404             ctx.exit(2)
405         else:
406             versions = set(target_version)
407     elif py36:
408         err(
409             "--py36 is deprecated and will be removed in a future version. "
410             "Use --target-version py36 instead."
411         )
412         versions = PY36_VERSIONS
413     else:
414         # We'll autodetect later.
415         versions = set()
416     mode = FileMode(
417         target_versions=versions,
418         line_length=line_length,
419         is_pyi=pyi,
420         string_normalization=not skip_string_normalization,
421     )
422     if config and verbose:
423         out(f"Using configuration from {config}.", bold=False, fg="blue")
424     if code is not None:
425         print(format_str(code, mode=mode))
426         ctx.exit(0)
427     try:
428         include_regex = re_compile_maybe_verbose(include)
429     except re.error:
430         err(f"Invalid regular expression for include given: {include!r}")
431         ctx.exit(2)
432     try:
433         exclude_regex = re_compile_maybe_verbose(exclude)
434     except re.error:
435         err(f"Invalid regular expression for exclude given: {exclude!r}")
436         ctx.exit(2)
437     report = Report(check=check, quiet=quiet, verbose=verbose)
438     root = find_project_root(src)
439     sources: Set[Path] = set()
440     path_empty(src, quiet, verbose, ctx)
441     for s in src:
442         p = Path(s)
443         if p.is_dir():
444             sources.update(
445                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
446             )
447         elif p.is_file() or s == "-":
448             # if a file was explicitly given, we don't care about its extension
449             sources.add(p)
450         else:
451             err(f"invalid path: {s}")
452     if len(sources) == 0:
453         if verbose or not quiet:
454             out("No Python files are present to be formatted. Nothing to do 😴")
455         ctx.exit(0)
456
457     if len(sources) == 1:
458         reformat_one(
459             src=sources.pop(),
460             fast=fast,
461             write_back=write_back,
462             mode=mode,
463             report=report,
464         )
465     else:
466         reformat_many(
467             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
468         )
469
470     if verbose or not quiet:
471         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
472         click.secho(str(report), err=True)
473     ctx.exit(report.return_code)
474
475
476 def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
477     """
478     Exit if there is no `src` provided for formatting
479     """
480     if not src:
481         if verbose or not quiet:
482             out("No Path provided. Nothing to do 😴")
483             ctx.exit(0)
484
485
486 def reformat_one(
487     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
488 ) -> None:
489     """Reformat a single file under `src` without spawning child processes.
490
491     `fast`, `write_back`, and `mode` options are passed to
492     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
493     """
494     try:
495         changed = Changed.NO
496         if not src.is_file() and str(src) == "-":
497             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
498                 changed = Changed.YES
499         else:
500             cache: Cache = {}
501             if write_back != WriteBack.DIFF:
502                 cache = read_cache(mode)
503                 res_src = src.resolve()
504                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
505                     changed = Changed.CACHED
506             if changed is not Changed.CACHED and format_file_in_place(
507                 src, fast=fast, write_back=write_back, mode=mode
508             ):
509                 changed = Changed.YES
510             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
511                 write_back is WriteBack.CHECK and changed is Changed.NO
512             ):
513                 write_cache(cache, [src], mode)
514         report.done(src, changed)
515     except Exception as exc:
516         report.failed(src, str(exc))
517
518
519 def reformat_many(
520     sources: Set[Path],
521     fast: bool,
522     write_back: WriteBack,
523     mode: FileMode,
524     report: "Report",
525 ) -> None:
526     """Reformat multiple files using a ProcessPoolExecutor."""
527     loop = asyncio.get_event_loop()
528     worker_count = os.cpu_count()
529     if sys.platform == "win32":
530         # Work around https://bugs.python.org/issue26903
531         worker_count = min(worker_count, 61)
532     executor = ProcessPoolExecutor(max_workers=worker_count)
533     try:
534         loop.run_until_complete(
535             schedule_formatting(
536                 sources=sources,
537                 fast=fast,
538                 write_back=write_back,
539                 mode=mode,
540                 report=report,
541                 loop=loop,
542                 executor=executor,
543             )
544         )
545     finally:
546         shutdown(loop)
547         executor.shutdown()
548
549
550 async def schedule_formatting(
551     sources: Set[Path],
552     fast: bool,
553     write_back: WriteBack,
554     mode: FileMode,
555     report: "Report",
556     loop: asyncio.AbstractEventLoop,
557     executor: Executor,
558 ) -> None:
559     """Run formatting of `sources` in parallel using the provided `executor`.
560
561     (Use ProcessPoolExecutors for actual parallelism.)
562
563     `write_back`, `fast`, and `mode` options are passed to
564     :func:`format_file_in_place`.
565     """
566     cache: Cache = {}
567     if write_back != WriteBack.DIFF:
568         cache = read_cache(mode)
569         sources, cached = filter_cached(cache, sources)
570         for src in sorted(cached):
571             report.done(src, Changed.CACHED)
572     if not sources:
573         return
574
575     cancelled = []
576     sources_to_cache = []
577     lock = None
578     if write_back == WriteBack.DIFF:
579         # For diff output, we need locks to ensure we don't interleave output
580         # from different processes.
581         manager = Manager()
582         lock = manager.Lock()
583     tasks = {
584         asyncio.ensure_future(
585             loop.run_in_executor(
586                 executor, format_file_in_place, src, fast, mode, write_back, lock
587             )
588         ): src
589         for src in sorted(sources)
590     }
591     pending: Iterable[asyncio.Future] = tasks.keys()
592     try:
593         loop.add_signal_handler(signal.SIGINT, cancel, pending)
594         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
595     except NotImplementedError:
596         # There are no good alternatives for these on Windows.
597         pass
598     while pending:
599         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
600         for task in done:
601             src = tasks.pop(task)
602             if task.cancelled():
603                 cancelled.append(task)
604             elif task.exception():
605                 report.failed(src, str(task.exception()))
606             else:
607                 changed = Changed.YES if task.result() else Changed.NO
608                 # If the file was written back or was successfully checked as
609                 # well-formatted, store this information in the cache.
610                 if write_back is WriteBack.YES or (
611                     write_back is WriteBack.CHECK and changed is Changed.NO
612                 ):
613                     sources_to_cache.append(src)
614                 report.done(src, changed)
615     if cancelled:
616         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
617     if sources_to_cache:
618         write_cache(cache, sources_to_cache, mode)
619
620
621 def format_file_in_place(
622     src: Path,
623     fast: bool,
624     mode: FileMode,
625     write_back: WriteBack = WriteBack.NO,
626     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
627 ) -> bool:
628     """Format file under `src` path. Return True if changed.
629
630     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
631     code to the file.
632     `mode` and `fast` options are passed to :func:`format_file_contents`.
633     """
634     if src.suffix == ".pyi":
635         mode = evolve(mode, is_pyi=True)
636
637     then = datetime.utcfromtimestamp(src.stat().st_mtime)
638     with open(src, "rb") as buf:
639         src_contents, encoding, newline = decode_bytes(buf.read())
640     try:
641         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
642     except NothingChanged:
643         return False
644
645     if write_back == write_back.YES:
646         with open(src, "w", encoding=encoding, newline=newline) as f:
647             f.write(dst_contents)
648     elif write_back == write_back.DIFF:
649         now = datetime.utcnow()
650         src_name = f"{src}\t{then} +0000"
651         dst_name = f"{src}\t{now} +0000"
652         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
653
654         with lock or nullcontext():
655             f = io.TextIOWrapper(
656                 sys.stdout.buffer,
657                 encoding=encoding,
658                 newline=newline,
659                 write_through=True,
660             )
661             f.write(diff_contents)
662             f.detach()
663
664     return True
665
666
667 def format_stdin_to_stdout(
668     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
669 ) -> bool:
670     """Format file on stdin. Return True if changed.
671
672     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
673     write a diff to stdout. The `mode` argument is passed to
674     :func:`format_file_contents`.
675     """
676     then = datetime.utcnow()
677     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
678     dst = src
679     try:
680         dst = format_file_contents(src, fast=fast, mode=mode)
681         return True
682
683     except NothingChanged:
684         return False
685
686     finally:
687         f = io.TextIOWrapper(
688             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
689         )
690         if write_back == WriteBack.YES:
691             f.write(dst)
692         elif write_back == WriteBack.DIFF:
693             now = datetime.utcnow()
694             src_name = f"STDIN\t{then} +0000"
695             dst_name = f"STDOUT\t{now} +0000"
696             f.write(diff(src, dst, src_name, dst_name))
697         f.detach()
698
699
700 def format_file_contents(
701     src_contents: str, *, fast: bool, mode: FileMode
702 ) -> FileContent:
703     """Reformat contents a file and return new contents.
704
705     If `fast` is False, additionally confirm that the reformatted code is
706     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
707     `mode` is passed to :func:`format_str`.
708     """
709     if src_contents.strip() == "":
710         raise NothingChanged
711
712     dst_contents = format_str(src_contents, mode=mode)
713     if src_contents == dst_contents:
714         raise NothingChanged
715
716     if not fast:
717         assert_equivalent(src_contents, dst_contents)
718         assert_stable(src_contents, dst_contents, mode=mode)
719     return dst_contents
720
721
722 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
723     """Reformat a string and return new contents.
724
725     `mode` determines formatting options, such as how many characters per line are
726     allowed.
727     """
728     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
729     dst_contents = []
730     future_imports = get_future_imports(src_node)
731     if mode.target_versions:
732         versions = mode.target_versions
733     else:
734         versions = detect_target_versions(src_node)
735     normalize_fmt_off(src_node)
736     lines = LineGenerator(
737         remove_u_prefix="unicode_literals" in future_imports
738         or supports_feature(versions, Feature.UNICODE_LITERALS),
739         is_pyi=mode.is_pyi,
740         normalize_strings=mode.string_normalization,
741     )
742     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
743     empty_line = Line()
744     after = 0
745     split_line_features = {
746         feature
747         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
748         if supports_feature(versions, feature)
749     }
750     for current_line in lines.visit(src_node):
751         for _ in range(after):
752             dst_contents.append(str(empty_line))
753         before, after = elt.maybe_empty_lines(current_line)
754         for _ in range(before):
755             dst_contents.append(str(empty_line))
756         for line in split_line(
757             current_line, line_length=mode.line_length, features=split_line_features
758         ):
759             dst_contents.append(str(line))
760     return "".join(dst_contents)
761
762
763 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
764     """Return a tuple of (decoded_contents, encoding, newline).
765
766     `newline` is either CRLF or LF but `decoded_contents` is decoded with
767     universal newlines (i.e. only contains LF).
768     """
769     srcbuf = io.BytesIO(src)
770     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
771     if not lines:
772         return "", encoding, "\n"
773
774     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
775     srcbuf.seek(0)
776     with io.TextIOWrapper(srcbuf, encoding) as tiow:
777         return tiow.read(), encoding, newline
778
779
780 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
781     if not target_versions:
782         # No target_version specified, so try all grammars.
783         return [
784             # Python 3.7+
785             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
786             # Python 3.0-3.6
787             pygram.python_grammar_no_print_statement_no_exec_statement,
788             # Python 2.7 with future print_function import
789             pygram.python_grammar_no_print_statement,
790             # Python 2.7
791             pygram.python_grammar,
792         ]
793     elif all(version.is_python2() for version in target_versions):
794         # Python 2-only code, so try Python 2 grammars.
795         return [
796             # Python 2.7 with future print_function import
797             pygram.python_grammar_no_print_statement,
798             # Python 2.7
799             pygram.python_grammar,
800         ]
801     else:
802         # Python 3-compatible code, so only try Python 3 grammar.
803         grammars = []
804         # If we have to parse both, try to parse async as a keyword first
805         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
806             # Python 3.7+
807             grammars.append(
808                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
809             )
810         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
811             # Python 3.0-3.6
812             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
813         # At least one of the above branches must have been taken, because every Python
814         # version has exactly one of the two 'ASYNC_*' flags
815         return grammars
816
817
818 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
819     """Given a string with source, return the lib2to3 Node."""
820     if src_txt[-1:] != "\n":
821         src_txt += "\n"
822
823     for grammar in get_grammars(set(target_versions)):
824         drv = driver.Driver(grammar, pytree.convert)
825         try:
826             result = drv.parse_string(src_txt, True)
827             break
828
829         except ParseError as pe:
830             lineno, column = pe.context[1]
831             lines = src_txt.splitlines()
832             try:
833                 faulty_line = lines[lineno - 1]
834             except IndexError:
835                 faulty_line = "<line number missing in source>"
836             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
837     else:
838         raise exc from None
839
840     if isinstance(result, Leaf):
841         result = Node(syms.file_input, [result])
842     return result
843
844
845 def lib2to3_unparse(node: Node) -> str:
846     """Given a lib2to3 node, return its string representation."""
847     code = str(node)
848     return code
849
850
851 T = TypeVar("T")
852
853
854 class Visitor(Generic[T]):
855     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
856
857     def visit(self, node: LN) -> Iterator[T]:
858         """Main method to visit `node` and its children.
859
860         It tries to find a `visit_*()` method for the given `node.type`, like
861         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
862         If no dedicated `visit_*()` method is found, chooses `visit_default()`
863         instead.
864
865         Then yields objects of type `T` from the selected visitor.
866         """
867         if node.type < 256:
868             name = token.tok_name[node.type]
869         else:
870             name = type_repr(node.type)
871         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
872
873     def visit_default(self, node: LN) -> Iterator[T]:
874         """Default `visit_*()` implementation. Recurses to children of `node`."""
875         if isinstance(node, Node):
876             for child in node.children:
877                 yield from self.visit(child)
878
879
880 @dataclass
881 class DebugVisitor(Visitor[T]):
882     tree_depth: int = 0
883
884     def visit_default(self, node: LN) -> Iterator[T]:
885         indent = " " * (2 * self.tree_depth)
886         if isinstance(node, Node):
887             _type = type_repr(node.type)
888             out(f"{indent}{_type}", fg="yellow")
889             self.tree_depth += 1
890             for child in node.children:
891                 yield from self.visit(child)
892
893             self.tree_depth -= 1
894             out(f"{indent}/{_type}", fg="yellow", bold=False)
895         else:
896             _type = token.tok_name.get(node.type, str(node.type))
897             out(f"{indent}{_type}", fg="blue", nl=False)
898             if node.prefix:
899                 # We don't have to handle prefixes for `Node` objects since
900                 # that delegates to the first child anyway.
901                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
902             out(f" {node.value!r}", fg="blue", bold=False)
903
904     @classmethod
905     def show(cls, code: Union[str, Leaf, Node]) -> None:
906         """Pretty-print the lib2to3 AST of a given string of `code`.
907
908         Convenience method for debugging.
909         """
910         v: DebugVisitor[None] = DebugVisitor()
911         if isinstance(code, str):
912             code = lib2to3_parse(code)
913         list(v.visit(code))
914
915
916 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
917 STATEMENT = {
918     syms.if_stmt,
919     syms.while_stmt,
920     syms.for_stmt,
921     syms.try_stmt,
922     syms.except_clause,
923     syms.with_stmt,
924     syms.funcdef,
925     syms.classdef,
926 }
927 STANDALONE_COMMENT = 153
928 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
929 LOGIC_OPERATORS = {"and", "or"}
930 COMPARATORS = {
931     token.LESS,
932     token.GREATER,
933     token.EQEQUAL,
934     token.NOTEQUAL,
935     token.LESSEQUAL,
936     token.GREATEREQUAL,
937 }
938 MATH_OPERATORS = {
939     token.VBAR,
940     token.CIRCUMFLEX,
941     token.AMPER,
942     token.LEFTSHIFT,
943     token.RIGHTSHIFT,
944     token.PLUS,
945     token.MINUS,
946     token.STAR,
947     token.SLASH,
948     token.DOUBLESLASH,
949     token.PERCENT,
950     token.AT,
951     token.TILDE,
952     token.DOUBLESTAR,
953 }
954 STARS = {token.STAR, token.DOUBLESTAR}
955 VARARGS_SPECIALS = STARS | {token.SLASH}
956 VARARGS_PARENTS = {
957     syms.arglist,
958     syms.argument,  # double star in arglist
959     syms.trailer,  # single argument to call
960     syms.typedargslist,
961     syms.varargslist,  # lambdas
962 }
963 UNPACKING_PARENTS = {
964     syms.atom,  # single element of a list or set literal
965     syms.dictsetmaker,
966     syms.listmaker,
967     syms.testlist_gexp,
968     syms.testlist_star_expr,
969 }
970 TEST_DESCENDANTS = {
971     syms.test,
972     syms.lambdef,
973     syms.or_test,
974     syms.and_test,
975     syms.not_test,
976     syms.comparison,
977     syms.star_expr,
978     syms.expr,
979     syms.xor_expr,
980     syms.and_expr,
981     syms.shift_expr,
982     syms.arith_expr,
983     syms.trailer,
984     syms.term,
985     syms.power,
986 }
987 ASSIGNMENTS = {
988     "=",
989     "+=",
990     "-=",
991     "*=",
992     "@=",
993     "/=",
994     "%=",
995     "&=",
996     "|=",
997     "^=",
998     "<<=",
999     ">>=",
1000     "**=",
1001     "//=",
1002 }
1003 COMPREHENSION_PRIORITY = 20
1004 COMMA_PRIORITY = 18
1005 TERNARY_PRIORITY = 16
1006 LOGIC_PRIORITY = 14
1007 STRING_PRIORITY = 12
1008 COMPARATOR_PRIORITY = 10
1009 MATH_PRIORITIES = {
1010     token.VBAR: 9,
1011     token.CIRCUMFLEX: 8,
1012     token.AMPER: 7,
1013     token.LEFTSHIFT: 6,
1014     token.RIGHTSHIFT: 6,
1015     token.PLUS: 5,
1016     token.MINUS: 5,
1017     token.STAR: 4,
1018     token.SLASH: 4,
1019     token.DOUBLESLASH: 4,
1020     token.PERCENT: 4,
1021     token.AT: 4,
1022     token.TILDE: 3,
1023     token.DOUBLESTAR: 2,
1024 }
1025 DOT_PRIORITY = 1
1026
1027
1028 @dataclass
1029 class BracketTracker:
1030     """Keeps track of brackets on a line."""
1031
1032     depth: int = 0
1033     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1034     delimiters: Dict[LeafID, Priority] = Factory(dict)
1035     previous: Optional[Leaf] = None
1036     _for_loop_depths: List[int] = Factory(list)
1037     _lambda_argument_depths: List[int] = Factory(list)
1038
1039     def mark(self, leaf: Leaf) -> None:
1040         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1041
1042         All leaves receive an int `bracket_depth` field that stores how deep
1043         within brackets a given leaf is. 0 means there are no enclosing brackets
1044         that started on this line.
1045
1046         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1047         field that it forms a pair with. This is a one-directional link to
1048         avoid reference cycles.
1049
1050         If a leaf is a delimiter (a token on which Black can split the line if
1051         needed) and it's on depth 0, its `id()` is stored in the tracker's
1052         `delimiters` field.
1053         """
1054         if leaf.type == token.COMMENT:
1055             return
1056
1057         self.maybe_decrement_after_for_loop_variable(leaf)
1058         self.maybe_decrement_after_lambda_arguments(leaf)
1059         if leaf.type in CLOSING_BRACKETS:
1060             self.depth -= 1
1061             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1062             leaf.opening_bracket = opening_bracket
1063         leaf.bracket_depth = self.depth
1064         if self.depth == 0:
1065             delim = is_split_before_delimiter(leaf, self.previous)
1066             if delim and self.previous is not None:
1067                 self.delimiters[id(self.previous)] = delim
1068             else:
1069                 delim = is_split_after_delimiter(leaf, self.previous)
1070                 if delim:
1071                     self.delimiters[id(leaf)] = delim
1072         if leaf.type in OPENING_BRACKETS:
1073             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1074             self.depth += 1
1075         self.previous = leaf
1076         self.maybe_increment_lambda_arguments(leaf)
1077         self.maybe_increment_for_loop_variable(leaf)
1078
1079     def any_open_brackets(self) -> bool:
1080         """Return True if there is an yet unmatched open bracket on the line."""
1081         return bool(self.bracket_match)
1082
1083     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1084         """Return the highest priority of a delimiter found on the line.
1085
1086         Values are consistent with what `is_split_*_delimiter()` return.
1087         Raises ValueError on no delimiters.
1088         """
1089         return max(v for k, v in self.delimiters.items() if k not in exclude)
1090
1091     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1092         """Return the number of delimiters with the given `priority`.
1093
1094         If no `priority` is passed, defaults to max priority on the line.
1095         """
1096         if not self.delimiters:
1097             return 0
1098
1099         priority = priority or self.max_delimiter_priority()
1100         return sum(1 for p in self.delimiters.values() if p == priority)
1101
1102     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1103         """In a for loop, or comprehension, the variables are often unpacks.
1104
1105         To avoid splitting on the comma in this situation, increase the depth of
1106         tokens between `for` and `in`.
1107         """
1108         if leaf.type == token.NAME and leaf.value == "for":
1109             self.depth += 1
1110             self._for_loop_depths.append(self.depth)
1111             return True
1112
1113         return False
1114
1115     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1116         """See `maybe_increment_for_loop_variable` above for explanation."""
1117         if (
1118             self._for_loop_depths
1119             and self._for_loop_depths[-1] == self.depth
1120             and leaf.type == token.NAME
1121             and leaf.value == "in"
1122         ):
1123             self.depth -= 1
1124             self._for_loop_depths.pop()
1125             return True
1126
1127         return False
1128
1129     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1130         """In a lambda expression, there might be more than one argument.
1131
1132         To avoid splitting on the comma in this situation, increase the depth of
1133         tokens between `lambda` and `:`.
1134         """
1135         if leaf.type == token.NAME and leaf.value == "lambda":
1136             self.depth += 1
1137             self._lambda_argument_depths.append(self.depth)
1138             return True
1139
1140         return False
1141
1142     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1143         """See `maybe_increment_lambda_arguments` above for explanation."""
1144         if (
1145             self._lambda_argument_depths
1146             and self._lambda_argument_depths[-1] == self.depth
1147             and leaf.type == token.COLON
1148         ):
1149             self.depth -= 1
1150             self._lambda_argument_depths.pop()
1151             return True
1152
1153         return False
1154
1155     def get_open_lsqb(self) -> Optional[Leaf]:
1156         """Return the most recent opening square bracket (if any)."""
1157         return self.bracket_match.get((self.depth - 1, token.RSQB))
1158
1159
1160 @dataclass
1161 class Line:
1162     """Holds leaves and comments. Can be printed with `str(line)`."""
1163
1164     depth: int = 0
1165     leaves: List[Leaf] = Factory(list)
1166     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1167     bracket_tracker: BracketTracker = Factory(BracketTracker)
1168     inside_brackets: bool = False
1169     should_explode: bool = False
1170
1171     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1172         """Add a new `leaf` to the end of the line.
1173
1174         Unless `preformatted` is True, the `leaf` will receive a new consistent
1175         whitespace prefix and metadata applied by :class:`BracketTracker`.
1176         Trailing commas are maybe removed, unpacked for loop variables are
1177         demoted from being delimiters.
1178
1179         Inline comments are put aside.
1180         """
1181         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1182         if not has_value:
1183             return
1184
1185         if token.COLON == leaf.type and self.is_class_paren_empty:
1186             del self.leaves[-2:]
1187         if self.leaves and not preformatted:
1188             # Note: at this point leaf.prefix should be empty except for
1189             # imports, for which we only preserve newlines.
1190             leaf.prefix += whitespace(
1191                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1192             )
1193         if self.inside_brackets or not preformatted:
1194             self.bracket_tracker.mark(leaf)
1195             self.maybe_remove_trailing_comma(leaf)
1196         if not self.append_comment(leaf):
1197             self.leaves.append(leaf)
1198
1199     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1200         """Like :func:`append()` but disallow invalid standalone comment structure.
1201
1202         Raises ValueError when any `leaf` is appended after a standalone comment
1203         or when a standalone comment is not the first leaf on the line.
1204         """
1205         if self.bracket_tracker.depth == 0:
1206             if self.is_comment:
1207                 raise ValueError("cannot append to standalone comments")
1208
1209             if self.leaves and leaf.type == STANDALONE_COMMENT:
1210                 raise ValueError(
1211                     "cannot append standalone comments to a populated line"
1212                 )
1213
1214         self.append(leaf, preformatted=preformatted)
1215
1216     @property
1217     def is_comment(self) -> bool:
1218         """Is this line a standalone comment?"""
1219         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1220
1221     @property
1222     def is_decorator(self) -> bool:
1223         """Is this line a decorator?"""
1224         return bool(self) and self.leaves[0].type == token.AT
1225
1226     @property
1227     def is_import(self) -> bool:
1228         """Is this an import line?"""
1229         return bool(self) and is_import(self.leaves[0])
1230
1231     @property
1232     def is_class(self) -> bool:
1233         """Is this line a class definition?"""
1234         return (
1235             bool(self)
1236             and self.leaves[0].type == token.NAME
1237             and self.leaves[0].value == "class"
1238         )
1239
1240     @property
1241     def is_stub_class(self) -> bool:
1242         """Is this line a class definition with a body consisting only of "..."?"""
1243         return self.is_class and self.leaves[-3:] == [
1244             Leaf(token.DOT, ".") for _ in range(3)
1245         ]
1246
1247     @property
1248     def is_def(self) -> bool:
1249         """Is this a function definition? (Also returns True for async defs.)"""
1250         try:
1251             first_leaf = self.leaves[0]
1252         except IndexError:
1253             return False
1254
1255         try:
1256             second_leaf: Optional[Leaf] = self.leaves[1]
1257         except IndexError:
1258             second_leaf = None
1259         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1260             first_leaf.type == token.ASYNC
1261             and second_leaf is not None
1262             and second_leaf.type == token.NAME
1263             and second_leaf.value == "def"
1264         )
1265
1266     @property
1267     def is_class_paren_empty(self) -> bool:
1268         """Is this a class with no base classes but using parentheses?
1269
1270         Those are unnecessary and should be removed.
1271         """
1272         return (
1273             bool(self)
1274             and len(self.leaves) == 4
1275             and self.is_class
1276             and self.leaves[2].type == token.LPAR
1277             and self.leaves[2].value == "("
1278             and self.leaves[3].type == token.RPAR
1279             and self.leaves[3].value == ")"
1280         )
1281
1282     @property
1283     def is_triple_quoted_string(self) -> bool:
1284         """Is the line a triple quoted string?"""
1285         return (
1286             bool(self)
1287             and self.leaves[0].type == token.STRING
1288             and self.leaves[0].value.startswith(('"""', "'''"))
1289         )
1290
1291     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1292         """If so, needs to be split before emitting."""
1293         for leaf in self.leaves:
1294             if leaf.type == STANDALONE_COMMENT:
1295                 if leaf.bracket_depth <= depth_limit:
1296                     return True
1297         return False
1298
1299     def contains_uncollapsable_type_comments(self) -> bool:
1300         ignored_ids = set()
1301         try:
1302             last_leaf = self.leaves[-1]
1303             ignored_ids.add(id(last_leaf))
1304             if last_leaf.type == token.COMMA or (
1305                 last_leaf.type == token.RPAR and not last_leaf.value
1306             ):
1307                 # When trailing commas or optional parens are inserted by Black for
1308                 # consistency, comments after the previous last element are not moved
1309                 # (they don't have to, rendering will still be correct).  So we ignore
1310                 # trailing commas and invisible.
1311                 last_leaf = self.leaves[-2]
1312                 ignored_ids.add(id(last_leaf))
1313         except IndexError:
1314             return False
1315
1316         # A type comment is uncollapsable if it is attached to a leaf
1317         # that isn't at the end of the line (since that could cause it
1318         # to get associated to a different argument) or if there are
1319         # comments before it (since that could cause it to get hidden
1320         # behind a comment.
1321         comment_seen = False
1322         for leaf_id, comments in self.comments.items():
1323             for comment in comments:
1324                 if is_type_comment(comment):
1325                     if leaf_id not in ignored_ids or comment_seen:
1326                         return True
1327
1328             comment_seen = True
1329
1330         return False
1331
1332     def contains_multiline_strings(self) -> bool:
1333         for leaf in self.leaves:
1334             if is_multiline_string(leaf):
1335                 return True
1336
1337         return False
1338
1339     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1340         """Remove trailing comma if there is one and it's safe."""
1341         if not (
1342             self.leaves
1343             and self.leaves[-1].type == token.COMMA
1344             and closing.type in CLOSING_BRACKETS
1345         ):
1346             return False
1347
1348         if closing.type == token.RBRACE:
1349             self.remove_trailing_comma()
1350             return True
1351
1352         if closing.type == token.RSQB:
1353             comma = self.leaves[-1]
1354             if comma.parent and comma.parent.type == syms.listmaker:
1355                 self.remove_trailing_comma()
1356                 return True
1357
1358         # For parens let's check if it's safe to remove the comma.
1359         # Imports are always safe.
1360         if self.is_import:
1361             self.remove_trailing_comma()
1362             return True
1363
1364         # Otherwise, if the trailing one is the only one, we might mistakenly
1365         # change a tuple into a different type by removing the comma.
1366         depth = closing.bracket_depth + 1
1367         commas = 0
1368         opening = closing.opening_bracket
1369         for _opening_index, leaf in enumerate(self.leaves):
1370             if leaf is opening:
1371                 break
1372
1373         else:
1374             return False
1375
1376         for leaf in self.leaves[_opening_index + 1 :]:
1377             if leaf is closing:
1378                 break
1379
1380             bracket_depth = leaf.bracket_depth
1381             if bracket_depth == depth and leaf.type == token.COMMA:
1382                 commas += 1
1383                 if leaf.parent and leaf.parent.type in {
1384                     syms.arglist,
1385                     syms.typedargslist,
1386                 }:
1387                     commas += 1
1388                     break
1389
1390         if commas > 1:
1391             self.remove_trailing_comma()
1392             return True
1393
1394         return False
1395
1396     def append_comment(self, comment: Leaf) -> bool:
1397         """Add an inline or standalone comment to the line."""
1398         if (
1399             comment.type == STANDALONE_COMMENT
1400             and self.bracket_tracker.any_open_brackets()
1401         ):
1402             comment.prefix = ""
1403             return False
1404
1405         if comment.type != token.COMMENT:
1406             return False
1407
1408         if not self.leaves:
1409             comment.type = STANDALONE_COMMENT
1410             comment.prefix = ""
1411             return False
1412
1413         last_leaf = self.leaves[-1]
1414         if (
1415             last_leaf.type == token.RPAR
1416             and not last_leaf.value
1417             and last_leaf.parent
1418             and len(list(last_leaf.parent.leaves())) <= 3
1419             and not is_type_comment(comment)
1420         ):
1421             # Comments on an optional parens wrapping a single leaf should belong to
1422             # the wrapped node except if it's a type comment. Pinning the comment like
1423             # this avoids unstable formatting caused by comment migration.
1424             if len(self.leaves) < 2:
1425                 comment.type = STANDALONE_COMMENT
1426                 comment.prefix = ""
1427                 return False
1428             last_leaf = self.leaves[-2]
1429         self.comments.setdefault(id(last_leaf), []).append(comment)
1430         return True
1431
1432     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1433         """Generate comments that should appear directly after `leaf`."""
1434         return self.comments.get(id(leaf), [])
1435
1436     def remove_trailing_comma(self) -> None:
1437         """Remove the trailing comma and moves the comments attached to it."""
1438         trailing_comma = self.leaves.pop()
1439         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1440         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1441             trailing_comma_comments
1442         )
1443
1444     def is_complex_subscript(self, leaf: Leaf) -> bool:
1445         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1446         open_lsqb = self.bracket_tracker.get_open_lsqb()
1447         if open_lsqb is None:
1448             return False
1449
1450         subscript_start = open_lsqb.next_sibling
1451
1452         if isinstance(subscript_start, Node):
1453             if subscript_start.type == syms.listmaker:
1454                 return False
1455
1456             if subscript_start.type == syms.subscriptlist:
1457                 subscript_start = child_towards(subscript_start, leaf)
1458         return subscript_start is not None and any(
1459             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1460         )
1461
1462     def __str__(self) -> str:
1463         """Render the line."""
1464         if not self:
1465             return "\n"
1466
1467         indent = "    " * self.depth
1468         leaves = iter(self.leaves)
1469         first = next(leaves)
1470         res = f"{first.prefix}{indent}{first.value}"
1471         for leaf in leaves:
1472             res += str(leaf)
1473         for comment in itertools.chain.from_iterable(self.comments.values()):
1474             res += str(comment)
1475         return res + "\n"
1476
1477     def __bool__(self) -> bool:
1478         """Return True if the line has leaves or comments."""
1479         return bool(self.leaves or self.comments)
1480
1481
1482 @dataclass
1483 class EmptyLineTracker:
1484     """Provides a stateful method that returns the number of potential extra
1485     empty lines needed before and after the currently processed line.
1486
1487     Note: this tracker works on lines that haven't been split yet.  It assumes
1488     the prefix of the first leaf consists of optional newlines.  Those newlines
1489     are consumed by `maybe_empty_lines()` and included in the computation.
1490     """
1491
1492     is_pyi: bool = False
1493     previous_line: Optional[Line] = None
1494     previous_after: int = 0
1495     previous_defs: List[int] = Factory(list)
1496
1497     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1498         """Return the number of extra empty lines before and after the `current_line`.
1499
1500         This is for separating `def`, `async def` and `class` with extra empty
1501         lines (two on module-level).
1502         """
1503         before, after = self._maybe_empty_lines(current_line)
1504         before = (
1505             # Black should not insert empty lines at the beginning
1506             # of the file
1507             0
1508             if self.previous_line is None
1509             else before - self.previous_after
1510         )
1511         self.previous_after = after
1512         self.previous_line = current_line
1513         return before, after
1514
1515     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1516         max_allowed = 1
1517         if current_line.depth == 0:
1518             max_allowed = 1 if self.is_pyi else 2
1519         if current_line.leaves:
1520             # Consume the first leaf's extra newlines.
1521             first_leaf = current_line.leaves[0]
1522             before = first_leaf.prefix.count("\n")
1523             before = min(before, max_allowed)
1524             first_leaf.prefix = ""
1525         else:
1526             before = 0
1527         depth = current_line.depth
1528         while self.previous_defs and self.previous_defs[-1] >= depth:
1529             self.previous_defs.pop()
1530             if self.is_pyi:
1531                 before = 0 if depth else 1
1532             else:
1533                 before = 1 if depth else 2
1534         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1535             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1536
1537         if (
1538             self.previous_line
1539             and self.previous_line.is_import
1540             and not current_line.is_import
1541             and depth == self.previous_line.depth
1542         ):
1543             return (before or 1), 0
1544
1545         if (
1546             self.previous_line
1547             and self.previous_line.is_class
1548             and current_line.is_triple_quoted_string
1549         ):
1550             return before, 1
1551
1552         return before, 0
1553
1554     def _maybe_empty_lines_for_class_or_def(
1555         self, current_line: Line, before: int
1556     ) -> Tuple[int, int]:
1557         if not current_line.is_decorator:
1558             self.previous_defs.append(current_line.depth)
1559         if self.previous_line is None:
1560             # Don't insert empty lines before the first line in the file.
1561             return 0, 0
1562
1563         if self.previous_line.is_decorator:
1564             return 0, 0
1565
1566         if self.previous_line.depth < current_line.depth and (
1567             self.previous_line.is_class or self.previous_line.is_def
1568         ):
1569             return 0, 0
1570
1571         if (
1572             self.previous_line.is_comment
1573             and self.previous_line.depth == current_line.depth
1574             and before == 0
1575         ):
1576             return 0, 0
1577
1578         if self.is_pyi:
1579             if self.previous_line.depth > current_line.depth:
1580                 newlines = 1
1581             elif current_line.is_class or self.previous_line.is_class:
1582                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1583                     # No blank line between classes with an empty body
1584                     newlines = 0
1585                 else:
1586                     newlines = 1
1587             elif current_line.is_def and not self.previous_line.is_def:
1588                 # Blank line between a block of functions and a block of non-functions
1589                 newlines = 1
1590             else:
1591                 newlines = 0
1592         else:
1593             newlines = 2
1594         if current_line.depth and newlines:
1595             newlines -= 1
1596         return newlines, 0
1597
1598
1599 @dataclass
1600 class LineGenerator(Visitor[Line]):
1601     """Generates reformatted Line objects.  Empty lines are not emitted.
1602
1603     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1604     in ways that will no longer stringify to valid Python code on the tree.
1605     """
1606
1607     is_pyi: bool = False
1608     normalize_strings: bool = True
1609     current_line: Line = Factory(Line)
1610     remove_u_prefix: bool = False
1611
1612     def line(self, indent: int = 0) -> Iterator[Line]:
1613         """Generate a line.
1614
1615         If the line is empty, only emit if it makes sense.
1616         If the line is too long, split it first and then generate.
1617
1618         If any lines were generated, set up a new current_line.
1619         """
1620         if not self.current_line:
1621             self.current_line.depth += indent
1622             return  # Line is empty, don't emit. Creating a new one unnecessary.
1623
1624         complete_line = self.current_line
1625         self.current_line = Line(depth=complete_line.depth + indent)
1626         yield complete_line
1627
1628     def visit_default(self, node: LN) -> Iterator[Line]:
1629         """Default `visit_*()` implementation. Recurses to children of `node`."""
1630         if isinstance(node, Leaf):
1631             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1632             for comment in generate_comments(node):
1633                 if any_open_brackets:
1634                     # any comment within brackets is subject to splitting
1635                     self.current_line.append(comment)
1636                 elif comment.type == token.COMMENT:
1637                     # regular trailing comment
1638                     self.current_line.append(comment)
1639                     yield from self.line()
1640
1641                 else:
1642                     # regular standalone comment
1643                     yield from self.line()
1644
1645                     self.current_line.append(comment)
1646                     yield from self.line()
1647
1648             normalize_prefix(node, inside_brackets=any_open_brackets)
1649             if self.normalize_strings and node.type == token.STRING:
1650                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1651                 normalize_string_quotes(node)
1652             if node.type == token.NUMBER:
1653                 normalize_numeric_literal(node)
1654             if node.type not in WHITESPACE:
1655                 self.current_line.append(node)
1656         yield from super().visit_default(node)
1657
1658     def visit_atom(self, node: Node) -> Iterator[Line]:
1659         # Always make parentheses invisible around a single node, because it should
1660         # not be needed (except in the case of yield, where removing the parentheses
1661         # produces a SyntaxError).
1662         if (
1663             len(node.children) == 3
1664             and isinstance(node.children[0], Leaf)
1665             and node.children[0].type == token.LPAR
1666             and isinstance(node.children[2], Leaf)
1667             and node.children[2].type == token.RPAR
1668             and isinstance(node.children[1], Leaf)
1669             and not (
1670                 node.children[1].type == token.NAME
1671                 and node.children[1].value == "yield"
1672             )
1673         ):
1674             node.children[0].value = ""
1675             node.children[2].value = ""
1676         yield from super().visit_default(node)
1677
1678     def visit_factor(self, node: Node) -> Iterator[Line]:
1679         """Force parentheses between a unary op and a binary power:
1680
1681         -2 ** 8 -> -(2 ** 8)
1682         """
1683         child = node.children[1]
1684         if child.type == syms.power and len(child.children) == 3:
1685             lpar = Leaf(token.LPAR, "(")
1686             rpar = Leaf(token.RPAR, ")")
1687             index = child.remove() or 0
1688             node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
1689         yield from self.visit_default(node)
1690
1691     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1692         """Increase indentation level, maybe yield a line."""
1693         # In blib2to3 INDENT never holds comments.
1694         yield from self.line(+1)
1695         yield from self.visit_default(node)
1696
1697     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1698         """Decrease indentation level, maybe yield a line."""
1699         # The current line might still wait for trailing comments.  At DEDENT time
1700         # there won't be any (they would be prefixes on the preceding NEWLINE).
1701         # Emit the line then.
1702         yield from self.line()
1703
1704         # While DEDENT has no value, its prefix may contain standalone comments
1705         # that belong to the current indentation level.  Get 'em.
1706         yield from self.visit_default(node)
1707
1708         # Finally, emit the dedent.
1709         yield from self.line(-1)
1710
1711     def visit_stmt(
1712         self, node: Node, keywords: Set[str], parens: Set[str]
1713     ) -> Iterator[Line]:
1714         """Visit a statement.
1715
1716         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1717         `def`, `with`, `class`, `assert` and assignments.
1718
1719         The relevant Python language `keywords` for a given statement will be
1720         NAME leaves within it. This methods puts those on a separate line.
1721
1722         `parens` holds a set of string leaf values immediately after which
1723         invisible parens should be put.
1724         """
1725         normalize_invisible_parens(node, parens_after=parens)
1726         for child in node.children:
1727             if child.type == token.NAME and child.value in keywords:  # type: ignore
1728                 yield from self.line()
1729
1730             yield from self.visit(child)
1731
1732     def visit_suite(self, node: Node) -> Iterator[Line]:
1733         """Visit a suite."""
1734         if self.is_pyi and is_stub_suite(node):
1735             yield from self.visit(node.children[2])
1736         else:
1737             yield from self.visit_default(node)
1738
1739     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1740         """Visit a statement without nested statements."""
1741         is_suite_like = node.parent and node.parent.type in STATEMENT
1742         if is_suite_like:
1743             if self.is_pyi and is_stub_body(node):
1744                 yield from self.visit_default(node)
1745             else:
1746                 yield from self.line(+1)
1747                 yield from self.visit_default(node)
1748                 yield from self.line(-1)
1749
1750         else:
1751             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1752                 yield from self.line()
1753             yield from self.visit_default(node)
1754
1755     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1756         """Visit `async def`, `async for`, `async with`."""
1757         yield from self.line()
1758
1759         children = iter(node.children)
1760         for child in children:
1761             yield from self.visit(child)
1762
1763             if child.type == token.ASYNC:
1764                 break
1765
1766         internal_stmt = next(children)
1767         for child in internal_stmt.children:
1768             yield from self.visit(child)
1769
1770     def visit_decorators(self, node: Node) -> Iterator[Line]:
1771         """Visit decorators."""
1772         for child in node.children:
1773             yield from self.line()
1774             yield from self.visit(child)
1775
1776     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1777         """Remove a semicolon and put the other statement on a separate line."""
1778         yield from self.line()
1779
1780     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1781         """End of file. Process outstanding comments and end with a newline."""
1782         yield from self.visit_default(leaf)
1783         yield from self.line()
1784
1785     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1786         if not self.current_line.bracket_tracker.any_open_brackets():
1787             yield from self.line()
1788         yield from self.visit_default(leaf)
1789
1790     def __attrs_post_init__(self) -> None:
1791         """You are in a twisty little maze of passages."""
1792         v = self.visit_stmt
1793         Ø: Set[str] = set()
1794         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1795         self.visit_if_stmt = partial(
1796             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1797         )
1798         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1799         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1800         self.visit_try_stmt = partial(
1801             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1802         )
1803         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1804         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1805         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1806         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1807         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1808         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1809         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1810         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1811         self.visit_async_funcdef = self.visit_async_stmt
1812         self.visit_decorated = self.visit_decorators
1813
1814
1815 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1816 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1817 OPENING_BRACKETS = set(BRACKET.keys())
1818 CLOSING_BRACKETS = set(BRACKET.values())
1819 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1820 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1821
1822
1823 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1824     """Return whitespace prefix if needed for the given `leaf`.
1825
1826     `complex_subscript` signals whether the given leaf is part of a subscription
1827     which has non-trivial arguments, like arithmetic expressions or function calls.
1828     """
1829     NO = ""
1830     SPACE = " "
1831     DOUBLESPACE = "  "
1832     t = leaf.type
1833     p = leaf.parent
1834     v = leaf.value
1835     if t in ALWAYS_NO_SPACE:
1836         return NO
1837
1838     if t == token.COMMENT:
1839         return DOUBLESPACE
1840
1841     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1842     if t == token.COLON and p.type not in {
1843         syms.subscript,
1844         syms.subscriptlist,
1845         syms.sliceop,
1846     }:
1847         return NO
1848
1849     prev = leaf.prev_sibling
1850     if not prev:
1851         prevp = preceding_leaf(p)
1852         if not prevp or prevp.type in OPENING_BRACKETS:
1853             return NO
1854
1855         if t == token.COLON:
1856             if prevp.type == token.COLON:
1857                 return NO
1858
1859             elif prevp.type != token.COMMA and not complex_subscript:
1860                 return NO
1861
1862             return SPACE
1863
1864         if prevp.type == token.EQUAL:
1865             if prevp.parent:
1866                 if prevp.parent.type in {
1867                     syms.arglist,
1868                     syms.argument,
1869                     syms.parameters,
1870                     syms.varargslist,
1871                 }:
1872                     return NO
1873
1874                 elif prevp.parent.type == syms.typedargslist:
1875                     # A bit hacky: if the equal sign has whitespace, it means we
1876                     # previously found it's a typed argument.  So, we're using
1877                     # that, too.
1878                     return prevp.prefix
1879
1880         elif prevp.type in VARARGS_SPECIALS:
1881             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1882                 return NO
1883
1884         elif prevp.type == token.COLON:
1885             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1886                 return SPACE if complex_subscript else NO
1887
1888         elif (
1889             prevp.parent
1890             and prevp.parent.type == syms.factor
1891             and prevp.type in MATH_OPERATORS
1892         ):
1893             return NO
1894
1895         elif (
1896             prevp.type == token.RIGHTSHIFT
1897             and prevp.parent
1898             and prevp.parent.type == syms.shift_expr
1899             and prevp.prev_sibling
1900             and prevp.prev_sibling.type == token.NAME
1901             and prevp.prev_sibling.value == "print"  # type: ignore
1902         ):
1903             # Python 2 print chevron
1904             return NO
1905
1906     elif prev.type in OPENING_BRACKETS:
1907         return NO
1908
1909     if p.type in {syms.parameters, syms.arglist}:
1910         # untyped function signatures or calls
1911         if not prev or prev.type != token.COMMA:
1912             return NO
1913
1914     elif p.type == syms.varargslist:
1915         # lambdas
1916         if prev and prev.type != token.COMMA:
1917             return NO
1918
1919     elif p.type == syms.typedargslist:
1920         # typed function signatures
1921         if not prev:
1922             return NO
1923
1924         if t == token.EQUAL:
1925             if prev.type != syms.tname:
1926                 return NO
1927
1928         elif prev.type == token.EQUAL:
1929             # A bit hacky: if the equal sign has whitespace, it means we
1930             # previously found it's a typed argument.  So, we're using that, too.
1931             return prev.prefix
1932
1933         elif prev.type != token.COMMA:
1934             return NO
1935
1936     elif p.type == syms.tname:
1937         # type names
1938         if not prev:
1939             prevp = preceding_leaf(p)
1940             if not prevp or prevp.type != token.COMMA:
1941                 return NO
1942
1943     elif p.type == syms.trailer:
1944         # attributes and calls
1945         if t == token.LPAR or t == token.RPAR:
1946             return NO
1947
1948         if not prev:
1949             if t == token.DOT:
1950                 prevp = preceding_leaf(p)
1951                 if not prevp or prevp.type != token.NUMBER:
1952                     return NO
1953
1954             elif t == token.LSQB:
1955                 return NO
1956
1957         elif prev.type != token.COMMA:
1958             return NO
1959
1960     elif p.type == syms.argument:
1961         # single argument
1962         if t == token.EQUAL:
1963             return NO
1964
1965         if not prev:
1966             prevp = preceding_leaf(p)
1967             if not prevp or prevp.type == token.LPAR:
1968                 return NO
1969
1970         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
1971             return NO
1972
1973     elif p.type == syms.decorator:
1974         # decorators
1975         return NO
1976
1977     elif p.type == syms.dotted_name:
1978         if prev:
1979             return NO
1980
1981         prevp = preceding_leaf(p)
1982         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1983             return NO
1984
1985     elif p.type == syms.classdef:
1986         if t == token.LPAR:
1987             return NO
1988
1989         if prev and prev.type == token.LPAR:
1990             return NO
1991
1992     elif p.type in {syms.subscript, syms.sliceop}:
1993         # indexing
1994         if not prev:
1995             assert p.parent is not None, "subscripts are always parented"
1996             if p.parent.type == syms.subscriptlist:
1997                 return SPACE
1998
1999             return NO
2000
2001         elif not complex_subscript:
2002             return NO
2003
2004     elif p.type == syms.atom:
2005         if prev and t == token.DOT:
2006             # dots, but not the first one.
2007             return NO
2008
2009     elif p.type == syms.dictsetmaker:
2010         # dict unpacking
2011         if prev and prev.type == token.DOUBLESTAR:
2012             return NO
2013
2014     elif p.type in {syms.factor, syms.star_expr}:
2015         # unary ops
2016         if not prev:
2017             prevp = preceding_leaf(p)
2018             if not prevp or prevp.type in OPENING_BRACKETS:
2019                 return NO
2020
2021             prevp_parent = prevp.parent
2022             assert prevp_parent is not None
2023             if prevp.type == token.COLON and prevp_parent.type in {
2024                 syms.subscript,
2025                 syms.sliceop,
2026             }:
2027                 return NO
2028
2029             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2030                 return NO
2031
2032         elif t in {token.NAME, token.NUMBER, token.STRING}:
2033             return NO
2034
2035     elif p.type == syms.import_from:
2036         if t == token.DOT:
2037             if prev and prev.type == token.DOT:
2038                 return NO
2039
2040         elif t == token.NAME:
2041             if v == "import":
2042                 return SPACE
2043
2044             if prev and prev.type == token.DOT:
2045                 return NO
2046
2047     elif p.type == syms.sliceop:
2048         return NO
2049
2050     return SPACE
2051
2052
2053 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2054     """Return the first leaf that precedes `node`, if any."""
2055     while node:
2056         res = node.prev_sibling
2057         if res:
2058             if isinstance(res, Leaf):
2059                 return res
2060
2061             try:
2062                 return list(res.leaves())[-1]
2063
2064             except IndexError:
2065                 return None
2066
2067         node = node.parent
2068     return None
2069
2070
2071 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2072     """Return the child of `ancestor` that contains `descendant`."""
2073     node: Optional[LN] = descendant
2074     while node and node.parent != ancestor:
2075         node = node.parent
2076     return node
2077
2078
2079 def container_of(leaf: Leaf) -> LN:
2080     """Return `leaf` or one of its ancestors that is the topmost container of it.
2081
2082     By "container" we mean a node where `leaf` is the very first child.
2083     """
2084     same_prefix = leaf.prefix
2085     container: LN = leaf
2086     while container:
2087         parent = container.parent
2088         if parent is None:
2089             break
2090
2091         if parent.children[0].prefix != same_prefix:
2092             break
2093
2094         if parent.type == syms.file_input:
2095             break
2096
2097         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2098             break
2099
2100         container = parent
2101     return container
2102
2103
2104 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2105     """Return the priority of the `leaf` delimiter, given a line break after it.
2106
2107     The delimiter priorities returned here are from those delimiters that would
2108     cause a line break after themselves.
2109
2110     Higher numbers are higher priority.
2111     """
2112     if leaf.type == token.COMMA:
2113         return COMMA_PRIORITY
2114
2115     return 0
2116
2117
2118 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2119     """Return the priority of the `leaf` delimiter, given a line break before it.
2120
2121     The delimiter priorities returned here are from those delimiters that would
2122     cause a line break before themselves.
2123
2124     Higher numbers are higher priority.
2125     """
2126     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2127         # * and ** might also be MATH_OPERATORS but in this case they are not.
2128         # Don't treat them as a delimiter.
2129         return 0
2130
2131     if (
2132         leaf.type == token.DOT
2133         and leaf.parent
2134         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2135         and (previous is None or previous.type in CLOSING_BRACKETS)
2136     ):
2137         return DOT_PRIORITY
2138
2139     if (
2140         leaf.type in MATH_OPERATORS
2141         and leaf.parent
2142         and leaf.parent.type not in {syms.factor, syms.star_expr}
2143     ):
2144         return MATH_PRIORITIES[leaf.type]
2145
2146     if leaf.type in COMPARATORS:
2147         return COMPARATOR_PRIORITY
2148
2149     if (
2150         leaf.type == token.STRING
2151         and previous is not None
2152         and previous.type == token.STRING
2153     ):
2154         return STRING_PRIORITY
2155
2156     if leaf.type not in {token.NAME, token.ASYNC}:
2157         return 0
2158
2159     if (
2160         leaf.value == "for"
2161         and leaf.parent
2162         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2163         or leaf.type == token.ASYNC
2164     ):
2165         if (
2166             not isinstance(leaf.prev_sibling, Leaf)
2167             or leaf.prev_sibling.value != "async"
2168         ):
2169             return COMPREHENSION_PRIORITY
2170
2171     if (
2172         leaf.value == "if"
2173         and leaf.parent
2174         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2175     ):
2176         return COMPREHENSION_PRIORITY
2177
2178     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2179         return TERNARY_PRIORITY
2180
2181     if leaf.value == "is":
2182         return COMPARATOR_PRIORITY
2183
2184     if (
2185         leaf.value == "in"
2186         and leaf.parent
2187         and leaf.parent.type in {syms.comp_op, syms.comparison}
2188         and not (
2189             previous is not None
2190             and previous.type == token.NAME
2191             and previous.value == "not"
2192         )
2193     ):
2194         return COMPARATOR_PRIORITY
2195
2196     if (
2197         leaf.value == "not"
2198         and leaf.parent
2199         and leaf.parent.type == syms.comp_op
2200         and not (
2201             previous is not None
2202             and previous.type == token.NAME
2203             and previous.value == "is"
2204         )
2205     ):
2206         return COMPARATOR_PRIORITY
2207
2208     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2209         return LOGIC_PRIORITY
2210
2211     return 0
2212
2213
2214 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2215 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2216
2217
2218 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2219     """Clean the prefix of the `leaf` and generate comments from it, if any.
2220
2221     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2222     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2223     move because it does away with modifying the grammar to include all the
2224     possible places in which comments can be placed.
2225
2226     The sad consequence for us though is that comments don't "belong" anywhere.
2227     This is why this function generates simple parentless Leaf objects for
2228     comments.  We simply don't know what the correct parent should be.
2229
2230     No matter though, we can live without this.  We really only need to
2231     differentiate between inline and standalone comments.  The latter don't
2232     share the line with any code.
2233
2234     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2235     are emitted with a fake STANDALONE_COMMENT token identifier.
2236     """
2237     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2238         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2239
2240
2241 @dataclass
2242 class ProtoComment:
2243     """Describes a piece of syntax that is a comment.
2244
2245     It's not a :class:`blib2to3.pytree.Leaf` so that:
2246
2247     * it can be cached (`Leaf` objects should not be reused more than once as
2248       they store their lineno, column, prefix, and parent information);
2249     * `newlines` and `consumed` fields are kept separate from the `value`. This
2250       simplifies handling of special marker comments like ``# fmt: off/on``.
2251     """
2252
2253     type: int  # token.COMMENT or STANDALONE_COMMENT
2254     value: str  # content of the comment
2255     newlines: int  # how many newlines before the comment
2256     consumed: int  # how many characters of the original leaf's prefix did we consume
2257
2258
2259 @lru_cache(maxsize=4096)
2260 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2261     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2262     result: List[ProtoComment] = []
2263     if not prefix or "#" not in prefix:
2264         return result
2265
2266     consumed = 0
2267     nlines = 0
2268     ignored_lines = 0
2269     for index, line in enumerate(prefix.split("\n")):
2270         consumed += len(line) + 1  # adding the length of the split '\n'
2271         line = line.lstrip()
2272         if not line:
2273             nlines += 1
2274         if not line.startswith("#"):
2275             # Escaped newlines outside of a comment are not really newlines at
2276             # all. We treat a single-line comment following an escaped newline
2277             # as a simple trailing comment.
2278             if line.endswith("\\"):
2279                 ignored_lines += 1
2280             continue
2281
2282         if index == ignored_lines and not is_endmarker:
2283             comment_type = token.COMMENT  # simple trailing comment
2284         else:
2285             comment_type = STANDALONE_COMMENT
2286         comment = make_comment(line)
2287         result.append(
2288             ProtoComment(
2289                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2290             )
2291         )
2292         nlines = 0
2293     return result
2294
2295
2296 def make_comment(content: str) -> str:
2297     """Return a consistently formatted comment from the given `content` string.
2298
2299     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2300     space between the hash sign and the content.
2301
2302     If `content` didn't start with a hash sign, one is provided.
2303     """
2304     content = content.rstrip()
2305     if not content:
2306         return "#"
2307
2308     if content[0] == "#":
2309         content = content[1:]
2310     if content and content[0] not in " !:#'%":
2311         content = " " + content
2312     return "#" + content
2313
2314
2315 def split_line(
2316     line: Line,
2317     line_length: int,
2318     inner: bool = False,
2319     features: Collection[Feature] = (),
2320 ) -> Iterator[Line]:
2321     """Split a `line` into potentially many lines.
2322
2323     They should fit in the allotted `line_length` but might not be able to.
2324     `inner` signifies that there were a pair of brackets somewhere around the
2325     current `line`, possibly transitively. This means we can fallback to splitting
2326     by delimiters if the LHS/RHS don't yield any results.
2327
2328     `features` are syntactical features that may be used in the output.
2329     """
2330     if line.is_comment:
2331         yield line
2332         return
2333
2334     line_str = str(line).strip("\n")
2335
2336     if (
2337         not line.contains_uncollapsable_type_comments()
2338         and not line.should_explode
2339         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2340     ):
2341         yield line
2342         return
2343
2344     split_funcs: List[SplitFunc]
2345     if line.is_def:
2346         split_funcs = [left_hand_split]
2347     else:
2348
2349         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2350             for omit in generate_trailers_to_omit(line, line_length):
2351                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2352                 if is_line_short_enough(lines[0], line_length=line_length):
2353                     yield from lines
2354                     return
2355
2356             # All splits failed, best effort split with no omits.
2357             # This mostly happens to multiline strings that are by definition
2358             # reported as not fitting a single line.
2359             yield from right_hand_split(line, line_length, features=features)
2360
2361         if line.inside_brackets:
2362             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2363         else:
2364             split_funcs = [rhs]
2365     for split_func in split_funcs:
2366         # We are accumulating lines in `result` because we might want to abort
2367         # mission and return the original line in the end, or attempt a different
2368         # split altogether.
2369         result: List[Line] = []
2370         try:
2371             for l in split_func(line, features):
2372                 if str(l).strip("\n") == line_str:
2373                     raise CannotSplit("Split function returned an unchanged result")
2374
2375                 result.extend(
2376                     split_line(
2377                         l, line_length=line_length, inner=True, features=features
2378                     )
2379                 )
2380         except CannotSplit:
2381             continue
2382
2383         else:
2384             yield from result
2385             break
2386
2387     else:
2388         yield line
2389
2390
2391 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2392     """Split line into many lines, starting with the first matching bracket pair.
2393
2394     Note: this usually looks weird, only use this for function definitions.
2395     Prefer RHS otherwise.  This is why this function is not symmetrical with
2396     :func:`right_hand_split` which also handles optional parentheses.
2397     """
2398     tail_leaves: List[Leaf] = []
2399     body_leaves: List[Leaf] = []
2400     head_leaves: List[Leaf] = []
2401     current_leaves = head_leaves
2402     matching_bracket = None
2403     for leaf in line.leaves:
2404         if (
2405             current_leaves is body_leaves
2406             and leaf.type in CLOSING_BRACKETS
2407             and leaf.opening_bracket is matching_bracket
2408         ):
2409             current_leaves = tail_leaves if body_leaves else head_leaves
2410         current_leaves.append(leaf)
2411         if current_leaves is head_leaves:
2412             if leaf.type in OPENING_BRACKETS:
2413                 matching_bracket = leaf
2414                 current_leaves = body_leaves
2415     if not matching_bracket:
2416         raise CannotSplit("No brackets found")
2417
2418     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2419     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2420     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2421     bracket_split_succeeded_or_raise(head, body, tail)
2422     for result in (head, body, tail):
2423         if result:
2424             yield result
2425
2426
2427 def right_hand_split(
2428     line: Line,
2429     line_length: int,
2430     features: Collection[Feature] = (),
2431     omit: Collection[LeafID] = (),
2432 ) -> Iterator[Line]:
2433     """Split line into many lines, starting with the last matching bracket pair.
2434
2435     If the split was by optional parentheses, attempt splitting without them, too.
2436     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2437     this split.
2438
2439     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2440     """
2441     tail_leaves: List[Leaf] = []
2442     body_leaves: List[Leaf] = []
2443     head_leaves: List[Leaf] = []
2444     current_leaves = tail_leaves
2445     opening_bracket = None
2446     closing_bracket = None
2447     for leaf in reversed(line.leaves):
2448         if current_leaves is body_leaves:
2449             if leaf is opening_bracket:
2450                 current_leaves = head_leaves if body_leaves else tail_leaves
2451         current_leaves.append(leaf)
2452         if current_leaves is tail_leaves:
2453             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2454                 opening_bracket = leaf.opening_bracket
2455                 closing_bracket = leaf
2456                 current_leaves = body_leaves
2457     if not (opening_bracket and closing_bracket and head_leaves):
2458         # If there is no opening or closing_bracket that means the split failed and
2459         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2460         # the matching `opening_bracket` wasn't available on `line` anymore.
2461         raise CannotSplit("No brackets found")
2462
2463     tail_leaves.reverse()
2464     body_leaves.reverse()
2465     head_leaves.reverse()
2466     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2467     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2468     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2469     bracket_split_succeeded_or_raise(head, body, tail)
2470     if (
2471         # the body shouldn't be exploded
2472         not body.should_explode
2473         # the opening bracket is an optional paren
2474         and opening_bracket.type == token.LPAR
2475         and not opening_bracket.value
2476         # the closing bracket is an optional paren
2477         and closing_bracket.type == token.RPAR
2478         and not closing_bracket.value
2479         # it's not an import (optional parens are the only thing we can split on
2480         # in this case; attempting a split without them is a waste of time)
2481         and not line.is_import
2482         # there are no standalone comments in the body
2483         and not body.contains_standalone_comments(0)
2484         # and we can actually remove the parens
2485         and can_omit_invisible_parens(body, line_length)
2486     ):
2487         omit = {id(closing_bracket), *omit}
2488         try:
2489             yield from right_hand_split(line, line_length, features=features, omit=omit)
2490             return
2491
2492         except CannotSplit:
2493             if not (
2494                 can_be_split(body)
2495                 or is_line_short_enough(body, line_length=line_length)
2496             ):
2497                 raise CannotSplit(
2498                     "Splitting failed, body is still too long and can't be split."
2499                 )
2500
2501             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2502                 raise CannotSplit(
2503                     "The current optional pair of parentheses is bound to fail to "
2504                     "satisfy the splitting algorithm because the head or the tail "
2505                     "contains multiline strings which by definition never fit one "
2506                     "line."
2507                 )
2508
2509     ensure_visible(opening_bracket)
2510     ensure_visible(closing_bracket)
2511     for result in (head, body, tail):
2512         if result:
2513             yield result
2514
2515
2516 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2517     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2518
2519     Do nothing otherwise.
2520
2521     A left- or right-hand split is based on a pair of brackets. Content before
2522     (and including) the opening bracket is left on one line, content inside the
2523     brackets is put on a separate line, and finally content starting with and
2524     following the closing bracket is put on a separate line.
2525
2526     Those are called `head`, `body`, and `tail`, respectively. If the split
2527     produced the same line (all content in `head`) or ended up with an empty `body`
2528     and the `tail` is just the closing bracket, then it's considered failed.
2529     """
2530     tail_len = len(str(tail).strip())
2531     if not body:
2532         if tail_len == 0:
2533             raise CannotSplit("Splitting brackets produced the same line")
2534
2535         elif tail_len < 3:
2536             raise CannotSplit(
2537                 f"Splitting brackets on an empty body to save "
2538                 f"{tail_len} characters is not worth it"
2539             )
2540
2541
2542 def bracket_split_build_line(
2543     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2544 ) -> Line:
2545     """Return a new line with given `leaves` and respective comments from `original`.
2546
2547     If `is_body` is True, the result line is one-indented inside brackets and as such
2548     has its first leaf's prefix normalized and a trailing comma added when expected.
2549     """
2550     result = Line(depth=original.depth)
2551     if is_body:
2552         result.inside_brackets = True
2553         result.depth += 1
2554         if leaves:
2555             # Since body is a new indent level, remove spurious leading whitespace.
2556             normalize_prefix(leaves[0], inside_brackets=True)
2557             # Ensure a trailing comma for imports and standalone function arguments, but
2558             # be careful not to add one after any comments.
2559             no_commas = original.is_def and not any(
2560                 l.type == token.COMMA for l in leaves
2561             )
2562
2563             if original.is_import or no_commas:
2564                 for i in range(len(leaves) - 1, -1, -1):
2565                     if leaves[i].type == STANDALONE_COMMENT:
2566                         continue
2567                     elif leaves[i].type == token.COMMA:
2568                         break
2569                     else:
2570                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2571                         break
2572     # Populate the line
2573     for leaf in leaves:
2574         result.append(leaf, preformatted=True)
2575         for comment_after in original.comments_after(leaf):
2576             result.append(comment_after, preformatted=True)
2577     if is_body:
2578         result.should_explode = should_explode(result, opening_bracket)
2579     return result
2580
2581
2582 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2583     """Normalize prefix of the first leaf in every line returned by `split_func`.
2584
2585     This is a decorator over relevant split functions.
2586     """
2587
2588     @wraps(split_func)
2589     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2590         for l in split_func(line, features):
2591             normalize_prefix(l.leaves[0], inside_brackets=True)
2592             yield l
2593
2594     return split_wrapper
2595
2596
2597 @dont_increase_indentation
2598 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2599     """Split according to delimiters of the highest priority.
2600
2601     If the appropriate Features are given, the split will add trailing commas
2602     also in function signatures and calls that contain `*` and `**`.
2603     """
2604     try:
2605         last_leaf = line.leaves[-1]
2606     except IndexError:
2607         raise CannotSplit("Line empty")
2608
2609     bt = line.bracket_tracker
2610     try:
2611         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2612     except ValueError:
2613         raise CannotSplit("No delimiters found")
2614
2615     if delimiter_priority == DOT_PRIORITY:
2616         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2617             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2618
2619     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2620     lowest_depth = sys.maxsize
2621     trailing_comma_safe = True
2622
2623     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2624         """Append `leaf` to current line or to new line if appending impossible."""
2625         nonlocal current_line
2626         try:
2627             current_line.append_safe(leaf, preformatted=True)
2628         except ValueError:
2629             yield current_line
2630
2631             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2632             current_line.append(leaf)
2633
2634     for leaf in line.leaves:
2635         yield from append_to_line(leaf)
2636
2637         for comment_after in line.comments_after(leaf):
2638             yield from append_to_line(comment_after)
2639
2640         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2641         if leaf.bracket_depth == lowest_depth:
2642             if is_vararg(leaf, within={syms.typedargslist}):
2643                 trailing_comma_safe = (
2644                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2645                 )
2646             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2647                 trailing_comma_safe = (
2648                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2649                 )
2650
2651         leaf_priority = bt.delimiters.get(id(leaf))
2652         if leaf_priority == delimiter_priority:
2653             yield current_line
2654
2655             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2656     if current_line:
2657         if (
2658             trailing_comma_safe
2659             and delimiter_priority == COMMA_PRIORITY
2660             and current_line.leaves[-1].type != token.COMMA
2661             and current_line.leaves[-1].type != STANDALONE_COMMENT
2662         ):
2663             current_line.append(Leaf(token.COMMA, ","))
2664         yield current_line
2665
2666
2667 @dont_increase_indentation
2668 def standalone_comment_split(
2669     line: Line, features: Collection[Feature] = ()
2670 ) -> Iterator[Line]:
2671     """Split standalone comments from the rest of the line."""
2672     if not line.contains_standalone_comments(0):
2673         raise CannotSplit("Line does not have any standalone comments")
2674
2675     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2676
2677     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2678         """Append `leaf` to current line or to new line if appending impossible."""
2679         nonlocal current_line
2680         try:
2681             current_line.append_safe(leaf, preformatted=True)
2682         except ValueError:
2683             yield current_line
2684
2685             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2686             current_line.append(leaf)
2687
2688     for leaf in line.leaves:
2689         yield from append_to_line(leaf)
2690
2691         for comment_after in line.comments_after(leaf):
2692             yield from append_to_line(comment_after)
2693
2694     if current_line:
2695         yield current_line
2696
2697
2698 def is_import(leaf: Leaf) -> bool:
2699     """Return True if the given leaf starts an import statement."""
2700     p = leaf.parent
2701     t = leaf.type
2702     v = leaf.value
2703     return bool(
2704         t == token.NAME
2705         and (
2706             (v == "import" and p and p.type == syms.import_name)
2707             or (v == "from" and p and p.type == syms.import_from)
2708         )
2709     )
2710
2711
2712 def is_type_comment(leaf: Leaf) -> bool:
2713     """Return True if the given leaf is a special comment.
2714     Only returns true for type comments for now."""
2715     t = leaf.type
2716     v = leaf.value
2717     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2718
2719
2720 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2721     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2722     else.
2723
2724     Note: don't use backslashes for formatting or you'll lose your voting rights.
2725     """
2726     if not inside_brackets:
2727         spl = leaf.prefix.split("#")
2728         if "\\" not in spl[0]:
2729             nl_count = spl[-1].count("\n")
2730             if len(spl) > 1:
2731                 nl_count -= 1
2732             leaf.prefix = "\n" * nl_count
2733             return
2734
2735     leaf.prefix = ""
2736
2737
2738 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2739     """Make all string prefixes lowercase.
2740
2741     If remove_u_prefix is given, also removes any u prefix from the string.
2742
2743     Note: Mutates its argument.
2744     """
2745     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2746     assert match is not None, f"failed to match string {leaf.value!r}"
2747     orig_prefix = match.group(1)
2748     new_prefix = orig_prefix.lower()
2749     if remove_u_prefix:
2750         new_prefix = new_prefix.replace("u", "")
2751     leaf.value = f"{new_prefix}{match.group(2)}"
2752
2753
2754 def normalize_string_quotes(leaf: Leaf) -> None:
2755     """Prefer double quotes but only if it doesn't cause more escaping.
2756
2757     Adds or removes backslashes as appropriate. Doesn't parse and fix
2758     strings nested in f-strings (yet).
2759
2760     Note: Mutates its argument.
2761     """
2762     value = leaf.value.lstrip("furbFURB")
2763     if value[:3] == '"""':
2764         return
2765
2766     elif value[:3] == "'''":
2767         orig_quote = "'''"
2768         new_quote = '"""'
2769     elif value[0] == '"':
2770         orig_quote = '"'
2771         new_quote = "'"
2772     else:
2773         orig_quote = "'"
2774         new_quote = '"'
2775     first_quote_pos = leaf.value.find(orig_quote)
2776     if first_quote_pos == -1:
2777         return  # There's an internal error
2778
2779     prefix = leaf.value[:first_quote_pos]
2780     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2781     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2782     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2783     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2784     if "r" in prefix.casefold():
2785         if unescaped_new_quote.search(body):
2786             # There's at least one unescaped new_quote in this raw string
2787             # so converting is impossible
2788             return
2789
2790         # Do not introduce or remove backslashes in raw strings
2791         new_body = body
2792     else:
2793         # remove unnecessary escapes
2794         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2795         if body != new_body:
2796             # Consider the string without unnecessary escapes as the original
2797             body = new_body
2798             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2799         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2800         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2801     if "f" in prefix.casefold():
2802         matches = re.findall(
2803             r"""
2804             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2805                 ([^{].*?)  # contents of the brackets except if begins with {{
2806             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2807             """,
2808             new_body,
2809             re.VERBOSE,
2810         )
2811         for m in matches:
2812             if "\\" in str(m):
2813                 # Do not introduce backslashes in interpolated expressions
2814                 return
2815     if new_quote == '"""' and new_body[-1:] == '"':
2816         # edge case:
2817         new_body = new_body[:-1] + '\\"'
2818     orig_escape_count = body.count("\\")
2819     new_escape_count = new_body.count("\\")
2820     if new_escape_count > orig_escape_count:
2821         return  # Do not introduce more escaping
2822
2823     if new_escape_count == orig_escape_count and orig_quote == '"':
2824         return  # Prefer double quotes
2825
2826     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2827
2828
2829 def normalize_numeric_literal(leaf: Leaf) -> None:
2830     """Normalizes numeric (float, int, and complex) literals.
2831
2832     All letters used in the representation are normalized to lowercase (except
2833     in Python 2 long literals).
2834     """
2835     text = leaf.value.lower()
2836     if text.startswith(("0o", "0b")):
2837         # Leave octal and binary literals alone.
2838         pass
2839     elif text.startswith("0x"):
2840         # Change hex literals to upper case.
2841         before, after = text[:2], text[2:]
2842         text = f"{before}{after.upper()}"
2843     elif "e" in text:
2844         before, after = text.split("e")
2845         sign = ""
2846         if after.startswith("-"):
2847             after = after[1:]
2848             sign = "-"
2849         elif after.startswith("+"):
2850             after = after[1:]
2851         before = format_float_or_int_string(before)
2852         text = f"{before}e{sign}{after}"
2853     elif text.endswith(("j", "l")):
2854         number = text[:-1]
2855         suffix = text[-1]
2856         # Capitalize in "2L" because "l" looks too similar to "1".
2857         if suffix == "l":
2858             suffix = "L"
2859         text = f"{format_float_or_int_string(number)}{suffix}"
2860     else:
2861         text = format_float_or_int_string(text)
2862     leaf.value = text
2863
2864
2865 def format_float_or_int_string(text: str) -> str:
2866     """Formats a float string like "1.0"."""
2867     if "." not in text:
2868         return text
2869
2870     before, after = text.split(".")
2871     return f"{before or 0}.{after or 0}"
2872
2873
2874 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2875     """Make existing optional parentheses invisible or create new ones.
2876
2877     `parens_after` is a set of string leaf values immediately after which parens
2878     should be put.
2879
2880     Standardizes on visible parentheses for single-element tuples, and keeps
2881     existing visible parentheses for other tuples and generator expressions.
2882     """
2883     for pc in list_comments(node.prefix, is_endmarker=False):
2884         if pc.value in FMT_OFF:
2885             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2886             return
2887
2888     check_lpar = False
2889     for index, child in enumerate(list(node.children)):
2890         # Add parentheses around long tuple unpacking in assignments.
2891         if (
2892             index == 0
2893             and isinstance(child, Node)
2894             and child.type == syms.testlist_star_expr
2895         ):
2896             check_lpar = True
2897
2898         if check_lpar:
2899             if is_walrus_assignment(child):
2900                 continue
2901             if child.type == syms.atom:
2902                 # Determines if the underlying atom should be surrounded with
2903                 # invisible params - also makes parens invisible recursively
2904                 # within the atom and removes repeated invisible parens within
2905                 # the atom
2906                 should_surround_with_parens = maybe_make_parens_invisible_in_atom(
2907                     child, parent=node
2908                 )
2909
2910                 if should_surround_with_parens:
2911                     lpar = Leaf(token.LPAR, "")
2912                     rpar = Leaf(token.RPAR, "")
2913                     index = child.remove() or 0
2914                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2915             elif is_one_tuple(child):
2916                 # wrap child in visible parentheses
2917                 lpar = Leaf(token.LPAR, "(")
2918                 rpar = Leaf(token.RPAR, ")")
2919                 child.remove()
2920                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2921             elif node.type == syms.import_from:
2922                 # "import from" nodes store parentheses directly as part of
2923                 # the statement
2924                 if child.type == token.LPAR:
2925                     # make parentheses invisible
2926                     child.value = ""  # type: ignore
2927                     node.children[-1].value = ""  # type: ignore
2928                 elif child.type != token.STAR:
2929                     # insert invisible parentheses
2930                     node.insert_child(index, Leaf(token.LPAR, ""))
2931                     node.append_child(Leaf(token.RPAR, ""))
2932                 break
2933
2934             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2935                 # wrap child in invisible parentheses
2936                 lpar = Leaf(token.LPAR, "")
2937                 rpar = Leaf(token.RPAR, "")
2938                 index = child.remove() or 0
2939                 prefix = child.prefix
2940                 child.prefix = ""
2941                 new_child = Node(syms.atom, [lpar, child, rpar])
2942                 new_child.prefix = prefix
2943                 node.insert_child(index, new_child)
2944
2945         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2946
2947
2948 def normalize_fmt_off(node: Node) -> None:
2949     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2950     try_again = True
2951     while try_again:
2952         try_again = convert_one_fmt_off_pair(node)
2953
2954
2955 def convert_one_fmt_off_pair(node: Node) -> bool:
2956     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2957
2958     Returns True if a pair was converted.
2959     """
2960     for leaf in node.leaves():
2961         previous_consumed = 0
2962         for comment in list_comments(leaf.prefix, is_endmarker=False):
2963             if comment.value in FMT_OFF:
2964                 # We only want standalone comments. If there's no previous leaf or
2965                 # the previous leaf is indentation, it's a standalone comment in
2966                 # disguise.
2967                 if comment.type != STANDALONE_COMMENT:
2968                     prev = preceding_leaf(leaf)
2969                     if prev and prev.type not in WHITESPACE:
2970                         continue
2971
2972                 ignored_nodes = list(generate_ignored_nodes(leaf))
2973                 if not ignored_nodes:
2974                     continue
2975
2976                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2977                 parent = first.parent
2978                 prefix = first.prefix
2979                 first.prefix = prefix[comment.consumed :]
2980                 hidden_value = (
2981                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2982                 )
2983                 if hidden_value.endswith("\n"):
2984                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2985                     # leaf (possibly followed by a DEDENT).
2986                     hidden_value = hidden_value[:-1]
2987                 first_idx = None
2988                 for ignored in ignored_nodes:
2989                     index = ignored.remove()
2990                     if first_idx is None:
2991                         first_idx = index
2992                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2993                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2994                 parent.insert_child(
2995                     first_idx,
2996                     Leaf(
2997                         STANDALONE_COMMENT,
2998                         hidden_value,
2999                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3000                     ),
3001                 )
3002                 return True
3003
3004             previous_consumed = comment.consumed
3005
3006     return False
3007
3008
3009 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3010     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3011
3012     Stops at the end of the block.
3013     """
3014     container: Optional[LN] = container_of(leaf)
3015     while container is not None and container.type != token.ENDMARKER:
3016         for comment in list_comments(container.prefix, is_endmarker=False):
3017             if comment.value in FMT_ON:
3018                 return
3019
3020         yield container
3021
3022         container = container.next_sibling
3023
3024
3025 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3026     """If it's safe, make the parens in the atom `node` invisible, recursively.
3027     Additionally, remove repeated, adjacent invisible parens from the atom `node`
3028     as they are redundant.
3029
3030     Returns whether the node should itself be wrapped in invisible parentheses.
3031
3032     """
3033     if (
3034         node.type != syms.atom
3035         or is_empty_tuple(node)
3036         or is_one_tuple(node)
3037         or (is_yield(node) and parent.type != syms.expr_stmt)
3038         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3039     ):
3040         return False
3041
3042     first = node.children[0]
3043     last = node.children[-1]
3044     if first.type == token.LPAR and last.type == token.RPAR:
3045         middle = node.children[1]
3046         # make parentheses invisible
3047         first.value = ""  # type: ignore
3048         last.value = ""  # type: ignore
3049         maybe_make_parens_invisible_in_atom(middle, parent=parent)
3050
3051         if is_atom_with_invisible_parens(middle):
3052             # Strip the invisible parens from `middle` by replacing
3053             # it with the child in-between the invisible parens
3054             middle.replace(middle.children[1])
3055
3056         return False
3057
3058     return True
3059
3060
3061 def is_atom_with_invisible_parens(node: LN) -> bool:
3062     """Given a `LN`, determines whether it's an atom `node` with invisible
3063     parens. Useful in dedupe-ing and normalizing parens.
3064     """
3065     if isinstance(node, Leaf) or node.type != syms.atom:
3066         return False
3067
3068     first, last = node.children[0], node.children[-1]
3069     return (
3070         isinstance(first, Leaf)
3071         and first.type == token.LPAR
3072         and first.value == ""
3073         and isinstance(last, Leaf)
3074         and last.type == token.RPAR
3075         and last.value == ""
3076     )
3077
3078
3079 def is_empty_tuple(node: LN) -> bool:
3080     """Return True if `node` holds an empty tuple."""
3081     return (
3082         node.type == syms.atom
3083         and len(node.children) == 2
3084         and node.children[0].type == token.LPAR
3085         and node.children[1].type == token.RPAR
3086     )
3087
3088
3089 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3090     """Returns `wrapped` if `node` is of the shape ( wrapped ).
3091
3092     Parenthesis can be optional. Returns None otherwise"""
3093     if len(node.children) != 3:
3094         return None
3095     lpar, wrapped, rpar = node.children
3096     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3097         return None
3098
3099     return wrapped
3100
3101
3102 def is_one_tuple(node: LN) -> bool:
3103     """Return True if `node` holds a tuple with one element, with or without parens."""
3104     if node.type == syms.atom:
3105         gexp = unwrap_singleton_parenthesis(node)
3106         if gexp is None or gexp.type != syms.testlist_gexp:
3107             return False
3108
3109         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3110
3111     return (
3112         node.type in IMPLICIT_TUPLE
3113         and len(node.children) == 2
3114         and node.children[1].type == token.COMMA
3115     )
3116
3117
3118 def is_walrus_assignment(node: LN) -> bool:
3119     """Return True iff `node` is of the shape ( test := test )"""
3120     inner = unwrap_singleton_parenthesis(node)
3121     return inner is not None and inner.type == syms.namedexpr_test
3122
3123
3124 def is_yield(node: LN) -> bool:
3125     """Return True if `node` holds a `yield` or `yield from` expression."""
3126     if node.type == syms.yield_expr:
3127         return True
3128
3129     if node.type == token.NAME and node.value == "yield":  # type: ignore
3130         return True
3131
3132     if node.type != syms.atom:
3133         return False
3134
3135     if len(node.children) != 3:
3136         return False
3137
3138     lpar, expr, rpar = node.children
3139     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3140         return is_yield(expr)
3141
3142     return False
3143
3144
3145 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3146     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3147
3148     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3149     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3150     extended iterable unpacking (PEP 3132) and additional unpacking
3151     generalizations (PEP 448).
3152     """
3153     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3154         return False
3155
3156     p = leaf.parent
3157     if p.type == syms.star_expr:
3158         # Star expressions are also used as assignment targets in extended
3159         # iterable unpacking (PEP 3132).  See what its parent is instead.
3160         if not p.parent:
3161             return False
3162
3163         p = p.parent
3164
3165     return p.type in within
3166
3167
3168 def is_multiline_string(leaf: Leaf) -> bool:
3169     """Return True if `leaf` is a multiline string that actually spans many lines."""
3170     value = leaf.value.lstrip("furbFURB")
3171     return value[:3] in {'"""', "'''"} and "\n" in value
3172
3173
3174 def is_stub_suite(node: Node) -> bool:
3175     """Return True if `node` is a suite with a stub body."""
3176     if (
3177         len(node.children) != 4
3178         or node.children[0].type != token.NEWLINE
3179         or node.children[1].type != token.INDENT
3180         or node.children[3].type != token.DEDENT
3181     ):
3182         return False
3183
3184     return is_stub_body(node.children[2])
3185
3186
3187 def is_stub_body(node: LN) -> bool:
3188     """Return True if `node` is a simple statement containing an ellipsis."""
3189     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3190         return False
3191
3192     if len(node.children) != 2:
3193         return False
3194
3195     child = node.children[0]
3196     return (
3197         child.type == syms.atom
3198         and len(child.children) == 3
3199         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3200     )
3201
3202
3203 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3204     """Return maximum delimiter priority inside `node`.
3205
3206     This is specific to atoms with contents contained in a pair of parentheses.
3207     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3208     """
3209     if node.type != syms.atom:
3210         return 0
3211
3212     first = node.children[0]
3213     last = node.children[-1]
3214     if not (first.type == token.LPAR and last.type == token.RPAR):
3215         return 0
3216
3217     bt = BracketTracker()
3218     for c in node.children[1:-1]:
3219         if isinstance(c, Leaf):
3220             bt.mark(c)
3221         else:
3222             for leaf in c.leaves():
3223                 bt.mark(leaf)
3224     try:
3225         return bt.max_delimiter_priority()
3226
3227     except ValueError:
3228         return 0
3229
3230
3231 def ensure_visible(leaf: Leaf) -> None:
3232     """Make sure parentheses are visible.
3233
3234     They could be invisible as part of some statements (see
3235     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3236     """
3237     if leaf.type == token.LPAR:
3238         leaf.value = "("
3239     elif leaf.type == token.RPAR:
3240         leaf.value = ")"
3241
3242
3243 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3244     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3245
3246     if not (
3247         opening_bracket.parent
3248         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3249         and opening_bracket.value in "[{("
3250     ):
3251         return False
3252
3253     try:
3254         last_leaf = line.leaves[-1]
3255         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3256         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3257     except (IndexError, ValueError):
3258         return False
3259
3260     return max_priority == COMMA_PRIORITY
3261
3262
3263 def get_features_used(node: Node) -> Set[Feature]:
3264     """Return a set of (relatively) new Python features used in this file.
3265
3266     Currently looking for:
3267     - f-strings;
3268     - underscores in numeric literals;
3269     - trailing commas after * or ** in function signatures and calls;
3270     - positional only arguments in function signatures and lambdas;
3271     """
3272     features: Set[Feature] = set()
3273     for n in node.pre_order():
3274         if n.type == token.STRING:
3275             value_head = n.value[:2]  # type: ignore
3276             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3277                 features.add(Feature.F_STRINGS)
3278
3279         elif n.type == token.NUMBER:
3280             if "_" in n.value:  # type: ignore
3281                 features.add(Feature.NUMERIC_UNDERSCORES)
3282
3283         elif n.type == token.SLASH:
3284             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3285                 features.add(Feature.POS_ONLY_ARGUMENTS)
3286
3287         elif n.type == token.COLONEQUAL:
3288             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3289
3290         elif (
3291             n.type in {syms.typedargslist, syms.arglist}
3292             and n.children
3293             and n.children[-1].type == token.COMMA
3294         ):
3295             if n.type == syms.typedargslist:
3296                 feature = Feature.TRAILING_COMMA_IN_DEF
3297             else:
3298                 feature = Feature.TRAILING_COMMA_IN_CALL
3299
3300             for ch in n.children:
3301                 if ch.type in STARS:
3302                     features.add(feature)
3303
3304                 if ch.type == syms.argument:
3305                     for argch in ch.children:
3306                         if argch.type in STARS:
3307                             features.add(feature)
3308
3309     return features
3310
3311
3312 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3313     """Detect the version to target based on the nodes used."""
3314     features = get_features_used(node)
3315     return {
3316         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3317     }
3318
3319
3320 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3321     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3322
3323     Brackets can be omitted if the entire trailer up to and including
3324     a preceding closing bracket fits in one line.
3325
3326     Yielded sets are cumulative (contain results of previous yields, too).  First
3327     set is empty.
3328     """
3329
3330     omit: Set[LeafID] = set()
3331     yield omit
3332
3333     length = 4 * line.depth
3334     opening_bracket = None
3335     closing_bracket = None
3336     inner_brackets: Set[LeafID] = set()
3337     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3338         length += leaf_length
3339         if length > line_length:
3340             break
3341
3342         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3343         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3344             break
3345
3346         if opening_bracket:
3347             if leaf is opening_bracket:
3348                 opening_bracket = None
3349             elif leaf.type in CLOSING_BRACKETS:
3350                 inner_brackets.add(id(leaf))
3351         elif leaf.type in CLOSING_BRACKETS:
3352             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3353                 # Empty brackets would fail a split so treat them as "inner"
3354                 # brackets (e.g. only add them to the `omit` set if another
3355                 # pair of brackets was good enough.
3356                 inner_brackets.add(id(leaf))
3357                 continue
3358
3359             if closing_bracket:
3360                 omit.add(id(closing_bracket))
3361                 omit.update(inner_brackets)
3362                 inner_brackets.clear()
3363                 yield omit
3364
3365             if leaf.value:
3366                 opening_bracket = leaf.opening_bracket
3367                 closing_bracket = leaf
3368
3369
3370 def get_future_imports(node: Node) -> Set[str]:
3371     """Return a set of __future__ imports in the file."""
3372     imports: Set[str] = set()
3373
3374     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3375         for child in children:
3376             if isinstance(child, Leaf):
3377                 if child.type == token.NAME:
3378                     yield child.value
3379             elif child.type == syms.import_as_name:
3380                 orig_name = child.children[0]
3381                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3382                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3383                 yield orig_name.value
3384             elif child.type == syms.import_as_names:
3385                 yield from get_imports_from_children(child.children)
3386             else:
3387                 raise AssertionError("Invalid syntax parsing imports")
3388
3389     for child in node.children:
3390         if child.type != syms.simple_stmt:
3391             break
3392         first_child = child.children[0]
3393         if isinstance(first_child, Leaf):
3394             # Continue looking if we see a docstring; otherwise stop.
3395             if (
3396                 len(child.children) == 2
3397                 and first_child.type == token.STRING
3398                 and child.children[1].type == token.NEWLINE
3399             ):
3400                 continue
3401             else:
3402                 break
3403         elif first_child.type == syms.import_from:
3404             module_name = first_child.children[1]
3405             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3406                 break
3407             imports |= set(get_imports_from_children(first_child.children[3:]))
3408         else:
3409             break
3410     return imports
3411
3412
3413 def gen_python_files_in_dir(
3414     path: Path,
3415     root: Path,
3416     include: Pattern[str],
3417     exclude: Pattern[str],
3418     report: "Report",
3419 ) -> Iterator[Path]:
3420     """Generate all files under `path` whose paths are not excluded by the
3421     `exclude` regex, but are included by the `include` regex.
3422
3423     Symbolic links pointing outside of the `root` directory are ignored.
3424
3425     `report` is where output about exclusions goes.
3426     """
3427     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3428     for child in path.iterdir():
3429         try:
3430             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3431         except ValueError:
3432             if child.is_symlink():
3433                 report.path_ignored(
3434                     child, f"is a symbolic link that points outside {root}"
3435                 )
3436                 continue
3437
3438             raise
3439
3440         if child.is_dir():
3441             normalized_path += "/"
3442         exclude_match = exclude.search(normalized_path)
3443         if exclude_match and exclude_match.group(0):
3444             report.path_ignored(child, f"matches the --exclude regular expression")
3445             continue
3446
3447         if child.is_dir():
3448             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3449
3450         elif child.is_file():
3451             include_match = include.search(normalized_path)
3452             if include_match:
3453                 yield child
3454
3455
3456 @lru_cache()
3457 def find_project_root(srcs: Iterable[str]) -> Path:
3458     """Return a directory containing .git, .hg, or pyproject.toml.
3459
3460     That directory can be one of the directories passed in `srcs` or their
3461     common parent.
3462
3463     If no directory in the tree contains a marker that would specify it's the
3464     project root, the root of the file system is returned.
3465     """
3466     if not srcs:
3467         return Path("/").resolve()
3468
3469     common_base = min(Path(src).resolve() for src in srcs)
3470     if common_base.is_dir():
3471         # Append a fake file so `parents` below returns `common_base_dir`, too.
3472         common_base /= "fake-file"
3473     for directory in common_base.parents:
3474         if (directory / ".git").is_dir():
3475             return directory
3476
3477         if (directory / ".hg").is_dir():
3478             return directory
3479
3480         if (directory / "pyproject.toml").is_file():
3481             return directory
3482
3483     return directory
3484
3485
3486 @dataclass
3487 class Report:
3488     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3489
3490     check: bool = False
3491     quiet: bool = False
3492     verbose: bool = False
3493     change_count: int = 0
3494     same_count: int = 0
3495     failure_count: int = 0
3496
3497     def done(self, src: Path, changed: Changed) -> None:
3498         """Increment the counter for successful reformatting. Write out a message."""
3499         if changed is Changed.YES:
3500             reformatted = "would reformat" if self.check else "reformatted"
3501             if self.verbose or not self.quiet:
3502                 out(f"{reformatted} {src}")
3503             self.change_count += 1
3504         else:
3505             if self.verbose:
3506                 if changed is Changed.NO:
3507                     msg = f"{src} already well formatted, good job."
3508                 else:
3509                     msg = f"{src} wasn't modified on disk since last run."
3510                 out(msg, bold=False)
3511             self.same_count += 1
3512
3513     def failed(self, src: Path, message: str) -> None:
3514         """Increment the counter for failed reformatting. Write out a message."""
3515         err(f"error: cannot format {src}: {message}")
3516         self.failure_count += 1
3517
3518     def path_ignored(self, path: Path, message: str) -> None:
3519         if self.verbose:
3520             out(f"{path} ignored: {message}", bold=False)
3521
3522     @property
3523     def return_code(self) -> int:
3524         """Return the exit code that the app should use.
3525
3526         This considers the current state of changed files and failures:
3527         - if there were any failures, return 123;
3528         - if any files were changed and --check is being used, return 1;
3529         - otherwise return 0.
3530         """
3531         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3532         # 126 we have special return codes reserved by the shell.
3533         if self.failure_count:
3534             return 123
3535
3536         elif self.change_count and self.check:
3537             return 1
3538
3539         return 0
3540
3541     def __str__(self) -> str:
3542         """Render a color report of the current state.
3543
3544         Use `click.unstyle` to remove colors.
3545         """
3546         if self.check:
3547             reformatted = "would be reformatted"
3548             unchanged = "would be left unchanged"
3549             failed = "would fail to reformat"
3550         else:
3551             reformatted = "reformatted"
3552             unchanged = "left unchanged"
3553             failed = "failed to reformat"
3554         report = []
3555         if self.change_count:
3556             s = "s" if self.change_count > 1 else ""
3557             report.append(
3558                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3559             )
3560         if self.same_count:
3561             s = "s" if self.same_count > 1 else ""
3562             report.append(f"{self.same_count} file{s} {unchanged}")
3563         if self.failure_count:
3564             s = "s" if self.failure_count > 1 else ""
3565             report.append(
3566                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3567             )
3568         return ", ".join(report) + "."
3569
3570
3571 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3572     filename = "<unknown>"
3573     if sys.version_info >= (3, 8):
3574         # TODO: support Python 4+ ;)
3575         for minor_version in range(sys.version_info[1], 4, -1):
3576             try:
3577                 return ast.parse(src, filename, feature_version=(3, minor_version))
3578             except SyntaxError:
3579                 continue
3580     else:
3581         for feature_version in (7, 6):
3582             try:
3583                 return ast3.parse(src, filename, feature_version=feature_version)
3584             except SyntaxError:
3585                 continue
3586
3587     return ast27.parse(src)
3588
3589
3590 def _fixup_ast_constants(
3591     node: Union[ast.AST, ast3.AST, ast27.AST]
3592 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3593     """Map ast nodes deprecated in 3.8 to Constant."""
3594     # casts are required until this is released:
3595     # https://github.com/python/typeshed/pull/3142
3596     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3597         return cast(ast.AST, ast.Constant(value=node.s))
3598     elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3599         return cast(ast.AST, ast.Constant(value=node.n))
3600     elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3601         return cast(ast.AST, ast.Constant(value=node.value))
3602     return node
3603
3604
3605 def assert_equivalent(src: str, dst: str) -> None:
3606     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3607
3608     def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3609         """Simple visitor generating strings to compare ASTs by content."""
3610
3611         node = _fixup_ast_constants(node)
3612
3613         yield f"{'  ' * depth}{node.__class__.__name__}("
3614
3615         for field in sorted(node._fields):
3616             # TypeIgnore has only one field 'lineno' which breaks this comparison
3617             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3618             if sys.version_info >= (3, 8):
3619                 type_ignore_classes += (ast.TypeIgnore,)
3620             if isinstance(node, type_ignore_classes):
3621                 break
3622
3623             try:
3624                 value = getattr(node, field)
3625             except AttributeError:
3626                 continue
3627
3628             yield f"{'  ' * (depth+1)}{field}="
3629
3630             if isinstance(value, list):
3631                 for item in value:
3632                     # Ignore nested tuples within del statements, because we may insert
3633                     # parentheses and they change the AST.
3634                     if (
3635                         field == "targets"
3636                         and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3637                         and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3638                     ):
3639                         for item in item.elts:
3640                             yield from _v(item, depth + 2)
3641                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3642                         yield from _v(item, depth + 2)
3643
3644             elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3645                 yield from _v(value, depth + 2)
3646
3647             else:
3648                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3649
3650         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3651
3652     try:
3653         src_ast = parse_ast(src)
3654     except Exception as exc:
3655         raise AssertionError(
3656             f"cannot use --safe with this file; failed to parse source file.  "
3657             f"AST error message: {exc}"
3658         )
3659
3660     try:
3661         dst_ast = parse_ast(dst)
3662     except Exception as exc:
3663         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3664         raise AssertionError(
3665             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3666             f"Please report a bug on https://github.com/psf/black/issues.  "
3667             f"This invalid output might be helpful: {log}"
3668         ) from None
3669
3670     src_ast_str = "\n".join(_v(src_ast))
3671     dst_ast_str = "\n".join(_v(dst_ast))
3672     if src_ast_str != dst_ast_str:
3673         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3674         raise AssertionError(
3675             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3676             f"the source.  "
3677             f"Please report a bug on https://github.com/psf/black/issues.  "
3678             f"This diff might be helpful: {log}"
3679         ) from None
3680
3681
3682 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3683     """Raise AssertionError if `dst` reformats differently the second time."""
3684     newdst = format_str(dst, mode=mode)
3685     if dst != newdst:
3686         log = dump_to_file(
3687             diff(src, dst, "source", "first pass"),
3688             diff(dst, newdst, "first pass", "second pass"),
3689         )
3690         raise AssertionError(
3691             f"INTERNAL ERROR: Black produced different code on the second pass "
3692             f"of the formatter.  "
3693             f"Please report a bug on https://github.com/psf/black/issues.  "
3694             f"This diff might be helpful: {log}"
3695         ) from None
3696
3697
3698 def dump_to_file(*output: str) -> str:
3699     """Dump `output` to a temporary file. Return path to the file."""
3700     with tempfile.NamedTemporaryFile(
3701         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3702     ) as f:
3703         for lines in output:
3704             f.write(lines)
3705             if lines and lines[-1] != "\n":
3706                 f.write("\n")
3707     return f.name
3708
3709
3710 @contextmanager
3711 def nullcontext() -> Iterator[None]:
3712     """Return context manager that does nothing.
3713     Similar to `nullcontext` from python 3.7"""
3714     yield
3715
3716
3717 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3718     """Return a unified diff string between strings `a` and `b`."""
3719     import difflib
3720
3721     a_lines = [line + "\n" for line in a.split("\n")]
3722     b_lines = [line + "\n" for line in b.split("\n")]
3723     return "".join(
3724         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3725     )
3726
3727
3728 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3729     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3730     err("Aborted!")
3731     for task in tasks:
3732         task.cancel()
3733
3734
3735 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3736     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3737     try:
3738         if sys.version_info[:2] >= (3, 7):
3739             all_tasks = asyncio.all_tasks
3740         else:
3741             all_tasks = asyncio.Task.all_tasks
3742         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3743         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3744         if not to_cancel:
3745             return
3746
3747         for task in to_cancel:
3748             task.cancel()
3749         loop.run_until_complete(
3750             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3751         )
3752     finally:
3753         # `concurrent.futures.Future` objects cannot be cancelled once they
3754         # are already running. There might be some when the `shutdown()` happened.
3755         # Silence their logger's spew about the event loop being closed.
3756         cf_logger = logging.getLogger("concurrent.futures")
3757         cf_logger.setLevel(logging.CRITICAL)
3758         loop.close()
3759
3760
3761 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3762     """Replace `regex` with `replacement` twice on `original`.
3763
3764     This is used by string normalization to perform replaces on
3765     overlapping matches.
3766     """
3767     return regex.sub(replacement, regex.sub(replacement, original))
3768
3769
3770 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3771     """Compile a regular expression string in `regex`.
3772
3773     If it contains newlines, use verbose mode.
3774     """
3775     if "\n" in regex:
3776         regex = "(?x)" + regex
3777     return re.compile(regex)
3778
3779
3780 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3781     """Like `reversed(enumerate(sequence))` if that were possible."""
3782     index = len(sequence) - 1
3783     for element in reversed(sequence):
3784         yield (index, element)
3785         index -= 1
3786
3787
3788 def enumerate_with_length(
3789     line: Line, reversed: bool = False
3790 ) -> Iterator[Tuple[Index, Leaf, int]]:
3791     """Return an enumeration of leaves with their length.
3792
3793     Stops prematurely on multiline strings and standalone comments.
3794     """
3795     op = cast(
3796         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3797         enumerate_reversed if reversed else enumerate,
3798     )
3799     for index, leaf in op(line.leaves):
3800         length = len(leaf.prefix) + len(leaf.value)
3801         if "\n" in leaf.value:
3802             return  # Multiline strings, we can't continue.
3803
3804         for comment in line.comments_after(leaf):
3805             length += len(comment.value)
3806
3807         yield index, leaf, length
3808
3809
3810 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3811     """Return True if `line` is no longer than `line_length`.
3812
3813     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3814     """
3815     if not line_str:
3816         line_str = str(line).strip("\n")
3817     return (
3818         len(line_str) <= line_length
3819         and "\n" not in line_str  # multiline strings
3820         and not line.contains_standalone_comments()
3821     )
3822
3823
3824 def can_be_split(line: Line) -> bool:
3825     """Return False if the line cannot be split *for sure*.
3826
3827     This is not an exhaustive search but a cheap heuristic that we can use to
3828     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3829     in unnecessary parentheses).
3830     """
3831     leaves = line.leaves
3832     if len(leaves) < 2:
3833         return False
3834
3835     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3836         call_count = 0
3837         dot_count = 0
3838         next = leaves[-1]
3839         for leaf in leaves[-2::-1]:
3840             if leaf.type in OPENING_BRACKETS:
3841                 if next.type not in CLOSING_BRACKETS:
3842                     return False
3843
3844                 call_count += 1
3845             elif leaf.type == token.DOT:
3846                 dot_count += 1
3847             elif leaf.type == token.NAME:
3848                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3849                     return False
3850
3851             elif leaf.type not in CLOSING_BRACKETS:
3852                 return False
3853
3854             if dot_count > 1 and call_count > 1:
3855                 return False
3856
3857     return True
3858
3859
3860 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3861     """Does `line` have a shape safe to reformat without optional parens around it?
3862
3863     Returns True for only a subset of potentially nice looking formattings but
3864     the point is to not return false positives that end up producing lines that
3865     are too long.
3866     """
3867     bt = line.bracket_tracker
3868     if not bt.delimiters:
3869         # Without delimiters the optional parentheses are useless.
3870         return True
3871
3872     max_priority = bt.max_delimiter_priority()
3873     if bt.delimiter_count_with_priority(max_priority) > 1:
3874         # With more than one delimiter of a kind the optional parentheses read better.
3875         return False
3876
3877     if max_priority == DOT_PRIORITY:
3878         # A single stranded method call doesn't require optional parentheses.
3879         return True
3880
3881     assert len(line.leaves) >= 2, "Stranded delimiter"
3882
3883     first = line.leaves[0]
3884     second = line.leaves[1]
3885     penultimate = line.leaves[-2]
3886     last = line.leaves[-1]
3887
3888     # With a single delimiter, omit if the expression starts or ends with
3889     # a bracket.
3890     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3891         remainder = False
3892         length = 4 * line.depth
3893         for _index, leaf, leaf_length in enumerate_with_length(line):
3894             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3895                 remainder = True
3896             if remainder:
3897                 length += leaf_length
3898                 if length > line_length:
3899                     break
3900
3901                 if leaf.type in OPENING_BRACKETS:
3902                     # There are brackets we can further split on.
3903                     remainder = False
3904
3905         else:
3906             # checked the entire string and line length wasn't exceeded
3907             if len(line.leaves) == _index + 1:
3908                 return True
3909
3910         # Note: we are not returning False here because a line might have *both*
3911         # a leading opening bracket and a trailing closing bracket.  If the
3912         # opening bracket doesn't match our rule, maybe the closing will.
3913
3914     if (
3915         last.type == token.RPAR
3916         or last.type == token.RBRACE
3917         or (
3918             # don't use indexing for omitting optional parentheses;
3919             # it looks weird
3920             last.type == token.RSQB
3921             and last.parent
3922             and last.parent.type != syms.trailer
3923         )
3924     ):
3925         if penultimate.type in OPENING_BRACKETS:
3926             # Empty brackets don't help.
3927             return False
3928
3929         if is_multiline_string(first):
3930             # Additional wrapping of a multiline string in this situation is
3931             # unnecessary.
3932             return True
3933
3934         length = 4 * line.depth
3935         seen_other_brackets = False
3936         for _index, leaf, leaf_length in enumerate_with_length(line):
3937             length += leaf_length
3938             if leaf is last.opening_bracket:
3939                 if seen_other_brackets or length <= line_length:
3940                     return True
3941
3942             elif leaf.type in OPENING_BRACKETS:
3943                 # There are brackets we can further split on.
3944                 seen_other_brackets = True
3945
3946     return False
3947
3948
3949 def get_cache_file(mode: FileMode) -> Path:
3950     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3951
3952
3953 def read_cache(mode: FileMode) -> Cache:
3954     """Read the cache if it exists and is well formed.
3955
3956     If it is not well formed, the call to write_cache later should resolve the issue.
3957     """
3958     cache_file = get_cache_file(mode)
3959     if not cache_file.exists():
3960         return {}
3961
3962     with cache_file.open("rb") as fobj:
3963         try:
3964             cache: Cache = pickle.load(fobj)
3965         except pickle.UnpicklingError:
3966             return {}
3967
3968     return cache
3969
3970
3971 def get_cache_info(path: Path) -> CacheInfo:
3972     """Return the information used to check if a file is already formatted or not."""
3973     stat = path.stat()
3974     return stat.st_mtime, stat.st_size
3975
3976
3977 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3978     """Split an iterable of paths in `sources` into two sets.
3979
3980     The first contains paths of files that modified on disk or are not in the
3981     cache. The other contains paths to non-modified files.
3982     """
3983     todo, done = set(), set()
3984     for src in sources:
3985         src = src.resolve()
3986         if cache.get(src) != get_cache_info(src):
3987             todo.add(src)
3988         else:
3989             done.add(src)
3990     return todo, done
3991
3992
3993 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3994     """Update the cache file."""
3995     cache_file = get_cache_file(mode)
3996     try:
3997         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3998         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3999         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4000             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
4001         os.replace(f.name, cache_file)
4002     except OSError:
4003         pass
4004
4005
4006 def patch_click() -> None:
4007     """Make Click not crash.
4008
4009     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4010     default which restricts paths that it can access during the lifetime of the
4011     application.  Click refuses to work in this scenario by raising a RuntimeError.
4012
4013     In case of Black the likelihood that non-ASCII characters are going to be used in
4014     file paths is minimal since it's Python source code.  Moreover, this crash was
4015     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4016     """
4017     try:
4018         from click import core
4019         from click import _unicodefun  # type: ignore
4020     except ModuleNotFoundError:
4021         return
4022
4023     for module in (core, _unicodefun):
4024         if hasattr(module, "_verify_python3_env"):
4025             module._verify_python3_env = lambda: None
4026
4027
4028 def patched_main() -> None:
4029     freeze_support()
4030     patch_click()
4031     main()
4032
4033
4034 if __name__ == "__main__":
4035     patched_main()