]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

7629d9f4c2d23f09e45b16fdf83eb82f3bbf99a1
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43 from typed_ast import ast3, ast27
44
45 # lib2to3 fork
46 from blib2to3.pytree import Node, Leaf, type_repr
47 from blib2to3 import pygram, pytree
48 from blib2to3.pgen2 import driver, token
49 from blib2to3.pgen2.grammar import Grammar
50 from blib2to3.pgen2.parse import ParseError
51
52
53 __version__ = "19.3b0"
54 DEFAULT_LINE_LENGTH = 88
55 DEFAULT_EXCLUDES = (
56     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
57 )
58 DEFAULT_INCLUDES = r"\.pyi?$"
59 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
60
61
62 # types
63 FileContent = str
64 Encoding = str
65 NewLine = str
66 Depth = int
67 NodeType = int
68 LeafID = int
69 Priority = int
70 Index = int
71 LN = Union[Leaf, Node]
72 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
73 Timestamp = float
74 FileSize = int
75 CacheInfo = Tuple[Timestamp, FileSize]
76 Cache = Dict[Path, CacheInfo]
77 out = partial(click.secho, bold=True, err=True)
78 err = partial(click.secho, fg="red", err=True)
79
80 pygram.initialize(CACHE_DIR)
81 syms = pygram.python_symbols
82
83
84 class NothingChanged(UserWarning):
85     """Raised when reformatted code is the same as source."""
86
87
88 class CannotSplit(Exception):
89     """A readable split that fits the allotted line length is impossible."""
90
91
92 class InvalidInput(ValueError):
93     """Raised when input source code fails all parse attempts."""
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101
102     @classmethod
103     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
104         if check and not diff:
105             return cls.CHECK
106
107         return cls.DIFF if diff else cls.YES
108
109
110 class Changed(Enum):
111     NO = 0
112     CACHED = 1
113     YES = 2
114
115
116 class TargetVersion(Enum):
117     PY27 = 2
118     PY33 = 3
119     PY34 = 4
120     PY35 = 5
121     PY36 = 6
122     PY37 = 7
123     PY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.PY27
127
128
129 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA_IN_CALL = 4
138     TRAILING_COMMA_IN_DEF = 5
139     # The following two feature-flags are mutually exclusive, and exactly one should be
140     # set for every version of python.
141     ASYNC_IDENTIFIERS = 6
142     ASYNC_KEYWORDS = 7
143
144
145 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
146     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
147     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
148     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
149     TargetVersion.PY35: {
150         Feature.UNICODE_LITERALS,
151         Feature.TRAILING_COMMA_IN_CALL,
152         Feature.ASYNC_IDENTIFIERS,
153     },
154     TargetVersion.PY36: {
155         Feature.UNICODE_LITERALS,
156         Feature.F_STRINGS,
157         Feature.NUMERIC_UNDERSCORES,
158         Feature.TRAILING_COMMA_IN_CALL,
159         Feature.TRAILING_COMMA_IN_DEF,
160         Feature.ASYNC_IDENTIFIERS,
161     },
162     TargetVersion.PY37: {
163         Feature.UNICODE_LITERALS,
164         Feature.F_STRINGS,
165         Feature.NUMERIC_UNDERSCORES,
166         Feature.TRAILING_COMMA_IN_CALL,
167         Feature.TRAILING_COMMA_IN_DEF,
168         Feature.ASYNC_KEYWORDS,
169     },
170     TargetVersion.PY38: {
171         Feature.UNICODE_LITERALS,
172         Feature.F_STRINGS,
173         Feature.NUMERIC_UNDERSCORES,
174         Feature.TRAILING_COMMA_IN_CALL,
175         Feature.TRAILING_COMMA_IN_DEF,
176         Feature.ASYNC_KEYWORDS,
177     },
178 }
179
180
181 @dataclass
182 class FileMode:
183     target_versions: Set[TargetVersion] = Factory(set)
184     line_length: int = DEFAULT_LINE_LENGTH
185     string_normalization: bool = True
186     is_pyi: bool = False
187
188     def get_cache_key(self) -> str:
189         if self.target_versions:
190             version_str = ",".join(
191                 str(version.value)
192                 for version in sorted(self.target_versions, key=lambda v: v.value)
193             )
194         else:
195             version_str = "-"
196         parts = [
197             version_str,
198             str(self.line_length),
199             str(int(self.string_normalization)),
200             str(int(self.is_pyi)),
201         ]
202         return ".".join(parts)
203
204
205 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
206     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
207
208
209 def read_pyproject_toml(
210     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
211 ) -> Optional[str]:
212     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
213
214     Returns the path to a successfully found and read configuration file, None
215     otherwise.
216     """
217     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
218     if not value:
219         root = find_project_root(ctx.params.get("src", ()))
220         path = root / "pyproject.toml"
221         if path.is_file():
222             value = str(path)
223         else:
224             return None
225
226     try:
227         pyproject_toml = toml.load(value)
228         config = pyproject_toml.get("tool", {}).get("black", {})
229     except (toml.TomlDecodeError, OSError) as e:
230         raise click.FileError(
231             filename=value, hint=f"Error reading configuration file: {e}"
232         )
233
234     if not config:
235         return None
236
237     if ctx.default_map is None:
238         ctx.default_map = {}
239     ctx.default_map.update(  # type: ignore  # bad types in .pyi
240         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
241     )
242     return value
243
244
245 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
246 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
247 @click.option(
248     "-l",
249     "--line-length",
250     type=int,
251     default=DEFAULT_LINE_LENGTH,
252     help="How many characters per line to allow.",
253     show_default=True,
254 )
255 @click.option(
256     "-t",
257     "--target-version",
258     type=click.Choice([v.name.lower() for v in TargetVersion]),
259     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
260     multiple=True,
261     help=(
262         "Python versions that should be supported by Black's output. [default: "
263         "per-file auto-detection]"
264     ),
265 )
266 @click.option(
267     "--py36",
268     is_flag=True,
269     help=(
270         "Allow using Python 3.6-only syntax on all input files.  This will put "
271         "trailing commas in function signatures and calls also after *args and "
272         "**kwargs. Deprecated; use --target-version instead. "
273         "[default: per-file auto-detection]"
274     ),
275 )
276 @click.option(
277     "--pyi",
278     is_flag=True,
279     help=(
280         "Format all input files like typing stubs regardless of file extension "
281         "(useful when piping source on standard input)."
282     ),
283 )
284 @click.option(
285     "-S",
286     "--skip-string-normalization",
287     is_flag=True,
288     help="Don't normalize string quotes or prefixes.",
289 )
290 @click.option(
291     "--check",
292     is_flag=True,
293     help=(
294         "Don't write the files back, just return the status.  Return code 0 "
295         "means nothing would change.  Return code 1 means some files would be "
296         "reformatted.  Return code 123 means there was an internal error."
297     ),
298 )
299 @click.option(
300     "--diff",
301     is_flag=True,
302     help="Don't write the files back, just output a diff for each file on stdout.",
303 )
304 @click.option(
305     "--fast/--safe",
306     is_flag=True,
307     help="If --fast given, skip temporary sanity checks. [default: --safe]",
308 )
309 @click.option(
310     "--include",
311     type=str,
312     default=DEFAULT_INCLUDES,
313     help=(
314         "A regular expression that matches files and directories that should be "
315         "included on recursive searches.  An empty value means all files are "
316         "included regardless of the name.  Use forward slashes for directories on "
317         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
318         "later."
319     ),
320     show_default=True,
321 )
322 @click.option(
323     "--exclude",
324     type=str,
325     default=DEFAULT_EXCLUDES,
326     help=(
327         "A regular expression that matches files and directories that should be "
328         "excluded on recursive searches.  An empty value means no paths are excluded. "
329         "Use forward slashes for directories on all platforms (Windows, too).  "
330         "Exclusions are calculated first, inclusions later."
331     ),
332     show_default=True,
333 )
334 @click.option(
335     "-q",
336     "--quiet",
337     is_flag=True,
338     help=(
339         "Don't emit non-error messages to stderr. Errors are still emitted, "
340         "silence those with 2>/dev/null."
341     ),
342 )
343 @click.option(
344     "-v",
345     "--verbose",
346     is_flag=True,
347     help=(
348         "Also emit messages to stderr about files that were not changed or were "
349         "ignored due to --exclude=."
350     ),
351 )
352 @click.version_option(version=__version__)
353 @click.argument(
354     "src",
355     nargs=-1,
356     type=click.Path(
357         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
358     ),
359     is_eager=True,
360 )
361 @click.option(
362     "--config",
363     type=click.Path(
364         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
365     ),
366     is_eager=True,
367     callback=read_pyproject_toml,
368     help="Read configuration from PATH.",
369 )
370 @click.pass_context
371 def main(
372     ctx: click.Context,
373     code: Optional[str],
374     line_length: int,
375     target_version: List[TargetVersion],
376     check: bool,
377     diff: bool,
378     fast: bool,
379     pyi: bool,
380     py36: bool,
381     skip_string_normalization: bool,
382     quiet: bool,
383     verbose: bool,
384     include: str,
385     exclude: str,
386     src: Tuple[str],
387     config: Optional[str],
388 ) -> None:
389     """The uncompromising code formatter."""
390     write_back = WriteBack.from_configuration(check=check, diff=diff)
391     if target_version:
392         if py36:
393             err(f"Cannot use both --target-version and --py36")
394             ctx.exit(2)
395         else:
396             versions = set(target_version)
397     elif py36:
398         err(
399             "--py36 is deprecated and will be removed in a future version. "
400             "Use --target-version py36 instead."
401         )
402         versions = PY36_VERSIONS
403     else:
404         # We'll autodetect later.
405         versions = set()
406     mode = FileMode(
407         target_versions=versions,
408         line_length=line_length,
409         is_pyi=pyi,
410         string_normalization=not skip_string_normalization,
411     )
412     if config and verbose:
413         out(f"Using configuration from {config}.", bold=False, fg="blue")
414     if code is not None:
415         print(format_str(code, mode=mode))
416         ctx.exit(0)
417     try:
418         include_regex = re_compile_maybe_verbose(include)
419     except re.error:
420         err(f"Invalid regular expression for include given: {include!r}")
421         ctx.exit(2)
422     try:
423         exclude_regex = re_compile_maybe_verbose(exclude)
424     except re.error:
425         err(f"Invalid regular expression for exclude given: {exclude!r}")
426         ctx.exit(2)
427     report = Report(check=check, quiet=quiet, verbose=verbose)
428     root = find_project_root(src)
429     sources: Set[Path] = set()
430     for s in src:
431         p = Path(s)
432         if p.is_dir():
433             sources.update(
434                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
435             )
436         elif p.is_file() or s == "-":
437             # if a file was explicitly given, we don't care about its extension
438             sources.add(p)
439         else:
440             err(f"invalid path: {s}")
441     if len(sources) == 0:
442         if verbose or not quiet:
443             out("No paths given. Nothing to do 😴")
444         ctx.exit(0)
445
446     if len(sources) == 1:
447         reformat_one(
448             src=sources.pop(),
449             fast=fast,
450             write_back=write_back,
451             mode=mode,
452             report=report,
453         )
454     else:
455         reformat_many(
456             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
457         )
458
459     if verbose or not quiet:
460         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
461         click.secho(str(report), err=True)
462     ctx.exit(report.return_code)
463
464
465 def reformat_one(
466     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
467 ) -> None:
468     """Reformat a single file under `src` without spawning child processes.
469
470     If `quiet` is True, non-error messages are not output. `line_length`,
471     `write_back`, `fast` and `pyi` options are passed to
472     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
473     """
474     try:
475         changed = Changed.NO
476         if not src.is_file() and str(src) == "-":
477             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
478                 changed = Changed.YES
479         else:
480             cache: Cache = {}
481             if write_back != WriteBack.DIFF:
482                 cache = read_cache(mode)
483                 res_src = src.resolve()
484                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
485                     changed = Changed.CACHED
486             if changed is not Changed.CACHED and format_file_in_place(
487                 src, fast=fast, write_back=write_back, mode=mode
488             ):
489                 changed = Changed.YES
490             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
491                 write_back is WriteBack.CHECK and changed is Changed.NO
492             ):
493                 write_cache(cache, [src], mode)
494         report.done(src, changed)
495     except Exception as exc:
496         report.failed(src, str(exc))
497
498
499 def reformat_many(
500     sources: Set[Path],
501     fast: bool,
502     write_back: WriteBack,
503     mode: FileMode,
504     report: "Report",
505 ) -> None:
506     """Reformat multiple files using a ProcessPoolExecutor."""
507     loop = asyncio.get_event_loop()
508     worker_count = os.cpu_count()
509     if sys.platform == "win32":
510         # Work around https://bugs.python.org/issue26903
511         worker_count = min(worker_count, 61)
512     executor = ProcessPoolExecutor(max_workers=worker_count)
513     try:
514         loop.run_until_complete(
515             schedule_formatting(
516                 sources=sources,
517                 fast=fast,
518                 write_back=write_back,
519                 mode=mode,
520                 report=report,
521                 loop=loop,
522                 executor=executor,
523             )
524         )
525     finally:
526         shutdown(loop)
527
528
529 async def schedule_formatting(
530     sources: Set[Path],
531     fast: bool,
532     write_back: WriteBack,
533     mode: FileMode,
534     report: "Report",
535     loop: BaseEventLoop,
536     executor: Executor,
537 ) -> None:
538     """Run formatting of `sources` in parallel using the provided `executor`.
539
540     (Use ProcessPoolExecutors for actual parallelism.)
541
542     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
543     :func:`format_file_in_place`.
544     """
545     cache: Cache = {}
546     if write_back != WriteBack.DIFF:
547         cache = read_cache(mode)
548         sources, cached = filter_cached(cache, sources)
549         for src in sorted(cached):
550             report.done(src, Changed.CACHED)
551     if not sources:
552         return
553
554     cancelled = []
555     sources_to_cache = []
556     lock = None
557     if write_back == WriteBack.DIFF:
558         # For diff output, we need locks to ensure we don't interleave output
559         # from different processes.
560         manager = Manager()
561         lock = manager.Lock()
562     tasks = {
563         asyncio.ensure_future(
564             loop.run_in_executor(
565                 executor, format_file_in_place, src, fast, mode, write_back, lock
566             )
567         ): src
568         for src in sorted(sources)
569     }
570     pending: Iterable[asyncio.Future] = tasks.keys()
571     try:
572         loop.add_signal_handler(signal.SIGINT, cancel, pending)
573         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
574     except NotImplementedError:
575         # There are no good alternatives for these on Windows.
576         pass
577     while pending:
578         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
579         for task in done:
580             src = tasks.pop(task)
581             if task.cancelled():
582                 cancelled.append(task)
583             elif task.exception():
584                 report.failed(src, str(task.exception()))
585             else:
586                 changed = Changed.YES if task.result() else Changed.NO
587                 # If the file was written back or was successfully checked as
588                 # well-formatted, store this information in the cache.
589                 if write_back is WriteBack.YES or (
590                     write_back is WriteBack.CHECK and changed is Changed.NO
591                 ):
592                     sources_to_cache.append(src)
593                 report.done(src, changed)
594     if cancelled:
595         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
596     if sources_to_cache:
597         write_cache(cache, sources_to_cache, mode)
598
599
600 def format_file_in_place(
601     src: Path,
602     fast: bool,
603     mode: FileMode,
604     write_back: WriteBack = WriteBack.NO,
605     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
606 ) -> bool:
607     """Format file under `src` path. Return True if changed.
608
609     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
610     code to the file.
611     `line_length` and `fast` options are passed to :func:`format_file_contents`.
612     """
613     if src.suffix == ".pyi":
614         mode = evolve(mode, is_pyi=True)
615
616     then = datetime.utcfromtimestamp(src.stat().st_mtime)
617     with open(src, "rb") as buf:
618         src_contents, encoding, newline = decode_bytes(buf.read())
619     try:
620         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
621     except NothingChanged:
622         return False
623
624     if write_back == write_back.YES:
625         with open(src, "w", encoding=encoding, newline=newline) as f:
626             f.write(dst_contents)
627     elif write_back == write_back.DIFF:
628         now = datetime.utcnow()
629         src_name = f"{src}\t{then} +0000"
630         dst_name = f"{src}\t{now} +0000"
631         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
632         if lock:
633             lock.acquire()
634         try:
635             f = io.TextIOWrapper(
636                 sys.stdout.buffer,
637                 encoding=encoding,
638                 newline=newline,
639                 write_through=True,
640             )
641             f.write(diff_contents)
642             f.detach()
643         finally:
644             if lock:
645                 lock.release()
646     return True
647
648
649 def format_stdin_to_stdout(
650     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
651 ) -> bool:
652     """Format file on stdin. Return True if changed.
653
654     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
655     write a diff to stdout. The `mode` argument is passed to
656     :func:`format_file_contents`.
657     """
658     then = datetime.utcnow()
659     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
660     dst = src
661     try:
662         dst = format_file_contents(src, fast=fast, mode=mode)
663         return True
664
665     except NothingChanged:
666         return False
667
668     finally:
669         f = io.TextIOWrapper(
670             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
671         )
672         if write_back == WriteBack.YES:
673             f.write(dst)
674         elif write_back == WriteBack.DIFF:
675             now = datetime.utcnow()
676             src_name = f"STDIN\t{then} +0000"
677             dst_name = f"STDOUT\t{now} +0000"
678             f.write(diff(src, dst, src_name, dst_name))
679         f.detach()
680
681
682 def format_file_contents(
683     src_contents: str, *, fast: bool, mode: FileMode
684 ) -> FileContent:
685     """Reformat contents a file and return new contents.
686
687     If `fast` is False, additionally confirm that the reformatted code is
688     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
689     `line_length` is passed to :func:`format_str`.
690     """
691     if src_contents.strip() == "":
692         raise NothingChanged
693
694     dst_contents = format_str(src_contents, mode=mode)
695     if src_contents == dst_contents:
696         raise NothingChanged
697
698     if not fast:
699         assert_equivalent(src_contents, dst_contents)
700         assert_stable(src_contents, dst_contents, mode=mode)
701     return dst_contents
702
703
704 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
705     """Reformat a string and return new contents.
706
707     `line_length` determines how many characters per line are allowed.
708     """
709     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
710     dst_contents = []
711     future_imports = get_future_imports(src_node)
712     if mode.target_versions:
713         versions = mode.target_versions
714     else:
715         versions = detect_target_versions(src_node)
716     normalize_fmt_off(src_node)
717     lines = LineGenerator(
718         remove_u_prefix="unicode_literals" in future_imports
719         or supports_feature(versions, Feature.UNICODE_LITERALS),
720         is_pyi=mode.is_pyi,
721         normalize_strings=mode.string_normalization,
722     )
723     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
724     empty_line = Line()
725     after = 0
726     split_line_features = {
727         feature
728         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
729         if supports_feature(versions, feature)
730     }
731     for current_line in lines.visit(src_node):
732         for _ in range(after):
733             dst_contents.append(str(empty_line))
734         before, after = elt.maybe_empty_lines(current_line)
735         for _ in range(before):
736             dst_contents.append(str(empty_line))
737         for line in split_line(
738             current_line, line_length=mode.line_length, features=split_line_features
739         ):
740             dst_contents.append(str(line))
741     return "".join(dst_contents)
742
743
744 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
745     """Return a tuple of (decoded_contents, encoding, newline).
746
747     `newline` is either CRLF or LF but `decoded_contents` is decoded with
748     universal newlines (i.e. only contains LF).
749     """
750     srcbuf = io.BytesIO(src)
751     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
752     if not lines:
753         return "", encoding, "\n"
754
755     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
756     srcbuf.seek(0)
757     with io.TextIOWrapper(srcbuf, encoding) as tiow:
758         return tiow.read(), encoding, newline
759
760
761 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
762     if not target_versions:
763         # No target_version specified, so try all grammars.
764         return [
765             # Python 3.7+
766             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
767             # Python 3.0-3.6
768             pygram.python_grammar_no_print_statement_no_exec_statement,
769             # Python 2.7 with future print_function import
770             pygram.python_grammar_no_print_statement,
771             # Python 2.7
772             pygram.python_grammar,
773         ]
774     elif all(version.is_python2() for version in target_versions):
775         # Python 2-only code, so try Python 2 grammars.
776         return [
777             # Python 2.7 with future print_function import
778             pygram.python_grammar_no_print_statement,
779             # Python 2.7
780             pygram.python_grammar,
781         ]
782     else:
783         # Python 3-compatible code, so only try Python 3 grammar.
784         grammars = []
785         # If we have to parse both, try to parse async as a keyword first
786         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
787             # Python 3.7+
788             grammars.append(
789                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
790             )
791         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
792             # Python 3.0-3.6
793             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
794         # At least one of the above branches must have been taken, because every Python
795         # version has exactly one of the two 'ASYNC_*' flags
796         return grammars
797
798
799 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
800     """Given a string with source, return the lib2to3 Node."""
801     if src_txt[-1:] != "\n":
802         src_txt += "\n"
803
804     for grammar in get_grammars(set(target_versions)):
805         drv = driver.Driver(grammar, pytree.convert)
806         try:
807             result = drv.parse_string(src_txt, True)
808             break
809
810         except ParseError as pe:
811             lineno, column = pe.context[1]
812             lines = src_txt.splitlines()
813             try:
814                 faulty_line = lines[lineno - 1]
815             except IndexError:
816                 faulty_line = "<line number missing in source>"
817             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
818     else:
819         raise exc from None
820
821     if isinstance(result, Leaf):
822         result = Node(syms.file_input, [result])
823     return result
824
825
826 def lib2to3_unparse(node: Node) -> str:
827     """Given a lib2to3 node, return its string representation."""
828     code = str(node)
829     return code
830
831
832 T = TypeVar("T")
833
834
835 class Visitor(Generic[T]):
836     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
837
838     def visit(self, node: LN) -> Iterator[T]:
839         """Main method to visit `node` and its children.
840
841         It tries to find a `visit_*()` method for the given `node.type`, like
842         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
843         If no dedicated `visit_*()` method is found, chooses `visit_default()`
844         instead.
845
846         Then yields objects of type `T` from the selected visitor.
847         """
848         if node.type < 256:
849             name = token.tok_name[node.type]
850         else:
851             name = type_repr(node.type)
852         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
853
854     def visit_default(self, node: LN) -> Iterator[T]:
855         """Default `visit_*()` implementation. Recurses to children of `node`."""
856         if isinstance(node, Node):
857             for child in node.children:
858                 yield from self.visit(child)
859
860
861 @dataclass
862 class DebugVisitor(Visitor[T]):
863     tree_depth: int = 0
864
865     def visit_default(self, node: LN) -> Iterator[T]:
866         indent = " " * (2 * self.tree_depth)
867         if isinstance(node, Node):
868             _type = type_repr(node.type)
869             out(f"{indent}{_type}", fg="yellow")
870             self.tree_depth += 1
871             for child in node.children:
872                 yield from self.visit(child)
873
874             self.tree_depth -= 1
875             out(f"{indent}/{_type}", fg="yellow", bold=False)
876         else:
877             _type = token.tok_name.get(node.type, str(node.type))
878             out(f"{indent}{_type}", fg="blue", nl=False)
879             if node.prefix:
880                 # We don't have to handle prefixes for `Node` objects since
881                 # that delegates to the first child anyway.
882                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
883             out(f" {node.value!r}", fg="blue", bold=False)
884
885     @classmethod
886     def show(cls, code: Union[str, Leaf, Node]) -> None:
887         """Pretty-print the lib2to3 AST of a given string of `code`.
888
889         Convenience method for debugging.
890         """
891         v: DebugVisitor[None] = DebugVisitor()
892         if isinstance(code, str):
893             code = lib2to3_parse(code)
894         list(v.visit(code))
895
896
897 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
898 STATEMENT = {
899     syms.if_stmt,
900     syms.while_stmt,
901     syms.for_stmt,
902     syms.try_stmt,
903     syms.except_clause,
904     syms.with_stmt,
905     syms.funcdef,
906     syms.classdef,
907 }
908 STANDALONE_COMMENT = 153
909 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
910 LOGIC_OPERATORS = {"and", "or"}
911 COMPARATORS = {
912     token.LESS,
913     token.GREATER,
914     token.EQEQUAL,
915     token.NOTEQUAL,
916     token.LESSEQUAL,
917     token.GREATEREQUAL,
918 }
919 MATH_OPERATORS = {
920     token.VBAR,
921     token.CIRCUMFLEX,
922     token.AMPER,
923     token.LEFTSHIFT,
924     token.RIGHTSHIFT,
925     token.PLUS,
926     token.MINUS,
927     token.STAR,
928     token.SLASH,
929     token.DOUBLESLASH,
930     token.PERCENT,
931     token.AT,
932     token.TILDE,
933     token.DOUBLESTAR,
934 }
935 STARS = {token.STAR, token.DOUBLESTAR}
936 VARARGS_PARENTS = {
937     syms.arglist,
938     syms.argument,  # double star in arglist
939     syms.trailer,  # single argument to call
940     syms.typedargslist,
941     syms.varargslist,  # lambdas
942 }
943 UNPACKING_PARENTS = {
944     syms.atom,  # single element of a list or set literal
945     syms.dictsetmaker,
946     syms.listmaker,
947     syms.testlist_gexp,
948     syms.testlist_star_expr,
949 }
950 TEST_DESCENDANTS = {
951     syms.test,
952     syms.lambdef,
953     syms.or_test,
954     syms.and_test,
955     syms.not_test,
956     syms.comparison,
957     syms.star_expr,
958     syms.expr,
959     syms.xor_expr,
960     syms.and_expr,
961     syms.shift_expr,
962     syms.arith_expr,
963     syms.trailer,
964     syms.term,
965     syms.power,
966 }
967 ASSIGNMENTS = {
968     "=",
969     "+=",
970     "-=",
971     "*=",
972     "@=",
973     "/=",
974     "%=",
975     "&=",
976     "|=",
977     "^=",
978     "<<=",
979     ">>=",
980     "**=",
981     "//=",
982 }
983 COMPREHENSION_PRIORITY = 20
984 COMMA_PRIORITY = 18
985 TERNARY_PRIORITY = 16
986 LOGIC_PRIORITY = 14
987 STRING_PRIORITY = 12
988 COMPARATOR_PRIORITY = 10
989 MATH_PRIORITIES = {
990     token.VBAR: 9,
991     token.CIRCUMFLEX: 8,
992     token.AMPER: 7,
993     token.LEFTSHIFT: 6,
994     token.RIGHTSHIFT: 6,
995     token.PLUS: 5,
996     token.MINUS: 5,
997     token.STAR: 4,
998     token.SLASH: 4,
999     token.DOUBLESLASH: 4,
1000     token.PERCENT: 4,
1001     token.AT: 4,
1002     token.TILDE: 3,
1003     token.DOUBLESTAR: 2,
1004 }
1005 DOT_PRIORITY = 1
1006
1007
1008 @dataclass
1009 class BracketTracker:
1010     """Keeps track of brackets on a line."""
1011
1012     depth: int = 0
1013     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1014     delimiters: Dict[LeafID, Priority] = Factory(dict)
1015     previous: Optional[Leaf] = None
1016     _for_loop_depths: List[int] = Factory(list)
1017     _lambda_argument_depths: List[int] = Factory(list)
1018
1019     def mark(self, leaf: Leaf) -> None:
1020         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1021
1022         All leaves receive an int `bracket_depth` field that stores how deep
1023         within brackets a given leaf is. 0 means there are no enclosing brackets
1024         that started on this line.
1025
1026         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1027         field that it forms a pair with. This is a one-directional link to
1028         avoid reference cycles.
1029
1030         If a leaf is a delimiter (a token on which Black can split the line if
1031         needed) and it's on depth 0, its `id()` is stored in the tracker's
1032         `delimiters` field.
1033         """
1034         if leaf.type == token.COMMENT:
1035             return
1036
1037         self.maybe_decrement_after_for_loop_variable(leaf)
1038         self.maybe_decrement_after_lambda_arguments(leaf)
1039         if leaf.type in CLOSING_BRACKETS:
1040             self.depth -= 1
1041             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1042             leaf.opening_bracket = opening_bracket
1043         leaf.bracket_depth = self.depth
1044         if self.depth == 0:
1045             delim = is_split_before_delimiter(leaf, self.previous)
1046             if delim and self.previous is not None:
1047                 self.delimiters[id(self.previous)] = delim
1048             else:
1049                 delim = is_split_after_delimiter(leaf, self.previous)
1050                 if delim:
1051                     self.delimiters[id(leaf)] = delim
1052         if leaf.type in OPENING_BRACKETS:
1053             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1054             self.depth += 1
1055         self.previous = leaf
1056         self.maybe_increment_lambda_arguments(leaf)
1057         self.maybe_increment_for_loop_variable(leaf)
1058
1059     def any_open_brackets(self) -> bool:
1060         """Return True if there is an yet unmatched open bracket on the line."""
1061         return bool(self.bracket_match)
1062
1063     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1064         """Return the highest priority of a delimiter found on the line.
1065
1066         Values are consistent with what `is_split_*_delimiter()` return.
1067         Raises ValueError on no delimiters.
1068         """
1069         return max(v for k, v in self.delimiters.items() if k not in exclude)
1070
1071     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1072         """Return the number of delimiters with the given `priority`.
1073
1074         If no `priority` is passed, defaults to max priority on the line.
1075         """
1076         if not self.delimiters:
1077             return 0
1078
1079         priority = priority or self.max_delimiter_priority()
1080         return sum(1 for p in self.delimiters.values() if p == priority)
1081
1082     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1083         """In a for loop, or comprehension, the variables are often unpacks.
1084
1085         To avoid splitting on the comma in this situation, increase the depth of
1086         tokens between `for` and `in`.
1087         """
1088         if leaf.type == token.NAME and leaf.value == "for":
1089             self.depth += 1
1090             self._for_loop_depths.append(self.depth)
1091             return True
1092
1093         return False
1094
1095     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1096         """See `maybe_increment_for_loop_variable` above for explanation."""
1097         if (
1098             self._for_loop_depths
1099             and self._for_loop_depths[-1] == self.depth
1100             and leaf.type == token.NAME
1101             and leaf.value == "in"
1102         ):
1103             self.depth -= 1
1104             self._for_loop_depths.pop()
1105             return True
1106
1107         return False
1108
1109     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1110         """In a lambda expression, there might be more than one argument.
1111
1112         To avoid splitting on the comma in this situation, increase the depth of
1113         tokens between `lambda` and `:`.
1114         """
1115         if leaf.type == token.NAME and leaf.value == "lambda":
1116             self.depth += 1
1117             self._lambda_argument_depths.append(self.depth)
1118             return True
1119
1120         return False
1121
1122     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1123         """See `maybe_increment_lambda_arguments` above for explanation."""
1124         if (
1125             self._lambda_argument_depths
1126             and self._lambda_argument_depths[-1] == self.depth
1127             and leaf.type == token.COLON
1128         ):
1129             self.depth -= 1
1130             self._lambda_argument_depths.pop()
1131             return True
1132
1133         return False
1134
1135     def get_open_lsqb(self) -> Optional[Leaf]:
1136         """Return the most recent opening square bracket (if any)."""
1137         return self.bracket_match.get((self.depth - 1, token.RSQB))
1138
1139
1140 @dataclass
1141 class Line:
1142     """Holds leaves and comments. Can be printed with `str(line)`."""
1143
1144     depth: int = 0
1145     leaves: List[Leaf] = Factory(list)
1146     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1147     bracket_tracker: BracketTracker = Factory(BracketTracker)
1148     inside_brackets: bool = False
1149     should_explode: bool = False
1150
1151     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1152         """Add a new `leaf` to the end of the line.
1153
1154         Unless `preformatted` is True, the `leaf` will receive a new consistent
1155         whitespace prefix and metadata applied by :class:`BracketTracker`.
1156         Trailing commas are maybe removed, unpacked for loop variables are
1157         demoted from being delimiters.
1158
1159         Inline comments are put aside.
1160         """
1161         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1162         if not has_value:
1163             return
1164
1165         if token.COLON == leaf.type and self.is_class_paren_empty:
1166             del self.leaves[-2:]
1167         if self.leaves and not preformatted:
1168             # Note: at this point leaf.prefix should be empty except for
1169             # imports, for which we only preserve newlines.
1170             leaf.prefix += whitespace(
1171                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1172             )
1173         if self.inside_brackets or not preformatted:
1174             self.bracket_tracker.mark(leaf)
1175             self.maybe_remove_trailing_comma(leaf)
1176         if not self.append_comment(leaf):
1177             self.leaves.append(leaf)
1178
1179     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1180         """Like :func:`append()` but disallow invalid standalone comment structure.
1181
1182         Raises ValueError when any `leaf` is appended after a standalone comment
1183         or when a standalone comment is not the first leaf on the line.
1184         """
1185         if self.bracket_tracker.depth == 0:
1186             if self.is_comment:
1187                 raise ValueError("cannot append to standalone comments")
1188
1189             if self.leaves and leaf.type == STANDALONE_COMMENT:
1190                 raise ValueError(
1191                     "cannot append standalone comments to a populated line"
1192                 )
1193
1194         self.append(leaf, preformatted=preformatted)
1195
1196     @property
1197     def is_comment(self) -> bool:
1198         """Is this line a standalone comment?"""
1199         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1200
1201     @property
1202     def is_decorator(self) -> bool:
1203         """Is this line a decorator?"""
1204         return bool(self) and self.leaves[0].type == token.AT
1205
1206     @property
1207     def is_import(self) -> bool:
1208         """Is this an import line?"""
1209         return bool(self) and is_import(self.leaves[0])
1210
1211     @property
1212     def is_class(self) -> bool:
1213         """Is this line a class definition?"""
1214         return (
1215             bool(self)
1216             and self.leaves[0].type == token.NAME
1217             and self.leaves[0].value == "class"
1218         )
1219
1220     @property
1221     def is_stub_class(self) -> bool:
1222         """Is this line a class definition with a body consisting only of "..."?"""
1223         return self.is_class and self.leaves[-3:] == [
1224             Leaf(token.DOT, ".") for _ in range(3)
1225         ]
1226
1227     @property
1228     def is_def(self) -> bool:
1229         """Is this a function definition? (Also returns True for async defs.)"""
1230         try:
1231             first_leaf = self.leaves[0]
1232         except IndexError:
1233             return False
1234
1235         try:
1236             second_leaf: Optional[Leaf] = self.leaves[1]
1237         except IndexError:
1238             second_leaf = None
1239         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1240             first_leaf.type == token.ASYNC
1241             and second_leaf is not None
1242             and second_leaf.type == token.NAME
1243             and second_leaf.value == "def"
1244         )
1245
1246     @property
1247     def is_class_paren_empty(self) -> bool:
1248         """Is this a class with no base classes but using parentheses?
1249
1250         Those are unnecessary and should be removed.
1251         """
1252         return (
1253             bool(self)
1254             and len(self.leaves) == 4
1255             and self.is_class
1256             and self.leaves[2].type == token.LPAR
1257             and self.leaves[2].value == "("
1258             and self.leaves[3].type == token.RPAR
1259             and self.leaves[3].value == ")"
1260         )
1261
1262     @property
1263     def is_triple_quoted_string(self) -> bool:
1264         """Is the line a triple quoted string?"""
1265         return (
1266             bool(self)
1267             and self.leaves[0].type == token.STRING
1268             and self.leaves[0].value.startswith(('"""', "'''"))
1269         )
1270
1271     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1272         """If so, needs to be split before emitting."""
1273         for leaf in self.leaves:
1274             if leaf.type == STANDALONE_COMMENT:
1275                 if leaf.bracket_depth <= depth_limit:
1276                     return True
1277         return False
1278
1279     def contains_inner_type_comments(self) -> bool:
1280         ignored_ids = set()
1281         try:
1282             last_leaf = self.leaves[-1]
1283             ignored_ids.add(id(last_leaf))
1284             if last_leaf.type == token.COMMA:
1285                 # When trailing commas are inserted by Black for consistency, comments
1286                 # after the previous last element are not moved (they don't have to,
1287                 # rendering will still be correct).  So we ignore trailing commas.
1288                 last_leaf = self.leaves[-2]
1289                 ignored_ids.add(id(last_leaf))
1290         except IndexError:
1291             return False
1292
1293         for leaf_id, comments in self.comments.items():
1294             if leaf_id in ignored_ids:
1295                 continue
1296
1297             for comment in comments:
1298                 if is_type_comment(comment):
1299                     return True
1300
1301         return False
1302
1303     def contains_multiline_strings(self) -> bool:
1304         for leaf in self.leaves:
1305             if is_multiline_string(leaf):
1306                 return True
1307
1308         return False
1309
1310     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1311         """Remove trailing comma if there is one and it's safe."""
1312         if not (
1313             self.leaves
1314             and self.leaves[-1].type == token.COMMA
1315             and closing.type in CLOSING_BRACKETS
1316         ):
1317             return False
1318
1319         if closing.type == token.RBRACE:
1320             self.remove_trailing_comma()
1321             return True
1322
1323         if closing.type == token.RSQB:
1324             comma = self.leaves[-1]
1325             if comma.parent and comma.parent.type == syms.listmaker:
1326                 self.remove_trailing_comma()
1327                 return True
1328
1329         # For parens let's check if it's safe to remove the comma.
1330         # Imports are always safe.
1331         if self.is_import:
1332             self.remove_trailing_comma()
1333             return True
1334
1335         # Otherwise, if the trailing one is the only one, we might mistakenly
1336         # change a tuple into a different type by removing the comma.
1337         depth = closing.bracket_depth + 1
1338         commas = 0
1339         opening = closing.opening_bracket
1340         for _opening_index, leaf in enumerate(self.leaves):
1341             if leaf is opening:
1342                 break
1343
1344         else:
1345             return False
1346
1347         for leaf in self.leaves[_opening_index + 1 :]:
1348             if leaf is closing:
1349                 break
1350
1351             bracket_depth = leaf.bracket_depth
1352             if bracket_depth == depth and leaf.type == token.COMMA:
1353                 commas += 1
1354                 if leaf.parent and leaf.parent.type == syms.arglist:
1355                     commas += 1
1356                     break
1357
1358         if commas > 1:
1359             self.remove_trailing_comma()
1360             return True
1361
1362         return False
1363
1364     def append_comment(self, comment: Leaf) -> bool:
1365         """Add an inline or standalone comment to the line."""
1366         if (
1367             comment.type == STANDALONE_COMMENT
1368             and self.bracket_tracker.any_open_brackets()
1369         ):
1370             comment.prefix = ""
1371             return False
1372
1373         if comment.type != token.COMMENT:
1374             return False
1375
1376         if not self.leaves:
1377             comment.type = STANDALONE_COMMENT
1378             comment.prefix = ""
1379             return False
1380
1381         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1382         return True
1383
1384     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1385         """Generate comments that should appear directly after `leaf`."""
1386         return self.comments.get(id(leaf), [])
1387
1388     def remove_trailing_comma(self) -> None:
1389         """Remove the trailing comma and moves the comments attached to it."""
1390         trailing_comma = self.leaves.pop()
1391         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1392         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1393             trailing_comma_comments
1394         )
1395
1396     def is_complex_subscript(self, leaf: Leaf) -> bool:
1397         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1398         open_lsqb = self.bracket_tracker.get_open_lsqb()
1399         if open_lsqb is None:
1400             return False
1401
1402         subscript_start = open_lsqb.next_sibling
1403
1404         if isinstance(subscript_start, Node):
1405             if subscript_start.type == syms.listmaker:
1406                 return False
1407
1408             if subscript_start.type == syms.subscriptlist:
1409                 subscript_start = child_towards(subscript_start, leaf)
1410         return subscript_start is not None and any(
1411             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1412         )
1413
1414     def __str__(self) -> str:
1415         """Render the line."""
1416         if not self:
1417             return "\n"
1418
1419         indent = "    " * self.depth
1420         leaves = iter(self.leaves)
1421         first = next(leaves)
1422         res = f"{first.prefix}{indent}{first.value}"
1423         for leaf in leaves:
1424             res += str(leaf)
1425         for comment in itertools.chain.from_iterable(self.comments.values()):
1426             res += str(comment)
1427         return res + "\n"
1428
1429     def __bool__(self) -> bool:
1430         """Return True if the line has leaves or comments."""
1431         return bool(self.leaves or self.comments)
1432
1433
1434 @dataclass
1435 class EmptyLineTracker:
1436     """Provides a stateful method that returns the number of potential extra
1437     empty lines needed before and after the currently processed line.
1438
1439     Note: this tracker works on lines that haven't been split yet.  It assumes
1440     the prefix of the first leaf consists of optional newlines.  Those newlines
1441     are consumed by `maybe_empty_lines()` and included in the computation.
1442     """
1443
1444     is_pyi: bool = False
1445     previous_line: Optional[Line] = None
1446     previous_after: int = 0
1447     previous_defs: List[int] = Factory(list)
1448
1449     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1450         """Return the number of extra empty lines before and after the `current_line`.
1451
1452         This is for separating `def`, `async def` and `class` with extra empty
1453         lines (two on module-level).
1454         """
1455         before, after = self._maybe_empty_lines(current_line)
1456         before -= self.previous_after
1457         self.previous_after = after
1458         self.previous_line = current_line
1459         return before, after
1460
1461     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1462         max_allowed = 1
1463         if current_line.depth == 0:
1464             max_allowed = 1 if self.is_pyi else 2
1465         if current_line.leaves:
1466             # Consume the first leaf's extra newlines.
1467             first_leaf = current_line.leaves[0]
1468             before = first_leaf.prefix.count("\n")
1469             before = min(before, max_allowed)
1470             first_leaf.prefix = ""
1471         else:
1472             before = 0
1473         depth = current_line.depth
1474         while self.previous_defs and self.previous_defs[-1] >= depth:
1475             self.previous_defs.pop()
1476             if self.is_pyi:
1477                 before = 0 if depth else 1
1478             else:
1479                 before = 1 if depth else 2
1480         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1481             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1482
1483         if (
1484             self.previous_line
1485             and self.previous_line.is_import
1486             and not current_line.is_import
1487             and depth == self.previous_line.depth
1488         ):
1489             return (before or 1), 0
1490
1491         if (
1492             self.previous_line
1493             and self.previous_line.is_class
1494             and current_line.is_triple_quoted_string
1495         ):
1496             return before, 1
1497
1498         return before, 0
1499
1500     def _maybe_empty_lines_for_class_or_def(
1501         self, current_line: Line, before: int
1502     ) -> Tuple[int, int]:
1503         if not current_line.is_decorator:
1504             self.previous_defs.append(current_line.depth)
1505         if self.previous_line is None:
1506             # Don't insert empty lines before the first line in the file.
1507             return 0, 0
1508
1509         if self.previous_line.is_decorator:
1510             return 0, 0
1511
1512         if self.previous_line.depth < current_line.depth and (
1513             self.previous_line.is_class or self.previous_line.is_def
1514         ):
1515             return 0, 0
1516
1517         if (
1518             self.previous_line.is_comment
1519             and self.previous_line.depth == current_line.depth
1520             and before == 0
1521         ):
1522             return 0, 0
1523
1524         if self.is_pyi:
1525             if self.previous_line.depth > current_line.depth:
1526                 newlines = 1
1527             elif current_line.is_class or self.previous_line.is_class:
1528                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1529                     # No blank line between classes with an empty body
1530                     newlines = 0
1531                 else:
1532                     newlines = 1
1533             elif current_line.is_def and not self.previous_line.is_def:
1534                 # Blank line between a block of functions and a block of non-functions
1535                 newlines = 1
1536             else:
1537                 newlines = 0
1538         else:
1539             newlines = 2
1540         if current_line.depth and newlines:
1541             newlines -= 1
1542         return newlines, 0
1543
1544
1545 @dataclass
1546 class LineGenerator(Visitor[Line]):
1547     """Generates reformatted Line objects.  Empty lines are not emitted.
1548
1549     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1550     in ways that will no longer stringify to valid Python code on the tree.
1551     """
1552
1553     is_pyi: bool = False
1554     normalize_strings: bool = True
1555     current_line: Line = Factory(Line)
1556     remove_u_prefix: bool = False
1557
1558     def line(self, indent: int = 0) -> Iterator[Line]:
1559         """Generate a line.
1560
1561         If the line is empty, only emit if it makes sense.
1562         If the line is too long, split it first and then generate.
1563
1564         If any lines were generated, set up a new current_line.
1565         """
1566         if not self.current_line:
1567             self.current_line.depth += indent
1568             return  # Line is empty, don't emit. Creating a new one unnecessary.
1569
1570         complete_line = self.current_line
1571         self.current_line = Line(depth=complete_line.depth + indent)
1572         yield complete_line
1573
1574     def visit_default(self, node: LN) -> Iterator[Line]:
1575         """Default `visit_*()` implementation. Recurses to children of `node`."""
1576         if isinstance(node, Leaf):
1577             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1578             for comment in generate_comments(node):
1579                 if any_open_brackets:
1580                     # any comment within brackets is subject to splitting
1581                     self.current_line.append(comment)
1582                 elif comment.type == token.COMMENT:
1583                     # regular trailing comment
1584                     self.current_line.append(comment)
1585                     yield from self.line()
1586
1587                 else:
1588                     # regular standalone comment
1589                     yield from self.line()
1590
1591                     self.current_line.append(comment)
1592                     yield from self.line()
1593
1594             normalize_prefix(node, inside_brackets=any_open_brackets)
1595             if self.normalize_strings and node.type == token.STRING:
1596                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1597                 normalize_string_quotes(node)
1598             if node.type == token.NUMBER:
1599                 normalize_numeric_literal(node)
1600             if node.type not in WHITESPACE:
1601                 self.current_line.append(node)
1602         yield from super().visit_default(node)
1603
1604     def visit_atom(self, node: Node) -> Iterator[Line]:
1605         # Always make parentheses invisible around a single node, because it should
1606         # not be needed (except in the case of yield, where removing the parentheses
1607         # produces a SyntaxError).
1608         if (
1609             len(node.children) == 3
1610             and isinstance(node.children[0], Leaf)
1611             and node.children[0].type == token.LPAR
1612             and isinstance(node.children[2], Leaf)
1613             and node.children[2].type == token.RPAR
1614             and isinstance(node.children[1], Leaf)
1615             and not (
1616                 node.children[1].type == token.NAME
1617                 and node.children[1].value == "yield"
1618             )
1619         ):
1620             node.children[0].value = ""
1621             node.children[2].value = ""
1622         yield from super().visit_default(node)
1623
1624     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1625         """Increase indentation level, maybe yield a line."""
1626         # In blib2to3 INDENT never holds comments.
1627         yield from self.line(+1)
1628         yield from self.visit_default(node)
1629
1630     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1631         """Decrease indentation level, maybe yield a line."""
1632         # The current line might still wait for trailing comments.  At DEDENT time
1633         # there won't be any (they would be prefixes on the preceding NEWLINE).
1634         # Emit the line then.
1635         yield from self.line()
1636
1637         # While DEDENT has no value, its prefix may contain standalone comments
1638         # that belong to the current indentation level.  Get 'em.
1639         yield from self.visit_default(node)
1640
1641         # Finally, emit the dedent.
1642         yield from self.line(-1)
1643
1644     def visit_stmt(
1645         self, node: Node, keywords: Set[str], parens: Set[str]
1646     ) -> Iterator[Line]:
1647         """Visit a statement.
1648
1649         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1650         `def`, `with`, `class`, `assert` and assignments.
1651
1652         The relevant Python language `keywords` for a given statement will be
1653         NAME leaves within it. This methods puts those on a separate line.
1654
1655         `parens` holds a set of string leaf values immediately after which
1656         invisible parens should be put.
1657         """
1658         normalize_invisible_parens(node, parens_after=parens)
1659         for child in node.children:
1660             if child.type == token.NAME and child.value in keywords:  # type: ignore
1661                 yield from self.line()
1662
1663             yield from self.visit(child)
1664
1665     def visit_suite(self, node: Node) -> Iterator[Line]:
1666         """Visit a suite."""
1667         if self.is_pyi and is_stub_suite(node):
1668             yield from self.visit(node.children[2])
1669         else:
1670             yield from self.visit_default(node)
1671
1672     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1673         """Visit a statement without nested statements."""
1674         is_suite_like = node.parent and node.parent.type in STATEMENT
1675         if is_suite_like:
1676             if self.is_pyi and is_stub_body(node):
1677                 yield from self.visit_default(node)
1678             else:
1679                 yield from self.line(+1)
1680                 yield from self.visit_default(node)
1681                 yield from self.line(-1)
1682
1683         else:
1684             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1685                 yield from self.line()
1686             yield from self.visit_default(node)
1687
1688     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1689         """Visit `async def`, `async for`, `async with`."""
1690         yield from self.line()
1691
1692         children = iter(node.children)
1693         for child in children:
1694             yield from self.visit(child)
1695
1696             if child.type == token.ASYNC:
1697                 break
1698
1699         internal_stmt = next(children)
1700         for child in internal_stmt.children:
1701             yield from self.visit(child)
1702
1703     def visit_decorators(self, node: Node) -> Iterator[Line]:
1704         """Visit decorators."""
1705         for child in node.children:
1706             yield from self.line()
1707             yield from self.visit(child)
1708
1709     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1710         """Remove a semicolon and put the other statement on a separate line."""
1711         yield from self.line()
1712
1713     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1714         """End of file. Process outstanding comments and end with a newline."""
1715         yield from self.visit_default(leaf)
1716         yield from self.line()
1717
1718     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1719         if not self.current_line.bracket_tracker.any_open_brackets():
1720             yield from self.line()
1721         yield from self.visit_default(leaf)
1722
1723     def __attrs_post_init__(self) -> None:
1724         """You are in a twisty little maze of passages."""
1725         v = self.visit_stmt
1726         Ø: Set[str] = set()
1727         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1728         self.visit_if_stmt = partial(
1729             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1730         )
1731         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1732         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1733         self.visit_try_stmt = partial(
1734             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1735         )
1736         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1737         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1738         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1739         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1740         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1741         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1742         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1743         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1744         self.visit_async_funcdef = self.visit_async_stmt
1745         self.visit_decorated = self.visit_decorators
1746
1747
1748 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1749 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1750 OPENING_BRACKETS = set(BRACKET.keys())
1751 CLOSING_BRACKETS = set(BRACKET.values())
1752 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1753 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1754
1755
1756 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1757     """Return whitespace prefix if needed for the given `leaf`.
1758
1759     `complex_subscript` signals whether the given leaf is part of a subscription
1760     which has non-trivial arguments, like arithmetic expressions or function calls.
1761     """
1762     NO = ""
1763     SPACE = " "
1764     DOUBLESPACE = "  "
1765     t = leaf.type
1766     p = leaf.parent
1767     v = leaf.value
1768     if t in ALWAYS_NO_SPACE:
1769         return NO
1770
1771     if t == token.COMMENT:
1772         return DOUBLESPACE
1773
1774     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1775     if t == token.COLON and p.type not in {
1776         syms.subscript,
1777         syms.subscriptlist,
1778         syms.sliceop,
1779     }:
1780         return NO
1781
1782     prev = leaf.prev_sibling
1783     if not prev:
1784         prevp = preceding_leaf(p)
1785         if not prevp or prevp.type in OPENING_BRACKETS:
1786             return NO
1787
1788         if t == token.COLON:
1789             if prevp.type == token.COLON:
1790                 return NO
1791
1792             elif prevp.type != token.COMMA and not complex_subscript:
1793                 return NO
1794
1795             return SPACE
1796
1797         if prevp.type == token.EQUAL:
1798             if prevp.parent:
1799                 if prevp.parent.type in {
1800                     syms.arglist,
1801                     syms.argument,
1802                     syms.parameters,
1803                     syms.varargslist,
1804                 }:
1805                     return NO
1806
1807                 elif prevp.parent.type == syms.typedargslist:
1808                     # A bit hacky: if the equal sign has whitespace, it means we
1809                     # previously found it's a typed argument.  So, we're using
1810                     # that, too.
1811                     return prevp.prefix
1812
1813         elif prevp.type in STARS:
1814             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1815                 return NO
1816
1817         elif prevp.type == token.COLON:
1818             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1819                 return SPACE if complex_subscript else NO
1820
1821         elif (
1822             prevp.parent
1823             and prevp.parent.type == syms.factor
1824             and prevp.type in MATH_OPERATORS
1825         ):
1826             return NO
1827
1828         elif (
1829             prevp.type == token.RIGHTSHIFT
1830             and prevp.parent
1831             and prevp.parent.type == syms.shift_expr
1832             and prevp.prev_sibling
1833             and prevp.prev_sibling.type == token.NAME
1834             and prevp.prev_sibling.value == "print"  # type: ignore
1835         ):
1836             # Python 2 print chevron
1837             return NO
1838
1839     elif prev.type in OPENING_BRACKETS:
1840         return NO
1841
1842     if p.type in {syms.parameters, syms.arglist}:
1843         # untyped function signatures or calls
1844         if not prev or prev.type != token.COMMA:
1845             return NO
1846
1847     elif p.type == syms.varargslist:
1848         # lambdas
1849         if prev and prev.type != token.COMMA:
1850             return NO
1851
1852     elif p.type == syms.typedargslist:
1853         # typed function signatures
1854         if not prev:
1855             return NO
1856
1857         if t == token.EQUAL:
1858             if prev.type != syms.tname:
1859                 return NO
1860
1861         elif prev.type == token.EQUAL:
1862             # A bit hacky: if the equal sign has whitespace, it means we
1863             # previously found it's a typed argument.  So, we're using that, too.
1864             return prev.prefix
1865
1866         elif prev.type != token.COMMA:
1867             return NO
1868
1869     elif p.type == syms.tname:
1870         # type names
1871         if not prev:
1872             prevp = preceding_leaf(p)
1873             if not prevp or prevp.type != token.COMMA:
1874                 return NO
1875
1876     elif p.type == syms.trailer:
1877         # attributes and calls
1878         if t == token.LPAR or t == token.RPAR:
1879             return NO
1880
1881         if not prev:
1882             if t == token.DOT:
1883                 prevp = preceding_leaf(p)
1884                 if not prevp or prevp.type != token.NUMBER:
1885                     return NO
1886
1887             elif t == token.LSQB:
1888                 return NO
1889
1890         elif prev.type != token.COMMA:
1891             return NO
1892
1893     elif p.type == syms.argument:
1894         # single argument
1895         if t == token.EQUAL:
1896             return NO
1897
1898         if not prev:
1899             prevp = preceding_leaf(p)
1900             if not prevp or prevp.type == token.LPAR:
1901                 return NO
1902
1903         elif prev.type in {token.EQUAL} | STARS:
1904             return NO
1905
1906     elif p.type == syms.decorator:
1907         # decorators
1908         return NO
1909
1910     elif p.type == syms.dotted_name:
1911         if prev:
1912             return NO
1913
1914         prevp = preceding_leaf(p)
1915         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1916             return NO
1917
1918     elif p.type == syms.classdef:
1919         if t == token.LPAR:
1920             return NO
1921
1922         if prev and prev.type == token.LPAR:
1923             return NO
1924
1925     elif p.type in {syms.subscript, syms.sliceop}:
1926         # indexing
1927         if not prev:
1928             assert p.parent is not None, "subscripts are always parented"
1929             if p.parent.type == syms.subscriptlist:
1930                 return SPACE
1931
1932             return NO
1933
1934         elif not complex_subscript:
1935             return NO
1936
1937     elif p.type == syms.atom:
1938         if prev and t == token.DOT:
1939             # dots, but not the first one.
1940             return NO
1941
1942     elif p.type == syms.dictsetmaker:
1943         # dict unpacking
1944         if prev and prev.type == token.DOUBLESTAR:
1945             return NO
1946
1947     elif p.type in {syms.factor, syms.star_expr}:
1948         # unary ops
1949         if not prev:
1950             prevp = preceding_leaf(p)
1951             if not prevp or prevp.type in OPENING_BRACKETS:
1952                 return NO
1953
1954             prevp_parent = prevp.parent
1955             assert prevp_parent is not None
1956             if prevp.type == token.COLON and prevp_parent.type in {
1957                 syms.subscript,
1958                 syms.sliceop,
1959             }:
1960                 return NO
1961
1962             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1963                 return NO
1964
1965         elif t in {token.NAME, token.NUMBER, token.STRING}:
1966             return NO
1967
1968     elif p.type == syms.import_from:
1969         if t == token.DOT:
1970             if prev and prev.type == token.DOT:
1971                 return NO
1972
1973         elif t == token.NAME:
1974             if v == "import":
1975                 return SPACE
1976
1977             if prev and prev.type == token.DOT:
1978                 return NO
1979
1980     elif p.type == syms.sliceop:
1981         return NO
1982
1983     return SPACE
1984
1985
1986 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1987     """Return the first leaf that precedes `node`, if any."""
1988     while node:
1989         res = node.prev_sibling
1990         if res:
1991             if isinstance(res, Leaf):
1992                 return res
1993
1994             try:
1995                 return list(res.leaves())[-1]
1996
1997             except IndexError:
1998                 return None
1999
2000         node = node.parent
2001     return None
2002
2003
2004 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2005     """Return the child of `ancestor` that contains `descendant`."""
2006     node: Optional[LN] = descendant
2007     while node and node.parent != ancestor:
2008         node = node.parent
2009     return node
2010
2011
2012 def container_of(leaf: Leaf) -> LN:
2013     """Return `leaf` or one of its ancestors that is the topmost container of it.
2014
2015     By "container" we mean a node where `leaf` is the very first child.
2016     """
2017     same_prefix = leaf.prefix
2018     container: LN = leaf
2019     while container:
2020         parent = container.parent
2021         if parent is None:
2022             break
2023
2024         if parent.children[0].prefix != same_prefix:
2025             break
2026
2027         if parent.type == syms.file_input:
2028             break
2029
2030         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2031             break
2032
2033         container = parent
2034     return container
2035
2036
2037 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
2038     """Return the priority of the `leaf` delimiter, given a line break after it.
2039
2040     The delimiter priorities returned here are from those delimiters that would
2041     cause a line break after themselves.
2042
2043     Higher numbers are higher priority.
2044     """
2045     if leaf.type == token.COMMA:
2046         return COMMA_PRIORITY
2047
2048     return 0
2049
2050
2051 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
2052     """Return the priority of the `leaf` delimiter, given a line break before it.
2053
2054     The delimiter priorities returned here are from those delimiters that would
2055     cause a line break before themselves.
2056
2057     Higher numbers are higher priority.
2058     """
2059     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2060         # * and ** might also be MATH_OPERATORS but in this case they are not.
2061         # Don't treat them as a delimiter.
2062         return 0
2063
2064     if (
2065         leaf.type == token.DOT
2066         and leaf.parent
2067         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2068         and (previous is None or previous.type in CLOSING_BRACKETS)
2069     ):
2070         return DOT_PRIORITY
2071
2072     if (
2073         leaf.type in MATH_OPERATORS
2074         and leaf.parent
2075         and leaf.parent.type not in {syms.factor, syms.star_expr}
2076     ):
2077         return MATH_PRIORITIES[leaf.type]
2078
2079     if leaf.type in COMPARATORS:
2080         return COMPARATOR_PRIORITY
2081
2082     if (
2083         leaf.type == token.STRING
2084         and previous is not None
2085         and previous.type == token.STRING
2086     ):
2087         return STRING_PRIORITY
2088
2089     if leaf.type not in {token.NAME, token.ASYNC}:
2090         return 0
2091
2092     if (
2093         leaf.value == "for"
2094         and leaf.parent
2095         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2096         or leaf.type == token.ASYNC
2097     ):
2098         if (
2099             not isinstance(leaf.prev_sibling, Leaf)
2100             or leaf.prev_sibling.value != "async"
2101         ):
2102             return COMPREHENSION_PRIORITY
2103
2104     if (
2105         leaf.value == "if"
2106         and leaf.parent
2107         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2108     ):
2109         return COMPREHENSION_PRIORITY
2110
2111     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2112         return TERNARY_PRIORITY
2113
2114     if leaf.value == "is":
2115         return COMPARATOR_PRIORITY
2116
2117     if (
2118         leaf.value == "in"
2119         and leaf.parent
2120         and leaf.parent.type in {syms.comp_op, syms.comparison}
2121         and not (
2122             previous is not None
2123             and previous.type == token.NAME
2124             and previous.value == "not"
2125         )
2126     ):
2127         return COMPARATOR_PRIORITY
2128
2129     if (
2130         leaf.value == "not"
2131         and leaf.parent
2132         and leaf.parent.type == syms.comp_op
2133         and not (
2134             previous is not None
2135             and previous.type == token.NAME
2136             and previous.value == "is"
2137         )
2138     ):
2139         return COMPARATOR_PRIORITY
2140
2141     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2142         return LOGIC_PRIORITY
2143
2144     return 0
2145
2146
2147 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2148 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2149
2150
2151 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2152     """Clean the prefix of the `leaf` and generate comments from it, if any.
2153
2154     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2155     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2156     move because it does away with modifying the grammar to include all the
2157     possible places in which comments can be placed.
2158
2159     The sad consequence for us though is that comments don't "belong" anywhere.
2160     This is why this function generates simple parentless Leaf objects for
2161     comments.  We simply don't know what the correct parent should be.
2162
2163     No matter though, we can live without this.  We really only need to
2164     differentiate between inline and standalone comments.  The latter don't
2165     share the line with any code.
2166
2167     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2168     are emitted with a fake STANDALONE_COMMENT token identifier.
2169     """
2170     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2171         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2172
2173
2174 @dataclass
2175 class ProtoComment:
2176     """Describes a piece of syntax that is a comment.
2177
2178     It's not a :class:`blib2to3.pytree.Leaf` so that:
2179
2180     * it can be cached (`Leaf` objects should not be reused more than once as
2181       they store their lineno, column, prefix, and parent information);
2182     * `newlines` and `consumed` fields are kept separate from the `value`. This
2183       simplifies handling of special marker comments like ``# fmt: off/on``.
2184     """
2185
2186     type: int  # token.COMMENT or STANDALONE_COMMENT
2187     value: str  # content of the comment
2188     newlines: int  # how many newlines before the comment
2189     consumed: int  # how many characters of the original leaf's prefix did we consume
2190
2191
2192 @lru_cache(maxsize=4096)
2193 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2194     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2195     result: List[ProtoComment] = []
2196     if not prefix or "#" not in prefix:
2197         return result
2198
2199     consumed = 0
2200     nlines = 0
2201     ignored_lines = 0
2202     for index, line in enumerate(prefix.split("\n")):
2203         consumed += len(line) + 1  # adding the length of the split '\n'
2204         line = line.lstrip()
2205         if not line:
2206             nlines += 1
2207         if not line.startswith("#"):
2208             # Escaped newlines outside of a comment are not really newlines at
2209             # all. We treat a single-line comment following an escaped newline
2210             # as a simple trailing comment.
2211             if line.endswith("\\"):
2212                 ignored_lines += 1
2213             continue
2214
2215         if index == ignored_lines and not is_endmarker:
2216             comment_type = token.COMMENT  # simple trailing comment
2217         else:
2218             comment_type = STANDALONE_COMMENT
2219         comment = make_comment(line)
2220         result.append(
2221             ProtoComment(
2222                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2223             )
2224         )
2225         nlines = 0
2226     return result
2227
2228
2229 def make_comment(content: str) -> str:
2230     """Return a consistently formatted comment from the given `content` string.
2231
2232     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2233     space between the hash sign and the content.
2234
2235     If `content` didn't start with a hash sign, one is provided.
2236     """
2237     content = content.rstrip()
2238     if not content:
2239         return "#"
2240
2241     if content[0] == "#":
2242         content = content[1:]
2243     if content and content[0] not in " !:#'%":
2244         content = " " + content
2245     return "#" + content
2246
2247
2248 def split_line(
2249     line: Line,
2250     line_length: int,
2251     inner: bool = False,
2252     features: Collection[Feature] = (),
2253 ) -> Iterator[Line]:
2254     """Split a `line` into potentially many lines.
2255
2256     They should fit in the allotted `line_length` but might not be able to.
2257     `inner` signifies that there were a pair of brackets somewhere around the
2258     current `line`, possibly transitively. This means we can fallback to splitting
2259     by delimiters if the LHS/RHS don't yield any results.
2260
2261     `features` are syntactical features that may be used in the output.
2262     """
2263     if line.is_comment:
2264         yield line
2265         return
2266
2267     line_str = str(line).strip("\n")
2268
2269     if (
2270         not line.contains_inner_type_comments()
2271         and not line.should_explode
2272         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2273     ):
2274         yield line
2275         return
2276
2277     split_funcs: List[SplitFunc]
2278     if line.is_def:
2279         split_funcs = [left_hand_split]
2280     else:
2281
2282         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2283             for omit in generate_trailers_to_omit(line, line_length):
2284                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2285                 if is_line_short_enough(lines[0], line_length=line_length):
2286                     yield from lines
2287                     return
2288
2289             # All splits failed, best effort split with no omits.
2290             # This mostly happens to multiline strings that are by definition
2291             # reported as not fitting a single line.
2292             yield from right_hand_split(line, line_length, features=features)
2293
2294         if line.inside_brackets:
2295             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2296         else:
2297             split_funcs = [rhs]
2298     for split_func in split_funcs:
2299         # We are accumulating lines in `result` because we might want to abort
2300         # mission and return the original line in the end, or attempt a different
2301         # split altogether.
2302         result: List[Line] = []
2303         try:
2304             for l in split_func(line, features):
2305                 if str(l).strip("\n") == line_str:
2306                     raise CannotSplit("Split function returned an unchanged result")
2307
2308                 result.extend(
2309                     split_line(
2310                         l, line_length=line_length, inner=True, features=features
2311                     )
2312                 )
2313         except CannotSplit:
2314             continue
2315
2316         else:
2317             yield from result
2318             break
2319
2320     else:
2321         yield line
2322
2323
2324 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2325     """Split line into many lines, starting with the first matching bracket pair.
2326
2327     Note: this usually looks weird, only use this for function definitions.
2328     Prefer RHS otherwise.  This is why this function is not symmetrical with
2329     :func:`right_hand_split` which also handles optional parentheses.
2330     """
2331     tail_leaves: List[Leaf] = []
2332     body_leaves: List[Leaf] = []
2333     head_leaves: List[Leaf] = []
2334     current_leaves = head_leaves
2335     matching_bracket = None
2336     for leaf in line.leaves:
2337         if (
2338             current_leaves is body_leaves
2339             and leaf.type in CLOSING_BRACKETS
2340             and leaf.opening_bracket is matching_bracket
2341         ):
2342             current_leaves = tail_leaves if body_leaves else head_leaves
2343         current_leaves.append(leaf)
2344         if current_leaves is head_leaves:
2345             if leaf.type in OPENING_BRACKETS:
2346                 matching_bracket = leaf
2347                 current_leaves = body_leaves
2348     if not matching_bracket:
2349         raise CannotSplit("No brackets found")
2350
2351     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2352     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2353     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2354     bracket_split_succeeded_or_raise(head, body, tail)
2355     for result in (head, body, tail):
2356         if result:
2357             yield result
2358
2359
2360 def right_hand_split(
2361     line: Line,
2362     line_length: int,
2363     features: Collection[Feature] = (),
2364     omit: Collection[LeafID] = (),
2365 ) -> Iterator[Line]:
2366     """Split line into many lines, starting with the last matching bracket pair.
2367
2368     If the split was by optional parentheses, attempt splitting without them, too.
2369     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2370     this split.
2371
2372     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2373     """
2374     tail_leaves: List[Leaf] = []
2375     body_leaves: List[Leaf] = []
2376     head_leaves: List[Leaf] = []
2377     current_leaves = tail_leaves
2378     opening_bracket = None
2379     closing_bracket = None
2380     for leaf in reversed(line.leaves):
2381         if current_leaves is body_leaves:
2382             if leaf is opening_bracket:
2383                 current_leaves = head_leaves if body_leaves else tail_leaves
2384         current_leaves.append(leaf)
2385         if current_leaves is tail_leaves:
2386             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2387                 opening_bracket = leaf.opening_bracket
2388                 closing_bracket = leaf
2389                 current_leaves = body_leaves
2390     if not (opening_bracket and closing_bracket and head_leaves):
2391         # If there is no opening or closing_bracket that means the split failed and
2392         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2393         # the matching `opening_bracket` wasn't available on `line` anymore.
2394         raise CannotSplit("No brackets found")
2395
2396     tail_leaves.reverse()
2397     body_leaves.reverse()
2398     head_leaves.reverse()
2399     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2400     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2401     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2402     bracket_split_succeeded_or_raise(head, body, tail)
2403     if (
2404         # the body shouldn't be exploded
2405         not body.should_explode
2406         # the opening bracket is an optional paren
2407         and opening_bracket.type == token.LPAR
2408         and not opening_bracket.value
2409         # the closing bracket is an optional paren
2410         and closing_bracket.type == token.RPAR
2411         and not closing_bracket.value
2412         # it's not an import (optional parens are the only thing we can split on
2413         # in this case; attempting a split without them is a waste of time)
2414         and not line.is_import
2415         # there are no standalone comments in the body
2416         and not body.contains_standalone_comments(0)
2417         # and we can actually remove the parens
2418         and can_omit_invisible_parens(body, line_length)
2419     ):
2420         omit = {id(closing_bracket), *omit}
2421         try:
2422             yield from right_hand_split(line, line_length, features=features, omit=omit)
2423             return
2424
2425         except CannotSplit:
2426             if not (
2427                 can_be_split(body)
2428                 or is_line_short_enough(body, line_length=line_length)
2429             ):
2430                 raise CannotSplit(
2431                     "Splitting failed, body is still too long and can't be split."
2432                 )
2433
2434             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2435                 raise CannotSplit(
2436                     "The current optional pair of parentheses is bound to fail to "
2437                     "satisfy the splitting algorithm because the head or the tail "
2438                     "contains multiline strings which by definition never fit one "
2439                     "line."
2440                 )
2441
2442     ensure_visible(opening_bracket)
2443     ensure_visible(closing_bracket)
2444     for result in (head, body, tail):
2445         if result:
2446             yield result
2447
2448
2449 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2450     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2451
2452     Do nothing otherwise.
2453
2454     A left- or right-hand split is based on a pair of brackets. Content before
2455     (and including) the opening bracket is left on one line, content inside the
2456     brackets is put on a separate line, and finally content starting with and
2457     following the closing bracket is put on a separate line.
2458
2459     Those are called `head`, `body`, and `tail`, respectively. If the split
2460     produced the same line (all content in `head`) or ended up with an empty `body`
2461     and the `tail` is just the closing bracket, then it's considered failed.
2462     """
2463     tail_len = len(str(tail).strip())
2464     if not body:
2465         if tail_len == 0:
2466             raise CannotSplit("Splitting brackets produced the same line")
2467
2468         elif tail_len < 3:
2469             raise CannotSplit(
2470                 f"Splitting brackets on an empty body to save "
2471                 f"{tail_len} characters is not worth it"
2472             )
2473
2474
2475 def bracket_split_build_line(
2476     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2477 ) -> Line:
2478     """Return a new line with given `leaves` and respective comments from `original`.
2479
2480     If `is_body` is True, the result line is one-indented inside brackets and as such
2481     has its first leaf's prefix normalized and a trailing comma added when expected.
2482     """
2483     result = Line(depth=original.depth)
2484     if is_body:
2485         result.inside_brackets = True
2486         result.depth += 1
2487         if leaves:
2488             # Since body is a new indent level, remove spurious leading whitespace.
2489             normalize_prefix(leaves[0], inside_brackets=True)
2490             # Ensure a trailing comma for imports, but be careful not to add one after
2491             # any comments.
2492             if original.is_import:
2493                 for i in range(len(leaves) - 1, -1, -1):
2494                     if leaves[i].type == STANDALONE_COMMENT:
2495                         continue
2496                     elif leaves[i].type == token.COMMA:
2497                         break
2498                     else:
2499                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2500                         break
2501     # Populate the line
2502     for leaf in leaves:
2503         result.append(leaf, preformatted=True)
2504         for comment_after in original.comments_after(leaf):
2505             result.append(comment_after, preformatted=True)
2506     if is_body:
2507         result.should_explode = should_explode(result, opening_bracket)
2508     return result
2509
2510
2511 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2512     """Normalize prefix of the first leaf in every line returned by `split_func`.
2513
2514     This is a decorator over relevant split functions.
2515     """
2516
2517     @wraps(split_func)
2518     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2519         for l in split_func(line, features):
2520             normalize_prefix(l.leaves[0], inside_brackets=True)
2521             yield l
2522
2523     return split_wrapper
2524
2525
2526 @dont_increase_indentation
2527 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2528     """Split according to delimiters of the highest priority.
2529
2530     If the appropriate Features are given, the split will add trailing commas
2531     also in function signatures and calls that contain `*` and `**`.
2532     """
2533     try:
2534         last_leaf = line.leaves[-1]
2535     except IndexError:
2536         raise CannotSplit("Line empty")
2537
2538     bt = line.bracket_tracker
2539     try:
2540         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2541     except ValueError:
2542         raise CannotSplit("No delimiters found")
2543
2544     if delimiter_priority == DOT_PRIORITY:
2545         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2546             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2547
2548     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2549     lowest_depth = sys.maxsize
2550     trailing_comma_safe = True
2551
2552     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2553         """Append `leaf` to current line or to new line if appending impossible."""
2554         nonlocal current_line
2555         try:
2556             current_line.append_safe(leaf, preformatted=True)
2557         except ValueError:
2558             yield current_line
2559
2560             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2561             current_line.append(leaf)
2562
2563     for leaf in line.leaves:
2564         yield from append_to_line(leaf)
2565
2566         for comment_after in line.comments_after(leaf):
2567             yield from append_to_line(comment_after)
2568
2569         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2570         if leaf.bracket_depth == lowest_depth:
2571             if is_vararg(leaf, within={syms.typedargslist}):
2572                 trailing_comma_safe = (
2573                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2574                 )
2575             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2576                 trailing_comma_safe = (
2577                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2578                 )
2579
2580         leaf_priority = bt.delimiters.get(id(leaf))
2581         if leaf_priority == delimiter_priority:
2582             yield current_line
2583
2584             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2585     if current_line:
2586         if (
2587             trailing_comma_safe
2588             and delimiter_priority == COMMA_PRIORITY
2589             and current_line.leaves[-1].type != token.COMMA
2590             and current_line.leaves[-1].type != STANDALONE_COMMENT
2591         ):
2592             current_line.append(Leaf(token.COMMA, ","))
2593         yield current_line
2594
2595
2596 @dont_increase_indentation
2597 def standalone_comment_split(
2598     line: Line, features: Collection[Feature] = ()
2599 ) -> Iterator[Line]:
2600     """Split standalone comments from the rest of the line."""
2601     if not line.contains_standalone_comments(0):
2602         raise CannotSplit("Line does not have any standalone comments")
2603
2604     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2605
2606     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2607         """Append `leaf` to current line or to new line if appending impossible."""
2608         nonlocal current_line
2609         try:
2610             current_line.append_safe(leaf, preformatted=True)
2611         except ValueError:
2612             yield current_line
2613
2614             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2615             current_line.append(leaf)
2616
2617     for leaf in line.leaves:
2618         yield from append_to_line(leaf)
2619
2620         for comment_after in line.comments_after(leaf):
2621             yield from append_to_line(comment_after)
2622
2623     if current_line:
2624         yield current_line
2625
2626
2627 def is_import(leaf: Leaf) -> bool:
2628     """Return True if the given leaf starts an import statement."""
2629     p = leaf.parent
2630     t = leaf.type
2631     v = leaf.value
2632     return bool(
2633         t == token.NAME
2634         and (
2635             (v == "import" and p and p.type == syms.import_name)
2636             or (v == "from" and p and p.type == syms.import_from)
2637         )
2638     )
2639
2640
2641 def is_type_comment(leaf: Leaf) -> bool:
2642     """Return True if the given leaf is a special comment.
2643     Only returns true for type comments for now."""
2644     t = leaf.type
2645     v = leaf.value
2646     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2647
2648
2649 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2650     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2651     else.
2652
2653     Note: don't use backslashes for formatting or you'll lose your voting rights.
2654     """
2655     if not inside_brackets:
2656         spl = leaf.prefix.split("#")
2657         if "\\" not in spl[0]:
2658             nl_count = spl[-1].count("\n")
2659             if len(spl) > 1:
2660                 nl_count -= 1
2661             leaf.prefix = "\n" * nl_count
2662             return
2663
2664     leaf.prefix = ""
2665
2666
2667 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2668     """Make all string prefixes lowercase.
2669
2670     If remove_u_prefix is given, also removes any u prefix from the string.
2671
2672     Note: Mutates its argument.
2673     """
2674     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2675     assert match is not None, f"failed to match string {leaf.value!r}"
2676     orig_prefix = match.group(1)
2677     new_prefix = orig_prefix.lower()
2678     if remove_u_prefix:
2679         new_prefix = new_prefix.replace("u", "")
2680     leaf.value = f"{new_prefix}{match.group(2)}"
2681
2682
2683 def normalize_string_quotes(leaf: Leaf) -> None:
2684     """Prefer double quotes but only if it doesn't cause more escaping.
2685
2686     Adds or removes backslashes as appropriate. Doesn't parse and fix
2687     strings nested in f-strings (yet).
2688
2689     Note: Mutates its argument.
2690     """
2691     value = leaf.value.lstrip("furbFURB")
2692     if value[:3] == '"""':
2693         return
2694
2695     elif value[:3] == "'''":
2696         orig_quote = "'''"
2697         new_quote = '"""'
2698     elif value[0] == '"':
2699         orig_quote = '"'
2700         new_quote = "'"
2701     else:
2702         orig_quote = "'"
2703         new_quote = '"'
2704     first_quote_pos = leaf.value.find(orig_quote)
2705     if first_quote_pos == -1:
2706         return  # There's an internal error
2707
2708     prefix = leaf.value[:first_quote_pos]
2709     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2710     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2711     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2712     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2713     if "r" in prefix.casefold():
2714         if unescaped_new_quote.search(body):
2715             # There's at least one unescaped new_quote in this raw string
2716             # so converting is impossible
2717             return
2718
2719         # Do not introduce or remove backslashes in raw strings
2720         new_body = body
2721     else:
2722         # remove unnecessary escapes
2723         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2724         if body != new_body:
2725             # Consider the string without unnecessary escapes as the original
2726             body = new_body
2727             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2728         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2729         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2730     if "f" in prefix.casefold():
2731         matches = re.findall(
2732             r"""
2733             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2734                 ([^{].*?)  # contents of the brackets except if begins with {{
2735             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2736             """,
2737             new_body,
2738             re.VERBOSE,
2739         )
2740         for m in matches:
2741             if "\\" in str(m):
2742                 # Do not introduce backslashes in interpolated expressions
2743                 return
2744     if new_quote == '"""' and new_body[-1:] == '"':
2745         # edge case:
2746         new_body = new_body[:-1] + '\\"'
2747     orig_escape_count = body.count("\\")
2748     new_escape_count = new_body.count("\\")
2749     if new_escape_count > orig_escape_count:
2750         return  # Do not introduce more escaping
2751
2752     if new_escape_count == orig_escape_count and orig_quote == '"':
2753         return  # Prefer double quotes
2754
2755     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2756
2757
2758 def normalize_numeric_literal(leaf: Leaf) -> None:
2759     """Normalizes numeric (float, int, and complex) literals.
2760
2761     All letters used in the representation are normalized to lowercase (except
2762     in Python 2 long literals).
2763     """
2764     text = leaf.value.lower()
2765     if text.startswith(("0o", "0b")):
2766         # Leave octal and binary literals alone.
2767         pass
2768     elif text.startswith("0x"):
2769         # Change hex literals to upper case.
2770         before, after = text[:2], text[2:]
2771         text = f"{before}{after.upper()}"
2772     elif "e" in text:
2773         before, after = text.split("e")
2774         sign = ""
2775         if after.startswith("-"):
2776             after = after[1:]
2777             sign = "-"
2778         elif after.startswith("+"):
2779             after = after[1:]
2780         before = format_float_or_int_string(before)
2781         text = f"{before}e{sign}{after}"
2782     elif text.endswith(("j", "l")):
2783         number = text[:-1]
2784         suffix = text[-1]
2785         # Capitalize in "2L" because "l" looks too similar to "1".
2786         if suffix == "l":
2787             suffix = "L"
2788         text = f"{format_float_or_int_string(number)}{suffix}"
2789     else:
2790         text = format_float_or_int_string(text)
2791     leaf.value = text
2792
2793
2794 def format_float_or_int_string(text: str) -> str:
2795     """Formats a float string like "1.0"."""
2796     if "." not in text:
2797         return text
2798
2799     before, after = text.split(".")
2800     return f"{before or 0}.{after or 0}"
2801
2802
2803 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2804     """Make existing optional parentheses invisible or create new ones.
2805
2806     `parens_after` is a set of string leaf values immeditely after which parens
2807     should be put.
2808
2809     Standardizes on visible parentheses for single-element tuples, and keeps
2810     existing visible parentheses for other tuples and generator expressions.
2811     """
2812     for pc in list_comments(node.prefix, is_endmarker=False):
2813         if pc.value in FMT_OFF:
2814             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2815             return
2816
2817     check_lpar = False
2818     for index, child in enumerate(list(node.children)):
2819         # Add parentheses around long tuple unpacking in assignments.
2820         if (
2821             index == 0
2822             and isinstance(child, Node)
2823             and child.type == syms.testlist_star_expr
2824         ):
2825             check_lpar = True
2826
2827         if check_lpar:
2828             if child.type == syms.atom:
2829                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2830                     lpar = Leaf(token.LPAR, "")
2831                     rpar = Leaf(token.RPAR, "")
2832                     index = child.remove() or 0
2833                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2834             elif is_one_tuple(child):
2835                 # wrap child in visible parentheses
2836                 lpar = Leaf(token.LPAR, "(")
2837                 rpar = Leaf(token.RPAR, ")")
2838                 child.remove()
2839                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2840             elif node.type == syms.import_from:
2841                 # "import from" nodes store parentheses directly as part of
2842                 # the statement
2843                 if child.type == token.LPAR:
2844                     # make parentheses invisible
2845                     child.value = ""  # type: ignore
2846                     node.children[-1].value = ""  # type: ignore
2847                 elif child.type != token.STAR:
2848                     # insert invisible parentheses
2849                     node.insert_child(index, Leaf(token.LPAR, ""))
2850                     node.append_child(Leaf(token.RPAR, ""))
2851                 break
2852
2853             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2854                 # wrap child in invisible parentheses
2855                 lpar = Leaf(token.LPAR, "")
2856                 rpar = Leaf(token.RPAR, "")
2857                 index = child.remove() or 0
2858                 prefix = child.prefix
2859                 child.prefix = ""
2860                 new_child = Node(syms.atom, [lpar, child, rpar])
2861                 new_child.prefix = prefix
2862                 node.insert_child(index, new_child)
2863
2864         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2865
2866
2867 def normalize_fmt_off(node: Node) -> None:
2868     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2869     try_again = True
2870     while try_again:
2871         try_again = convert_one_fmt_off_pair(node)
2872
2873
2874 def convert_one_fmt_off_pair(node: Node) -> bool:
2875     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2876
2877     Returns True if a pair was converted.
2878     """
2879     for leaf in node.leaves():
2880         previous_consumed = 0
2881         for comment in list_comments(leaf.prefix, is_endmarker=False):
2882             if comment.value in FMT_OFF:
2883                 # We only want standalone comments. If there's no previous leaf or
2884                 # the previous leaf is indentation, it's a standalone comment in
2885                 # disguise.
2886                 if comment.type != STANDALONE_COMMENT:
2887                     prev = preceding_leaf(leaf)
2888                     if prev and prev.type not in WHITESPACE:
2889                         continue
2890
2891                 ignored_nodes = list(generate_ignored_nodes(leaf))
2892                 if not ignored_nodes:
2893                     continue
2894
2895                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2896                 parent = first.parent
2897                 prefix = first.prefix
2898                 first.prefix = prefix[comment.consumed :]
2899                 hidden_value = (
2900                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2901                 )
2902                 if hidden_value.endswith("\n"):
2903                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2904                     # leaf (possibly followed by a DEDENT).
2905                     hidden_value = hidden_value[:-1]
2906                 first_idx = None
2907                 for ignored in ignored_nodes:
2908                     index = ignored.remove()
2909                     if first_idx is None:
2910                         first_idx = index
2911                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2912                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2913                 parent.insert_child(
2914                     first_idx,
2915                     Leaf(
2916                         STANDALONE_COMMENT,
2917                         hidden_value,
2918                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2919                     ),
2920                 )
2921                 return True
2922
2923             previous_consumed = comment.consumed
2924
2925     return False
2926
2927
2928 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2929     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2930
2931     Stops at the end of the block.
2932     """
2933     container: Optional[LN] = container_of(leaf)
2934     while container is not None and container.type != token.ENDMARKER:
2935         for comment in list_comments(container.prefix, is_endmarker=False):
2936             if comment.value in FMT_ON:
2937                 return
2938
2939         yield container
2940
2941         container = container.next_sibling
2942
2943
2944 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
2945     """If it's safe, make the parens in the atom `node` invisible, recursively.
2946
2947     Returns whether the node should itself be wrapped in invisible parentheses.
2948
2949     """
2950     if (
2951         node.type != syms.atom
2952         or is_empty_tuple(node)
2953         or is_one_tuple(node)
2954         or (is_yield(node) and parent.type != syms.expr_stmt)
2955         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2956     ):
2957         return False
2958
2959     first = node.children[0]
2960     last = node.children[-1]
2961     if first.type == token.LPAR and last.type == token.RPAR:
2962         # make parentheses invisible
2963         first.value = ""  # type: ignore
2964         last.value = ""  # type: ignore
2965         if len(node.children) > 1:
2966             maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
2967         return False
2968
2969     return True
2970
2971
2972 def is_empty_tuple(node: LN) -> bool:
2973     """Return True if `node` holds an empty tuple."""
2974     return (
2975         node.type == syms.atom
2976         and len(node.children) == 2
2977         and node.children[0].type == token.LPAR
2978         and node.children[1].type == token.RPAR
2979     )
2980
2981
2982 def is_one_tuple(node: LN) -> bool:
2983     """Return True if `node` holds a tuple with one element, with or without parens."""
2984     if node.type == syms.atom:
2985         if len(node.children) != 3:
2986             return False
2987
2988         lpar, gexp, rpar = node.children
2989         if not (
2990             lpar.type == token.LPAR
2991             and gexp.type == syms.testlist_gexp
2992             and rpar.type == token.RPAR
2993         ):
2994             return False
2995
2996         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2997
2998     return (
2999         node.type in IMPLICIT_TUPLE
3000         and len(node.children) == 2
3001         and node.children[1].type == token.COMMA
3002     )
3003
3004
3005 def is_yield(node: LN) -> bool:
3006     """Return True if `node` holds a `yield` or `yield from` expression."""
3007     if node.type == syms.yield_expr:
3008         return True
3009
3010     if node.type == token.NAME and node.value == "yield":  # type: ignore
3011         return True
3012
3013     if node.type != syms.atom:
3014         return False
3015
3016     if len(node.children) != 3:
3017         return False
3018
3019     lpar, expr, rpar = node.children
3020     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3021         return is_yield(expr)
3022
3023     return False
3024
3025
3026 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3027     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3028
3029     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3030     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3031     extended iterable unpacking (PEP 3132) and additional unpacking
3032     generalizations (PEP 448).
3033     """
3034     if leaf.type not in STARS or not leaf.parent:
3035         return False
3036
3037     p = leaf.parent
3038     if p.type == syms.star_expr:
3039         # Star expressions are also used as assignment targets in extended
3040         # iterable unpacking (PEP 3132).  See what its parent is instead.
3041         if not p.parent:
3042             return False
3043
3044         p = p.parent
3045
3046     return p.type in within
3047
3048
3049 def is_multiline_string(leaf: Leaf) -> bool:
3050     """Return True if `leaf` is a multiline string that actually spans many lines."""
3051     value = leaf.value.lstrip("furbFURB")
3052     return value[:3] in {'"""', "'''"} and "\n" in value
3053
3054
3055 def is_stub_suite(node: Node) -> bool:
3056     """Return True if `node` is a suite with a stub body."""
3057     if (
3058         len(node.children) != 4
3059         or node.children[0].type != token.NEWLINE
3060         or node.children[1].type != token.INDENT
3061         or node.children[3].type != token.DEDENT
3062     ):
3063         return False
3064
3065     return is_stub_body(node.children[2])
3066
3067
3068 def is_stub_body(node: LN) -> bool:
3069     """Return True if `node` is a simple statement containing an ellipsis."""
3070     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3071         return False
3072
3073     if len(node.children) != 2:
3074         return False
3075
3076     child = node.children[0]
3077     return (
3078         child.type == syms.atom
3079         and len(child.children) == 3
3080         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3081     )
3082
3083
3084 def max_delimiter_priority_in_atom(node: LN) -> int:
3085     """Return maximum delimiter priority inside `node`.
3086
3087     This is specific to atoms with contents contained in a pair of parentheses.
3088     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3089     """
3090     if node.type != syms.atom:
3091         return 0
3092
3093     first = node.children[0]
3094     last = node.children[-1]
3095     if not (first.type == token.LPAR and last.type == token.RPAR):
3096         return 0
3097
3098     bt = BracketTracker()
3099     for c in node.children[1:-1]:
3100         if isinstance(c, Leaf):
3101             bt.mark(c)
3102         else:
3103             for leaf in c.leaves():
3104                 bt.mark(leaf)
3105     try:
3106         return bt.max_delimiter_priority()
3107
3108     except ValueError:
3109         return 0
3110
3111
3112 def ensure_visible(leaf: Leaf) -> None:
3113     """Make sure parentheses are visible.
3114
3115     They could be invisible as part of some statements (see
3116     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3117     """
3118     if leaf.type == token.LPAR:
3119         leaf.value = "("
3120     elif leaf.type == token.RPAR:
3121         leaf.value = ")"
3122
3123
3124 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3125     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3126
3127     if not (
3128         opening_bracket.parent
3129         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3130         and opening_bracket.value in "[{("
3131     ):
3132         return False
3133
3134     try:
3135         last_leaf = line.leaves[-1]
3136         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3137         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3138     except (IndexError, ValueError):
3139         return False
3140
3141     return max_priority == COMMA_PRIORITY
3142
3143
3144 def get_features_used(node: Node) -> Set[Feature]:
3145     """Return a set of (relatively) new Python features used in this file.
3146
3147     Currently looking for:
3148     - f-strings;
3149     - underscores in numeric literals; and
3150     - trailing commas after * or ** in function signatures and calls.
3151     """
3152     features: Set[Feature] = set()
3153     for n in node.pre_order():
3154         if n.type == token.STRING:
3155             value_head = n.value[:2]  # type: ignore
3156             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3157                 features.add(Feature.F_STRINGS)
3158
3159         elif n.type == token.NUMBER:
3160             if "_" in n.value:  # type: ignore
3161                 features.add(Feature.NUMERIC_UNDERSCORES)
3162
3163         elif (
3164             n.type in {syms.typedargslist, syms.arglist}
3165             and n.children
3166             and n.children[-1].type == token.COMMA
3167         ):
3168             if n.type == syms.typedargslist:
3169                 feature = Feature.TRAILING_COMMA_IN_DEF
3170             else:
3171                 feature = Feature.TRAILING_COMMA_IN_CALL
3172
3173             for ch in n.children:
3174                 if ch.type in STARS:
3175                     features.add(feature)
3176
3177                 if ch.type == syms.argument:
3178                     for argch in ch.children:
3179                         if argch.type in STARS:
3180                             features.add(feature)
3181
3182     return features
3183
3184
3185 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3186     """Detect the version to target based on the nodes used."""
3187     features = get_features_used(node)
3188     return {
3189         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3190     }
3191
3192
3193 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3194     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3195
3196     Brackets can be omitted if the entire trailer up to and including
3197     a preceding closing bracket fits in one line.
3198
3199     Yielded sets are cumulative (contain results of previous yields, too).  First
3200     set is empty.
3201     """
3202
3203     omit: Set[LeafID] = set()
3204     yield omit
3205
3206     length = 4 * line.depth
3207     opening_bracket = None
3208     closing_bracket = None
3209     inner_brackets: Set[LeafID] = set()
3210     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3211         length += leaf_length
3212         if length > line_length:
3213             break
3214
3215         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3216         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3217             break
3218
3219         if opening_bracket:
3220             if leaf is opening_bracket:
3221                 opening_bracket = None
3222             elif leaf.type in CLOSING_BRACKETS:
3223                 inner_brackets.add(id(leaf))
3224         elif leaf.type in CLOSING_BRACKETS:
3225             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3226                 # Empty brackets would fail a split so treat them as "inner"
3227                 # brackets (e.g. only add them to the `omit` set if another
3228                 # pair of brackets was good enough.
3229                 inner_brackets.add(id(leaf))
3230                 continue
3231
3232             if closing_bracket:
3233                 omit.add(id(closing_bracket))
3234                 omit.update(inner_brackets)
3235                 inner_brackets.clear()
3236                 yield omit
3237
3238             if leaf.value:
3239                 opening_bracket = leaf.opening_bracket
3240                 closing_bracket = leaf
3241
3242
3243 def get_future_imports(node: Node) -> Set[str]:
3244     """Return a set of __future__ imports in the file."""
3245     imports: Set[str] = set()
3246
3247     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3248         for child in children:
3249             if isinstance(child, Leaf):
3250                 if child.type == token.NAME:
3251                     yield child.value
3252             elif child.type == syms.import_as_name:
3253                 orig_name = child.children[0]
3254                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3255                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3256                 yield orig_name.value
3257             elif child.type == syms.import_as_names:
3258                 yield from get_imports_from_children(child.children)
3259             else:
3260                 raise AssertionError("Invalid syntax parsing imports")
3261
3262     for child in node.children:
3263         if child.type != syms.simple_stmt:
3264             break
3265         first_child = child.children[0]
3266         if isinstance(first_child, Leaf):
3267             # Continue looking if we see a docstring; otherwise stop.
3268             if (
3269                 len(child.children) == 2
3270                 and first_child.type == token.STRING
3271                 and child.children[1].type == token.NEWLINE
3272             ):
3273                 continue
3274             else:
3275                 break
3276         elif first_child.type == syms.import_from:
3277             module_name = first_child.children[1]
3278             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3279                 break
3280             imports |= set(get_imports_from_children(first_child.children[3:]))
3281         else:
3282             break
3283     return imports
3284
3285
3286 def gen_python_files_in_dir(
3287     path: Path,
3288     root: Path,
3289     include: Pattern[str],
3290     exclude: Pattern[str],
3291     report: "Report",
3292 ) -> Iterator[Path]:
3293     """Generate all files under `path` whose paths are not excluded by the
3294     `exclude` regex, but are included by the `include` regex.
3295
3296     Symbolic links pointing outside of the `root` directory are ignored.
3297
3298     `report` is where output about exclusions goes.
3299     """
3300     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3301     for child in path.iterdir():
3302         try:
3303             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3304         except ValueError:
3305             if child.is_symlink():
3306                 report.path_ignored(
3307                     child, f"is a symbolic link that points outside {root}"
3308                 )
3309                 continue
3310
3311             raise
3312
3313         if child.is_dir():
3314             normalized_path += "/"
3315         exclude_match = exclude.search(normalized_path)
3316         if exclude_match and exclude_match.group(0):
3317             report.path_ignored(child, f"matches the --exclude regular expression")
3318             continue
3319
3320         if child.is_dir():
3321             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3322
3323         elif child.is_file():
3324             include_match = include.search(normalized_path)
3325             if include_match:
3326                 yield child
3327
3328
3329 @lru_cache()
3330 def find_project_root(srcs: Iterable[str]) -> Path:
3331     """Return a directory containing .git, .hg, or pyproject.toml.
3332
3333     That directory can be one of the directories passed in `srcs` or their
3334     common parent.
3335
3336     If no directory in the tree contains a marker that would specify it's the
3337     project root, the root of the file system is returned.
3338     """
3339     if not srcs:
3340         return Path("/").resolve()
3341
3342     common_base = min(Path(src).resolve() for src in srcs)
3343     if common_base.is_dir():
3344         # Append a fake file so `parents` below returns `common_base_dir`, too.
3345         common_base /= "fake-file"
3346     for directory in common_base.parents:
3347         if (directory / ".git").is_dir():
3348             return directory
3349
3350         if (directory / ".hg").is_dir():
3351             return directory
3352
3353         if (directory / "pyproject.toml").is_file():
3354             return directory
3355
3356     return directory
3357
3358
3359 @dataclass
3360 class Report:
3361     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3362
3363     check: bool = False
3364     quiet: bool = False
3365     verbose: bool = False
3366     change_count: int = 0
3367     same_count: int = 0
3368     failure_count: int = 0
3369
3370     def done(self, src: Path, changed: Changed) -> None:
3371         """Increment the counter for successful reformatting. Write out a message."""
3372         if changed is Changed.YES:
3373             reformatted = "would reformat" if self.check else "reformatted"
3374             if self.verbose or not self.quiet:
3375                 out(f"{reformatted} {src}")
3376             self.change_count += 1
3377         else:
3378             if self.verbose:
3379                 if changed is Changed.NO:
3380                     msg = f"{src} already well formatted, good job."
3381                 else:
3382                     msg = f"{src} wasn't modified on disk since last run."
3383                 out(msg, bold=False)
3384             self.same_count += 1
3385
3386     def failed(self, src: Path, message: str) -> None:
3387         """Increment the counter for failed reformatting. Write out a message."""
3388         err(f"error: cannot format {src}: {message}")
3389         self.failure_count += 1
3390
3391     def path_ignored(self, path: Path, message: str) -> None:
3392         if self.verbose:
3393             out(f"{path} ignored: {message}", bold=False)
3394
3395     @property
3396     def return_code(self) -> int:
3397         """Return the exit code that the app should use.
3398
3399         This considers the current state of changed files and failures:
3400         - if there were any failures, return 123;
3401         - if any files were changed and --check is being used, return 1;
3402         - otherwise return 0.
3403         """
3404         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3405         # 126 we have special return codes reserved by the shell.
3406         if self.failure_count:
3407             return 123
3408
3409         elif self.change_count and self.check:
3410             return 1
3411
3412         return 0
3413
3414     def __str__(self) -> str:
3415         """Render a color report of the current state.
3416
3417         Use `click.unstyle` to remove colors.
3418         """
3419         if self.check:
3420             reformatted = "would be reformatted"
3421             unchanged = "would be left unchanged"
3422             failed = "would fail to reformat"
3423         else:
3424             reformatted = "reformatted"
3425             unchanged = "left unchanged"
3426             failed = "failed to reformat"
3427         report = []
3428         if self.change_count:
3429             s = "s" if self.change_count > 1 else ""
3430             report.append(
3431                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3432             )
3433         if self.same_count:
3434             s = "s" if self.same_count > 1 else ""
3435             report.append(f"{self.same_count} file{s} {unchanged}")
3436         if self.failure_count:
3437             s = "s" if self.failure_count > 1 else ""
3438             report.append(
3439                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3440             )
3441         return ", ".join(report) + "."
3442
3443
3444 def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
3445     for feature_version in (7, 6):
3446         try:
3447             return ast3.parse(src, feature_version=feature_version)
3448         except SyntaxError:
3449             continue
3450
3451     return ast27.parse(src)
3452
3453
3454 def assert_equivalent(src: str, dst: str) -> None:
3455     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3456
3457     import traceback
3458
3459     def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3460         """Simple visitor generating strings to compare ASTs by content."""
3461         yield f"{'  ' * depth}{node.__class__.__name__}("
3462
3463         for field in sorted(node._fields):
3464             # TypeIgnore has only one field 'lineno' which breaks this comparison
3465             if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
3466                 break
3467
3468             # Ignore str kind which is case sensitive / and ignores unicode_literals
3469             if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
3470                 continue
3471
3472             try:
3473                 value = getattr(node, field)
3474             except AttributeError:
3475                 continue
3476
3477             yield f"{'  ' * (depth+1)}{field}="
3478
3479             if isinstance(value, list):
3480                 for item in value:
3481                     # Ignore nested tuples within del statements, because we may insert
3482                     # parentheses and they change the AST.
3483                     if (
3484                         field == "targets"
3485                         and isinstance(node, (ast3.Delete, ast27.Delete))
3486                         and isinstance(item, (ast3.Tuple, ast27.Tuple))
3487                     ):
3488                         for item in item.elts:
3489                             yield from _v(item, depth + 2)
3490                     elif isinstance(item, (ast3.AST, ast27.AST)):
3491                         yield from _v(item, depth + 2)
3492
3493             elif isinstance(value, (ast3.AST, ast27.AST)):
3494                 yield from _v(value, depth + 2)
3495
3496             else:
3497                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3498
3499         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3500
3501     try:
3502         src_ast = parse_ast(src)
3503     except Exception as exc:
3504         raise AssertionError(
3505             f"cannot use --safe with this file; failed to parse source file.  "
3506             f"AST error message: {exc}"
3507         )
3508
3509     try:
3510         dst_ast = parse_ast(dst)
3511     except Exception as exc:
3512         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3513         raise AssertionError(
3514             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3515             f"Please report a bug on https://github.com/python/black/issues.  "
3516             f"This invalid output might be helpful: {log}"
3517         ) from None
3518
3519     src_ast_str = "\n".join(_v(src_ast))
3520     dst_ast_str = "\n".join(_v(dst_ast))
3521     if src_ast_str != dst_ast_str:
3522         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3523         raise AssertionError(
3524             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3525             f"the source.  "
3526             f"Please report a bug on https://github.com/python/black/issues.  "
3527             f"This diff might be helpful: {log}"
3528         ) from None
3529
3530
3531 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3532     """Raise AssertionError if `dst` reformats differently the second time."""
3533     newdst = format_str(dst, mode=mode)
3534     if dst != newdst:
3535         log = dump_to_file(
3536             diff(src, dst, "source", "first pass"),
3537             diff(dst, newdst, "first pass", "second pass"),
3538         )
3539         raise AssertionError(
3540             f"INTERNAL ERROR: Black produced different code on the second pass "
3541             f"of the formatter.  "
3542             f"Please report a bug on https://github.com/python/black/issues.  "
3543             f"This diff might be helpful: {log}"
3544         ) from None
3545
3546
3547 def dump_to_file(*output: str) -> str:
3548     """Dump `output` to a temporary file. Return path to the file."""
3549     import tempfile
3550
3551     with tempfile.NamedTemporaryFile(
3552         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3553     ) as f:
3554         for lines in output:
3555             f.write(lines)
3556             if lines and lines[-1] != "\n":
3557                 f.write("\n")
3558     return f.name
3559
3560
3561 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3562     """Return a unified diff string between strings `a` and `b`."""
3563     import difflib
3564
3565     a_lines = [line + "\n" for line in a.split("\n")]
3566     b_lines = [line + "\n" for line in b.split("\n")]
3567     return "".join(
3568         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3569     )
3570
3571
3572 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3573     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3574     err("Aborted!")
3575     for task in tasks:
3576         task.cancel()
3577
3578
3579 def shutdown(loop: BaseEventLoop) -> None:
3580     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3581     try:
3582         if sys.version_info[:2] >= (3, 7):
3583             all_tasks = asyncio.all_tasks
3584         else:
3585             all_tasks = asyncio.Task.all_tasks
3586         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3587         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3588         if not to_cancel:
3589             return
3590
3591         for task in to_cancel:
3592             task.cancel()
3593         loop.run_until_complete(
3594             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3595         )
3596     finally:
3597         # `concurrent.futures.Future` objects cannot be cancelled once they
3598         # are already running. There might be some when the `shutdown()` happened.
3599         # Silence their logger's spew about the event loop being closed.
3600         cf_logger = logging.getLogger("concurrent.futures")
3601         cf_logger.setLevel(logging.CRITICAL)
3602         loop.close()
3603
3604
3605 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3606     """Replace `regex` with `replacement` twice on `original`.
3607
3608     This is used by string normalization to perform replaces on
3609     overlapping matches.
3610     """
3611     return regex.sub(replacement, regex.sub(replacement, original))
3612
3613
3614 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3615     """Compile a regular expression string in `regex`.
3616
3617     If it contains newlines, use verbose mode.
3618     """
3619     if "\n" in regex:
3620         regex = "(?x)" + regex
3621     return re.compile(regex)
3622
3623
3624 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3625     """Like `reversed(enumerate(sequence))` if that were possible."""
3626     index = len(sequence) - 1
3627     for element in reversed(sequence):
3628         yield (index, element)
3629         index -= 1
3630
3631
3632 def enumerate_with_length(
3633     line: Line, reversed: bool = False
3634 ) -> Iterator[Tuple[Index, Leaf, int]]:
3635     """Return an enumeration of leaves with their length.
3636
3637     Stops prematurely on multiline strings and standalone comments.
3638     """
3639     op = cast(
3640         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3641         enumerate_reversed if reversed else enumerate,
3642     )
3643     for index, leaf in op(line.leaves):
3644         length = len(leaf.prefix) + len(leaf.value)
3645         if "\n" in leaf.value:
3646             return  # Multiline strings, we can't continue.
3647
3648         comment: Optional[Leaf]
3649         for comment in line.comments_after(leaf):
3650             length += len(comment.value)
3651
3652         yield index, leaf, length
3653
3654
3655 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3656     """Return True if `line` is no longer than `line_length`.
3657
3658     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3659     """
3660     if not line_str:
3661         line_str = str(line).strip("\n")
3662     return (
3663         len(line_str) <= line_length
3664         and "\n" not in line_str  # multiline strings
3665         and not line.contains_standalone_comments()
3666     )
3667
3668
3669 def can_be_split(line: Line) -> bool:
3670     """Return False if the line cannot be split *for sure*.
3671
3672     This is not an exhaustive search but a cheap heuristic that we can use to
3673     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3674     in unnecessary parentheses).
3675     """
3676     leaves = line.leaves
3677     if len(leaves) < 2:
3678         return False
3679
3680     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3681         call_count = 0
3682         dot_count = 0
3683         next = leaves[-1]
3684         for leaf in leaves[-2::-1]:
3685             if leaf.type in OPENING_BRACKETS:
3686                 if next.type not in CLOSING_BRACKETS:
3687                     return False
3688
3689                 call_count += 1
3690             elif leaf.type == token.DOT:
3691                 dot_count += 1
3692             elif leaf.type == token.NAME:
3693                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3694                     return False
3695
3696             elif leaf.type not in CLOSING_BRACKETS:
3697                 return False
3698
3699             if dot_count > 1 and call_count > 1:
3700                 return False
3701
3702     return True
3703
3704
3705 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3706     """Does `line` have a shape safe to reformat without optional parens around it?
3707
3708     Returns True for only a subset of potentially nice looking formattings but
3709     the point is to not return false positives that end up producing lines that
3710     are too long.
3711     """
3712     bt = line.bracket_tracker
3713     if not bt.delimiters:
3714         # Without delimiters the optional parentheses are useless.
3715         return True
3716
3717     max_priority = bt.max_delimiter_priority()
3718     if bt.delimiter_count_with_priority(max_priority) > 1:
3719         # With more than one delimiter of a kind the optional parentheses read better.
3720         return False
3721
3722     if max_priority == DOT_PRIORITY:
3723         # A single stranded method call doesn't require optional parentheses.
3724         return True
3725
3726     assert len(line.leaves) >= 2, "Stranded delimiter"
3727
3728     first = line.leaves[0]
3729     second = line.leaves[1]
3730     penultimate = line.leaves[-2]
3731     last = line.leaves[-1]
3732
3733     # With a single delimiter, omit if the expression starts or ends with
3734     # a bracket.
3735     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3736         remainder = False
3737         length = 4 * line.depth
3738         for _index, leaf, leaf_length in enumerate_with_length(line):
3739             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3740                 remainder = True
3741             if remainder:
3742                 length += leaf_length
3743                 if length > line_length:
3744                     break
3745
3746                 if leaf.type in OPENING_BRACKETS:
3747                     # There are brackets we can further split on.
3748                     remainder = False
3749
3750         else:
3751             # checked the entire string and line length wasn't exceeded
3752             if len(line.leaves) == _index + 1:
3753                 return True
3754
3755         # Note: we are not returning False here because a line might have *both*
3756         # a leading opening bracket and a trailing closing bracket.  If the
3757         # opening bracket doesn't match our rule, maybe the closing will.
3758
3759     if (
3760         last.type == token.RPAR
3761         or last.type == token.RBRACE
3762         or (
3763             # don't use indexing for omitting optional parentheses;
3764             # it looks weird
3765             last.type == token.RSQB
3766             and last.parent
3767             and last.parent.type != syms.trailer
3768         )
3769     ):
3770         if penultimate.type in OPENING_BRACKETS:
3771             # Empty brackets don't help.
3772             return False
3773
3774         if is_multiline_string(first):
3775             # Additional wrapping of a multiline string in this situation is
3776             # unnecessary.
3777             return True
3778
3779         length = 4 * line.depth
3780         seen_other_brackets = False
3781         for _index, leaf, leaf_length in enumerate_with_length(line):
3782             length += leaf_length
3783             if leaf is last.opening_bracket:
3784                 if seen_other_brackets or length <= line_length:
3785                     return True
3786
3787             elif leaf.type in OPENING_BRACKETS:
3788                 # There are brackets we can further split on.
3789                 seen_other_brackets = True
3790
3791     return False
3792
3793
3794 def get_cache_file(mode: FileMode) -> Path:
3795     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3796
3797
3798 def read_cache(mode: FileMode) -> Cache:
3799     """Read the cache if it exists and is well formed.
3800
3801     If it is not well formed, the call to write_cache later should resolve the issue.
3802     """
3803     cache_file = get_cache_file(mode)
3804     if not cache_file.exists():
3805         return {}
3806
3807     with cache_file.open("rb") as fobj:
3808         try:
3809             cache: Cache = pickle.load(fobj)
3810         except pickle.UnpicklingError:
3811             return {}
3812
3813     return cache
3814
3815
3816 def get_cache_info(path: Path) -> CacheInfo:
3817     """Return the information used to check if a file is already formatted or not."""
3818     stat = path.stat()
3819     return stat.st_mtime, stat.st_size
3820
3821
3822 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3823     """Split an iterable of paths in `sources` into two sets.
3824
3825     The first contains paths of files that modified on disk or are not in the
3826     cache. The other contains paths to non-modified files.
3827     """
3828     todo, done = set(), set()
3829     for src in sources:
3830         src = src.resolve()
3831         if cache.get(src) != get_cache_info(src):
3832             todo.add(src)
3833         else:
3834             done.add(src)
3835     return todo, done
3836
3837
3838 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3839     """Update the cache file."""
3840     cache_file = get_cache_file(mode)
3841     try:
3842         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3843         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3844         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3845             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3846         os.replace(f.name, cache_file)
3847     except OSError:
3848         pass
3849
3850
3851 def patch_click() -> None:
3852     """Make Click not crash.
3853
3854     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3855     default which restricts paths that it can access during the lifetime of the
3856     application.  Click refuses to work in this scenario by raising a RuntimeError.
3857
3858     In case of Black the likelihood that non-ASCII characters are going to be used in
3859     file paths is minimal since it's Python source code.  Moreover, this crash was
3860     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3861     """
3862     try:
3863         from click import core
3864         from click import _unicodefun  # type: ignore
3865     except ModuleNotFoundError:
3866         return
3867
3868     for module in (core, _unicodefun):
3869         if hasattr(module, "_verify_python3_env"):
3870             module._verify_python3_env = lambda: None
3871
3872
3873 def patched_main() -> None:
3874     freeze_support()
3875     patch_click()
3876     main()
3877
3878
3879 if __name__ == "__main__":
3880     patched_main()