]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

200e15fbd071f897e4161f9cc5923fbbb96c9b8f
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "18.9b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA = 4
137
138
139 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
140     TargetVersion.PY27: set(),
141     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
142     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
144     TargetVersion.PY36: {
145         Feature.UNICODE_LITERALS,
146         Feature.F_STRINGS,
147         Feature.NUMERIC_UNDERSCORES,
148         Feature.TRAILING_COMMA,
149     },
150     TargetVersion.PY37: {
151         Feature.UNICODE_LITERALS,
152         Feature.F_STRINGS,
153         Feature.NUMERIC_UNDERSCORES,
154         Feature.TRAILING_COMMA,
155     },
156     TargetVersion.PY38: {
157         Feature.UNICODE_LITERALS,
158         Feature.F_STRINGS,
159         Feature.NUMERIC_UNDERSCORES,
160         Feature.TRAILING_COMMA,
161     },
162 }
163
164
165 @dataclass
166 class FileMode:
167     target_versions: Set[TargetVersion] = Factory(set)
168     line_length: int = DEFAULT_LINE_LENGTH
169     string_normalization: bool = True
170     is_pyi: bool = False
171
172     def get_cache_key(self) -> str:
173         if self.target_versions:
174             version_str = ",".join(
175                 str(version.value)
176                 for version in sorted(self.target_versions, key=lambda v: v.value)
177             )
178         else:
179             version_str = "-"
180         parts = [
181             version_str,
182             str(self.line_length),
183             str(int(self.string_normalization)),
184             str(int(self.is_pyi)),
185         ]
186         return ".".join(parts)
187
188
189 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
190     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
191
192
193 def read_pyproject_toml(
194     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
195 ) -> Optional[str]:
196     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
197
198     Returns the path to a successfully found and read configuration file, None
199     otherwise.
200     """
201     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
202     if not value:
203         root = find_project_root(ctx.params.get("src", ()))
204         path = root / "pyproject.toml"
205         if path.is_file():
206             value = str(path)
207         else:
208             return None
209
210     try:
211         pyproject_toml = toml.load(value)
212         config = pyproject_toml.get("tool", {}).get("black", {})
213     except (toml.TomlDecodeError, OSError) as e:
214         raise click.FileError(
215             filename=value, hint=f"Error reading configuration file: {e}"
216         )
217
218     if not config:
219         return None
220
221     if ctx.default_map is None:
222         ctx.default_map = {}
223     ctx.default_map.update(  # type: ignore  # bad types in .pyi
224         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
225     )
226     return value
227
228
229 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
230 @click.option(
231     "-l",
232     "--line-length",
233     type=int,
234     default=DEFAULT_LINE_LENGTH,
235     help="How many characters per line to allow.",
236     show_default=True,
237 )
238 @click.option(
239     "-t",
240     "--target-version",
241     type=click.Choice([v.name.lower() for v in TargetVersion]),
242     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
243     multiple=True,
244     help=(
245         "Python versions that should be supported by Black's output. [default: "
246         "per-file auto-detection]"
247     ),
248 )
249 @click.option(
250     "--pyi",
251     is_flag=True,
252     help=(
253         "Format all input files like typing stubs regardless of file extension "
254         "(useful when piping source on standard input)."
255     ),
256 )
257 @click.option(
258     "-S",
259     "--skip-string-normalization",
260     is_flag=True,
261     help="Don't normalize string quotes or prefixes.",
262 )
263 @click.option(
264     "--check",
265     is_flag=True,
266     help=(
267         "Don't write the files back, just return the status.  Return code 0 "
268         "means nothing would change.  Return code 1 means some files would be "
269         "reformatted.  Return code 123 means there was an internal error."
270     ),
271 )
272 @click.option(
273     "--diff",
274     is_flag=True,
275     help="Don't write the files back, just output a diff for each file on stdout.",
276 )
277 @click.option(
278     "--fast/--safe",
279     is_flag=True,
280     help="If --fast given, skip temporary sanity checks. [default: --safe]",
281 )
282 @click.option(
283     "--include",
284     type=str,
285     default=DEFAULT_INCLUDES,
286     help=(
287         "A regular expression that matches files and directories that should be "
288         "included on recursive searches.  An empty value means all files are "
289         "included regardless of the name.  Use forward slashes for directories on "
290         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
291         "later."
292     ),
293     show_default=True,
294 )
295 @click.option(
296     "--exclude",
297     type=str,
298     default=DEFAULT_EXCLUDES,
299     help=(
300         "A regular expression that matches files and directories that should be "
301         "excluded on recursive searches.  An empty value means no paths are excluded. "
302         "Use forward slashes for directories on all platforms (Windows, too).  "
303         "Exclusions are calculated first, inclusions later."
304     ),
305     show_default=True,
306 )
307 @click.option(
308     "-q",
309     "--quiet",
310     is_flag=True,
311     help=(
312         "Don't emit non-error messages to stderr. Errors are still emitted, "
313         "silence those with 2>/dev/null."
314     ),
315 )
316 @click.option(
317     "-v",
318     "--verbose",
319     is_flag=True,
320     help=(
321         "Also emit messages to stderr about files that were not changed or were "
322         "ignored due to --exclude=."
323     ),
324 )
325 @click.version_option(version=__version__)
326 @click.argument(
327     "src",
328     nargs=-1,
329     type=click.Path(
330         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
331     ),
332     is_eager=True,
333 )
334 @click.option(
335     "--config",
336     type=click.Path(
337         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
338     ),
339     is_eager=True,
340     callback=read_pyproject_toml,
341     help="Read configuration from PATH.",
342 )
343 @click.pass_context
344 def main(
345     ctx: click.Context,
346     line_length: int,
347     target_version: List[TargetVersion],
348     check: bool,
349     diff: bool,
350     fast: bool,
351     pyi: bool,
352     skip_string_normalization: bool,
353     quiet: bool,
354     verbose: bool,
355     include: str,
356     exclude: str,
357     src: Tuple[str],
358     config: Optional[str],
359 ) -> None:
360     """The uncompromising code formatter."""
361     write_back = WriteBack.from_configuration(check=check, diff=diff)
362     if target_version:
363         versions = set(target_version)
364     else:
365         # We'll autodetect later.
366         versions = set()
367     mode = FileMode(
368         target_versions=versions,
369         line_length=line_length,
370         is_pyi=pyi,
371         string_normalization=not skip_string_normalization,
372     )
373     if config and verbose:
374         out(f"Using configuration from {config}.", bold=False, fg="blue")
375     try:
376         include_regex = re_compile_maybe_verbose(include)
377     except re.error:
378         err(f"Invalid regular expression for include given: {include!r}")
379         ctx.exit(2)
380     try:
381         exclude_regex = re_compile_maybe_verbose(exclude)
382     except re.error:
383         err(f"Invalid regular expression for exclude given: {exclude!r}")
384         ctx.exit(2)
385     report = Report(check=check, quiet=quiet, verbose=verbose)
386     root = find_project_root(src)
387     sources: Set[Path] = set()
388     for s in src:
389         p = Path(s)
390         if p.is_dir():
391             sources.update(
392                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
393             )
394         elif p.is_file() or s == "-":
395             # if a file was explicitly given, we don't care about its extension
396             sources.add(p)
397         else:
398             err(f"invalid path: {s}")
399     if len(sources) == 0:
400         if verbose or not quiet:
401             out("No paths given. Nothing to do 😴")
402         ctx.exit(0)
403
404     if len(sources) == 1:
405         reformat_one(
406             src=sources.pop(),
407             fast=fast,
408             write_back=write_back,
409             mode=mode,
410             report=report,
411         )
412     else:
413         loop = asyncio.get_event_loop()
414         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
415         try:
416             loop.run_until_complete(
417                 schedule_formatting(
418                     sources=sources,
419                     fast=fast,
420                     write_back=write_back,
421                     mode=mode,
422                     report=report,
423                     loop=loop,
424                     executor=executor,
425                 )
426             )
427         finally:
428             shutdown(loop)
429     if verbose or not quiet:
430         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
431         out(f"All done! {bang}")
432         click.secho(str(report), err=True)
433     ctx.exit(report.return_code)
434
435
436 def reformat_one(
437     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
438 ) -> None:
439     """Reformat a single file under `src` without spawning child processes.
440
441     If `quiet` is True, non-error messages are not output. `line_length`,
442     `write_back`, `fast` and `pyi` options are passed to
443     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
444     """
445     try:
446         changed = Changed.NO
447         if not src.is_file() and str(src) == "-":
448             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
449                 changed = Changed.YES
450         else:
451             cache: Cache = {}
452             if write_back != WriteBack.DIFF:
453                 cache = read_cache(mode)
454                 res_src = src.resolve()
455                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
456                     changed = Changed.CACHED
457             if changed is not Changed.CACHED and format_file_in_place(
458                 src, fast=fast, write_back=write_back, mode=mode
459             ):
460                 changed = Changed.YES
461             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
462                 write_back is WriteBack.CHECK and changed is Changed.NO
463             ):
464                 write_cache(cache, [src], mode)
465         report.done(src, changed)
466     except Exception as exc:
467         report.failed(src, str(exc))
468
469
470 async def schedule_formatting(
471     sources: Set[Path],
472     fast: bool,
473     write_back: WriteBack,
474     mode: FileMode,
475     report: "Report",
476     loop: BaseEventLoop,
477     executor: Executor,
478 ) -> None:
479     """Run formatting of `sources` in parallel using the provided `executor`.
480
481     (Use ProcessPoolExecutors for actual parallelism.)
482
483     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
484     :func:`format_file_in_place`.
485     """
486     cache: Cache = {}
487     if write_back != WriteBack.DIFF:
488         cache = read_cache(mode)
489         sources, cached = filter_cached(cache, sources)
490         for src in sorted(cached):
491             report.done(src, Changed.CACHED)
492     if not sources:
493         return
494
495     cancelled = []
496     sources_to_cache = []
497     lock = None
498     if write_back == WriteBack.DIFF:
499         # For diff output, we need locks to ensure we don't interleave output
500         # from different processes.
501         manager = Manager()
502         lock = manager.Lock()
503     tasks = {
504         loop.run_in_executor(
505             executor, format_file_in_place, src, fast, mode, write_back, lock
506         ): src
507         for src in sorted(sources)
508     }
509     pending: Iterable[asyncio.Task] = tasks.keys()
510     try:
511         loop.add_signal_handler(signal.SIGINT, cancel, pending)
512         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
513     except NotImplementedError:
514         # There are no good alternatives for these on Windows.
515         pass
516     while pending:
517         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
518         for task in done:
519             src = tasks.pop(task)
520             if task.cancelled():
521                 cancelled.append(task)
522             elif task.exception():
523                 report.failed(src, str(task.exception()))
524             else:
525                 changed = Changed.YES if task.result() else Changed.NO
526                 # If the file was written back or was successfully checked as
527                 # well-formatted, store this information in the cache.
528                 if write_back is WriteBack.YES or (
529                     write_back is WriteBack.CHECK and changed is Changed.NO
530                 ):
531                     sources_to_cache.append(src)
532                 report.done(src, changed)
533     if cancelled:
534         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
535     if sources_to_cache:
536         write_cache(cache, sources_to_cache, mode)
537
538
539 def format_file_in_place(
540     src: Path,
541     fast: bool,
542     mode: FileMode,
543     write_back: WriteBack = WriteBack.NO,
544     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
545 ) -> bool:
546     """Format file under `src` path. Return True if changed.
547
548     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
549     code to the file.
550     `line_length` and `fast` options are passed to :func:`format_file_contents`.
551     """
552     if src.suffix == ".pyi":
553         mode = evolve(mode, is_pyi=True)
554
555     then = datetime.utcfromtimestamp(src.stat().st_mtime)
556     with open(src, "rb") as buf:
557         src_contents, encoding, newline = decode_bytes(buf.read())
558     try:
559         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
560     except NothingChanged:
561         return False
562
563     if write_back == write_back.YES:
564         with open(src, "w", encoding=encoding, newline=newline) as f:
565             f.write(dst_contents)
566     elif write_back == write_back.DIFF:
567         now = datetime.utcnow()
568         src_name = f"{src}\t{then} +0000"
569         dst_name = f"{src}\t{now} +0000"
570         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
571         if lock:
572             lock.acquire()
573         try:
574             f = io.TextIOWrapper(
575                 sys.stdout.buffer,
576                 encoding=encoding,
577                 newline=newline,
578                 write_through=True,
579             )
580             f.write(diff_contents)
581             f.detach()
582         finally:
583             if lock:
584                 lock.release()
585     return True
586
587
588 def format_stdin_to_stdout(
589     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
590 ) -> bool:
591     """Format file on stdin. Return True if changed.
592
593     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
594     write a diff to stdout. The `mode` argument is passed to
595     :func:`format_file_contents`.
596     """
597     then = datetime.utcnow()
598     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
599     dst = src
600     try:
601         dst = format_file_contents(src, fast=fast, mode=mode)
602         return True
603
604     except NothingChanged:
605         return False
606
607     finally:
608         f = io.TextIOWrapper(
609             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
610         )
611         if write_back == WriteBack.YES:
612             f.write(dst)
613         elif write_back == WriteBack.DIFF:
614             now = datetime.utcnow()
615             src_name = f"STDIN\t{then} +0000"
616             dst_name = f"STDOUT\t{now} +0000"
617             f.write(diff(src, dst, src_name, dst_name))
618         f.detach()
619
620
621 def format_file_contents(
622     src_contents: str, *, fast: bool, mode: FileMode
623 ) -> FileContent:
624     """Reformat contents a file and return new contents.
625
626     If `fast` is False, additionally confirm that the reformatted code is
627     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
628     `line_length` is passed to :func:`format_str`.
629     """
630     if src_contents.strip() == "":
631         raise NothingChanged
632
633     dst_contents = format_str(src_contents, mode=mode)
634     if src_contents == dst_contents:
635         raise NothingChanged
636
637     if not fast:
638         assert_equivalent(src_contents, dst_contents)
639         assert_stable(src_contents, dst_contents, mode=mode)
640     return dst_contents
641
642
643 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
644     """Reformat a string and return new contents.
645
646     `line_length` determines how many characters per line are allowed.
647     """
648     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
649     dst_contents = ""
650     future_imports = get_future_imports(src_node)
651     if mode.target_versions:
652         versions = mode.target_versions
653     else:
654         versions = detect_target_versions(src_node)
655     normalize_fmt_off(src_node)
656     lines = LineGenerator(
657         remove_u_prefix="unicode_literals" in future_imports
658         or supports_feature(versions, Feature.UNICODE_LITERALS),
659         is_pyi=mode.is_pyi,
660         normalize_strings=mode.string_normalization,
661     )
662     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
663     empty_line = Line()
664     after = 0
665     for current_line in lines.visit(src_node):
666         for _ in range(after):
667             dst_contents += str(empty_line)
668         before, after = elt.maybe_empty_lines(current_line)
669         for _ in range(before):
670             dst_contents += str(empty_line)
671         for line in split_line(
672             current_line,
673             line_length=mode.line_length,
674             supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
675         ):
676             dst_contents += str(line)
677     return dst_contents
678
679
680 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
681     """Return a tuple of (decoded_contents, encoding, newline).
682
683     `newline` is either CRLF or LF but `decoded_contents` is decoded with
684     universal newlines (i.e. only contains LF).
685     """
686     srcbuf = io.BytesIO(src)
687     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
688     if not lines:
689         return "", encoding, "\n"
690
691     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
692     srcbuf.seek(0)
693     with io.TextIOWrapper(srcbuf, encoding) as tiow:
694         return tiow.read(), encoding, newline
695
696
697 GRAMMARS = [
698     pygram.python_grammar_no_print_statement_no_exec_statement,
699     pygram.python_grammar_no_print_statement,
700     pygram.python_grammar,
701 ]
702
703
704 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
705     if not target_versions:
706         return GRAMMARS
707     elif all(not version.is_python2() for version in target_versions):
708         # Python 2-compatible code, so don't try Python 3 grammar.
709         return [
710             pygram.python_grammar_no_print_statement_no_exec_statement,
711             pygram.python_grammar_no_print_statement,
712         ]
713     else:
714         return [pygram.python_grammar]
715
716
717 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
718     """Given a string with source, return the lib2to3 Node."""
719     if src_txt[-1:] != "\n":
720         src_txt += "\n"
721
722     for grammar in get_grammars(set(target_versions)):
723         drv = driver.Driver(grammar, pytree.convert)
724         try:
725             result = drv.parse_string(src_txt, True)
726             break
727
728         except ParseError as pe:
729             lineno, column = pe.context[1]
730             lines = src_txt.splitlines()
731             try:
732                 faulty_line = lines[lineno - 1]
733             except IndexError:
734                 faulty_line = "<line number missing in source>"
735             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
736     else:
737         raise exc from None
738
739     if isinstance(result, Leaf):
740         result = Node(syms.file_input, [result])
741     return result
742
743
744 def lib2to3_unparse(node: Node) -> str:
745     """Given a lib2to3 node, return its string representation."""
746     code = str(node)
747     return code
748
749
750 T = TypeVar("T")
751
752
753 class Visitor(Generic[T]):
754     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
755
756     def visit(self, node: LN) -> Iterator[T]:
757         """Main method to visit `node` and its children.
758
759         It tries to find a `visit_*()` method for the given `node.type`, like
760         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
761         If no dedicated `visit_*()` method is found, chooses `visit_default()`
762         instead.
763
764         Then yields objects of type `T` from the selected visitor.
765         """
766         if node.type < 256:
767             name = token.tok_name[node.type]
768         else:
769             name = type_repr(node.type)
770         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
771
772     def visit_default(self, node: LN) -> Iterator[T]:
773         """Default `visit_*()` implementation. Recurses to children of `node`."""
774         if isinstance(node, Node):
775             for child in node.children:
776                 yield from self.visit(child)
777
778
779 @dataclass
780 class DebugVisitor(Visitor[T]):
781     tree_depth: int = 0
782
783     def visit_default(self, node: LN) -> Iterator[T]:
784         indent = " " * (2 * self.tree_depth)
785         if isinstance(node, Node):
786             _type = type_repr(node.type)
787             out(f"{indent}{_type}", fg="yellow")
788             self.tree_depth += 1
789             for child in node.children:
790                 yield from self.visit(child)
791
792             self.tree_depth -= 1
793             out(f"{indent}/{_type}", fg="yellow", bold=False)
794         else:
795             _type = token.tok_name.get(node.type, str(node.type))
796             out(f"{indent}{_type}", fg="blue", nl=False)
797             if node.prefix:
798                 # We don't have to handle prefixes for `Node` objects since
799                 # that delegates to the first child anyway.
800                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
801             out(f" {node.value!r}", fg="blue", bold=False)
802
803     @classmethod
804     def show(cls, code: Union[str, Leaf, Node]) -> None:
805         """Pretty-print the lib2to3 AST of a given string of `code`.
806
807         Convenience method for debugging.
808         """
809         v: DebugVisitor[None] = DebugVisitor()
810         if isinstance(code, str):
811             code = lib2to3_parse(code)
812         list(v.visit(code))
813
814
815 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
816 STATEMENT = {
817     syms.if_stmt,
818     syms.while_stmt,
819     syms.for_stmt,
820     syms.try_stmt,
821     syms.except_clause,
822     syms.with_stmt,
823     syms.funcdef,
824     syms.classdef,
825 }
826 STANDALONE_COMMENT = 153
827 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
828 LOGIC_OPERATORS = {"and", "or"}
829 COMPARATORS = {
830     token.LESS,
831     token.GREATER,
832     token.EQEQUAL,
833     token.NOTEQUAL,
834     token.LESSEQUAL,
835     token.GREATEREQUAL,
836 }
837 MATH_OPERATORS = {
838     token.VBAR,
839     token.CIRCUMFLEX,
840     token.AMPER,
841     token.LEFTSHIFT,
842     token.RIGHTSHIFT,
843     token.PLUS,
844     token.MINUS,
845     token.STAR,
846     token.SLASH,
847     token.DOUBLESLASH,
848     token.PERCENT,
849     token.AT,
850     token.TILDE,
851     token.DOUBLESTAR,
852 }
853 STARS = {token.STAR, token.DOUBLESTAR}
854 VARARGS_PARENTS = {
855     syms.arglist,
856     syms.argument,  # double star in arglist
857     syms.trailer,  # single argument to call
858     syms.typedargslist,
859     syms.varargslist,  # lambdas
860 }
861 UNPACKING_PARENTS = {
862     syms.atom,  # single element of a list or set literal
863     syms.dictsetmaker,
864     syms.listmaker,
865     syms.testlist_gexp,
866     syms.testlist_star_expr,
867 }
868 TEST_DESCENDANTS = {
869     syms.test,
870     syms.lambdef,
871     syms.or_test,
872     syms.and_test,
873     syms.not_test,
874     syms.comparison,
875     syms.star_expr,
876     syms.expr,
877     syms.xor_expr,
878     syms.and_expr,
879     syms.shift_expr,
880     syms.arith_expr,
881     syms.trailer,
882     syms.term,
883     syms.power,
884 }
885 ASSIGNMENTS = {
886     "=",
887     "+=",
888     "-=",
889     "*=",
890     "@=",
891     "/=",
892     "%=",
893     "&=",
894     "|=",
895     "^=",
896     "<<=",
897     ">>=",
898     "**=",
899     "//=",
900 }
901 COMPREHENSION_PRIORITY = 20
902 COMMA_PRIORITY = 18
903 TERNARY_PRIORITY = 16
904 LOGIC_PRIORITY = 14
905 STRING_PRIORITY = 12
906 COMPARATOR_PRIORITY = 10
907 MATH_PRIORITIES = {
908     token.VBAR: 9,
909     token.CIRCUMFLEX: 8,
910     token.AMPER: 7,
911     token.LEFTSHIFT: 6,
912     token.RIGHTSHIFT: 6,
913     token.PLUS: 5,
914     token.MINUS: 5,
915     token.STAR: 4,
916     token.SLASH: 4,
917     token.DOUBLESLASH: 4,
918     token.PERCENT: 4,
919     token.AT: 4,
920     token.TILDE: 3,
921     token.DOUBLESTAR: 2,
922 }
923 DOT_PRIORITY = 1
924
925
926 @dataclass
927 class BracketTracker:
928     """Keeps track of brackets on a line."""
929
930     depth: int = 0
931     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
932     delimiters: Dict[LeafID, Priority] = Factory(dict)
933     previous: Optional[Leaf] = None
934     _for_loop_depths: List[int] = Factory(list)
935     _lambda_argument_depths: List[int] = Factory(list)
936
937     def mark(self, leaf: Leaf) -> None:
938         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
939
940         All leaves receive an int `bracket_depth` field that stores how deep
941         within brackets a given leaf is. 0 means there are no enclosing brackets
942         that started on this line.
943
944         If a leaf is itself a closing bracket, it receives an `opening_bracket`
945         field that it forms a pair with. This is a one-directional link to
946         avoid reference cycles.
947
948         If a leaf is a delimiter (a token on which Black can split the line if
949         needed) and it's on depth 0, its `id()` is stored in the tracker's
950         `delimiters` field.
951         """
952         if leaf.type == token.COMMENT:
953             return
954
955         self.maybe_decrement_after_for_loop_variable(leaf)
956         self.maybe_decrement_after_lambda_arguments(leaf)
957         if leaf.type in CLOSING_BRACKETS:
958             self.depth -= 1
959             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
960             leaf.opening_bracket = opening_bracket
961         leaf.bracket_depth = self.depth
962         if self.depth == 0:
963             delim = is_split_before_delimiter(leaf, self.previous)
964             if delim and self.previous is not None:
965                 self.delimiters[id(self.previous)] = delim
966             else:
967                 delim = is_split_after_delimiter(leaf, self.previous)
968                 if delim:
969                     self.delimiters[id(leaf)] = delim
970         if leaf.type in OPENING_BRACKETS:
971             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
972             self.depth += 1
973         self.previous = leaf
974         self.maybe_increment_lambda_arguments(leaf)
975         self.maybe_increment_for_loop_variable(leaf)
976
977     def any_open_brackets(self) -> bool:
978         """Return True if there is an yet unmatched open bracket on the line."""
979         return bool(self.bracket_match)
980
981     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
982         """Return the highest priority of a delimiter found on the line.
983
984         Values are consistent with what `is_split_*_delimiter()` return.
985         Raises ValueError on no delimiters.
986         """
987         return max(v for k, v in self.delimiters.items() if k not in exclude)
988
989     def delimiter_count_with_priority(self, priority: int = 0) -> int:
990         """Return the number of delimiters with the given `priority`.
991
992         If no `priority` is passed, defaults to max priority on the line.
993         """
994         if not self.delimiters:
995             return 0
996
997         priority = priority or self.max_delimiter_priority()
998         return sum(1 for p in self.delimiters.values() if p == priority)
999
1000     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1001         """In a for loop, or comprehension, the variables are often unpacks.
1002
1003         To avoid splitting on the comma in this situation, increase the depth of
1004         tokens between `for` and `in`.
1005         """
1006         if leaf.type == token.NAME and leaf.value == "for":
1007             self.depth += 1
1008             self._for_loop_depths.append(self.depth)
1009             return True
1010
1011         return False
1012
1013     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1014         """See `maybe_increment_for_loop_variable` above for explanation."""
1015         if (
1016             self._for_loop_depths
1017             and self._for_loop_depths[-1] == self.depth
1018             and leaf.type == token.NAME
1019             and leaf.value == "in"
1020         ):
1021             self.depth -= 1
1022             self._for_loop_depths.pop()
1023             return True
1024
1025         return False
1026
1027     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1028         """In a lambda expression, there might be more than one argument.
1029
1030         To avoid splitting on the comma in this situation, increase the depth of
1031         tokens between `lambda` and `:`.
1032         """
1033         if leaf.type == token.NAME and leaf.value == "lambda":
1034             self.depth += 1
1035             self._lambda_argument_depths.append(self.depth)
1036             return True
1037
1038         return False
1039
1040     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1041         """See `maybe_increment_lambda_arguments` above for explanation."""
1042         if (
1043             self._lambda_argument_depths
1044             and self._lambda_argument_depths[-1] == self.depth
1045             and leaf.type == token.COLON
1046         ):
1047             self.depth -= 1
1048             self._lambda_argument_depths.pop()
1049             return True
1050
1051         return False
1052
1053     def get_open_lsqb(self) -> Optional[Leaf]:
1054         """Return the most recent opening square bracket (if any)."""
1055         return self.bracket_match.get((self.depth - 1, token.RSQB))
1056
1057
1058 @dataclass
1059 class Line:
1060     """Holds leaves and comments. Can be printed with `str(line)`."""
1061
1062     depth: int = 0
1063     leaves: List[Leaf] = Factory(list)
1064     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1065     bracket_tracker: BracketTracker = Factory(BracketTracker)
1066     inside_brackets: bool = False
1067     should_explode: bool = False
1068
1069     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1070         """Add a new `leaf` to the end of the line.
1071
1072         Unless `preformatted` is True, the `leaf` will receive a new consistent
1073         whitespace prefix and metadata applied by :class:`BracketTracker`.
1074         Trailing commas are maybe removed, unpacked for loop variables are
1075         demoted from being delimiters.
1076
1077         Inline comments are put aside.
1078         """
1079         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1080         if not has_value:
1081             return
1082
1083         if token.COLON == leaf.type and self.is_class_paren_empty:
1084             del self.leaves[-2:]
1085         if self.leaves and not preformatted:
1086             # Note: at this point leaf.prefix should be empty except for
1087             # imports, for which we only preserve newlines.
1088             leaf.prefix += whitespace(
1089                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1090             )
1091         if self.inside_brackets or not preformatted:
1092             self.bracket_tracker.mark(leaf)
1093             self.maybe_remove_trailing_comma(leaf)
1094         if not self.append_comment(leaf):
1095             self.leaves.append(leaf)
1096
1097     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1098         """Like :func:`append()` but disallow invalid standalone comment structure.
1099
1100         Raises ValueError when any `leaf` is appended after a standalone comment
1101         or when a standalone comment is not the first leaf on the line.
1102         """
1103         if self.bracket_tracker.depth == 0:
1104             if self.is_comment:
1105                 raise ValueError("cannot append to standalone comments")
1106
1107             if self.leaves and leaf.type == STANDALONE_COMMENT:
1108                 raise ValueError(
1109                     "cannot append standalone comments to a populated line"
1110                 )
1111
1112         self.append(leaf, preformatted=preformatted)
1113
1114     @property
1115     def is_comment(self) -> bool:
1116         """Is this line a standalone comment?"""
1117         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1118
1119     @property
1120     def is_decorator(self) -> bool:
1121         """Is this line a decorator?"""
1122         return bool(self) and self.leaves[0].type == token.AT
1123
1124     @property
1125     def is_import(self) -> bool:
1126         """Is this an import line?"""
1127         return bool(self) and is_import(self.leaves[0])
1128
1129     @property
1130     def is_class(self) -> bool:
1131         """Is this line a class definition?"""
1132         return (
1133             bool(self)
1134             and self.leaves[0].type == token.NAME
1135             and self.leaves[0].value == "class"
1136         )
1137
1138     @property
1139     def is_stub_class(self) -> bool:
1140         """Is this line a class definition with a body consisting only of "..."?"""
1141         return self.is_class and self.leaves[-3:] == [
1142             Leaf(token.DOT, ".") for _ in range(3)
1143         ]
1144
1145     @property
1146     def is_def(self) -> bool:
1147         """Is this a function definition? (Also returns True for async defs.)"""
1148         try:
1149             first_leaf = self.leaves[0]
1150         except IndexError:
1151             return False
1152
1153         try:
1154             second_leaf: Optional[Leaf] = self.leaves[1]
1155         except IndexError:
1156             second_leaf = None
1157         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1158             first_leaf.type == token.ASYNC
1159             and second_leaf is not None
1160             and second_leaf.type == token.NAME
1161             and second_leaf.value == "def"
1162         )
1163
1164     @property
1165     def is_class_paren_empty(self) -> bool:
1166         """Is this a class with no base classes but using parentheses?
1167
1168         Those are unnecessary and should be removed.
1169         """
1170         return (
1171             bool(self)
1172             and len(self.leaves) == 4
1173             and self.is_class
1174             and self.leaves[2].type == token.LPAR
1175             and self.leaves[2].value == "("
1176             and self.leaves[3].type == token.RPAR
1177             and self.leaves[3].value == ")"
1178         )
1179
1180     @property
1181     def is_triple_quoted_string(self) -> bool:
1182         """Is the line a triple quoted string?"""
1183         return (
1184             bool(self)
1185             and self.leaves[0].type == token.STRING
1186             and self.leaves[0].value.startswith(('"""', "'''"))
1187         )
1188
1189     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1190         """If so, needs to be split before emitting."""
1191         for leaf in self.leaves:
1192             if leaf.type == STANDALONE_COMMENT:
1193                 if leaf.bracket_depth <= depth_limit:
1194                     return True
1195
1196         return False
1197
1198     def contains_multiline_strings(self) -> bool:
1199         for leaf in self.leaves:
1200             if is_multiline_string(leaf):
1201                 return True
1202
1203         return False
1204
1205     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1206         """Remove trailing comma if there is one and it's safe."""
1207         if not (
1208             self.leaves
1209             and self.leaves[-1].type == token.COMMA
1210             and closing.type in CLOSING_BRACKETS
1211         ):
1212             return False
1213
1214         if closing.type == token.RBRACE:
1215             self.remove_trailing_comma()
1216             return True
1217
1218         if closing.type == token.RSQB:
1219             comma = self.leaves[-1]
1220             if comma.parent and comma.parent.type == syms.listmaker:
1221                 self.remove_trailing_comma()
1222                 return True
1223
1224         # For parens let's check if it's safe to remove the comma.
1225         # Imports are always safe.
1226         if self.is_import:
1227             self.remove_trailing_comma()
1228             return True
1229
1230         # Otherwise, if the trailing one is the only one, we might mistakenly
1231         # change a tuple into a different type by removing the comma.
1232         depth = closing.bracket_depth + 1
1233         commas = 0
1234         opening = closing.opening_bracket
1235         for _opening_index, leaf in enumerate(self.leaves):
1236             if leaf is opening:
1237                 break
1238
1239         else:
1240             return False
1241
1242         for leaf in self.leaves[_opening_index + 1 :]:
1243             if leaf is closing:
1244                 break
1245
1246             bracket_depth = leaf.bracket_depth
1247             if bracket_depth == depth and leaf.type == token.COMMA:
1248                 commas += 1
1249                 if leaf.parent and leaf.parent.type == syms.arglist:
1250                     commas += 1
1251                     break
1252
1253         if commas > 1:
1254             self.remove_trailing_comma()
1255             return True
1256
1257         return False
1258
1259     def append_comment(self, comment: Leaf) -> bool:
1260         """Add an inline or standalone comment to the line."""
1261         if (
1262             comment.type == STANDALONE_COMMENT
1263             and self.bracket_tracker.any_open_brackets()
1264         ):
1265             comment.prefix = ""
1266             return False
1267
1268         if comment.type != token.COMMENT:
1269             return False
1270
1271         if not self.leaves:
1272             comment.type = STANDALONE_COMMENT
1273             comment.prefix = ""
1274             return False
1275
1276         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1277         return True
1278
1279     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1280         """Generate comments that should appear directly after `leaf`."""
1281         return self.comments.get(id(leaf), [])
1282
1283     def remove_trailing_comma(self) -> None:
1284         """Remove the trailing comma and moves the comments attached to it."""
1285         trailing_comma = self.leaves.pop()
1286         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1287         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1288             trailing_comma_comments
1289         )
1290
1291     def is_complex_subscript(self, leaf: Leaf) -> bool:
1292         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1293         open_lsqb = self.bracket_tracker.get_open_lsqb()
1294         if open_lsqb is None:
1295             return False
1296
1297         subscript_start = open_lsqb.next_sibling
1298
1299         if isinstance(subscript_start, Node):
1300             if subscript_start.type == syms.listmaker:
1301                 return False
1302
1303             if subscript_start.type == syms.subscriptlist:
1304                 subscript_start = child_towards(subscript_start, leaf)
1305         return subscript_start is not None and any(
1306             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1307         )
1308
1309     def __str__(self) -> str:
1310         """Render the line."""
1311         if not self:
1312             return "\n"
1313
1314         indent = "    " * self.depth
1315         leaves = iter(self.leaves)
1316         first = next(leaves)
1317         res = f"{first.prefix}{indent}{first.value}"
1318         for leaf in leaves:
1319             res += str(leaf)
1320         for comment in itertools.chain.from_iterable(self.comments.values()):
1321             res += str(comment)
1322         return res + "\n"
1323
1324     def __bool__(self) -> bool:
1325         """Return True if the line has leaves or comments."""
1326         return bool(self.leaves or self.comments)
1327
1328
1329 @dataclass
1330 class EmptyLineTracker:
1331     """Provides a stateful method that returns the number of potential extra
1332     empty lines needed before and after the currently processed line.
1333
1334     Note: this tracker works on lines that haven't been split yet.  It assumes
1335     the prefix of the first leaf consists of optional newlines.  Those newlines
1336     are consumed by `maybe_empty_lines()` and included in the computation.
1337     """
1338
1339     is_pyi: bool = False
1340     previous_line: Optional[Line] = None
1341     previous_after: int = 0
1342     previous_defs: List[int] = Factory(list)
1343
1344     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1345         """Return the number of extra empty lines before and after the `current_line`.
1346
1347         This is for separating `def`, `async def` and `class` with extra empty
1348         lines (two on module-level).
1349         """
1350         before, after = self._maybe_empty_lines(current_line)
1351         before -= self.previous_after
1352         self.previous_after = after
1353         self.previous_line = current_line
1354         return before, after
1355
1356     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1357         max_allowed = 1
1358         if current_line.depth == 0:
1359             max_allowed = 1 if self.is_pyi else 2
1360         if current_line.leaves:
1361             # Consume the first leaf's extra newlines.
1362             first_leaf = current_line.leaves[0]
1363             before = first_leaf.prefix.count("\n")
1364             before = min(before, max_allowed)
1365             first_leaf.prefix = ""
1366         else:
1367             before = 0
1368         depth = current_line.depth
1369         while self.previous_defs and self.previous_defs[-1] >= depth:
1370             self.previous_defs.pop()
1371             if self.is_pyi:
1372                 before = 0 if depth else 1
1373             else:
1374                 before = 1 if depth else 2
1375         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1376             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1377
1378         if (
1379             self.previous_line
1380             and self.previous_line.is_import
1381             and not current_line.is_import
1382             and depth == self.previous_line.depth
1383         ):
1384             return (before or 1), 0
1385
1386         if (
1387             self.previous_line
1388             and self.previous_line.is_class
1389             and current_line.is_triple_quoted_string
1390         ):
1391             return before, 1
1392
1393         return before, 0
1394
1395     def _maybe_empty_lines_for_class_or_def(
1396         self, current_line: Line, before: int
1397     ) -> Tuple[int, int]:
1398         if not current_line.is_decorator:
1399             self.previous_defs.append(current_line.depth)
1400         if self.previous_line is None:
1401             # Don't insert empty lines before the first line in the file.
1402             return 0, 0
1403
1404         if self.previous_line.is_decorator:
1405             return 0, 0
1406
1407         if self.previous_line.depth < current_line.depth and (
1408             self.previous_line.is_class or self.previous_line.is_def
1409         ):
1410             return 0, 0
1411
1412         if (
1413             self.previous_line.is_comment
1414             and self.previous_line.depth == current_line.depth
1415             and before == 0
1416         ):
1417             return 0, 0
1418
1419         if self.is_pyi:
1420             if self.previous_line.depth > current_line.depth:
1421                 newlines = 1
1422             elif current_line.is_class or self.previous_line.is_class:
1423                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1424                     # No blank line between classes with an empty body
1425                     newlines = 0
1426                 else:
1427                     newlines = 1
1428             elif current_line.is_def and not self.previous_line.is_def:
1429                 # Blank line between a block of functions and a block of non-functions
1430                 newlines = 1
1431             else:
1432                 newlines = 0
1433         else:
1434             newlines = 2
1435         if current_line.depth and newlines:
1436             newlines -= 1
1437         return newlines, 0
1438
1439
1440 @dataclass
1441 class LineGenerator(Visitor[Line]):
1442     """Generates reformatted Line objects.  Empty lines are not emitted.
1443
1444     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1445     in ways that will no longer stringify to valid Python code on the tree.
1446     """
1447
1448     is_pyi: bool = False
1449     normalize_strings: bool = True
1450     current_line: Line = Factory(Line)
1451     remove_u_prefix: bool = False
1452
1453     def line(self, indent: int = 0) -> Iterator[Line]:
1454         """Generate a line.
1455
1456         If the line is empty, only emit if it makes sense.
1457         If the line is too long, split it first and then generate.
1458
1459         If any lines were generated, set up a new current_line.
1460         """
1461         if not self.current_line:
1462             self.current_line.depth += indent
1463             return  # Line is empty, don't emit. Creating a new one unnecessary.
1464
1465         complete_line = self.current_line
1466         self.current_line = Line(depth=complete_line.depth + indent)
1467         yield complete_line
1468
1469     def visit_default(self, node: LN) -> Iterator[Line]:
1470         """Default `visit_*()` implementation. Recurses to children of `node`."""
1471         if isinstance(node, Leaf):
1472             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1473             for comment in generate_comments(node):
1474                 if any_open_brackets:
1475                     # any comment within brackets is subject to splitting
1476                     self.current_line.append(comment)
1477                 elif comment.type == token.COMMENT:
1478                     # regular trailing comment
1479                     self.current_line.append(comment)
1480                     yield from self.line()
1481
1482                 else:
1483                     # regular standalone comment
1484                     yield from self.line()
1485
1486                     self.current_line.append(comment)
1487                     yield from self.line()
1488
1489             normalize_prefix(node, inside_brackets=any_open_brackets)
1490             if self.normalize_strings and node.type == token.STRING:
1491                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1492                 normalize_string_quotes(node)
1493             if node.type == token.NUMBER:
1494                 normalize_numeric_literal(node)
1495             if node.type not in WHITESPACE:
1496                 self.current_line.append(node)
1497         yield from super().visit_default(node)
1498
1499     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1500         """Increase indentation level, maybe yield a line."""
1501         # In blib2to3 INDENT never holds comments.
1502         yield from self.line(+1)
1503         yield from self.visit_default(node)
1504
1505     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1506         """Decrease indentation level, maybe yield a line."""
1507         # The current line might still wait for trailing comments.  At DEDENT time
1508         # there won't be any (they would be prefixes on the preceding NEWLINE).
1509         # Emit the line then.
1510         yield from self.line()
1511
1512         # While DEDENT has no value, its prefix may contain standalone comments
1513         # that belong to the current indentation level.  Get 'em.
1514         yield from self.visit_default(node)
1515
1516         # Finally, emit the dedent.
1517         yield from self.line(-1)
1518
1519     def visit_stmt(
1520         self, node: Node, keywords: Set[str], parens: Set[str]
1521     ) -> Iterator[Line]:
1522         """Visit a statement.
1523
1524         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1525         `def`, `with`, `class`, `assert` and assignments.
1526
1527         The relevant Python language `keywords` for a given statement will be
1528         NAME leaves within it. This methods puts those on a separate line.
1529
1530         `parens` holds a set of string leaf values immediately after which
1531         invisible parens should be put.
1532         """
1533         normalize_invisible_parens(node, parens_after=parens)
1534         for child in node.children:
1535             if child.type == token.NAME and child.value in keywords:  # type: ignore
1536                 yield from self.line()
1537
1538             yield from self.visit(child)
1539
1540     def visit_suite(self, node: Node) -> Iterator[Line]:
1541         """Visit a suite."""
1542         if self.is_pyi and is_stub_suite(node):
1543             yield from self.visit(node.children[2])
1544         else:
1545             yield from self.visit_default(node)
1546
1547     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1548         """Visit a statement without nested statements."""
1549         is_suite_like = node.parent and node.parent.type in STATEMENT
1550         if is_suite_like:
1551             if self.is_pyi and is_stub_body(node):
1552                 yield from self.visit_default(node)
1553             else:
1554                 yield from self.line(+1)
1555                 yield from self.visit_default(node)
1556                 yield from self.line(-1)
1557
1558         else:
1559             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1560                 yield from self.line()
1561             yield from self.visit_default(node)
1562
1563     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1564         """Visit `async def`, `async for`, `async with`."""
1565         yield from self.line()
1566
1567         children = iter(node.children)
1568         for child in children:
1569             yield from self.visit(child)
1570
1571             if child.type == token.ASYNC:
1572                 break
1573
1574         internal_stmt = next(children)
1575         for child in internal_stmt.children:
1576             yield from self.visit(child)
1577
1578     def visit_decorators(self, node: Node) -> Iterator[Line]:
1579         """Visit decorators."""
1580         for child in node.children:
1581             yield from self.line()
1582             yield from self.visit(child)
1583
1584     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1585         """Remove a semicolon and put the other statement on a separate line."""
1586         yield from self.line()
1587
1588     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1589         """End of file. Process outstanding comments and end with a newline."""
1590         yield from self.visit_default(leaf)
1591         yield from self.line()
1592
1593     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1594         if not self.current_line.bracket_tracker.any_open_brackets():
1595             yield from self.line()
1596         yield from self.visit_default(leaf)
1597
1598     def __attrs_post_init__(self) -> None:
1599         """You are in a twisty little maze of passages."""
1600         v = self.visit_stmt
1601         Ø: Set[str] = set()
1602         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1603         self.visit_if_stmt = partial(
1604             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1605         )
1606         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1607         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1608         self.visit_try_stmt = partial(
1609             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1610         )
1611         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1612         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1613         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1614         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1615         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1616         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1617         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1618         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1619         self.visit_async_funcdef = self.visit_async_stmt
1620         self.visit_decorated = self.visit_decorators
1621
1622
1623 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1624 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1625 OPENING_BRACKETS = set(BRACKET.keys())
1626 CLOSING_BRACKETS = set(BRACKET.values())
1627 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1628 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1629
1630
1631 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1632     """Return whitespace prefix if needed for the given `leaf`.
1633
1634     `complex_subscript` signals whether the given leaf is part of a subscription
1635     which has non-trivial arguments, like arithmetic expressions or function calls.
1636     """
1637     NO = ""
1638     SPACE = " "
1639     DOUBLESPACE = "  "
1640     t = leaf.type
1641     p = leaf.parent
1642     v = leaf.value
1643     if t in ALWAYS_NO_SPACE:
1644         return NO
1645
1646     if t == token.COMMENT:
1647         return DOUBLESPACE
1648
1649     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1650     if t == token.COLON and p.type not in {
1651         syms.subscript,
1652         syms.subscriptlist,
1653         syms.sliceop,
1654     }:
1655         return NO
1656
1657     prev = leaf.prev_sibling
1658     if not prev:
1659         prevp = preceding_leaf(p)
1660         if not prevp or prevp.type in OPENING_BRACKETS:
1661             return NO
1662
1663         if t == token.COLON:
1664             if prevp.type == token.COLON:
1665                 return NO
1666
1667             elif prevp.type != token.COMMA and not complex_subscript:
1668                 return NO
1669
1670             return SPACE
1671
1672         if prevp.type == token.EQUAL:
1673             if prevp.parent:
1674                 if prevp.parent.type in {
1675                     syms.arglist,
1676                     syms.argument,
1677                     syms.parameters,
1678                     syms.varargslist,
1679                 }:
1680                     return NO
1681
1682                 elif prevp.parent.type == syms.typedargslist:
1683                     # A bit hacky: if the equal sign has whitespace, it means we
1684                     # previously found it's a typed argument.  So, we're using
1685                     # that, too.
1686                     return prevp.prefix
1687
1688         elif prevp.type in STARS:
1689             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1690                 return NO
1691
1692         elif prevp.type == token.COLON:
1693             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1694                 return SPACE if complex_subscript else NO
1695
1696         elif (
1697             prevp.parent
1698             and prevp.parent.type == syms.factor
1699             and prevp.type in MATH_OPERATORS
1700         ):
1701             return NO
1702
1703         elif (
1704             prevp.type == token.RIGHTSHIFT
1705             and prevp.parent
1706             and prevp.parent.type == syms.shift_expr
1707             and prevp.prev_sibling
1708             and prevp.prev_sibling.type == token.NAME
1709             and prevp.prev_sibling.value == "print"  # type: ignore
1710         ):
1711             # Python 2 print chevron
1712             return NO
1713
1714     elif prev.type in OPENING_BRACKETS:
1715         return NO
1716
1717     if p.type in {syms.parameters, syms.arglist}:
1718         # untyped function signatures or calls
1719         if not prev or prev.type != token.COMMA:
1720             return NO
1721
1722     elif p.type == syms.varargslist:
1723         # lambdas
1724         if prev and prev.type != token.COMMA:
1725             return NO
1726
1727     elif p.type == syms.typedargslist:
1728         # typed function signatures
1729         if not prev:
1730             return NO
1731
1732         if t == token.EQUAL:
1733             if prev.type != syms.tname:
1734                 return NO
1735
1736         elif prev.type == token.EQUAL:
1737             # A bit hacky: if the equal sign has whitespace, it means we
1738             # previously found it's a typed argument.  So, we're using that, too.
1739             return prev.prefix
1740
1741         elif prev.type != token.COMMA:
1742             return NO
1743
1744     elif p.type == syms.tname:
1745         # type names
1746         if not prev:
1747             prevp = preceding_leaf(p)
1748             if not prevp or prevp.type != token.COMMA:
1749                 return NO
1750
1751     elif p.type == syms.trailer:
1752         # attributes and calls
1753         if t == token.LPAR or t == token.RPAR:
1754             return NO
1755
1756         if not prev:
1757             if t == token.DOT:
1758                 prevp = preceding_leaf(p)
1759                 if not prevp or prevp.type != token.NUMBER:
1760                     return NO
1761
1762             elif t == token.LSQB:
1763                 return NO
1764
1765         elif prev.type != token.COMMA:
1766             return NO
1767
1768     elif p.type == syms.argument:
1769         # single argument
1770         if t == token.EQUAL:
1771             return NO
1772
1773         if not prev:
1774             prevp = preceding_leaf(p)
1775             if not prevp or prevp.type == token.LPAR:
1776                 return NO
1777
1778         elif prev.type in {token.EQUAL} | STARS:
1779             return NO
1780
1781     elif p.type == syms.decorator:
1782         # decorators
1783         return NO
1784
1785     elif p.type == syms.dotted_name:
1786         if prev:
1787             return NO
1788
1789         prevp = preceding_leaf(p)
1790         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1791             return NO
1792
1793     elif p.type == syms.classdef:
1794         if t == token.LPAR:
1795             return NO
1796
1797         if prev and prev.type == token.LPAR:
1798             return NO
1799
1800     elif p.type in {syms.subscript, syms.sliceop}:
1801         # indexing
1802         if not prev:
1803             assert p.parent is not None, "subscripts are always parented"
1804             if p.parent.type == syms.subscriptlist:
1805                 return SPACE
1806
1807             return NO
1808
1809         elif not complex_subscript:
1810             return NO
1811
1812     elif p.type == syms.atom:
1813         if prev and t == token.DOT:
1814             # dots, but not the first one.
1815             return NO
1816
1817     elif p.type == syms.dictsetmaker:
1818         # dict unpacking
1819         if prev and prev.type == token.DOUBLESTAR:
1820             return NO
1821
1822     elif p.type in {syms.factor, syms.star_expr}:
1823         # unary ops
1824         if not prev:
1825             prevp = preceding_leaf(p)
1826             if not prevp or prevp.type in OPENING_BRACKETS:
1827                 return NO
1828
1829             prevp_parent = prevp.parent
1830             assert prevp_parent is not None
1831             if prevp.type == token.COLON and prevp_parent.type in {
1832                 syms.subscript,
1833                 syms.sliceop,
1834             }:
1835                 return NO
1836
1837             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1838                 return NO
1839
1840         elif t in {token.NAME, token.NUMBER, token.STRING}:
1841             return NO
1842
1843     elif p.type == syms.import_from:
1844         if t == token.DOT:
1845             if prev and prev.type == token.DOT:
1846                 return NO
1847
1848         elif t == token.NAME:
1849             if v == "import":
1850                 return SPACE
1851
1852             if prev and prev.type == token.DOT:
1853                 return NO
1854
1855     elif p.type == syms.sliceop:
1856         return NO
1857
1858     return SPACE
1859
1860
1861 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1862     """Return the first leaf that precedes `node`, if any."""
1863     while node:
1864         res = node.prev_sibling
1865         if res:
1866             if isinstance(res, Leaf):
1867                 return res
1868
1869             try:
1870                 return list(res.leaves())[-1]
1871
1872             except IndexError:
1873                 return None
1874
1875         node = node.parent
1876     return None
1877
1878
1879 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1880     """Return the child of `ancestor` that contains `descendant`."""
1881     node: Optional[LN] = descendant
1882     while node and node.parent != ancestor:
1883         node = node.parent
1884     return node
1885
1886
1887 def container_of(leaf: Leaf) -> LN:
1888     """Return `leaf` or one of its ancestors that is the topmost container of it.
1889
1890     By "container" we mean a node where `leaf` is the very first child.
1891     """
1892     same_prefix = leaf.prefix
1893     container: LN = leaf
1894     while container:
1895         parent = container.parent
1896         if parent is None:
1897             break
1898
1899         if parent.children[0].prefix != same_prefix:
1900             break
1901
1902         if parent.type == syms.file_input:
1903             break
1904
1905         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1906             break
1907
1908         container = parent
1909     return container
1910
1911
1912 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1913     """Return the priority of the `leaf` delimiter, given a line break after it.
1914
1915     The delimiter priorities returned here are from those delimiters that would
1916     cause a line break after themselves.
1917
1918     Higher numbers are higher priority.
1919     """
1920     if leaf.type == token.COMMA:
1921         return COMMA_PRIORITY
1922
1923     return 0
1924
1925
1926 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1927     """Return the priority of the `leaf` delimiter, given a line break before it.
1928
1929     The delimiter priorities returned here are from those delimiters that would
1930     cause a line break before themselves.
1931
1932     Higher numbers are higher priority.
1933     """
1934     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1935         # * and ** might also be MATH_OPERATORS but in this case they are not.
1936         # Don't treat them as a delimiter.
1937         return 0
1938
1939     if (
1940         leaf.type == token.DOT
1941         and leaf.parent
1942         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1943         and (previous is None or previous.type in CLOSING_BRACKETS)
1944     ):
1945         return DOT_PRIORITY
1946
1947     if (
1948         leaf.type in MATH_OPERATORS
1949         and leaf.parent
1950         and leaf.parent.type not in {syms.factor, syms.star_expr}
1951     ):
1952         return MATH_PRIORITIES[leaf.type]
1953
1954     if leaf.type in COMPARATORS:
1955         return COMPARATOR_PRIORITY
1956
1957     if (
1958         leaf.type == token.STRING
1959         and previous is not None
1960         and previous.type == token.STRING
1961     ):
1962         return STRING_PRIORITY
1963
1964     if leaf.type not in {token.NAME, token.ASYNC}:
1965         return 0
1966
1967     if (
1968         leaf.value == "for"
1969         and leaf.parent
1970         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1971         or leaf.type == token.ASYNC
1972     ):
1973         if (
1974             not isinstance(leaf.prev_sibling, Leaf)
1975             or leaf.prev_sibling.value != "async"
1976         ):
1977             return COMPREHENSION_PRIORITY
1978
1979     if (
1980         leaf.value == "if"
1981         and leaf.parent
1982         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1983     ):
1984         return COMPREHENSION_PRIORITY
1985
1986     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1987         return TERNARY_PRIORITY
1988
1989     if leaf.value == "is":
1990         return COMPARATOR_PRIORITY
1991
1992     if (
1993         leaf.value == "in"
1994         and leaf.parent
1995         and leaf.parent.type in {syms.comp_op, syms.comparison}
1996         and not (
1997             previous is not None
1998             and previous.type == token.NAME
1999             and previous.value == "not"
2000         )
2001     ):
2002         return COMPARATOR_PRIORITY
2003
2004     if (
2005         leaf.value == "not"
2006         and leaf.parent
2007         and leaf.parent.type == syms.comp_op
2008         and not (
2009             previous is not None
2010             and previous.type == token.NAME
2011             and previous.value == "is"
2012         )
2013     ):
2014         return COMPARATOR_PRIORITY
2015
2016     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2017         return LOGIC_PRIORITY
2018
2019     return 0
2020
2021
2022 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2023 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2024
2025
2026 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2027     """Clean the prefix of the `leaf` and generate comments from it, if any.
2028
2029     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2030     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2031     move because it does away with modifying the grammar to include all the
2032     possible places in which comments can be placed.
2033
2034     The sad consequence for us though is that comments don't "belong" anywhere.
2035     This is why this function generates simple parentless Leaf objects for
2036     comments.  We simply don't know what the correct parent should be.
2037
2038     No matter though, we can live without this.  We really only need to
2039     differentiate between inline and standalone comments.  The latter don't
2040     share the line with any code.
2041
2042     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2043     are emitted with a fake STANDALONE_COMMENT token identifier.
2044     """
2045     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2046         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2047
2048
2049 @dataclass
2050 class ProtoComment:
2051     """Describes a piece of syntax that is a comment.
2052
2053     It's not a :class:`blib2to3.pytree.Leaf` so that:
2054
2055     * it can be cached (`Leaf` objects should not be reused more than once as
2056       they store their lineno, column, prefix, and parent information);
2057     * `newlines` and `consumed` fields are kept separate from the `value`. This
2058       simplifies handling of special marker comments like ``# fmt: off/on``.
2059     """
2060
2061     type: int  # token.COMMENT or STANDALONE_COMMENT
2062     value: str  # content of the comment
2063     newlines: int  # how many newlines before the comment
2064     consumed: int  # how many characters of the original leaf's prefix did we consume
2065
2066
2067 @lru_cache(maxsize=4096)
2068 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2069     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2070     result: List[ProtoComment] = []
2071     if not prefix or "#" not in prefix:
2072         return result
2073
2074     consumed = 0
2075     nlines = 0
2076     for index, line in enumerate(prefix.split("\n")):
2077         consumed += len(line) + 1  # adding the length of the split '\n'
2078         line = line.lstrip()
2079         if not line:
2080             nlines += 1
2081         if not line.startswith("#"):
2082             continue
2083
2084         if index == 0 and not is_endmarker:
2085             comment_type = token.COMMENT  # simple trailing comment
2086         else:
2087             comment_type = STANDALONE_COMMENT
2088         comment = make_comment(line)
2089         result.append(
2090             ProtoComment(
2091                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2092             )
2093         )
2094         nlines = 0
2095     return result
2096
2097
2098 def make_comment(content: str) -> str:
2099     """Return a consistently formatted comment from the given `content` string.
2100
2101     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2102     space between the hash sign and the content.
2103
2104     If `content` didn't start with a hash sign, one is provided.
2105     """
2106     content = content.rstrip()
2107     if not content:
2108         return "#"
2109
2110     if content[0] == "#":
2111         content = content[1:]
2112     if content and content[0] not in " !:#'%":
2113         content = " " + content
2114     return "#" + content
2115
2116
2117 def split_line(
2118     line: Line,
2119     line_length: int,
2120     inner: bool = False,
2121     supports_trailing_commas: bool = False,
2122 ) -> Iterator[Line]:
2123     """Split a `line` into potentially many lines.
2124
2125     They should fit in the allotted `line_length` but might not be able to.
2126     `inner` signifies that there were a pair of brackets somewhere around the
2127     current `line`, possibly transitively. This means we can fallback to splitting
2128     by delimiters if the LHS/RHS don't yield any results.
2129
2130     If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
2131     """
2132     if line.is_comment:
2133         yield line
2134         return
2135
2136     line_str = str(line).strip("\n")
2137
2138     # we don't want to split special comments like type annotations
2139     # https://github.com/python/typing/issues/186
2140     has_special_comment = False
2141     for leaf in line.leaves:
2142         for comment in line.comments_after(leaf):
2143             if leaf.type == token.COMMA and is_special_comment(comment):
2144                 has_special_comment = True
2145
2146     if (
2147         not has_special_comment
2148         and not line.should_explode
2149         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2150     ):
2151         yield line
2152         return
2153
2154     split_funcs: List[SplitFunc]
2155     if line.is_def:
2156         split_funcs = [left_hand_split]
2157     else:
2158
2159         def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
2160             for omit in generate_trailers_to_omit(line, line_length):
2161                 lines = list(
2162                     right_hand_split(
2163                         line, line_length, supports_trailing_commas, omit=omit
2164                     )
2165                 )
2166                 if is_line_short_enough(lines[0], line_length=line_length):
2167                     yield from lines
2168                     return
2169
2170             # All splits failed, best effort split with no omits.
2171             # This mostly happens to multiline strings that are by definition
2172             # reported as not fitting a single line.
2173             yield from right_hand_split(line, supports_trailing_commas)
2174
2175         if line.inside_brackets:
2176             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2177         else:
2178             split_funcs = [rhs]
2179     for split_func in split_funcs:
2180         # We are accumulating lines in `result` because we might want to abort
2181         # mission and return the original line in the end, or attempt a different
2182         # split altogether.
2183         result: List[Line] = []
2184         try:
2185             for l in split_func(line, supports_trailing_commas):
2186                 if str(l).strip("\n") == line_str:
2187                     raise CannotSplit("Split function returned an unchanged result")
2188
2189                 result.extend(
2190                     split_line(
2191                         l,
2192                         line_length=line_length,
2193                         inner=True,
2194                         supports_trailing_commas=supports_trailing_commas,
2195                     )
2196                 )
2197         except CannotSplit:
2198             continue
2199
2200         else:
2201             yield from result
2202             break
2203
2204     else:
2205         yield line
2206
2207
2208 def left_hand_split(
2209     line: Line, supports_trailing_commas: bool = False
2210 ) -> Iterator[Line]:
2211     """Split line into many lines, starting with the first matching bracket pair.
2212
2213     Note: this usually looks weird, only use this for function definitions.
2214     Prefer RHS otherwise.  This is why this function is not symmetrical with
2215     :func:`right_hand_split` which also handles optional parentheses.
2216     """
2217     tail_leaves: List[Leaf] = []
2218     body_leaves: List[Leaf] = []
2219     head_leaves: List[Leaf] = []
2220     current_leaves = head_leaves
2221     matching_bracket = None
2222     for leaf in line.leaves:
2223         if (
2224             current_leaves is body_leaves
2225             and leaf.type in CLOSING_BRACKETS
2226             and leaf.opening_bracket is matching_bracket
2227         ):
2228             current_leaves = tail_leaves if body_leaves else head_leaves
2229         current_leaves.append(leaf)
2230         if current_leaves is head_leaves:
2231             if leaf.type in OPENING_BRACKETS:
2232                 matching_bracket = leaf
2233                 current_leaves = body_leaves
2234     if not matching_bracket:
2235         raise CannotSplit("No brackets found")
2236
2237     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2238     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2239     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2240     bracket_split_succeeded_or_raise(head, body, tail)
2241     for result in (head, body, tail):
2242         if result:
2243             yield result
2244
2245
2246 def right_hand_split(
2247     line: Line,
2248     line_length: int,
2249     supports_trailing_commas: bool = False,
2250     omit: Collection[LeafID] = (),
2251 ) -> Iterator[Line]:
2252     """Split line into many lines, starting with the last matching bracket pair.
2253
2254     If the split was by optional parentheses, attempt splitting without them, too.
2255     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2256     this split.
2257
2258     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2259     """
2260     tail_leaves: List[Leaf] = []
2261     body_leaves: List[Leaf] = []
2262     head_leaves: List[Leaf] = []
2263     current_leaves = tail_leaves
2264     opening_bracket = None
2265     closing_bracket = None
2266     for leaf in reversed(line.leaves):
2267         if current_leaves is body_leaves:
2268             if leaf is opening_bracket:
2269                 current_leaves = head_leaves if body_leaves else tail_leaves
2270         current_leaves.append(leaf)
2271         if current_leaves is tail_leaves:
2272             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2273                 opening_bracket = leaf.opening_bracket
2274                 closing_bracket = leaf
2275                 current_leaves = body_leaves
2276     if not (opening_bracket and closing_bracket and head_leaves):
2277         # If there is no opening or closing_bracket that means the split failed and
2278         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2279         # the matching `opening_bracket` wasn't available on `line` anymore.
2280         raise CannotSplit("No brackets found")
2281
2282     tail_leaves.reverse()
2283     body_leaves.reverse()
2284     head_leaves.reverse()
2285     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2286     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2287     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2288     bracket_split_succeeded_or_raise(head, body, tail)
2289     if (
2290         # the body shouldn't be exploded
2291         not body.should_explode
2292         # the opening bracket is an optional paren
2293         and opening_bracket.type == token.LPAR
2294         and not opening_bracket.value
2295         # the closing bracket is an optional paren
2296         and closing_bracket.type == token.RPAR
2297         and not closing_bracket.value
2298         # it's not an import (optional parens are the only thing we can split on
2299         # in this case; attempting a split without them is a waste of time)
2300         and not line.is_import
2301         # there are no standalone comments in the body
2302         and not body.contains_standalone_comments(0)
2303         # and we can actually remove the parens
2304         and can_omit_invisible_parens(body, line_length)
2305     ):
2306         omit = {id(closing_bracket), *omit}
2307         try:
2308             yield from right_hand_split(
2309                 line,
2310                 line_length,
2311                 supports_trailing_commas=supports_trailing_commas,
2312                 omit=omit,
2313             )
2314             return
2315
2316         except CannotSplit:
2317             if not (
2318                 can_be_split(body)
2319                 or is_line_short_enough(body, line_length=line_length)
2320             ):
2321                 raise CannotSplit(
2322                     "Splitting failed, body is still too long and can't be split."
2323                 )
2324
2325             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2326                 raise CannotSplit(
2327                     "The current optional pair of parentheses is bound to fail to "
2328                     "satisfy the splitting algorithm because the head or the tail "
2329                     "contains multiline strings which by definition never fit one "
2330                     "line."
2331                 )
2332
2333     ensure_visible(opening_bracket)
2334     ensure_visible(closing_bracket)
2335     for result in (head, body, tail):
2336         if result:
2337             yield result
2338
2339
2340 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2341     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2342
2343     Do nothing otherwise.
2344
2345     A left- or right-hand split is based on a pair of brackets. Content before
2346     (and including) the opening bracket is left on one line, content inside the
2347     brackets is put on a separate line, and finally content starting with and
2348     following the closing bracket is put on a separate line.
2349
2350     Those are called `head`, `body`, and `tail`, respectively. If the split
2351     produced the same line (all content in `head`) or ended up with an empty `body`
2352     and the `tail` is just the closing bracket, then it's considered failed.
2353     """
2354     tail_len = len(str(tail).strip())
2355     if not body:
2356         if tail_len == 0:
2357             raise CannotSplit("Splitting brackets produced the same line")
2358
2359         elif tail_len < 3:
2360             raise CannotSplit(
2361                 f"Splitting brackets on an empty body to save "
2362                 f"{tail_len} characters is not worth it"
2363             )
2364
2365
2366 def bracket_split_build_line(
2367     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2368 ) -> Line:
2369     """Return a new line with given `leaves` and respective comments from `original`.
2370
2371     If `is_body` is True, the result line is one-indented inside brackets and as such
2372     has its first leaf's prefix normalized and a trailing comma added when expected.
2373     """
2374     result = Line(depth=original.depth)
2375     if is_body:
2376         result.inside_brackets = True
2377         result.depth += 1
2378         if leaves:
2379             # Since body is a new indent level, remove spurious leading whitespace.
2380             normalize_prefix(leaves[0], inside_brackets=True)
2381             # Ensure a trailing comma when expected.
2382             if original.is_import:
2383                 if leaves[-1].type != token.COMMA:
2384                     leaves.append(Leaf(token.COMMA, ","))
2385     # Populate the line
2386     for leaf in leaves:
2387         result.append(leaf, preformatted=True)
2388         for comment_after in original.comments_after(leaf):
2389             result.append(comment_after, preformatted=True)
2390     if is_body:
2391         result.should_explode = should_explode(result, opening_bracket)
2392     return result
2393
2394
2395 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2396     """Normalize prefix of the first leaf in every line returned by `split_func`.
2397
2398     This is a decorator over relevant split functions.
2399     """
2400
2401     @wraps(split_func)
2402     def split_wrapper(
2403         line: Line, supports_trailing_commas: bool = False
2404     ) -> Iterator[Line]:
2405         for l in split_func(line, supports_trailing_commas):
2406             normalize_prefix(l.leaves[0], inside_brackets=True)
2407             yield l
2408
2409     return split_wrapper
2410
2411
2412 @dont_increase_indentation
2413 def delimiter_split(
2414     line: Line, supports_trailing_commas: bool = False
2415 ) -> Iterator[Line]:
2416     """Split according to delimiters of the highest priority.
2417
2418     If `supports_trailing_commas` is True, the split will add trailing commas
2419     also in function signatures that contain `*` and `**`.
2420     """
2421     try:
2422         last_leaf = line.leaves[-1]
2423     except IndexError:
2424         raise CannotSplit("Line empty")
2425
2426     bt = line.bracket_tracker
2427     try:
2428         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2429     except ValueError:
2430         raise CannotSplit("No delimiters found")
2431
2432     if delimiter_priority == DOT_PRIORITY:
2433         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2434             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2435
2436     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2437     lowest_depth = sys.maxsize
2438     trailing_comma_safe = True
2439
2440     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2441         """Append `leaf` to current line or to new line if appending impossible."""
2442         nonlocal current_line
2443         try:
2444             current_line.append_safe(leaf, preformatted=True)
2445         except ValueError:
2446             yield current_line
2447
2448             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2449             current_line.append(leaf)
2450
2451     for leaf in line.leaves:
2452         yield from append_to_line(leaf)
2453
2454         for comment_after in line.comments_after(leaf):
2455             yield from append_to_line(comment_after)
2456
2457         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2458         if leaf.bracket_depth == lowest_depth and is_vararg(
2459             leaf, within=VARARGS_PARENTS
2460         ):
2461             trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
2462         leaf_priority = bt.delimiters.get(id(leaf))
2463         if leaf_priority == delimiter_priority:
2464             yield current_line
2465
2466             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2467     if current_line:
2468         if (
2469             trailing_comma_safe
2470             and delimiter_priority == COMMA_PRIORITY
2471             and current_line.leaves[-1].type != token.COMMA
2472             and current_line.leaves[-1].type != STANDALONE_COMMENT
2473         ):
2474             current_line.append(Leaf(token.COMMA, ","))
2475         yield current_line
2476
2477
2478 @dont_increase_indentation
2479 def standalone_comment_split(
2480     line: Line, supports_trailing_commas: bool = False
2481 ) -> Iterator[Line]:
2482     """Split standalone comments from the rest of the line."""
2483     if not line.contains_standalone_comments(0):
2484         raise CannotSplit("Line does not have any standalone comments")
2485
2486     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2487
2488     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2489         """Append `leaf` to current line or to new line if appending impossible."""
2490         nonlocal current_line
2491         try:
2492             current_line.append_safe(leaf, preformatted=True)
2493         except ValueError:
2494             yield current_line
2495
2496             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2497             current_line.append(leaf)
2498
2499     for leaf in line.leaves:
2500         yield from append_to_line(leaf)
2501
2502         for comment_after in line.comments_after(leaf):
2503             yield from append_to_line(comment_after)
2504
2505     if current_line:
2506         yield current_line
2507
2508
2509 def is_import(leaf: Leaf) -> bool:
2510     """Return True if the given leaf starts an import statement."""
2511     p = leaf.parent
2512     t = leaf.type
2513     v = leaf.value
2514     return bool(
2515         t == token.NAME
2516         and (
2517             (v == "import" and p and p.type == syms.import_name)
2518             or (v == "from" and p and p.type == syms.import_from)
2519         )
2520     )
2521
2522
2523 def is_special_comment(leaf: Leaf) -> bool:
2524     """Return True if the given leaf is a special comment.
2525     Only returns true for type comments for now."""
2526     t = leaf.type
2527     v = leaf.value
2528     return bool(
2529         (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
2530     )
2531
2532
2533 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2534     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2535     else.
2536
2537     Note: don't use backslashes for formatting or you'll lose your voting rights.
2538     """
2539     if not inside_brackets:
2540         spl = leaf.prefix.split("#")
2541         if "\\" not in spl[0]:
2542             nl_count = spl[-1].count("\n")
2543             if len(spl) > 1:
2544                 nl_count -= 1
2545             leaf.prefix = "\n" * nl_count
2546             return
2547
2548     leaf.prefix = ""
2549
2550
2551 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2552     """Make all string prefixes lowercase.
2553
2554     If remove_u_prefix is given, also removes any u prefix from the string.
2555
2556     Note: Mutates its argument.
2557     """
2558     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2559     assert match is not None, f"failed to match string {leaf.value!r}"
2560     orig_prefix = match.group(1)
2561     new_prefix = orig_prefix.lower()
2562     if remove_u_prefix:
2563         new_prefix = new_prefix.replace("u", "")
2564     leaf.value = f"{new_prefix}{match.group(2)}"
2565
2566
2567 def normalize_string_quotes(leaf: Leaf) -> None:
2568     """Prefer double quotes but only if it doesn't cause more escaping.
2569
2570     Adds or removes backslashes as appropriate. Doesn't parse and fix
2571     strings nested in f-strings (yet).
2572
2573     Note: Mutates its argument.
2574     """
2575     value = leaf.value.lstrip("furbFURB")
2576     if value[:3] == '"""':
2577         return
2578
2579     elif value[:3] == "'''":
2580         orig_quote = "'''"
2581         new_quote = '"""'
2582     elif value[0] == '"':
2583         orig_quote = '"'
2584         new_quote = "'"
2585     else:
2586         orig_quote = "'"
2587         new_quote = '"'
2588     first_quote_pos = leaf.value.find(orig_quote)
2589     if first_quote_pos == -1:
2590         return  # There's an internal error
2591
2592     prefix = leaf.value[:first_quote_pos]
2593     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2594     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2595     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2596     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2597     if "r" in prefix.casefold():
2598         if unescaped_new_quote.search(body):
2599             # There's at least one unescaped new_quote in this raw string
2600             # so converting is impossible
2601             return
2602
2603         # Do not introduce or remove backslashes in raw strings
2604         new_body = body
2605     else:
2606         # remove unnecessary escapes
2607         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2608         if body != new_body:
2609             # Consider the string without unnecessary escapes as the original
2610             body = new_body
2611             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2612         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2613         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2614     if "f" in prefix.casefold():
2615         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2616         for m in matches:
2617             if "\\" in str(m):
2618                 # Do not introduce backslashes in interpolated expressions
2619                 return
2620     if new_quote == '"""' and new_body[-1:] == '"':
2621         # edge case:
2622         new_body = new_body[:-1] + '\\"'
2623     orig_escape_count = body.count("\\")
2624     new_escape_count = new_body.count("\\")
2625     if new_escape_count > orig_escape_count:
2626         return  # Do not introduce more escaping
2627
2628     if new_escape_count == orig_escape_count and orig_quote == '"':
2629         return  # Prefer double quotes
2630
2631     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2632
2633
2634 def normalize_numeric_literal(leaf: Leaf) -> None:
2635     """Normalizes numeric (float, int, and complex) literals.
2636
2637     All letters used in the representation are normalized to lowercase (except
2638     in Python 2 long literals).
2639     """
2640     text = leaf.value.lower()
2641     if text.startswith(("0o", "0b")):
2642         # Leave octal and binary literals alone.
2643         pass
2644     elif text.startswith("0x"):
2645         # Change hex literals to upper case.
2646         before, after = text[:2], text[2:]
2647         text = f"{before}{after.upper()}"
2648     elif "e" in text:
2649         before, after = text.split("e")
2650         sign = ""
2651         if after.startswith("-"):
2652             after = after[1:]
2653             sign = "-"
2654         elif after.startswith("+"):
2655             after = after[1:]
2656         before = format_float_or_int_string(before)
2657         text = f"{before}e{sign}{after}"
2658     elif text.endswith(("j", "l")):
2659         number = text[:-1]
2660         suffix = text[-1]
2661         # Capitalize in "2L" because "l" looks too similar to "1".
2662         if suffix == "l":
2663             suffix = "L"
2664         text = f"{format_float_or_int_string(number)}{suffix}"
2665     else:
2666         text = format_float_or_int_string(text)
2667     leaf.value = text
2668
2669
2670 def format_float_or_int_string(text: str) -> str:
2671     """Formats a float string like "1.0"."""
2672     if "." not in text:
2673         return text
2674
2675     before, after = text.split(".")
2676     return f"{before or 0}.{after or 0}"
2677
2678
2679 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2680     """Make existing optional parentheses invisible or create new ones.
2681
2682     `parens_after` is a set of string leaf values immeditely after which parens
2683     should be put.
2684
2685     Standardizes on visible parentheses for single-element tuples, and keeps
2686     existing visible parentheses for other tuples and generator expressions.
2687     """
2688     for pc in list_comments(node.prefix, is_endmarker=False):
2689         if pc.value in FMT_OFF:
2690             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2691             return
2692
2693     check_lpar = False
2694     for index, child in enumerate(list(node.children)):
2695         if check_lpar:
2696             if child.type == syms.atom:
2697                 if maybe_make_parens_invisible_in_atom(child):
2698                     lpar = Leaf(token.LPAR, "")
2699                     rpar = Leaf(token.RPAR, "")
2700                     index = child.remove() or 0
2701                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2702             elif is_one_tuple(child):
2703                 # wrap child in visible parentheses
2704                 lpar = Leaf(token.LPAR, "(")
2705                 rpar = Leaf(token.RPAR, ")")
2706                 child.remove()
2707                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2708             elif node.type == syms.import_from:
2709                 # "import from" nodes store parentheses directly as part of
2710                 # the statement
2711                 if child.type == token.LPAR:
2712                     # make parentheses invisible
2713                     child.value = ""  # type: ignore
2714                     node.children[-1].value = ""  # type: ignore
2715                 elif child.type != token.STAR:
2716                     # insert invisible parentheses
2717                     node.insert_child(index, Leaf(token.LPAR, ""))
2718                     node.append_child(Leaf(token.RPAR, ""))
2719                 break
2720
2721             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2722                 # wrap child in invisible parentheses
2723                 lpar = Leaf(token.LPAR, "")
2724                 rpar = Leaf(token.RPAR, "")
2725                 index = child.remove() or 0
2726                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2727
2728         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2729
2730
2731 def normalize_fmt_off(node: Node) -> None:
2732     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2733     try_again = True
2734     while try_again:
2735         try_again = convert_one_fmt_off_pair(node)
2736
2737
2738 def convert_one_fmt_off_pair(node: Node) -> bool:
2739     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2740
2741     Returns True if a pair was converted.
2742     """
2743     for leaf in node.leaves():
2744         previous_consumed = 0
2745         for comment in list_comments(leaf.prefix, is_endmarker=False):
2746             if comment.value in FMT_OFF:
2747                 # We only want standalone comments. If there's no previous leaf or
2748                 # the previous leaf is indentation, it's a standalone comment in
2749                 # disguise.
2750                 if comment.type != STANDALONE_COMMENT:
2751                     prev = preceding_leaf(leaf)
2752                     if prev and prev.type not in WHITESPACE:
2753                         continue
2754
2755                 ignored_nodes = list(generate_ignored_nodes(leaf))
2756                 if not ignored_nodes:
2757                     continue
2758
2759                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2760                 parent = first.parent
2761                 prefix = first.prefix
2762                 first.prefix = prefix[comment.consumed :]
2763                 hidden_value = (
2764                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2765                 )
2766                 if hidden_value.endswith("\n"):
2767                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2768                     # leaf (possibly followed by a DEDENT).
2769                     hidden_value = hidden_value[:-1]
2770                 first_idx = None
2771                 for ignored in ignored_nodes:
2772                     index = ignored.remove()
2773                     if first_idx is None:
2774                         first_idx = index
2775                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2776                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2777                 parent.insert_child(
2778                     first_idx,
2779                     Leaf(
2780                         STANDALONE_COMMENT,
2781                         hidden_value,
2782                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2783                     ),
2784                 )
2785                 return True
2786
2787             previous_consumed = comment.consumed
2788
2789     return False
2790
2791
2792 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2793     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2794
2795     Stops at the end of the block.
2796     """
2797     container: Optional[LN] = container_of(leaf)
2798     while container is not None and container.type != token.ENDMARKER:
2799         for comment in list_comments(container.prefix, is_endmarker=False):
2800             if comment.value in FMT_ON:
2801                 return
2802
2803         yield container
2804
2805         container = container.next_sibling
2806
2807
2808 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2809     """If it's safe, make the parens in the atom `node` invisible, recursively.
2810
2811     Returns whether the node should itself be wrapped in invisible parentheses.
2812
2813     """
2814     if (
2815         node.type != syms.atom
2816         or is_empty_tuple(node)
2817         or is_one_tuple(node)
2818         or is_yield(node)
2819         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2820     ):
2821         return False
2822
2823     first = node.children[0]
2824     last = node.children[-1]
2825     if first.type == token.LPAR and last.type == token.RPAR:
2826         # make parentheses invisible
2827         first.value = ""  # type: ignore
2828         last.value = ""  # type: ignore
2829         if len(node.children) > 1:
2830             maybe_make_parens_invisible_in_atom(node.children[1])
2831         return False
2832
2833     return True
2834
2835
2836 def is_empty_tuple(node: LN) -> bool:
2837     """Return True if `node` holds an empty tuple."""
2838     return (
2839         node.type == syms.atom
2840         and len(node.children) == 2
2841         and node.children[0].type == token.LPAR
2842         and node.children[1].type == token.RPAR
2843     )
2844
2845
2846 def is_one_tuple(node: LN) -> bool:
2847     """Return True if `node` holds a tuple with one element, with or without parens."""
2848     if node.type == syms.atom:
2849         if len(node.children) != 3:
2850             return False
2851
2852         lpar, gexp, rpar = node.children
2853         if not (
2854             lpar.type == token.LPAR
2855             and gexp.type == syms.testlist_gexp
2856             and rpar.type == token.RPAR
2857         ):
2858             return False
2859
2860         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2861
2862     return (
2863         node.type in IMPLICIT_TUPLE
2864         and len(node.children) == 2
2865         and node.children[1].type == token.COMMA
2866     )
2867
2868
2869 def is_yield(node: LN) -> bool:
2870     """Return True if `node` holds a `yield` or `yield from` expression."""
2871     if node.type == syms.yield_expr:
2872         return True
2873
2874     if node.type == token.NAME and node.value == "yield":  # type: ignore
2875         return True
2876
2877     if node.type != syms.atom:
2878         return False
2879
2880     if len(node.children) != 3:
2881         return False
2882
2883     lpar, expr, rpar = node.children
2884     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2885         return is_yield(expr)
2886
2887     return False
2888
2889
2890 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2891     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2892
2893     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2894     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2895     extended iterable unpacking (PEP 3132) and additional unpacking
2896     generalizations (PEP 448).
2897     """
2898     if leaf.type not in STARS or not leaf.parent:
2899         return False
2900
2901     p = leaf.parent
2902     if p.type == syms.star_expr:
2903         # Star expressions are also used as assignment targets in extended
2904         # iterable unpacking (PEP 3132).  See what its parent is instead.
2905         if not p.parent:
2906             return False
2907
2908         p = p.parent
2909
2910     return p.type in within
2911
2912
2913 def is_multiline_string(leaf: Leaf) -> bool:
2914     """Return True if `leaf` is a multiline string that actually spans many lines."""
2915     value = leaf.value.lstrip("furbFURB")
2916     return value[:3] in {'"""', "'''"} and "\n" in value
2917
2918
2919 def is_stub_suite(node: Node) -> bool:
2920     """Return True if `node` is a suite with a stub body."""
2921     if (
2922         len(node.children) != 4
2923         or node.children[0].type != token.NEWLINE
2924         or node.children[1].type != token.INDENT
2925         or node.children[3].type != token.DEDENT
2926     ):
2927         return False
2928
2929     return is_stub_body(node.children[2])
2930
2931
2932 def is_stub_body(node: LN) -> bool:
2933     """Return True if `node` is a simple statement containing an ellipsis."""
2934     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2935         return False
2936
2937     if len(node.children) != 2:
2938         return False
2939
2940     child = node.children[0]
2941     return (
2942         child.type == syms.atom
2943         and len(child.children) == 3
2944         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2945     )
2946
2947
2948 def max_delimiter_priority_in_atom(node: LN) -> int:
2949     """Return maximum delimiter priority inside `node`.
2950
2951     This is specific to atoms with contents contained in a pair of parentheses.
2952     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2953     """
2954     if node.type != syms.atom:
2955         return 0
2956
2957     first = node.children[0]
2958     last = node.children[-1]
2959     if not (first.type == token.LPAR and last.type == token.RPAR):
2960         return 0
2961
2962     bt = BracketTracker()
2963     for c in node.children[1:-1]:
2964         if isinstance(c, Leaf):
2965             bt.mark(c)
2966         else:
2967             for leaf in c.leaves():
2968                 bt.mark(leaf)
2969     try:
2970         return bt.max_delimiter_priority()
2971
2972     except ValueError:
2973         return 0
2974
2975
2976 def ensure_visible(leaf: Leaf) -> None:
2977     """Make sure parentheses are visible.
2978
2979     They could be invisible as part of some statements (see
2980     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2981     """
2982     if leaf.type == token.LPAR:
2983         leaf.value = "("
2984     elif leaf.type == token.RPAR:
2985         leaf.value = ")"
2986
2987
2988 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2989     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2990
2991     if not (
2992         opening_bracket.parent
2993         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2994         and opening_bracket.value in "[{("
2995     ):
2996         return False
2997
2998     try:
2999         last_leaf = line.leaves[-1]
3000         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3001         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3002     except (IndexError, ValueError):
3003         return False
3004
3005     return max_priority == COMMA_PRIORITY
3006
3007
3008 def get_features_used(node: Node) -> Set[Feature]:
3009     """Return a set of (relatively) new Python features used in this file.
3010
3011     Currently looking for:
3012     - f-strings;
3013     - underscores in numeric literals; and
3014     - trailing commas after * or ** in function signatures and calls.
3015     """
3016     features: Set[Feature] = set()
3017     for n in node.pre_order():
3018         if n.type == token.STRING:
3019             value_head = n.value[:2]  # type: ignore
3020             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3021                 features.add(Feature.F_STRINGS)
3022
3023         elif n.type == token.NUMBER:
3024             if "_" in n.value:  # type: ignore
3025                 features.add(Feature.NUMERIC_UNDERSCORES)
3026
3027         elif (
3028             n.type in {syms.typedargslist, syms.arglist}
3029             and n.children
3030             and n.children[-1].type == token.COMMA
3031         ):
3032             for ch in n.children:
3033                 if ch.type in STARS:
3034                     features.add(Feature.TRAILING_COMMA)
3035
3036                 if ch.type == syms.argument:
3037                     for argch in ch.children:
3038                         if argch.type in STARS:
3039                             features.add(Feature.TRAILING_COMMA)
3040
3041     return features
3042
3043
3044 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3045     """Detect the version to target based on the nodes used."""
3046     features = get_features_used(node)
3047     return {
3048         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3049     }
3050
3051
3052 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3053     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3054
3055     Brackets can be omitted if the entire trailer up to and including
3056     a preceding closing bracket fits in one line.
3057
3058     Yielded sets are cumulative (contain results of previous yields, too).  First
3059     set is empty.
3060     """
3061
3062     omit: Set[LeafID] = set()
3063     yield omit
3064
3065     length = 4 * line.depth
3066     opening_bracket = None
3067     closing_bracket = None
3068     inner_brackets: Set[LeafID] = set()
3069     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3070         length += leaf_length
3071         if length > line_length:
3072             break
3073
3074         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3075         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3076             break
3077
3078         if opening_bracket:
3079             if leaf is opening_bracket:
3080                 opening_bracket = None
3081             elif leaf.type in CLOSING_BRACKETS:
3082                 inner_brackets.add(id(leaf))
3083         elif leaf.type in CLOSING_BRACKETS:
3084             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3085                 # Empty brackets would fail a split so treat them as "inner"
3086                 # brackets (e.g. only add them to the `omit` set if another
3087                 # pair of brackets was good enough.
3088                 inner_brackets.add(id(leaf))
3089                 continue
3090
3091             if closing_bracket:
3092                 omit.add(id(closing_bracket))
3093                 omit.update(inner_brackets)
3094                 inner_brackets.clear()
3095                 yield omit
3096
3097             if leaf.value:
3098                 opening_bracket = leaf.opening_bracket
3099                 closing_bracket = leaf
3100
3101
3102 def get_future_imports(node: Node) -> Set[str]:
3103     """Return a set of __future__ imports in the file."""
3104     imports: Set[str] = set()
3105
3106     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3107         for child in children:
3108             if isinstance(child, Leaf):
3109                 if child.type == token.NAME:
3110                     yield child.value
3111             elif child.type == syms.import_as_name:
3112                 orig_name = child.children[0]
3113                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3114                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3115                 yield orig_name.value
3116             elif child.type == syms.import_as_names:
3117                 yield from get_imports_from_children(child.children)
3118             else:
3119                 assert False, "Invalid syntax parsing imports"
3120
3121     for child in node.children:
3122         if child.type != syms.simple_stmt:
3123             break
3124         first_child = child.children[0]
3125         if isinstance(first_child, Leaf):
3126             # Continue looking if we see a docstring; otherwise stop.
3127             if (
3128                 len(child.children) == 2
3129                 and first_child.type == token.STRING
3130                 and child.children[1].type == token.NEWLINE
3131             ):
3132                 continue
3133             else:
3134                 break
3135         elif first_child.type == syms.import_from:
3136             module_name = first_child.children[1]
3137             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3138                 break
3139             imports |= set(get_imports_from_children(first_child.children[3:]))
3140         else:
3141             break
3142     return imports
3143
3144
3145 def gen_python_files_in_dir(
3146     path: Path,
3147     root: Path,
3148     include: Pattern[str],
3149     exclude: Pattern[str],
3150     report: "Report",
3151 ) -> Iterator[Path]:
3152     """Generate all files under `path` whose paths are not excluded by the
3153     `exclude` regex, but are included by the `include` regex.
3154
3155     Symbolic links pointing outside of the `root` directory are ignored.
3156
3157     `report` is where output about exclusions goes.
3158     """
3159     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3160     for child in path.iterdir():
3161         try:
3162             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3163         except ValueError:
3164             if child.is_symlink():
3165                 report.path_ignored(
3166                     child, f"is a symbolic link that points outside {root}"
3167                 )
3168                 continue
3169
3170             raise
3171
3172         if child.is_dir():
3173             normalized_path += "/"
3174         exclude_match = exclude.search(normalized_path)
3175         if exclude_match and exclude_match.group(0):
3176             report.path_ignored(child, f"matches the --exclude regular expression")
3177             continue
3178
3179         if child.is_dir():
3180             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3181
3182         elif child.is_file():
3183             include_match = include.search(normalized_path)
3184             if include_match:
3185                 yield child
3186
3187
3188 @lru_cache()
3189 def find_project_root(srcs: Iterable[str]) -> Path:
3190     """Return a directory containing .git, .hg, or pyproject.toml.
3191
3192     That directory can be one of the directories passed in `srcs` or their
3193     common parent.
3194
3195     If no directory in the tree contains a marker that would specify it's the
3196     project root, the root of the file system is returned.
3197     """
3198     if not srcs:
3199         return Path("/").resolve()
3200
3201     common_base = min(Path(src).resolve() for src in srcs)
3202     if common_base.is_dir():
3203         # Append a fake file so `parents` below returns `common_base_dir`, too.
3204         common_base /= "fake-file"
3205     for directory in common_base.parents:
3206         if (directory / ".git").is_dir():
3207             return directory
3208
3209         if (directory / ".hg").is_dir():
3210             return directory
3211
3212         if (directory / "pyproject.toml").is_file():
3213             return directory
3214
3215     return directory
3216
3217
3218 @dataclass
3219 class Report:
3220     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3221
3222     check: bool = False
3223     quiet: bool = False
3224     verbose: bool = False
3225     change_count: int = 0
3226     same_count: int = 0
3227     failure_count: int = 0
3228
3229     def done(self, src: Path, changed: Changed) -> None:
3230         """Increment the counter for successful reformatting. Write out a message."""
3231         if changed is Changed.YES:
3232             reformatted = "would reformat" if self.check else "reformatted"
3233             if self.verbose or not self.quiet:
3234                 out(f"{reformatted} {src}")
3235             self.change_count += 1
3236         else:
3237             if self.verbose:
3238                 if changed is Changed.NO:
3239                     msg = f"{src} already well formatted, good job."
3240                 else:
3241                     msg = f"{src} wasn't modified on disk since last run."
3242                 out(msg, bold=False)
3243             self.same_count += 1
3244
3245     def failed(self, src: Path, message: str) -> None:
3246         """Increment the counter for failed reformatting. Write out a message."""
3247         err(f"error: cannot format {src}: {message}")
3248         self.failure_count += 1
3249
3250     def path_ignored(self, path: Path, message: str) -> None:
3251         if self.verbose:
3252             out(f"{path} ignored: {message}", bold=False)
3253
3254     @property
3255     def return_code(self) -> int:
3256         """Return the exit code that the app should use.
3257
3258         This considers the current state of changed files and failures:
3259         - if there were any failures, return 123;
3260         - if any files were changed and --check is being used, return 1;
3261         - otherwise return 0.
3262         """
3263         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3264         # 126 we have special return codes reserved by the shell.
3265         if self.failure_count:
3266             return 123
3267
3268         elif self.change_count and self.check:
3269             return 1
3270
3271         return 0
3272
3273     def __str__(self) -> str:
3274         """Render a color report of the current state.
3275
3276         Use `click.unstyle` to remove colors.
3277         """
3278         if self.check:
3279             reformatted = "would be reformatted"
3280             unchanged = "would be left unchanged"
3281             failed = "would fail to reformat"
3282         else:
3283             reformatted = "reformatted"
3284             unchanged = "left unchanged"
3285             failed = "failed to reformat"
3286         report = []
3287         if self.change_count:
3288             s = "s" if self.change_count > 1 else ""
3289             report.append(
3290                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3291             )
3292         if self.same_count:
3293             s = "s" if self.same_count > 1 else ""
3294             report.append(f"{self.same_count} file{s} {unchanged}")
3295         if self.failure_count:
3296             s = "s" if self.failure_count > 1 else ""
3297             report.append(
3298                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3299             )
3300         return ", ".join(report) + "."
3301
3302
3303 def assert_equivalent(src: str, dst: str) -> None:
3304     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3305
3306     import ast
3307     import traceback
3308
3309     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3310         """Simple visitor generating strings to compare ASTs by content."""
3311         yield f"{'  ' * depth}{node.__class__.__name__}("
3312
3313         for field in sorted(node._fields):
3314             try:
3315                 value = getattr(node, field)
3316             except AttributeError:
3317                 continue
3318
3319             yield f"{'  ' * (depth+1)}{field}="
3320
3321             if isinstance(value, list):
3322                 for item in value:
3323                     # Ignore nested tuples within del statements, because we may insert
3324                     # parentheses and they change the AST.
3325                     if (
3326                         field == "targets"
3327                         and isinstance(node, ast.Delete)
3328                         and isinstance(item, ast.Tuple)
3329                     ):
3330                         for item in item.elts:
3331                             yield from _v(item, depth + 2)
3332                     elif isinstance(item, ast.AST):
3333                         yield from _v(item, depth + 2)
3334
3335             elif isinstance(value, ast.AST):
3336                 yield from _v(value, depth + 2)
3337
3338             else:
3339                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3340
3341         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3342
3343     try:
3344         src_ast = ast.parse(src)
3345     except Exception as exc:
3346         major, minor = sys.version_info[:2]
3347         raise AssertionError(
3348             f"cannot use --safe with this file; failed to parse source file "
3349             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3350             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3351         )
3352
3353     try:
3354         dst_ast = ast.parse(dst)
3355     except Exception as exc:
3356         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3357         raise AssertionError(
3358             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3359             f"Please report a bug on https://github.com/ambv/black/issues.  "
3360             f"This invalid output might be helpful: {log}"
3361         ) from None
3362
3363     src_ast_str = "\n".join(_v(src_ast))
3364     dst_ast_str = "\n".join(_v(dst_ast))
3365     if src_ast_str != dst_ast_str:
3366         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3367         raise AssertionError(
3368             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3369             f"the source.  "
3370             f"Please report a bug on https://github.com/ambv/black/issues.  "
3371             f"This diff might be helpful: {log}"
3372         ) from None
3373
3374
3375 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3376     """Raise AssertionError if `dst` reformats differently the second time."""
3377     newdst = format_str(dst, mode=mode)
3378     if dst != newdst:
3379         log = dump_to_file(
3380             diff(src, dst, "source", "first pass"),
3381             diff(dst, newdst, "first pass", "second pass"),
3382         )
3383         raise AssertionError(
3384             f"INTERNAL ERROR: Black produced different code on the second pass "
3385             f"of the formatter.  "
3386             f"Please report a bug on https://github.com/ambv/black/issues.  "
3387             f"This diff might be helpful: {log}"
3388         ) from None
3389
3390
3391 def dump_to_file(*output: str) -> str:
3392     """Dump `output` to a temporary file. Return path to the file."""
3393     import tempfile
3394
3395     with tempfile.NamedTemporaryFile(
3396         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3397     ) as f:
3398         for lines in output:
3399             f.write(lines)
3400             if lines and lines[-1] != "\n":
3401                 f.write("\n")
3402     return f.name
3403
3404
3405 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3406     """Return a unified diff string between strings `a` and `b`."""
3407     import difflib
3408
3409     a_lines = [line + "\n" for line in a.split("\n")]
3410     b_lines = [line + "\n" for line in b.split("\n")]
3411     return "".join(
3412         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3413     )
3414
3415
3416 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3417     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3418     err("Aborted!")
3419     for task in tasks:
3420         task.cancel()
3421
3422
3423 def shutdown(loop: BaseEventLoop) -> None:
3424     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3425     try:
3426         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3427         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3428         if not to_cancel:
3429             return
3430
3431         for task in to_cancel:
3432             task.cancel()
3433         loop.run_until_complete(
3434             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3435         )
3436     finally:
3437         # `concurrent.futures.Future` objects cannot be cancelled once they
3438         # are already running. There might be some when the `shutdown()` happened.
3439         # Silence their logger's spew about the event loop being closed.
3440         cf_logger = logging.getLogger("concurrent.futures")
3441         cf_logger.setLevel(logging.CRITICAL)
3442         loop.close()
3443
3444
3445 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3446     """Replace `regex` with `replacement` twice on `original`.
3447
3448     This is used by string normalization to perform replaces on
3449     overlapping matches.
3450     """
3451     return regex.sub(replacement, regex.sub(replacement, original))
3452
3453
3454 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3455     """Compile a regular expression string in `regex`.
3456
3457     If it contains newlines, use verbose mode.
3458     """
3459     if "\n" in regex:
3460         regex = "(?x)" + regex
3461     return re.compile(regex)
3462
3463
3464 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3465     """Like `reversed(enumerate(sequence))` if that were possible."""
3466     index = len(sequence) - 1
3467     for element in reversed(sequence):
3468         yield (index, element)
3469         index -= 1
3470
3471
3472 def enumerate_with_length(
3473     line: Line, reversed: bool = False
3474 ) -> Iterator[Tuple[Index, Leaf, int]]:
3475     """Return an enumeration of leaves with their length.
3476
3477     Stops prematurely on multiline strings and standalone comments.
3478     """
3479     op = cast(
3480         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3481         enumerate_reversed if reversed else enumerate,
3482     )
3483     for index, leaf in op(line.leaves):
3484         length = len(leaf.prefix) + len(leaf.value)
3485         if "\n" in leaf.value:
3486             return  # Multiline strings, we can't continue.
3487
3488         comment: Optional[Leaf]
3489         for comment in line.comments_after(leaf):
3490             length += len(comment.value)
3491
3492         yield index, leaf, length
3493
3494
3495 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3496     """Return True if `line` is no longer than `line_length`.
3497
3498     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3499     """
3500     if not line_str:
3501         line_str = str(line).strip("\n")
3502     return (
3503         len(line_str) <= line_length
3504         and "\n" not in line_str  # multiline strings
3505         and not line.contains_standalone_comments()
3506     )
3507
3508
3509 def can_be_split(line: Line) -> bool:
3510     """Return False if the line cannot be split *for sure*.
3511
3512     This is not an exhaustive search but a cheap heuristic that we can use to
3513     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3514     in unnecessary parentheses).
3515     """
3516     leaves = line.leaves
3517     if len(leaves) < 2:
3518         return False
3519
3520     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3521         call_count = 0
3522         dot_count = 0
3523         next = leaves[-1]
3524         for leaf in leaves[-2::-1]:
3525             if leaf.type in OPENING_BRACKETS:
3526                 if next.type not in CLOSING_BRACKETS:
3527                     return False
3528
3529                 call_count += 1
3530             elif leaf.type == token.DOT:
3531                 dot_count += 1
3532             elif leaf.type == token.NAME:
3533                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3534                     return False
3535
3536             elif leaf.type not in CLOSING_BRACKETS:
3537                 return False
3538
3539             if dot_count > 1 and call_count > 1:
3540                 return False
3541
3542     return True
3543
3544
3545 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3546     """Does `line` have a shape safe to reformat without optional parens around it?
3547
3548     Returns True for only a subset of potentially nice looking formattings but
3549     the point is to not return false positives that end up producing lines that
3550     are too long.
3551     """
3552     bt = line.bracket_tracker
3553     if not bt.delimiters:
3554         # Without delimiters the optional parentheses are useless.
3555         return True
3556
3557     max_priority = bt.max_delimiter_priority()
3558     if bt.delimiter_count_with_priority(max_priority) > 1:
3559         # With more than one delimiter of a kind the optional parentheses read better.
3560         return False
3561
3562     if max_priority == DOT_PRIORITY:
3563         # A single stranded method call doesn't require optional parentheses.
3564         return True
3565
3566     assert len(line.leaves) >= 2, "Stranded delimiter"
3567
3568     first = line.leaves[0]
3569     second = line.leaves[1]
3570     penultimate = line.leaves[-2]
3571     last = line.leaves[-1]
3572
3573     # With a single delimiter, omit if the expression starts or ends with
3574     # a bracket.
3575     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3576         remainder = False
3577         length = 4 * line.depth
3578         for _index, leaf, leaf_length in enumerate_with_length(line):
3579             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3580                 remainder = True
3581             if remainder:
3582                 length += leaf_length
3583                 if length > line_length:
3584                     break
3585
3586                 if leaf.type in OPENING_BRACKETS:
3587                     # There are brackets we can further split on.
3588                     remainder = False
3589
3590         else:
3591             # checked the entire string and line length wasn't exceeded
3592             if len(line.leaves) == _index + 1:
3593                 return True
3594
3595         # Note: we are not returning False here because a line might have *both*
3596         # a leading opening bracket and a trailing closing bracket.  If the
3597         # opening bracket doesn't match our rule, maybe the closing will.
3598
3599     if (
3600         last.type == token.RPAR
3601         or last.type == token.RBRACE
3602         or (
3603             # don't use indexing for omitting optional parentheses;
3604             # it looks weird
3605             last.type == token.RSQB
3606             and last.parent
3607             and last.parent.type != syms.trailer
3608         )
3609     ):
3610         if penultimate.type in OPENING_BRACKETS:
3611             # Empty brackets don't help.
3612             return False
3613
3614         if is_multiline_string(first):
3615             # Additional wrapping of a multiline string in this situation is
3616             # unnecessary.
3617             return True
3618
3619         length = 4 * line.depth
3620         seen_other_brackets = False
3621         for _index, leaf, leaf_length in enumerate_with_length(line):
3622             length += leaf_length
3623             if leaf is last.opening_bracket:
3624                 if seen_other_brackets or length <= line_length:
3625                     return True
3626
3627             elif leaf.type in OPENING_BRACKETS:
3628                 # There are brackets we can further split on.
3629                 seen_other_brackets = True
3630
3631     return False
3632
3633
3634 def get_cache_file(mode: FileMode) -> Path:
3635     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3636
3637
3638 def read_cache(mode: FileMode) -> Cache:
3639     """Read the cache if it exists and is well formed.
3640
3641     If it is not well formed, the call to write_cache later should resolve the issue.
3642     """
3643     cache_file = get_cache_file(mode)
3644     if not cache_file.exists():
3645         return {}
3646
3647     with cache_file.open("rb") as fobj:
3648         try:
3649             cache: Cache = pickle.load(fobj)
3650         except pickle.UnpicklingError:
3651             return {}
3652
3653     return cache
3654
3655
3656 def get_cache_info(path: Path) -> CacheInfo:
3657     """Return the information used to check if a file is already formatted or not."""
3658     stat = path.stat()
3659     return stat.st_mtime, stat.st_size
3660
3661
3662 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3663     """Split an iterable of paths in `sources` into two sets.
3664
3665     The first contains paths of files that modified on disk or are not in the
3666     cache. The other contains paths to non-modified files.
3667     """
3668     todo, done = set(), set()
3669     for src in sources:
3670         src = src.resolve()
3671         if cache.get(src) != get_cache_info(src):
3672             todo.add(src)
3673         else:
3674             done.add(src)
3675     return todo, done
3676
3677
3678 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3679     """Update the cache file."""
3680     cache_file = get_cache_file(mode)
3681     try:
3682         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3683         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3684         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3685             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3686         os.replace(f.name, cache_file)
3687     except OSError:
3688         pass
3689
3690
3691 def patch_click() -> None:
3692     """Make Click not crash.
3693
3694     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3695     default which restricts paths that it can access during the lifetime of the
3696     application.  Click refuses to work in this scenario by raising a RuntimeError.
3697
3698     In case of Black the likelihood that non-ASCII characters are going to be used in
3699     file paths is minimal since it's Python source code.  Moreover, this crash was
3700     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3701     """
3702     try:
3703         from click import core
3704         from click import _unicodefun  # type: ignore
3705     except ModuleNotFoundError:
3706         return
3707
3708     for module in (core, _unicodefun):
3709         if hasattr(module, "_verify_python3_env"):
3710             module._verify_python3_env = lambda: None
3711
3712
3713 def patched_main() -> None:
3714     freeze_support()
3715     patch_click()
3716     main()
3717
3718
3719 if __name__ == "__main__":
3720     patched_main()