]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

906b5064f5992022fbd485e0a57cf19ffeffafe9
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "18.9b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA = 4
137
138
139 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
140     TargetVersion.PY27: set(),
141     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
142     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
144     TargetVersion.PY36: {
145         Feature.UNICODE_LITERALS,
146         Feature.F_STRINGS,
147         Feature.NUMERIC_UNDERSCORES,
148         Feature.TRAILING_COMMA,
149     },
150     TargetVersion.PY37: {
151         Feature.UNICODE_LITERALS,
152         Feature.F_STRINGS,
153         Feature.NUMERIC_UNDERSCORES,
154         Feature.TRAILING_COMMA,
155     },
156     TargetVersion.PY38: {
157         Feature.UNICODE_LITERALS,
158         Feature.F_STRINGS,
159         Feature.NUMERIC_UNDERSCORES,
160         Feature.TRAILING_COMMA,
161     },
162 }
163
164
165 @dataclass
166 class FileMode:
167     target_versions: Set[TargetVersion] = Factory(set)
168     line_length: int = DEFAULT_LINE_LENGTH
169     string_normalization: bool = True
170     is_pyi: bool = False
171
172     def get_cache_key(self) -> str:
173         if self.target_versions:
174             version_str = ",".join(
175                 str(version.value)
176                 for version in sorted(self.target_versions, key=lambda v: v.value)
177             )
178         else:
179             version_str = "-"
180         parts = [
181             version_str,
182             str(self.line_length),
183             str(int(self.string_normalization)),
184             str(int(self.is_pyi)),
185         ]
186         return ".".join(parts)
187
188
189 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
190     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
191
192
193 def read_pyproject_toml(
194     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
195 ) -> Optional[str]:
196     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
197
198     Returns the path to a successfully found and read configuration file, None
199     otherwise.
200     """
201     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
202     if not value:
203         root = find_project_root(ctx.params.get("src", ()))
204         path = root / "pyproject.toml"
205         if path.is_file():
206             value = str(path)
207         else:
208             return None
209
210     try:
211         pyproject_toml = toml.load(value)
212         config = pyproject_toml.get("tool", {}).get("black", {})
213     except (toml.TomlDecodeError, OSError) as e:
214         raise click.FileError(
215             filename=value, hint=f"Error reading configuration file: {e}"
216         )
217
218     if not config:
219         return None
220
221     if ctx.default_map is None:
222         ctx.default_map = {}
223     ctx.default_map.update(  # type: ignore  # bad types in .pyi
224         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
225     )
226     return value
227
228
229 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
230 @click.option(
231     "-l",
232     "--line-length",
233     type=int,
234     default=DEFAULT_LINE_LENGTH,
235     help="How many characters per line to allow.",
236     show_default=True,
237 )
238 @click.option(
239     "-t",
240     "--target-version",
241     type=click.Choice([v.name.lower() for v in TargetVersion]),
242     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
243     multiple=True,
244     help=(
245         "Python versions that should be supported by Black's output. [default: "
246         "per-file auto-detection]"
247     ),
248 )
249 @click.option(
250     "--pyi",
251     is_flag=True,
252     help=(
253         "Format all input files like typing stubs regardless of file extension "
254         "(useful when piping source on standard input)."
255     ),
256 )
257 @click.option(
258     "-S",
259     "--skip-string-normalization",
260     is_flag=True,
261     help="Don't normalize string quotes or prefixes.",
262 )
263 @click.option(
264     "--check",
265     is_flag=True,
266     help=(
267         "Don't write the files back, just return the status.  Return code 0 "
268         "means nothing would change.  Return code 1 means some files would be "
269         "reformatted.  Return code 123 means there was an internal error."
270     ),
271 )
272 @click.option(
273     "--diff",
274     is_flag=True,
275     help="Don't write the files back, just output a diff for each file on stdout.",
276 )
277 @click.option(
278     "--fast/--safe",
279     is_flag=True,
280     help="If --fast given, skip temporary sanity checks. [default: --safe]",
281 )
282 @click.option(
283     "--include",
284     type=str,
285     default=DEFAULT_INCLUDES,
286     help=(
287         "A regular expression that matches files and directories that should be "
288         "included on recursive searches.  An empty value means all files are "
289         "included regardless of the name.  Use forward slashes for directories on "
290         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
291         "later."
292     ),
293     show_default=True,
294 )
295 @click.option(
296     "--exclude",
297     type=str,
298     default=DEFAULT_EXCLUDES,
299     help=(
300         "A regular expression that matches files and directories that should be "
301         "excluded on recursive searches.  An empty value means no paths are excluded. "
302         "Use forward slashes for directories on all platforms (Windows, too).  "
303         "Exclusions are calculated first, inclusions later."
304     ),
305     show_default=True,
306 )
307 @click.option(
308     "-q",
309     "--quiet",
310     is_flag=True,
311     help=(
312         "Don't emit non-error messages to stderr. Errors are still emitted, "
313         "silence those with 2>/dev/null."
314     ),
315 )
316 @click.option(
317     "-v",
318     "--verbose",
319     is_flag=True,
320     help=(
321         "Also emit messages to stderr about files that were not changed or were "
322         "ignored due to --exclude=."
323     ),
324 )
325 @click.version_option(version=__version__)
326 @click.argument(
327     "src",
328     nargs=-1,
329     type=click.Path(
330         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
331     ),
332     is_eager=True,
333 )
334 @click.option(
335     "--config",
336     type=click.Path(
337         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
338     ),
339     is_eager=True,
340     callback=read_pyproject_toml,
341     help="Read configuration from PATH.",
342 )
343 @click.pass_context
344 def main(
345     ctx: click.Context,
346     line_length: int,
347     target_version: List[TargetVersion],
348     check: bool,
349     diff: bool,
350     fast: bool,
351     pyi: bool,
352     skip_string_normalization: bool,
353     quiet: bool,
354     verbose: bool,
355     include: str,
356     exclude: str,
357     src: Tuple[str],
358     config: Optional[str],
359 ) -> None:
360     """The uncompromising code formatter."""
361     write_back = WriteBack.from_configuration(check=check, diff=diff)
362     if target_version:
363         versions = set(target_version)
364     else:
365         # We'll autodetect later.
366         versions = set()
367     mode = FileMode(
368         target_versions=versions,
369         line_length=line_length,
370         is_pyi=pyi,
371         string_normalization=not skip_string_normalization,
372     )
373     if config and verbose:
374         out(f"Using configuration from {config}.", bold=False, fg="blue")
375     try:
376         include_regex = re_compile_maybe_verbose(include)
377     except re.error:
378         err(f"Invalid regular expression for include given: {include!r}")
379         ctx.exit(2)
380     try:
381         exclude_regex = re_compile_maybe_verbose(exclude)
382     except re.error:
383         err(f"Invalid regular expression for exclude given: {exclude!r}")
384         ctx.exit(2)
385     report = Report(check=check, quiet=quiet, verbose=verbose)
386     root = find_project_root(src)
387     sources: Set[Path] = set()
388     for s in src:
389         p = Path(s)
390         if p.is_dir():
391             sources.update(
392                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
393             )
394         elif p.is_file() or s == "-":
395             # if a file was explicitly given, we don't care about its extension
396             sources.add(p)
397         else:
398             err(f"invalid path: {s}")
399     if len(sources) == 0:
400         if verbose or not quiet:
401             out("No paths given. Nothing to do 😴")
402         ctx.exit(0)
403
404     if len(sources) == 1:
405         reformat_one(
406             src=sources.pop(),
407             fast=fast,
408             write_back=write_back,
409             mode=mode,
410             report=report,
411         )
412     else:
413         loop = asyncio.get_event_loop()
414         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
415         try:
416             loop.run_until_complete(
417                 schedule_formatting(
418                     sources=sources,
419                     fast=fast,
420                     write_back=write_back,
421                     mode=mode,
422                     report=report,
423                     loop=loop,
424                     executor=executor,
425                 )
426             )
427         finally:
428             shutdown(loop)
429     if verbose or not quiet:
430         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
431         out(f"All done! {bang}")
432         click.secho(str(report), err=True)
433     ctx.exit(report.return_code)
434
435
436 def reformat_one(
437     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
438 ) -> None:
439     """Reformat a single file under `src` without spawning child processes.
440
441     If `quiet` is True, non-error messages are not output. `line_length`,
442     `write_back`, `fast` and `pyi` options are passed to
443     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
444     """
445     try:
446         changed = Changed.NO
447         if not src.is_file() and str(src) == "-":
448             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
449                 changed = Changed.YES
450         else:
451             cache: Cache = {}
452             if write_back != WriteBack.DIFF:
453                 cache = read_cache(mode)
454                 res_src = src.resolve()
455                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
456                     changed = Changed.CACHED
457             if changed is not Changed.CACHED and format_file_in_place(
458                 src, fast=fast, write_back=write_back, mode=mode
459             ):
460                 changed = Changed.YES
461             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
462                 write_back is WriteBack.CHECK and changed is Changed.NO
463             ):
464                 write_cache(cache, [src], mode)
465         report.done(src, changed)
466     except Exception as exc:
467         report.failed(src, str(exc))
468
469
470 async def schedule_formatting(
471     sources: Set[Path],
472     fast: bool,
473     write_back: WriteBack,
474     mode: FileMode,
475     report: "Report",
476     loop: BaseEventLoop,
477     executor: Executor,
478 ) -> None:
479     """Run formatting of `sources` in parallel using the provided `executor`.
480
481     (Use ProcessPoolExecutors for actual parallelism.)
482
483     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
484     :func:`format_file_in_place`.
485     """
486     cache: Cache = {}
487     if write_back != WriteBack.DIFF:
488         cache = read_cache(mode)
489         sources, cached = filter_cached(cache, sources)
490         for src in sorted(cached):
491             report.done(src, Changed.CACHED)
492     if not sources:
493         return
494
495     cancelled = []
496     sources_to_cache = []
497     lock = None
498     if write_back == WriteBack.DIFF:
499         # For diff output, we need locks to ensure we don't interleave output
500         # from different processes.
501         manager = Manager()
502         lock = manager.Lock()
503     tasks = {
504         loop.run_in_executor(
505             executor, format_file_in_place, src, fast, mode, write_back, lock
506         ): src
507         for src in sorted(sources)
508     }
509     pending: Iterable[asyncio.Task] = tasks.keys()
510     try:
511         loop.add_signal_handler(signal.SIGINT, cancel, pending)
512         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
513     except NotImplementedError:
514         # There are no good alternatives for these on Windows.
515         pass
516     while pending:
517         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
518         for task in done:
519             src = tasks.pop(task)
520             if task.cancelled():
521                 cancelled.append(task)
522             elif task.exception():
523                 report.failed(src, str(task.exception()))
524             else:
525                 changed = Changed.YES if task.result() else Changed.NO
526                 # If the file was written back or was successfully checked as
527                 # well-formatted, store this information in the cache.
528                 if write_back is WriteBack.YES or (
529                     write_back is WriteBack.CHECK and changed is Changed.NO
530                 ):
531                     sources_to_cache.append(src)
532                 report.done(src, changed)
533     if cancelled:
534         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
535     if sources_to_cache:
536         write_cache(cache, sources_to_cache, mode)
537
538
539 def format_file_in_place(
540     src: Path,
541     fast: bool,
542     mode: FileMode,
543     write_back: WriteBack = WriteBack.NO,
544     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
545 ) -> bool:
546     """Format file under `src` path. Return True if changed.
547
548     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
549     code to the file.
550     `line_length` and `fast` options are passed to :func:`format_file_contents`.
551     """
552     if src.suffix == ".pyi":
553         mode = evolve(mode, is_pyi=True)
554
555     then = datetime.utcfromtimestamp(src.stat().st_mtime)
556     with open(src, "rb") as buf:
557         src_contents, encoding, newline = decode_bytes(buf.read())
558     try:
559         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
560     except NothingChanged:
561         return False
562
563     if write_back == write_back.YES:
564         with open(src, "w", encoding=encoding, newline=newline) as f:
565             f.write(dst_contents)
566     elif write_back == write_back.DIFF:
567         now = datetime.utcnow()
568         src_name = f"{src}\t{then} +0000"
569         dst_name = f"{src}\t{now} +0000"
570         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
571         if lock:
572             lock.acquire()
573         try:
574             f = io.TextIOWrapper(
575                 sys.stdout.buffer,
576                 encoding=encoding,
577                 newline=newline,
578                 write_through=True,
579             )
580             f.write(diff_contents)
581             f.detach()
582         finally:
583             if lock:
584                 lock.release()
585     return True
586
587
588 def format_stdin_to_stdout(
589     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
590 ) -> bool:
591     """Format file on stdin. Return True if changed.
592
593     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
594     write a diff to stdout. The `mode` argument is passed to
595     :func:`format_file_contents`.
596     """
597     then = datetime.utcnow()
598     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
599     dst = src
600     try:
601         dst = format_file_contents(src, fast=fast, mode=mode)
602         return True
603
604     except NothingChanged:
605         return False
606
607     finally:
608         f = io.TextIOWrapper(
609             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
610         )
611         if write_back == WriteBack.YES:
612             f.write(dst)
613         elif write_back == WriteBack.DIFF:
614             now = datetime.utcnow()
615             src_name = f"STDIN\t{then} +0000"
616             dst_name = f"STDOUT\t{now} +0000"
617             f.write(diff(src, dst, src_name, dst_name))
618         f.detach()
619
620
621 def format_file_contents(
622     src_contents: str, *, fast: bool, mode: FileMode
623 ) -> FileContent:
624     """Reformat contents a file and return new contents.
625
626     If `fast` is False, additionally confirm that the reformatted code is
627     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
628     `line_length` is passed to :func:`format_str`.
629     """
630     if src_contents.strip() == "":
631         raise NothingChanged
632
633     dst_contents = format_str(src_contents, mode=mode)
634     if src_contents == dst_contents:
635         raise NothingChanged
636
637     if not fast:
638         assert_equivalent(src_contents, dst_contents)
639         assert_stable(src_contents, dst_contents, mode=mode)
640     return dst_contents
641
642
643 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
644     """Reformat a string and return new contents.
645
646     `line_length` determines how many characters per line are allowed.
647     """
648     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
649     dst_contents = ""
650     future_imports = get_future_imports(src_node)
651     if mode.target_versions:
652         versions = mode.target_versions
653     else:
654         versions = detect_target_versions(src_node)
655     normalize_fmt_off(src_node)
656     lines = LineGenerator(
657         remove_u_prefix="unicode_literals" in future_imports
658         or supports_feature(versions, Feature.UNICODE_LITERALS),
659         is_pyi=mode.is_pyi,
660         normalize_strings=mode.string_normalization,
661     )
662     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
663     empty_line = Line()
664     after = 0
665     for current_line in lines.visit(src_node):
666         for _ in range(after):
667             dst_contents += str(empty_line)
668         before, after = elt.maybe_empty_lines(current_line)
669         for _ in range(before):
670             dst_contents += str(empty_line)
671         for line in split_line(
672             current_line,
673             line_length=mode.line_length,
674             supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
675         ):
676             dst_contents += str(line)
677     return dst_contents
678
679
680 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
681     """Return a tuple of (decoded_contents, encoding, newline).
682
683     `newline` is either CRLF or LF but `decoded_contents` is decoded with
684     universal newlines (i.e. only contains LF).
685     """
686     srcbuf = io.BytesIO(src)
687     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
688     if not lines:
689         return "", encoding, "\n"
690
691     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
692     srcbuf.seek(0)
693     with io.TextIOWrapper(srcbuf, encoding) as tiow:
694         return tiow.read(), encoding, newline
695
696
697 GRAMMARS = [
698     pygram.python_grammar_no_print_statement_no_exec_statement,
699     pygram.python_grammar_no_print_statement,
700     pygram.python_grammar,
701 ]
702
703
704 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
705     if not target_versions:
706         return GRAMMARS
707     elif all(not version.is_python2() for version in target_versions):
708         # Python 2-compatible code, so don't try Python 3 grammar.
709         return [
710             pygram.python_grammar_no_print_statement_no_exec_statement,
711             pygram.python_grammar_no_print_statement,
712         ]
713     else:
714         return [pygram.python_grammar]
715
716
717 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
718     """Given a string with source, return the lib2to3 Node."""
719     if src_txt[-1:] != "\n":
720         src_txt += "\n"
721
722     for grammar in get_grammars(set(target_versions)):
723         drv = driver.Driver(grammar, pytree.convert)
724         try:
725             result = drv.parse_string(src_txt, True)
726             break
727
728         except ParseError as pe:
729             lineno, column = pe.context[1]
730             lines = src_txt.splitlines()
731             try:
732                 faulty_line = lines[lineno - 1]
733             except IndexError:
734                 faulty_line = "<line number missing in source>"
735             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
736     else:
737         raise exc from None
738
739     if isinstance(result, Leaf):
740         result = Node(syms.file_input, [result])
741     return result
742
743
744 def lib2to3_unparse(node: Node) -> str:
745     """Given a lib2to3 node, return its string representation."""
746     code = str(node)
747     return code
748
749
750 T = TypeVar("T")
751
752
753 class Visitor(Generic[T]):
754     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
755
756     def visit(self, node: LN) -> Iterator[T]:
757         """Main method to visit `node` and its children.
758
759         It tries to find a `visit_*()` method for the given `node.type`, like
760         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
761         If no dedicated `visit_*()` method is found, chooses `visit_default()`
762         instead.
763
764         Then yields objects of type `T` from the selected visitor.
765         """
766         if node.type < 256:
767             name = token.tok_name[node.type]
768         else:
769             name = type_repr(node.type)
770         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
771
772     def visit_default(self, node: LN) -> Iterator[T]:
773         """Default `visit_*()` implementation. Recurses to children of `node`."""
774         if isinstance(node, Node):
775             for child in node.children:
776                 yield from self.visit(child)
777
778
779 @dataclass
780 class DebugVisitor(Visitor[T]):
781     tree_depth: int = 0
782
783     def visit_default(self, node: LN) -> Iterator[T]:
784         indent = " " * (2 * self.tree_depth)
785         if isinstance(node, Node):
786             _type = type_repr(node.type)
787             out(f"{indent}{_type}", fg="yellow")
788             self.tree_depth += 1
789             for child in node.children:
790                 yield from self.visit(child)
791
792             self.tree_depth -= 1
793             out(f"{indent}/{_type}", fg="yellow", bold=False)
794         else:
795             _type = token.tok_name.get(node.type, str(node.type))
796             out(f"{indent}{_type}", fg="blue", nl=False)
797             if node.prefix:
798                 # We don't have to handle prefixes for `Node` objects since
799                 # that delegates to the first child anyway.
800                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
801             out(f" {node.value!r}", fg="blue", bold=False)
802
803     @classmethod
804     def show(cls, code: Union[str, Leaf, Node]) -> None:
805         """Pretty-print the lib2to3 AST of a given string of `code`.
806
807         Convenience method for debugging.
808         """
809         v: DebugVisitor[None] = DebugVisitor()
810         if isinstance(code, str):
811             code = lib2to3_parse(code)
812         list(v.visit(code))
813
814
815 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
816 STATEMENT = {
817     syms.if_stmt,
818     syms.while_stmt,
819     syms.for_stmt,
820     syms.try_stmt,
821     syms.except_clause,
822     syms.with_stmt,
823     syms.funcdef,
824     syms.classdef,
825 }
826 STANDALONE_COMMENT = 153
827 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
828 LOGIC_OPERATORS = {"and", "or"}
829 COMPARATORS = {
830     token.LESS,
831     token.GREATER,
832     token.EQEQUAL,
833     token.NOTEQUAL,
834     token.LESSEQUAL,
835     token.GREATEREQUAL,
836 }
837 MATH_OPERATORS = {
838     token.VBAR,
839     token.CIRCUMFLEX,
840     token.AMPER,
841     token.LEFTSHIFT,
842     token.RIGHTSHIFT,
843     token.PLUS,
844     token.MINUS,
845     token.STAR,
846     token.SLASH,
847     token.DOUBLESLASH,
848     token.PERCENT,
849     token.AT,
850     token.TILDE,
851     token.DOUBLESTAR,
852 }
853 STARS = {token.STAR, token.DOUBLESTAR}
854 VARARGS_PARENTS = {
855     syms.arglist,
856     syms.argument,  # double star in arglist
857     syms.trailer,  # single argument to call
858     syms.typedargslist,
859     syms.varargslist,  # lambdas
860 }
861 UNPACKING_PARENTS = {
862     syms.atom,  # single element of a list or set literal
863     syms.dictsetmaker,
864     syms.listmaker,
865     syms.testlist_gexp,
866     syms.testlist_star_expr,
867 }
868 TEST_DESCENDANTS = {
869     syms.test,
870     syms.lambdef,
871     syms.or_test,
872     syms.and_test,
873     syms.not_test,
874     syms.comparison,
875     syms.star_expr,
876     syms.expr,
877     syms.xor_expr,
878     syms.and_expr,
879     syms.shift_expr,
880     syms.arith_expr,
881     syms.trailer,
882     syms.term,
883     syms.power,
884 }
885 ASSIGNMENTS = {
886     "=",
887     "+=",
888     "-=",
889     "*=",
890     "@=",
891     "/=",
892     "%=",
893     "&=",
894     "|=",
895     "^=",
896     "<<=",
897     ">>=",
898     "**=",
899     "//=",
900 }
901 COMPREHENSION_PRIORITY = 20
902 COMMA_PRIORITY = 18
903 TERNARY_PRIORITY = 16
904 LOGIC_PRIORITY = 14
905 STRING_PRIORITY = 12
906 COMPARATOR_PRIORITY = 10
907 MATH_PRIORITIES = {
908     token.VBAR: 9,
909     token.CIRCUMFLEX: 8,
910     token.AMPER: 7,
911     token.LEFTSHIFT: 6,
912     token.RIGHTSHIFT: 6,
913     token.PLUS: 5,
914     token.MINUS: 5,
915     token.STAR: 4,
916     token.SLASH: 4,
917     token.DOUBLESLASH: 4,
918     token.PERCENT: 4,
919     token.AT: 4,
920     token.TILDE: 3,
921     token.DOUBLESTAR: 2,
922 }
923 DOT_PRIORITY = 1
924
925
926 @dataclass
927 class BracketTracker:
928     """Keeps track of brackets on a line."""
929
930     depth: int = 0
931     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
932     delimiters: Dict[LeafID, Priority] = Factory(dict)
933     previous: Optional[Leaf] = None
934     _for_loop_depths: List[int] = Factory(list)
935     _lambda_argument_depths: List[int] = Factory(list)
936
937     def mark(self, leaf: Leaf) -> None:
938         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
939
940         All leaves receive an int `bracket_depth` field that stores how deep
941         within brackets a given leaf is. 0 means there are no enclosing brackets
942         that started on this line.
943
944         If a leaf is itself a closing bracket, it receives an `opening_bracket`
945         field that it forms a pair with. This is a one-directional link to
946         avoid reference cycles.
947
948         If a leaf is a delimiter (a token on which Black can split the line if
949         needed) and it's on depth 0, its `id()` is stored in the tracker's
950         `delimiters` field.
951         """
952         if leaf.type == token.COMMENT:
953             return
954
955         self.maybe_decrement_after_for_loop_variable(leaf)
956         self.maybe_decrement_after_lambda_arguments(leaf)
957         if leaf.type in CLOSING_BRACKETS:
958             self.depth -= 1
959             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
960             leaf.opening_bracket = opening_bracket
961         leaf.bracket_depth = self.depth
962         if self.depth == 0:
963             delim = is_split_before_delimiter(leaf, self.previous)
964             if delim and self.previous is not None:
965                 self.delimiters[id(self.previous)] = delim
966             else:
967                 delim = is_split_after_delimiter(leaf, self.previous)
968                 if delim:
969                     self.delimiters[id(leaf)] = delim
970         if leaf.type in OPENING_BRACKETS:
971             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
972             self.depth += 1
973         self.previous = leaf
974         self.maybe_increment_lambda_arguments(leaf)
975         self.maybe_increment_for_loop_variable(leaf)
976
977     def any_open_brackets(self) -> bool:
978         """Return True if there is an yet unmatched open bracket on the line."""
979         return bool(self.bracket_match)
980
981     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
982         """Return the highest priority of a delimiter found on the line.
983
984         Values are consistent with what `is_split_*_delimiter()` return.
985         Raises ValueError on no delimiters.
986         """
987         return max(v for k, v in self.delimiters.items() if k not in exclude)
988
989     def delimiter_count_with_priority(self, priority: int = 0) -> int:
990         """Return the number of delimiters with the given `priority`.
991
992         If no `priority` is passed, defaults to max priority on the line.
993         """
994         if not self.delimiters:
995             return 0
996
997         priority = priority or self.max_delimiter_priority()
998         return sum(1 for p in self.delimiters.values() if p == priority)
999
1000     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1001         """In a for loop, or comprehension, the variables are often unpacks.
1002
1003         To avoid splitting on the comma in this situation, increase the depth of
1004         tokens between `for` and `in`.
1005         """
1006         if leaf.type == token.NAME and leaf.value == "for":
1007             self.depth += 1
1008             self._for_loop_depths.append(self.depth)
1009             return True
1010
1011         return False
1012
1013     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1014         """See `maybe_increment_for_loop_variable` above for explanation."""
1015         if (
1016             self._for_loop_depths
1017             and self._for_loop_depths[-1] == self.depth
1018             and leaf.type == token.NAME
1019             and leaf.value == "in"
1020         ):
1021             self.depth -= 1
1022             self._for_loop_depths.pop()
1023             return True
1024
1025         return False
1026
1027     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1028         """In a lambda expression, there might be more than one argument.
1029
1030         To avoid splitting on the comma in this situation, increase the depth of
1031         tokens between `lambda` and `:`.
1032         """
1033         if leaf.type == token.NAME and leaf.value == "lambda":
1034             self.depth += 1
1035             self._lambda_argument_depths.append(self.depth)
1036             return True
1037
1038         return False
1039
1040     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1041         """See `maybe_increment_lambda_arguments` above for explanation."""
1042         if (
1043             self._lambda_argument_depths
1044             and self._lambda_argument_depths[-1] == self.depth
1045             and leaf.type == token.COLON
1046         ):
1047             self.depth -= 1
1048             self._lambda_argument_depths.pop()
1049             return True
1050
1051         return False
1052
1053     def get_open_lsqb(self) -> Optional[Leaf]:
1054         """Return the most recent opening square bracket (if any)."""
1055         return self.bracket_match.get((self.depth - 1, token.RSQB))
1056
1057
1058 @dataclass
1059 class Line:
1060     """Holds leaves and comments. Can be printed with `str(line)`."""
1061
1062     depth: int = 0
1063     leaves: List[Leaf] = Factory(list)
1064     # The LeafID keys of comments must remain ordered by the corresponding leaf's index
1065     # in leaves
1066     comments: Dict[LeafID, List[Leaf]] = Factory(dict)
1067     bracket_tracker: BracketTracker = Factory(BracketTracker)
1068     inside_brackets: bool = False
1069     should_explode: bool = False
1070
1071     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1072         """Add a new `leaf` to the end of the line.
1073
1074         Unless `preformatted` is True, the `leaf` will receive a new consistent
1075         whitespace prefix and metadata applied by :class:`BracketTracker`.
1076         Trailing commas are maybe removed, unpacked for loop variables are
1077         demoted from being delimiters.
1078
1079         Inline comments are put aside.
1080         """
1081         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1082         if not has_value:
1083             return
1084
1085         if token.COLON == leaf.type and self.is_class_paren_empty:
1086             del self.leaves[-2:]
1087         if self.leaves and not preformatted:
1088             # Note: at this point leaf.prefix should be empty except for
1089             # imports, for which we only preserve newlines.
1090             leaf.prefix += whitespace(
1091                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1092             )
1093         if self.inside_brackets or not preformatted:
1094             self.bracket_tracker.mark(leaf)
1095             self.maybe_remove_trailing_comma(leaf)
1096         if not self.append_comment(leaf):
1097             self.leaves.append(leaf)
1098
1099     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1100         """Like :func:`append()` but disallow invalid standalone comment structure.
1101
1102         Raises ValueError when any `leaf` is appended after a standalone comment
1103         or when a standalone comment is not the first leaf on the line.
1104         """
1105         if self.bracket_tracker.depth == 0:
1106             if self.is_comment:
1107                 raise ValueError("cannot append to standalone comments")
1108
1109             if self.leaves and leaf.type == STANDALONE_COMMENT:
1110                 raise ValueError(
1111                     "cannot append standalone comments to a populated line"
1112                 )
1113
1114         self.append(leaf, preformatted=preformatted)
1115
1116     @property
1117     def is_comment(self) -> bool:
1118         """Is this line a standalone comment?"""
1119         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1120
1121     @property
1122     def is_decorator(self) -> bool:
1123         """Is this line a decorator?"""
1124         return bool(self) and self.leaves[0].type == token.AT
1125
1126     @property
1127     def is_import(self) -> bool:
1128         """Is this an import line?"""
1129         return bool(self) and is_import(self.leaves[0])
1130
1131     @property
1132     def is_class(self) -> bool:
1133         """Is this line a class definition?"""
1134         return (
1135             bool(self)
1136             and self.leaves[0].type == token.NAME
1137             and self.leaves[0].value == "class"
1138         )
1139
1140     @property
1141     def is_stub_class(self) -> bool:
1142         """Is this line a class definition with a body consisting only of "..."?"""
1143         return self.is_class and self.leaves[-3:] == [
1144             Leaf(token.DOT, ".") for _ in range(3)
1145         ]
1146
1147     @property
1148     def is_def(self) -> bool:
1149         """Is this a function definition? (Also returns True for async defs.)"""
1150         try:
1151             first_leaf = self.leaves[0]
1152         except IndexError:
1153             return False
1154
1155         try:
1156             second_leaf: Optional[Leaf] = self.leaves[1]
1157         except IndexError:
1158             second_leaf = None
1159         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1160             first_leaf.type == token.ASYNC
1161             and second_leaf is not None
1162             and second_leaf.type == token.NAME
1163             and second_leaf.value == "def"
1164         )
1165
1166     @property
1167     def is_class_paren_empty(self) -> bool:
1168         """Is this a class with no base classes but using parentheses?
1169
1170         Those are unnecessary and should be removed.
1171         """
1172         return (
1173             bool(self)
1174             and len(self.leaves) == 4
1175             and self.is_class
1176             and self.leaves[2].type == token.LPAR
1177             and self.leaves[2].value == "("
1178             and self.leaves[3].type == token.RPAR
1179             and self.leaves[3].value == ")"
1180         )
1181
1182     @property
1183     def is_triple_quoted_string(self) -> bool:
1184         """Is the line a triple quoted string?"""
1185         return (
1186             bool(self)
1187             and self.leaves[0].type == token.STRING
1188             and self.leaves[0].value.startswith(('"""', "'''"))
1189         )
1190
1191     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1192         """If so, needs to be split before emitting."""
1193         for leaf in self.leaves:
1194             if leaf.type == STANDALONE_COMMENT:
1195                 if leaf.bracket_depth <= depth_limit:
1196                     return True
1197
1198         return False
1199
1200     def contains_multiline_strings(self) -> bool:
1201         for leaf in self.leaves:
1202             if is_multiline_string(leaf):
1203                 return True
1204
1205         return False
1206
1207     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1208         """Remove trailing comma if there is one and it's safe."""
1209         if not (
1210             self.leaves
1211             and self.leaves[-1].type == token.COMMA
1212             and closing.type in CLOSING_BRACKETS
1213         ):
1214             return False
1215
1216         if closing.type == token.RBRACE:
1217             self.remove_trailing_comma()
1218             return True
1219
1220         if closing.type == token.RSQB:
1221             comma = self.leaves[-1]
1222             if comma.parent and comma.parent.type == syms.listmaker:
1223                 self.remove_trailing_comma()
1224                 return True
1225
1226         # For parens let's check if it's safe to remove the comma.
1227         # Imports are always safe.
1228         if self.is_import:
1229             self.remove_trailing_comma()
1230             return True
1231
1232         # Otherwise, if the trailing one is the only one, we might mistakenly
1233         # change a tuple into a different type by removing the comma.
1234         depth = closing.bracket_depth + 1
1235         commas = 0
1236         opening = closing.opening_bracket
1237         for _opening_index, leaf in enumerate(self.leaves):
1238             if leaf is opening:
1239                 break
1240
1241         else:
1242             return False
1243
1244         for leaf in self.leaves[_opening_index + 1 :]:
1245             if leaf is closing:
1246                 break
1247
1248             bracket_depth = leaf.bracket_depth
1249             if bracket_depth == depth and leaf.type == token.COMMA:
1250                 commas += 1
1251                 if leaf.parent and leaf.parent.type == syms.arglist:
1252                     commas += 1
1253                     break
1254
1255         if commas > 1:
1256             self.remove_trailing_comma()
1257             return True
1258
1259         return False
1260
1261     def append_comment(self, comment: Leaf) -> bool:
1262         """Add an inline or standalone comment to the line."""
1263         if (
1264             comment.type == STANDALONE_COMMENT
1265             and self.bracket_tracker.any_open_brackets()
1266         ):
1267             comment.prefix = ""
1268             return False
1269
1270         if comment.type != token.COMMENT:
1271             return False
1272
1273         if not self.leaves:
1274             comment.type = STANDALONE_COMMENT
1275             comment.prefix = ""
1276             return False
1277
1278         else:
1279             leaf_id = id(self.leaves[-1])
1280             if leaf_id not in self.comments:
1281                 self.comments[leaf_id] = [comment]
1282             else:
1283                 self.comments[leaf_id].append(comment)
1284             return True
1285
1286     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1287         """Generate comments that should appear directly after `leaf`."""
1288         return self.comments.get(id(leaf), [])
1289
1290     def remove_trailing_comma(self) -> None:
1291         """Remove the trailing comma and moves the comments attached to it."""
1292         # Remember, the LeafID keys of self.comments are ordered by the
1293         # corresponding leaf's index in self.leaves
1294         # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
1295         # Otherwise, we insert it into self.comments, and it becomes the last entry.
1296         # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
1297         # is maintained
1298         self.comments.setdefault(id(self.leaves[-2]), []).extend(
1299             self.comments.get(id(self.leaves[-1]), [])
1300         )
1301         self.comments.pop(id(self.leaves[-1]), None)
1302         self.leaves.pop()
1303
1304     def is_complex_subscript(self, leaf: Leaf) -> bool:
1305         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1306         open_lsqb = self.bracket_tracker.get_open_lsqb()
1307         if open_lsqb is None:
1308             return False
1309
1310         subscript_start = open_lsqb.next_sibling
1311
1312         if isinstance(subscript_start, Node):
1313             if subscript_start.type == syms.listmaker:
1314                 return False
1315
1316             if subscript_start.type == syms.subscriptlist:
1317                 subscript_start = child_towards(subscript_start, leaf)
1318         return subscript_start is not None and any(
1319             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1320         )
1321
1322     def __str__(self) -> str:
1323         """Render the line."""
1324         if not self:
1325             return "\n"
1326
1327         indent = "    " * self.depth
1328         leaves = iter(self.leaves)
1329         first = next(leaves)
1330         res = f"{first.prefix}{indent}{first.value}"
1331         for leaf in leaves:
1332             res += str(leaf)
1333         for comment in itertools.chain.from_iterable(self.comments.values()):
1334             res += str(comment)
1335         return res + "\n"
1336
1337     def __bool__(self) -> bool:
1338         """Return True if the line has leaves or comments."""
1339         return bool(self.leaves or self.comments)
1340
1341
1342 @dataclass
1343 class EmptyLineTracker:
1344     """Provides a stateful method that returns the number of potential extra
1345     empty lines needed before and after the currently processed line.
1346
1347     Note: this tracker works on lines that haven't been split yet.  It assumes
1348     the prefix of the first leaf consists of optional newlines.  Those newlines
1349     are consumed by `maybe_empty_lines()` and included in the computation.
1350     """
1351
1352     is_pyi: bool = False
1353     previous_line: Optional[Line] = None
1354     previous_after: int = 0
1355     previous_defs: List[int] = Factory(list)
1356
1357     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1358         """Return the number of extra empty lines before and after the `current_line`.
1359
1360         This is for separating `def`, `async def` and `class` with extra empty
1361         lines (two on module-level).
1362         """
1363         before, after = self._maybe_empty_lines(current_line)
1364         before -= self.previous_after
1365         self.previous_after = after
1366         self.previous_line = current_line
1367         return before, after
1368
1369     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1370         max_allowed = 1
1371         if current_line.depth == 0:
1372             max_allowed = 1 if self.is_pyi else 2
1373         if current_line.leaves:
1374             # Consume the first leaf's extra newlines.
1375             first_leaf = current_line.leaves[0]
1376             before = first_leaf.prefix.count("\n")
1377             before = min(before, max_allowed)
1378             first_leaf.prefix = ""
1379         else:
1380             before = 0
1381         depth = current_line.depth
1382         while self.previous_defs and self.previous_defs[-1] >= depth:
1383             self.previous_defs.pop()
1384             if self.is_pyi:
1385                 before = 0 if depth else 1
1386             else:
1387                 before = 1 if depth else 2
1388         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1389             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1390
1391         if (
1392             self.previous_line
1393             and self.previous_line.is_import
1394             and not current_line.is_import
1395             and depth == self.previous_line.depth
1396         ):
1397             return (before or 1), 0
1398
1399         if (
1400             self.previous_line
1401             and self.previous_line.is_class
1402             and current_line.is_triple_quoted_string
1403         ):
1404             return before, 1
1405
1406         return before, 0
1407
1408     def _maybe_empty_lines_for_class_or_def(
1409         self, current_line: Line, before: int
1410     ) -> Tuple[int, int]:
1411         if not current_line.is_decorator:
1412             self.previous_defs.append(current_line.depth)
1413         if self.previous_line is None:
1414             # Don't insert empty lines before the first line in the file.
1415             return 0, 0
1416
1417         if self.previous_line.is_decorator:
1418             return 0, 0
1419
1420         if self.previous_line.depth < current_line.depth and (
1421             self.previous_line.is_class or self.previous_line.is_def
1422         ):
1423             return 0, 0
1424
1425         if (
1426             self.previous_line.is_comment
1427             and self.previous_line.depth == current_line.depth
1428             and before == 0
1429         ):
1430             return 0, 0
1431
1432         if self.is_pyi:
1433             if self.previous_line.depth > current_line.depth:
1434                 newlines = 1
1435             elif current_line.is_class or self.previous_line.is_class:
1436                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1437                     # No blank line between classes with an empty body
1438                     newlines = 0
1439                 else:
1440                     newlines = 1
1441             elif current_line.is_def and not self.previous_line.is_def:
1442                 # Blank line between a block of functions and a block of non-functions
1443                 newlines = 1
1444             else:
1445                 newlines = 0
1446         else:
1447             newlines = 2
1448         if current_line.depth and newlines:
1449             newlines -= 1
1450         return newlines, 0
1451
1452
1453 @dataclass
1454 class LineGenerator(Visitor[Line]):
1455     """Generates reformatted Line objects.  Empty lines are not emitted.
1456
1457     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1458     in ways that will no longer stringify to valid Python code on the tree.
1459     """
1460
1461     is_pyi: bool = False
1462     normalize_strings: bool = True
1463     current_line: Line = Factory(Line)
1464     remove_u_prefix: bool = False
1465
1466     def line(self, indent: int = 0) -> Iterator[Line]:
1467         """Generate a line.
1468
1469         If the line is empty, only emit if it makes sense.
1470         If the line is too long, split it first and then generate.
1471
1472         If any lines were generated, set up a new current_line.
1473         """
1474         if not self.current_line:
1475             self.current_line.depth += indent
1476             return  # Line is empty, don't emit. Creating a new one unnecessary.
1477
1478         complete_line = self.current_line
1479         self.current_line = Line(depth=complete_line.depth + indent)
1480         yield complete_line
1481
1482     def visit_default(self, node: LN) -> Iterator[Line]:
1483         """Default `visit_*()` implementation. Recurses to children of `node`."""
1484         if isinstance(node, Leaf):
1485             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1486             for comment in generate_comments(node):
1487                 if any_open_brackets:
1488                     # any comment within brackets is subject to splitting
1489                     self.current_line.append(comment)
1490                 elif comment.type == token.COMMENT:
1491                     # regular trailing comment
1492                     self.current_line.append(comment)
1493                     yield from self.line()
1494
1495                 else:
1496                     # regular standalone comment
1497                     yield from self.line()
1498
1499                     self.current_line.append(comment)
1500                     yield from self.line()
1501
1502             normalize_prefix(node, inside_brackets=any_open_brackets)
1503             if self.normalize_strings and node.type == token.STRING:
1504                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1505                 normalize_string_quotes(node)
1506             if node.type == token.NUMBER:
1507                 normalize_numeric_literal(node)
1508             if node.type not in WHITESPACE:
1509                 self.current_line.append(node)
1510         yield from super().visit_default(node)
1511
1512     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1513         """Increase indentation level, maybe yield a line."""
1514         # In blib2to3 INDENT never holds comments.
1515         yield from self.line(+1)
1516         yield from self.visit_default(node)
1517
1518     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1519         """Decrease indentation level, maybe yield a line."""
1520         # The current line might still wait for trailing comments.  At DEDENT time
1521         # there won't be any (they would be prefixes on the preceding NEWLINE).
1522         # Emit the line then.
1523         yield from self.line()
1524
1525         # While DEDENT has no value, its prefix may contain standalone comments
1526         # that belong to the current indentation level.  Get 'em.
1527         yield from self.visit_default(node)
1528
1529         # Finally, emit the dedent.
1530         yield from self.line(-1)
1531
1532     def visit_stmt(
1533         self, node: Node, keywords: Set[str], parens: Set[str]
1534     ) -> Iterator[Line]:
1535         """Visit a statement.
1536
1537         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1538         `def`, `with`, `class`, `assert` and assignments.
1539
1540         The relevant Python language `keywords` for a given statement will be
1541         NAME leaves within it. This methods puts those on a separate line.
1542
1543         `parens` holds a set of string leaf values immediately after which
1544         invisible parens should be put.
1545         """
1546         normalize_invisible_parens(node, parens_after=parens)
1547         for child in node.children:
1548             if child.type == token.NAME and child.value in keywords:  # type: ignore
1549                 yield from self.line()
1550
1551             yield from self.visit(child)
1552
1553     def visit_suite(self, node: Node) -> Iterator[Line]:
1554         """Visit a suite."""
1555         if self.is_pyi and is_stub_suite(node):
1556             yield from self.visit(node.children[2])
1557         else:
1558             yield from self.visit_default(node)
1559
1560     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1561         """Visit a statement without nested statements."""
1562         is_suite_like = node.parent and node.parent.type in STATEMENT
1563         if is_suite_like:
1564             if self.is_pyi and is_stub_body(node):
1565                 yield from self.visit_default(node)
1566             else:
1567                 yield from self.line(+1)
1568                 yield from self.visit_default(node)
1569                 yield from self.line(-1)
1570
1571         else:
1572             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1573                 yield from self.line()
1574             yield from self.visit_default(node)
1575
1576     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1577         """Visit `async def`, `async for`, `async with`."""
1578         yield from self.line()
1579
1580         children = iter(node.children)
1581         for child in children:
1582             yield from self.visit(child)
1583
1584             if child.type == token.ASYNC:
1585                 break
1586
1587         internal_stmt = next(children)
1588         for child in internal_stmt.children:
1589             yield from self.visit(child)
1590
1591     def visit_decorators(self, node: Node) -> Iterator[Line]:
1592         """Visit decorators."""
1593         for child in node.children:
1594             yield from self.line()
1595             yield from self.visit(child)
1596
1597     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1598         """Remove a semicolon and put the other statement on a separate line."""
1599         yield from self.line()
1600
1601     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1602         """End of file. Process outstanding comments and end with a newline."""
1603         yield from self.visit_default(leaf)
1604         yield from self.line()
1605
1606     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1607         if not self.current_line.bracket_tracker.any_open_brackets():
1608             yield from self.line()
1609         yield from self.visit_default(leaf)
1610
1611     def __attrs_post_init__(self) -> None:
1612         """You are in a twisty little maze of passages."""
1613         v = self.visit_stmt
1614         Ø: Set[str] = set()
1615         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1616         self.visit_if_stmt = partial(
1617             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1618         )
1619         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1620         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1621         self.visit_try_stmt = partial(
1622             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1623         )
1624         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1625         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1626         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1627         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1628         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1629         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1630         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1631         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1632         self.visit_async_funcdef = self.visit_async_stmt
1633         self.visit_decorated = self.visit_decorators
1634
1635
1636 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1637 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1638 OPENING_BRACKETS = set(BRACKET.keys())
1639 CLOSING_BRACKETS = set(BRACKET.values())
1640 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1641 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1642
1643
1644 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1645     """Return whitespace prefix if needed for the given `leaf`.
1646
1647     `complex_subscript` signals whether the given leaf is part of a subscription
1648     which has non-trivial arguments, like arithmetic expressions or function calls.
1649     """
1650     NO = ""
1651     SPACE = " "
1652     DOUBLESPACE = "  "
1653     t = leaf.type
1654     p = leaf.parent
1655     v = leaf.value
1656     if t in ALWAYS_NO_SPACE:
1657         return NO
1658
1659     if t == token.COMMENT:
1660         return DOUBLESPACE
1661
1662     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1663     if t == token.COLON and p.type not in {
1664         syms.subscript,
1665         syms.subscriptlist,
1666         syms.sliceop,
1667     }:
1668         return NO
1669
1670     prev = leaf.prev_sibling
1671     if not prev:
1672         prevp = preceding_leaf(p)
1673         if not prevp or prevp.type in OPENING_BRACKETS:
1674             return NO
1675
1676         if t == token.COLON:
1677             if prevp.type == token.COLON:
1678                 return NO
1679
1680             elif prevp.type != token.COMMA and not complex_subscript:
1681                 return NO
1682
1683             return SPACE
1684
1685         if prevp.type == token.EQUAL:
1686             if prevp.parent:
1687                 if prevp.parent.type in {
1688                     syms.arglist,
1689                     syms.argument,
1690                     syms.parameters,
1691                     syms.varargslist,
1692                 }:
1693                     return NO
1694
1695                 elif prevp.parent.type == syms.typedargslist:
1696                     # A bit hacky: if the equal sign has whitespace, it means we
1697                     # previously found it's a typed argument.  So, we're using
1698                     # that, too.
1699                     return prevp.prefix
1700
1701         elif prevp.type in STARS:
1702             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1703                 return NO
1704
1705         elif prevp.type == token.COLON:
1706             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1707                 return SPACE if complex_subscript else NO
1708
1709         elif (
1710             prevp.parent
1711             and prevp.parent.type == syms.factor
1712             and prevp.type in MATH_OPERATORS
1713         ):
1714             return NO
1715
1716         elif (
1717             prevp.type == token.RIGHTSHIFT
1718             and prevp.parent
1719             and prevp.parent.type == syms.shift_expr
1720             and prevp.prev_sibling
1721             and prevp.prev_sibling.type == token.NAME
1722             and prevp.prev_sibling.value == "print"  # type: ignore
1723         ):
1724             # Python 2 print chevron
1725             return NO
1726
1727     elif prev.type in OPENING_BRACKETS:
1728         return NO
1729
1730     if p.type in {syms.parameters, syms.arglist}:
1731         # untyped function signatures or calls
1732         if not prev or prev.type != token.COMMA:
1733             return NO
1734
1735     elif p.type == syms.varargslist:
1736         # lambdas
1737         if prev and prev.type != token.COMMA:
1738             return NO
1739
1740     elif p.type == syms.typedargslist:
1741         # typed function signatures
1742         if not prev:
1743             return NO
1744
1745         if t == token.EQUAL:
1746             if prev.type != syms.tname:
1747                 return NO
1748
1749         elif prev.type == token.EQUAL:
1750             # A bit hacky: if the equal sign has whitespace, it means we
1751             # previously found it's a typed argument.  So, we're using that, too.
1752             return prev.prefix
1753
1754         elif prev.type != token.COMMA:
1755             return NO
1756
1757     elif p.type == syms.tname:
1758         # type names
1759         if not prev:
1760             prevp = preceding_leaf(p)
1761             if not prevp or prevp.type != token.COMMA:
1762                 return NO
1763
1764     elif p.type == syms.trailer:
1765         # attributes and calls
1766         if t == token.LPAR or t == token.RPAR:
1767             return NO
1768
1769         if not prev:
1770             if t == token.DOT:
1771                 prevp = preceding_leaf(p)
1772                 if not prevp or prevp.type != token.NUMBER:
1773                     return NO
1774
1775             elif t == token.LSQB:
1776                 return NO
1777
1778         elif prev.type != token.COMMA:
1779             return NO
1780
1781     elif p.type == syms.argument:
1782         # single argument
1783         if t == token.EQUAL:
1784             return NO
1785
1786         if not prev:
1787             prevp = preceding_leaf(p)
1788             if not prevp or prevp.type == token.LPAR:
1789                 return NO
1790
1791         elif prev.type in {token.EQUAL} | STARS:
1792             return NO
1793
1794     elif p.type == syms.decorator:
1795         # decorators
1796         return NO
1797
1798     elif p.type == syms.dotted_name:
1799         if prev:
1800             return NO
1801
1802         prevp = preceding_leaf(p)
1803         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1804             return NO
1805
1806     elif p.type == syms.classdef:
1807         if t == token.LPAR:
1808             return NO
1809
1810         if prev and prev.type == token.LPAR:
1811             return NO
1812
1813     elif p.type in {syms.subscript, syms.sliceop}:
1814         # indexing
1815         if not prev:
1816             assert p.parent is not None, "subscripts are always parented"
1817             if p.parent.type == syms.subscriptlist:
1818                 return SPACE
1819
1820             return NO
1821
1822         elif not complex_subscript:
1823             return NO
1824
1825     elif p.type == syms.atom:
1826         if prev and t == token.DOT:
1827             # dots, but not the first one.
1828             return NO
1829
1830     elif p.type == syms.dictsetmaker:
1831         # dict unpacking
1832         if prev and prev.type == token.DOUBLESTAR:
1833             return NO
1834
1835     elif p.type in {syms.factor, syms.star_expr}:
1836         # unary ops
1837         if not prev:
1838             prevp = preceding_leaf(p)
1839             if not prevp or prevp.type in OPENING_BRACKETS:
1840                 return NO
1841
1842             prevp_parent = prevp.parent
1843             assert prevp_parent is not None
1844             if prevp.type == token.COLON and prevp_parent.type in {
1845                 syms.subscript,
1846                 syms.sliceop,
1847             }:
1848                 return NO
1849
1850             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1851                 return NO
1852
1853         elif t in {token.NAME, token.NUMBER, token.STRING}:
1854             return NO
1855
1856     elif p.type == syms.import_from:
1857         if t == token.DOT:
1858             if prev and prev.type == token.DOT:
1859                 return NO
1860
1861         elif t == token.NAME:
1862             if v == "import":
1863                 return SPACE
1864
1865             if prev and prev.type == token.DOT:
1866                 return NO
1867
1868     elif p.type == syms.sliceop:
1869         return NO
1870
1871     return SPACE
1872
1873
1874 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1875     """Return the first leaf that precedes `node`, if any."""
1876     while node:
1877         res = node.prev_sibling
1878         if res:
1879             if isinstance(res, Leaf):
1880                 return res
1881
1882             try:
1883                 return list(res.leaves())[-1]
1884
1885             except IndexError:
1886                 return None
1887
1888         node = node.parent
1889     return None
1890
1891
1892 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1893     """Return the child of `ancestor` that contains `descendant`."""
1894     node: Optional[LN] = descendant
1895     while node and node.parent != ancestor:
1896         node = node.parent
1897     return node
1898
1899
1900 def container_of(leaf: Leaf) -> LN:
1901     """Return `leaf` or one of its ancestors that is the topmost container of it.
1902
1903     By "container" we mean a node where `leaf` is the very first child.
1904     """
1905     same_prefix = leaf.prefix
1906     container: LN = leaf
1907     while container:
1908         parent = container.parent
1909         if parent is None:
1910             break
1911
1912         if parent.children[0].prefix != same_prefix:
1913             break
1914
1915         if parent.type == syms.file_input:
1916             break
1917
1918         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1919             break
1920
1921         container = parent
1922     return container
1923
1924
1925 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1926     """Return the priority of the `leaf` delimiter, given a line break after it.
1927
1928     The delimiter priorities returned here are from those delimiters that would
1929     cause a line break after themselves.
1930
1931     Higher numbers are higher priority.
1932     """
1933     if leaf.type == token.COMMA:
1934         return COMMA_PRIORITY
1935
1936     return 0
1937
1938
1939 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1940     """Return the priority of the `leaf` delimiter, given a line break before it.
1941
1942     The delimiter priorities returned here are from those delimiters that would
1943     cause a line break before themselves.
1944
1945     Higher numbers are higher priority.
1946     """
1947     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1948         # * and ** might also be MATH_OPERATORS but in this case they are not.
1949         # Don't treat them as a delimiter.
1950         return 0
1951
1952     if (
1953         leaf.type == token.DOT
1954         and leaf.parent
1955         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1956         and (previous is None or previous.type in CLOSING_BRACKETS)
1957     ):
1958         return DOT_PRIORITY
1959
1960     if (
1961         leaf.type in MATH_OPERATORS
1962         and leaf.parent
1963         and leaf.parent.type not in {syms.factor, syms.star_expr}
1964     ):
1965         return MATH_PRIORITIES[leaf.type]
1966
1967     if leaf.type in COMPARATORS:
1968         return COMPARATOR_PRIORITY
1969
1970     if (
1971         leaf.type == token.STRING
1972         and previous is not None
1973         and previous.type == token.STRING
1974     ):
1975         return STRING_PRIORITY
1976
1977     if leaf.type not in {token.NAME, token.ASYNC}:
1978         return 0
1979
1980     if (
1981         leaf.value == "for"
1982         and leaf.parent
1983         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1984         or leaf.type == token.ASYNC
1985     ):
1986         if (
1987             not isinstance(leaf.prev_sibling, Leaf)
1988             or leaf.prev_sibling.value != "async"
1989         ):
1990             return COMPREHENSION_PRIORITY
1991
1992     if (
1993         leaf.value == "if"
1994         and leaf.parent
1995         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1996     ):
1997         return COMPREHENSION_PRIORITY
1998
1999     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2000         return TERNARY_PRIORITY
2001
2002     if leaf.value == "is":
2003         return COMPARATOR_PRIORITY
2004
2005     if (
2006         leaf.value == "in"
2007         and leaf.parent
2008         and leaf.parent.type in {syms.comp_op, syms.comparison}
2009         and not (
2010             previous is not None
2011             and previous.type == token.NAME
2012             and previous.value == "not"
2013         )
2014     ):
2015         return COMPARATOR_PRIORITY
2016
2017     if (
2018         leaf.value == "not"
2019         and leaf.parent
2020         and leaf.parent.type == syms.comp_op
2021         and not (
2022             previous is not None
2023             and previous.type == token.NAME
2024             and previous.value == "is"
2025         )
2026     ):
2027         return COMPARATOR_PRIORITY
2028
2029     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2030         return LOGIC_PRIORITY
2031
2032     return 0
2033
2034
2035 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2036 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2037
2038
2039 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2040     """Clean the prefix of the `leaf` and generate comments from it, if any.
2041
2042     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2043     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2044     move because it does away with modifying the grammar to include all the
2045     possible places in which comments can be placed.
2046
2047     The sad consequence for us though is that comments don't "belong" anywhere.
2048     This is why this function generates simple parentless Leaf objects for
2049     comments.  We simply don't know what the correct parent should be.
2050
2051     No matter though, we can live without this.  We really only need to
2052     differentiate between inline and standalone comments.  The latter don't
2053     share the line with any code.
2054
2055     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2056     are emitted with a fake STANDALONE_COMMENT token identifier.
2057     """
2058     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2059         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2060
2061
2062 @dataclass
2063 class ProtoComment:
2064     """Describes a piece of syntax that is a comment.
2065
2066     It's not a :class:`blib2to3.pytree.Leaf` so that:
2067
2068     * it can be cached (`Leaf` objects should not be reused more than once as
2069       they store their lineno, column, prefix, and parent information);
2070     * `newlines` and `consumed` fields are kept separate from the `value`. This
2071       simplifies handling of special marker comments like ``# fmt: off/on``.
2072     """
2073
2074     type: int  # token.COMMENT or STANDALONE_COMMENT
2075     value: str  # content of the comment
2076     newlines: int  # how many newlines before the comment
2077     consumed: int  # how many characters of the original leaf's prefix did we consume
2078
2079
2080 @lru_cache(maxsize=4096)
2081 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2082     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2083     result: List[ProtoComment] = []
2084     if not prefix or "#" not in prefix:
2085         return result
2086
2087     consumed = 0
2088     nlines = 0
2089     for index, line in enumerate(prefix.split("\n")):
2090         consumed += len(line) + 1  # adding the length of the split '\n'
2091         line = line.lstrip()
2092         if not line:
2093             nlines += 1
2094         if not line.startswith("#"):
2095             continue
2096
2097         if index == 0 and not is_endmarker:
2098             comment_type = token.COMMENT  # simple trailing comment
2099         else:
2100             comment_type = STANDALONE_COMMENT
2101         comment = make_comment(line)
2102         result.append(
2103             ProtoComment(
2104                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2105             )
2106         )
2107         nlines = 0
2108     return result
2109
2110
2111 def make_comment(content: str) -> str:
2112     """Return a consistently formatted comment from the given `content` string.
2113
2114     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2115     space between the hash sign and the content.
2116
2117     If `content` didn't start with a hash sign, one is provided.
2118     """
2119     content = content.rstrip()
2120     if not content:
2121         return "#"
2122
2123     if content[0] == "#":
2124         content = content[1:]
2125     if content and content[0] not in " !:#'%":
2126         content = " " + content
2127     return "#" + content
2128
2129
2130 def split_line(
2131     line: Line,
2132     line_length: int,
2133     inner: bool = False,
2134     supports_trailing_commas: bool = False,
2135 ) -> Iterator[Line]:
2136     """Split a `line` into potentially many lines.
2137
2138     They should fit in the allotted `line_length` but might not be able to.
2139     `inner` signifies that there were a pair of brackets somewhere around the
2140     current `line`, possibly transitively. This means we can fallback to splitting
2141     by delimiters if the LHS/RHS don't yield any results.
2142
2143     If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
2144     """
2145     if line.is_comment:
2146         yield line
2147         return
2148
2149     line_str = str(line).strip("\n")
2150
2151     # we don't want to split special comments like type annotations
2152     # https://github.com/python/typing/issues/186
2153     has_special_comment = False
2154     for leaf in line.leaves:
2155         for comment in line.comments_after(leaf):
2156             if leaf.type == token.COMMA and is_special_comment(comment):
2157                 has_special_comment = True
2158
2159     if (
2160         not has_special_comment
2161         and not line.should_explode
2162         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2163     ):
2164         yield line
2165         return
2166
2167     split_funcs: List[SplitFunc]
2168     if line.is_def:
2169         split_funcs = [left_hand_split]
2170     else:
2171
2172         def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
2173             for omit in generate_trailers_to_omit(line, line_length):
2174                 lines = list(
2175                     right_hand_split(
2176                         line, line_length, supports_trailing_commas, omit=omit
2177                     )
2178                 )
2179                 if is_line_short_enough(lines[0], line_length=line_length):
2180                     yield from lines
2181                     return
2182
2183             # All splits failed, best effort split with no omits.
2184             # This mostly happens to multiline strings that are by definition
2185             # reported as not fitting a single line.
2186             yield from right_hand_split(line, supports_trailing_commas)
2187
2188         if line.inside_brackets:
2189             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2190         else:
2191             split_funcs = [rhs]
2192     for split_func in split_funcs:
2193         # We are accumulating lines in `result` because we might want to abort
2194         # mission and return the original line in the end, or attempt a different
2195         # split altogether.
2196         result: List[Line] = []
2197         try:
2198             for l in split_func(line, supports_trailing_commas):
2199                 if str(l).strip("\n") == line_str:
2200                     raise CannotSplit("Split function returned an unchanged result")
2201
2202                 result.extend(
2203                     split_line(
2204                         l,
2205                         line_length=line_length,
2206                         inner=True,
2207                         supports_trailing_commas=supports_trailing_commas,
2208                     )
2209                 )
2210         except CannotSplit:
2211             continue
2212
2213         else:
2214             yield from result
2215             break
2216
2217     else:
2218         yield line
2219
2220
2221 def left_hand_split(
2222     line: Line, supports_trailing_commas: bool = False
2223 ) -> Iterator[Line]:
2224     """Split line into many lines, starting with the first matching bracket pair.
2225
2226     Note: this usually looks weird, only use this for function definitions.
2227     Prefer RHS otherwise.  This is why this function is not symmetrical with
2228     :func:`right_hand_split` which also handles optional parentheses.
2229     """
2230     tail_leaves: List[Leaf] = []
2231     body_leaves: List[Leaf] = []
2232     head_leaves: List[Leaf] = []
2233     current_leaves = head_leaves
2234     matching_bracket = None
2235     for leaf in line.leaves:
2236         if (
2237             current_leaves is body_leaves
2238             and leaf.type in CLOSING_BRACKETS
2239             and leaf.opening_bracket is matching_bracket
2240         ):
2241             current_leaves = tail_leaves if body_leaves else head_leaves
2242         current_leaves.append(leaf)
2243         if current_leaves is head_leaves:
2244             if leaf.type in OPENING_BRACKETS:
2245                 matching_bracket = leaf
2246                 current_leaves = body_leaves
2247     if not matching_bracket:
2248         raise CannotSplit("No brackets found")
2249
2250     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2251     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2252     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2253     bracket_split_succeeded_or_raise(head, body, tail)
2254     for result in (head, body, tail):
2255         if result:
2256             yield result
2257
2258
2259 def right_hand_split(
2260     line: Line,
2261     line_length: int,
2262     supports_trailing_commas: bool = False,
2263     omit: Collection[LeafID] = (),
2264 ) -> Iterator[Line]:
2265     """Split line into many lines, starting with the last matching bracket pair.
2266
2267     If the split was by optional parentheses, attempt splitting without them, too.
2268     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2269     this split.
2270
2271     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2272     """
2273     tail_leaves: List[Leaf] = []
2274     body_leaves: List[Leaf] = []
2275     head_leaves: List[Leaf] = []
2276     current_leaves = tail_leaves
2277     opening_bracket = None
2278     closing_bracket = None
2279     for leaf in reversed(line.leaves):
2280         if current_leaves is body_leaves:
2281             if leaf is opening_bracket:
2282                 current_leaves = head_leaves if body_leaves else tail_leaves
2283         current_leaves.append(leaf)
2284         if current_leaves is tail_leaves:
2285             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2286                 opening_bracket = leaf.opening_bracket
2287                 closing_bracket = leaf
2288                 current_leaves = body_leaves
2289     if not (opening_bracket and closing_bracket and head_leaves):
2290         # If there is no opening or closing_bracket that means the split failed and
2291         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2292         # the matching `opening_bracket` wasn't available on `line` anymore.
2293         raise CannotSplit("No brackets found")
2294
2295     tail_leaves.reverse()
2296     body_leaves.reverse()
2297     head_leaves.reverse()
2298     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2299     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2300     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2301     bracket_split_succeeded_or_raise(head, body, tail)
2302     if (
2303         # the body shouldn't be exploded
2304         not body.should_explode
2305         # the opening bracket is an optional paren
2306         and opening_bracket.type == token.LPAR
2307         and not opening_bracket.value
2308         # the closing bracket is an optional paren
2309         and closing_bracket.type == token.RPAR
2310         and not closing_bracket.value
2311         # it's not an import (optional parens are the only thing we can split on
2312         # in this case; attempting a split without them is a waste of time)
2313         and not line.is_import
2314         # there are no standalone comments in the body
2315         and not body.contains_standalone_comments(0)
2316         # and we can actually remove the parens
2317         and can_omit_invisible_parens(body, line_length)
2318     ):
2319         omit = {id(closing_bracket), *omit}
2320         try:
2321             yield from right_hand_split(
2322                 line,
2323                 line_length,
2324                 supports_trailing_commas=supports_trailing_commas,
2325                 omit=omit,
2326             )
2327             return
2328
2329         except CannotSplit:
2330             if not (
2331                 can_be_split(body)
2332                 or is_line_short_enough(body, line_length=line_length)
2333             ):
2334                 raise CannotSplit(
2335                     "Splitting failed, body is still too long and can't be split."
2336                 )
2337
2338             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2339                 raise CannotSplit(
2340                     "The current optional pair of parentheses is bound to fail to "
2341                     "satisfy the splitting algorithm because the head or the tail "
2342                     "contains multiline strings which by definition never fit one "
2343                     "line."
2344                 )
2345
2346     ensure_visible(opening_bracket)
2347     ensure_visible(closing_bracket)
2348     for result in (head, body, tail):
2349         if result:
2350             yield result
2351
2352
2353 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2354     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2355
2356     Do nothing otherwise.
2357
2358     A left- or right-hand split is based on a pair of brackets. Content before
2359     (and including) the opening bracket is left on one line, content inside the
2360     brackets is put on a separate line, and finally content starting with and
2361     following the closing bracket is put on a separate line.
2362
2363     Those are called `head`, `body`, and `tail`, respectively. If the split
2364     produced the same line (all content in `head`) or ended up with an empty `body`
2365     and the `tail` is just the closing bracket, then it's considered failed.
2366     """
2367     tail_len = len(str(tail).strip())
2368     if not body:
2369         if tail_len == 0:
2370             raise CannotSplit("Splitting brackets produced the same line")
2371
2372         elif tail_len < 3:
2373             raise CannotSplit(
2374                 f"Splitting brackets on an empty body to save "
2375                 f"{tail_len} characters is not worth it"
2376             )
2377
2378
2379 def bracket_split_build_line(
2380     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2381 ) -> Line:
2382     """Return a new line with given `leaves` and respective comments from `original`.
2383
2384     If `is_body` is True, the result line is one-indented inside brackets and as such
2385     has its first leaf's prefix normalized and a trailing comma added when expected.
2386     """
2387     result = Line(depth=original.depth)
2388     if is_body:
2389         result.inside_brackets = True
2390         result.depth += 1
2391         if leaves:
2392             # Since body is a new indent level, remove spurious leading whitespace.
2393             normalize_prefix(leaves[0], inside_brackets=True)
2394             # Ensure a trailing comma when expected.
2395             if original.is_import:
2396                 if leaves[-1].type != token.COMMA:
2397                     leaves.append(Leaf(token.COMMA, ","))
2398     # Populate the line
2399     for leaf in leaves:
2400         result.append(leaf, preformatted=True)
2401         for comment_after in original.comments_after(leaf):
2402             result.append(comment_after, preformatted=True)
2403     if is_body:
2404         result.should_explode = should_explode(result, opening_bracket)
2405     return result
2406
2407
2408 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2409     """Normalize prefix of the first leaf in every line returned by `split_func`.
2410
2411     This is a decorator over relevant split functions.
2412     """
2413
2414     @wraps(split_func)
2415     def split_wrapper(
2416         line: Line, supports_trailing_commas: bool = False
2417     ) -> Iterator[Line]:
2418         for l in split_func(line, supports_trailing_commas):
2419             normalize_prefix(l.leaves[0], inside_brackets=True)
2420             yield l
2421
2422     return split_wrapper
2423
2424
2425 @dont_increase_indentation
2426 def delimiter_split(
2427     line: Line, supports_trailing_commas: bool = False
2428 ) -> Iterator[Line]:
2429     """Split according to delimiters of the highest priority.
2430
2431     If `supports_trailing_commas` is True, the split will add trailing commas
2432     also in function signatures that contain `*` and `**`.
2433     """
2434     try:
2435         last_leaf = line.leaves[-1]
2436     except IndexError:
2437         raise CannotSplit("Line empty")
2438
2439     bt = line.bracket_tracker
2440     try:
2441         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2442     except ValueError:
2443         raise CannotSplit("No delimiters found")
2444
2445     if delimiter_priority == DOT_PRIORITY:
2446         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2447             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2448
2449     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2450     lowest_depth = sys.maxsize
2451     trailing_comma_safe = True
2452
2453     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2454         """Append `leaf` to current line or to new line if appending impossible."""
2455         nonlocal current_line
2456         try:
2457             current_line.append_safe(leaf, preformatted=True)
2458         except ValueError:
2459             yield current_line
2460
2461             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2462             current_line.append(leaf)
2463
2464     for leaf in line.leaves:
2465         yield from append_to_line(leaf)
2466
2467         for comment_after in line.comments_after(leaf):
2468             yield from append_to_line(comment_after)
2469
2470         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2471         if leaf.bracket_depth == lowest_depth and is_vararg(
2472             leaf, within=VARARGS_PARENTS
2473         ):
2474             trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
2475         leaf_priority = bt.delimiters.get(id(leaf))
2476         if leaf_priority == delimiter_priority:
2477             yield current_line
2478
2479             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2480     if current_line:
2481         if (
2482             trailing_comma_safe
2483             and delimiter_priority == COMMA_PRIORITY
2484             and current_line.leaves[-1].type != token.COMMA
2485             and current_line.leaves[-1].type != STANDALONE_COMMENT
2486         ):
2487             current_line.append(Leaf(token.COMMA, ","))
2488         yield current_line
2489
2490
2491 @dont_increase_indentation
2492 def standalone_comment_split(
2493     line: Line, supports_trailing_commas: bool = False
2494 ) -> Iterator[Line]:
2495     """Split standalone comments from the rest of the line."""
2496     if not line.contains_standalone_comments(0):
2497         raise CannotSplit("Line does not have any standalone comments")
2498
2499     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2500
2501     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2502         """Append `leaf` to current line or to new line if appending impossible."""
2503         nonlocal current_line
2504         try:
2505             current_line.append_safe(leaf, preformatted=True)
2506         except ValueError:
2507             yield current_line
2508
2509             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2510             current_line.append(leaf)
2511
2512     for leaf in line.leaves:
2513         yield from append_to_line(leaf)
2514
2515         for comment_after in line.comments_after(leaf):
2516             yield from append_to_line(comment_after)
2517
2518     if current_line:
2519         yield current_line
2520
2521
2522 def is_import(leaf: Leaf) -> bool:
2523     """Return True if the given leaf starts an import statement."""
2524     p = leaf.parent
2525     t = leaf.type
2526     v = leaf.value
2527     return bool(
2528         t == token.NAME
2529         and (
2530             (v == "import" and p and p.type == syms.import_name)
2531             or (v == "from" and p and p.type == syms.import_from)
2532         )
2533     )
2534
2535
2536 def is_special_comment(leaf: Leaf) -> bool:
2537     """Return True if the given leaf is a special comment.
2538     Only returns true for type comments for now."""
2539     t = leaf.type
2540     v = leaf.value
2541     return bool(
2542         (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
2543     )
2544
2545
2546 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2547     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2548     else.
2549
2550     Note: don't use backslashes for formatting or you'll lose your voting rights.
2551     """
2552     if not inside_brackets:
2553         spl = leaf.prefix.split("#")
2554         if "\\" not in spl[0]:
2555             nl_count = spl[-1].count("\n")
2556             if len(spl) > 1:
2557                 nl_count -= 1
2558             leaf.prefix = "\n" * nl_count
2559             return
2560
2561     leaf.prefix = ""
2562
2563
2564 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2565     """Make all string prefixes lowercase.
2566
2567     If remove_u_prefix is given, also removes any u prefix from the string.
2568
2569     Note: Mutates its argument.
2570     """
2571     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2572     assert match is not None, f"failed to match string {leaf.value!r}"
2573     orig_prefix = match.group(1)
2574     new_prefix = orig_prefix.lower()
2575     if remove_u_prefix:
2576         new_prefix = new_prefix.replace("u", "")
2577     leaf.value = f"{new_prefix}{match.group(2)}"
2578
2579
2580 def normalize_string_quotes(leaf: Leaf) -> None:
2581     """Prefer double quotes but only if it doesn't cause more escaping.
2582
2583     Adds or removes backslashes as appropriate. Doesn't parse and fix
2584     strings nested in f-strings (yet).
2585
2586     Note: Mutates its argument.
2587     """
2588     value = leaf.value.lstrip("furbFURB")
2589     if value[:3] == '"""':
2590         return
2591
2592     elif value[:3] == "'''":
2593         orig_quote = "'''"
2594         new_quote = '"""'
2595     elif value[0] == '"':
2596         orig_quote = '"'
2597         new_quote = "'"
2598     else:
2599         orig_quote = "'"
2600         new_quote = '"'
2601     first_quote_pos = leaf.value.find(orig_quote)
2602     if first_quote_pos == -1:
2603         return  # There's an internal error
2604
2605     prefix = leaf.value[:first_quote_pos]
2606     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2607     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2608     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2609     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2610     if "r" in prefix.casefold():
2611         if unescaped_new_quote.search(body):
2612             # There's at least one unescaped new_quote in this raw string
2613             # so converting is impossible
2614             return
2615
2616         # Do not introduce or remove backslashes in raw strings
2617         new_body = body
2618     else:
2619         # remove unnecessary escapes
2620         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2621         if body != new_body:
2622             # Consider the string without unnecessary escapes as the original
2623             body = new_body
2624             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2625         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2626         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2627     if "f" in prefix.casefold():
2628         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2629         for m in matches:
2630             if "\\" in str(m):
2631                 # Do not introduce backslashes in interpolated expressions
2632                 return
2633     if new_quote == '"""' and new_body[-1:] == '"':
2634         # edge case:
2635         new_body = new_body[:-1] + '\\"'
2636     orig_escape_count = body.count("\\")
2637     new_escape_count = new_body.count("\\")
2638     if new_escape_count > orig_escape_count:
2639         return  # Do not introduce more escaping
2640
2641     if new_escape_count == orig_escape_count and orig_quote == '"':
2642         return  # Prefer double quotes
2643
2644     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2645
2646
2647 def normalize_numeric_literal(leaf: Leaf) -> None:
2648     """Normalizes numeric (float, int, and complex) literals.
2649
2650     All letters used in the representation are normalized to lowercase (except
2651     in Python 2 long literals).
2652     """
2653     text = leaf.value.lower()
2654     if text.startswith(("0o", "0b")):
2655         # Leave octal and binary literals alone.
2656         pass
2657     elif text.startswith("0x"):
2658         # Change hex literals to upper case.
2659         before, after = text[:2], text[2:]
2660         text = f"{before}{after.upper()}"
2661     elif "e" in text:
2662         before, after = text.split("e")
2663         sign = ""
2664         if after.startswith("-"):
2665             after = after[1:]
2666             sign = "-"
2667         elif after.startswith("+"):
2668             after = after[1:]
2669         before = format_float_or_int_string(before)
2670         text = f"{before}e{sign}{after}"
2671     elif text.endswith(("j", "l")):
2672         number = text[:-1]
2673         suffix = text[-1]
2674         # Capitalize in "2L" because "l" looks too similar to "1".
2675         if suffix == "l":
2676             suffix = "L"
2677         text = f"{format_float_or_int_string(number)}{suffix}"
2678     else:
2679         text = format_float_or_int_string(text)
2680     leaf.value = text
2681
2682
2683 def format_float_or_int_string(text: str) -> str:
2684     """Formats a float string like "1.0"."""
2685     if "." not in text:
2686         return text
2687
2688     before, after = text.split(".")
2689     return f"{before or 0}.{after or 0}"
2690
2691
2692 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2693     """Make existing optional parentheses invisible or create new ones.
2694
2695     `parens_after` is a set of string leaf values immeditely after which parens
2696     should be put.
2697
2698     Standardizes on visible parentheses for single-element tuples, and keeps
2699     existing visible parentheses for other tuples and generator expressions.
2700     """
2701     for pc in list_comments(node.prefix, is_endmarker=False):
2702         if pc.value in FMT_OFF:
2703             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2704             return
2705
2706     check_lpar = False
2707     for index, child in enumerate(list(node.children)):
2708         if check_lpar:
2709             if child.type == syms.atom:
2710                 if maybe_make_parens_invisible_in_atom(child):
2711                     lpar = Leaf(token.LPAR, "")
2712                     rpar = Leaf(token.RPAR, "")
2713                     index = child.remove() or 0
2714                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2715             elif is_one_tuple(child):
2716                 # wrap child in visible parentheses
2717                 lpar = Leaf(token.LPAR, "(")
2718                 rpar = Leaf(token.RPAR, ")")
2719                 child.remove()
2720                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2721             elif node.type == syms.import_from:
2722                 # "import from" nodes store parentheses directly as part of
2723                 # the statement
2724                 if child.type == token.LPAR:
2725                     # make parentheses invisible
2726                     child.value = ""  # type: ignore
2727                     node.children[-1].value = ""  # type: ignore
2728                 elif child.type != token.STAR:
2729                     # insert invisible parentheses
2730                     node.insert_child(index, Leaf(token.LPAR, ""))
2731                     node.append_child(Leaf(token.RPAR, ""))
2732                 break
2733
2734             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2735                 # wrap child in invisible parentheses
2736                 lpar = Leaf(token.LPAR, "")
2737                 rpar = Leaf(token.RPAR, "")
2738                 index = child.remove() or 0
2739                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2740
2741         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2742
2743
2744 def normalize_fmt_off(node: Node) -> None:
2745     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2746     try_again = True
2747     while try_again:
2748         try_again = convert_one_fmt_off_pair(node)
2749
2750
2751 def convert_one_fmt_off_pair(node: Node) -> bool:
2752     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2753
2754     Returns True if a pair was converted.
2755     """
2756     for leaf in node.leaves():
2757         previous_consumed = 0
2758         for comment in list_comments(leaf.prefix, is_endmarker=False):
2759             if comment.value in FMT_OFF:
2760                 # We only want standalone comments. If there's no previous leaf or
2761                 # the previous leaf is indentation, it's a standalone comment in
2762                 # disguise.
2763                 if comment.type != STANDALONE_COMMENT:
2764                     prev = preceding_leaf(leaf)
2765                     if prev and prev.type not in WHITESPACE:
2766                         continue
2767
2768                 ignored_nodes = list(generate_ignored_nodes(leaf))
2769                 if not ignored_nodes:
2770                     continue
2771
2772                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2773                 parent = first.parent
2774                 prefix = first.prefix
2775                 first.prefix = prefix[comment.consumed :]
2776                 hidden_value = (
2777                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2778                 )
2779                 if hidden_value.endswith("\n"):
2780                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2781                     # leaf (possibly followed by a DEDENT).
2782                     hidden_value = hidden_value[:-1]
2783                 first_idx = None
2784                 for ignored in ignored_nodes:
2785                     index = ignored.remove()
2786                     if first_idx is None:
2787                         first_idx = index
2788                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2789                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2790                 parent.insert_child(
2791                     first_idx,
2792                     Leaf(
2793                         STANDALONE_COMMENT,
2794                         hidden_value,
2795                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2796                     ),
2797                 )
2798                 return True
2799
2800             previous_consumed = comment.consumed
2801
2802     return False
2803
2804
2805 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2806     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2807
2808     Stops at the end of the block.
2809     """
2810     container: Optional[LN] = container_of(leaf)
2811     while container is not None and container.type != token.ENDMARKER:
2812         for comment in list_comments(container.prefix, is_endmarker=False):
2813             if comment.value in FMT_ON:
2814                 return
2815
2816         yield container
2817
2818         container = container.next_sibling
2819
2820
2821 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2822     """If it's safe, make the parens in the atom `node` invisible, recursively.
2823
2824     Returns whether the node should itself be wrapped in invisible parentheses.
2825
2826     """
2827     if (
2828         node.type != syms.atom
2829         or is_empty_tuple(node)
2830         or is_one_tuple(node)
2831         or is_yield(node)
2832         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2833     ):
2834         return False
2835
2836     first = node.children[0]
2837     last = node.children[-1]
2838     if first.type == token.LPAR and last.type == token.RPAR:
2839         # make parentheses invisible
2840         first.value = ""  # type: ignore
2841         last.value = ""  # type: ignore
2842         if len(node.children) > 1:
2843             maybe_make_parens_invisible_in_atom(node.children[1])
2844         return False
2845
2846     return True
2847
2848
2849 def is_empty_tuple(node: LN) -> bool:
2850     """Return True if `node` holds an empty tuple."""
2851     return (
2852         node.type == syms.atom
2853         and len(node.children) == 2
2854         and node.children[0].type == token.LPAR
2855         and node.children[1].type == token.RPAR
2856     )
2857
2858
2859 def is_one_tuple(node: LN) -> bool:
2860     """Return True if `node` holds a tuple with one element, with or without parens."""
2861     if node.type == syms.atom:
2862         if len(node.children) != 3:
2863             return False
2864
2865         lpar, gexp, rpar = node.children
2866         if not (
2867             lpar.type == token.LPAR
2868             and gexp.type == syms.testlist_gexp
2869             and rpar.type == token.RPAR
2870         ):
2871             return False
2872
2873         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2874
2875     return (
2876         node.type in IMPLICIT_TUPLE
2877         and len(node.children) == 2
2878         and node.children[1].type == token.COMMA
2879     )
2880
2881
2882 def is_yield(node: LN) -> bool:
2883     """Return True if `node` holds a `yield` or `yield from` expression."""
2884     if node.type == syms.yield_expr:
2885         return True
2886
2887     if node.type == token.NAME and node.value == "yield":  # type: ignore
2888         return True
2889
2890     if node.type != syms.atom:
2891         return False
2892
2893     if len(node.children) != 3:
2894         return False
2895
2896     lpar, expr, rpar = node.children
2897     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2898         return is_yield(expr)
2899
2900     return False
2901
2902
2903 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2904     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2905
2906     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2907     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2908     extended iterable unpacking (PEP 3132) and additional unpacking
2909     generalizations (PEP 448).
2910     """
2911     if leaf.type not in STARS or not leaf.parent:
2912         return False
2913
2914     p = leaf.parent
2915     if p.type == syms.star_expr:
2916         # Star expressions are also used as assignment targets in extended
2917         # iterable unpacking (PEP 3132).  See what its parent is instead.
2918         if not p.parent:
2919             return False
2920
2921         p = p.parent
2922
2923     return p.type in within
2924
2925
2926 def is_multiline_string(leaf: Leaf) -> bool:
2927     """Return True if `leaf` is a multiline string that actually spans many lines."""
2928     value = leaf.value.lstrip("furbFURB")
2929     return value[:3] in {'"""', "'''"} and "\n" in value
2930
2931
2932 def is_stub_suite(node: Node) -> bool:
2933     """Return True if `node` is a suite with a stub body."""
2934     if (
2935         len(node.children) != 4
2936         or node.children[0].type != token.NEWLINE
2937         or node.children[1].type != token.INDENT
2938         or node.children[3].type != token.DEDENT
2939     ):
2940         return False
2941
2942     return is_stub_body(node.children[2])
2943
2944
2945 def is_stub_body(node: LN) -> bool:
2946     """Return True if `node` is a simple statement containing an ellipsis."""
2947     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2948         return False
2949
2950     if len(node.children) != 2:
2951         return False
2952
2953     child = node.children[0]
2954     return (
2955         child.type == syms.atom
2956         and len(child.children) == 3
2957         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2958     )
2959
2960
2961 def max_delimiter_priority_in_atom(node: LN) -> int:
2962     """Return maximum delimiter priority inside `node`.
2963
2964     This is specific to atoms with contents contained in a pair of parentheses.
2965     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2966     """
2967     if node.type != syms.atom:
2968         return 0
2969
2970     first = node.children[0]
2971     last = node.children[-1]
2972     if not (first.type == token.LPAR and last.type == token.RPAR):
2973         return 0
2974
2975     bt = BracketTracker()
2976     for c in node.children[1:-1]:
2977         if isinstance(c, Leaf):
2978             bt.mark(c)
2979         else:
2980             for leaf in c.leaves():
2981                 bt.mark(leaf)
2982     try:
2983         return bt.max_delimiter_priority()
2984
2985     except ValueError:
2986         return 0
2987
2988
2989 def ensure_visible(leaf: Leaf) -> None:
2990     """Make sure parentheses are visible.
2991
2992     They could be invisible as part of some statements (see
2993     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2994     """
2995     if leaf.type == token.LPAR:
2996         leaf.value = "("
2997     elif leaf.type == token.RPAR:
2998         leaf.value = ")"
2999
3000
3001 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3002     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3003
3004     if not (
3005         opening_bracket.parent
3006         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3007         and opening_bracket.value in "[{("
3008     ):
3009         return False
3010
3011     try:
3012         last_leaf = line.leaves[-1]
3013         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3014         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3015     except (IndexError, ValueError):
3016         return False
3017
3018     return max_priority == COMMA_PRIORITY
3019
3020
3021 def get_features_used(node: Node) -> Set[Feature]:
3022     """Return a set of (relatively) new Python features used in this file.
3023
3024     Currently looking for:
3025     - f-strings;
3026     - underscores in numeric literals; and
3027     - trailing commas after * or ** in function signatures and calls.
3028     """
3029     features: Set[Feature] = set()
3030     for n in node.pre_order():
3031         if n.type == token.STRING:
3032             value_head = n.value[:2]  # type: ignore
3033             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3034                 features.add(Feature.F_STRINGS)
3035
3036         elif n.type == token.NUMBER:
3037             if "_" in n.value:  # type: ignore
3038                 features.add(Feature.NUMERIC_UNDERSCORES)
3039
3040         elif (
3041             n.type in {syms.typedargslist, syms.arglist}
3042             and n.children
3043             and n.children[-1].type == token.COMMA
3044         ):
3045             for ch in n.children:
3046                 if ch.type in STARS:
3047                     features.add(Feature.TRAILING_COMMA)
3048
3049                 if ch.type == syms.argument:
3050                     for argch in ch.children:
3051                         if argch.type in STARS:
3052                             features.add(Feature.TRAILING_COMMA)
3053
3054     return features
3055
3056
3057 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3058     """Detect the version to target based on the nodes used."""
3059     features = get_features_used(node)
3060     return {
3061         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3062     }
3063
3064
3065 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3066     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3067
3068     Brackets can be omitted if the entire trailer up to and including
3069     a preceding closing bracket fits in one line.
3070
3071     Yielded sets are cumulative (contain results of previous yields, too).  First
3072     set is empty.
3073     """
3074
3075     omit: Set[LeafID] = set()
3076     yield omit
3077
3078     length = 4 * line.depth
3079     opening_bracket = None
3080     closing_bracket = None
3081     inner_brackets: Set[LeafID] = set()
3082     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3083         length += leaf_length
3084         if length > line_length:
3085             break
3086
3087         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3088         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3089             break
3090
3091         if opening_bracket:
3092             if leaf is opening_bracket:
3093                 opening_bracket = None
3094             elif leaf.type in CLOSING_BRACKETS:
3095                 inner_brackets.add(id(leaf))
3096         elif leaf.type in CLOSING_BRACKETS:
3097             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3098                 # Empty brackets would fail a split so treat them as "inner"
3099                 # brackets (e.g. only add them to the `omit` set if another
3100                 # pair of brackets was good enough.
3101                 inner_brackets.add(id(leaf))
3102                 continue
3103
3104             if closing_bracket:
3105                 omit.add(id(closing_bracket))
3106                 omit.update(inner_brackets)
3107                 inner_brackets.clear()
3108                 yield omit
3109
3110             if leaf.value:
3111                 opening_bracket = leaf.opening_bracket
3112                 closing_bracket = leaf
3113
3114
3115 def get_future_imports(node: Node) -> Set[str]:
3116     """Return a set of __future__ imports in the file."""
3117     imports: Set[str] = set()
3118
3119     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3120         for child in children:
3121             if isinstance(child, Leaf):
3122                 if child.type == token.NAME:
3123                     yield child.value
3124             elif child.type == syms.import_as_name:
3125                 orig_name = child.children[0]
3126                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3127                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3128                 yield orig_name.value
3129             elif child.type == syms.import_as_names:
3130                 yield from get_imports_from_children(child.children)
3131             else:
3132                 assert False, "Invalid syntax parsing imports"
3133
3134     for child in node.children:
3135         if child.type != syms.simple_stmt:
3136             break
3137         first_child = child.children[0]
3138         if isinstance(first_child, Leaf):
3139             # Continue looking if we see a docstring; otherwise stop.
3140             if (
3141                 len(child.children) == 2
3142                 and first_child.type == token.STRING
3143                 and child.children[1].type == token.NEWLINE
3144             ):
3145                 continue
3146             else:
3147                 break
3148         elif first_child.type == syms.import_from:
3149             module_name = first_child.children[1]
3150             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3151                 break
3152             imports |= set(get_imports_from_children(first_child.children[3:]))
3153         else:
3154             break
3155     return imports
3156
3157
3158 def gen_python_files_in_dir(
3159     path: Path,
3160     root: Path,
3161     include: Pattern[str],
3162     exclude: Pattern[str],
3163     report: "Report",
3164 ) -> Iterator[Path]:
3165     """Generate all files under `path` whose paths are not excluded by the
3166     `exclude` regex, but are included by the `include` regex.
3167
3168     Symbolic links pointing outside of the `root` directory are ignored.
3169
3170     `report` is where output about exclusions goes.
3171     """
3172     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3173     for child in path.iterdir():
3174         try:
3175             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3176         except ValueError:
3177             if child.is_symlink():
3178                 report.path_ignored(
3179                     child, f"is a symbolic link that points outside {root}"
3180                 )
3181                 continue
3182
3183             raise
3184
3185         if child.is_dir():
3186             normalized_path += "/"
3187         exclude_match = exclude.search(normalized_path)
3188         if exclude_match and exclude_match.group(0):
3189             report.path_ignored(child, f"matches the --exclude regular expression")
3190             continue
3191
3192         if child.is_dir():
3193             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3194
3195         elif child.is_file():
3196             include_match = include.search(normalized_path)
3197             if include_match:
3198                 yield child
3199
3200
3201 @lru_cache()
3202 def find_project_root(srcs: Iterable[str]) -> Path:
3203     """Return a directory containing .git, .hg, or pyproject.toml.
3204
3205     That directory can be one of the directories passed in `srcs` or their
3206     common parent.
3207
3208     If no directory in the tree contains a marker that would specify it's the
3209     project root, the root of the file system is returned.
3210     """
3211     if not srcs:
3212         return Path("/").resolve()
3213
3214     common_base = min(Path(src).resolve() for src in srcs)
3215     if common_base.is_dir():
3216         # Append a fake file so `parents` below returns `common_base_dir`, too.
3217         common_base /= "fake-file"
3218     for directory in common_base.parents:
3219         if (directory / ".git").is_dir():
3220             return directory
3221
3222         if (directory / ".hg").is_dir():
3223             return directory
3224
3225         if (directory / "pyproject.toml").is_file():
3226             return directory
3227
3228     return directory
3229
3230
3231 @dataclass
3232 class Report:
3233     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3234
3235     check: bool = False
3236     quiet: bool = False
3237     verbose: bool = False
3238     change_count: int = 0
3239     same_count: int = 0
3240     failure_count: int = 0
3241
3242     def done(self, src: Path, changed: Changed) -> None:
3243         """Increment the counter for successful reformatting. Write out a message."""
3244         if changed is Changed.YES:
3245             reformatted = "would reformat" if self.check else "reformatted"
3246             if self.verbose or not self.quiet:
3247                 out(f"{reformatted} {src}")
3248             self.change_count += 1
3249         else:
3250             if self.verbose:
3251                 if changed is Changed.NO:
3252                     msg = f"{src} already well formatted, good job."
3253                 else:
3254                     msg = f"{src} wasn't modified on disk since last run."
3255                 out(msg, bold=False)
3256             self.same_count += 1
3257
3258     def failed(self, src: Path, message: str) -> None:
3259         """Increment the counter for failed reformatting. Write out a message."""
3260         err(f"error: cannot format {src}: {message}")
3261         self.failure_count += 1
3262
3263     def path_ignored(self, path: Path, message: str) -> None:
3264         if self.verbose:
3265             out(f"{path} ignored: {message}", bold=False)
3266
3267     @property
3268     def return_code(self) -> int:
3269         """Return the exit code that the app should use.
3270
3271         This considers the current state of changed files and failures:
3272         - if there were any failures, return 123;
3273         - if any files were changed and --check is being used, return 1;
3274         - otherwise return 0.
3275         """
3276         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3277         # 126 we have special return codes reserved by the shell.
3278         if self.failure_count:
3279             return 123
3280
3281         elif self.change_count and self.check:
3282             return 1
3283
3284         return 0
3285
3286     def __str__(self) -> str:
3287         """Render a color report of the current state.
3288
3289         Use `click.unstyle` to remove colors.
3290         """
3291         if self.check:
3292             reformatted = "would be reformatted"
3293             unchanged = "would be left unchanged"
3294             failed = "would fail to reformat"
3295         else:
3296             reformatted = "reformatted"
3297             unchanged = "left unchanged"
3298             failed = "failed to reformat"
3299         report = []
3300         if self.change_count:
3301             s = "s" if self.change_count > 1 else ""
3302             report.append(
3303                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3304             )
3305         if self.same_count:
3306             s = "s" if self.same_count > 1 else ""
3307             report.append(f"{self.same_count} file{s} {unchanged}")
3308         if self.failure_count:
3309             s = "s" if self.failure_count > 1 else ""
3310             report.append(
3311                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3312             )
3313         return ", ".join(report) + "."
3314
3315
3316 def assert_equivalent(src: str, dst: str) -> None:
3317     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3318
3319     import ast
3320     import traceback
3321
3322     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3323         """Simple visitor generating strings to compare ASTs by content."""
3324         yield f"{'  ' * depth}{node.__class__.__name__}("
3325
3326         for field in sorted(node._fields):
3327             try:
3328                 value = getattr(node, field)
3329             except AttributeError:
3330                 continue
3331
3332             yield f"{'  ' * (depth+1)}{field}="
3333
3334             if isinstance(value, list):
3335                 for item in value:
3336                     # Ignore nested tuples within del statements, because we may insert
3337                     # parentheses and they change the AST.
3338                     if (
3339                         field == "targets"
3340                         and isinstance(node, ast.Delete)
3341                         and isinstance(item, ast.Tuple)
3342                     ):
3343                         for item in item.elts:
3344                             yield from _v(item, depth + 2)
3345                     elif isinstance(item, ast.AST):
3346                         yield from _v(item, depth + 2)
3347
3348             elif isinstance(value, ast.AST):
3349                 yield from _v(value, depth + 2)
3350
3351             else:
3352                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3353
3354         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3355
3356     try:
3357         src_ast = ast.parse(src)
3358     except Exception as exc:
3359         major, minor = sys.version_info[:2]
3360         raise AssertionError(
3361             f"cannot use --safe with this file; failed to parse source file "
3362             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3363             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3364         )
3365
3366     try:
3367         dst_ast = ast.parse(dst)
3368     except Exception as exc:
3369         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3370         raise AssertionError(
3371             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3372             f"Please report a bug on https://github.com/ambv/black/issues.  "
3373             f"This invalid output might be helpful: {log}"
3374         ) from None
3375
3376     src_ast_str = "\n".join(_v(src_ast))
3377     dst_ast_str = "\n".join(_v(dst_ast))
3378     if src_ast_str != dst_ast_str:
3379         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3380         raise AssertionError(
3381             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3382             f"the source.  "
3383             f"Please report a bug on https://github.com/ambv/black/issues.  "
3384             f"This diff might be helpful: {log}"
3385         ) from None
3386
3387
3388 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3389     """Raise AssertionError if `dst` reformats differently the second time."""
3390     newdst = format_str(dst, mode=mode)
3391     if dst != newdst:
3392         log = dump_to_file(
3393             diff(src, dst, "source", "first pass"),
3394             diff(dst, newdst, "first pass", "second pass"),
3395         )
3396         raise AssertionError(
3397             f"INTERNAL ERROR: Black produced different code on the second pass "
3398             f"of the formatter.  "
3399             f"Please report a bug on https://github.com/ambv/black/issues.  "
3400             f"This diff might be helpful: {log}"
3401         ) from None
3402
3403
3404 def dump_to_file(*output: str) -> str:
3405     """Dump `output` to a temporary file. Return path to the file."""
3406     import tempfile
3407
3408     with tempfile.NamedTemporaryFile(
3409         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3410     ) as f:
3411         for lines in output:
3412             f.write(lines)
3413             if lines and lines[-1] != "\n":
3414                 f.write("\n")
3415     return f.name
3416
3417
3418 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3419     """Return a unified diff string between strings `a` and `b`."""
3420     import difflib
3421
3422     a_lines = [line + "\n" for line in a.split("\n")]
3423     b_lines = [line + "\n" for line in b.split("\n")]
3424     return "".join(
3425         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3426     )
3427
3428
3429 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3430     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3431     err("Aborted!")
3432     for task in tasks:
3433         task.cancel()
3434
3435
3436 def shutdown(loop: BaseEventLoop) -> None:
3437     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3438     try:
3439         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3440         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3441         if not to_cancel:
3442             return
3443
3444         for task in to_cancel:
3445             task.cancel()
3446         loop.run_until_complete(
3447             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3448         )
3449     finally:
3450         # `concurrent.futures.Future` objects cannot be cancelled once they
3451         # are already running. There might be some when the `shutdown()` happened.
3452         # Silence their logger's spew about the event loop being closed.
3453         cf_logger = logging.getLogger("concurrent.futures")
3454         cf_logger.setLevel(logging.CRITICAL)
3455         loop.close()
3456
3457
3458 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3459     """Replace `regex` with `replacement` twice on `original`.
3460
3461     This is used by string normalization to perform replaces on
3462     overlapping matches.
3463     """
3464     return regex.sub(replacement, regex.sub(replacement, original))
3465
3466
3467 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3468     """Compile a regular expression string in `regex`.
3469
3470     If it contains newlines, use verbose mode.
3471     """
3472     if "\n" in regex:
3473         regex = "(?x)" + regex
3474     return re.compile(regex)
3475
3476
3477 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3478     """Like `reversed(enumerate(sequence))` if that were possible."""
3479     index = len(sequence) - 1
3480     for element in reversed(sequence):
3481         yield (index, element)
3482         index -= 1
3483
3484
3485 def enumerate_with_length(
3486     line: Line, reversed: bool = False
3487 ) -> Iterator[Tuple[Index, Leaf, int]]:
3488     """Return an enumeration of leaves with their length.
3489
3490     Stops prematurely on multiline strings and standalone comments.
3491     """
3492     op = cast(
3493         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3494         enumerate_reversed if reversed else enumerate,
3495     )
3496     for index, leaf in op(line.leaves):
3497         length = len(leaf.prefix) + len(leaf.value)
3498         if "\n" in leaf.value:
3499             return  # Multiline strings, we can't continue.
3500
3501         comment: Optional[Leaf]
3502         for comment in line.comments_after(leaf):
3503             length += len(comment.value)
3504
3505         yield index, leaf, length
3506
3507
3508 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3509     """Return True if `line` is no longer than `line_length`.
3510
3511     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3512     """
3513     if not line_str:
3514         line_str = str(line).strip("\n")
3515     return (
3516         len(line_str) <= line_length
3517         and "\n" not in line_str  # multiline strings
3518         and not line.contains_standalone_comments()
3519     )
3520
3521
3522 def can_be_split(line: Line) -> bool:
3523     """Return False if the line cannot be split *for sure*.
3524
3525     This is not an exhaustive search but a cheap heuristic that we can use to
3526     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3527     in unnecessary parentheses).
3528     """
3529     leaves = line.leaves
3530     if len(leaves) < 2:
3531         return False
3532
3533     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3534         call_count = 0
3535         dot_count = 0
3536         next = leaves[-1]
3537         for leaf in leaves[-2::-1]:
3538             if leaf.type in OPENING_BRACKETS:
3539                 if next.type not in CLOSING_BRACKETS:
3540                     return False
3541
3542                 call_count += 1
3543             elif leaf.type == token.DOT:
3544                 dot_count += 1
3545             elif leaf.type == token.NAME:
3546                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3547                     return False
3548
3549             elif leaf.type not in CLOSING_BRACKETS:
3550                 return False
3551
3552             if dot_count > 1 and call_count > 1:
3553                 return False
3554
3555     return True
3556
3557
3558 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3559     """Does `line` have a shape safe to reformat without optional parens around it?
3560
3561     Returns True for only a subset of potentially nice looking formattings but
3562     the point is to not return false positives that end up producing lines that
3563     are too long.
3564     """
3565     bt = line.bracket_tracker
3566     if not bt.delimiters:
3567         # Without delimiters the optional parentheses are useless.
3568         return True
3569
3570     max_priority = bt.max_delimiter_priority()
3571     if bt.delimiter_count_with_priority(max_priority) > 1:
3572         # With more than one delimiter of a kind the optional parentheses read better.
3573         return False
3574
3575     if max_priority == DOT_PRIORITY:
3576         # A single stranded method call doesn't require optional parentheses.
3577         return True
3578
3579     assert len(line.leaves) >= 2, "Stranded delimiter"
3580
3581     first = line.leaves[0]
3582     second = line.leaves[1]
3583     penultimate = line.leaves[-2]
3584     last = line.leaves[-1]
3585
3586     # With a single delimiter, omit if the expression starts or ends with
3587     # a bracket.
3588     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3589         remainder = False
3590         length = 4 * line.depth
3591         for _index, leaf, leaf_length in enumerate_with_length(line):
3592             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3593                 remainder = True
3594             if remainder:
3595                 length += leaf_length
3596                 if length > line_length:
3597                     break
3598
3599                 if leaf.type in OPENING_BRACKETS:
3600                     # There are brackets we can further split on.
3601                     remainder = False
3602
3603         else:
3604             # checked the entire string and line length wasn't exceeded
3605             if len(line.leaves) == _index + 1:
3606                 return True
3607
3608         # Note: we are not returning False here because a line might have *both*
3609         # a leading opening bracket and a trailing closing bracket.  If the
3610         # opening bracket doesn't match our rule, maybe the closing will.
3611
3612     if (
3613         last.type == token.RPAR
3614         or last.type == token.RBRACE
3615         or (
3616             # don't use indexing for omitting optional parentheses;
3617             # it looks weird
3618             last.type == token.RSQB
3619             and last.parent
3620             and last.parent.type != syms.trailer
3621         )
3622     ):
3623         if penultimate.type in OPENING_BRACKETS:
3624             # Empty brackets don't help.
3625             return False
3626
3627         if is_multiline_string(first):
3628             # Additional wrapping of a multiline string in this situation is
3629             # unnecessary.
3630             return True
3631
3632         length = 4 * line.depth
3633         seen_other_brackets = False
3634         for _index, leaf, leaf_length in enumerate_with_length(line):
3635             length += leaf_length
3636             if leaf is last.opening_bracket:
3637                 if seen_other_brackets or length <= line_length:
3638                     return True
3639
3640             elif leaf.type in OPENING_BRACKETS:
3641                 # There are brackets we can further split on.
3642                 seen_other_brackets = True
3643
3644     return False
3645
3646
3647 def get_cache_file(mode: FileMode) -> Path:
3648     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3649
3650
3651 def read_cache(mode: FileMode) -> Cache:
3652     """Read the cache if it exists and is well formed.
3653
3654     If it is not well formed, the call to write_cache later should resolve the issue.
3655     """
3656     cache_file = get_cache_file(mode)
3657     if not cache_file.exists():
3658         return {}
3659
3660     with cache_file.open("rb") as fobj:
3661         try:
3662             cache: Cache = pickle.load(fobj)
3663         except pickle.UnpicklingError:
3664             return {}
3665
3666     return cache
3667
3668
3669 def get_cache_info(path: Path) -> CacheInfo:
3670     """Return the information used to check if a file is already formatted or not."""
3671     stat = path.stat()
3672     return stat.st_mtime, stat.st_size
3673
3674
3675 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3676     """Split an iterable of paths in `sources` into two sets.
3677
3678     The first contains paths of files that modified on disk or are not in the
3679     cache. The other contains paths to non-modified files.
3680     """
3681     todo, done = set(), set()
3682     for src in sources:
3683         src = src.resolve()
3684         if cache.get(src) != get_cache_info(src):
3685             todo.add(src)
3686         else:
3687             done.add(src)
3688     return todo, done
3689
3690
3691 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3692     """Update the cache file."""
3693     cache_file = get_cache_file(mode)
3694     try:
3695         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3696         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3697         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3698             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3699         os.replace(f.name, cache_file)
3700     except OSError:
3701         pass
3702
3703
3704 def patch_click() -> None:
3705     """Make Click not crash.
3706
3707     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3708     default which restricts paths that it can access during the lifetime of the
3709     application.  Click refuses to work in this scenario by raising a RuntimeError.
3710
3711     In case of Black the likelihood that non-ASCII characters are going to be used in
3712     file paths is minimal since it's Python source code.  Moreover, this crash was
3713     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3714     """
3715     try:
3716         from click import core
3717         from click import _unicodefun  # type: ignore
3718     except ModuleNotFoundError:
3719         return
3720
3721     for module in (core, _unicodefun):
3722         if hasattr(module, "_verify_python3_env"):
3723             module._verify_python3_env = lambda: None
3724
3725
3726 def patched_main() -> None:
3727     freeze_support()
3728     patch_click()
3729     main()
3730
3731
3732 if __name__ == "__main__":
3733     patched_main()