]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add doc clarifying that there is no blackd client (#859)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43 from typed_ast import ast3, ast27
44
45 # lib2to3 fork
46 from blib2to3.pytree import Node, Leaf, type_repr
47 from blib2to3 import pygram, pytree
48 from blib2to3.pgen2 import driver, token
49 from blib2to3.pgen2.grammar import Grammar
50 from blib2to3.pgen2.parse import ParseError
51
52
53 __version__ = "19.3b0"
54 DEFAULT_LINE_LENGTH = 88
55 DEFAULT_EXCLUDES = (
56     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
57 )
58 DEFAULT_INCLUDES = r"\.pyi?$"
59 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
60
61
62 # types
63 FileContent = str
64 Encoding = str
65 NewLine = str
66 Depth = int
67 NodeType = int
68 LeafID = int
69 Priority = int
70 Index = int
71 LN = Union[Leaf, Node]
72 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
73 Timestamp = float
74 FileSize = int
75 CacheInfo = Tuple[Timestamp, FileSize]
76 Cache = Dict[Path, CacheInfo]
77 out = partial(click.secho, bold=True, err=True)
78 err = partial(click.secho, fg="red", err=True)
79
80 pygram.initialize(CACHE_DIR)
81 syms = pygram.python_symbols
82
83
84 class NothingChanged(UserWarning):
85     """Raised when reformatted code is the same as source."""
86
87
88 class CannotSplit(Exception):
89     """A readable split that fits the allotted line length is impossible."""
90
91
92 class InvalidInput(ValueError):
93     """Raised when input source code fails all parse attempts."""
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101
102     @classmethod
103     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
104         if check and not diff:
105             return cls.CHECK
106
107         return cls.DIFF if diff else cls.YES
108
109
110 class Changed(Enum):
111     NO = 0
112     CACHED = 1
113     YES = 2
114
115
116 class TargetVersion(Enum):
117     PY27 = 2
118     PY33 = 3
119     PY34 = 4
120     PY35 = 5
121     PY36 = 6
122     PY37 = 7
123     PY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.PY27
127
128
129 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA_IN_CALL = 4
138     TRAILING_COMMA_IN_DEF = 5
139     # The following two feature-flags are mutually exclusive, and exactly one should be
140     # set for every version of python.
141     ASYNC_IDENTIFIERS = 6
142     ASYNC_KEYWORDS = 7
143
144
145 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
146     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
147     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
148     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
149     TargetVersion.PY35: {
150         Feature.UNICODE_LITERALS,
151         Feature.TRAILING_COMMA_IN_CALL,
152         Feature.ASYNC_IDENTIFIERS,
153     },
154     TargetVersion.PY36: {
155         Feature.UNICODE_LITERALS,
156         Feature.F_STRINGS,
157         Feature.NUMERIC_UNDERSCORES,
158         Feature.TRAILING_COMMA_IN_CALL,
159         Feature.TRAILING_COMMA_IN_DEF,
160         Feature.ASYNC_IDENTIFIERS,
161     },
162     TargetVersion.PY37: {
163         Feature.UNICODE_LITERALS,
164         Feature.F_STRINGS,
165         Feature.NUMERIC_UNDERSCORES,
166         Feature.TRAILING_COMMA_IN_CALL,
167         Feature.TRAILING_COMMA_IN_DEF,
168         Feature.ASYNC_KEYWORDS,
169     },
170     TargetVersion.PY38: {
171         Feature.UNICODE_LITERALS,
172         Feature.F_STRINGS,
173         Feature.NUMERIC_UNDERSCORES,
174         Feature.TRAILING_COMMA_IN_CALL,
175         Feature.TRAILING_COMMA_IN_DEF,
176         Feature.ASYNC_KEYWORDS,
177     },
178 }
179
180
181 @dataclass
182 class FileMode:
183     target_versions: Set[TargetVersion] = Factory(set)
184     line_length: int = DEFAULT_LINE_LENGTH
185     string_normalization: bool = True
186     is_pyi: bool = False
187
188     def get_cache_key(self) -> str:
189         if self.target_versions:
190             version_str = ",".join(
191                 str(version.value)
192                 for version in sorted(self.target_versions, key=lambda v: v.value)
193             )
194         else:
195             version_str = "-"
196         parts = [
197             version_str,
198             str(self.line_length),
199             str(int(self.string_normalization)),
200             str(int(self.is_pyi)),
201         ]
202         return ".".join(parts)
203
204
205 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
206     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
207
208
209 def read_pyproject_toml(
210     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
211 ) -> Optional[str]:
212     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
213
214     Returns the path to a successfully found and read configuration file, None
215     otherwise.
216     """
217     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
218     if not value:
219         root = find_project_root(ctx.params.get("src", ()))
220         path = root / "pyproject.toml"
221         if path.is_file():
222             value = str(path)
223         else:
224             return None
225
226     try:
227         pyproject_toml = toml.load(value)
228         config = pyproject_toml.get("tool", {}).get("black", {})
229     except (toml.TomlDecodeError, OSError) as e:
230         raise click.FileError(
231             filename=value, hint=f"Error reading configuration file: {e}"
232         )
233
234     if not config:
235         return None
236
237     if ctx.default_map is None:
238         ctx.default_map = {}
239     ctx.default_map.update(  # type: ignore  # bad types in .pyi
240         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
241     )
242     return value
243
244
245 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
246 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
247 @click.option(
248     "-l",
249     "--line-length",
250     type=int,
251     default=DEFAULT_LINE_LENGTH,
252     help="How many characters per line to allow.",
253     show_default=True,
254 )
255 @click.option(
256     "-t",
257     "--target-version",
258     type=click.Choice([v.name.lower() for v in TargetVersion]),
259     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
260     multiple=True,
261     help=(
262         "Python versions that should be supported by Black's output. [default: "
263         "per-file auto-detection]"
264     ),
265 )
266 @click.option(
267     "--py36",
268     is_flag=True,
269     help=(
270         "Allow using Python 3.6-only syntax on all input files.  This will put "
271         "trailing commas in function signatures and calls also after *args and "
272         "**kwargs. Deprecated; use --target-version instead. "
273         "[default: per-file auto-detection]"
274     ),
275 )
276 @click.option(
277     "--pyi",
278     is_flag=True,
279     help=(
280         "Format all input files like typing stubs regardless of file extension "
281         "(useful when piping source on standard input)."
282     ),
283 )
284 @click.option(
285     "-S",
286     "--skip-string-normalization",
287     is_flag=True,
288     help="Don't normalize string quotes or prefixes.",
289 )
290 @click.option(
291     "--check",
292     is_flag=True,
293     help=(
294         "Don't write the files back, just return the status.  Return code 0 "
295         "means nothing would change.  Return code 1 means some files would be "
296         "reformatted.  Return code 123 means there was an internal error."
297     ),
298 )
299 @click.option(
300     "--diff",
301     is_flag=True,
302     help="Don't write the files back, just output a diff for each file on stdout.",
303 )
304 @click.option(
305     "--fast/--safe",
306     is_flag=True,
307     help="If --fast given, skip temporary sanity checks. [default: --safe]",
308 )
309 @click.option(
310     "--include",
311     type=str,
312     default=DEFAULT_INCLUDES,
313     help=(
314         "A regular expression that matches files and directories that should be "
315         "included on recursive searches.  An empty value means all files are "
316         "included regardless of the name.  Use forward slashes for directories on "
317         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
318         "later."
319     ),
320     show_default=True,
321 )
322 @click.option(
323     "--exclude",
324     type=str,
325     default=DEFAULT_EXCLUDES,
326     help=(
327         "A regular expression that matches files and directories that should be "
328         "excluded on recursive searches.  An empty value means no paths are excluded. "
329         "Use forward slashes for directories on all platforms (Windows, too).  "
330         "Exclusions are calculated first, inclusions later."
331     ),
332     show_default=True,
333 )
334 @click.option(
335     "-q",
336     "--quiet",
337     is_flag=True,
338     help=(
339         "Don't emit non-error messages to stderr. Errors are still emitted, "
340         "silence those with 2>/dev/null."
341     ),
342 )
343 @click.option(
344     "-v",
345     "--verbose",
346     is_flag=True,
347     help=(
348         "Also emit messages to stderr about files that were not changed or were "
349         "ignored due to --exclude=."
350     ),
351 )
352 @click.version_option(version=__version__)
353 @click.argument(
354     "src",
355     nargs=-1,
356     type=click.Path(
357         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
358     ),
359     is_eager=True,
360 )
361 @click.option(
362     "--config",
363     type=click.Path(
364         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
365     ),
366     is_eager=True,
367     callback=read_pyproject_toml,
368     help="Read configuration from PATH.",
369 )
370 @click.pass_context
371 def main(
372     ctx: click.Context,
373     code: Optional[str],
374     line_length: int,
375     target_version: List[TargetVersion],
376     check: bool,
377     diff: bool,
378     fast: bool,
379     pyi: bool,
380     py36: bool,
381     skip_string_normalization: bool,
382     quiet: bool,
383     verbose: bool,
384     include: str,
385     exclude: str,
386     src: Tuple[str],
387     config: Optional[str],
388 ) -> None:
389     """The uncompromising code formatter."""
390     write_back = WriteBack.from_configuration(check=check, diff=diff)
391     if target_version:
392         if py36:
393             err(f"Cannot use both --target-version and --py36")
394             ctx.exit(2)
395         else:
396             versions = set(target_version)
397     elif py36:
398         err(
399             "--py36 is deprecated and will be removed in a future version. "
400             "Use --target-version py36 instead."
401         )
402         versions = PY36_VERSIONS
403     else:
404         # We'll autodetect later.
405         versions = set()
406     mode = FileMode(
407         target_versions=versions,
408         line_length=line_length,
409         is_pyi=pyi,
410         string_normalization=not skip_string_normalization,
411     )
412     if config and verbose:
413         out(f"Using configuration from {config}.", bold=False, fg="blue")
414     if code is not None:
415         print(format_str(code, mode=mode))
416         ctx.exit(0)
417     try:
418         include_regex = re_compile_maybe_verbose(include)
419     except re.error:
420         err(f"Invalid regular expression for include given: {include!r}")
421         ctx.exit(2)
422     try:
423         exclude_regex = re_compile_maybe_verbose(exclude)
424     except re.error:
425         err(f"Invalid regular expression for exclude given: {exclude!r}")
426         ctx.exit(2)
427     report = Report(check=check, quiet=quiet, verbose=verbose)
428     root = find_project_root(src)
429     sources: Set[Path] = set()
430     for s in src:
431         p = Path(s)
432         if p.is_dir():
433             sources.update(
434                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
435             )
436         elif p.is_file() or s == "-":
437             # if a file was explicitly given, we don't care about its extension
438             sources.add(p)
439         else:
440             err(f"invalid path: {s}")
441     if len(sources) == 0:
442         if verbose or not quiet:
443             out("No paths given. Nothing to do 😴")
444         ctx.exit(0)
445
446     if len(sources) == 1:
447         reformat_one(
448             src=sources.pop(),
449             fast=fast,
450             write_back=write_back,
451             mode=mode,
452             report=report,
453         )
454     else:
455         reformat_many(
456             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
457         )
458
459     if verbose or not quiet:
460         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
461         click.secho(str(report), err=True)
462     ctx.exit(report.return_code)
463
464
465 def reformat_one(
466     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
467 ) -> None:
468     """Reformat a single file under `src` without spawning child processes.
469
470     If `quiet` is True, non-error messages are not output. `line_length`,
471     `write_back`, `fast` and `pyi` options are passed to
472     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
473     """
474     try:
475         changed = Changed.NO
476         if not src.is_file() and str(src) == "-":
477             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
478                 changed = Changed.YES
479         else:
480             cache: Cache = {}
481             if write_back != WriteBack.DIFF:
482                 cache = read_cache(mode)
483                 res_src = src.resolve()
484                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
485                     changed = Changed.CACHED
486             if changed is not Changed.CACHED and format_file_in_place(
487                 src, fast=fast, write_back=write_back, mode=mode
488             ):
489                 changed = Changed.YES
490             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
491                 write_back is WriteBack.CHECK and changed is Changed.NO
492             ):
493                 write_cache(cache, [src], mode)
494         report.done(src, changed)
495     except Exception as exc:
496         report.failed(src, str(exc))
497
498
499 def reformat_many(
500     sources: Set[Path],
501     fast: bool,
502     write_back: WriteBack,
503     mode: FileMode,
504     report: "Report",
505 ) -> None:
506     """Reformat multiple files using a ProcessPoolExecutor."""
507     loop = asyncio.get_event_loop()
508     worker_count = os.cpu_count()
509     if sys.platform == "win32":
510         # Work around https://bugs.python.org/issue26903
511         worker_count = min(worker_count, 61)
512     executor = ProcessPoolExecutor(max_workers=worker_count)
513     try:
514         loop.run_until_complete(
515             schedule_formatting(
516                 sources=sources,
517                 fast=fast,
518                 write_back=write_back,
519                 mode=mode,
520                 report=report,
521                 loop=loop,
522                 executor=executor,
523             )
524         )
525     finally:
526         shutdown(loop)
527
528
529 async def schedule_formatting(
530     sources: Set[Path],
531     fast: bool,
532     write_back: WriteBack,
533     mode: FileMode,
534     report: "Report",
535     loop: BaseEventLoop,
536     executor: Executor,
537 ) -> None:
538     """Run formatting of `sources` in parallel using the provided `executor`.
539
540     (Use ProcessPoolExecutors for actual parallelism.)
541
542     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
543     :func:`format_file_in_place`.
544     """
545     cache: Cache = {}
546     if write_back != WriteBack.DIFF:
547         cache = read_cache(mode)
548         sources, cached = filter_cached(cache, sources)
549         for src in sorted(cached):
550             report.done(src, Changed.CACHED)
551     if not sources:
552         return
553
554     cancelled = []
555     sources_to_cache = []
556     lock = None
557     if write_back == WriteBack.DIFF:
558         # For diff output, we need locks to ensure we don't interleave output
559         # from different processes.
560         manager = Manager()
561         lock = manager.Lock()
562     tasks = {
563         asyncio.ensure_future(
564             loop.run_in_executor(
565                 executor, format_file_in_place, src, fast, mode, write_back, lock
566             )
567         ): src
568         for src in sorted(sources)
569     }
570     pending: Iterable[asyncio.Future] = tasks.keys()
571     try:
572         loop.add_signal_handler(signal.SIGINT, cancel, pending)
573         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
574     except NotImplementedError:
575         # There are no good alternatives for these on Windows.
576         pass
577     while pending:
578         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
579         for task in done:
580             src = tasks.pop(task)
581             if task.cancelled():
582                 cancelled.append(task)
583             elif task.exception():
584                 report.failed(src, str(task.exception()))
585             else:
586                 changed = Changed.YES if task.result() else Changed.NO
587                 # If the file was written back or was successfully checked as
588                 # well-formatted, store this information in the cache.
589                 if write_back is WriteBack.YES or (
590                     write_back is WriteBack.CHECK and changed is Changed.NO
591                 ):
592                     sources_to_cache.append(src)
593                 report.done(src, changed)
594     if cancelled:
595         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
596     if sources_to_cache:
597         write_cache(cache, sources_to_cache, mode)
598
599
600 def format_file_in_place(
601     src: Path,
602     fast: bool,
603     mode: FileMode,
604     write_back: WriteBack = WriteBack.NO,
605     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
606 ) -> bool:
607     """Format file under `src` path. Return True if changed.
608
609     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
610     code to the file.
611     `line_length` and `fast` options are passed to :func:`format_file_contents`.
612     """
613     if src.suffix == ".pyi":
614         mode = evolve(mode, is_pyi=True)
615
616     then = datetime.utcfromtimestamp(src.stat().st_mtime)
617     with open(src, "rb") as buf:
618         src_contents, encoding, newline = decode_bytes(buf.read())
619     try:
620         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
621     except NothingChanged:
622         return False
623
624     if write_back == write_back.YES:
625         with open(src, "w", encoding=encoding, newline=newline) as f:
626             f.write(dst_contents)
627     elif write_back == write_back.DIFF:
628         now = datetime.utcnow()
629         src_name = f"{src}\t{then} +0000"
630         dst_name = f"{src}\t{now} +0000"
631         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
632         if lock:
633             lock.acquire()
634         try:
635             f = io.TextIOWrapper(
636                 sys.stdout.buffer,
637                 encoding=encoding,
638                 newline=newline,
639                 write_through=True,
640             )
641             f.write(diff_contents)
642             f.detach()
643         finally:
644             if lock:
645                 lock.release()
646     return True
647
648
649 def format_stdin_to_stdout(
650     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
651 ) -> bool:
652     """Format file on stdin. Return True if changed.
653
654     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
655     write a diff to stdout. The `mode` argument is passed to
656     :func:`format_file_contents`.
657     """
658     then = datetime.utcnow()
659     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
660     dst = src
661     try:
662         dst = format_file_contents(src, fast=fast, mode=mode)
663         return True
664
665     except NothingChanged:
666         return False
667
668     finally:
669         f = io.TextIOWrapper(
670             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
671         )
672         if write_back == WriteBack.YES:
673             f.write(dst)
674         elif write_back == WriteBack.DIFF:
675             now = datetime.utcnow()
676             src_name = f"STDIN\t{then} +0000"
677             dst_name = f"STDOUT\t{now} +0000"
678             f.write(diff(src, dst, src_name, dst_name))
679         f.detach()
680
681
682 def format_file_contents(
683     src_contents: str, *, fast: bool, mode: FileMode
684 ) -> FileContent:
685     """Reformat contents a file and return new contents.
686
687     If `fast` is False, additionally confirm that the reformatted code is
688     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
689     `line_length` is passed to :func:`format_str`.
690     """
691     if src_contents.strip() == "":
692         raise NothingChanged
693
694     dst_contents = format_str(src_contents, mode=mode)
695     if src_contents == dst_contents:
696         raise NothingChanged
697
698     if not fast:
699         assert_equivalent(src_contents, dst_contents)
700         assert_stable(src_contents, dst_contents, mode=mode)
701     return dst_contents
702
703
704 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
705     """Reformat a string and return new contents.
706
707     `line_length` determines how many characters per line are allowed.
708     """
709     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
710     dst_contents = ""
711     future_imports = get_future_imports(src_node)
712     if mode.target_versions:
713         versions = mode.target_versions
714     else:
715         versions = detect_target_versions(src_node)
716     normalize_fmt_off(src_node)
717     lines = LineGenerator(
718         remove_u_prefix="unicode_literals" in future_imports
719         or supports_feature(versions, Feature.UNICODE_LITERALS),
720         is_pyi=mode.is_pyi,
721         normalize_strings=mode.string_normalization,
722     )
723     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
724     empty_line = Line()
725     after = 0
726     split_line_features = {
727         feature
728         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
729         if supports_feature(versions, feature)
730     }
731     for current_line in lines.visit(src_node):
732         for _ in range(after):
733             dst_contents += str(empty_line)
734         before, after = elt.maybe_empty_lines(current_line)
735         for _ in range(before):
736             dst_contents += str(empty_line)
737         for line in split_line(
738             current_line, line_length=mode.line_length, features=split_line_features
739         ):
740             dst_contents += str(line)
741     return dst_contents
742
743
744 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
745     """Return a tuple of (decoded_contents, encoding, newline).
746
747     `newline` is either CRLF or LF but `decoded_contents` is decoded with
748     universal newlines (i.e. only contains LF).
749     """
750     srcbuf = io.BytesIO(src)
751     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
752     if not lines:
753         return "", encoding, "\n"
754
755     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
756     srcbuf.seek(0)
757     with io.TextIOWrapper(srcbuf, encoding) as tiow:
758         return tiow.read(), encoding, newline
759
760
761 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
762     if not target_versions:
763         # No target_version specified, so try all grammars.
764         return [
765             # Python 3.7+
766             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
767             # Python 3.0-3.6
768             pygram.python_grammar_no_print_statement_no_exec_statement,
769             # Python 2.7 with future print_function import
770             pygram.python_grammar_no_print_statement,
771             # Python 2.7
772             pygram.python_grammar,
773         ]
774     elif all(version.is_python2() for version in target_versions):
775         # Python 2-only code, so try Python 2 grammars.
776         return [
777             # Python 2.7 with future print_function import
778             pygram.python_grammar_no_print_statement,
779             # Python 2.7
780             pygram.python_grammar,
781         ]
782     else:
783         # Python 3-compatible code, so only try Python 3 grammar.
784         grammars = []
785         # If we have to parse both, try to parse async as a keyword first
786         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
787             # Python 3.7+
788             grammars.append(
789                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
790             )
791         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
792             # Python 3.0-3.6
793             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
794         # At least one of the above branches must have been taken, because every Python
795         # version has exactly one of the two 'ASYNC_*' flags
796         return grammars
797
798
799 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
800     """Given a string with source, return the lib2to3 Node."""
801     if src_txt[-1:] != "\n":
802         src_txt += "\n"
803
804     for grammar in get_grammars(set(target_versions)):
805         drv = driver.Driver(grammar, pytree.convert)
806         try:
807             result = drv.parse_string(src_txt, True)
808             break
809
810         except ParseError as pe:
811             lineno, column = pe.context[1]
812             lines = src_txt.splitlines()
813             try:
814                 faulty_line = lines[lineno - 1]
815             except IndexError:
816                 faulty_line = "<line number missing in source>"
817             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
818     else:
819         raise exc from None
820
821     if isinstance(result, Leaf):
822         result = Node(syms.file_input, [result])
823     return result
824
825
826 def lib2to3_unparse(node: Node) -> str:
827     """Given a lib2to3 node, return its string representation."""
828     code = str(node)
829     return code
830
831
832 T = TypeVar("T")
833
834
835 class Visitor(Generic[T]):
836     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
837
838     def visit(self, node: LN) -> Iterator[T]:
839         """Main method to visit `node` and its children.
840
841         It tries to find a `visit_*()` method for the given `node.type`, like
842         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
843         If no dedicated `visit_*()` method is found, chooses `visit_default()`
844         instead.
845
846         Then yields objects of type `T` from the selected visitor.
847         """
848         if node.type < 256:
849             name = token.tok_name[node.type]
850         else:
851             name = type_repr(node.type)
852         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
853
854     def visit_default(self, node: LN) -> Iterator[T]:
855         """Default `visit_*()` implementation. Recurses to children of `node`."""
856         if isinstance(node, Node):
857             for child in node.children:
858                 yield from self.visit(child)
859
860
861 @dataclass
862 class DebugVisitor(Visitor[T]):
863     tree_depth: int = 0
864
865     def visit_default(self, node: LN) -> Iterator[T]:
866         indent = " " * (2 * self.tree_depth)
867         if isinstance(node, Node):
868             _type = type_repr(node.type)
869             out(f"{indent}{_type}", fg="yellow")
870             self.tree_depth += 1
871             for child in node.children:
872                 yield from self.visit(child)
873
874             self.tree_depth -= 1
875             out(f"{indent}/{_type}", fg="yellow", bold=False)
876         else:
877             _type = token.tok_name.get(node.type, str(node.type))
878             out(f"{indent}{_type}", fg="blue", nl=False)
879             if node.prefix:
880                 # We don't have to handle prefixes for `Node` objects since
881                 # that delegates to the first child anyway.
882                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
883             out(f" {node.value!r}", fg="blue", bold=False)
884
885     @classmethod
886     def show(cls, code: Union[str, Leaf, Node]) -> None:
887         """Pretty-print the lib2to3 AST of a given string of `code`.
888
889         Convenience method for debugging.
890         """
891         v: DebugVisitor[None] = DebugVisitor()
892         if isinstance(code, str):
893             code = lib2to3_parse(code)
894         list(v.visit(code))
895
896
897 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
898 STATEMENT = {
899     syms.if_stmt,
900     syms.while_stmt,
901     syms.for_stmt,
902     syms.try_stmt,
903     syms.except_clause,
904     syms.with_stmt,
905     syms.funcdef,
906     syms.classdef,
907 }
908 STANDALONE_COMMENT = 153
909 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
910 LOGIC_OPERATORS = {"and", "or"}
911 COMPARATORS = {
912     token.LESS,
913     token.GREATER,
914     token.EQEQUAL,
915     token.NOTEQUAL,
916     token.LESSEQUAL,
917     token.GREATEREQUAL,
918 }
919 MATH_OPERATORS = {
920     token.VBAR,
921     token.CIRCUMFLEX,
922     token.AMPER,
923     token.LEFTSHIFT,
924     token.RIGHTSHIFT,
925     token.PLUS,
926     token.MINUS,
927     token.STAR,
928     token.SLASH,
929     token.DOUBLESLASH,
930     token.PERCENT,
931     token.AT,
932     token.TILDE,
933     token.DOUBLESTAR,
934 }
935 STARS = {token.STAR, token.DOUBLESTAR}
936 VARARGS_PARENTS = {
937     syms.arglist,
938     syms.argument,  # double star in arglist
939     syms.trailer,  # single argument to call
940     syms.typedargslist,
941     syms.varargslist,  # lambdas
942 }
943 UNPACKING_PARENTS = {
944     syms.atom,  # single element of a list or set literal
945     syms.dictsetmaker,
946     syms.listmaker,
947     syms.testlist_gexp,
948     syms.testlist_star_expr,
949 }
950 TEST_DESCENDANTS = {
951     syms.test,
952     syms.lambdef,
953     syms.or_test,
954     syms.and_test,
955     syms.not_test,
956     syms.comparison,
957     syms.star_expr,
958     syms.expr,
959     syms.xor_expr,
960     syms.and_expr,
961     syms.shift_expr,
962     syms.arith_expr,
963     syms.trailer,
964     syms.term,
965     syms.power,
966 }
967 ASSIGNMENTS = {
968     "=",
969     "+=",
970     "-=",
971     "*=",
972     "@=",
973     "/=",
974     "%=",
975     "&=",
976     "|=",
977     "^=",
978     "<<=",
979     ">>=",
980     "**=",
981     "//=",
982 }
983 COMPREHENSION_PRIORITY = 20
984 COMMA_PRIORITY = 18
985 TERNARY_PRIORITY = 16
986 LOGIC_PRIORITY = 14
987 STRING_PRIORITY = 12
988 COMPARATOR_PRIORITY = 10
989 MATH_PRIORITIES = {
990     token.VBAR: 9,
991     token.CIRCUMFLEX: 8,
992     token.AMPER: 7,
993     token.LEFTSHIFT: 6,
994     token.RIGHTSHIFT: 6,
995     token.PLUS: 5,
996     token.MINUS: 5,
997     token.STAR: 4,
998     token.SLASH: 4,
999     token.DOUBLESLASH: 4,
1000     token.PERCENT: 4,
1001     token.AT: 4,
1002     token.TILDE: 3,
1003     token.DOUBLESTAR: 2,
1004 }
1005 DOT_PRIORITY = 1
1006
1007
1008 @dataclass
1009 class BracketTracker:
1010     """Keeps track of brackets on a line."""
1011
1012     depth: int = 0
1013     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1014     delimiters: Dict[LeafID, Priority] = Factory(dict)
1015     previous: Optional[Leaf] = None
1016     _for_loop_depths: List[int] = Factory(list)
1017     _lambda_argument_depths: List[int] = Factory(list)
1018
1019     def mark(self, leaf: Leaf) -> None:
1020         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1021
1022         All leaves receive an int `bracket_depth` field that stores how deep
1023         within brackets a given leaf is. 0 means there are no enclosing brackets
1024         that started on this line.
1025
1026         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1027         field that it forms a pair with. This is a one-directional link to
1028         avoid reference cycles.
1029
1030         If a leaf is a delimiter (a token on which Black can split the line if
1031         needed) and it's on depth 0, its `id()` is stored in the tracker's
1032         `delimiters` field.
1033         """
1034         if leaf.type == token.COMMENT:
1035             return
1036
1037         self.maybe_decrement_after_for_loop_variable(leaf)
1038         self.maybe_decrement_after_lambda_arguments(leaf)
1039         if leaf.type in CLOSING_BRACKETS:
1040             self.depth -= 1
1041             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1042             leaf.opening_bracket = opening_bracket
1043         leaf.bracket_depth = self.depth
1044         if self.depth == 0:
1045             delim = is_split_before_delimiter(leaf, self.previous)
1046             if delim and self.previous is not None:
1047                 self.delimiters[id(self.previous)] = delim
1048             else:
1049                 delim = is_split_after_delimiter(leaf, self.previous)
1050                 if delim:
1051                     self.delimiters[id(leaf)] = delim
1052         if leaf.type in OPENING_BRACKETS:
1053             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1054             self.depth += 1
1055         self.previous = leaf
1056         self.maybe_increment_lambda_arguments(leaf)
1057         self.maybe_increment_for_loop_variable(leaf)
1058
1059     def any_open_brackets(self) -> bool:
1060         """Return True if there is an yet unmatched open bracket on the line."""
1061         return bool(self.bracket_match)
1062
1063     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1064         """Return the highest priority of a delimiter found on the line.
1065
1066         Values are consistent with what `is_split_*_delimiter()` return.
1067         Raises ValueError on no delimiters.
1068         """
1069         return max(v for k, v in self.delimiters.items() if k not in exclude)
1070
1071     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1072         """Return the number of delimiters with the given `priority`.
1073
1074         If no `priority` is passed, defaults to max priority on the line.
1075         """
1076         if not self.delimiters:
1077             return 0
1078
1079         priority = priority or self.max_delimiter_priority()
1080         return sum(1 for p in self.delimiters.values() if p == priority)
1081
1082     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1083         """In a for loop, or comprehension, the variables are often unpacks.
1084
1085         To avoid splitting on the comma in this situation, increase the depth of
1086         tokens between `for` and `in`.
1087         """
1088         if leaf.type == token.NAME and leaf.value == "for":
1089             self.depth += 1
1090             self._for_loop_depths.append(self.depth)
1091             return True
1092
1093         return False
1094
1095     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1096         """See `maybe_increment_for_loop_variable` above for explanation."""
1097         if (
1098             self._for_loop_depths
1099             and self._for_loop_depths[-1] == self.depth
1100             and leaf.type == token.NAME
1101             and leaf.value == "in"
1102         ):
1103             self.depth -= 1
1104             self._for_loop_depths.pop()
1105             return True
1106
1107         return False
1108
1109     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1110         """In a lambda expression, there might be more than one argument.
1111
1112         To avoid splitting on the comma in this situation, increase the depth of
1113         tokens between `lambda` and `:`.
1114         """
1115         if leaf.type == token.NAME and leaf.value == "lambda":
1116             self.depth += 1
1117             self._lambda_argument_depths.append(self.depth)
1118             return True
1119
1120         return False
1121
1122     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1123         """See `maybe_increment_lambda_arguments` above for explanation."""
1124         if (
1125             self._lambda_argument_depths
1126             and self._lambda_argument_depths[-1] == self.depth
1127             and leaf.type == token.COLON
1128         ):
1129             self.depth -= 1
1130             self._lambda_argument_depths.pop()
1131             return True
1132
1133         return False
1134
1135     def get_open_lsqb(self) -> Optional[Leaf]:
1136         """Return the most recent opening square bracket (if any)."""
1137         return self.bracket_match.get((self.depth - 1, token.RSQB))
1138
1139
1140 @dataclass
1141 class Line:
1142     """Holds leaves and comments. Can be printed with `str(line)`."""
1143
1144     depth: int = 0
1145     leaves: List[Leaf] = Factory(list)
1146     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1147     bracket_tracker: BracketTracker = Factory(BracketTracker)
1148     inside_brackets: bool = False
1149     should_explode: bool = False
1150
1151     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1152         """Add a new `leaf` to the end of the line.
1153
1154         Unless `preformatted` is True, the `leaf` will receive a new consistent
1155         whitespace prefix and metadata applied by :class:`BracketTracker`.
1156         Trailing commas are maybe removed, unpacked for loop variables are
1157         demoted from being delimiters.
1158
1159         Inline comments are put aside.
1160         """
1161         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1162         if not has_value:
1163             return
1164
1165         if token.COLON == leaf.type and self.is_class_paren_empty:
1166             del self.leaves[-2:]
1167         if self.leaves and not preformatted:
1168             # Note: at this point leaf.prefix should be empty except for
1169             # imports, for which we only preserve newlines.
1170             leaf.prefix += whitespace(
1171                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1172             )
1173         if self.inside_brackets or not preformatted:
1174             self.bracket_tracker.mark(leaf)
1175             self.maybe_remove_trailing_comma(leaf)
1176         if not self.append_comment(leaf):
1177             self.leaves.append(leaf)
1178
1179     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1180         """Like :func:`append()` but disallow invalid standalone comment structure.
1181
1182         Raises ValueError when any `leaf` is appended after a standalone comment
1183         or when a standalone comment is not the first leaf on the line.
1184         """
1185         if self.bracket_tracker.depth == 0:
1186             if self.is_comment:
1187                 raise ValueError("cannot append to standalone comments")
1188
1189             if self.leaves and leaf.type == STANDALONE_COMMENT:
1190                 raise ValueError(
1191                     "cannot append standalone comments to a populated line"
1192                 )
1193
1194         self.append(leaf, preformatted=preformatted)
1195
1196     @property
1197     def is_comment(self) -> bool:
1198         """Is this line a standalone comment?"""
1199         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1200
1201     @property
1202     def is_decorator(self) -> bool:
1203         """Is this line a decorator?"""
1204         return bool(self) and self.leaves[0].type == token.AT
1205
1206     @property
1207     def is_import(self) -> bool:
1208         """Is this an import line?"""
1209         return bool(self) and is_import(self.leaves[0])
1210
1211     @property
1212     def is_class(self) -> bool:
1213         """Is this line a class definition?"""
1214         return (
1215             bool(self)
1216             and self.leaves[0].type == token.NAME
1217             and self.leaves[0].value == "class"
1218         )
1219
1220     @property
1221     def is_stub_class(self) -> bool:
1222         """Is this line a class definition with a body consisting only of "..."?"""
1223         return self.is_class and self.leaves[-3:] == [
1224             Leaf(token.DOT, ".") for _ in range(3)
1225         ]
1226
1227     @property
1228     def is_def(self) -> bool:
1229         """Is this a function definition? (Also returns True for async defs.)"""
1230         try:
1231             first_leaf = self.leaves[0]
1232         except IndexError:
1233             return False
1234
1235         try:
1236             second_leaf: Optional[Leaf] = self.leaves[1]
1237         except IndexError:
1238             second_leaf = None
1239         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1240             first_leaf.type == token.ASYNC
1241             and second_leaf is not None
1242             and second_leaf.type == token.NAME
1243             and second_leaf.value == "def"
1244         )
1245
1246     @property
1247     def is_class_paren_empty(self) -> bool:
1248         """Is this a class with no base classes but using parentheses?
1249
1250         Those are unnecessary and should be removed.
1251         """
1252         return (
1253             bool(self)
1254             and len(self.leaves) == 4
1255             and self.is_class
1256             and self.leaves[2].type == token.LPAR
1257             and self.leaves[2].value == "("
1258             and self.leaves[3].type == token.RPAR
1259             and self.leaves[3].value == ")"
1260         )
1261
1262     @property
1263     def is_triple_quoted_string(self) -> bool:
1264         """Is the line a triple quoted string?"""
1265         return (
1266             bool(self)
1267             and self.leaves[0].type == token.STRING
1268             and self.leaves[0].value.startswith(('"""', "'''"))
1269         )
1270
1271     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1272         """If so, needs to be split before emitting."""
1273         for leaf in self.leaves:
1274             if leaf.type == STANDALONE_COMMENT:
1275                 if leaf.bracket_depth <= depth_limit:
1276                     return True
1277         return False
1278
1279     def contains_inner_type_comments(self) -> bool:
1280         ignored_ids = set()
1281         try:
1282             last_leaf = self.leaves[-1]
1283             ignored_ids.add(id(last_leaf))
1284             if last_leaf.type == token.COMMA:
1285                 # When trailing commas are inserted by Black for consistency, comments
1286                 # after the previous last element are not moved (they don't have to,
1287                 # rendering will still be correct).  So we ignore trailing commas.
1288                 last_leaf = self.leaves[-2]
1289                 ignored_ids.add(id(last_leaf))
1290         except IndexError:
1291             return False
1292
1293         for leaf_id, comments in self.comments.items():
1294             if leaf_id in ignored_ids:
1295                 continue
1296
1297             for comment in comments:
1298                 if is_type_comment(comment):
1299                     return True
1300
1301         return False
1302
1303     def contains_multiline_strings(self) -> bool:
1304         for leaf in self.leaves:
1305             if is_multiline_string(leaf):
1306                 return True
1307
1308         return False
1309
1310     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1311         """Remove trailing comma if there is one and it's safe."""
1312         if not (
1313             self.leaves
1314             and self.leaves[-1].type == token.COMMA
1315             and closing.type in CLOSING_BRACKETS
1316         ):
1317             return False
1318
1319         if closing.type == token.RBRACE:
1320             self.remove_trailing_comma()
1321             return True
1322
1323         if closing.type == token.RSQB:
1324             comma = self.leaves[-1]
1325             if comma.parent and comma.parent.type == syms.listmaker:
1326                 self.remove_trailing_comma()
1327                 return True
1328
1329         # For parens let's check if it's safe to remove the comma.
1330         # Imports are always safe.
1331         if self.is_import:
1332             self.remove_trailing_comma()
1333             return True
1334
1335         # Otherwise, if the trailing one is the only one, we might mistakenly
1336         # change a tuple into a different type by removing the comma.
1337         depth = closing.bracket_depth + 1
1338         commas = 0
1339         opening = closing.opening_bracket
1340         for _opening_index, leaf in enumerate(self.leaves):
1341             if leaf is opening:
1342                 break
1343
1344         else:
1345             return False
1346
1347         for leaf in self.leaves[_opening_index + 1 :]:
1348             if leaf is closing:
1349                 break
1350
1351             bracket_depth = leaf.bracket_depth
1352             if bracket_depth == depth and leaf.type == token.COMMA:
1353                 commas += 1
1354                 if leaf.parent and leaf.parent.type == syms.arglist:
1355                     commas += 1
1356                     break
1357
1358         if commas > 1:
1359             self.remove_trailing_comma()
1360             return True
1361
1362         return False
1363
1364     def append_comment(self, comment: Leaf) -> bool:
1365         """Add an inline or standalone comment to the line."""
1366         if (
1367             comment.type == STANDALONE_COMMENT
1368             and self.bracket_tracker.any_open_brackets()
1369         ):
1370             comment.prefix = ""
1371             return False
1372
1373         if comment.type != token.COMMENT:
1374             return False
1375
1376         if not self.leaves:
1377             comment.type = STANDALONE_COMMENT
1378             comment.prefix = ""
1379             return False
1380
1381         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1382         return True
1383
1384     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1385         """Generate comments that should appear directly after `leaf`."""
1386         return self.comments.get(id(leaf), [])
1387
1388     def remove_trailing_comma(self) -> None:
1389         """Remove the trailing comma and moves the comments attached to it."""
1390         trailing_comma = self.leaves.pop()
1391         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1392         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1393             trailing_comma_comments
1394         )
1395
1396     def is_complex_subscript(self, leaf: Leaf) -> bool:
1397         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1398         open_lsqb = self.bracket_tracker.get_open_lsqb()
1399         if open_lsqb is None:
1400             return False
1401
1402         subscript_start = open_lsqb.next_sibling
1403
1404         if isinstance(subscript_start, Node):
1405             if subscript_start.type == syms.listmaker:
1406                 return False
1407
1408             if subscript_start.type == syms.subscriptlist:
1409                 subscript_start = child_towards(subscript_start, leaf)
1410         return subscript_start is not None and any(
1411             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1412         )
1413
1414     def __str__(self) -> str:
1415         """Render the line."""
1416         if not self:
1417             return "\n"
1418
1419         indent = "    " * self.depth
1420         leaves = iter(self.leaves)
1421         first = next(leaves)
1422         res = f"{first.prefix}{indent}{first.value}"
1423         for leaf in leaves:
1424             res += str(leaf)
1425         for comment in itertools.chain.from_iterable(self.comments.values()):
1426             res += str(comment)
1427         return res + "\n"
1428
1429     def __bool__(self) -> bool:
1430         """Return True if the line has leaves or comments."""
1431         return bool(self.leaves or self.comments)
1432
1433
1434 @dataclass
1435 class EmptyLineTracker:
1436     """Provides a stateful method that returns the number of potential extra
1437     empty lines needed before and after the currently processed line.
1438
1439     Note: this tracker works on lines that haven't been split yet.  It assumes
1440     the prefix of the first leaf consists of optional newlines.  Those newlines
1441     are consumed by `maybe_empty_lines()` and included in the computation.
1442     """
1443
1444     is_pyi: bool = False
1445     previous_line: Optional[Line] = None
1446     previous_after: int = 0
1447     previous_defs: List[int] = Factory(list)
1448
1449     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1450         """Return the number of extra empty lines before and after the `current_line`.
1451
1452         This is for separating `def`, `async def` and `class` with extra empty
1453         lines (two on module-level).
1454         """
1455         before, after = self._maybe_empty_lines(current_line)
1456         before -= self.previous_after
1457         self.previous_after = after
1458         self.previous_line = current_line
1459         return before, after
1460
1461     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1462         max_allowed = 1
1463         if current_line.depth == 0:
1464             max_allowed = 1 if self.is_pyi else 2
1465         if current_line.leaves:
1466             # Consume the first leaf's extra newlines.
1467             first_leaf = current_line.leaves[0]
1468             before = first_leaf.prefix.count("\n")
1469             before = min(before, max_allowed)
1470             first_leaf.prefix = ""
1471         else:
1472             before = 0
1473         depth = current_line.depth
1474         while self.previous_defs and self.previous_defs[-1] >= depth:
1475             self.previous_defs.pop()
1476             if self.is_pyi:
1477                 before = 0 if depth else 1
1478             else:
1479                 before = 1 if depth else 2
1480         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1481             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1482
1483         if (
1484             self.previous_line
1485             and self.previous_line.is_import
1486             and not current_line.is_import
1487             and depth == self.previous_line.depth
1488         ):
1489             return (before or 1), 0
1490
1491         if (
1492             self.previous_line
1493             and self.previous_line.is_class
1494             and current_line.is_triple_quoted_string
1495         ):
1496             return before, 1
1497
1498         return before, 0
1499
1500     def _maybe_empty_lines_for_class_or_def(
1501         self, current_line: Line, before: int
1502     ) -> Tuple[int, int]:
1503         if not current_line.is_decorator:
1504             self.previous_defs.append(current_line.depth)
1505         if self.previous_line is None:
1506             # Don't insert empty lines before the first line in the file.
1507             return 0, 0
1508
1509         if self.previous_line.is_decorator:
1510             return 0, 0
1511
1512         if self.previous_line.depth < current_line.depth and (
1513             self.previous_line.is_class or self.previous_line.is_def
1514         ):
1515             return 0, 0
1516
1517         if (
1518             self.previous_line.is_comment
1519             and self.previous_line.depth == current_line.depth
1520             and before == 0
1521         ):
1522             return 0, 0
1523
1524         if self.is_pyi:
1525             if self.previous_line.depth > current_line.depth:
1526                 newlines = 1
1527             elif current_line.is_class or self.previous_line.is_class:
1528                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1529                     # No blank line between classes with an empty body
1530                     newlines = 0
1531                 else:
1532                     newlines = 1
1533             elif current_line.is_def and not self.previous_line.is_def:
1534                 # Blank line between a block of functions and a block of non-functions
1535                 newlines = 1
1536             else:
1537                 newlines = 0
1538         else:
1539             newlines = 2
1540         if current_line.depth and newlines:
1541             newlines -= 1
1542         return newlines, 0
1543
1544
1545 @dataclass
1546 class LineGenerator(Visitor[Line]):
1547     """Generates reformatted Line objects.  Empty lines are not emitted.
1548
1549     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1550     in ways that will no longer stringify to valid Python code on the tree.
1551     """
1552
1553     is_pyi: bool = False
1554     normalize_strings: bool = True
1555     current_line: Line = Factory(Line)
1556     remove_u_prefix: bool = False
1557
1558     def line(self, indent: int = 0) -> Iterator[Line]:
1559         """Generate a line.
1560
1561         If the line is empty, only emit if it makes sense.
1562         If the line is too long, split it first and then generate.
1563
1564         If any lines were generated, set up a new current_line.
1565         """
1566         if not self.current_line:
1567             self.current_line.depth += indent
1568             return  # Line is empty, don't emit. Creating a new one unnecessary.
1569
1570         complete_line = self.current_line
1571         self.current_line = Line(depth=complete_line.depth + indent)
1572         yield complete_line
1573
1574     def visit_default(self, node: LN) -> Iterator[Line]:
1575         """Default `visit_*()` implementation. Recurses to children of `node`."""
1576         if isinstance(node, Leaf):
1577             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1578             for comment in generate_comments(node):
1579                 if any_open_brackets:
1580                     # any comment within brackets is subject to splitting
1581                     self.current_line.append(comment)
1582                 elif comment.type == token.COMMENT:
1583                     # regular trailing comment
1584                     self.current_line.append(comment)
1585                     yield from self.line()
1586
1587                 else:
1588                     # regular standalone comment
1589                     yield from self.line()
1590
1591                     self.current_line.append(comment)
1592                     yield from self.line()
1593
1594             normalize_prefix(node, inside_brackets=any_open_brackets)
1595             if self.normalize_strings and node.type == token.STRING:
1596                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1597                 normalize_string_quotes(node)
1598             if node.type == token.NUMBER:
1599                 normalize_numeric_literal(node)
1600             if node.type not in WHITESPACE:
1601                 self.current_line.append(node)
1602         yield from super().visit_default(node)
1603
1604     def visit_atom(self, node: Node) -> Iterator[Line]:
1605         # Always make parentheses invisible around a single node, because it should
1606         # not be needed (except in the case of yield, where removing the parentheses
1607         # produces a SyntaxError).
1608         if (
1609             len(node.children) == 3
1610             and isinstance(node.children[0], Leaf)
1611             and node.children[0].type == token.LPAR
1612             and isinstance(node.children[2], Leaf)
1613             and node.children[2].type == token.RPAR
1614             and isinstance(node.children[1], Leaf)
1615             and not (
1616                 node.children[1].type == token.NAME
1617                 and node.children[1].value == "yield"
1618             )
1619         ):
1620             node.children[0].value = ""
1621             node.children[2].value = ""
1622         yield from super().visit_default(node)
1623
1624     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1625         """Increase indentation level, maybe yield a line."""
1626         # In blib2to3 INDENT never holds comments.
1627         yield from self.line(+1)
1628         yield from self.visit_default(node)
1629
1630     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1631         """Decrease indentation level, maybe yield a line."""
1632         # The current line might still wait for trailing comments.  At DEDENT time
1633         # there won't be any (they would be prefixes on the preceding NEWLINE).
1634         # Emit the line then.
1635         yield from self.line()
1636
1637         # While DEDENT has no value, its prefix may contain standalone comments
1638         # that belong to the current indentation level.  Get 'em.
1639         yield from self.visit_default(node)
1640
1641         # Finally, emit the dedent.
1642         yield from self.line(-1)
1643
1644     def visit_stmt(
1645         self, node: Node, keywords: Set[str], parens: Set[str]
1646     ) -> Iterator[Line]:
1647         """Visit a statement.
1648
1649         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1650         `def`, `with`, `class`, `assert` and assignments.
1651
1652         The relevant Python language `keywords` for a given statement will be
1653         NAME leaves within it. This methods puts those on a separate line.
1654
1655         `parens` holds a set of string leaf values immediately after which
1656         invisible parens should be put.
1657         """
1658         normalize_invisible_parens(node, parens_after=parens)
1659         for child in node.children:
1660             if child.type == token.NAME and child.value in keywords:  # type: ignore
1661                 yield from self.line()
1662
1663             yield from self.visit(child)
1664
1665     def visit_suite(self, node: Node) -> Iterator[Line]:
1666         """Visit a suite."""
1667         if self.is_pyi and is_stub_suite(node):
1668             yield from self.visit(node.children[2])
1669         else:
1670             yield from self.visit_default(node)
1671
1672     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1673         """Visit a statement without nested statements."""
1674         is_suite_like = node.parent and node.parent.type in STATEMENT
1675         if is_suite_like:
1676             if self.is_pyi and is_stub_body(node):
1677                 yield from self.visit_default(node)
1678             else:
1679                 yield from self.line(+1)
1680                 yield from self.visit_default(node)
1681                 yield from self.line(-1)
1682
1683         else:
1684             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1685                 yield from self.line()
1686             yield from self.visit_default(node)
1687
1688     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1689         """Visit `async def`, `async for`, `async with`."""
1690         yield from self.line()
1691
1692         children = iter(node.children)
1693         for child in children:
1694             yield from self.visit(child)
1695
1696             if child.type == token.ASYNC:
1697                 break
1698
1699         internal_stmt = next(children)
1700         for child in internal_stmt.children:
1701             yield from self.visit(child)
1702
1703     def visit_decorators(self, node: Node) -> Iterator[Line]:
1704         """Visit decorators."""
1705         for child in node.children:
1706             yield from self.line()
1707             yield from self.visit(child)
1708
1709     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1710         """Remove a semicolon and put the other statement on a separate line."""
1711         yield from self.line()
1712
1713     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1714         """End of file. Process outstanding comments and end with a newline."""
1715         yield from self.visit_default(leaf)
1716         yield from self.line()
1717
1718     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1719         if not self.current_line.bracket_tracker.any_open_brackets():
1720             yield from self.line()
1721         yield from self.visit_default(leaf)
1722
1723     def __attrs_post_init__(self) -> None:
1724         """You are in a twisty little maze of passages."""
1725         v = self.visit_stmt
1726         Ø: Set[str] = set()
1727         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1728         self.visit_if_stmt = partial(
1729             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1730         )
1731         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1732         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1733         self.visit_try_stmt = partial(
1734             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1735         )
1736         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1737         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1738         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1739         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1740         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1741         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1742         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1743         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1744         self.visit_async_funcdef = self.visit_async_stmt
1745         self.visit_decorated = self.visit_decorators
1746
1747
1748 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1749 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1750 OPENING_BRACKETS = set(BRACKET.keys())
1751 CLOSING_BRACKETS = set(BRACKET.values())
1752 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1753 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1754
1755
1756 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1757     """Return whitespace prefix if needed for the given `leaf`.
1758
1759     `complex_subscript` signals whether the given leaf is part of a subscription
1760     which has non-trivial arguments, like arithmetic expressions or function calls.
1761     """
1762     NO = ""
1763     SPACE = " "
1764     DOUBLESPACE = "  "
1765     t = leaf.type
1766     p = leaf.parent
1767     v = leaf.value
1768     if t in ALWAYS_NO_SPACE:
1769         return NO
1770
1771     if t == token.COMMENT:
1772         return DOUBLESPACE
1773
1774     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1775     if t == token.COLON and p.type not in {
1776         syms.subscript,
1777         syms.subscriptlist,
1778         syms.sliceop,
1779     }:
1780         return NO
1781
1782     prev = leaf.prev_sibling
1783     if not prev:
1784         prevp = preceding_leaf(p)
1785         if not prevp or prevp.type in OPENING_BRACKETS:
1786             return NO
1787
1788         if t == token.COLON:
1789             if prevp.type == token.COLON:
1790                 return NO
1791
1792             elif prevp.type != token.COMMA and not complex_subscript:
1793                 return NO
1794
1795             return SPACE
1796
1797         if prevp.type == token.EQUAL:
1798             if prevp.parent:
1799                 if prevp.parent.type in {
1800                     syms.arglist,
1801                     syms.argument,
1802                     syms.parameters,
1803                     syms.varargslist,
1804                 }:
1805                     return NO
1806
1807                 elif prevp.parent.type == syms.typedargslist:
1808                     # A bit hacky: if the equal sign has whitespace, it means we
1809                     # previously found it's a typed argument.  So, we're using
1810                     # that, too.
1811                     return prevp.prefix
1812
1813         elif prevp.type in STARS:
1814             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1815                 return NO
1816
1817         elif prevp.type == token.COLON:
1818             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1819                 return SPACE if complex_subscript else NO
1820
1821         elif (
1822             prevp.parent
1823             and prevp.parent.type == syms.factor
1824             and prevp.type in MATH_OPERATORS
1825         ):
1826             return NO
1827
1828         elif (
1829             prevp.type == token.RIGHTSHIFT
1830             and prevp.parent
1831             and prevp.parent.type == syms.shift_expr
1832             and prevp.prev_sibling
1833             and prevp.prev_sibling.type == token.NAME
1834             and prevp.prev_sibling.value == "print"  # type: ignore
1835         ):
1836             # Python 2 print chevron
1837             return NO
1838
1839     elif prev.type in OPENING_BRACKETS:
1840         return NO
1841
1842     if p.type in {syms.parameters, syms.arglist}:
1843         # untyped function signatures or calls
1844         if not prev or prev.type != token.COMMA:
1845             return NO
1846
1847     elif p.type == syms.varargslist:
1848         # lambdas
1849         if prev and prev.type != token.COMMA:
1850             return NO
1851
1852     elif p.type == syms.typedargslist:
1853         # typed function signatures
1854         if not prev:
1855             return NO
1856
1857         if t == token.EQUAL:
1858             if prev.type != syms.tname:
1859                 return NO
1860
1861         elif prev.type == token.EQUAL:
1862             # A bit hacky: if the equal sign has whitespace, it means we
1863             # previously found it's a typed argument.  So, we're using that, too.
1864             return prev.prefix
1865
1866         elif prev.type != token.COMMA:
1867             return NO
1868
1869     elif p.type == syms.tname:
1870         # type names
1871         if not prev:
1872             prevp = preceding_leaf(p)
1873             if not prevp or prevp.type != token.COMMA:
1874                 return NO
1875
1876     elif p.type == syms.trailer:
1877         # attributes and calls
1878         if t == token.LPAR or t == token.RPAR:
1879             return NO
1880
1881         if not prev:
1882             if t == token.DOT:
1883                 prevp = preceding_leaf(p)
1884                 if not prevp or prevp.type != token.NUMBER:
1885                     return NO
1886
1887             elif t == token.LSQB:
1888                 return NO
1889
1890         elif prev.type != token.COMMA:
1891             return NO
1892
1893     elif p.type == syms.argument:
1894         # single argument
1895         if t == token.EQUAL:
1896             return NO
1897
1898         if not prev:
1899             prevp = preceding_leaf(p)
1900             if not prevp or prevp.type == token.LPAR:
1901                 return NO
1902
1903         elif prev.type in {token.EQUAL} | STARS:
1904             return NO
1905
1906     elif p.type == syms.decorator:
1907         # decorators
1908         return NO
1909
1910     elif p.type == syms.dotted_name:
1911         if prev:
1912             return NO
1913
1914         prevp = preceding_leaf(p)
1915         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1916             return NO
1917
1918     elif p.type == syms.classdef:
1919         if t == token.LPAR:
1920             return NO
1921
1922         if prev and prev.type == token.LPAR:
1923             return NO
1924
1925     elif p.type in {syms.subscript, syms.sliceop}:
1926         # indexing
1927         if not prev:
1928             assert p.parent is not None, "subscripts are always parented"
1929             if p.parent.type == syms.subscriptlist:
1930                 return SPACE
1931
1932             return NO
1933
1934         elif not complex_subscript:
1935             return NO
1936
1937     elif p.type == syms.atom:
1938         if prev and t == token.DOT:
1939             # dots, but not the first one.
1940             return NO
1941
1942     elif p.type == syms.dictsetmaker:
1943         # dict unpacking
1944         if prev and prev.type == token.DOUBLESTAR:
1945             return NO
1946
1947     elif p.type in {syms.factor, syms.star_expr}:
1948         # unary ops
1949         if not prev:
1950             prevp = preceding_leaf(p)
1951             if not prevp or prevp.type in OPENING_BRACKETS:
1952                 return NO
1953
1954             prevp_parent = prevp.parent
1955             assert prevp_parent is not None
1956             if prevp.type == token.COLON and prevp_parent.type in {
1957                 syms.subscript,
1958                 syms.sliceop,
1959             }:
1960                 return NO
1961
1962             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1963                 return NO
1964
1965         elif t in {token.NAME, token.NUMBER, token.STRING}:
1966             return NO
1967
1968     elif p.type == syms.import_from:
1969         if t == token.DOT:
1970             if prev and prev.type == token.DOT:
1971                 return NO
1972
1973         elif t == token.NAME:
1974             if v == "import":
1975                 return SPACE
1976
1977             if prev and prev.type == token.DOT:
1978                 return NO
1979
1980     elif p.type == syms.sliceop:
1981         return NO
1982
1983     return SPACE
1984
1985
1986 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1987     """Return the first leaf that precedes `node`, if any."""
1988     while node:
1989         res = node.prev_sibling
1990         if res:
1991             if isinstance(res, Leaf):
1992                 return res
1993
1994             try:
1995                 return list(res.leaves())[-1]
1996
1997             except IndexError:
1998                 return None
1999
2000         node = node.parent
2001     return None
2002
2003
2004 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2005     """Return the child of `ancestor` that contains `descendant`."""
2006     node: Optional[LN] = descendant
2007     while node and node.parent != ancestor:
2008         node = node.parent
2009     return node
2010
2011
2012 def container_of(leaf: Leaf) -> LN:
2013     """Return `leaf` or one of its ancestors that is the topmost container of it.
2014
2015     By "container" we mean a node where `leaf` is the very first child.
2016     """
2017     same_prefix = leaf.prefix
2018     container: LN = leaf
2019     while container:
2020         parent = container.parent
2021         if parent is None:
2022             break
2023
2024         if parent.children[0].prefix != same_prefix:
2025             break
2026
2027         if parent.type == syms.file_input:
2028             break
2029
2030         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2031             break
2032
2033         container = parent
2034     return container
2035
2036
2037 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
2038     """Return the priority of the `leaf` delimiter, given a line break after it.
2039
2040     The delimiter priorities returned here are from those delimiters that would
2041     cause a line break after themselves.
2042
2043     Higher numbers are higher priority.
2044     """
2045     if leaf.type == token.COMMA:
2046         return COMMA_PRIORITY
2047
2048     return 0
2049
2050
2051 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
2052     """Return the priority of the `leaf` delimiter, given a line break before it.
2053
2054     The delimiter priorities returned here are from those delimiters that would
2055     cause a line break before themselves.
2056
2057     Higher numbers are higher priority.
2058     """
2059     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2060         # * and ** might also be MATH_OPERATORS but in this case they are not.
2061         # Don't treat them as a delimiter.
2062         return 0
2063
2064     if (
2065         leaf.type == token.DOT
2066         and leaf.parent
2067         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2068         and (previous is None or previous.type in CLOSING_BRACKETS)
2069     ):
2070         return DOT_PRIORITY
2071
2072     if (
2073         leaf.type in MATH_OPERATORS
2074         and leaf.parent
2075         and leaf.parent.type not in {syms.factor, syms.star_expr}
2076     ):
2077         return MATH_PRIORITIES[leaf.type]
2078
2079     if leaf.type in COMPARATORS:
2080         return COMPARATOR_PRIORITY
2081
2082     if (
2083         leaf.type == token.STRING
2084         and previous is not None
2085         and previous.type == token.STRING
2086     ):
2087         return STRING_PRIORITY
2088
2089     if leaf.type not in {token.NAME, token.ASYNC}:
2090         return 0
2091
2092     if (
2093         leaf.value == "for"
2094         and leaf.parent
2095         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2096         or leaf.type == token.ASYNC
2097     ):
2098         if (
2099             not isinstance(leaf.prev_sibling, Leaf)
2100             or leaf.prev_sibling.value != "async"
2101         ):
2102             return COMPREHENSION_PRIORITY
2103
2104     if (
2105         leaf.value == "if"
2106         and leaf.parent
2107         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2108     ):
2109         return COMPREHENSION_PRIORITY
2110
2111     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2112         return TERNARY_PRIORITY
2113
2114     if leaf.value == "is":
2115         return COMPARATOR_PRIORITY
2116
2117     if (
2118         leaf.value == "in"
2119         and leaf.parent
2120         and leaf.parent.type in {syms.comp_op, syms.comparison}
2121         and not (
2122             previous is not None
2123             and previous.type == token.NAME
2124             and previous.value == "not"
2125         )
2126     ):
2127         return COMPARATOR_PRIORITY
2128
2129     if (
2130         leaf.value == "not"
2131         and leaf.parent
2132         and leaf.parent.type == syms.comp_op
2133         and not (
2134             previous is not None
2135             and previous.type == token.NAME
2136             and previous.value == "is"
2137         )
2138     ):
2139         return COMPARATOR_PRIORITY
2140
2141     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2142         return LOGIC_PRIORITY
2143
2144     return 0
2145
2146
2147 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2148 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2149
2150
2151 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2152     """Clean the prefix of the `leaf` and generate comments from it, if any.
2153
2154     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2155     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2156     move because it does away with modifying the grammar to include all the
2157     possible places in which comments can be placed.
2158
2159     The sad consequence for us though is that comments don't "belong" anywhere.
2160     This is why this function generates simple parentless Leaf objects for
2161     comments.  We simply don't know what the correct parent should be.
2162
2163     No matter though, we can live without this.  We really only need to
2164     differentiate between inline and standalone comments.  The latter don't
2165     share the line with any code.
2166
2167     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2168     are emitted with a fake STANDALONE_COMMENT token identifier.
2169     """
2170     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2171         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2172
2173
2174 @dataclass
2175 class ProtoComment:
2176     """Describes a piece of syntax that is a comment.
2177
2178     It's not a :class:`blib2to3.pytree.Leaf` so that:
2179
2180     * it can be cached (`Leaf` objects should not be reused more than once as
2181       they store their lineno, column, prefix, and parent information);
2182     * `newlines` and `consumed` fields are kept separate from the `value`. This
2183       simplifies handling of special marker comments like ``# fmt: off/on``.
2184     """
2185
2186     type: int  # token.COMMENT or STANDALONE_COMMENT
2187     value: str  # content of the comment
2188     newlines: int  # how many newlines before the comment
2189     consumed: int  # how many characters of the original leaf's prefix did we consume
2190
2191
2192 @lru_cache(maxsize=4096)
2193 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2194     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2195     result: List[ProtoComment] = []
2196     if not prefix or "#" not in prefix:
2197         return result
2198
2199     consumed = 0
2200     nlines = 0
2201     ignored_lines = 0
2202     for index, line in enumerate(prefix.split("\n")):
2203         consumed += len(line) + 1  # adding the length of the split '\n'
2204         line = line.lstrip()
2205         if not line:
2206             nlines += 1
2207         if not line.startswith("#"):
2208             # Escaped newlines outside of a comment are not really newlines at
2209             # all. We treat a single-line comment following an escaped newline
2210             # as a simple trailing comment.
2211             if line.endswith("\\"):
2212                 ignored_lines += 1
2213             continue
2214
2215         if index == ignored_lines and not is_endmarker:
2216             comment_type = token.COMMENT  # simple trailing comment
2217         else:
2218             comment_type = STANDALONE_COMMENT
2219         comment = make_comment(line)
2220         result.append(
2221             ProtoComment(
2222                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2223             )
2224         )
2225         nlines = 0
2226     return result
2227
2228
2229 def make_comment(content: str) -> str:
2230     """Return a consistently formatted comment from the given `content` string.
2231
2232     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2233     space between the hash sign and the content.
2234
2235     If `content` didn't start with a hash sign, one is provided.
2236     """
2237     content = content.rstrip()
2238     if not content:
2239         return "#"
2240
2241     if content[0] == "#":
2242         content = content[1:]
2243     if content and content[0] not in " !:#'%":
2244         content = " " + content
2245     return "#" + content
2246
2247
2248 def split_line(
2249     line: Line,
2250     line_length: int,
2251     inner: bool = False,
2252     features: Collection[Feature] = (),
2253 ) -> Iterator[Line]:
2254     """Split a `line` into potentially many lines.
2255
2256     They should fit in the allotted `line_length` but might not be able to.
2257     `inner` signifies that there were a pair of brackets somewhere around the
2258     current `line`, possibly transitively. This means we can fallback to splitting
2259     by delimiters if the LHS/RHS don't yield any results.
2260
2261     `features` are syntactical features that may be used in the output.
2262     """
2263     if line.is_comment:
2264         yield line
2265         return
2266
2267     line_str = str(line).strip("\n")
2268
2269     if (
2270         not line.contains_inner_type_comments()
2271         and not line.should_explode
2272         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2273     ):
2274         yield line
2275         return
2276
2277     split_funcs: List[SplitFunc]
2278     if line.is_def:
2279         split_funcs = [left_hand_split]
2280     else:
2281
2282         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2283             for omit in generate_trailers_to_omit(line, line_length):
2284                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2285                 if is_line_short_enough(lines[0], line_length=line_length):
2286                     yield from lines
2287                     return
2288
2289             # All splits failed, best effort split with no omits.
2290             # This mostly happens to multiline strings that are by definition
2291             # reported as not fitting a single line.
2292             yield from right_hand_split(line, line_length, features=features)
2293
2294         if line.inside_brackets:
2295             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2296         else:
2297             split_funcs = [rhs]
2298     for split_func in split_funcs:
2299         # We are accumulating lines in `result` because we might want to abort
2300         # mission and return the original line in the end, or attempt a different
2301         # split altogether.
2302         result: List[Line] = []
2303         try:
2304             for l in split_func(line, features):
2305                 if str(l).strip("\n") == line_str:
2306                     raise CannotSplit("Split function returned an unchanged result")
2307
2308                 result.extend(
2309                     split_line(
2310                         l, line_length=line_length, inner=True, features=features
2311                     )
2312                 )
2313         except CannotSplit:
2314             continue
2315
2316         else:
2317             yield from result
2318             break
2319
2320     else:
2321         yield line
2322
2323
2324 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2325     """Split line into many lines, starting with the first matching bracket pair.
2326
2327     Note: this usually looks weird, only use this for function definitions.
2328     Prefer RHS otherwise.  This is why this function is not symmetrical with
2329     :func:`right_hand_split` which also handles optional parentheses.
2330     """
2331     tail_leaves: List[Leaf] = []
2332     body_leaves: List[Leaf] = []
2333     head_leaves: List[Leaf] = []
2334     current_leaves = head_leaves
2335     matching_bracket = None
2336     for leaf in line.leaves:
2337         if (
2338             current_leaves is body_leaves
2339             and leaf.type in CLOSING_BRACKETS
2340             and leaf.opening_bracket is matching_bracket
2341         ):
2342             current_leaves = tail_leaves if body_leaves else head_leaves
2343         current_leaves.append(leaf)
2344         if current_leaves is head_leaves:
2345             if leaf.type in OPENING_BRACKETS:
2346                 matching_bracket = leaf
2347                 current_leaves = body_leaves
2348     if not matching_bracket:
2349         raise CannotSplit("No brackets found")
2350
2351     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2352     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2353     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2354     bracket_split_succeeded_or_raise(head, body, tail)
2355     for result in (head, body, tail):
2356         if result:
2357             yield result
2358
2359
2360 def right_hand_split(
2361     line: Line,
2362     line_length: int,
2363     features: Collection[Feature] = (),
2364     omit: Collection[LeafID] = (),
2365 ) -> Iterator[Line]:
2366     """Split line into many lines, starting with the last matching bracket pair.
2367
2368     If the split was by optional parentheses, attempt splitting without them, too.
2369     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2370     this split.
2371
2372     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2373     """
2374     tail_leaves: List[Leaf] = []
2375     body_leaves: List[Leaf] = []
2376     head_leaves: List[Leaf] = []
2377     current_leaves = tail_leaves
2378     opening_bracket = None
2379     closing_bracket = None
2380     for leaf in reversed(line.leaves):
2381         if current_leaves is body_leaves:
2382             if leaf is opening_bracket:
2383                 current_leaves = head_leaves if body_leaves else tail_leaves
2384         current_leaves.append(leaf)
2385         if current_leaves is tail_leaves:
2386             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2387                 opening_bracket = leaf.opening_bracket
2388                 closing_bracket = leaf
2389                 current_leaves = body_leaves
2390     if not (opening_bracket and closing_bracket and head_leaves):
2391         # If there is no opening or closing_bracket that means the split failed and
2392         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2393         # the matching `opening_bracket` wasn't available on `line` anymore.
2394         raise CannotSplit("No brackets found")
2395
2396     tail_leaves.reverse()
2397     body_leaves.reverse()
2398     head_leaves.reverse()
2399     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2400     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2401     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2402     bracket_split_succeeded_or_raise(head, body, tail)
2403     if (
2404         # the body shouldn't be exploded
2405         not body.should_explode
2406         # the opening bracket is an optional paren
2407         and opening_bracket.type == token.LPAR
2408         and not opening_bracket.value
2409         # the closing bracket is an optional paren
2410         and closing_bracket.type == token.RPAR
2411         and not closing_bracket.value
2412         # it's not an import (optional parens are the only thing we can split on
2413         # in this case; attempting a split without them is a waste of time)
2414         and not line.is_import
2415         # there are no standalone comments in the body
2416         and not body.contains_standalone_comments(0)
2417         # and we can actually remove the parens
2418         and can_omit_invisible_parens(body, line_length)
2419     ):
2420         omit = {id(closing_bracket), *omit}
2421         try:
2422             yield from right_hand_split(line, line_length, features=features, omit=omit)
2423             return
2424
2425         except CannotSplit:
2426             if not (
2427                 can_be_split(body)
2428                 or is_line_short_enough(body, line_length=line_length)
2429             ):
2430                 raise CannotSplit(
2431                     "Splitting failed, body is still too long and can't be split."
2432                 )
2433
2434             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2435                 raise CannotSplit(
2436                     "The current optional pair of parentheses is bound to fail to "
2437                     "satisfy the splitting algorithm because the head or the tail "
2438                     "contains multiline strings which by definition never fit one "
2439                     "line."
2440                 )
2441
2442     ensure_visible(opening_bracket)
2443     ensure_visible(closing_bracket)
2444     for result in (head, body, tail):
2445         if result:
2446             yield result
2447
2448
2449 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2450     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2451
2452     Do nothing otherwise.
2453
2454     A left- or right-hand split is based on a pair of brackets. Content before
2455     (and including) the opening bracket is left on one line, content inside the
2456     brackets is put on a separate line, and finally content starting with and
2457     following the closing bracket is put on a separate line.
2458
2459     Those are called `head`, `body`, and `tail`, respectively. If the split
2460     produced the same line (all content in `head`) or ended up with an empty `body`
2461     and the `tail` is just the closing bracket, then it's considered failed.
2462     """
2463     tail_len = len(str(tail).strip())
2464     if not body:
2465         if tail_len == 0:
2466             raise CannotSplit("Splitting brackets produced the same line")
2467
2468         elif tail_len < 3:
2469             raise CannotSplit(
2470                 f"Splitting brackets on an empty body to save "
2471                 f"{tail_len} characters is not worth it"
2472             )
2473
2474
2475 def bracket_split_build_line(
2476     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2477 ) -> Line:
2478     """Return a new line with given `leaves` and respective comments from `original`.
2479
2480     If `is_body` is True, the result line is one-indented inside brackets and as such
2481     has its first leaf's prefix normalized and a trailing comma added when expected.
2482     """
2483     result = Line(depth=original.depth)
2484     if is_body:
2485         result.inside_brackets = True
2486         result.depth += 1
2487         if leaves:
2488             # Since body is a new indent level, remove spurious leading whitespace.
2489             normalize_prefix(leaves[0], inside_brackets=True)
2490             # Ensure a trailing comma for imports, but be careful not to add one after
2491             # any comments.
2492             if original.is_import:
2493                 for i in range(len(leaves) - 1, -1, -1):
2494                     if leaves[i].type == STANDALONE_COMMENT:
2495                         continue
2496                     elif leaves[i].type == token.COMMA:
2497                         break
2498                     else:
2499                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2500                         break
2501     # Populate the line
2502     for leaf in leaves:
2503         result.append(leaf, preformatted=True)
2504         for comment_after in original.comments_after(leaf):
2505             result.append(comment_after, preformatted=True)
2506     if is_body:
2507         result.should_explode = should_explode(result, opening_bracket)
2508     return result
2509
2510
2511 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2512     """Normalize prefix of the first leaf in every line returned by `split_func`.
2513
2514     This is a decorator over relevant split functions.
2515     """
2516
2517     @wraps(split_func)
2518     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2519         for l in split_func(line, features):
2520             normalize_prefix(l.leaves[0], inside_brackets=True)
2521             yield l
2522
2523     return split_wrapper
2524
2525
2526 @dont_increase_indentation
2527 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2528     """Split according to delimiters of the highest priority.
2529
2530     If the appropriate Features are given, the split will add trailing commas
2531     also in function signatures and calls that contain `*` and `**`.
2532     """
2533     try:
2534         last_leaf = line.leaves[-1]
2535     except IndexError:
2536         raise CannotSplit("Line empty")
2537
2538     bt = line.bracket_tracker
2539     try:
2540         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2541     except ValueError:
2542         raise CannotSplit("No delimiters found")
2543
2544     if delimiter_priority == DOT_PRIORITY:
2545         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2546             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2547
2548     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2549     lowest_depth = sys.maxsize
2550     trailing_comma_safe = True
2551
2552     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2553         """Append `leaf` to current line or to new line if appending impossible."""
2554         nonlocal current_line
2555         try:
2556             current_line.append_safe(leaf, preformatted=True)
2557         except ValueError:
2558             yield current_line
2559
2560             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2561             current_line.append(leaf)
2562
2563     for leaf in line.leaves:
2564         yield from append_to_line(leaf)
2565
2566         for comment_after in line.comments_after(leaf):
2567             yield from append_to_line(comment_after)
2568
2569         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2570         if leaf.bracket_depth == lowest_depth:
2571             if is_vararg(leaf, within={syms.typedargslist}):
2572                 trailing_comma_safe = (
2573                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2574                 )
2575             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2576                 trailing_comma_safe = (
2577                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2578                 )
2579
2580         leaf_priority = bt.delimiters.get(id(leaf))
2581         if leaf_priority == delimiter_priority:
2582             yield current_line
2583
2584             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2585     if current_line:
2586         if (
2587             trailing_comma_safe
2588             and delimiter_priority == COMMA_PRIORITY
2589             and current_line.leaves[-1].type != token.COMMA
2590             and current_line.leaves[-1].type != STANDALONE_COMMENT
2591         ):
2592             current_line.append(Leaf(token.COMMA, ","))
2593         yield current_line
2594
2595
2596 @dont_increase_indentation
2597 def standalone_comment_split(
2598     line: Line, features: Collection[Feature] = ()
2599 ) -> Iterator[Line]:
2600     """Split standalone comments from the rest of the line."""
2601     if not line.contains_standalone_comments(0):
2602         raise CannotSplit("Line does not have any standalone comments")
2603
2604     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2605
2606     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2607         """Append `leaf` to current line or to new line if appending impossible."""
2608         nonlocal current_line
2609         try:
2610             current_line.append_safe(leaf, preformatted=True)
2611         except ValueError:
2612             yield current_line
2613
2614             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2615             current_line.append(leaf)
2616
2617     for leaf in line.leaves:
2618         yield from append_to_line(leaf)
2619
2620         for comment_after in line.comments_after(leaf):
2621             yield from append_to_line(comment_after)
2622
2623     if current_line:
2624         yield current_line
2625
2626
2627 def is_import(leaf: Leaf) -> bool:
2628     """Return True if the given leaf starts an import statement."""
2629     p = leaf.parent
2630     t = leaf.type
2631     v = leaf.value
2632     return bool(
2633         t == token.NAME
2634         and (
2635             (v == "import" and p and p.type == syms.import_name)
2636             or (v == "from" and p and p.type == syms.import_from)
2637         )
2638     )
2639
2640
2641 def is_type_comment(leaf: Leaf) -> bool:
2642     """Return True if the given leaf is a special comment.
2643     Only returns true for type comments for now."""
2644     t = leaf.type
2645     v = leaf.value
2646     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2647
2648
2649 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2650     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2651     else.
2652
2653     Note: don't use backslashes for formatting or you'll lose your voting rights.
2654     """
2655     if not inside_brackets:
2656         spl = leaf.prefix.split("#")
2657         if "\\" not in spl[0]:
2658             nl_count = spl[-1].count("\n")
2659             if len(spl) > 1:
2660                 nl_count -= 1
2661             leaf.prefix = "\n" * nl_count
2662             return
2663
2664     leaf.prefix = ""
2665
2666
2667 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2668     """Make all string prefixes lowercase.
2669
2670     If remove_u_prefix is given, also removes any u prefix from the string.
2671
2672     Note: Mutates its argument.
2673     """
2674     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2675     assert match is not None, f"failed to match string {leaf.value!r}"
2676     orig_prefix = match.group(1)
2677     new_prefix = orig_prefix.lower()
2678     if remove_u_prefix:
2679         new_prefix = new_prefix.replace("u", "")
2680     leaf.value = f"{new_prefix}{match.group(2)}"
2681
2682
2683 def normalize_string_quotes(leaf: Leaf) -> None:
2684     """Prefer double quotes but only if it doesn't cause more escaping.
2685
2686     Adds or removes backslashes as appropriate. Doesn't parse and fix
2687     strings nested in f-strings (yet).
2688
2689     Note: Mutates its argument.
2690     """
2691     value = leaf.value.lstrip("furbFURB")
2692     if value[:3] == '"""':
2693         return
2694
2695     elif value[:3] == "'''":
2696         orig_quote = "'''"
2697         new_quote = '"""'
2698     elif value[0] == '"':
2699         orig_quote = '"'
2700         new_quote = "'"
2701     else:
2702         orig_quote = "'"
2703         new_quote = '"'
2704     first_quote_pos = leaf.value.find(orig_quote)
2705     if first_quote_pos == -1:
2706         return  # There's an internal error
2707
2708     prefix = leaf.value[:first_quote_pos]
2709     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2710     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2711     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2712     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2713     if "r" in prefix.casefold():
2714         if unescaped_new_quote.search(body):
2715             # There's at least one unescaped new_quote in this raw string
2716             # so converting is impossible
2717             return
2718
2719         # Do not introduce or remove backslashes in raw strings
2720         new_body = body
2721     else:
2722         # remove unnecessary escapes
2723         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2724         if body != new_body:
2725             # Consider the string without unnecessary escapes as the original
2726             body = new_body
2727             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2728         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2729         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2730     if "f" in prefix.casefold():
2731         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2732         for m in matches:
2733             if "\\" in str(m):
2734                 # Do not introduce backslashes in interpolated expressions
2735                 return
2736     if new_quote == '"""' and new_body[-1:] == '"':
2737         # edge case:
2738         new_body = new_body[:-1] + '\\"'
2739     orig_escape_count = body.count("\\")
2740     new_escape_count = new_body.count("\\")
2741     if new_escape_count > orig_escape_count:
2742         return  # Do not introduce more escaping
2743
2744     if new_escape_count == orig_escape_count and orig_quote == '"':
2745         return  # Prefer double quotes
2746
2747     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2748
2749
2750 def normalize_numeric_literal(leaf: Leaf) -> None:
2751     """Normalizes numeric (float, int, and complex) literals.
2752
2753     All letters used in the representation are normalized to lowercase (except
2754     in Python 2 long literals).
2755     """
2756     text = leaf.value.lower()
2757     if text.startswith(("0o", "0b")):
2758         # Leave octal and binary literals alone.
2759         pass
2760     elif text.startswith("0x"):
2761         # Change hex literals to upper case.
2762         before, after = text[:2], text[2:]
2763         text = f"{before}{after.upper()}"
2764     elif "e" in text:
2765         before, after = text.split("e")
2766         sign = ""
2767         if after.startswith("-"):
2768             after = after[1:]
2769             sign = "-"
2770         elif after.startswith("+"):
2771             after = after[1:]
2772         before = format_float_or_int_string(before)
2773         text = f"{before}e{sign}{after}"
2774     elif text.endswith(("j", "l")):
2775         number = text[:-1]
2776         suffix = text[-1]
2777         # Capitalize in "2L" because "l" looks too similar to "1".
2778         if suffix == "l":
2779             suffix = "L"
2780         text = f"{format_float_or_int_string(number)}{suffix}"
2781     else:
2782         text = format_float_or_int_string(text)
2783     leaf.value = text
2784
2785
2786 def format_float_or_int_string(text: str) -> str:
2787     """Formats a float string like "1.0"."""
2788     if "." not in text:
2789         return text
2790
2791     before, after = text.split(".")
2792     return f"{before or 0}.{after or 0}"
2793
2794
2795 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2796     """Make existing optional parentheses invisible or create new ones.
2797
2798     `parens_after` is a set of string leaf values immeditely after which parens
2799     should be put.
2800
2801     Standardizes on visible parentheses for single-element tuples, and keeps
2802     existing visible parentheses for other tuples and generator expressions.
2803     """
2804     for pc in list_comments(node.prefix, is_endmarker=False):
2805         if pc.value in FMT_OFF:
2806             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2807             return
2808
2809     check_lpar = False
2810     for index, child in enumerate(list(node.children)):
2811         # Add parentheses around long tuple unpacking in assignments.
2812         if (
2813             index == 0
2814             and isinstance(child, Node)
2815             and child.type == syms.testlist_star_expr
2816         ):
2817             check_lpar = True
2818
2819         if check_lpar:
2820             if child.type == syms.atom:
2821                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2822                     lpar = Leaf(token.LPAR, "")
2823                     rpar = Leaf(token.RPAR, "")
2824                     index = child.remove() or 0
2825                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2826             elif is_one_tuple(child):
2827                 # wrap child in visible parentheses
2828                 lpar = Leaf(token.LPAR, "(")
2829                 rpar = Leaf(token.RPAR, ")")
2830                 child.remove()
2831                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2832             elif node.type == syms.import_from:
2833                 # "import from" nodes store parentheses directly as part of
2834                 # the statement
2835                 if child.type == token.LPAR:
2836                     # make parentheses invisible
2837                     child.value = ""  # type: ignore
2838                     node.children[-1].value = ""  # type: ignore
2839                 elif child.type != token.STAR:
2840                     # insert invisible parentheses
2841                     node.insert_child(index, Leaf(token.LPAR, ""))
2842                     node.append_child(Leaf(token.RPAR, ""))
2843                 break
2844
2845             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2846                 # wrap child in invisible parentheses
2847                 lpar = Leaf(token.LPAR, "")
2848                 rpar = Leaf(token.RPAR, "")
2849                 index = child.remove() or 0
2850                 prefix = child.prefix
2851                 child.prefix = ""
2852                 new_child = Node(syms.atom, [lpar, child, rpar])
2853                 new_child.prefix = prefix
2854                 node.insert_child(index, new_child)
2855
2856         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2857
2858
2859 def normalize_fmt_off(node: Node) -> None:
2860     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2861     try_again = True
2862     while try_again:
2863         try_again = convert_one_fmt_off_pair(node)
2864
2865
2866 def convert_one_fmt_off_pair(node: Node) -> bool:
2867     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2868
2869     Returns True if a pair was converted.
2870     """
2871     for leaf in node.leaves():
2872         previous_consumed = 0
2873         for comment in list_comments(leaf.prefix, is_endmarker=False):
2874             if comment.value in FMT_OFF:
2875                 # We only want standalone comments. If there's no previous leaf or
2876                 # the previous leaf is indentation, it's a standalone comment in
2877                 # disguise.
2878                 if comment.type != STANDALONE_COMMENT:
2879                     prev = preceding_leaf(leaf)
2880                     if prev and prev.type not in WHITESPACE:
2881                         continue
2882
2883                 ignored_nodes = list(generate_ignored_nodes(leaf))
2884                 if not ignored_nodes:
2885                     continue
2886
2887                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2888                 parent = first.parent
2889                 prefix = first.prefix
2890                 first.prefix = prefix[comment.consumed :]
2891                 hidden_value = (
2892                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2893                 )
2894                 if hidden_value.endswith("\n"):
2895                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2896                     # leaf (possibly followed by a DEDENT).
2897                     hidden_value = hidden_value[:-1]
2898                 first_idx = None
2899                 for ignored in ignored_nodes:
2900                     index = ignored.remove()
2901                     if first_idx is None:
2902                         first_idx = index
2903                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2904                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2905                 parent.insert_child(
2906                     first_idx,
2907                     Leaf(
2908                         STANDALONE_COMMENT,
2909                         hidden_value,
2910                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2911                     ),
2912                 )
2913                 return True
2914
2915             previous_consumed = comment.consumed
2916
2917     return False
2918
2919
2920 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2921     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2922
2923     Stops at the end of the block.
2924     """
2925     container: Optional[LN] = container_of(leaf)
2926     while container is not None and container.type != token.ENDMARKER:
2927         for comment in list_comments(container.prefix, is_endmarker=False):
2928             if comment.value in FMT_ON:
2929                 return
2930
2931         yield container
2932
2933         container = container.next_sibling
2934
2935
2936 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
2937     """If it's safe, make the parens in the atom `node` invisible, recursively.
2938
2939     Returns whether the node should itself be wrapped in invisible parentheses.
2940
2941     """
2942     if (
2943         node.type != syms.atom
2944         or is_empty_tuple(node)
2945         or is_one_tuple(node)
2946         or (is_yield(node) and parent.type != syms.expr_stmt)
2947         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2948     ):
2949         return False
2950
2951     first = node.children[0]
2952     last = node.children[-1]
2953     if first.type == token.LPAR and last.type == token.RPAR:
2954         # make parentheses invisible
2955         first.value = ""  # type: ignore
2956         last.value = ""  # type: ignore
2957         if len(node.children) > 1:
2958             maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
2959         return False
2960
2961     return True
2962
2963
2964 def is_empty_tuple(node: LN) -> bool:
2965     """Return True if `node` holds an empty tuple."""
2966     return (
2967         node.type == syms.atom
2968         and len(node.children) == 2
2969         and node.children[0].type == token.LPAR
2970         and node.children[1].type == token.RPAR
2971     )
2972
2973
2974 def is_one_tuple(node: LN) -> bool:
2975     """Return True if `node` holds a tuple with one element, with or without parens."""
2976     if node.type == syms.atom:
2977         if len(node.children) != 3:
2978             return False
2979
2980         lpar, gexp, rpar = node.children
2981         if not (
2982             lpar.type == token.LPAR
2983             and gexp.type == syms.testlist_gexp
2984             and rpar.type == token.RPAR
2985         ):
2986             return False
2987
2988         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2989
2990     return (
2991         node.type in IMPLICIT_TUPLE
2992         and len(node.children) == 2
2993         and node.children[1].type == token.COMMA
2994     )
2995
2996
2997 def is_yield(node: LN) -> bool:
2998     """Return True if `node` holds a `yield` or `yield from` expression."""
2999     if node.type == syms.yield_expr:
3000         return True
3001
3002     if node.type == token.NAME and node.value == "yield":  # type: ignore
3003         return True
3004
3005     if node.type != syms.atom:
3006         return False
3007
3008     if len(node.children) != 3:
3009         return False
3010
3011     lpar, expr, rpar = node.children
3012     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3013         return is_yield(expr)
3014
3015     return False
3016
3017
3018 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3019     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3020
3021     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3022     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3023     extended iterable unpacking (PEP 3132) and additional unpacking
3024     generalizations (PEP 448).
3025     """
3026     if leaf.type not in STARS or not leaf.parent:
3027         return False
3028
3029     p = leaf.parent
3030     if p.type == syms.star_expr:
3031         # Star expressions are also used as assignment targets in extended
3032         # iterable unpacking (PEP 3132).  See what its parent is instead.
3033         if not p.parent:
3034             return False
3035
3036         p = p.parent
3037
3038     return p.type in within
3039
3040
3041 def is_multiline_string(leaf: Leaf) -> bool:
3042     """Return True if `leaf` is a multiline string that actually spans many lines."""
3043     value = leaf.value.lstrip("furbFURB")
3044     return value[:3] in {'"""', "'''"} and "\n" in value
3045
3046
3047 def is_stub_suite(node: Node) -> bool:
3048     """Return True if `node` is a suite with a stub body."""
3049     if (
3050         len(node.children) != 4
3051         or node.children[0].type != token.NEWLINE
3052         or node.children[1].type != token.INDENT
3053         or node.children[3].type != token.DEDENT
3054     ):
3055         return False
3056
3057     return is_stub_body(node.children[2])
3058
3059
3060 def is_stub_body(node: LN) -> bool:
3061     """Return True if `node` is a simple statement containing an ellipsis."""
3062     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3063         return False
3064
3065     if len(node.children) != 2:
3066         return False
3067
3068     child = node.children[0]
3069     return (
3070         child.type == syms.atom
3071         and len(child.children) == 3
3072         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3073     )
3074
3075
3076 def max_delimiter_priority_in_atom(node: LN) -> int:
3077     """Return maximum delimiter priority inside `node`.
3078
3079     This is specific to atoms with contents contained in a pair of parentheses.
3080     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3081     """
3082     if node.type != syms.atom:
3083         return 0
3084
3085     first = node.children[0]
3086     last = node.children[-1]
3087     if not (first.type == token.LPAR and last.type == token.RPAR):
3088         return 0
3089
3090     bt = BracketTracker()
3091     for c in node.children[1:-1]:
3092         if isinstance(c, Leaf):
3093             bt.mark(c)
3094         else:
3095             for leaf in c.leaves():
3096                 bt.mark(leaf)
3097     try:
3098         return bt.max_delimiter_priority()
3099
3100     except ValueError:
3101         return 0
3102
3103
3104 def ensure_visible(leaf: Leaf) -> None:
3105     """Make sure parentheses are visible.
3106
3107     They could be invisible as part of some statements (see
3108     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3109     """
3110     if leaf.type == token.LPAR:
3111         leaf.value = "("
3112     elif leaf.type == token.RPAR:
3113         leaf.value = ")"
3114
3115
3116 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3117     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3118
3119     if not (
3120         opening_bracket.parent
3121         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3122         and opening_bracket.value in "[{("
3123     ):
3124         return False
3125
3126     try:
3127         last_leaf = line.leaves[-1]
3128         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3129         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3130     except (IndexError, ValueError):
3131         return False
3132
3133     return max_priority == COMMA_PRIORITY
3134
3135
3136 def get_features_used(node: Node) -> Set[Feature]:
3137     """Return a set of (relatively) new Python features used in this file.
3138
3139     Currently looking for:
3140     - f-strings;
3141     - underscores in numeric literals; and
3142     - trailing commas after * or ** in function signatures and calls.
3143     """
3144     features: Set[Feature] = set()
3145     for n in node.pre_order():
3146         if n.type == token.STRING:
3147             value_head = n.value[:2]  # type: ignore
3148             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3149                 features.add(Feature.F_STRINGS)
3150
3151         elif n.type == token.NUMBER:
3152             if "_" in n.value:  # type: ignore
3153                 features.add(Feature.NUMERIC_UNDERSCORES)
3154
3155         elif (
3156             n.type in {syms.typedargslist, syms.arglist}
3157             and n.children
3158             and n.children[-1].type == token.COMMA
3159         ):
3160             if n.type == syms.typedargslist:
3161                 feature = Feature.TRAILING_COMMA_IN_DEF
3162             else:
3163                 feature = Feature.TRAILING_COMMA_IN_CALL
3164
3165             for ch in n.children:
3166                 if ch.type in STARS:
3167                     features.add(feature)
3168
3169                 if ch.type == syms.argument:
3170                     for argch in ch.children:
3171                         if argch.type in STARS:
3172                             features.add(feature)
3173
3174     return features
3175
3176
3177 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3178     """Detect the version to target based on the nodes used."""
3179     features = get_features_used(node)
3180     return {
3181         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3182     }
3183
3184
3185 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3186     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3187
3188     Brackets can be omitted if the entire trailer up to and including
3189     a preceding closing bracket fits in one line.
3190
3191     Yielded sets are cumulative (contain results of previous yields, too).  First
3192     set is empty.
3193     """
3194
3195     omit: Set[LeafID] = set()
3196     yield omit
3197
3198     length = 4 * line.depth
3199     opening_bracket = None
3200     closing_bracket = None
3201     inner_brackets: Set[LeafID] = set()
3202     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3203         length += leaf_length
3204         if length > line_length:
3205             break
3206
3207         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3208         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3209             break
3210
3211         if opening_bracket:
3212             if leaf is opening_bracket:
3213                 opening_bracket = None
3214             elif leaf.type in CLOSING_BRACKETS:
3215                 inner_brackets.add(id(leaf))
3216         elif leaf.type in CLOSING_BRACKETS:
3217             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3218                 # Empty brackets would fail a split so treat them as "inner"
3219                 # brackets (e.g. only add them to the `omit` set if another
3220                 # pair of brackets was good enough.
3221                 inner_brackets.add(id(leaf))
3222                 continue
3223
3224             if closing_bracket:
3225                 omit.add(id(closing_bracket))
3226                 omit.update(inner_brackets)
3227                 inner_brackets.clear()
3228                 yield omit
3229
3230             if leaf.value:
3231                 opening_bracket = leaf.opening_bracket
3232                 closing_bracket = leaf
3233
3234
3235 def get_future_imports(node: Node) -> Set[str]:
3236     """Return a set of __future__ imports in the file."""
3237     imports: Set[str] = set()
3238
3239     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3240         for child in children:
3241             if isinstance(child, Leaf):
3242                 if child.type == token.NAME:
3243                     yield child.value
3244             elif child.type == syms.import_as_name:
3245                 orig_name = child.children[0]
3246                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3247                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3248                 yield orig_name.value
3249             elif child.type == syms.import_as_names:
3250                 yield from get_imports_from_children(child.children)
3251             else:
3252                 raise AssertionError("Invalid syntax parsing imports")
3253
3254     for child in node.children:
3255         if child.type != syms.simple_stmt:
3256             break
3257         first_child = child.children[0]
3258         if isinstance(first_child, Leaf):
3259             # Continue looking if we see a docstring; otherwise stop.
3260             if (
3261                 len(child.children) == 2
3262                 and first_child.type == token.STRING
3263                 and child.children[1].type == token.NEWLINE
3264             ):
3265                 continue
3266             else:
3267                 break
3268         elif first_child.type == syms.import_from:
3269             module_name = first_child.children[1]
3270             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3271                 break
3272             imports |= set(get_imports_from_children(first_child.children[3:]))
3273         else:
3274             break
3275     return imports
3276
3277
3278 def gen_python_files_in_dir(
3279     path: Path,
3280     root: Path,
3281     include: Pattern[str],
3282     exclude: Pattern[str],
3283     report: "Report",
3284 ) -> Iterator[Path]:
3285     """Generate all files under `path` whose paths are not excluded by the
3286     `exclude` regex, but are included by the `include` regex.
3287
3288     Symbolic links pointing outside of the `root` directory are ignored.
3289
3290     `report` is where output about exclusions goes.
3291     """
3292     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3293     for child in path.iterdir():
3294         try:
3295             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3296         except ValueError:
3297             if child.is_symlink():
3298                 report.path_ignored(
3299                     child, f"is a symbolic link that points outside {root}"
3300                 )
3301                 continue
3302
3303             raise
3304
3305         if child.is_dir():
3306             normalized_path += "/"
3307         exclude_match = exclude.search(normalized_path)
3308         if exclude_match and exclude_match.group(0):
3309             report.path_ignored(child, f"matches the --exclude regular expression")
3310             continue
3311
3312         if child.is_dir():
3313             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3314
3315         elif child.is_file():
3316             include_match = include.search(normalized_path)
3317             if include_match:
3318                 yield child
3319
3320
3321 @lru_cache()
3322 def find_project_root(srcs: Iterable[str]) -> Path:
3323     """Return a directory containing .git, .hg, or pyproject.toml.
3324
3325     That directory can be one of the directories passed in `srcs` or their
3326     common parent.
3327
3328     If no directory in the tree contains a marker that would specify it's the
3329     project root, the root of the file system is returned.
3330     """
3331     if not srcs:
3332         return Path("/").resolve()
3333
3334     common_base = min(Path(src).resolve() for src in srcs)
3335     if common_base.is_dir():
3336         # Append a fake file so `parents` below returns `common_base_dir`, too.
3337         common_base /= "fake-file"
3338     for directory in common_base.parents:
3339         if (directory / ".git").is_dir():
3340             return directory
3341
3342         if (directory / ".hg").is_dir():
3343             return directory
3344
3345         if (directory / "pyproject.toml").is_file():
3346             return directory
3347
3348     return directory
3349
3350
3351 @dataclass
3352 class Report:
3353     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3354
3355     check: bool = False
3356     quiet: bool = False
3357     verbose: bool = False
3358     change_count: int = 0
3359     same_count: int = 0
3360     failure_count: int = 0
3361
3362     def done(self, src: Path, changed: Changed) -> None:
3363         """Increment the counter for successful reformatting. Write out a message."""
3364         if changed is Changed.YES:
3365             reformatted = "would reformat" if self.check else "reformatted"
3366             if self.verbose or not self.quiet:
3367                 out(f"{reformatted} {src}")
3368             self.change_count += 1
3369         else:
3370             if self.verbose:
3371                 if changed is Changed.NO:
3372                     msg = f"{src} already well formatted, good job."
3373                 else:
3374                     msg = f"{src} wasn't modified on disk since last run."
3375                 out(msg, bold=False)
3376             self.same_count += 1
3377
3378     def failed(self, src: Path, message: str) -> None:
3379         """Increment the counter for failed reformatting. Write out a message."""
3380         err(f"error: cannot format {src}: {message}")
3381         self.failure_count += 1
3382
3383     def path_ignored(self, path: Path, message: str) -> None:
3384         if self.verbose:
3385             out(f"{path} ignored: {message}", bold=False)
3386
3387     @property
3388     def return_code(self) -> int:
3389         """Return the exit code that the app should use.
3390
3391         This considers the current state of changed files and failures:
3392         - if there were any failures, return 123;
3393         - if any files were changed and --check is being used, return 1;
3394         - otherwise return 0.
3395         """
3396         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3397         # 126 we have special return codes reserved by the shell.
3398         if self.failure_count:
3399             return 123
3400
3401         elif self.change_count and self.check:
3402             return 1
3403
3404         return 0
3405
3406     def __str__(self) -> str:
3407         """Render a color report of the current state.
3408
3409         Use `click.unstyle` to remove colors.
3410         """
3411         if self.check:
3412             reformatted = "would be reformatted"
3413             unchanged = "would be left unchanged"
3414             failed = "would fail to reformat"
3415         else:
3416             reformatted = "reformatted"
3417             unchanged = "left unchanged"
3418             failed = "failed to reformat"
3419         report = []
3420         if self.change_count:
3421             s = "s" if self.change_count > 1 else ""
3422             report.append(
3423                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3424             )
3425         if self.same_count:
3426             s = "s" if self.same_count > 1 else ""
3427             report.append(f"{self.same_count} file{s} {unchanged}")
3428         if self.failure_count:
3429             s = "s" if self.failure_count > 1 else ""
3430             report.append(
3431                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3432             )
3433         return ", ".join(report) + "."
3434
3435
3436 def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
3437     for feature_version in (7, 6):
3438         try:
3439             return ast3.parse(src, feature_version=feature_version)
3440         except SyntaxError:
3441             continue
3442
3443     return ast27.parse(src)
3444
3445
3446 def assert_equivalent(src: str, dst: str) -> None:
3447     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3448
3449     import traceback
3450
3451     def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3452         """Simple visitor generating strings to compare ASTs by content."""
3453         yield f"{'  ' * depth}{node.__class__.__name__}("
3454
3455         for field in sorted(node._fields):
3456             # TypeIgnore has only one field 'lineno' which breaks this comparison
3457             if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
3458                 break
3459
3460             # Ignore str kind which is case sensitive / and ignores unicode_literals
3461             if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
3462                 continue
3463
3464             try:
3465                 value = getattr(node, field)
3466             except AttributeError:
3467                 continue
3468
3469             yield f"{'  ' * (depth+1)}{field}="
3470
3471             if isinstance(value, list):
3472                 for item in value:
3473                     # Ignore nested tuples within del statements, because we may insert
3474                     # parentheses and they change the AST.
3475                     if (
3476                         field == "targets"
3477                         and isinstance(node, (ast3.Delete, ast27.Delete))
3478                         and isinstance(item, (ast3.Tuple, ast27.Tuple))
3479                     ):
3480                         for item in item.elts:
3481                             yield from _v(item, depth + 2)
3482                     elif isinstance(item, (ast3.AST, ast27.AST)):
3483                         yield from _v(item, depth + 2)
3484
3485             elif isinstance(value, (ast3.AST, ast27.AST)):
3486                 yield from _v(value, depth + 2)
3487
3488             else:
3489                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3490
3491         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3492
3493     try:
3494         src_ast = parse_ast(src)
3495     except Exception as exc:
3496         raise AssertionError(
3497             f"cannot use --safe with this file; failed to parse source file.  "
3498             f"AST error message: {exc}"
3499         )
3500
3501     try:
3502         dst_ast = parse_ast(dst)
3503     except Exception as exc:
3504         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3505         raise AssertionError(
3506             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3507             f"Please report a bug on https://github.com/python/black/issues.  "
3508             f"This invalid output might be helpful: {log}"
3509         ) from None
3510
3511     src_ast_str = "\n".join(_v(src_ast))
3512     dst_ast_str = "\n".join(_v(dst_ast))
3513     if src_ast_str != dst_ast_str:
3514         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3515         raise AssertionError(
3516             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3517             f"the source.  "
3518             f"Please report a bug on https://github.com/python/black/issues.  "
3519             f"This diff might be helpful: {log}"
3520         ) from None
3521
3522
3523 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3524     """Raise AssertionError if `dst` reformats differently the second time."""
3525     newdst = format_str(dst, mode=mode)
3526     if dst != newdst:
3527         log = dump_to_file(
3528             diff(src, dst, "source", "first pass"),
3529             diff(dst, newdst, "first pass", "second pass"),
3530         )
3531         raise AssertionError(
3532             f"INTERNAL ERROR: Black produced different code on the second pass "
3533             f"of the formatter.  "
3534             f"Please report a bug on https://github.com/python/black/issues.  "
3535             f"This diff might be helpful: {log}"
3536         ) from None
3537
3538
3539 def dump_to_file(*output: str) -> str:
3540     """Dump `output` to a temporary file. Return path to the file."""
3541     import tempfile
3542
3543     with tempfile.NamedTemporaryFile(
3544         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3545     ) as f:
3546         for lines in output:
3547             f.write(lines)
3548             if lines and lines[-1] != "\n":
3549                 f.write("\n")
3550     return f.name
3551
3552
3553 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3554     """Return a unified diff string between strings `a` and `b`."""
3555     import difflib
3556
3557     a_lines = [line + "\n" for line in a.split("\n")]
3558     b_lines = [line + "\n" for line in b.split("\n")]
3559     return "".join(
3560         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3561     )
3562
3563
3564 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3565     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3566     err("Aborted!")
3567     for task in tasks:
3568         task.cancel()
3569
3570
3571 def shutdown(loop: BaseEventLoop) -> None:
3572     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3573     try:
3574         if sys.version_info[:2] >= (3, 7):
3575             all_tasks = asyncio.all_tasks
3576         else:
3577             all_tasks = asyncio.Task.all_tasks
3578         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3579         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3580         if not to_cancel:
3581             return
3582
3583         for task in to_cancel:
3584             task.cancel()
3585         loop.run_until_complete(
3586             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3587         )
3588     finally:
3589         # `concurrent.futures.Future` objects cannot be cancelled once they
3590         # are already running. There might be some when the `shutdown()` happened.
3591         # Silence their logger's spew about the event loop being closed.
3592         cf_logger = logging.getLogger("concurrent.futures")
3593         cf_logger.setLevel(logging.CRITICAL)
3594         loop.close()
3595
3596
3597 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3598     """Replace `regex` with `replacement` twice on `original`.
3599
3600     This is used by string normalization to perform replaces on
3601     overlapping matches.
3602     """
3603     return regex.sub(replacement, regex.sub(replacement, original))
3604
3605
3606 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3607     """Compile a regular expression string in `regex`.
3608
3609     If it contains newlines, use verbose mode.
3610     """
3611     if "\n" in regex:
3612         regex = "(?x)" + regex
3613     return re.compile(regex)
3614
3615
3616 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3617     """Like `reversed(enumerate(sequence))` if that were possible."""
3618     index = len(sequence) - 1
3619     for element in reversed(sequence):
3620         yield (index, element)
3621         index -= 1
3622
3623
3624 def enumerate_with_length(
3625     line: Line, reversed: bool = False
3626 ) -> Iterator[Tuple[Index, Leaf, int]]:
3627     """Return an enumeration of leaves with their length.
3628
3629     Stops prematurely on multiline strings and standalone comments.
3630     """
3631     op = cast(
3632         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3633         enumerate_reversed if reversed else enumerate,
3634     )
3635     for index, leaf in op(line.leaves):
3636         length = len(leaf.prefix) + len(leaf.value)
3637         if "\n" in leaf.value:
3638             return  # Multiline strings, we can't continue.
3639
3640         comment: Optional[Leaf]
3641         for comment in line.comments_after(leaf):
3642             length += len(comment.value)
3643
3644         yield index, leaf, length
3645
3646
3647 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3648     """Return True if `line` is no longer than `line_length`.
3649
3650     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3651     """
3652     if not line_str:
3653         line_str = str(line).strip("\n")
3654     return (
3655         len(line_str) <= line_length
3656         and "\n" not in line_str  # multiline strings
3657         and not line.contains_standalone_comments()
3658     )
3659
3660
3661 def can_be_split(line: Line) -> bool:
3662     """Return False if the line cannot be split *for sure*.
3663
3664     This is not an exhaustive search but a cheap heuristic that we can use to
3665     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3666     in unnecessary parentheses).
3667     """
3668     leaves = line.leaves
3669     if len(leaves) < 2:
3670         return False
3671
3672     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3673         call_count = 0
3674         dot_count = 0
3675         next = leaves[-1]
3676         for leaf in leaves[-2::-1]:
3677             if leaf.type in OPENING_BRACKETS:
3678                 if next.type not in CLOSING_BRACKETS:
3679                     return False
3680
3681                 call_count += 1
3682             elif leaf.type == token.DOT:
3683                 dot_count += 1
3684             elif leaf.type == token.NAME:
3685                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3686                     return False
3687
3688             elif leaf.type not in CLOSING_BRACKETS:
3689                 return False
3690
3691             if dot_count > 1 and call_count > 1:
3692                 return False
3693
3694     return True
3695
3696
3697 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3698     """Does `line` have a shape safe to reformat without optional parens around it?
3699
3700     Returns True for only a subset of potentially nice looking formattings but
3701     the point is to not return false positives that end up producing lines that
3702     are too long.
3703     """
3704     bt = line.bracket_tracker
3705     if not bt.delimiters:
3706         # Without delimiters the optional parentheses are useless.
3707         return True
3708
3709     max_priority = bt.max_delimiter_priority()
3710     if bt.delimiter_count_with_priority(max_priority) > 1:
3711         # With more than one delimiter of a kind the optional parentheses read better.
3712         return False
3713
3714     if max_priority == DOT_PRIORITY:
3715         # A single stranded method call doesn't require optional parentheses.
3716         return True
3717
3718     assert len(line.leaves) >= 2, "Stranded delimiter"
3719
3720     first = line.leaves[0]
3721     second = line.leaves[1]
3722     penultimate = line.leaves[-2]
3723     last = line.leaves[-1]
3724
3725     # With a single delimiter, omit if the expression starts or ends with
3726     # a bracket.
3727     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3728         remainder = False
3729         length = 4 * line.depth
3730         for _index, leaf, leaf_length in enumerate_with_length(line):
3731             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3732                 remainder = True
3733             if remainder:
3734                 length += leaf_length
3735                 if length > line_length:
3736                     break
3737
3738                 if leaf.type in OPENING_BRACKETS:
3739                     # There are brackets we can further split on.
3740                     remainder = False
3741
3742         else:
3743             # checked the entire string and line length wasn't exceeded
3744             if len(line.leaves) == _index + 1:
3745                 return True
3746
3747         # Note: we are not returning False here because a line might have *both*
3748         # a leading opening bracket and a trailing closing bracket.  If the
3749         # opening bracket doesn't match our rule, maybe the closing will.
3750
3751     if (
3752         last.type == token.RPAR
3753         or last.type == token.RBRACE
3754         or (
3755             # don't use indexing for omitting optional parentheses;
3756             # it looks weird
3757             last.type == token.RSQB
3758             and last.parent
3759             and last.parent.type != syms.trailer
3760         )
3761     ):
3762         if penultimate.type in OPENING_BRACKETS:
3763             # Empty brackets don't help.
3764             return False
3765
3766         if is_multiline_string(first):
3767             # Additional wrapping of a multiline string in this situation is
3768             # unnecessary.
3769             return True
3770
3771         length = 4 * line.depth
3772         seen_other_brackets = False
3773         for _index, leaf, leaf_length in enumerate_with_length(line):
3774             length += leaf_length
3775             if leaf is last.opening_bracket:
3776                 if seen_other_brackets or length <= line_length:
3777                     return True
3778
3779             elif leaf.type in OPENING_BRACKETS:
3780                 # There are brackets we can further split on.
3781                 seen_other_brackets = True
3782
3783     return False
3784
3785
3786 def get_cache_file(mode: FileMode) -> Path:
3787     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3788
3789
3790 def read_cache(mode: FileMode) -> Cache:
3791     """Read the cache if it exists and is well formed.
3792
3793     If it is not well formed, the call to write_cache later should resolve the issue.
3794     """
3795     cache_file = get_cache_file(mode)
3796     if not cache_file.exists():
3797         return {}
3798
3799     with cache_file.open("rb") as fobj:
3800         try:
3801             cache: Cache = pickle.load(fobj)
3802         except pickle.UnpicklingError:
3803             return {}
3804
3805     return cache
3806
3807
3808 def get_cache_info(path: Path) -> CacheInfo:
3809     """Return the information used to check if a file is already formatted or not."""
3810     stat = path.stat()
3811     return stat.st_mtime, stat.st_size
3812
3813
3814 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3815     """Split an iterable of paths in `sources` into two sets.
3816
3817     The first contains paths of files that modified on disk or are not in the
3818     cache. The other contains paths to non-modified files.
3819     """
3820     todo, done = set(), set()
3821     for src in sources:
3822         src = src.resolve()
3823         if cache.get(src) != get_cache_info(src):
3824             todo.add(src)
3825         else:
3826             done.add(src)
3827     return todo, done
3828
3829
3830 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3831     """Update the cache file."""
3832     cache_file = get_cache_file(mode)
3833     try:
3834         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3835         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3836         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3837             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3838         os.replace(f.name, cache_file)
3839     except OSError:
3840         pass
3841
3842
3843 def patch_click() -> None:
3844     """Make Click not crash.
3845
3846     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3847     default which restricts paths that it can access during the lifetime of the
3848     application.  Click refuses to work in this scenario by raising a RuntimeError.
3849
3850     In case of Black the likelihood that non-ASCII characters are going to be used in
3851     file paths is minimal since it's Python source code.  Moreover, this crash was
3852     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3853     """
3854     try:
3855         from click import core
3856         from click import _unicodefun  # type: ignore
3857     except ModuleNotFoundError:
3858         return
3859
3860     for module in (core, _unicodefun):
3861         if hasattr(module, "_verify_python3_env"):
3862             module._verify_python3_env = lambda: None
3863
3864
3865 def patched_main() -> None:
3866     freeze_support()
3867     patch_click()
3868     main()
3869
3870
3871 if __name__ == "__main__":
3872     patched_main()