]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Pin comment to single leaf in invisible parens (#872)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 import traceback
20 from typing import (
21     Any,
22     Callable,
23     Collection,
24     Dict,
25     Generator,
26     Generic,
27     Iterable,
28     Iterator,
29     List,
30     Optional,
31     Pattern,
32     Sequence,
33     Set,
34     Tuple,
35     TypeVar,
36     Union,
37     cast,
38 )
39
40 from appdirs import user_cache_dir
41 from attr import dataclass, evolve, Factory
42 import click
43 import toml
44 from typed_ast import ast3, ast27
45
46 # lib2to3 fork
47 from blib2to3.pytree import Node, Leaf, type_repr
48 from blib2to3 import pygram, pytree
49 from blib2to3.pgen2 import driver, token
50 from blib2to3.pgen2.grammar import Grammar
51 from blib2to3.pgen2.parse import ParseError
52
53
54 __version__ = "19.3b0"
55 DEFAULT_LINE_LENGTH = 88
56 DEFAULT_EXCLUDES = (
57     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
58 )
59 DEFAULT_INCLUDES = r"\.pyi?$"
60 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
61
62
63 # types
64 FileContent = str
65 Encoding = str
66 NewLine = str
67 Depth = int
68 NodeType = int
69 LeafID = int
70 Priority = int
71 Index = int
72 LN = Union[Leaf, Node]
73 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
74 Timestamp = float
75 FileSize = int
76 CacheInfo = Tuple[Timestamp, FileSize]
77 Cache = Dict[Path, CacheInfo]
78 out = partial(click.secho, bold=True, err=True)
79 err = partial(click.secho, fg="red", err=True)
80
81 pygram.initialize(CACHE_DIR)
82 syms = pygram.python_symbols
83
84
85 class NothingChanged(UserWarning):
86     """Raised when reformatted code is the same as source."""
87
88
89 class CannotSplit(Exception):
90     """A readable split that fits the allotted line length is impossible."""
91
92
93 class InvalidInput(ValueError):
94     """Raised when input source code fails all parse attempts."""
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102
103     @classmethod
104     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
105         if check and not diff:
106             return cls.CHECK
107
108         return cls.DIFF if diff else cls.YES
109
110
111 class Changed(Enum):
112     NO = 0
113     CACHED = 1
114     YES = 2
115
116
117 class TargetVersion(Enum):
118     PY27 = 2
119     PY33 = 3
120     PY34 = 4
121     PY35 = 5
122     PY36 = 6
123     PY37 = 7
124     PY38 = 8
125
126     def is_python2(self) -> bool:
127         return self is TargetVersion.PY27
128
129
130 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
131
132
133 class Feature(Enum):
134     # All string literals are unicode
135     UNICODE_LITERALS = 1
136     F_STRINGS = 2
137     NUMERIC_UNDERSCORES = 3
138     TRAILING_COMMA_IN_CALL = 4
139     TRAILING_COMMA_IN_DEF = 5
140     # The following two feature-flags are mutually exclusive, and exactly one should be
141     # set for every version of python.
142     ASYNC_IDENTIFIERS = 6
143     ASYNC_KEYWORDS = 7
144
145
146 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
147     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
148     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
149     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
150     TargetVersion.PY35: {
151         Feature.UNICODE_LITERALS,
152         Feature.TRAILING_COMMA_IN_CALL,
153         Feature.ASYNC_IDENTIFIERS,
154     },
155     TargetVersion.PY36: {
156         Feature.UNICODE_LITERALS,
157         Feature.F_STRINGS,
158         Feature.NUMERIC_UNDERSCORES,
159         Feature.TRAILING_COMMA_IN_CALL,
160         Feature.TRAILING_COMMA_IN_DEF,
161         Feature.ASYNC_IDENTIFIERS,
162     },
163     TargetVersion.PY37: {
164         Feature.UNICODE_LITERALS,
165         Feature.F_STRINGS,
166         Feature.NUMERIC_UNDERSCORES,
167         Feature.TRAILING_COMMA_IN_CALL,
168         Feature.TRAILING_COMMA_IN_DEF,
169         Feature.ASYNC_KEYWORDS,
170     },
171     TargetVersion.PY38: {
172         Feature.UNICODE_LITERALS,
173         Feature.F_STRINGS,
174         Feature.NUMERIC_UNDERSCORES,
175         Feature.TRAILING_COMMA_IN_CALL,
176         Feature.TRAILING_COMMA_IN_DEF,
177         Feature.ASYNC_KEYWORDS,
178     },
179 }
180
181
182 @dataclass
183 class FileMode:
184     target_versions: Set[TargetVersion] = Factory(set)
185     line_length: int = DEFAULT_LINE_LENGTH
186     string_normalization: bool = True
187     is_pyi: bool = False
188
189     def get_cache_key(self) -> str:
190         if self.target_versions:
191             version_str = ",".join(
192                 str(version.value)
193                 for version in sorted(self.target_versions, key=lambda v: v.value)
194             )
195         else:
196             version_str = "-"
197         parts = [
198             version_str,
199             str(self.line_length),
200             str(int(self.string_normalization)),
201             str(int(self.is_pyi)),
202         ]
203         return ".".join(parts)
204
205
206 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
207     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
208
209
210 def read_pyproject_toml(
211     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
212 ) -> Optional[str]:
213     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
214
215     Returns the path to a successfully found and read configuration file, None
216     otherwise.
217     """
218     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
219     if not value:
220         root = find_project_root(ctx.params.get("src", ()))
221         path = root / "pyproject.toml"
222         if path.is_file():
223             value = str(path)
224         else:
225             return None
226
227     try:
228         pyproject_toml = toml.load(value)
229         config = pyproject_toml.get("tool", {}).get("black", {})
230     except (toml.TomlDecodeError, OSError) as e:
231         raise click.FileError(
232             filename=value, hint=f"Error reading configuration file: {e}"
233         )
234
235     if not config:
236         return None
237
238     if ctx.default_map is None:
239         ctx.default_map = {}
240     ctx.default_map.update(  # type: ignore  # bad types in .pyi
241         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
242     )
243     return value
244
245
246 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
247 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
248 @click.option(
249     "-l",
250     "--line-length",
251     type=int,
252     default=DEFAULT_LINE_LENGTH,
253     help="How many characters per line to allow.",
254     show_default=True,
255 )
256 @click.option(
257     "-t",
258     "--target-version",
259     type=click.Choice([v.name.lower() for v in TargetVersion]),
260     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
261     multiple=True,
262     help=(
263         "Python versions that should be supported by Black's output. [default: "
264         "per-file auto-detection]"
265     ),
266 )
267 @click.option(
268     "--py36",
269     is_flag=True,
270     help=(
271         "Allow using Python 3.6-only syntax on all input files.  This will put "
272         "trailing commas in function signatures and calls also after *args and "
273         "**kwargs. Deprecated; use --target-version instead. "
274         "[default: per-file auto-detection]"
275     ),
276 )
277 @click.option(
278     "--pyi",
279     is_flag=True,
280     help=(
281         "Format all input files like typing stubs regardless of file extension "
282         "(useful when piping source on standard input)."
283     ),
284 )
285 @click.option(
286     "-S",
287     "--skip-string-normalization",
288     is_flag=True,
289     help="Don't normalize string quotes or prefixes.",
290 )
291 @click.option(
292     "--check",
293     is_flag=True,
294     help=(
295         "Don't write the files back, just return the status.  Return code 0 "
296         "means nothing would change.  Return code 1 means some files would be "
297         "reformatted.  Return code 123 means there was an internal error."
298     ),
299 )
300 @click.option(
301     "--diff",
302     is_flag=True,
303     help="Don't write the files back, just output a diff for each file on stdout.",
304 )
305 @click.option(
306     "--fast/--safe",
307     is_flag=True,
308     help="If --fast given, skip temporary sanity checks. [default: --safe]",
309 )
310 @click.option(
311     "--include",
312     type=str,
313     default=DEFAULT_INCLUDES,
314     help=(
315         "A regular expression that matches files and directories that should be "
316         "included on recursive searches.  An empty value means all files are "
317         "included regardless of the name.  Use forward slashes for directories on "
318         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
319         "later."
320     ),
321     show_default=True,
322 )
323 @click.option(
324     "--exclude",
325     type=str,
326     default=DEFAULT_EXCLUDES,
327     help=(
328         "A regular expression that matches files and directories that should be "
329         "excluded on recursive searches.  An empty value means no paths are excluded. "
330         "Use forward slashes for directories on all platforms (Windows, too).  "
331         "Exclusions are calculated first, inclusions later."
332     ),
333     show_default=True,
334 )
335 @click.option(
336     "-q",
337     "--quiet",
338     is_flag=True,
339     help=(
340         "Don't emit non-error messages to stderr. Errors are still emitted; "
341         "silence those with 2>/dev/null."
342     ),
343 )
344 @click.option(
345     "-v",
346     "--verbose",
347     is_flag=True,
348     help=(
349         "Also emit messages to stderr about files that were not changed or were "
350         "ignored due to --exclude=."
351     ),
352 )
353 @click.version_option(version=__version__)
354 @click.argument(
355     "src",
356     nargs=-1,
357     type=click.Path(
358         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
359     ),
360     is_eager=True,
361 )
362 @click.option(
363     "--config",
364     type=click.Path(
365         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
366     ),
367     is_eager=True,
368     callback=read_pyproject_toml,
369     help="Read configuration from PATH.",
370 )
371 @click.pass_context
372 def main(
373     ctx: click.Context,
374     code: Optional[str],
375     line_length: int,
376     target_version: List[TargetVersion],
377     check: bool,
378     diff: bool,
379     fast: bool,
380     pyi: bool,
381     py36: bool,
382     skip_string_normalization: bool,
383     quiet: bool,
384     verbose: bool,
385     include: str,
386     exclude: str,
387     src: Tuple[str],
388     config: Optional[str],
389 ) -> None:
390     """The uncompromising code formatter."""
391     write_back = WriteBack.from_configuration(check=check, diff=diff)
392     if target_version:
393         if py36:
394             err(f"Cannot use both --target-version and --py36")
395             ctx.exit(2)
396         else:
397             versions = set(target_version)
398     elif py36:
399         err(
400             "--py36 is deprecated and will be removed in a future version. "
401             "Use --target-version py36 instead."
402         )
403         versions = PY36_VERSIONS
404     else:
405         # We'll autodetect later.
406         versions = set()
407     mode = FileMode(
408         target_versions=versions,
409         line_length=line_length,
410         is_pyi=pyi,
411         string_normalization=not skip_string_normalization,
412     )
413     if config and verbose:
414         out(f"Using configuration from {config}.", bold=False, fg="blue")
415     if code is not None:
416         print(format_str(code, mode=mode))
417         ctx.exit(0)
418     try:
419         include_regex = re_compile_maybe_verbose(include)
420     except re.error:
421         err(f"Invalid regular expression for include given: {include!r}")
422         ctx.exit(2)
423     try:
424         exclude_regex = re_compile_maybe_verbose(exclude)
425     except re.error:
426         err(f"Invalid regular expression for exclude given: {exclude!r}")
427         ctx.exit(2)
428     report = Report(check=check, quiet=quiet, verbose=verbose)
429     root = find_project_root(src)
430     sources: Set[Path] = set()
431     for s in src:
432         p = Path(s)
433         if p.is_dir():
434             sources.update(
435                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
436             )
437         elif p.is_file() or s == "-":
438             # if a file was explicitly given, we don't care about its extension
439             sources.add(p)
440         else:
441             err(f"invalid path: {s}")
442     if len(sources) == 0:
443         if verbose or not quiet:
444             out("No paths given. Nothing to do 😴")
445         ctx.exit(0)
446
447     if len(sources) == 1:
448         reformat_one(
449             src=sources.pop(),
450             fast=fast,
451             write_back=write_back,
452             mode=mode,
453             report=report,
454         )
455     else:
456         reformat_many(
457             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
458         )
459
460     if verbose or not quiet:
461         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
462         click.secho(str(report), err=True)
463     ctx.exit(report.return_code)
464
465
466 def reformat_one(
467     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
468 ) -> None:
469     """Reformat a single file under `src` without spawning child processes.
470
471     `fast`, `write_back`, and `mode` options are passed to
472     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
473     """
474     try:
475         changed = Changed.NO
476         if not src.is_file() and str(src) == "-":
477             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
478                 changed = Changed.YES
479         else:
480             cache: Cache = {}
481             if write_back != WriteBack.DIFF:
482                 cache = read_cache(mode)
483                 res_src = src.resolve()
484                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
485                     changed = Changed.CACHED
486             if changed is not Changed.CACHED and format_file_in_place(
487                 src, fast=fast, write_back=write_back, mode=mode
488             ):
489                 changed = Changed.YES
490             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
491                 write_back is WriteBack.CHECK and changed is Changed.NO
492             ):
493                 write_cache(cache, [src], mode)
494         report.done(src, changed)
495     except Exception as exc:
496         report.failed(src, str(exc))
497
498
499 def reformat_many(
500     sources: Set[Path],
501     fast: bool,
502     write_back: WriteBack,
503     mode: FileMode,
504     report: "Report",
505 ) -> None:
506     """Reformat multiple files using a ProcessPoolExecutor."""
507     loop = asyncio.get_event_loop()
508     worker_count = os.cpu_count()
509     if sys.platform == "win32":
510         # Work around https://bugs.python.org/issue26903
511         worker_count = min(worker_count, 61)
512     executor = ProcessPoolExecutor(max_workers=worker_count)
513     try:
514         loop.run_until_complete(
515             schedule_formatting(
516                 sources=sources,
517                 fast=fast,
518                 write_back=write_back,
519                 mode=mode,
520                 report=report,
521                 loop=loop,
522                 executor=executor,
523             )
524         )
525     finally:
526         shutdown(loop)
527
528
529 async def schedule_formatting(
530     sources: Set[Path],
531     fast: bool,
532     write_back: WriteBack,
533     mode: FileMode,
534     report: "Report",
535     loop: BaseEventLoop,
536     executor: Executor,
537 ) -> None:
538     """Run formatting of `sources` in parallel using the provided `executor`.
539
540     (Use ProcessPoolExecutors for actual parallelism.)
541
542     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
543     :func:`format_file_in_place`.
544     """
545     cache: Cache = {}
546     if write_back != WriteBack.DIFF:
547         cache = read_cache(mode)
548         sources, cached = filter_cached(cache, sources)
549         for src in sorted(cached):
550             report.done(src, Changed.CACHED)
551     if not sources:
552         return
553
554     cancelled = []
555     sources_to_cache = []
556     lock = None
557     if write_back == WriteBack.DIFF:
558         # For diff output, we need locks to ensure we don't interleave output
559         # from different processes.
560         manager = Manager()
561         lock = manager.Lock()
562     tasks = {
563         asyncio.ensure_future(
564             loop.run_in_executor(
565                 executor, format_file_in_place, src, fast, mode, write_back, lock
566             )
567         ): src
568         for src in sorted(sources)
569     }
570     pending: Iterable[asyncio.Future] = tasks.keys()
571     try:
572         loop.add_signal_handler(signal.SIGINT, cancel, pending)
573         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
574     except NotImplementedError:
575         # There are no good alternatives for these on Windows.
576         pass
577     while pending:
578         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
579         for task in done:
580             src = tasks.pop(task)
581             if task.cancelled():
582                 cancelled.append(task)
583             elif task.exception():
584                 report.failed(src, str(task.exception()))
585             else:
586                 changed = Changed.YES if task.result() else Changed.NO
587                 # If the file was written back or was successfully checked as
588                 # well-formatted, store this information in the cache.
589                 if write_back is WriteBack.YES or (
590                     write_back is WriteBack.CHECK and changed is Changed.NO
591                 ):
592                     sources_to_cache.append(src)
593                 report.done(src, changed)
594     if cancelled:
595         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
596     if sources_to_cache:
597         write_cache(cache, sources_to_cache, mode)
598
599
600 def format_file_in_place(
601     src: Path,
602     fast: bool,
603     mode: FileMode,
604     write_back: WriteBack = WriteBack.NO,
605     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
606 ) -> bool:
607     """Format file under `src` path. Return True if changed.
608
609     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
610     code to the file.
611     `mode` and `fast` options are passed to :func:`format_file_contents`.
612     """
613     if src.suffix == ".pyi":
614         mode = evolve(mode, is_pyi=True)
615
616     then = datetime.utcfromtimestamp(src.stat().st_mtime)
617     with open(src, "rb") as buf:
618         src_contents, encoding, newline = decode_bytes(buf.read())
619     try:
620         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
621     except NothingChanged:
622         return False
623
624     if write_back == write_back.YES:
625         with open(src, "w", encoding=encoding, newline=newline) as f:
626             f.write(dst_contents)
627     elif write_back == write_back.DIFF:
628         now = datetime.utcnow()
629         src_name = f"{src}\t{then} +0000"
630         dst_name = f"{src}\t{now} +0000"
631         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
632         if lock:
633             lock.acquire()
634         try:
635             f = io.TextIOWrapper(
636                 sys.stdout.buffer,
637                 encoding=encoding,
638                 newline=newline,
639                 write_through=True,
640             )
641             f.write(diff_contents)
642             f.detach()
643         finally:
644             if lock:
645                 lock.release()
646     return True
647
648
649 def format_stdin_to_stdout(
650     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
651 ) -> bool:
652     """Format file on stdin. Return True if changed.
653
654     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
655     write a diff to stdout. The `mode` argument is passed to
656     :func:`format_file_contents`.
657     """
658     then = datetime.utcnow()
659     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
660     dst = src
661     try:
662         dst = format_file_contents(src, fast=fast, mode=mode)
663         return True
664
665     except NothingChanged:
666         return False
667
668     finally:
669         f = io.TextIOWrapper(
670             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
671         )
672         if write_back == WriteBack.YES:
673             f.write(dst)
674         elif write_back == WriteBack.DIFF:
675             now = datetime.utcnow()
676             src_name = f"STDIN\t{then} +0000"
677             dst_name = f"STDOUT\t{now} +0000"
678             f.write(diff(src, dst, src_name, dst_name))
679         f.detach()
680
681
682 def format_file_contents(
683     src_contents: str, *, fast: bool, mode: FileMode
684 ) -> FileContent:
685     """Reformat contents a file and return new contents.
686
687     If `fast` is False, additionally confirm that the reformatted code is
688     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
689     `mode` is passed to :func:`format_str`.
690     """
691     if src_contents.strip() == "":
692         raise NothingChanged
693
694     dst_contents = format_str(src_contents, mode=mode)
695     if src_contents == dst_contents:
696         raise NothingChanged
697
698     if not fast:
699         assert_equivalent(src_contents, dst_contents)
700         assert_stable(src_contents, dst_contents, mode=mode)
701     return dst_contents
702
703
704 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
705     """Reformat a string and return new contents.
706
707     `mode` determines formatting options, such as how many characters per line are
708     allowed.
709     """
710     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
711     dst_contents = []
712     future_imports = get_future_imports(src_node)
713     if mode.target_versions:
714         versions = mode.target_versions
715     else:
716         versions = detect_target_versions(src_node)
717     normalize_fmt_off(src_node)
718     lines = LineGenerator(
719         remove_u_prefix="unicode_literals" in future_imports
720         or supports_feature(versions, Feature.UNICODE_LITERALS),
721         is_pyi=mode.is_pyi,
722         normalize_strings=mode.string_normalization,
723     )
724     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
725     empty_line = Line()
726     after = 0
727     split_line_features = {
728         feature
729         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
730         if supports_feature(versions, feature)
731     }
732     for current_line in lines.visit(src_node):
733         for _ in range(after):
734             dst_contents.append(str(empty_line))
735         before, after = elt.maybe_empty_lines(current_line)
736         for _ in range(before):
737             dst_contents.append(str(empty_line))
738         for line in split_line(
739             current_line, line_length=mode.line_length, features=split_line_features
740         ):
741             dst_contents.append(str(line))
742     return "".join(dst_contents)
743
744
745 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
746     """Return a tuple of (decoded_contents, encoding, newline).
747
748     `newline` is either CRLF or LF but `decoded_contents` is decoded with
749     universal newlines (i.e. only contains LF).
750     """
751     srcbuf = io.BytesIO(src)
752     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
753     if not lines:
754         return "", encoding, "\n"
755
756     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
757     srcbuf.seek(0)
758     with io.TextIOWrapper(srcbuf, encoding) as tiow:
759         return tiow.read(), encoding, newline
760
761
762 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
763     if not target_versions:
764         # No target_version specified, so try all grammars.
765         return [
766             # Python 3.7+
767             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
768             # Python 3.0-3.6
769             pygram.python_grammar_no_print_statement_no_exec_statement,
770             # Python 2.7 with future print_function import
771             pygram.python_grammar_no_print_statement,
772             # Python 2.7
773             pygram.python_grammar,
774         ]
775     elif all(version.is_python2() for version in target_versions):
776         # Python 2-only code, so try Python 2 grammars.
777         return [
778             # Python 2.7 with future print_function import
779             pygram.python_grammar_no_print_statement,
780             # Python 2.7
781             pygram.python_grammar,
782         ]
783     else:
784         # Python 3-compatible code, so only try Python 3 grammar.
785         grammars = []
786         # If we have to parse both, try to parse async as a keyword first
787         if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
788             # Python 3.7+
789             grammars.append(
790                 pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
791             )
792         if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
793             # Python 3.0-3.6
794             grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
795         # At least one of the above branches must have been taken, because every Python
796         # version has exactly one of the two 'ASYNC_*' flags
797         return grammars
798
799
800 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
801     """Given a string with source, return the lib2to3 Node."""
802     if src_txt[-1:] != "\n":
803         src_txt += "\n"
804
805     for grammar in get_grammars(set(target_versions)):
806         drv = driver.Driver(grammar, pytree.convert)
807         try:
808             result = drv.parse_string(src_txt, True)
809             break
810
811         except ParseError as pe:
812             lineno, column = pe.context[1]
813             lines = src_txt.splitlines()
814             try:
815                 faulty_line = lines[lineno - 1]
816             except IndexError:
817                 faulty_line = "<line number missing in source>"
818             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
819     else:
820         raise exc from None
821
822     if isinstance(result, Leaf):
823         result = Node(syms.file_input, [result])
824     return result
825
826
827 def lib2to3_unparse(node: Node) -> str:
828     """Given a lib2to3 node, return its string representation."""
829     code = str(node)
830     return code
831
832
833 T = TypeVar("T")
834
835
836 class Visitor(Generic[T]):
837     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
838
839     def visit(self, node: LN) -> Iterator[T]:
840         """Main method to visit `node` and its children.
841
842         It tries to find a `visit_*()` method for the given `node.type`, like
843         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
844         If no dedicated `visit_*()` method is found, chooses `visit_default()`
845         instead.
846
847         Then yields objects of type `T` from the selected visitor.
848         """
849         if node.type < 256:
850             name = token.tok_name[node.type]
851         else:
852             name = type_repr(node.type)
853         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
854
855     def visit_default(self, node: LN) -> Iterator[T]:
856         """Default `visit_*()` implementation. Recurses to children of `node`."""
857         if isinstance(node, Node):
858             for child in node.children:
859                 yield from self.visit(child)
860
861
862 @dataclass
863 class DebugVisitor(Visitor[T]):
864     tree_depth: int = 0
865
866     def visit_default(self, node: LN) -> Iterator[T]:
867         indent = " " * (2 * self.tree_depth)
868         if isinstance(node, Node):
869             _type = type_repr(node.type)
870             out(f"{indent}{_type}", fg="yellow")
871             self.tree_depth += 1
872             for child in node.children:
873                 yield from self.visit(child)
874
875             self.tree_depth -= 1
876             out(f"{indent}/{_type}", fg="yellow", bold=False)
877         else:
878             _type = token.tok_name.get(node.type, str(node.type))
879             out(f"{indent}{_type}", fg="blue", nl=False)
880             if node.prefix:
881                 # We don't have to handle prefixes for `Node` objects since
882                 # that delegates to the first child anyway.
883                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
884             out(f" {node.value!r}", fg="blue", bold=False)
885
886     @classmethod
887     def show(cls, code: Union[str, Leaf, Node]) -> None:
888         """Pretty-print the lib2to3 AST of a given string of `code`.
889
890         Convenience method for debugging.
891         """
892         v: DebugVisitor[None] = DebugVisitor()
893         if isinstance(code, str):
894             code = lib2to3_parse(code)
895         list(v.visit(code))
896
897
898 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
899 STATEMENT = {
900     syms.if_stmt,
901     syms.while_stmt,
902     syms.for_stmt,
903     syms.try_stmt,
904     syms.except_clause,
905     syms.with_stmt,
906     syms.funcdef,
907     syms.classdef,
908 }
909 STANDALONE_COMMENT = 153
910 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
911 LOGIC_OPERATORS = {"and", "or"}
912 COMPARATORS = {
913     token.LESS,
914     token.GREATER,
915     token.EQEQUAL,
916     token.NOTEQUAL,
917     token.LESSEQUAL,
918     token.GREATEREQUAL,
919 }
920 MATH_OPERATORS = {
921     token.VBAR,
922     token.CIRCUMFLEX,
923     token.AMPER,
924     token.LEFTSHIFT,
925     token.RIGHTSHIFT,
926     token.PLUS,
927     token.MINUS,
928     token.STAR,
929     token.SLASH,
930     token.DOUBLESLASH,
931     token.PERCENT,
932     token.AT,
933     token.TILDE,
934     token.DOUBLESTAR,
935 }
936 STARS = {token.STAR, token.DOUBLESTAR}
937 VARARGS_PARENTS = {
938     syms.arglist,
939     syms.argument,  # double star in arglist
940     syms.trailer,  # single argument to call
941     syms.typedargslist,
942     syms.varargslist,  # lambdas
943 }
944 UNPACKING_PARENTS = {
945     syms.atom,  # single element of a list or set literal
946     syms.dictsetmaker,
947     syms.listmaker,
948     syms.testlist_gexp,
949     syms.testlist_star_expr,
950 }
951 TEST_DESCENDANTS = {
952     syms.test,
953     syms.lambdef,
954     syms.or_test,
955     syms.and_test,
956     syms.not_test,
957     syms.comparison,
958     syms.star_expr,
959     syms.expr,
960     syms.xor_expr,
961     syms.and_expr,
962     syms.shift_expr,
963     syms.arith_expr,
964     syms.trailer,
965     syms.term,
966     syms.power,
967 }
968 ASSIGNMENTS = {
969     "=",
970     "+=",
971     "-=",
972     "*=",
973     "@=",
974     "/=",
975     "%=",
976     "&=",
977     "|=",
978     "^=",
979     "<<=",
980     ">>=",
981     "**=",
982     "//=",
983 }
984 COMPREHENSION_PRIORITY = 20
985 COMMA_PRIORITY = 18
986 TERNARY_PRIORITY = 16
987 LOGIC_PRIORITY = 14
988 STRING_PRIORITY = 12
989 COMPARATOR_PRIORITY = 10
990 MATH_PRIORITIES = {
991     token.VBAR: 9,
992     token.CIRCUMFLEX: 8,
993     token.AMPER: 7,
994     token.LEFTSHIFT: 6,
995     token.RIGHTSHIFT: 6,
996     token.PLUS: 5,
997     token.MINUS: 5,
998     token.STAR: 4,
999     token.SLASH: 4,
1000     token.DOUBLESLASH: 4,
1001     token.PERCENT: 4,
1002     token.AT: 4,
1003     token.TILDE: 3,
1004     token.DOUBLESTAR: 2,
1005 }
1006 DOT_PRIORITY = 1
1007
1008
1009 @dataclass
1010 class BracketTracker:
1011     """Keeps track of brackets on a line."""
1012
1013     depth: int = 0
1014     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
1015     delimiters: Dict[LeafID, Priority] = Factory(dict)
1016     previous: Optional[Leaf] = None
1017     _for_loop_depths: List[int] = Factory(list)
1018     _lambda_argument_depths: List[int] = Factory(list)
1019
1020     def mark(self, leaf: Leaf) -> None:
1021         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1022
1023         All leaves receive an int `bracket_depth` field that stores how deep
1024         within brackets a given leaf is. 0 means there are no enclosing brackets
1025         that started on this line.
1026
1027         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1028         field that it forms a pair with. This is a one-directional link to
1029         avoid reference cycles.
1030
1031         If a leaf is a delimiter (a token on which Black can split the line if
1032         needed) and it's on depth 0, its `id()` is stored in the tracker's
1033         `delimiters` field.
1034         """
1035         if leaf.type == token.COMMENT:
1036             return
1037
1038         self.maybe_decrement_after_for_loop_variable(leaf)
1039         self.maybe_decrement_after_lambda_arguments(leaf)
1040         if leaf.type in CLOSING_BRACKETS:
1041             self.depth -= 1
1042             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1043             leaf.opening_bracket = opening_bracket
1044         leaf.bracket_depth = self.depth
1045         if self.depth == 0:
1046             delim = is_split_before_delimiter(leaf, self.previous)
1047             if delim and self.previous is not None:
1048                 self.delimiters[id(self.previous)] = delim
1049             else:
1050                 delim = is_split_after_delimiter(leaf, self.previous)
1051                 if delim:
1052                     self.delimiters[id(leaf)] = delim
1053         if leaf.type in OPENING_BRACKETS:
1054             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1055             self.depth += 1
1056         self.previous = leaf
1057         self.maybe_increment_lambda_arguments(leaf)
1058         self.maybe_increment_for_loop_variable(leaf)
1059
1060     def any_open_brackets(self) -> bool:
1061         """Return True if there is an yet unmatched open bracket on the line."""
1062         return bool(self.bracket_match)
1063
1064     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1065         """Return the highest priority of a delimiter found on the line.
1066
1067         Values are consistent with what `is_split_*_delimiter()` return.
1068         Raises ValueError on no delimiters.
1069         """
1070         return max(v for k, v in self.delimiters.items() if k not in exclude)
1071
1072     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1073         """Return the number of delimiters with the given `priority`.
1074
1075         If no `priority` is passed, defaults to max priority on the line.
1076         """
1077         if not self.delimiters:
1078             return 0
1079
1080         priority = priority or self.max_delimiter_priority()
1081         return sum(1 for p in self.delimiters.values() if p == priority)
1082
1083     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1084         """In a for loop, or comprehension, the variables are often unpacks.
1085
1086         To avoid splitting on the comma in this situation, increase the depth of
1087         tokens between `for` and `in`.
1088         """
1089         if leaf.type == token.NAME and leaf.value == "for":
1090             self.depth += 1
1091             self._for_loop_depths.append(self.depth)
1092             return True
1093
1094         return False
1095
1096     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1097         """See `maybe_increment_for_loop_variable` above for explanation."""
1098         if (
1099             self._for_loop_depths
1100             and self._for_loop_depths[-1] == self.depth
1101             and leaf.type == token.NAME
1102             and leaf.value == "in"
1103         ):
1104             self.depth -= 1
1105             self._for_loop_depths.pop()
1106             return True
1107
1108         return False
1109
1110     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1111         """In a lambda expression, there might be more than one argument.
1112
1113         To avoid splitting on the comma in this situation, increase the depth of
1114         tokens between `lambda` and `:`.
1115         """
1116         if leaf.type == token.NAME and leaf.value == "lambda":
1117             self.depth += 1
1118             self._lambda_argument_depths.append(self.depth)
1119             return True
1120
1121         return False
1122
1123     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1124         """See `maybe_increment_lambda_arguments` above for explanation."""
1125         if (
1126             self._lambda_argument_depths
1127             and self._lambda_argument_depths[-1] == self.depth
1128             and leaf.type == token.COLON
1129         ):
1130             self.depth -= 1
1131             self._lambda_argument_depths.pop()
1132             return True
1133
1134         return False
1135
1136     def get_open_lsqb(self) -> Optional[Leaf]:
1137         """Return the most recent opening square bracket (if any)."""
1138         return self.bracket_match.get((self.depth - 1, token.RSQB))
1139
1140
1141 @dataclass
1142 class Line:
1143     """Holds leaves and comments. Can be printed with `str(line)`."""
1144
1145     depth: int = 0
1146     leaves: List[Leaf] = Factory(list)
1147     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1148     bracket_tracker: BracketTracker = Factory(BracketTracker)
1149     inside_brackets: bool = False
1150     should_explode: bool = False
1151
1152     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1153         """Add a new `leaf` to the end of the line.
1154
1155         Unless `preformatted` is True, the `leaf` will receive a new consistent
1156         whitespace prefix and metadata applied by :class:`BracketTracker`.
1157         Trailing commas are maybe removed, unpacked for loop variables are
1158         demoted from being delimiters.
1159
1160         Inline comments are put aside.
1161         """
1162         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1163         if not has_value:
1164             return
1165
1166         if token.COLON == leaf.type and self.is_class_paren_empty:
1167             del self.leaves[-2:]
1168         if self.leaves and not preformatted:
1169             # Note: at this point leaf.prefix should be empty except for
1170             # imports, for which we only preserve newlines.
1171             leaf.prefix += whitespace(
1172                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1173             )
1174         if self.inside_brackets or not preformatted:
1175             self.bracket_tracker.mark(leaf)
1176             self.maybe_remove_trailing_comma(leaf)
1177         if not self.append_comment(leaf):
1178             self.leaves.append(leaf)
1179
1180     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1181         """Like :func:`append()` but disallow invalid standalone comment structure.
1182
1183         Raises ValueError when any `leaf` is appended after a standalone comment
1184         or when a standalone comment is not the first leaf on the line.
1185         """
1186         if self.bracket_tracker.depth == 0:
1187             if self.is_comment:
1188                 raise ValueError("cannot append to standalone comments")
1189
1190             if self.leaves and leaf.type == STANDALONE_COMMENT:
1191                 raise ValueError(
1192                     "cannot append standalone comments to a populated line"
1193                 )
1194
1195         self.append(leaf, preformatted=preformatted)
1196
1197     @property
1198     def is_comment(self) -> bool:
1199         """Is this line a standalone comment?"""
1200         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1201
1202     @property
1203     def is_decorator(self) -> bool:
1204         """Is this line a decorator?"""
1205         return bool(self) and self.leaves[0].type == token.AT
1206
1207     @property
1208     def is_import(self) -> bool:
1209         """Is this an import line?"""
1210         return bool(self) and is_import(self.leaves[0])
1211
1212     @property
1213     def is_class(self) -> bool:
1214         """Is this line a class definition?"""
1215         return (
1216             bool(self)
1217             and self.leaves[0].type == token.NAME
1218             and self.leaves[0].value == "class"
1219         )
1220
1221     @property
1222     def is_stub_class(self) -> bool:
1223         """Is this line a class definition with a body consisting only of "..."?"""
1224         return self.is_class and self.leaves[-3:] == [
1225             Leaf(token.DOT, ".") for _ in range(3)
1226         ]
1227
1228     @property
1229     def is_def(self) -> bool:
1230         """Is this a function definition? (Also returns True for async defs.)"""
1231         try:
1232             first_leaf = self.leaves[0]
1233         except IndexError:
1234             return False
1235
1236         try:
1237             second_leaf: Optional[Leaf] = self.leaves[1]
1238         except IndexError:
1239             second_leaf = None
1240         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1241             first_leaf.type == token.ASYNC
1242             and second_leaf is not None
1243             and second_leaf.type == token.NAME
1244             and second_leaf.value == "def"
1245         )
1246
1247     @property
1248     def is_class_paren_empty(self) -> bool:
1249         """Is this a class with no base classes but using parentheses?
1250
1251         Those are unnecessary and should be removed.
1252         """
1253         return (
1254             bool(self)
1255             and len(self.leaves) == 4
1256             and self.is_class
1257             and self.leaves[2].type == token.LPAR
1258             and self.leaves[2].value == "("
1259             and self.leaves[3].type == token.RPAR
1260             and self.leaves[3].value == ")"
1261         )
1262
1263     @property
1264     def is_triple_quoted_string(self) -> bool:
1265         """Is the line a triple quoted string?"""
1266         return (
1267             bool(self)
1268             and self.leaves[0].type == token.STRING
1269             and self.leaves[0].value.startswith(('"""', "'''"))
1270         )
1271
1272     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1273         """If so, needs to be split before emitting."""
1274         for leaf in self.leaves:
1275             if leaf.type == STANDALONE_COMMENT:
1276                 if leaf.bracket_depth <= depth_limit:
1277                     return True
1278         return False
1279
1280     def contains_inner_type_comments(self) -> bool:
1281         ignored_ids = set()
1282         try:
1283             last_leaf = self.leaves[-1]
1284             ignored_ids.add(id(last_leaf))
1285             if last_leaf.type == token.COMMA or (
1286                 last_leaf.type == token.RPAR and not last_leaf.value
1287             ):
1288                 # When trailing commas or optional parens are inserted by Black for
1289                 # consistency, comments after the previous last element are not moved
1290                 # (they don't have to, rendering will still be correct).  So we ignore
1291                 # trailing commas and invisible.
1292                 last_leaf = self.leaves[-2]
1293                 ignored_ids.add(id(last_leaf))
1294         except IndexError:
1295             return False
1296
1297         for leaf_id, comments in self.comments.items():
1298             if leaf_id in ignored_ids:
1299                 continue
1300
1301             for comment in comments:
1302                 if is_type_comment(comment):
1303                     return True
1304
1305         return False
1306
1307     def contains_multiline_strings(self) -> bool:
1308         for leaf in self.leaves:
1309             if is_multiline_string(leaf):
1310                 return True
1311
1312         return False
1313
1314     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1315         """Remove trailing comma if there is one and it's safe."""
1316         if not (
1317             self.leaves
1318             and self.leaves[-1].type == token.COMMA
1319             and closing.type in CLOSING_BRACKETS
1320         ):
1321             return False
1322
1323         if closing.type == token.RBRACE:
1324             self.remove_trailing_comma()
1325             return True
1326
1327         if closing.type == token.RSQB:
1328             comma = self.leaves[-1]
1329             if comma.parent and comma.parent.type == syms.listmaker:
1330                 self.remove_trailing_comma()
1331                 return True
1332
1333         # For parens let's check if it's safe to remove the comma.
1334         # Imports are always safe.
1335         if self.is_import:
1336             self.remove_trailing_comma()
1337             return True
1338
1339         # Otherwise, if the trailing one is the only one, we might mistakenly
1340         # change a tuple into a different type by removing the comma.
1341         depth = closing.bracket_depth + 1
1342         commas = 0
1343         opening = closing.opening_bracket
1344         for _opening_index, leaf in enumerate(self.leaves):
1345             if leaf is opening:
1346                 break
1347
1348         else:
1349             return False
1350
1351         for leaf in self.leaves[_opening_index + 1 :]:
1352             if leaf is closing:
1353                 break
1354
1355             bracket_depth = leaf.bracket_depth
1356             if bracket_depth == depth and leaf.type == token.COMMA:
1357                 commas += 1
1358                 if leaf.parent and leaf.parent.type in {
1359                     syms.arglist,
1360                     syms.typedargslist,
1361                 }:
1362                     commas += 1
1363                     break
1364
1365         if commas > 1:
1366             self.remove_trailing_comma()
1367             return True
1368
1369         return False
1370
1371     def append_comment(self, comment: Leaf) -> bool:
1372         """Add an inline or standalone comment to the line."""
1373         if (
1374             comment.type == STANDALONE_COMMENT
1375             and self.bracket_tracker.any_open_brackets()
1376         ):
1377             comment.prefix = ""
1378             return False
1379
1380         if comment.type != token.COMMENT:
1381             return False
1382
1383         if not self.leaves:
1384             comment.type = STANDALONE_COMMENT
1385             comment.prefix = ""
1386             return False
1387
1388         last_leaf = self.leaves[-1]
1389         if (
1390             last_leaf.type == token.RPAR
1391             and not last_leaf.value
1392             and last_leaf.parent
1393             and len(list(last_leaf.parent.leaves())) <= 3
1394             and not is_type_comment(comment)
1395         ):
1396             # Comments on an optional parens wrapping a single leaf should belong to
1397             # the wrapped node except if it's a type comment. Pinning the comment like
1398             # this avoids unstable formatting caused by comment migration.
1399             if len(self.leaves) < 2:
1400                 comment.type = STANDALONE_COMMENT
1401                 comment.prefix = ""
1402                 return False
1403             last_leaf = self.leaves[-2]
1404         self.comments.setdefault(id(last_leaf), []).append(comment)
1405         return True
1406
1407     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1408         """Generate comments that should appear directly after `leaf`."""
1409         return self.comments.get(id(leaf), [])
1410
1411     def remove_trailing_comma(self) -> None:
1412         """Remove the trailing comma and moves the comments attached to it."""
1413         trailing_comma = self.leaves.pop()
1414         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1415         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1416             trailing_comma_comments
1417         )
1418
1419     def is_complex_subscript(self, leaf: Leaf) -> bool:
1420         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1421         open_lsqb = self.bracket_tracker.get_open_lsqb()
1422         if open_lsqb is None:
1423             return False
1424
1425         subscript_start = open_lsqb.next_sibling
1426
1427         if isinstance(subscript_start, Node):
1428             if subscript_start.type == syms.listmaker:
1429                 return False
1430
1431             if subscript_start.type == syms.subscriptlist:
1432                 subscript_start = child_towards(subscript_start, leaf)
1433         return subscript_start is not None and any(
1434             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1435         )
1436
1437     def __str__(self) -> str:
1438         """Render the line."""
1439         if not self:
1440             return "\n"
1441
1442         indent = "    " * self.depth
1443         leaves = iter(self.leaves)
1444         first = next(leaves)
1445         res = f"{first.prefix}{indent}{first.value}"
1446         for leaf in leaves:
1447             res += str(leaf)
1448         for comment in itertools.chain.from_iterable(self.comments.values()):
1449             res += str(comment)
1450         return res + "\n"
1451
1452     def __bool__(self) -> bool:
1453         """Return True if the line has leaves or comments."""
1454         return bool(self.leaves or self.comments)
1455
1456
1457 @dataclass
1458 class EmptyLineTracker:
1459     """Provides a stateful method that returns the number of potential extra
1460     empty lines needed before and after the currently processed line.
1461
1462     Note: this tracker works on lines that haven't been split yet.  It assumes
1463     the prefix of the first leaf consists of optional newlines.  Those newlines
1464     are consumed by `maybe_empty_lines()` and included in the computation.
1465     """
1466
1467     is_pyi: bool = False
1468     previous_line: Optional[Line] = None
1469     previous_after: int = 0
1470     previous_defs: List[int] = Factory(list)
1471
1472     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1473         """Return the number of extra empty lines before and after the `current_line`.
1474
1475         This is for separating `def`, `async def` and `class` with extra empty
1476         lines (two on module-level).
1477         """
1478         before, after = self._maybe_empty_lines(current_line)
1479         before -= self.previous_after
1480         self.previous_after = after
1481         self.previous_line = current_line
1482         return before, after
1483
1484     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1485         max_allowed = 1
1486         if current_line.depth == 0:
1487             max_allowed = 1 if self.is_pyi else 2
1488         if current_line.leaves:
1489             # Consume the first leaf's extra newlines.
1490             first_leaf = current_line.leaves[0]
1491             before = first_leaf.prefix.count("\n")
1492             before = min(before, max_allowed)
1493             first_leaf.prefix = ""
1494         else:
1495             before = 0
1496         depth = current_line.depth
1497         while self.previous_defs and self.previous_defs[-1] >= depth:
1498             self.previous_defs.pop()
1499             if self.is_pyi:
1500                 before = 0 if depth else 1
1501             else:
1502                 before = 1 if depth else 2
1503         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1504             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1505
1506         if (
1507             self.previous_line
1508             and self.previous_line.is_import
1509             and not current_line.is_import
1510             and depth == self.previous_line.depth
1511         ):
1512             return (before or 1), 0
1513
1514         if (
1515             self.previous_line
1516             and self.previous_line.is_class
1517             and current_line.is_triple_quoted_string
1518         ):
1519             return before, 1
1520
1521         return before, 0
1522
1523     def _maybe_empty_lines_for_class_or_def(
1524         self, current_line: Line, before: int
1525     ) -> Tuple[int, int]:
1526         if not current_line.is_decorator:
1527             self.previous_defs.append(current_line.depth)
1528         if self.previous_line is None:
1529             # Don't insert empty lines before the first line in the file.
1530             return 0, 0
1531
1532         if self.previous_line.is_decorator:
1533             return 0, 0
1534
1535         if self.previous_line.depth < current_line.depth and (
1536             self.previous_line.is_class or self.previous_line.is_def
1537         ):
1538             return 0, 0
1539
1540         if (
1541             self.previous_line.is_comment
1542             and self.previous_line.depth == current_line.depth
1543             and before == 0
1544         ):
1545             return 0, 0
1546
1547         if self.is_pyi:
1548             if self.previous_line.depth > current_line.depth:
1549                 newlines = 1
1550             elif current_line.is_class or self.previous_line.is_class:
1551                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1552                     # No blank line between classes with an empty body
1553                     newlines = 0
1554                 else:
1555                     newlines = 1
1556             elif current_line.is_def and not self.previous_line.is_def:
1557                 # Blank line between a block of functions and a block of non-functions
1558                 newlines = 1
1559             else:
1560                 newlines = 0
1561         else:
1562             newlines = 2
1563         if current_line.depth and newlines:
1564             newlines -= 1
1565         return newlines, 0
1566
1567
1568 @dataclass
1569 class LineGenerator(Visitor[Line]):
1570     """Generates reformatted Line objects.  Empty lines are not emitted.
1571
1572     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1573     in ways that will no longer stringify to valid Python code on the tree.
1574     """
1575
1576     is_pyi: bool = False
1577     normalize_strings: bool = True
1578     current_line: Line = Factory(Line)
1579     remove_u_prefix: bool = False
1580
1581     def line(self, indent: int = 0) -> Iterator[Line]:
1582         """Generate a line.
1583
1584         If the line is empty, only emit if it makes sense.
1585         If the line is too long, split it first and then generate.
1586
1587         If any lines were generated, set up a new current_line.
1588         """
1589         if not self.current_line:
1590             self.current_line.depth += indent
1591             return  # Line is empty, don't emit. Creating a new one unnecessary.
1592
1593         complete_line = self.current_line
1594         self.current_line = Line(depth=complete_line.depth + indent)
1595         yield complete_line
1596
1597     def visit_default(self, node: LN) -> Iterator[Line]:
1598         """Default `visit_*()` implementation. Recurses to children of `node`."""
1599         if isinstance(node, Leaf):
1600             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1601             for comment in generate_comments(node):
1602                 if any_open_brackets:
1603                     # any comment within brackets is subject to splitting
1604                     self.current_line.append(comment)
1605                 elif comment.type == token.COMMENT:
1606                     # regular trailing comment
1607                     self.current_line.append(comment)
1608                     yield from self.line()
1609
1610                 else:
1611                     # regular standalone comment
1612                     yield from self.line()
1613
1614                     self.current_line.append(comment)
1615                     yield from self.line()
1616
1617             normalize_prefix(node, inside_brackets=any_open_brackets)
1618             if self.normalize_strings and node.type == token.STRING:
1619                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1620                 normalize_string_quotes(node)
1621             if node.type == token.NUMBER:
1622                 normalize_numeric_literal(node)
1623             if node.type not in WHITESPACE:
1624                 self.current_line.append(node)
1625         yield from super().visit_default(node)
1626
1627     def visit_atom(self, node: Node) -> Iterator[Line]:
1628         # Always make parentheses invisible around a single node, because it should
1629         # not be needed (except in the case of yield, where removing the parentheses
1630         # produces a SyntaxError).
1631         if (
1632             len(node.children) == 3
1633             and isinstance(node.children[0], Leaf)
1634             and node.children[0].type == token.LPAR
1635             and isinstance(node.children[2], Leaf)
1636             and node.children[2].type == token.RPAR
1637             and isinstance(node.children[1], Leaf)
1638             and not (
1639                 node.children[1].type == token.NAME
1640                 and node.children[1].value == "yield"
1641             )
1642         ):
1643             node.children[0].value = ""
1644             node.children[2].value = ""
1645         yield from super().visit_default(node)
1646
1647     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1648         """Increase indentation level, maybe yield a line."""
1649         # In blib2to3 INDENT never holds comments.
1650         yield from self.line(+1)
1651         yield from self.visit_default(node)
1652
1653     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1654         """Decrease indentation level, maybe yield a line."""
1655         # The current line might still wait for trailing comments.  At DEDENT time
1656         # there won't be any (they would be prefixes on the preceding NEWLINE).
1657         # Emit the line then.
1658         yield from self.line()
1659
1660         # While DEDENT has no value, its prefix may contain standalone comments
1661         # that belong to the current indentation level.  Get 'em.
1662         yield from self.visit_default(node)
1663
1664         # Finally, emit the dedent.
1665         yield from self.line(-1)
1666
1667     def visit_stmt(
1668         self, node: Node, keywords: Set[str], parens: Set[str]
1669     ) -> Iterator[Line]:
1670         """Visit a statement.
1671
1672         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1673         `def`, `with`, `class`, `assert` and assignments.
1674
1675         The relevant Python language `keywords` for a given statement will be
1676         NAME leaves within it. This methods puts those on a separate line.
1677
1678         `parens` holds a set of string leaf values immediately after which
1679         invisible parens should be put.
1680         """
1681         normalize_invisible_parens(node, parens_after=parens)
1682         for child in node.children:
1683             if child.type == token.NAME and child.value in keywords:  # type: ignore
1684                 yield from self.line()
1685
1686             yield from self.visit(child)
1687
1688     def visit_suite(self, node: Node) -> Iterator[Line]:
1689         """Visit a suite."""
1690         if self.is_pyi and is_stub_suite(node):
1691             yield from self.visit(node.children[2])
1692         else:
1693             yield from self.visit_default(node)
1694
1695     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1696         """Visit a statement without nested statements."""
1697         is_suite_like = node.parent and node.parent.type in STATEMENT
1698         if is_suite_like:
1699             if self.is_pyi and is_stub_body(node):
1700                 yield from self.visit_default(node)
1701             else:
1702                 yield from self.line(+1)
1703                 yield from self.visit_default(node)
1704                 yield from self.line(-1)
1705
1706         else:
1707             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1708                 yield from self.line()
1709             yield from self.visit_default(node)
1710
1711     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1712         """Visit `async def`, `async for`, `async with`."""
1713         yield from self.line()
1714
1715         children = iter(node.children)
1716         for child in children:
1717             yield from self.visit(child)
1718
1719             if child.type == token.ASYNC:
1720                 break
1721
1722         internal_stmt = next(children)
1723         for child in internal_stmt.children:
1724             yield from self.visit(child)
1725
1726     def visit_decorators(self, node: Node) -> Iterator[Line]:
1727         """Visit decorators."""
1728         for child in node.children:
1729             yield from self.line()
1730             yield from self.visit(child)
1731
1732     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1733         """Remove a semicolon and put the other statement on a separate line."""
1734         yield from self.line()
1735
1736     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1737         """End of file. Process outstanding comments and end with a newline."""
1738         yield from self.visit_default(leaf)
1739         yield from self.line()
1740
1741     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1742         if not self.current_line.bracket_tracker.any_open_brackets():
1743             yield from self.line()
1744         yield from self.visit_default(leaf)
1745
1746     def __attrs_post_init__(self) -> None:
1747         """You are in a twisty little maze of passages."""
1748         v = self.visit_stmt
1749         Ø: Set[str] = set()
1750         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1751         self.visit_if_stmt = partial(
1752             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1753         )
1754         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1755         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1756         self.visit_try_stmt = partial(
1757             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1758         )
1759         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1760         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1761         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1762         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1763         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1764         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1765         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1766         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1767         self.visit_async_funcdef = self.visit_async_stmt
1768         self.visit_decorated = self.visit_decorators
1769
1770
1771 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1772 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1773 OPENING_BRACKETS = set(BRACKET.keys())
1774 CLOSING_BRACKETS = set(BRACKET.values())
1775 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1776 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1777
1778
1779 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1780     """Return whitespace prefix if needed for the given `leaf`.
1781
1782     `complex_subscript` signals whether the given leaf is part of a subscription
1783     which has non-trivial arguments, like arithmetic expressions or function calls.
1784     """
1785     NO = ""
1786     SPACE = " "
1787     DOUBLESPACE = "  "
1788     t = leaf.type
1789     p = leaf.parent
1790     v = leaf.value
1791     if t in ALWAYS_NO_SPACE:
1792         return NO
1793
1794     if t == token.COMMENT:
1795         return DOUBLESPACE
1796
1797     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1798     if t == token.COLON and p.type not in {
1799         syms.subscript,
1800         syms.subscriptlist,
1801         syms.sliceop,
1802     }:
1803         return NO
1804
1805     prev = leaf.prev_sibling
1806     if not prev:
1807         prevp = preceding_leaf(p)
1808         if not prevp or prevp.type in OPENING_BRACKETS:
1809             return NO
1810
1811         if t == token.COLON:
1812             if prevp.type == token.COLON:
1813                 return NO
1814
1815             elif prevp.type != token.COMMA and not complex_subscript:
1816                 return NO
1817
1818             return SPACE
1819
1820         if prevp.type == token.EQUAL:
1821             if prevp.parent:
1822                 if prevp.parent.type in {
1823                     syms.arglist,
1824                     syms.argument,
1825                     syms.parameters,
1826                     syms.varargslist,
1827                 }:
1828                     return NO
1829
1830                 elif prevp.parent.type == syms.typedargslist:
1831                     # A bit hacky: if the equal sign has whitespace, it means we
1832                     # previously found it's a typed argument.  So, we're using
1833                     # that, too.
1834                     return prevp.prefix
1835
1836         elif prevp.type in STARS:
1837             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1838                 return NO
1839
1840         elif prevp.type == token.COLON:
1841             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1842                 return SPACE if complex_subscript else NO
1843
1844         elif (
1845             prevp.parent
1846             and prevp.parent.type == syms.factor
1847             and prevp.type in MATH_OPERATORS
1848         ):
1849             return NO
1850
1851         elif (
1852             prevp.type == token.RIGHTSHIFT
1853             and prevp.parent
1854             and prevp.parent.type == syms.shift_expr
1855             and prevp.prev_sibling
1856             and prevp.prev_sibling.type == token.NAME
1857             and prevp.prev_sibling.value == "print"  # type: ignore
1858         ):
1859             # Python 2 print chevron
1860             return NO
1861
1862     elif prev.type in OPENING_BRACKETS:
1863         return NO
1864
1865     if p.type in {syms.parameters, syms.arglist}:
1866         # untyped function signatures or calls
1867         if not prev or prev.type != token.COMMA:
1868             return NO
1869
1870     elif p.type == syms.varargslist:
1871         # lambdas
1872         if prev and prev.type != token.COMMA:
1873             return NO
1874
1875     elif p.type == syms.typedargslist:
1876         # typed function signatures
1877         if not prev:
1878             return NO
1879
1880         if t == token.EQUAL:
1881             if prev.type != syms.tname:
1882                 return NO
1883
1884         elif prev.type == token.EQUAL:
1885             # A bit hacky: if the equal sign has whitespace, it means we
1886             # previously found it's a typed argument.  So, we're using that, too.
1887             return prev.prefix
1888
1889         elif prev.type != token.COMMA:
1890             return NO
1891
1892     elif p.type == syms.tname:
1893         # type names
1894         if not prev:
1895             prevp = preceding_leaf(p)
1896             if not prevp or prevp.type != token.COMMA:
1897                 return NO
1898
1899     elif p.type == syms.trailer:
1900         # attributes and calls
1901         if t == token.LPAR or t == token.RPAR:
1902             return NO
1903
1904         if not prev:
1905             if t == token.DOT:
1906                 prevp = preceding_leaf(p)
1907                 if not prevp or prevp.type != token.NUMBER:
1908                     return NO
1909
1910             elif t == token.LSQB:
1911                 return NO
1912
1913         elif prev.type != token.COMMA:
1914             return NO
1915
1916     elif p.type == syms.argument:
1917         # single argument
1918         if t == token.EQUAL:
1919             return NO
1920
1921         if not prev:
1922             prevp = preceding_leaf(p)
1923             if not prevp or prevp.type == token.LPAR:
1924                 return NO
1925
1926         elif prev.type in {token.EQUAL} | STARS:
1927             return NO
1928
1929     elif p.type == syms.decorator:
1930         # decorators
1931         return NO
1932
1933     elif p.type == syms.dotted_name:
1934         if prev:
1935             return NO
1936
1937         prevp = preceding_leaf(p)
1938         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1939             return NO
1940
1941     elif p.type == syms.classdef:
1942         if t == token.LPAR:
1943             return NO
1944
1945         if prev and prev.type == token.LPAR:
1946             return NO
1947
1948     elif p.type in {syms.subscript, syms.sliceop}:
1949         # indexing
1950         if not prev:
1951             assert p.parent is not None, "subscripts are always parented"
1952             if p.parent.type == syms.subscriptlist:
1953                 return SPACE
1954
1955             return NO
1956
1957         elif not complex_subscript:
1958             return NO
1959
1960     elif p.type == syms.atom:
1961         if prev and t == token.DOT:
1962             # dots, but not the first one.
1963             return NO
1964
1965     elif p.type == syms.dictsetmaker:
1966         # dict unpacking
1967         if prev and prev.type == token.DOUBLESTAR:
1968             return NO
1969
1970     elif p.type in {syms.factor, syms.star_expr}:
1971         # unary ops
1972         if not prev:
1973             prevp = preceding_leaf(p)
1974             if not prevp or prevp.type in OPENING_BRACKETS:
1975                 return NO
1976
1977             prevp_parent = prevp.parent
1978             assert prevp_parent is not None
1979             if prevp.type == token.COLON and prevp_parent.type in {
1980                 syms.subscript,
1981                 syms.sliceop,
1982             }:
1983                 return NO
1984
1985             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1986                 return NO
1987
1988         elif t in {token.NAME, token.NUMBER, token.STRING}:
1989             return NO
1990
1991     elif p.type == syms.import_from:
1992         if t == token.DOT:
1993             if prev and prev.type == token.DOT:
1994                 return NO
1995
1996         elif t == token.NAME:
1997             if v == "import":
1998                 return SPACE
1999
2000             if prev and prev.type == token.DOT:
2001                 return NO
2002
2003     elif p.type == syms.sliceop:
2004         return NO
2005
2006     return SPACE
2007
2008
2009 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2010     """Return the first leaf that precedes `node`, if any."""
2011     while node:
2012         res = node.prev_sibling
2013         if res:
2014             if isinstance(res, Leaf):
2015                 return res
2016
2017             try:
2018                 return list(res.leaves())[-1]
2019
2020             except IndexError:
2021                 return None
2022
2023         node = node.parent
2024     return None
2025
2026
2027 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2028     """Return the child of `ancestor` that contains `descendant`."""
2029     node: Optional[LN] = descendant
2030     while node and node.parent != ancestor:
2031         node = node.parent
2032     return node
2033
2034
2035 def container_of(leaf: Leaf) -> LN:
2036     """Return `leaf` or one of its ancestors that is the topmost container of it.
2037
2038     By "container" we mean a node where `leaf` is the very first child.
2039     """
2040     same_prefix = leaf.prefix
2041     container: LN = leaf
2042     while container:
2043         parent = container.parent
2044         if parent is None:
2045             break
2046
2047         if parent.children[0].prefix != same_prefix:
2048             break
2049
2050         if parent.type == syms.file_input:
2051             break
2052
2053         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2054             break
2055
2056         container = parent
2057     return container
2058
2059
2060 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2061     """Return the priority of the `leaf` delimiter, given a line break after it.
2062
2063     The delimiter priorities returned here are from those delimiters that would
2064     cause a line break after themselves.
2065
2066     Higher numbers are higher priority.
2067     """
2068     if leaf.type == token.COMMA:
2069         return COMMA_PRIORITY
2070
2071     return 0
2072
2073
2074 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2075     """Return the priority of the `leaf` delimiter, given a line break before it.
2076
2077     The delimiter priorities returned here are from those delimiters that would
2078     cause a line break before themselves.
2079
2080     Higher numbers are higher priority.
2081     """
2082     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2083         # * and ** might also be MATH_OPERATORS but in this case they are not.
2084         # Don't treat them as a delimiter.
2085         return 0
2086
2087     if (
2088         leaf.type == token.DOT
2089         and leaf.parent
2090         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2091         and (previous is None or previous.type in CLOSING_BRACKETS)
2092     ):
2093         return DOT_PRIORITY
2094
2095     if (
2096         leaf.type in MATH_OPERATORS
2097         and leaf.parent
2098         and leaf.parent.type not in {syms.factor, syms.star_expr}
2099     ):
2100         return MATH_PRIORITIES[leaf.type]
2101
2102     if leaf.type in COMPARATORS:
2103         return COMPARATOR_PRIORITY
2104
2105     if (
2106         leaf.type == token.STRING
2107         and previous is not None
2108         and previous.type == token.STRING
2109     ):
2110         return STRING_PRIORITY
2111
2112     if leaf.type not in {token.NAME, token.ASYNC}:
2113         return 0
2114
2115     if (
2116         leaf.value == "for"
2117         and leaf.parent
2118         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2119         or leaf.type == token.ASYNC
2120     ):
2121         if (
2122             not isinstance(leaf.prev_sibling, Leaf)
2123             or leaf.prev_sibling.value != "async"
2124         ):
2125             return COMPREHENSION_PRIORITY
2126
2127     if (
2128         leaf.value == "if"
2129         and leaf.parent
2130         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2131     ):
2132         return COMPREHENSION_PRIORITY
2133
2134     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2135         return TERNARY_PRIORITY
2136
2137     if leaf.value == "is":
2138         return COMPARATOR_PRIORITY
2139
2140     if (
2141         leaf.value == "in"
2142         and leaf.parent
2143         and leaf.parent.type in {syms.comp_op, syms.comparison}
2144         and not (
2145             previous is not None
2146             and previous.type == token.NAME
2147             and previous.value == "not"
2148         )
2149     ):
2150         return COMPARATOR_PRIORITY
2151
2152     if (
2153         leaf.value == "not"
2154         and leaf.parent
2155         and leaf.parent.type == syms.comp_op
2156         and not (
2157             previous is not None
2158             and previous.type == token.NAME
2159             and previous.value == "is"
2160         )
2161     ):
2162         return COMPARATOR_PRIORITY
2163
2164     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2165         return LOGIC_PRIORITY
2166
2167     return 0
2168
2169
2170 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2171 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2172
2173
2174 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2175     """Clean the prefix of the `leaf` and generate comments from it, if any.
2176
2177     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2178     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2179     move because it does away with modifying the grammar to include all the
2180     possible places in which comments can be placed.
2181
2182     The sad consequence for us though is that comments don't "belong" anywhere.
2183     This is why this function generates simple parentless Leaf objects for
2184     comments.  We simply don't know what the correct parent should be.
2185
2186     No matter though, we can live without this.  We really only need to
2187     differentiate between inline and standalone comments.  The latter don't
2188     share the line with any code.
2189
2190     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2191     are emitted with a fake STANDALONE_COMMENT token identifier.
2192     """
2193     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2194         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2195
2196
2197 @dataclass
2198 class ProtoComment:
2199     """Describes a piece of syntax that is a comment.
2200
2201     It's not a :class:`blib2to3.pytree.Leaf` so that:
2202
2203     * it can be cached (`Leaf` objects should not be reused more than once as
2204       they store their lineno, column, prefix, and parent information);
2205     * `newlines` and `consumed` fields are kept separate from the `value`. This
2206       simplifies handling of special marker comments like ``# fmt: off/on``.
2207     """
2208
2209     type: int  # token.COMMENT or STANDALONE_COMMENT
2210     value: str  # content of the comment
2211     newlines: int  # how many newlines before the comment
2212     consumed: int  # how many characters of the original leaf's prefix did we consume
2213
2214
2215 @lru_cache(maxsize=4096)
2216 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2217     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2218     result: List[ProtoComment] = []
2219     if not prefix or "#" not in prefix:
2220         return result
2221
2222     consumed = 0
2223     nlines = 0
2224     ignored_lines = 0
2225     for index, line in enumerate(prefix.split("\n")):
2226         consumed += len(line) + 1  # adding the length of the split '\n'
2227         line = line.lstrip()
2228         if not line:
2229             nlines += 1
2230         if not line.startswith("#"):
2231             # Escaped newlines outside of a comment are not really newlines at
2232             # all. We treat a single-line comment following an escaped newline
2233             # as a simple trailing comment.
2234             if line.endswith("\\"):
2235                 ignored_lines += 1
2236             continue
2237
2238         if index == ignored_lines and not is_endmarker:
2239             comment_type = token.COMMENT  # simple trailing comment
2240         else:
2241             comment_type = STANDALONE_COMMENT
2242         comment = make_comment(line)
2243         result.append(
2244             ProtoComment(
2245                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2246             )
2247         )
2248         nlines = 0
2249     return result
2250
2251
2252 def make_comment(content: str) -> str:
2253     """Return a consistently formatted comment from the given `content` string.
2254
2255     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2256     space between the hash sign and the content.
2257
2258     If `content` didn't start with a hash sign, one is provided.
2259     """
2260     content = content.rstrip()
2261     if not content:
2262         return "#"
2263
2264     if content[0] == "#":
2265         content = content[1:]
2266     if content and content[0] not in " !:#'%":
2267         content = " " + content
2268     return "#" + content
2269
2270
2271 def split_line(
2272     line: Line,
2273     line_length: int,
2274     inner: bool = False,
2275     features: Collection[Feature] = (),
2276 ) -> Iterator[Line]:
2277     """Split a `line` into potentially many lines.
2278
2279     They should fit in the allotted `line_length` but might not be able to.
2280     `inner` signifies that there were a pair of brackets somewhere around the
2281     current `line`, possibly transitively. This means we can fallback to splitting
2282     by delimiters if the LHS/RHS don't yield any results.
2283
2284     `features` are syntactical features that may be used in the output.
2285     """
2286     if line.is_comment:
2287         yield line
2288         return
2289
2290     line_str = str(line).strip("\n")
2291
2292     if (
2293         not line.contains_inner_type_comments()
2294         and not line.should_explode
2295         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2296     ):
2297         yield line
2298         return
2299
2300     split_funcs: List[SplitFunc]
2301     if line.is_def:
2302         split_funcs = [left_hand_split]
2303     else:
2304
2305         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2306             for omit in generate_trailers_to_omit(line, line_length):
2307                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2308                 if is_line_short_enough(lines[0], line_length=line_length):
2309                     yield from lines
2310                     return
2311
2312             # All splits failed, best effort split with no omits.
2313             # This mostly happens to multiline strings that are by definition
2314             # reported as not fitting a single line.
2315             yield from right_hand_split(line, line_length, features=features)
2316
2317         if line.inside_brackets:
2318             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2319         else:
2320             split_funcs = [rhs]
2321     for split_func in split_funcs:
2322         # We are accumulating lines in `result` because we might want to abort
2323         # mission and return the original line in the end, or attempt a different
2324         # split altogether.
2325         result: List[Line] = []
2326         try:
2327             for l in split_func(line, features):
2328                 if str(l).strip("\n") == line_str:
2329                     raise CannotSplit("Split function returned an unchanged result")
2330
2331                 result.extend(
2332                     split_line(
2333                         l, line_length=line_length, inner=True, features=features
2334                     )
2335                 )
2336         except CannotSplit:
2337             continue
2338
2339         else:
2340             yield from result
2341             break
2342
2343     else:
2344         yield line
2345
2346
2347 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2348     """Split line into many lines, starting with the first matching bracket pair.
2349
2350     Note: this usually looks weird, only use this for function definitions.
2351     Prefer RHS otherwise.  This is why this function is not symmetrical with
2352     :func:`right_hand_split` which also handles optional parentheses.
2353     """
2354     tail_leaves: List[Leaf] = []
2355     body_leaves: List[Leaf] = []
2356     head_leaves: List[Leaf] = []
2357     current_leaves = head_leaves
2358     matching_bracket = None
2359     for leaf in line.leaves:
2360         if (
2361             current_leaves is body_leaves
2362             and leaf.type in CLOSING_BRACKETS
2363             and leaf.opening_bracket is matching_bracket
2364         ):
2365             current_leaves = tail_leaves if body_leaves else head_leaves
2366         current_leaves.append(leaf)
2367         if current_leaves is head_leaves:
2368             if leaf.type in OPENING_BRACKETS:
2369                 matching_bracket = leaf
2370                 current_leaves = body_leaves
2371     if not matching_bracket:
2372         raise CannotSplit("No brackets found")
2373
2374     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2375     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2376     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2377     bracket_split_succeeded_or_raise(head, body, tail)
2378     for result in (head, body, tail):
2379         if result:
2380             yield result
2381
2382
2383 def right_hand_split(
2384     line: Line,
2385     line_length: int,
2386     features: Collection[Feature] = (),
2387     omit: Collection[LeafID] = (),
2388 ) -> Iterator[Line]:
2389     """Split line into many lines, starting with the last matching bracket pair.
2390
2391     If the split was by optional parentheses, attempt splitting without them, too.
2392     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2393     this split.
2394
2395     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2396     """
2397     tail_leaves: List[Leaf] = []
2398     body_leaves: List[Leaf] = []
2399     head_leaves: List[Leaf] = []
2400     current_leaves = tail_leaves
2401     opening_bracket = None
2402     closing_bracket = None
2403     for leaf in reversed(line.leaves):
2404         if current_leaves is body_leaves:
2405             if leaf is opening_bracket:
2406                 current_leaves = head_leaves if body_leaves else tail_leaves
2407         current_leaves.append(leaf)
2408         if current_leaves is tail_leaves:
2409             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2410                 opening_bracket = leaf.opening_bracket
2411                 closing_bracket = leaf
2412                 current_leaves = body_leaves
2413     if not (opening_bracket and closing_bracket and head_leaves):
2414         # If there is no opening or closing_bracket that means the split failed and
2415         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2416         # the matching `opening_bracket` wasn't available on `line` anymore.
2417         raise CannotSplit("No brackets found")
2418
2419     tail_leaves.reverse()
2420     body_leaves.reverse()
2421     head_leaves.reverse()
2422     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2423     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2424     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2425     bracket_split_succeeded_or_raise(head, body, tail)
2426     if (
2427         # the body shouldn't be exploded
2428         not body.should_explode
2429         # the opening bracket is an optional paren
2430         and opening_bracket.type == token.LPAR
2431         and not opening_bracket.value
2432         # the closing bracket is an optional paren
2433         and closing_bracket.type == token.RPAR
2434         and not closing_bracket.value
2435         # it's not an import (optional parens are the only thing we can split on
2436         # in this case; attempting a split without them is a waste of time)
2437         and not line.is_import
2438         # there are no standalone comments in the body
2439         and not body.contains_standalone_comments(0)
2440         # and we can actually remove the parens
2441         and can_omit_invisible_parens(body, line_length)
2442     ):
2443         omit = {id(closing_bracket), *omit}
2444         try:
2445             yield from right_hand_split(line, line_length, features=features, omit=omit)
2446             return
2447
2448         except CannotSplit:
2449             if not (
2450                 can_be_split(body)
2451                 or is_line_short_enough(body, line_length=line_length)
2452             ):
2453                 raise CannotSplit(
2454                     "Splitting failed, body is still too long and can't be split."
2455                 )
2456
2457             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2458                 raise CannotSplit(
2459                     "The current optional pair of parentheses is bound to fail to "
2460                     "satisfy the splitting algorithm because the head or the tail "
2461                     "contains multiline strings which by definition never fit one "
2462                     "line."
2463                 )
2464
2465     ensure_visible(opening_bracket)
2466     ensure_visible(closing_bracket)
2467     for result in (head, body, tail):
2468         if result:
2469             yield result
2470
2471
2472 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2473     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2474
2475     Do nothing otherwise.
2476
2477     A left- or right-hand split is based on a pair of brackets. Content before
2478     (and including) the opening bracket is left on one line, content inside the
2479     brackets is put on a separate line, and finally content starting with and
2480     following the closing bracket is put on a separate line.
2481
2482     Those are called `head`, `body`, and `tail`, respectively. If the split
2483     produced the same line (all content in `head`) or ended up with an empty `body`
2484     and the `tail` is just the closing bracket, then it's considered failed.
2485     """
2486     tail_len = len(str(tail).strip())
2487     if not body:
2488         if tail_len == 0:
2489             raise CannotSplit("Splitting brackets produced the same line")
2490
2491         elif tail_len < 3:
2492             raise CannotSplit(
2493                 f"Splitting brackets on an empty body to save "
2494                 f"{tail_len} characters is not worth it"
2495             )
2496
2497
2498 def bracket_split_build_line(
2499     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2500 ) -> Line:
2501     """Return a new line with given `leaves` and respective comments from `original`.
2502
2503     If `is_body` is True, the result line is one-indented inside brackets and as such
2504     has its first leaf's prefix normalized and a trailing comma added when expected.
2505     """
2506     result = Line(depth=original.depth)
2507     if is_body:
2508         result.inside_brackets = True
2509         result.depth += 1
2510         if leaves:
2511             # Since body is a new indent level, remove spurious leading whitespace.
2512             normalize_prefix(leaves[0], inside_brackets=True)
2513             # Ensure a trailing comma for imports and standalone function arguments, but
2514             # be careful not to add one after any comments.
2515             no_commas = original.is_def and not any(
2516                 l.type == token.COMMA for l in leaves
2517             )
2518
2519             if original.is_import or no_commas:
2520                 for i in range(len(leaves) - 1, -1, -1):
2521                     if leaves[i].type == STANDALONE_COMMENT:
2522                         continue
2523                     elif leaves[i].type == token.COMMA:
2524                         break
2525                     else:
2526                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2527                         break
2528     # Populate the line
2529     for leaf in leaves:
2530         result.append(leaf, preformatted=True)
2531         for comment_after in original.comments_after(leaf):
2532             result.append(comment_after, preformatted=True)
2533     if is_body:
2534         result.should_explode = should_explode(result, opening_bracket)
2535     return result
2536
2537
2538 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2539     """Normalize prefix of the first leaf in every line returned by `split_func`.
2540
2541     This is a decorator over relevant split functions.
2542     """
2543
2544     @wraps(split_func)
2545     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2546         for l in split_func(line, features):
2547             normalize_prefix(l.leaves[0], inside_brackets=True)
2548             yield l
2549
2550     return split_wrapper
2551
2552
2553 @dont_increase_indentation
2554 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2555     """Split according to delimiters of the highest priority.
2556
2557     If the appropriate Features are given, the split will add trailing commas
2558     also in function signatures and calls that contain `*` and `**`.
2559     """
2560     try:
2561         last_leaf = line.leaves[-1]
2562     except IndexError:
2563         raise CannotSplit("Line empty")
2564
2565     bt = line.bracket_tracker
2566     try:
2567         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2568     except ValueError:
2569         raise CannotSplit("No delimiters found")
2570
2571     if delimiter_priority == DOT_PRIORITY:
2572         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2573             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2574
2575     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2576     lowest_depth = sys.maxsize
2577     trailing_comma_safe = True
2578
2579     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2580         """Append `leaf` to current line or to new line if appending impossible."""
2581         nonlocal current_line
2582         try:
2583             current_line.append_safe(leaf, preformatted=True)
2584         except ValueError:
2585             yield current_line
2586
2587             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2588             current_line.append(leaf)
2589
2590     for leaf in line.leaves:
2591         yield from append_to_line(leaf)
2592
2593         for comment_after in line.comments_after(leaf):
2594             yield from append_to_line(comment_after)
2595
2596         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2597         if leaf.bracket_depth == lowest_depth:
2598             if is_vararg(leaf, within={syms.typedargslist}):
2599                 trailing_comma_safe = (
2600                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2601                 )
2602             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2603                 trailing_comma_safe = (
2604                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2605                 )
2606
2607         leaf_priority = bt.delimiters.get(id(leaf))
2608         if leaf_priority == delimiter_priority:
2609             yield current_line
2610
2611             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2612     if current_line:
2613         if (
2614             trailing_comma_safe
2615             and delimiter_priority == COMMA_PRIORITY
2616             and current_line.leaves[-1].type != token.COMMA
2617             and current_line.leaves[-1].type != STANDALONE_COMMENT
2618         ):
2619             current_line.append(Leaf(token.COMMA, ","))
2620         yield current_line
2621
2622
2623 @dont_increase_indentation
2624 def standalone_comment_split(
2625     line: Line, features: Collection[Feature] = ()
2626 ) -> Iterator[Line]:
2627     """Split standalone comments from the rest of the line."""
2628     if not line.contains_standalone_comments(0):
2629         raise CannotSplit("Line does not have any standalone comments")
2630
2631     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2632
2633     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2634         """Append `leaf` to current line or to new line if appending impossible."""
2635         nonlocal current_line
2636         try:
2637             current_line.append_safe(leaf, preformatted=True)
2638         except ValueError:
2639             yield current_line
2640
2641             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2642             current_line.append(leaf)
2643
2644     for leaf in line.leaves:
2645         yield from append_to_line(leaf)
2646
2647         for comment_after in line.comments_after(leaf):
2648             yield from append_to_line(comment_after)
2649
2650     if current_line:
2651         yield current_line
2652
2653
2654 def is_import(leaf: Leaf) -> bool:
2655     """Return True if the given leaf starts an import statement."""
2656     p = leaf.parent
2657     t = leaf.type
2658     v = leaf.value
2659     return bool(
2660         t == token.NAME
2661         and (
2662             (v == "import" and p and p.type == syms.import_name)
2663             or (v == "from" and p and p.type == syms.import_from)
2664         )
2665     )
2666
2667
2668 def is_type_comment(leaf: Leaf) -> bool:
2669     """Return True if the given leaf is a special comment.
2670     Only returns true for type comments for now."""
2671     t = leaf.type
2672     v = leaf.value
2673     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2674
2675
2676 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2677     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2678     else.
2679
2680     Note: don't use backslashes for formatting or you'll lose your voting rights.
2681     """
2682     if not inside_brackets:
2683         spl = leaf.prefix.split("#")
2684         if "\\" not in spl[0]:
2685             nl_count = spl[-1].count("\n")
2686             if len(spl) > 1:
2687                 nl_count -= 1
2688             leaf.prefix = "\n" * nl_count
2689             return
2690
2691     leaf.prefix = ""
2692
2693
2694 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2695     """Make all string prefixes lowercase.
2696
2697     If remove_u_prefix is given, also removes any u prefix from the string.
2698
2699     Note: Mutates its argument.
2700     """
2701     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2702     assert match is not None, f"failed to match string {leaf.value!r}"
2703     orig_prefix = match.group(1)
2704     new_prefix = orig_prefix.lower()
2705     if remove_u_prefix:
2706         new_prefix = new_prefix.replace("u", "")
2707     leaf.value = f"{new_prefix}{match.group(2)}"
2708
2709
2710 def normalize_string_quotes(leaf: Leaf) -> None:
2711     """Prefer double quotes but only if it doesn't cause more escaping.
2712
2713     Adds or removes backslashes as appropriate. Doesn't parse and fix
2714     strings nested in f-strings (yet).
2715
2716     Note: Mutates its argument.
2717     """
2718     value = leaf.value.lstrip("furbFURB")
2719     if value[:3] == '"""':
2720         return
2721
2722     elif value[:3] == "'''":
2723         orig_quote = "'''"
2724         new_quote = '"""'
2725     elif value[0] == '"':
2726         orig_quote = '"'
2727         new_quote = "'"
2728     else:
2729         orig_quote = "'"
2730         new_quote = '"'
2731     first_quote_pos = leaf.value.find(orig_quote)
2732     if first_quote_pos == -1:
2733         return  # There's an internal error
2734
2735     prefix = leaf.value[:first_quote_pos]
2736     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2737     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2738     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2739     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2740     if "r" in prefix.casefold():
2741         if unescaped_new_quote.search(body):
2742             # There's at least one unescaped new_quote in this raw string
2743             # so converting is impossible
2744             return
2745
2746         # Do not introduce or remove backslashes in raw strings
2747         new_body = body
2748     else:
2749         # remove unnecessary escapes
2750         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2751         if body != new_body:
2752             # Consider the string without unnecessary escapes as the original
2753             body = new_body
2754             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2755         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2756         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2757     if "f" in prefix.casefold():
2758         matches = re.findall(
2759             r"""
2760             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2761                 ([^{].*?)  # contents of the brackets except if begins with {{
2762             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2763             """,
2764             new_body,
2765             re.VERBOSE,
2766         )
2767         for m in matches:
2768             if "\\" in str(m):
2769                 # Do not introduce backslashes in interpolated expressions
2770                 return
2771     if new_quote == '"""' and new_body[-1:] == '"':
2772         # edge case:
2773         new_body = new_body[:-1] + '\\"'
2774     orig_escape_count = body.count("\\")
2775     new_escape_count = new_body.count("\\")
2776     if new_escape_count > orig_escape_count:
2777         return  # Do not introduce more escaping
2778
2779     if new_escape_count == orig_escape_count and orig_quote == '"':
2780         return  # Prefer double quotes
2781
2782     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2783
2784
2785 def normalize_numeric_literal(leaf: Leaf) -> None:
2786     """Normalizes numeric (float, int, and complex) literals.
2787
2788     All letters used in the representation are normalized to lowercase (except
2789     in Python 2 long literals).
2790     """
2791     text = leaf.value.lower()
2792     if text.startswith(("0o", "0b")):
2793         # Leave octal and binary literals alone.
2794         pass
2795     elif text.startswith("0x"):
2796         # Change hex literals to upper case.
2797         before, after = text[:2], text[2:]
2798         text = f"{before}{after.upper()}"
2799     elif "e" in text:
2800         before, after = text.split("e")
2801         sign = ""
2802         if after.startswith("-"):
2803             after = after[1:]
2804             sign = "-"
2805         elif after.startswith("+"):
2806             after = after[1:]
2807         before = format_float_or_int_string(before)
2808         text = f"{before}e{sign}{after}"
2809     elif text.endswith(("j", "l")):
2810         number = text[:-1]
2811         suffix = text[-1]
2812         # Capitalize in "2L" because "l" looks too similar to "1".
2813         if suffix == "l":
2814             suffix = "L"
2815         text = f"{format_float_or_int_string(number)}{suffix}"
2816     else:
2817         text = format_float_or_int_string(text)
2818     leaf.value = text
2819
2820
2821 def format_float_or_int_string(text: str) -> str:
2822     """Formats a float string like "1.0"."""
2823     if "." not in text:
2824         return text
2825
2826     before, after = text.split(".")
2827     return f"{before or 0}.{after or 0}"
2828
2829
2830 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2831     """Make existing optional parentheses invisible or create new ones.
2832
2833     `parens_after` is a set of string leaf values immediately after which parens
2834     should be put.
2835
2836     Standardizes on visible parentheses for single-element tuples, and keeps
2837     existing visible parentheses for other tuples and generator expressions.
2838     """
2839     for pc in list_comments(node.prefix, is_endmarker=False):
2840         if pc.value in FMT_OFF:
2841             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2842             return
2843
2844     check_lpar = False
2845     for index, child in enumerate(list(node.children)):
2846         # Add parentheses around long tuple unpacking in assignments.
2847         if (
2848             index == 0
2849             and isinstance(child, Node)
2850             and child.type == syms.testlist_star_expr
2851         ):
2852             check_lpar = True
2853
2854         if check_lpar:
2855             if child.type == syms.atom:
2856                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2857                     lpar = Leaf(token.LPAR, "")
2858                     rpar = Leaf(token.RPAR, "")
2859                     index = child.remove() or 0
2860                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2861             elif is_one_tuple(child):
2862                 # wrap child in visible parentheses
2863                 lpar = Leaf(token.LPAR, "(")
2864                 rpar = Leaf(token.RPAR, ")")
2865                 child.remove()
2866                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2867             elif node.type == syms.import_from:
2868                 # "import from" nodes store parentheses directly as part of
2869                 # the statement
2870                 if child.type == token.LPAR:
2871                     # make parentheses invisible
2872                     child.value = ""  # type: ignore
2873                     node.children[-1].value = ""  # type: ignore
2874                 elif child.type != token.STAR:
2875                     # insert invisible parentheses
2876                     node.insert_child(index, Leaf(token.LPAR, ""))
2877                     node.append_child(Leaf(token.RPAR, ""))
2878                 break
2879
2880             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2881                 # wrap child in invisible parentheses
2882                 lpar = Leaf(token.LPAR, "")
2883                 rpar = Leaf(token.RPAR, "")
2884                 index = child.remove() or 0
2885                 prefix = child.prefix
2886                 child.prefix = ""
2887                 new_child = Node(syms.atom, [lpar, child, rpar])
2888                 new_child.prefix = prefix
2889                 node.insert_child(index, new_child)
2890
2891         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2892
2893
2894 def normalize_fmt_off(node: Node) -> None:
2895     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2896     try_again = True
2897     while try_again:
2898         try_again = convert_one_fmt_off_pair(node)
2899
2900
2901 def convert_one_fmt_off_pair(node: Node) -> bool:
2902     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2903
2904     Returns True if a pair was converted.
2905     """
2906     for leaf in node.leaves():
2907         previous_consumed = 0
2908         for comment in list_comments(leaf.prefix, is_endmarker=False):
2909             if comment.value in FMT_OFF:
2910                 # We only want standalone comments. If there's no previous leaf or
2911                 # the previous leaf is indentation, it's a standalone comment in
2912                 # disguise.
2913                 if comment.type != STANDALONE_COMMENT:
2914                     prev = preceding_leaf(leaf)
2915                     if prev and prev.type not in WHITESPACE:
2916                         continue
2917
2918                 ignored_nodes = list(generate_ignored_nodes(leaf))
2919                 if not ignored_nodes:
2920                     continue
2921
2922                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2923                 parent = first.parent
2924                 prefix = first.prefix
2925                 first.prefix = prefix[comment.consumed :]
2926                 hidden_value = (
2927                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2928                 )
2929                 if hidden_value.endswith("\n"):
2930                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2931                     # leaf (possibly followed by a DEDENT).
2932                     hidden_value = hidden_value[:-1]
2933                 first_idx = None
2934                 for ignored in ignored_nodes:
2935                     index = ignored.remove()
2936                     if first_idx is None:
2937                         first_idx = index
2938                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2939                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2940                 parent.insert_child(
2941                     first_idx,
2942                     Leaf(
2943                         STANDALONE_COMMENT,
2944                         hidden_value,
2945                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2946                     ),
2947                 )
2948                 return True
2949
2950             previous_consumed = comment.consumed
2951
2952     return False
2953
2954
2955 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2956     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2957
2958     Stops at the end of the block.
2959     """
2960     container: Optional[LN] = container_of(leaf)
2961     while container is not None and container.type != token.ENDMARKER:
2962         for comment in list_comments(container.prefix, is_endmarker=False):
2963             if comment.value in FMT_ON:
2964                 return
2965
2966         yield container
2967
2968         container = container.next_sibling
2969
2970
2971 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
2972     """If it's safe, make the parens in the atom `node` invisible, recursively.
2973
2974     Returns whether the node should itself be wrapped in invisible parentheses.
2975
2976     """
2977     if (
2978         node.type != syms.atom
2979         or is_empty_tuple(node)
2980         or is_one_tuple(node)
2981         or (is_yield(node) and parent.type != syms.expr_stmt)
2982         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2983     ):
2984         return False
2985
2986     first = node.children[0]
2987     last = node.children[-1]
2988     if first.type == token.LPAR and last.type == token.RPAR:
2989         # make parentheses invisible
2990         first.value = ""  # type: ignore
2991         last.value = ""  # type: ignore
2992         if len(node.children) > 1:
2993             maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
2994         return False
2995
2996     return True
2997
2998
2999 def is_empty_tuple(node: LN) -> bool:
3000     """Return True if `node` holds an empty tuple."""
3001     return (
3002         node.type == syms.atom
3003         and len(node.children) == 2
3004         and node.children[0].type == token.LPAR
3005         and node.children[1].type == token.RPAR
3006     )
3007
3008
3009 def is_one_tuple(node: LN) -> bool:
3010     """Return True if `node` holds a tuple with one element, with or without parens."""
3011     if node.type == syms.atom:
3012         if len(node.children) != 3:
3013             return False
3014
3015         lpar, gexp, rpar = node.children
3016         if not (
3017             lpar.type == token.LPAR
3018             and gexp.type == syms.testlist_gexp
3019             and rpar.type == token.RPAR
3020         ):
3021             return False
3022
3023         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3024
3025     return (
3026         node.type in IMPLICIT_TUPLE
3027         and len(node.children) == 2
3028         and node.children[1].type == token.COMMA
3029     )
3030
3031
3032 def is_yield(node: LN) -> bool:
3033     """Return True if `node` holds a `yield` or `yield from` expression."""
3034     if node.type == syms.yield_expr:
3035         return True
3036
3037     if node.type == token.NAME and node.value == "yield":  # type: ignore
3038         return True
3039
3040     if node.type != syms.atom:
3041         return False
3042
3043     if len(node.children) != 3:
3044         return False
3045
3046     lpar, expr, rpar = node.children
3047     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3048         return is_yield(expr)
3049
3050     return False
3051
3052
3053 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3054     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3055
3056     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3057     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3058     extended iterable unpacking (PEP 3132) and additional unpacking
3059     generalizations (PEP 448).
3060     """
3061     if leaf.type not in STARS or not leaf.parent:
3062         return False
3063
3064     p = leaf.parent
3065     if p.type == syms.star_expr:
3066         # Star expressions are also used as assignment targets in extended
3067         # iterable unpacking (PEP 3132).  See what its parent is instead.
3068         if not p.parent:
3069             return False
3070
3071         p = p.parent
3072
3073     return p.type in within
3074
3075
3076 def is_multiline_string(leaf: Leaf) -> bool:
3077     """Return True if `leaf` is a multiline string that actually spans many lines."""
3078     value = leaf.value.lstrip("furbFURB")
3079     return value[:3] in {'"""', "'''"} and "\n" in value
3080
3081
3082 def is_stub_suite(node: Node) -> bool:
3083     """Return True if `node` is a suite with a stub body."""
3084     if (
3085         len(node.children) != 4
3086         or node.children[0].type != token.NEWLINE
3087         or node.children[1].type != token.INDENT
3088         or node.children[3].type != token.DEDENT
3089     ):
3090         return False
3091
3092     return is_stub_body(node.children[2])
3093
3094
3095 def is_stub_body(node: LN) -> bool:
3096     """Return True if `node` is a simple statement containing an ellipsis."""
3097     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3098         return False
3099
3100     if len(node.children) != 2:
3101         return False
3102
3103     child = node.children[0]
3104     return (
3105         child.type == syms.atom
3106         and len(child.children) == 3
3107         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3108     )
3109
3110
3111 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3112     """Return maximum delimiter priority inside `node`.
3113
3114     This is specific to atoms with contents contained in a pair of parentheses.
3115     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3116     """
3117     if node.type != syms.atom:
3118         return 0
3119
3120     first = node.children[0]
3121     last = node.children[-1]
3122     if not (first.type == token.LPAR and last.type == token.RPAR):
3123         return 0
3124
3125     bt = BracketTracker()
3126     for c in node.children[1:-1]:
3127         if isinstance(c, Leaf):
3128             bt.mark(c)
3129         else:
3130             for leaf in c.leaves():
3131                 bt.mark(leaf)
3132     try:
3133         return bt.max_delimiter_priority()
3134
3135     except ValueError:
3136         return 0
3137
3138
3139 def ensure_visible(leaf: Leaf) -> None:
3140     """Make sure parentheses are visible.
3141
3142     They could be invisible as part of some statements (see
3143     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3144     """
3145     if leaf.type == token.LPAR:
3146         leaf.value = "("
3147     elif leaf.type == token.RPAR:
3148         leaf.value = ")"
3149
3150
3151 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3152     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3153
3154     if not (
3155         opening_bracket.parent
3156         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3157         and opening_bracket.value in "[{("
3158     ):
3159         return False
3160
3161     try:
3162         last_leaf = line.leaves[-1]
3163         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3164         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3165     except (IndexError, ValueError):
3166         return False
3167
3168     return max_priority == COMMA_PRIORITY
3169
3170
3171 def get_features_used(node: Node) -> Set[Feature]:
3172     """Return a set of (relatively) new Python features used in this file.
3173
3174     Currently looking for:
3175     - f-strings;
3176     - underscores in numeric literals; and
3177     - trailing commas after * or ** in function signatures and calls.
3178     """
3179     features: Set[Feature] = set()
3180     for n in node.pre_order():
3181         if n.type == token.STRING:
3182             value_head = n.value[:2]  # type: ignore
3183             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3184                 features.add(Feature.F_STRINGS)
3185
3186         elif n.type == token.NUMBER:
3187             if "_" in n.value:  # type: ignore
3188                 features.add(Feature.NUMERIC_UNDERSCORES)
3189
3190         elif (
3191             n.type in {syms.typedargslist, syms.arglist}
3192             and n.children
3193             and n.children[-1].type == token.COMMA
3194         ):
3195             if n.type == syms.typedargslist:
3196                 feature = Feature.TRAILING_COMMA_IN_DEF
3197             else:
3198                 feature = Feature.TRAILING_COMMA_IN_CALL
3199
3200             for ch in n.children:
3201                 if ch.type in STARS:
3202                     features.add(feature)
3203
3204                 if ch.type == syms.argument:
3205                     for argch in ch.children:
3206                         if argch.type in STARS:
3207                             features.add(feature)
3208
3209     return features
3210
3211
3212 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3213     """Detect the version to target based on the nodes used."""
3214     features = get_features_used(node)
3215     return {
3216         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3217     }
3218
3219
3220 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3221     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3222
3223     Brackets can be omitted if the entire trailer up to and including
3224     a preceding closing bracket fits in one line.
3225
3226     Yielded sets are cumulative (contain results of previous yields, too).  First
3227     set is empty.
3228     """
3229
3230     omit: Set[LeafID] = set()
3231     yield omit
3232
3233     length = 4 * line.depth
3234     opening_bracket = None
3235     closing_bracket = None
3236     inner_brackets: Set[LeafID] = set()
3237     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3238         length += leaf_length
3239         if length > line_length:
3240             break
3241
3242         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3243         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3244             break
3245
3246         if opening_bracket:
3247             if leaf is opening_bracket:
3248                 opening_bracket = None
3249             elif leaf.type in CLOSING_BRACKETS:
3250                 inner_brackets.add(id(leaf))
3251         elif leaf.type in CLOSING_BRACKETS:
3252             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3253                 # Empty brackets would fail a split so treat them as "inner"
3254                 # brackets (e.g. only add them to the `omit` set if another
3255                 # pair of brackets was good enough.
3256                 inner_brackets.add(id(leaf))
3257                 continue
3258
3259             if closing_bracket:
3260                 omit.add(id(closing_bracket))
3261                 omit.update(inner_brackets)
3262                 inner_brackets.clear()
3263                 yield omit
3264
3265             if leaf.value:
3266                 opening_bracket = leaf.opening_bracket
3267                 closing_bracket = leaf
3268
3269
3270 def get_future_imports(node: Node) -> Set[str]:
3271     """Return a set of __future__ imports in the file."""
3272     imports: Set[str] = set()
3273
3274     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3275         for child in children:
3276             if isinstance(child, Leaf):
3277                 if child.type == token.NAME:
3278                     yield child.value
3279             elif child.type == syms.import_as_name:
3280                 orig_name = child.children[0]
3281                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3282                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3283                 yield orig_name.value
3284             elif child.type == syms.import_as_names:
3285                 yield from get_imports_from_children(child.children)
3286             else:
3287                 raise AssertionError("Invalid syntax parsing imports")
3288
3289     for child in node.children:
3290         if child.type != syms.simple_stmt:
3291             break
3292         first_child = child.children[0]
3293         if isinstance(first_child, Leaf):
3294             # Continue looking if we see a docstring; otherwise stop.
3295             if (
3296                 len(child.children) == 2
3297                 and first_child.type == token.STRING
3298                 and child.children[1].type == token.NEWLINE
3299             ):
3300                 continue
3301             else:
3302                 break
3303         elif first_child.type == syms.import_from:
3304             module_name = first_child.children[1]
3305             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3306                 break
3307             imports |= set(get_imports_from_children(first_child.children[3:]))
3308         else:
3309             break
3310     return imports
3311
3312
3313 def gen_python_files_in_dir(
3314     path: Path,
3315     root: Path,
3316     include: Pattern[str],
3317     exclude: Pattern[str],
3318     report: "Report",
3319 ) -> Iterator[Path]:
3320     """Generate all files under `path` whose paths are not excluded by the
3321     `exclude` regex, but are included by the `include` regex.
3322
3323     Symbolic links pointing outside of the `root` directory are ignored.
3324
3325     `report` is where output about exclusions goes.
3326     """
3327     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3328     for child in path.iterdir():
3329         try:
3330             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3331         except ValueError:
3332             if child.is_symlink():
3333                 report.path_ignored(
3334                     child, f"is a symbolic link that points outside {root}"
3335                 )
3336                 continue
3337
3338             raise
3339
3340         if child.is_dir():
3341             normalized_path += "/"
3342         exclude_match = exclude.search(normalized_path)
3343         if exclude_match and exclude_match.group(0):
3344             report.path_ignored(child, f"matches the --exclude regular expression")
3345             continue
3346
3347         if child.is_dir():
3348             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3349
3350         elif child.is_file():
3351             include_match = include.search(normalized_path)
3352             if include_match:
3353                 yield child
3354
3355
3356 @lru_cache()
3357 def find_project_root(srcs: Iterable[str]) -> Path:
3358     """Return a directory containing .git, .hg, or pyproject.toml.
3359
3360     That directory can be one of the directories passed in `srcs` or their
3361     common parent.
3362
3363     If no directory in the tree contains a marker that would specify it's the
3364     project root, the root of the file system is returned.
3365     """
3366     if not srcs:
3367         return Path("/").resolve()
3368
3369     common_base = min(Path(src).resolve() for src in srcs)
3370     if common_base.is_dir():
3371         # Append a fake file so `parents` below returns `common_base_dir`, too.
3372         common_base /= "fake-file"
3373     for directory in common_base.parents:
3374         if (directory / ".git").is_dir():
3375             return directory
3376
3377         if (directory / ".hg").is_dir():
3378             return directory
3379
3380         if (directory / "pyproject.toml").is_file():
3381             return directory
3382
3383     return directory
3384
3385
3386 @dataclass
3387 class Report:
3388     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3389
3390     check: bool = False
3391     quiet: bool = False
3392     verbose: bool = False
3393     change_count: int = 0
3394     same_count: int = 0
3395     failure_count: int = 0
3396
3397     def done(self, src: Path, changed: Changed) -> None:
3398         """Increment the counter for successful reformatting. Write out a message."""
3399         if changed is Changed.YES:
3400             reformatted = "would reformat" if self.check else "reformatted"
3401             if self.verbose or not self.quiet:
3402                 out(f"{reformatted} {src}")
3403             self.change_count += 1
3404         else:
3405             if self.verbose:
3406                 if changed is Changed.NO:
3407                     msg = f"{src} already well formatted, good job."
3408                 else:
3409                     msg = f"{src} wasn't modified on disk since last run."
3410                 out(msg, bold=False)
3411             self.same_count += 1
3412
3413     def failed(self, src: Path, message: str) -> None:
3414         """Increment the counter for failed reformatting. Write out a message."""
3415         err(f"error: cannot format {src}: {message}")
3416         self.failure_count += 1
3417
3418     def path_ignored(self, path: Path, message: str) -> None:
3419         if self.verbose:
3420             out(f"{path} ignored: {message}", bold=False)
3421
3422     @property
3423     def return_code(self) -> int:
3424         """Return the exit code that the app should use.
3425
3426         This considers the current state of changed files and failures:
3427         - if there were any failures, return 123;
3428         - if any files were changed and --check is being used, return 1;
3429         - otherwise return 0.
3430         """
3431         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3432         # 126 we have special return codes reserved by the shell.
3433         if self.failure_count:
3434             return 123
3435
3436         elif self.change_count and self.check:
3437             return 1
3438
3439         return 0
3440
3441     def __str__(self) -> str:
3442         """Render a color report of the current state.
3443
3444         Use `click.unstyle` to remove colors.
3445         """
3446         if self.check:
3447             reformatted = "would be reformatted"
3448             unchanged = "would be left unchanged"
3449             failed = "would fail to reformat"
3450         else:
3451             reformatted = "reformatted"
3452             unchanged = "left unchanged"
3453             failed = "failed to reformat"
3454         report = []
3455         if self.change_count:
3456             s = "s" if self.change_count > 1 else ""
3457             report.append(
3458                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3459             )
3460         if self.same_count:
3461             s = "s" if self.same_count > 1 else ""
3462             report.append(f"{self.same_count} file{s} {unchanged}")
3463         if self.failure_count:
3464             s = "s" if self.failure_count > 1 else ""
3465             report.append(
3466                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3467             )
3468         return ", ".join(report) + "."
3469
3470
3471 def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
3472     for feature_version in (7, 6):
3473         try:
3474             return ast3.parse(src, feature_version=feature_version)
3475         except SyntaxError:
3476             continue
3477
3478     return ast27.parse(src)
3479
3480
3481 def assert_equivalent(src: str, dst: str) -> None:
3482     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3483
3484     def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3485         """Simple visitor generating strings to compare ASTs by content."""
3486         yield f"{'  ' * depth}{node.__class__.__name__}("
3487
3488         for field in sorted(node._fields):
3489             # TypeIgnore has only one field 'lineno' which breaks this comparison
3490             if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
3491                 break
3492
3493             # Ignore str kind which is case sensitive / and ignores unicode_literals
3494             if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
3495                 continue
3496
3497             try:
3498                 value = getattr(node, field)
3499             except AttributeError:
3500                 continue
3501
3502             yield f"{'  ' * (depth+1)}{field}="
3503
3504             if isinstance(value, list):
3505                 for item in value:
3506                     # Ignore nested tuples within del statements, because we may insert
3507                     # parentheses and they change the AST.
3508                     if (
3509                         field == "targets"
3510                         and isinstance(node, (ast3.Delete, ast27.Delete))
3511                         and isinstance(item, (ast3.Tuple, ast27.Tuple))
3512                     ):
3513                         for item in item.elts:
3514                             yield from _v(item, depth + 2)
3515                     elif isinstance(item, (ast3.AST, ast27.AST)):
3516                         yield from _v(item, depth + 2)
3517
3518             elif isinstance(value, (ast3.AST, ast27.AST)):
3519                 yield from _v(value, depth + 2)
3520
3521             else:
3522                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3523
3524         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3525
3526     try:
3527         src_ast = parse_ast(src)
3528     except Exception as exc:
3529         raise AssertionError(
3530             f"cannot use --safe with this file; failed to parse source file.  "
3531             f"AST error message: {exc}"
3532         )
3533
3534     try:
3535         dst_ast = parse_ast(dst)
3536     except Exception as exc:
3537         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3538         raise AssertionError(
3539             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3540             f"Please report a bug on https://github.com/python/black/issues.  "
3541             f"This invalid output might be helpful: {log}"
3542         ) from None
3543
3544     src_ast_str = "\n".join(_v(src_ast))
3545     dst_ast_str = "\n".join(_v(dst_ast))
3546     if src_ast_str != dst_ast_str:
3547         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3548         raise AssertionError(
3549             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3550             f"the source.  "
3551             f"Please report a bug on https://github.com/python/black/issues.  "
3552             f"This diff might be helpful: {log}"
3553         ) from None
3554
3555
3556 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3557     """Raise AssertionError if `dst` reformats differently the second time."""
3558     newdst = format_str(dst, mode=mode)
3559     if dst != newdst:
3560         log = dump_to_file(
3561             diff(src, dst, "source", "first pass"),
3562             diff(dst, newdst, "first pass", "second pass"),
3563         )
3564         raise AssertionError(
3565             f"INTERNAL ERROR: Black produced different code on the second pass "
3566             f"of the formatter.  "
3567             f"Please report a bug on https://github.com/python/black/issues.  "
3568             f"This diff might be helpful: {log}"
3569         ) from None
3570
3571
3572 def dump_to_file(*output: str) -> str:
3573     """Dump `output` to a temporary file. Return path to the file."""
3574     with tempfile.NamedTemporaryFile(
3575         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3576     ) as f:
3577         for lines in output:
3578             f.write(lines)
3579             if lines and lines[-1] != "\n":
3580                 f.write("\n")
3581     return f.name
3582
3583
3584 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3585     """Return a unified diff string between strings `a` and `b`."""
3586     import difflib
3587
3588     a_lines = [line + "\n" for line in a.split("\n")]
3589     b_lines = [line + "\n" for line in b.split("\n")]
3590     return "".join(
3591         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3592     )
3593
3594
3595 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3596     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3597     err("Aborted!")
3598     for task in tasks:
3599         task.cancel()
3600
3601
3602 def shutdown(loop: BaseEventLoop) -> None:
3603     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3604     try:
3605         if sys.version_info[:2] >= (3, 7):
3606             all_tasks = asyncio.all_tasks
3607         else:
3608             all_tasks = asyncio.Task.all_tasks
3609         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3610         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3611         if not to_cancel:
3612             return
3613
3614         for task in to_cancel:
3615             task.cancel()
3616         loop.run_until_complete(
3617             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3618         )
3619     finally:
3620         # `concurrent.futures.Future` objects cannot be cancelled once they
3621         # are already running. There might be some when the `shutdown()` happened.
3622         # Silence their logger's spew about the event loop being closed.
3623         cf_logger = logging.getLogger("concurrent.futures")
3624         cf_logger.setLevel(logging.CRITICAL)
3625         loop.close()
3626
3627
3628 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3629     """Replace `regex` with `replacement` twice on `original`.
3630
3631     This is used by string normalization to perform replaces on
3632     overlapping matches.
3633     """
3634     return regex.sub(replacement, regex.sub(replacement, original))
3635
3636
3637 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3638     """Compile a regular expression string in `regex`.
3639
3640     If it contains newlines, use verbose mode.
3641     """
3642     if "\n" in regex:
3643         regex = "(?x)" + regex
3644     return re.compile(regex)
3645
3646
3647 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3648     """Like `reversed(enumerate(sequence))` if that were possible."""
3649     index = len(sequence) - 1
3650     for element in reversed(sequence):
3651         yield (index, element)
3652         index -= 1
3653
3654
3655 def enumerate_with_length(
3656     line: Line, reversed: bool = False
3657 ) -> Iterator[Tuple[Index, Leaf, int]]:
3658     """Return an enumeration of leaves with their length.
3659
3660     Stops prematurely on multiline strings and standalone comments.
3661     """
3662     op = cast(
3663         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3664         enumerate_reversed if reversed else enumerate,
3665     )
3666     for index, leaf in op(line.leaves):
3667         length = len(leaf.prefix) + len(leaf.value)
3668         if "\n" in leaf.value:
3669             return  # Multiline strings, we can't continue.
3670
3671         for comment in line.comments_after(leaf):
3672             length += len(comment.value)
3673
3674         yield index, leaf, length
3675
3676
3677 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3678     """Return True if `line` is no longer than `line_length`.
3679
3680     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3681     """
3682     if not line_str:
3683         line_str = str(line).strip("\n")
3684     return (
3685         len(line_str) <= line_length
3686         and "\n" not in line_str  # multiline strings
3687         and not line.contains_standalone_comments()
3688     )
3689
3690
3691 def can_be_split(line: Line) -> bool:
3692     """Return False if the line cannot be split *for sure*.
3693
3694     This is not an exhaustive search but a cheap heuristic that we can use to
3695     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3696     in unnecessary parentheses).
3697     """
3698     leaves = line.leaves
3699     if len(leaves) < 2:
3700         return False
3701
3702     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3703         call_count = 0
3704         dot_count = 0
3705         next = leaves[-1]
3706         for leaf in leaves[-2::-1]:
3707             if leaf.type in OPENING_BRACKETS:
3708                 if next.type not in CLOSING_BRACKETS:
3709                     return False
3710
3711                 call_count += 1
3712             elif leaf.type == token.DOT:
3713                 dot_count += 1
3714             elif leaf.type == token.NAME:
3715                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3716                     return False
3717
3718             elif leaf.type not in CLOSING_BRACKETS:
3719                 return False
3720
3721             if dot_count > 1 and call_count > 1:
3722                 return False
3723
3724     return True
3725
3726
3727 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3728     """Does `line` have a shape safe to reformat without optional parens around it?
3729
3730     Returns True for only a subset of potentially nice looking formattings but
3731     the point is to not return false positives that end up producing lines that
3732     are too long.
3733     """
3734     bt = line.bracket_tracker
3735     if not bt.delimiters:
3736         # Without delimiters the optional parentheses are useless.
3737         return True
3738
3739     max_priority = bt.max_delimiter_priority()
3740     if bt.delimiter_count_with_priority(max_priority) > 1:
3741         # With more than one delimiter of a kind the optional parentheses read better.
3742         return False
3743
3744     if max_priority == DOT_PRIORITY:
3745         # A single stranded method call doesn't require optional parentheses.
3746         return True
3747
3748     assert len(line.leaves) >= 2, "Stranded delimiter"
3749
3750     first = line.leaves[0]
3751     second = line.leaves[1]
3752     penultimate = line.leaves[-2]
3753     last = line.leaves[-1]
3754
3755     # With a single delimiter, omit if the expression starts or ends with
3756     # a bracket.
3757     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3758         remainder = False
3759         length = 4 * line.depth
3760         for _index, leaf, leaf_length in enumerate_with_length(line):
3761             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3762                 remainder = True
3763             if remainder:
3764                 length += leaf_length
3765                 if length > line_length:
3766                     break
3767
3768                 if leaf.type in OPENING_BRACKETS:
3769                     # There are brackets we can further split on.
3770                     remainder = False
3771
3772         else:
3773             # checked the entire string and line length wasn't exceeded
3774             if len(line.leaves) == _index + 1:
3775                 return True
3776
3777         # Note: we are not returning False here because a line might have *both*
3778         # a leading opening bracket and a trailing closing bracket.  If the
3779         # opening bracket doesn't match our rule, maybe the closing will.
3780
3781     if (
3782         last.type == token.RPAR
3783         or last.type == token.RBRACE
3784         or (
3785             # don't use indexing for omitting optional parentheses;
3786             # it looks weird
3787             last.type == token.RSQB
3788             and last.parent
3789             and last.parent.type != syms.trailer
3790         )
3791     ):
3792         if penultimate.type in OPENING_BRACKETS:
3793             # Empty brackets don't help.
3794             return False
3795
3796         if is_multiline_string(first):
3797             # Additional wrapping of a multiline string in this situation is
3798             # unnecessary.
3799             return True
3800
3801         length = 4 * line.depth
3802         seen_other_brackets = False
3803         for _index, leaf, leaf_length in enumerate_with_length(line):
3804             length += leaf_length
3805             if leaf is last.opening_bracket:
3806                 if seen_other_brackets or length <= line_length:
3807                     return True
3808
3809             elif leaf.type in OPENING_BRACKETS:
3810                 # There are brackets we can further split on.
3811                 seen_other_brackets = True
3812
3813     return False
3814
3815
3816 def get_cache_file(mode: FileMode) -> Path:
3817     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3818
3819
3820 def read_cache(mode: FileMode) -> Cache:
3821     """Read the cache if it exists and is well formed.
3822
3823     If it is not well formed, the call to write_cache later should resolve the issue.
3824     """
3825     cache_file = get_cache_file(mode)
3826     if not cache_file.exists():
3827         return {}
3828
3829     with cache_file.open("rb") as fobj:
3830         try:
3831             cache: Cache = pickle.load(fobj)
3832         except pickle.UnpicklingError:
3833             return {}
3834
3835     return cache
3836
3837
3838 def get_cache_info(path: Path) -> CacheInfo:
3839     """Return the information used to check if a file is already formatted or not."""
3840     stat = path.stat()
3841     return stat.st_mtime, stat.st_size
3842
3843
3844 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3845     """Split an iterable of paths in `sources` into two sets.
3846
3847     The first contains paths of files that modified on disk or are not in the
3848     cache. The other contains paths to non-modified files.
3849     """
3850     todo, done = set(), set()
3851     for src in sources:
3852         src = src.resolve()
3853         if cache.get(src) != get_cache_info(src):
3854             todo.add(src)
3855         else:
3856             done.add(src)
3857     return todo, done
3858
3859
3860 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3861     """Update the cache file."""
3862     cache_file = get_cache_file(mode)
3863     try:
3864         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3865         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3866         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3867             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3868         os.replace(f.name, cache_file)
3869     except OSError:
3870         pass
3871
3872
3873 def patch_click() -> None:
3874     """Make Click not crash.
3875
3876     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3877     default which restricts paths that it can access during the lifetime of the
3878     application.  Click refuses to work in this scenario by raising a RuntimeError.
3879
3880     In case of Black the likelihood that non-ASCII characters are going to be used in
3881     file paths is minimal since it's Python source code.  Moreover, this crash was
3882     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3883     """
3884     try:
3885         from click import core
3886         from click import _unicodefun  # type: ignore
3887     except ModuleNotFoundError:
3888         return
3889
3890     for module in (core, _unicodefun):
3891         if hasattr(module, "_verify_python3_env"):
3892             module._verify_python3_env = lambda: None
3893
3894
3895 def patched_main() -> None:
3896     freeze_support()
3897     patch_click()
3898     main()
3899
3900
3901 if __name__ == "__main__":
3902     patched_main()