]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

2850ae1a19cb0038ce5acd4bbb86837db171cda0
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "18.9b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PYPY35 = 1
117     CPY27 = 2
118     CPY33 = 3
119     CPY34 = 4
120     CPY35 = 5
121     CPY36 = 6
122     CPY37 = 7
123     CPY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.CPY27
127
128
129 PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA = 4
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.CPY27: set(),
142     TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
143     TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
144     TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
145     TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
146     TargetVersion.CPY36: {
147         Feature.UNICODE_LITERALS,
148         Feature.F_STRINGS,
149         Feature.NUMERIC_UNDERSCORES,
150         Feature.TRAILING_COMMA,
151     },
152     TargetVersion.CPY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA,
157     },
158     TargetVersion.CPY38: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA,
163     },
164 }
165
166
167 @dataclass
168 class FileMode:
169     target_versions: Set[TargetVersion] = Factory(set)
170     line_length: int = DEFAULT_LINE_LENGTH
171     string_normalization: bool = True
172     is_pyi: bool = False
173
174     def get_cache_key(self) -> str:
175         if self.target_versions:
176             version_str = ",".join(
177                 str(version.value)
178                 for version in sorted(self.target_versions, key=lambda v: v.value)
179             )
180         else:
181             version_str = "-"
182         parts = [
183             version_str,
184             str(self.line_length),
185             str(int(self.string_normalization)),
186             str(int(self.is_pyi)),
187         ]
188         return ".".join(parts)
189
190
191 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
192     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
193
194
195 def read_pyproject_toml(
196     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
197 ) -> Optional[str]:
198     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
199
200     Returns the path to a successfully found and read configuration file, None
201     otherwise.
202     """
203     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
204     if not value:
205         root = find_project_root(ctx.params.get("src", ()))
206         path = root / "pyproject.toml"
207         if path.is_file():
208             value = str(path)
209         else:
210             return None
211
212     try:
213         pyproject_toml = toml.load(value)
214         config = pyproject_toml.get("tool", {}).get("black", {})
215     except (toml.TomlDecodeError, OSError) as e:
216         raise click.FileError(
217             filename=value, hint=f"Error reading configuration file: {e}"
218         )
219
220     if not config:
221         return None
222
223     if ctx.default_map is None:
224         ctx.default_map = {}
225     ctx.default_map.update(  # type: ignore  # bad types in .pyi
226         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
227     )
228     return value
229
230
231 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
232 @click.option(
233     "-l",
234     "--line-length",
235     type=int,
236     default=DEFAULT_LINE_LENGTH,
237     help="How many characters per line to allow.",
238     show_default=True,
239 )
240 @click.option(
241     "-t",
242     "--target-version",
243     type=click.Choice([v.name.lower() for v in TargetVersion]),
244     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
245     multiple=True,
246     help=(
247         "Python versions that should be supported by Black's output. [default: "
248         "per-file auto-detection]"
249     ),
250 )
251 @click.option(
252     "--pyi",
253     is_flag=True,
254     help=(
255         "Format all input files like typing stubs regardless of file extension "
256         "(useful when piping source on standard input)."
257     ),
258 )
259 @click.option(
260     "-S",
261     "--skip-string-normalization",
262     is_flag=True,
263     help="Don't normalize string quotes or prefixes.",
264 )
265 @click.option(
266     "--check",
267     is_flag=True,
268     help=(
269         "Don't write the files back, just return the status.  Return code 0 "
270         "means nothing would change.  Return code 1 means some files would be "
271         "reformatted.  Return code 123 means there was an internal error."
272     ),
273 )
274 @click.option(
275     "--diff",
276     is_flag=True,
277     help="Don't write the files back, just output a diff for each file on stdout.",
278 )
279 @click.option(
280     "--fast/--safe",
281     is_flag=True,
282     help="If --fast given, skip temporary sanity checks. [default: --safe]",
283 )
284 @click.option(
285     "--include",
286     type=str,
287     default=DEFAULT_INCLUDES,
288     help=(
289         "A regular expression that matches files and directories that should be "
290         "included on recursive searches.  An empty value means all files are "
291         "included regardless of the name.  Use forward slashes for directories on "
292         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
293         "later."
294     ),
295     show_default=True,
296 )
297 @click.option(
298     "--exclude",
299     type=str,
300     default=DEFAULT_EXCLUDES,
301     help=(
302         "A regular expression that matches files and directories that should be "
303         "excluded on recursive searches.  An empty value means no paths are excluded. "
304         "Use forward slashes for directories on all platforms (Windows, too).  "
305         "Exclusions are calculated first, inclusions later."
306     ),
307     show_default=True,
308 )
309 @click.option(
310     "-q",
311     "--quiet",
312     is_flag=True,
313     help=(
314         "Don't emit non-error messages to stderr. Errors are still emitted, "
315         "silence those with 2>/dev/null."
316     ),
317 )
318 @click.option(
319     "-v",
320     "--verbose",
321     is_flag=True,
322     help=(
323         "Also emit messages to stderr about files that were not changed or were "
324         "ignored due to --exclude=."
325     ),
326 )
327 @click.version_option(version=__version__)
328 @click.argument(
329     "src",
330     nargs=-1,
331     type=click.Path(
332         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
333     ),
334     is_eager=True,
335 )
336 @click.option(
337     "--config",
338     type=click.Path(
339         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
340     ),
341     is_eager=True,
342     callback=read_pyproject_toml,
343     help="Read configuration from PATH.",
344 )
345 @click.pass_context
346 def main(
347     ctx: click.Context,
348     line_length: int,
349     target_version: List[TargetVersion],
350     check: bool,
351     diff: bool,
352     fast: bool,
353     pyi: bool,
354     skip_string_normalization: bool,
355     quiet: bool,
356     verbose: bool,
357     include: str,
358     exclude: str,
359     src: Tuple[str],
360     config: Optional[str],
361 ) -> None:
362     """The uncompromising code formatter."""
363     write_back = WriteBack.from_configuration(check=check, diff=diff)
364     if target_version:
365         versions = set(target_version)
366     else:
367         # We'll autodetect later.
368         versions = set()
369     mode = FileMode(
370         target_versions=versions,
371         line_length=line_length,
372         is_pyi=pyi,
373         string_normalization=not skip_string_normalization,
374     )
375     if config and verbose:
376         out(f"Using configuration from {config}.", bold=False, fg="blue")
377     try:
378         include_regex = re_compile_maybe_verbose(include)
379     except re.error:
380         err(f"Invalid regular expression for include given: {include!r}")
381         ctx.exit(2)
382     try:
383         exclude_regex = re_compile_maybe_verbose(exclude)
384     except re.error:
385         err(f"Invalid regular expression for exclude given: {exclude!r}")
386         ctx.exit(2)
387     report = Report(check=check, quiet=quiet, verbose=verbose)
388     root = find_project_root(src)
389     sources: Set[Path] = set()
390     for s in src:
391         p = Path(s)
392         if p.is_dir():
393             sources.update(
394                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
395             )
396         elif p.is_file() or s == "-":
397             # if a file was explicitly given, we don't care about its extension
398             sources.add(p)
399         else:
400             err(f"invalid path: {s}")
401     if len(sources) == 0:
402         if verbose or not quiet:
403             out("No paths given. Nothing to do 😴")
404         ctx.exit(0)
405
406     if len(sources) == 1:
407         reformat_one(
408             src=sources.pop(),
409             fast=fast,
410             write_back=write_back,
411             mode=mode,
412             report=report,
413         )
414     else:
415         loop = asyncio.get_event_loop()
416         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
417         try:
418             loop.run_until_complete(
419                 schedule_formatting(
420                     sources=sources,
421                     fast=fast,
422                     write_back=write_back,
423                     mode=mode,
424                     report=report,
425                     loop=loop,
426                     executor=executor,
427                 )
428             )
429         finally:
430             shutdown(loop)
431     if verbose or not quiet:
432         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
433         out(f"All done! {bang}")
434         click.secho(str(report), err=True)
435     ctx.exit(report.return_code)
436
437
438 def reformat_one(
439     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
440 ) -> None:
441     """Reformat a single file under `src` without spawning child processes.
442
443     If `quiet` is True, non-error messages are not output. `line_length`,
444     `write_back`, `fast` and `pyi` options are passed to
445     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
446     """
447     try:
448         changed = Changed.NO
449         if not src.is_file() and str(src) == "-":
450             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
451                 changed = Changed.YES
452         else:
453             cache: Cache = {}
454             if write_back != WriteBack.DIFF:
455                 cache = read_cache(mode)
456                 res_src = src.resolve()
457                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
458                     changed = Changed.CACHED
459             if changed is not Changed.CACHED and format_file_in_place(
460                 src, fast=fast, write_back=write_back, mode=mode
461             ):
462                 changed = Changed.YES
463             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
464                 write_back is WriteBack.CHECK and changed is Changed.NO
465             ):
466                 write_cache(cache, [src], mode)
467         report.done(src, changed)
468     except Exception as exc:
469         report.failed(src, str(exc))
470
471
472 async def schedule_formatting(
473     sources: Set[Path],
474     fast: bool,
475     write_back: WriteBack,
476     mode: FileMode,
477     report: "Report",
478     loop: BaseEventLoop,
479     executor: Executor,
480 ) -> None:
481     """Run formatting of `sources` in parallel using the provided `executor`.
482
483     (Use ProcessPoolExecutors for actual parallelism.)
484
485     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
486     :func:`format_file_in_place`.
487     """
488     cache: Cache = {}
489     if write_back != WriteBack.DIFF:
490         cache = read_cache(mode)
491         sources, cached = filter_cached(cache, sources)
492         for src in sorted(cached):
493             report.done(src, Changed.CACHED)
494     if not sources:
495         return
496
497     cancelled = []
498     sources_to_cache = []
499     lock = None
500     if write_back == WriteBack.DIFF:
501         # For diff output, we need locks to ensure we don't interleave output
502         # from different processes.
503         manager = Manager()
504         lock = manager.Lock()
505     tasks = {
506         loop.run_in_executor(
507             executor, format_file_in_place, src, fast, mode, write_back, lock
508         ): src
509         for src in sorted(sources)
510     }
511     pending: Iterable[asyncio.Task] = tasks.keys()
512     try:
513         loop.add_signal_handler(signal.SIGINT, cancel, pending)
514         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
515     except NotImplementedError:
516         # There are no good alternatives for these on Windows.
517         pass
518     while pending:
519         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
520         for task in done:
521             src = tasks.pop(task)
522             if task.cancelled():
523                 cancelled.append(task)
524             elif task.exception():
525                 report.failed(src, str(task.exception()))
526             else:
527                 changed = Changed.YES if task.result() else Changed.NO
528                 # If the file was written back or was successfully checked as
529                 # well-formatted, store this information in the cache.
530                 if write_back is WriteBack.YES or (
531                     write_back is WriteBack.CHECK and changed is Changed.NO
532                 ):
533                     sources_to_cache.append(src)
534                 report.done(src, changed)
535     if cancelled:
536         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
537     if sources_to_cache:
538         write_cache(cache, sources_to_cache, mode)
539
540
541 def format_file_in_place(
542     src: Path,
543     fast: bool,
544     mode: FileMode,
545     write_back: WriteBack = WriteBack.NO,
546     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
547 ) -> bool:
548     """Format file under `src` path. Return True if changed.
549
550     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
551     code to the file.
552     `line_length` and `fast` options are passed to :func:`format_file_contents`.
553     """
554     if src.suffix == ".pyi":
555         mode = evolve(mode, is_pyi=True)
556
557     then = datetime.utcfromtimestamp(src.stat().st_mtime)
558     with open(src, "rb") as buf:
559         src_contents, encoding, newline = decode_bytes(buf.read())
560     try:
561         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
562     except NothingChanged:
563         return False
564
565     if write_back == write_back.YES:
566         with open(src, "w", encoding=encoding, newline=newline) as f:
567             f.write(dst_contents)
568     elif write_back == write_back.DIFF:
569         now = datetime.utcnow()
570         src_name = f"{src}\t{then} +0000"
571         dst_name = f"{src}\t{now} +0000"
572         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
573         if lock:
574             lock.acquire()
575         try:
576             f = io.TextIOWrapper(
577                 sys.stdout.buffer,
578                 encoding=encoding,
579                 newline=newline,
580                 write_through=True,
581             )
582             f.write(diff_contents)
583             f.detach()
584         finally:
585             if lock:
586                 lock.release()
587     return True
588
589
590 def format_stdin_to_stdout(
591     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
592 ) -> bool:
593     """Format file on stdin. Return True if changed.
594
595     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
596     write a diff to stdout. The `mode` argument is passed to
597     :func:`format_file_contents`.
598     """
599     then = datetime.utcnow()
600     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
601     dst = src
602     try:
603         dst = format_file_contents(src, fast=fast, mode=mode)
604         return True
605
606     except NothingChanged:
607         return False
608
609     finally:
610         f = io.TextIOWrapper(
611             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
612         )
613         if write_back == WriteBack.YES:
614             f.write(dst)
615         elif write_back == WriteBack.DIFF:
616             now = datetime.utcnow()
617             src_name = f"STDIN\t{then} +0000"
618             dst_name = f"STDOUT\t{now} +0000"
619             f.write(diff(src, dst, src_name, dst_name))
620         f.detach()
621
622
623 def format_file_contents(
624     src_contents: str, *, fast: bool, mode: FileMode
625 ) -> FileContent:
626     """Reformat contents a file and return new contents.
627
628     If `fast` is False, additionally confirm that the reformatted code is
629     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
630     `line_length` is passed to :func:`format_str`.
631     """
632     if src_contents.strip() == "":
633         raise NothingChanged
634
635     dst_contents = format_str(src_contents, mode=mode)
636     if src_contents == dst_contents:
637         raise NothingChanged
638
639     if not fast:
640         assert_equivalent(src_contents, dst_contents)
641         assert_stable(src_contents, dst_contents, mode=mode)
642     return dst_contents
643
644
645 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
646     """Reformat a string and return new contents.
647
648     `line_length` determines how many characters per line are allowed.
649     """
650     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
651     dst_contents = ""
652     future_imports = get_future_imports(src_node)
653     if mode.target_versions:
654         versions = mode.target_versions
655     else:
656         versions = detect_target_versions(src_node)
657     normalize_fmt_off(src_node)
658     lines = LineGenerator(
659         remove_u_prefix="unicode_literals" in future_imports
660         or supports_feature(versions, Feature.UNICODE_LITERALS),
661         is_pyi=mode.is_pyi,
662         normalize_strings=mode.string_normalization,
663     )
664     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
665     empty_line = Line()
666     after = 0
667     for current_line in lines.visit(src_node):
668         for _ in range(after):
669             dst_contents += str(empty_line)
670         before, after = elt.maybe_empty_lines(current_line)
671         for _ in range(before):
672             dst_contents += str(empty_line)
673         for line in split_line(
674             current_line,
675             line_length=mode.line_length,
676             supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
677         ):
678             dst_contents += str(line)
679     return dst_contents
680
681
682 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
683     """Return a tuple of (decoded_contents, encoding, newline).
684
685     `newline` is either CRLF or LF but `decoded_contents` is decoded with
686     universal newlines (i.e. only contains LF).
687     """
688     srcbuf = io.BytesIO(src)
689     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
690     if not lines:
691         return "", encoding, "\n"
692
693     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
694     srcbuf.seek(0)
695     with io.TextIOWrapper(srcbuf, encoding) as tiow:
696         return tiow.read(), encoding, newline
697
698
699 GRAMMARS = [
700     pygram.python_grammar_no_print_statement_no_exec_statement,
701     pygram.python_grammar_no_print_statement,
702     pygram.python_grammar,
703 ]
704
705
706 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
707     if not target_versions:
708         return GRAMMARS
709     elif all(not version.is_python2() for version in target_versions):
710         # Python 2-compatible code, so don't try Python 3 grammar.
711         return [
712             pygram.python_grammar_no_print_statement_no_exec_statement,
713             pygram.python_grammar_no_print_statement,
714         ]
715     else:
716         return [pygram.python_grammar]
717
718
719 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
720     """Given a string with source, return the lib2to3 Node."""
721     if src_txt[-1:] != "\n":
722         src_txt += "\n"
723
724     for grammar in get_grammars(set(target_versions)):
725         drv = driver.Driver(grammar, pytree.convert)
726         try:
727             result = drv.parse_string(src_txt, True)
728             break
729
730         except ParseError as pe:
731             lineno, column = pe.context[1]
732             lines = src_txt.splitlines()
733             try:
734                 faulty_line = lines[lineno - 1]
735             except IndexError:
736                 faulty_line = "<line number missing in source>"
737             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
738     else:
739         raise exc from None
740
741     if isinstance(result, Leaf):
742         result = Node(syms.file_input, [result])
743     return result
744
745
746 def lib2to3_unparse(node: Node) -> str:
747     """Given a lib2to3 node, return its string representation."""
748     code = str(node)
749     return code
750
751
752 T = TypeVar("T")
753
754
755 class Visitor(Generic[T]):
756     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
757
758     def visit(self, node: LN) -> Iterator[T]:
759         """Main method to visit `node` and its children.
760
761         It tries to find a `visit_*()` method for the given `node.type`, like
762         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
763         If no dedicated `visit_*()` method is found, chooses `visit_default()`
764         instead.
765
766         Then yields objects of type `T` from the selected visitor.
767         """
768         if node.type < 256:
769             name = token.tok_name[node.type]
770         else:
771             name = type_repr(node.type)
772         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
773
774     def visit_default(self, node: LN) -> Iterator[T]:
775         """Default `visit_*()` implementation. Recurses to children of `node`."""
776         if isinstance(node, Node):
777             for child in node.children:
778                 yield from self.visit(child)
779
780
781 @dataclass
782 class DebugVisitor(Visitor[T]):
783     tree_depth: int = 0
784
785     def visit_default(self, node: LN) -> Iterator[T]:
786         indent = " " * (2 * self.tree_depth)
787         if isinstance(node, Node):
788             _type = type_repr(node.type)
789             out(f"{indent}{_type}", fg="yellow")
790             self.tree_depth += 1
791             for child in node.children:
792                 yield from self.visit(child)
793
794             self.tree_depth -= 1
795             out(f"{indent}/{_type}", fg="yellow", bold=False)
796         else:
797             _type = token.tok_name.get(node.type, str(node.type))
798             out(f"{indent}{_type}", fg="blue", nl=False)
799             if node.prefix:
800                 # We don't have to handle prefixes for `Node` objects since
801                 # that delegates to the first child anyway.
802                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
803             out(f" {node.value!r}", fg="blue", bold=False)
804
805     @classmethod
806     def show(cls, code: Union[str, Leaf, Node]) -> None:
807         """Pretty-print the lib2to3 AST of a given string of `code`.
808
809         Convenience method for debugging.
810         """
811         v: DebugVisitor[None] = DebugVisitor()
812         if isinstance(code, str):
813             code = lib2to3_parse(code)
814         list(v.visit(code))
815
816
817 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
818 STATEMENT = {
819     syms.if_stmt,
820     syms.while_stmt,
821     syms.for_stmt,
822     syms.try_stmt,
823     syms.except_clause,
824     syms.with_stmt,
825     syms.funcdef,
826     syms.classdef,
827 }
828 STANDALONE_COMMENT = 153
829 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
830 LOGIC_OPERATORS = {"and", "or"}
831 COMPARATORS = {
832     token.LESS,
833     token.GREATER,
834     token.EQEQUAL,
835     token.NOTEQUAL,
836     token.LESSEQUAL,
837     token.GREATEREQUAL,
838 }
839 MATH_OPERATORS = {
840     token.VBAR,
841     token.CIRCUMFLEX,
842     token.AMPER,
843     token.LEFTSHIFT,
844     token.RIGHTSHIFT,
845     token.PLUS,
846     token.MINUS,
847     token.STAR,
848     token.SLASH,
849     token.DOUBLESLASH,
850     token.PERCENT,
851     token.AT,
852     token.TILDE,
853     token.DOUBLESTAR,
854 }
855 STARS = {token.STAR, token.DOUBLESTAR}
856 VARARGS_PARENTS = {
857     syms.arglist,
858     syms.argument,  # double star in arglist
859     syms.trailer,  # single argument to call
860     syms.typedargslist,
861     syms.varargslist,  # lambdas
862 }
863 UNPACKING_PARENTS = {
864     syms.atom,  # single element of a list or set literal
865     syms.dictsetmaker,
866     syms.listmaker,
867     syms.testlist_gexp,
868     syms.testlist_star_expr,
869 }
870 TEST_DESCENDANTS = {
871     syms.test,
872     syms.lambdef,
873     syms.or_test,
874     syms.and_test,
875     syms.not_test,
876     syms.comparison,
877     syms.star_expr,
878     syms.expr,
879     syms.xor_expr,
880     syms.and_expr,
881     syms.shift_expr,
882     syms.arith_expr,
883     syms.trailer,
884     syms.term,
885     syms.power,
886 }
887 ASSIGNMENTS = {
888     "=",
889     "+=",
890     "-=",
891     "*=",
892     "@=",
893     "/=",
894     "%=",
895     "&=",
896     "|=",
897     "^=",
898     "<<=",
899     ">>=",
900     "**=",
901     "//=",
902 }
903 COMPREHENSION_PRIORITY = 20
904 COMMA_PRIORITY = 18
905 TERNARY_PRIORITY = 16
906 LOGIC_PRIORITY = 14
907 STRING_PRIORITY = 12
908 COMPARATOR_PRIORITY = 10
909 MATH_PRIORITIES = {
910     token.VBAR: 9,
911     token.CIRCUMFLEX: 8,
912     token.AMPER: 7,
913     token.LEFTSHIFT: 6,
914     token.RIGHTSHIFT: 6,
915     token.PLUS: 5,
916     token.MINUS: 5,
917     token.STAR: 4,
918     token.SLASH: 4,
919     token.DOUBLESLASH: 4,
920     token.PERCENT: 4,
921     token.AT: 4,
922     token.TILDE: 3,
923     token.DOUBLESTAR: 2,
924 }
925 DOT_PRIORITY = 1
926
927
928 @dataclass
929 class BracketTracker:
930     """Keeps track of brackets on a line."""
931
932     depth: int = 0
933     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
934     delimiters: Dict[LeafID, Priority] = Factory(dict)
935     previous: Optional[Leaf] = None
936     _for_loop_depths: List[int] = Factory(list)
937     _lambda_argument_depths: List[int] = Factory(list)
938
939     def mark(self, leaf: Leaf) -> None:
940         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
941
942         All leaves receive an int `bracket_depth` field that stores how deep
943         within brackets a given leaf is. 0 means there are no enclosing brackets
944         that started on this line.
945
946         If a leaf is itself a closing bracket, it receives an `opening_bracket`
947         field that it forms a pair with. This is a one-directional link to
948         avoid reference cycles.
949
950         If a leaf is a delimiter (a token on which Black can split the line if
951         needed) and it's on depth 0, its `id()` is stored in the tracker's
952         `delimiters` field.
953         """
954         if leaf.type == token.COMMENT:
955             return
956
957         self.maybe_decrement_after_for_loop_variable(leaf)
958         self.maybe_decrement_after_lambda_arguments(leaf)
959         if leaf.type in CLOSING_BRACKETS:
960             self.depth -= 1
961             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
962             leaf.opening_bracket = opening_bracket
963         leaf.bracket_depth = self.depth
964         if self.depth == 0:
965             delim = is_split_before_delimiter(leaf, self.previous)
966             if delim and self.previous is not None:
967                 self.delimiters[id(self.previous)] = delim
968             else:
969                 delim = is_split_after_delimiter(leaf, self.previous)
970                 if delim:
971                     self.delimiters[id(leaf)] = delim
972         if leaf.type in OPENING_BRACKETS:
973             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
974             self.depth += 1
975         self.previous = leaf
976         self.maybe_increment_lambda_arguments(leaf)
977         self.maybe_increment_for_loop_variable(leaf)
978
979     def any_open_brackets(self) -> bool:
980         """Return True if there is an yet unmatched open bracket on the line."""
981         return bool(self.bracket_match)
982
983     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
984         """Return the highest priority of a delimiter found on the line.
985
986         Values are consistent with what `is_split_*_delimiter()` return.
987         Raises ValueError on no delimiters.
988         """
989         return max(v for k, v in self.delimiters.items() if k not in exclude)
990
991     def delimiter_count_with_priority(self, priority: int = 0) -> int:
992         """Return the number of delimiters with the given `priority`.
993
994         If no `priority` is passed, defaults to max priority on the line.
995         """
996         if not self.delimiters:
997             return 0
998
999         priority = priority or self.max_delimiter_priority()
1000         return sum(1 for p in self.delimiters.values() if p == priority)
1001
1002     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1003         """In a for loop, or comprehension, the variables are often unpacks.
1004
1005         To avoid splitting on the comma in this situation, increase the depth of
1006         tokens between `for` and `in`.
1007         """
1008         if leaf.type == token.NAME and leaf.value == "for":
1009             self.depth += 1
1010             self._for_loop_depths.append(self.depth)
1011             return True
1012
1013         return False
1014
1015     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1016         """See `maybe_increment_for_loop_variable` above for explanation."""
1017         if (
1018             self._for_loop_depths
1019             and self._for_loop_depths[-1] == self.depth
1020             and leaf.type == token.NAME
1021             and leaf.value == "in"
1022         ):
1023             self.depth -= 1
1024             self._for_loop_depths.pop()
1025             return True
1026
1027         return False
1028
1029     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1030         """In a lambda expression, there might be more than one argument.
1031
1032         To avoid splitting on the comma in this situation, increase the depth of
1033         tokens between `lambda` and `:`.
1034         """
1035         if leaf.type == token.NAME and leaf.value == "lambda":
1036             self.depth += 1
1037             self._lambda_argument_depths.append(self.depth)
1038             return True
1039
1040         return False
1041
1042     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1043         """See `maybe_increment_lambda_arguments` above for explanation."""
1044         if (
1045             self._lambda_argument_depths
1046             and self._lambda_argument_depths[-1] == self.depth
1047             and leaf.type == token.COLON
1048         ):
1049             self.depth -= 1
1050             self._lambda_argument_depths.pop()
1051             return True
1052
1053         return False
1054
1055     def get_open_lsqb(self) -> Optional[Leaf]:
1056         """Return the most recent opening square bracket (if any)."""
1057         return self.bracket_match.get((self.depth - 1, token.RSQB))
1058
1059
1060 @dataclass
1061 class Line:
1062     """Holds leaves and comments. Can be printed with `str(line)`."""
1063
1064     depth: int = 0
1065     leaves: List[Leaf] = Factory(list)
1066     # The LeafID keys of comments must remain ordered by the corresponding leaf's index
1067     # in leaves
1068     comments: Dict[LeafID, List[Leaf]] = Factory(dict)
1069     bracket_tracker: BracketTracker = Factory(BracketTracker)
1070     inside_brackets: bool = False
1071     should_explode: bool = False
1072
1073     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1074         """Add a new `leaf` to the end of the line.
1075
1076         Unless `preformatted` is True, the `leaf` will receive a new consistent
1077         whitespace prefix and metadata applied by :class:`BracketTracker`.
1078         Trailing commas are maybe removed, unpacked for loop variables are
1079         demoted from being delimiters.
1080
1081         Inline comments are put aside.
1082         """
1083         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1084         if not has_value:
1085             return
1086
1087         if token.COLON == leaf.type and self.is_class_paren_empty:
1088             del self.leaves[-2:]
1089         if self.leaves and not preformatted:
1090             # Note: at this point leaf.prefix should be empty except for
1091             # imports, for which we only preserve newlines.
1092             leaf.prefix += whitespace(
1093                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1094             )
1095         if self.inside_brackets or not preformatted:
1096             self.bracket_tracker.mark(leaf)
1097             self.maybe_remove_trailing_comma(leaf)
1098         if not self.append_comment(leaf):
1099             self.leaves.append(leaf)
1100
1101     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1102         """Like :func:`append()` but disallow invalid standalone comment structure.
1103
1104         Raises ValueError when any `leaf` is appended after a standalone comment
1105         or when a standalone comment is not the first leaf on the line.
1106         """
1107         if self.bracket_tracker.depth == 0:
1108             if self.is_comment:
1109                 raise ValueError("cannot append to standalone comments")
1110
1111             if self.leaves and leaf.type == STANDALONE_COMMENT:
1112                 raise ValueError(
1113                     "cannot append standalone comments to a populated line"
1114                 )
1115
1116         self.append(leaf, preformatted=preformatted)
1117
1118     @property
1119     def is_comment(self) -> bool:
1120         """Is this line a standalone comment?"""
1121         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1122
1123     @property
1124     def is_decorator(self) -> bool:
1125         """Is this line a decorator?"""
1126         return bool(self) and self.leaves[0].type == token.AT
1127
1128     @property
1129     def is_import(self) -> bool:
1130         """Is this an import line?"""
1131         return bool(self) and is_import(self.leaves[0])
1132
1133     @property
1134     def is_class(self) -> bool:
1135         """Is this line a class definition?"""
1136         return (
1137             bool(self)
1138             and self.leaves[0].type == token.NAME
1139             and self.leaves[0].value == "class"
1140         )
1141
1142     @property
1143     def is_stub_class(self) -> bool:
1144         """Is this line a class definition with a body consisting only of "..."?"""
1145         return self.is_class and self.leaves[-3:] == [
1146             Leaf(token.DOT, ".") for _ in range(3)
1147         ]
1148
1149     @property
1150     def is_def(self) -> bool:
1151         """Is this a function definition? (Also returns True for async defs.)"""
1152         try:
1153             first_leaf = self.leaves[0]
1154         except IndexError:
1155             return False
1156
1157         try:
1158             second_leaf: Optional[Leaf] = self.leaves[1]
1159         except IndexError:
1160             second_leaf = None
1161         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1162             first_leaf.type == token.ASYNC
1163             and second_leaf is not None
1164             and second_leaf.type == token.NAME
1165             and second_leaf.value == "def"
1166         )
1167
1168     @property
1169     def is_class_paren_empty(self) -> bool:
1170         """Is this a class with no base classes but using parentheses?
1171
1172         Those are unnecessary and should be removed.
1173         """
1174         return (
1175             bool(self)
1176             and len(self.leaves) == 4
1177             and self.is_class
1178             and self.leaves[2].type == token.LPAR
1179             and self.leaves[2].value == "("
1180             and self.leaves[3].type == token.RPAR
1181             and self.leaves[3].value == ")"
1182         )
1183
1184     @property
1185     def is_triple_quoted_string(self) -> bool:
1186         """Is the line a triple quoted string?"""
1187         return (
1188             bool(self)
1189             and self.leaves[0].type == token.STRING
1190             and self.leaves[0].value.startswith(('"""', "'''"))
1191         )
1192
1193     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1194         """If so, needs to be split before emitting."""
1195         for leaf in self.leaves:
1196             if leaf.type == STANDALONE_COMMENT:
1197                 if leaf.bracket_depth <= depth_limit:
1198                     return True
1199
1200         return False
1201
1202     def contains_multiline_strings(self) -> bool:
1203         for leaf in self.leaves:
1204             if is_multiline_string(leaf):
1205                 return True
1206
1207         return False
1208
1209     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1210         """Remove trailing comma if there is one and it's safe."""
1211         if not (
1212             self.leaves
1213             and self.leaves[-1].type == token.COMMA
1214             and closing.type in CLOSING_BRACKETS
1215         ):
1216             return False
1217
1218         if closing.type == token.RBRACE:
1219             self.remove_trailing_comma()
1220             return True
1221
1222         if closing.type == token.RSQB:
1223             comma = self.leaves[-1]
1224             if comma.parent and comma.parent.type == syms.listmaker:
1225                 self.remove_trailing_comma()
1226                 return True
1227
1228         # For parens let's check if it's safe to remove the comma.
1229         # Imports are always safe.
1230         if self.is_import:
1231             self.remove_trailing_comma()
1232             return True
1233
1234         # Otherwise, if the trailing one is the only one, we might mistakenly
1235         # change a tuple into a different type by removing the comma.
1236         depth = closing.bracket_depth + 1
1237         commas = 0
1238         opening = closing.opening_bracket
1239         for _opening_index, leaf in enumerate(self.leaves):
1240             if leaf is opening:
1241                 break
1242
1243         else:
1244             return False
1245
1246         for leaf in self.leaves[_opening_index + 1 :]:
1247             if leaf is closing:
1248                 break
1249
1250             bracket_depth = leaf.bracket_depth
1251             if bracket_depth == depth and leaf.type == token.COMMA:
1252                 commas += 1
1253                 if leaf.parent and leaf.parent.type == syms.arglist:
1254                     commas += 1
1255                     break
1256
1257         if commas > 1:
1258             self.remove_trailing_comma()
1259             return True
1260
1261         return False
1262
1263     def append_comment(self, comment: Leaf) -> bool:
1264         """Add an inline or standalone comment to the line."""
1265         if (
1266             comment.type == STANDALONE_COMMENT
1267             and self.bracket_tracker.any_open_brackets()
1268         ):
1269             comment.prefix = ""
1270             return False
1271
1272         if comment.type != token.COMMENT:
1273             return False
1274
1275         if not self.leaves:
1276             comment.type = STANDALONE_COMMENT
1277             comment.prefix = ""
1278             return False
1279
1280         else:
1281             leaf_id = id(self.leaves[-1])
1282             if leaf_id not in self.comments:
1283                 self.comments[leaf_id] = [comment]
1284             else:
1285                 self.comments[leaf_id].append(comment)
1286             return True
1287
1288     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1289         """Generate comments that should appear directly after `leaf`."""
1290         return self.comments.get(id(leaf), [])
1291
1292     def remove_trailing_comma(self) -> None:
1293         """Remove the trailing comma and moves the comments attached to it."""
1294         # Remember, the LeafID keys of self.comments are ordered by the
1295         # corresponding leaf's index in self.leaves
1296         # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
1297         # Otherwise, we insert it into self.comments, and it becomes the last entry.
1298         # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
1299         # is maintained
1300         self.comments.setdefault(id(self.leaves[-2]), []).extend(
1301             self.comments.get(id(self.leaves[-1]), [])
1302         )
1303         self.comments.pop(id(self.leaves[-1]), None)
1304         self.leaves.pop()
1305
1306     def is_complex_subscript(self, leaf: Leaf) -> bool:
1307         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1308         open_lsqb = self.bracket_tracker.get_open_lsqb()
1309         if open_lsqb is None:
1310             return False
1311
1312         subscript_start = open_lsqb.next_sibling
1313
1314         if isinstance(subscript_start, Node):
1315             if subscript_start.type == syms.listmaker:
1316                 return False
1317
1318             if subscript_start.type == syms.subscriptlist:
1319                 subscript_start = child_towards(subscript_start, leaf)
1320         return subscript_start is not None and any(
1321             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1322         )
1323
1324     def __str__(self) -> str:
1325         """Render the line."""
1326         if not self:
1327             return "\n"
1328
1329         indent = "    " * self.depth
1330         leaves = iter(self.leaves)
1331         first = next(leaves)
1332         res = f"{first.prefix}{indent}{first.value}"
1333         for leaf in leaves:
1334             res += str(leaf)
1335         for comment in itertools.chain.from_iterable(self.comments.values()):
1336             res += str(comment)
1337         return res + "\n"
1338
1339     def __bool__(self) -> bool:
1340         """Return True if the line has leaves or comments."""
1341         return bool(self.leaves or self.comments)
1342
1343
1344 @dataclass
1345 class EmptyLineTracker:
1346     """Provides a stateful method that returns the number of potential extra
1347     empty lines needed before and after the currently processed line.
1348
1349     Note: this tracker works on lines that haven't been split yet.  It assumes
1350     the prefix of the first leaf consists of optional newlines.  Those newlines
1351     are consumed by `maybe_empty_lines()` and included in the computation.
1352     """
1353
1354     is_pyi: bool = False
1355     previous_line: Optional[Line] = None
1356     previous_after: int = 0
1357     previous_defs: List[int] = Factory(list)
1358
1359     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1360         """Return the number of extra empty lines before and after the `current_line`.
1361
1362         This is for separating `def`, `async def` and `class` with extra empty
1363         lines (two on module-level).
1364         """
1365         before, after = self._maybe_empty_lines(current_line)
1366         before -= self.previous_after
1367         self.previous_after = after
1368         self.previous_line = current_line
1369         return before, after
1370
1371     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1372         max_allowed = 1
1373         if current_line.depth == 0:
1374             max_allowed = 1 if self.is_pyi else 2
1375         if current_line.leaves:
1376             # Consume the first leaf's extra newlines.
1377             first_leaf = current_line.leaves[0]
1378             before = first_leaf.prefix.count("\n")
1379             before = min(before, max_allowed)
1380             first_leaf.prefix = ""
1381         else:
1382             before = 0
1383         depth = current_line.depth
1384         while self.previous_defs and self.previous_defs[-1] >= depth:
1385             self.previous_defs.pop()
1386             if self.is_pyi:
1387                 before = 0 if depth else 1
1388             else:
1389                 before = 1 if depth else 2
1390         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1391             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1392
1393         if (
1394             self.previous_line
1395             and self.previous_line.is_import
1396             and not current_line.is_import
1397             and depth == self.previous_line.depth
1398         ):
1399             return (before or 1), 0
1400
1401         if (
1402             self.previous_line
1403             and self.previous_line.is_class
1404             and current_line.is_triple_quoted_string
1405         ):
1406             return before, 1
1407
1408         return before, 0
1409
1410     def _maybe_empty_lines_for_class_or_def(
1411         self, current_line: Line, before: int
1412     ) -> Tuple[int, int]:
1413         if not current_line.is_decorator:
1414             self.previous_defs.append(current_line.depth)
1415         if self.previous_line is None:
1416             # Don't insert empty lines before the first line in the file.
1417             return 0, 0
1418
1419         if self.previous_line.is_decorator:
1420             return 0, 0
1421
1422         if self.previous_line.depth < current_line.depth and (
1423             self.previous_line.is_class or self.previous_line.is_def
1424         ):
1425             return 0, 0
1426
1427         if (
1428             self.previous_line.is_comment
1429             and self.previous_line.depth == current_line.depth
1430             and before == 0
1431         ):
1432             return 0, 0
1433
1434         if self.is_pyi:
1435             if self.previous_line.depth > current_line.depth:
1436                 newlines = 1
1437             elif current_line.is_class or self.previous_line.is_class:
1438                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1439                     # No blank line between classes with an empty body
1440                     newlines = 0
1441                 else:
1442                     newlines = 1
1443             elif current_line.is_def and not self.previous_line.is_def:
1444                 # Blank line between a block of functions and a block of non-functions
1445                 newlines = 1
1446             else:
1447                 newlines = 0
1448         else:
1449             newlines = 2
1450         if current_line.depth and newlines:
1451             newlines -= 1
1452         return newlines, 0
1453
1454
1455 @dataclass
1456 class LineGenerator(Visitor[Line]):
1457     """Generates reformatted Line objects.  Empty lines are not emitted.
1458
1459     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1460     in ways that will no longer stringify to valid Python code on the tree.
1461     """
1462
1463     is_pyi: bool = False
1464     normalize_strings: bool = True
1465     current_line: Line = Factory(Line)
1466     remove_u_prefix: bool = False
1467
1468     def line(self, indent: int = 0) -> Iterator[Line]:
1469         """Generate a line.
1470
1471         If the line is empty, only emit if it makes sense.
1472         If the line is too long, split it first and then generate.
1473
1474         If any lines were generated, set up a new current_line.
1475         """
1476         if not self.current_line:
1477             self.current_line.depth += indent
1478             return  # Line is empty, don't emit. Creating a new one unnecessary.
1479
1480         complete_line = self.current_line
1481         self.current_line = Line(depth=complete_line.depth + indent)
1482         yield complete_line
1483
1484     def visit_default(self, node: LN) -> Iterator[Line]:
1485         """Default `visit_*()` implementation. Recurses to children of `node`."""
1486         if isinstance(node, Leaf):
1487             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1488             for comment in generate_comments(node):
1489                 if any_open_brackets:
1490                     # any comment within brackets is subject to splitting
1491                     self.current_line.append(comment)
1492                 elif comment.type == token.COMMENT:
1493                     # regular trailing comment
1494                     self.current_line.append(comment)
1495                     yield from self.line()
1496
1497                 else:
1498                     # regular standalone comment
1499                     yield from self.line()
1500
1501                     self.current_line.append(comment)
1502                     yield from self.line()
1503
1504             normalize_prefix(node, inside_brackets=any_open_brackets)
1505             if self.normalize_strings and node.type == token.STRING:
1506                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1507                 normalize_string_quotes(node)
1508             if node.type == token.NUMBER:
1509                 normalize_numeric_literal(node)
1510             if node.type not in WHITESPACE:
1511                 self.current_line.append(node)
1512         yield from super().visit_default(node)
1513
1514     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1515         """Increase indentation level, maybe yield a line."""
1516         # In blib2to3 INDENT never holds comments.
1517         yield from self.line(+1)
1518         yield from self.visit_default(node)
1519
1520     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1521         """Decrease indentation level, maybe yield a line."""
1522         # The current line might still wait for trailing comments.  At DEDENT time
1523         # there won't be any (they would be prefixes on the preceding NEWLINE).
1524         # Emit the line then.
1525         yield from self.line()
1526
1527         # While DEDENT has no value, its prefix may contain standalone comments
1528         # that belong to the current indentation level.  Get 'em.
1529         yield from self.visit_default(node)
1530
1531         # Finally, emit the dedent.
1532         yield from self.line(-1)
1533
1534     def visit_stmt(
1535         self, node: Node, keywords: Set[str], parens: Set[str]
1536     ) -> Iterator[Line]:
1537         """Visit a statement.
1538
1539         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1540         `def`, `with`, `class`, `assert` and assignments.
1541
1542         The relevant Python language `keywords` for a given statement will be
1543         NAME leaves within it. This methods puts those on a separate line.
1544
1545         `parens` holds a set of string leaf values immediately after which
1546         invisible parens should be put.
1547         """
1548         normalize_invisible_parens(node, parens_after=parens)
1549         for child in node.children:
1550             if child.type == token.NAME and child.value in keywords:  # type: ignore
1551                 yield from self.line()
1552
1553             yield from self.visit(child)
1554
1555     def visit_suite(self, node: Node) -> Iterator[Line]:
1556         """Visit a suite."""
1557         if self.is_pyi and is_stub_suite(node):
1558             yield from self.visit(node.children[2])
1559         else:
1560             yield from self.visit_default(node)
1561
1562     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1563         """Visit a statement without nested statements."""
1564         is_suite_like = node.parent and node.parent.type in STATEMENT
1565         if is_suite_like:
1566             if self.is_pyi and is_stub_body(node):
1567                 yield from self.visit_default(node)
1568             else:
1569                 yield from self.line(+1)
1570                 yield from self.visit_default(node)
1571                 yield from self.line(-1)
1572
1573         else:
1574             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1575                 yield from self.line()
1576             yield from self.visit_default(node)
1577
1578     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1579         """Visit `async def`, `async for`, `async with`."""
1580         yield from self.line()
1581
1582         children = iter(node.children)
1583         for child in children:
1584             yield from self.visit(child)
1585
1586             if child.type == token.ASYNC:
1587                 break
1588
1589         internal_stmt = next(children)
1590         for child in internal_stmt.children:
1591             yield from self.visit(child)
1592
1593     def visit_decorators(self, node: Node) -> Iterator[Line]:
1594         """Visit decorators."""
1595         for child in node.children:
1596             yield from self.line()
1597             yield from self.visit(child)
1598
1599     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1600         """Remove a semicolon and put the other statement on a separate line."""
1601         yield from self.line()
1602
1603     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1604         """End of file. Process outstanding comments and end with a newline."""
1605         yield from self.visit_default(leaf)
1606         yield from self.line()
1607
1608     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1609         if not self.current_line.bracket_tracker.any_open_brackets():
1610             yield from self.line()
1611         yield from self.visit_default(leaf)
1612
1613     def __attrs_post_init__(self) -> None:
1614         """You are in a twisty little maze of passages."""
1615         v = self.visit_stmt
1616         Ø: Set[str] = set()
1617         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1618         self.visit_if_stmt = partial(
1619             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1620         )
1621         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1622         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1623         self.visit_try_stmt = partial(
1624             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1625         )
1626         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1627         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1628         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1629         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1630         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1631         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1632         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1633         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1634         self.visit_async_funcdef = self.visit_async_stmt
1635         self.visit_decorated = self.visit_decorators
1636
1637
1638 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1639 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1640 OPENING_BRACKETS = set(BRACKET.keys())
1641 CLOSING_BRACKETS = set(BRACKET.values())
1642 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1643 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1644
1645
1646 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1647     """Return whitespace prefix if needed for the given `leaf`.
1648
1649     `complex_subscript` signals whether the given leaf is part of a subscription
1650     which has non-trivial arguments, like arithmetic expressions or function calls.
1651     """
1652     NO = ""
1653     SPACE = " "
1654     DOUBLESPACE = "  "
1655     t = leaf.type
1656     p = leaf.parent
1657     v = leaf.value
1658     if t in ALWAYS_NO_SPACE:
1659         return NO
1660
1661     if t == token.COMMENT:
1662         return DOUBLESPACE
1663
1664     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1665     if t == token.COLON and p.type not in {
1666         syms.subscript,
1667         syms.subscriptlist,
1668         syms.sliceop,
1669     }:
1670         return NO
1671
1672     prev = leaf.prev_sibling
1673     if not prev:
1674         prevp = preceding_leaf(p)
1675         if not prevp or prevp.type in OPENING_BRACKETS:
1676             return NO
1677
1678         if t == token.COLON:
1679             if prevp.type == token.COLON:
1680                 return NO
1681
1682             elif prevp.type != token.COMMA and not complex_subscript:
1683                 return NO
1684
1685             return SPACE
1686
1687         if prevp.type == token.EQUAL:
1688             if prevp.parent:
1689                 if prevp.parent.type in {
1690                     syms.arglist,
1691                     syms.argument,
1692                     syms.parameters,
1693                     syms.varargslist,
1694                 }:
1695                     return NO
1696
1697                 elif prevp.parent.type == syms.typedargslist:
1698                     # A bit hacky: if the equal sign has whitespace, it means we
1699                     # previously found it's a typed argument.  So, we're using
1700                     # that, too.
1701                     return prevp.prefix
1702
1703         elif prevp.type in STARS:
1704             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1705                 return NO
1706
1707         elif prevp.type == token.COLON:
1708             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1709                 return SPACE if complex_subscript else NO
1710
1711         elif (
1712             prevp.parent
1713             and prevp.parent.type == syms.factor
1714             and prevp.type in MATH_OPERATORS
1715         ):
1716             return NO
1717
1718         elif (
1719             prevp.type == token.RIGHTSHIFT
1720             and prevp.parent
1721             and prevp.parent.type == syms.shift_expr
1722             and prevp.prev_sibling
1723             and prevp.prev_sibling.type == token.NAME
1724             and prevp.prev_sibling.value == "print"  # type: ignore
1725         ):
1726             # Python 2 print chevron
1727             return NO
1728
1729     elif prev.type in OPENING_BRACKETS:
1730         return NO
1731
1732     if p.type in {syms.parameters, syms.arglist}:
1733         # untyped function signatures or calls
1734         if not prev or prev.type != token.COMMA:
1735             return NO
1736
1737     elif p.type == syms.varargslist:
1738         # lambdas
1739         if prev and prev.type != token.COMMA:
1740             return NO
1741
1742     elif p.type == syms.typedargslist:
1743         # typed function signatures
1744         if not prev:
1745             return NO
1746
1747         if t == token.EQUAL:
1748             if prev.type != syms.tname:
1749                 return NO
1750
1751         elif prev.type == token.EQUAL:
1752             # A bit hacky: if the equal sign has whitespace, it means we
1753             # previously found it's a typed argument.  So, we're using that, too.
1754             return prev.prefix
1755
1756         elif prev.type != token.COMMA:
1757             return NO
1758
1759     elif p.type == syms.tname:
1760         # type names
1761         if not prev:
1762             prevp = preceding_leaf(p)
1763             if not prevp or prevp.type != token.COMMA:
1764                 return NO
1765
1766     elif p.type == syms.trailer:
1767         # attributes and calls
1768         if t == token.LPAR or t == token.RPAR:
1769             return NO
1770
1771         if not prev:
1772             if t == token.DOT:
1773                 prevp = preceding_leaf(p)
1774                 if not prevp or prevp.type != token.NUMBER:
1775                     return NO
1776
1777             elif t == token.LSQB:
1778                 return NO
1779
1780         elif prev.type != token.COMMA:
1781             return NO
1782
1783     elif p.type == syms.argument:
1784         # single argument
1785         if t == token.EQUAL:
1786             return NO
1787
1788         if not prev:
1789             prevp = preceding_leaf(p)
1790             if not prevp or prevp.type == token.LPAR:
1791                 return NO
1792
1793         elif prev.type in {token.EQUAL} | STARS:
1794             return NO
1795
1796     elif p.type == syms.decorator:
1797         # decorators
1798         return NO
1799
1800     elif p.type == syms.dotted_name:
1801         if prev:
1802             return NO
1803
1804         prevp = preceding_leaf(p)
1805         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1806             return NO
1807
1808     elif p.type == syms.classdef:
1809         if t == token.LPAR:
1810             return NO
1811
1812         if prev and prev.type == token.LPAR:
1813             return NO
1814
1815     elif p.type in {syms.subscript, syms.sliceop}:
1816         # indexing
1817         if not prev:
1818             assert p.parent is not None, "subscripts are always parented"
1819             if p.parent.type == syms.subscriptlist:
1820                 return SPACE
1821
1822             return NO
1823
1824         elif not complex_subscript:
1825             return NO
1826
1827     elif p.type == syms.atom:
1828         if prev and t == token.DOT:
1829             # dots, but not the first one.
1830             return NO
1831
1832     elif p.type == syms.dictsetmaker:
1833         # dict unpacking
1834         if prev and prev.type == token.DOUBLESTAR:
1835             return NO
1836
1837     elif p.type in {syms.factor, syms.star_expr}:
1838         # unary ops
1839         if not prev:
1840             prevp = preceding_leaf(p)
1841             if not prevp or prevp.type in OPENING_BRACKETS:
1842                 return NO
1843
1844             prevp_parent = prevp.parent
1845             assert prevp_parent is not None
1846             if prevp.type == token.COLON and prevp_parent.type in {
1847                 syms.subscript,
1848                 syms.sliceop,
1849             }:
1850                 return NO
1851
1852             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1853                 return NO
1854
1855         elif t in {token.NAME, token.NUMBER, token.STRING}:
1856             return NO
1857
1858     elif p.type == syms.import_from:
1859         if t == token.DOT:
1860             if prev and prev.type == token.DOT:
1861                 return NO
1862
1863         elif t == token.NAME:
1864             if v == "import":
1865                 return SPACE
1866
1867             if prev and prev.type == token.DOT:
1868                 return NO
1869
1870     elif p.type == syms.sliceop:
1871         return NO
1872
1873     return SPACE
1874
1875
1876 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1877     """Return the first leaf that precedes `node`, if any."""
1878     while node:
1879         res = node.prev_sibling
1880         if res:
1881             if isinstance(res, Leaf):
1882                 return res
1883
1884             try:
1885                 return list(res.leaves())[-1]
1886
1887             except IndexError:
1888                 return None
1889
1890         node = node.parent
1891     return None
1892
1893
1894 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1895     """Return the child of `ancestor` that contains `descendant`."""
1896     node: Optional[LN] = descendant
1897     while node and node.parent != ancestor:
1898         node = node.parent
1899     return node
1900
1901
1902 def container_of(leaf: Leaf) -> LN:
1903     """Return `leaf` or one of its ancestors that is the topmost container of it.
1904
1905     By "container" we mean a node where `leaf` is the very first child.
1906     """
1907     same_prefix = leaf.prefix
1908     container: LN = leaf
1909     while container:
1910         parent = container.parent
1911         if parent is None:
1912             break
1913
1914         if parent.children[0].prefix != same_prefix:
1915             break
1916
1917         if parent.type == syms.file_input:
1918             break
1919
1920         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1921             break
1922
1923         container = parent
1924     return container
1925
1926
1927 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1928     """Return the priority of the `leaf` delimiter, given a line break after it.
1929
1930     The delimiter priorities returned here are from those delimiters that would
1931     cause a line break after themselves.
1932
1933     Higher numbers are higher priority.
1934     """
1935     if leaf.type == token.COMMA:
1936         return COMMA_PRIORITY
1937
1938     return 0
1939
1940
1941 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1942     """Return the priority of the `leaf` delimiter, given a line break before it.
1943
1944     The delimiter priorities returned here are from those delimiters that would
1945     cause a line break before themselves.
1946
1947     Higher numbers are higher priority.
1948     """
1949     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1950         # * and ** might also be MATH_OPERATORS but in this case they are not.
1951         # Don't treat them as a delimiter.
1952         return 0
1953
1954     if (
1955         leaf.type == token.DOT
1956         and leaf.parent
1957         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1958         and (previous is None or previous.type in CLOSING_BRACKETS)
1959     ):
1960         return DOT_PRIORITY
1961
1962     if (
1963         leaf.type in MATH_OPERATORS
1964         and leaf.parent
1965         and leaf.parent.type not in {syms.factor, syms.star_expr}
1966     ):
1967         return MATH_PRIORITIES[leaf.type]
1968
1969     if leaf.type in COMPARATORS:
1970         return COMPARATOR_PRIORITY
1971
1972     if (
1973         leaf.type == token.STRING
1974         and previous is not None
1975         and previous.type == token.STRING
1976     ):
1977         return STRING_PRIORITY
1978
1979     if leaf.type not in {token.NAME, token.ASYNC}:
1980         return 0
1981
1982     if (
1983         leaf.value == "for"
1984         and leaf.parent
1985         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1986         or leaf.type == token.ASYNC
1987     ):
1988         if (
1989             not isinstance(leaf.prev_sibling, Leaf)
1990             or leaf.prev_sibling.value != "async"
1991         ):
1992             return COMPREHENSION_PRIORITY
1993
1994     if (
1995         leaf.value == "if"
1996         and leaf.parent
1997         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1998     ):
1999         return COMPREHENSION_PRIORITY
2000
2001     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2002         return TERNARY_PRIORITY
2003
2004     if leaf.value == "is":
2005         return COMPARATOR_PRIORITY
2006
2007     if (
2008         leaf.value == "in"
2009         and leaf.parent
2010         and leaf.parent.type in {syms.comp_op, syms.comparison}
2011         and not (
2012             previous is not None
2013             and previous.type == token.NAME
2014             and previous.value == "not"
2015         )
2016     ):
2017         return COMPARATOR_PRIORITY
2018
2019     if (
2020         leaf.value == "not"
2021         and leaf.parent
2022         and leaf.parent.type == syms.comp_op
2023         and not (
2024             previous is not None
2025             and previous.type == token.NAME
2026             and previous.value == "is"
2027         )
2028     ):
2029         return COMPARATOR_PRIORITY
2030
2031     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2032         return LOGIC_PRIORITY
2033
2034     return 0
2035
2036
2037 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2038 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2039
2040
2041 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2042     """Clean the prefix of the `leaf` and generate comments from it, if any.
2043
2044     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2045     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2046     move because it does away with modifying the grammar to include all the
2047     possible places in which comments can be placed.
2048
2049     The sad consequence for us though is that comments don't "belong" anywhere.
2050     This is why this function generates simple parentless Leaf objects for
2051     comments.  We simply don't know what the correct parent should be.
2052
2053     No matter though, we can live without this.  We really only need to
2054     differentiate between inline and standalone comments.  The latter don't
2055     share the line with any code.
2056
2057     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2058     are emitted with a fake STANDALONE_COMMENT token identifier.
2059     """
2060     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2061         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2062
2063
2064 @dataclass
2065 class ProtoComment:
2066     """Describes a piece of syntax that is a comment.
2067
2068     It's not a :class:`blib2to3.pytree.Leaf` so that:
2069
2070     * it can be cached (`Leaf` objects should not be reused more than once as
2071       they store their lineno, column, prefix, and parent information);
2072     * `newlines` and `consumed` fields are kept separate from the `value`. This
2073       simplifies handling of special marker comments like ``# fmt: off/on``.
2074     """
2075
2076     type: int  # token.COMMENT or STANDALONE_COMMENT
2077     value: str  # content of the comment
2078     newlines: int  # how many newlines before the comment
2079     consumed: int  # how many characters of the original leaf's prefix did we consume
2080
2081
2082 @lru_cache(maxsize=4096)
2083 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2084     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2085     result: List[ProtoComment] = []
2086     if not prefix or "#" not in prefix:
2087         return result
2088
2089     consumed = 0
2090     nlines = 0
2091     for index, line in enumerate(prefix.split("\n")):
2092         consumed += len(line) + 1  # adding the length of the split '\n'
2093         line = line.lstrip()
2094         if not line:
2095             nlines += 1
2096         if not line.startswith("#"):
2097             continue
2098
2099         if index == 0 and not is_endmarker:
2100             comment_type = token.COMMENT  # simple trailing comment
2101         else:
2102             comment_type = STANDALONE_COMMENT
2103         comment = make_comment(line)
2104         result.append(
2105             ProtoComment(
2106                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2107             )
2108         )
2109         nlines = 0
2110     return result
2111
2112
2113 def make_comment(content: str) -> str:
2114     """Return a consistently formatted comment from the given `content` string.
2115
2116     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2117     space between the hash sign and the content.
2118
2119     If `content` didn't start with a hash sign, one is provided.
2120     """
2121     content = content.rstrip()
2122     if not content:
2123         return "#"
2124
2125     if content[0] == "#":
2126         content = content[1:]
2127     if content and content[0] not in " !:#'%":
2128         content = " " + content
2129     return "#" + content
2130
2131
2132 def split_line(
2133     line: Line,
2134     line_length: int,
2135     inner: bool = False,
2136     supports_trailing_commas: bool = False,
2137 ) -> Iterator[Line]:
2138     """Split a `line` into potentially many lines.
2139
2140     They should fit in the allotted `line_length` but might not be able to.
2141     `inner` signifies that there were a pair of brackets somewhere around the
2142     current `line`, possibly transitively. This means we can fallback to splitting
2143     by delimiters if the LHS/RHS don't yield any results.
2144
2145     If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
2146     """
2147     if line.is_comment:
2148         yield line
2149         return
2150
2151     line_str = str(line).strip("\n")
2152
2153     # we don't want to split special comments like type annotations
2154     # https://github.com/python/typing/issues/186
2155     has_special_comment = False
2156     for leaf in line.leaves:
2157         for comment in line.comments_after(leaf):
2158             if leaf.type == token.COMMA and is_special_comment(comment):
2159                 has_special_comment = True
2160
2161     if (
2162         not has_special_comment
2163         and not line.should_explode
2164         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2165     ):
2166         yield line
2167         return
2168
2169     split_funcs: List[SplitFunc]
2170     if line.is_def:
2171         split_funcs = [left_hand_split]
2172     else:
2173
2174         def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
2175             for omit in generate_trailers_to_omit(line, line_length):
2176                 lines = list(
2177                     right_hand_split(
2178                         line, line_length, supports_trailing_commas, omit=omit
2179                     )
2180                 )
2181                 if is_line_short_enough(lines[0], line_length=line_length):
2182                     yield from lines
2183                     return
2184
2185             # All splits failed, best effort split with no omits.
2186             # This mostly happens to multiline strings that are by definition
2187             # reported as not fitting a single line.
2188             yield from right_hand_split(line, supports_trailing_commas)
2189
2190         if line.inside_brackets:
2191             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2192         else:
2193             split_funcs = [rhs]
2194     for split_func in split_funcs:
2195         # We are accumulating lines in `result` because we might want to abort
2196         # mission and return the original line in the end, or attempt a different
2197         # split altogether.
2198         result: List[Line] = []
2199         try:
2200             for l in split_func(line, supports_trailing_commas):
2201                 if str(l).strip("\n") == line_str:
2202                     raise CannotSplit("Split function returned an unchanged result")
2203
2204                 result.extend(
2205                     split_line(
2206                         l,
2207                         line_length=line_length,
2208                         inner=True,
2209                         supports_trailing_commas=supports_trailing_commas,
2210                     )
2211                 )
2212         except CannotSplit:
2213             continue
2214
2215         else:
2216             yield from result
2217             break
2218
2219     else:
2220         yield line
2221
2222
2223 def left_hand_split(
2224     line: Line, supports_trailing_commas: bool = False
2225 ) -> Iterator[Line]:
2226     """Split line into many lines, starting with the first matching bracket pair.
2227
2228     Note: this usually looks weird, only use this for function definitions.
2229     Prefer RHS otherwise.  This is why this function is not symmetrical with
2230     :func:`right_hand_split` which also handles optional parentheses.
2231     """
2232     tail_leaves: List[Leaf] = []
2233     body_leaves: List[Leaf] = []
2234     head_leaves: List[Leaf] = []
2235     current_leaves = head_leaves
2236     matching_bracket = None
2237     for leaf in line.leaves:
2238         if (
2239             current_leaves is body_leaves
2240             and leaf.type in CLOSING_BRACKETS
2241             and leaf.opening_bracket is matching_bracket
2242         ):
2243             current_leaves = tail_leaves if body_leaves else head_leaves
2244         current_leaves.append(leaf)
2245         if current_leaves is head_leaves:
2246             if leaf.type in OPENING_BRACKETS:
2247                 matching_bracket = leaf
2248                 current_leaves = body_leaves
2249     if not matching_bracket:
2250         raise CannotSplit("No brackets found")
2251
2252     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2253     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2254     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2255     bracket_split_succeeded_or_raise(head, body, tail)
2256     for result in (head, body, tail):
2257         if result:
2258             yield result
2259
2260
2261 def right_hand_split(
2262     line: Line,
2263     line_length: int,
2264     supports_trailing_commas: bool = False,
2265     omit: Collection[LeafID] = (),
2266 ) -> Iterator[Line]:
2267     """Split line into many lines, starting with the last matching bracket pair.
2268
2269     If the split was by optional parentheses, attempt splitting without them, too.
2270     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2271     this split.
2272
2273     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2274     """
2275     tail_leaves: List[Leaf] = []
2276     body_leaves: List[Leaf] = []
2277     head_leaves: List[Leaf] = []
2278     current_leaves = tail_leaves
2279     opening_bracket = None
2280     closing_bracket = None
2281     for leaf in reversed(line.leaves):
2282         if current_leaves is body_leaves:
2283             if leaf is opening_bracket:
2284                 current_leaves = head_leaves if body_leaves else tail_leaves
2285         current_leaves.append(leaf)
2286         if current_leaves is tail_leaves:
2287             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2288                 opening_bracket = leaf.opening_bracket
2289                 closing_bracket = leaf
2290                 current_leaves = body_leaves
2291     if not (opening_bracket and closing_bracket and head_leaves):
2292         # If there is no opening or closing_bracket that means the split failed and
2293         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2294         # the matching `opening_bracket` wasn't available on `line` anymore.
2295         raise CannotSplit("No brackets found")
2296
2297     tail_leaves.reverse()
2298     body_leaves.reverse()
2299     head_leaves.reverse()
2300     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2301     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2302     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2303     bracket_split_succeeded_or_raise(head, body, tail)
2304     if (
2305         # the body shouldn't be exploded
2306         not body.should_explode
2307         # the opening bracket is an optional paren
2308         and opening_bracket.type == token.LPAR
2309         and not opening_bracket.value
2310         # the closing bracket is an optional paren
2311         and closing_bracket.type == token.RPAR
2312         and not closing_bracket.value
2313         # it's not an import (optional parens are the only thing we can split on
2314         # in this case; attempting a split without them is a waste of time)
2315         and not line.is_import
2316         # there are no standalone comments in the body
2317         and not body.contains_standalone_comments(0)
2318         # and we can actually remove the parens
2319         and can_omit_invisible_parens(body, line_length)
2320     ):
2321         omit = {id(closing_bracket), *omit}
2322         try:
2323             yield from right_hand_split(
2324                 line,
2325                 line_length,
2326                 supports_trailing_commas=supports_trailing_commas,
2327                 omit=omit,
2328             )
2329             return
2330
2331         except CannotSplit:
2332             if not (
2333                 can_be_split(body)
2334                 or is_line_short_enough(body, line_length=line_length)
2335             ):
2336                 raise CannotSplit(
2337                     "Splitting failed, body is still too long and can't be split."
2338                 )
2339
2340             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2341                 raise CannotSplit(
2342                     "The current optional pair of parentheses is bound to fail to "
2343                     "satisfy the splitting algorithm because the head or the tail "
2344                     "contains multiline strings which by definition never fit one "
2345                     "line."
2346                 )
2347
2348     ensure_visible(opening_bracket)
2349     ensure_visible(closing_bracket)
2350     for result in (head, body, tail):
2351         if result:
2352             yield result
2353
2354
2355 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2356     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2357
2358     Do nothing otherwise.
2359
2360     A left- or right-hand split is based on a pair of brackets. Content before
2361     (and including) the opening bracket is left on one line, content inside the
2362     brackets is put on a separate line, and finally content starting with and
2363     following the closing bracket is put on a separate line.
2364
2365     Those are called `head`, `body`, and `tail`, respectively. If the split
2366     produced the same line (all content in `head`) or ended up with an empty `body`
2367     and the `tail` is just the closing bracket, then it's considered failed.
2368     """
2369     tail_len = len(str(tail).strip())
2370     if not body:
2371         if tail_len == 0:
2372             raise CannotSplit("Splitting brackets produced the same line")
2373
2374         elif tail_len < 3:
2375             raise CannotSplit(
2376                 f"Splitting brackets on an empty body to save "
2377                 f"{tail_len} characters is not worth it"
2378             )
2379
2380
2381 def bracket_split_build_line(
2382     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2383 ) -> Line:
2384     """Return a new line with given `leaves` and respective comments from `original`.
2385
2386     If `is_body` is True, the result line is one-indented inside brackets and as such
2387     has its first leaf's prefix normalized and a trailing comma added when expected.
2388     """
2389     result = Line(depth=original.depth)
2390     if is_body:
2391         result.inside_brackets = True
2392         result.depth += 1
2393         if leaves:
2394             # Since body is a new indent level, remove spurious leading whitespace.
2395             normalize_prefix(leaves[0], inside_brackets=True)
2396             # Ensure a trailing comma when expected.
2397             if original.is_import:
2398                 if leaves[-1].type != token.COMMA:
2399                     leaves.append(Leaf(token.COMMA, ","))
2400     # Populate the line
2401     for leaf in leaves:
2402         result.append(leaf, preformatted=True)
2403         for comment_after in original.comments_after(leaf):
2404             result.append(comment_after, preformatted=True)
2405     if is_body:
2406         result.should_explode = should_explode(result, opening_bracket)
2407     return result
2408
2409
2410 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2411     """Normalize prefix of the first leaf in every line returned by `split_func`.
2412
2413     This is a decorator over relevant split functions.
2414     """
2415
2416     @wraps(split_func)
2417     def split_wrapper(
2418         line: Line, supports_trailing_commas: bool = False
2419     ) -> Iterator[Line]:
2420         for l in split_func(line, supports_trailing_commas):
2421             normalize_prefix(l.leaves[0], inside_brackets=True)
2422             yield l
2423
2424     return split_wrapper
2425
2426
2427 @dont_increase_indentation
2428 def delimiter_split(
2429     line: Line, supports_trailing_commas: bool = False
2430 ) -> Iterator[Line]:
2431     """Split according to delimiters of the highest priority.
2432
2433     If `supports_trailing_commas` is True, the split will add trailing commas
2434     also in function signatures that contain `*` and `**`.
2435     """
2436     try:
2437         last_leaf = line.leaves[-1]
2438     except IndexError:
2439         raise CannotSplit("Line empty")
2440
2441     bt = line.bracket_tracker
2442     try:
2443         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2444     except ValueError:
2445         raise CannotSplit("No delimiters found")
2446
2447     if delimiter_priority == DOT_PRIORITY:
2448         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2449             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2450
2451     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2452     lowest_depth = sys.maxsize
2453     trailing_comma_safe = True
2454
2455     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2456         """Append `leaf` to current line or to new line if appending impossible."""
2457         nonlocal current_line
2458         try:
2459             current_line.append_safe(leaf, preformatted=True)
2460         except ValueError:
2461             yield current_line
2462
2463             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2464             current_line.append(leaf)
2465
2466     for leaf in line.leaves:
2467         yield from append_to_line(leaf)
2468
2469         for comment_after in line.comments_after(leaf):
2470             yield from append_to_line(comment_after)
2471
2472         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2473         if leaf.bracket_depth == lowest_depth and is_vararg(
2474             leaf, within=VARARGS_PARENTS
2475         ):
2476             trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
2477         leaf_priority = bt.delimiters.get(id(leaf))
2478         if leaf_priority == delimiter_priority:
2479             yield current_line
2480
2481             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2482     if current_line:
2483         if (
2484             trailing_comma_safe
2485             and delimiter_priority == COMMA_PRIORITY
2486             and current_line.leaves[-1].type != token.COMMA
2487             and current_line.leaves[-1].type != STANDALONE_COMMENT
2488         ):
2489             current_line.append(Leaf(token.COMMA, ","))
2490         yield current_line
2491
2492
2493 @dont_increase_indentation
2494 def standalone_comment_split(
2495     line: Line, supports_trailing_commas: bool = False
2496 ) -> Iterator[Line]:
2497     """Split standalone comments from the rest of the line."""
2498     if not line.contains_standalone_comments(0):
2499         raise CannotSplit("Line does not have any standalone comments")
2500
2501     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2502
2503     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2504         """Append `leaf` to current line or to new line if appending impossible."""
2505         nonlocal current_line
2506         try:
2507             current_line.append_safe(leaf, preformatted=True)
2508         except ValueError:
2509             yield current_line
2510
2511             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2512             current_line.append(leaf)
2513
2514     for leaf in line.leaves:
2515         yield from append_to_line(leaf)
2516
2517         for comment_after in line.comments_after(leaf):
2518             yield from append_to_line(comment_after)
2519
2520     if current_line:
2521         yield current_line
2522
2523
2524 def is_import(leaf: Leaf) -> bool:
2525     """Return True if the given leaf starts an import statement."""
2526     p = leaf.parent
2527     t = leaf.type
2528     v = leaf.value
2529     return bool(
2530         t == token.NAME
2531         and (
2532             (v == "import" and p and p.type == syms.import_name)
2533             or (v == "from" and p and p.type == syms.import_from)
2534         )
2535     )
2536
2537
2538 def is_special_comment(leaf: Leaf) -> bool:
2539     """Return True if the given leaf is a special comment.
2540     Only returns true for type comments for now."""
2541     t = leaf.type
2542     v = leaf.value
2543     return bool(
2544         (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
2545     )
2546
2547
2548 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2549     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2550     else.
2551
2552     Note: don't use backslashes for formatting or you'll lose your voting rights.
2553     """
2554     if not inside_brackets:
2555         spl = leaf.prefix.split("#")
2556         if "\\" not in spl[0]:
2557             nl_count = spl[-1].count("\n")
2558             if len(spl) > 1:
2559                 nl_count -= 1
2560             leaf.prefix = "\n" * nl_count
2561             return
2562
2563     leaf.prefix = ""
2564
2565
2566 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2567     """Make all string prefixes lowercase.
2568
2569     If remove_u_prefix is given, also removes any u prefix from the string.
2570
2571     Note: Mutates its argument.
2572     """
2573     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2574     assert match is not None, f"failed to match string {leaf.value!r}"
2575     orig_prefix = match.group(1)
2576     new_prefix = orig_prefix.lower()
2577     if remove_u_prefix:
2578         new_prefix = new_prefix.replace("u", "")
2579     leaf.value = f"{new_prefix}{match.group(2)}"
2580
2581
2582 def normalize_string_quotes(leaf: Leaf) -> None:
2583     """Prefer double quotes but only if it doesn't cause more escaping.
2584
2585     Adds or removes backslashes as appropriate. Doesn't parse and fix
2586     strings nested in f-strings (yet).
2587
2588     Note: Mutates its argument.
2589     """
2590     value = leaf.value.lstrip("furbFURB")
2591     if value[:3] == '"""':
2592         return
2593
2594     elif value[:3] == "'''":
2595         orig_quote = "'''"
2596         new_quote = '"""'
2597     elif value[0] == '"':
2598         orig_quote = '"'
2599         new_quote = "'"
2600     else:
2601         orig_quote = "'"
2602         new_quote = '"'
2603     first_quote_pos = leaf.value.find(orig_quote)
2604     if first_quote_pos == -1:
2605         return  # There's an internal error
2606
2607     prefix = leaf.value[:first_quote_pos]
2608     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2609     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2610     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2611     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2612     if "r" in prefix.casefold():
2613         if unescaped_new_quote.search(body):
2614             # There's at least one unescaped new_quote in this raw string
2615             # so converting is impossible
2616             return
2617
2618         # Do not introduce or remove backslashes in raw strings
2619         new_body = body
2620     else:
2621         # remove unnecessary escapes
2622         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2623         if body != new_body:
2624             # Consider the string without unnecessary escapes as the original
2625             body = new_body
2626             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2627         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2628         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2629     if "f" in prefix.casefold():
2630         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2631         for m in matches:
2632             if "\\" in str(m):
2633                 # Do not introduce backslashes in interpolated expressions
2634                 return
2635     if new_quote == '"""' and new_body[-1:] == '"':
2636         # edge case:
2637         new_body = new_body[:-1] + '\\"'
2638     orig_escape_count = body.count("\\")
2639     new_escape_count = new_body.count("\\")
2640     if new_escape_count > orig_escape_count:
2641         return  # Do not introduce more escaping
2642
2643     if new_escape_count == orig_escape_count and orig_quote == '"':
2644         return  # Prefer double quotes
2645
2646     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2647
2648
2649 def normalize_numeric_literal(leaf: Leaf) -> None:
2650     """Normalizes numeric (float, int, and complex) literals.
2651
2652     All letters used in the representation are normalized to lowercase (except
2653     in Python 2 long literals).
2654     """
2655     text = leaf.value.lower()
2656     if text.startswith(("0o", "0b")):
2657         # Leave octal and binary literals alone.
2658         pass
2659     elif text.startswith("0x"):
2660         # Change hex literals to upper case.
2661         before, after = text[:2], text[2:]
2662         text = f"{before}{after.upper()}"
2663     elif "e" in text:
2664         before, after = text.split("e")
2665         sign = ""
2666         if after.startswith("-"):
2667             after = after[1:]
2668             sign = "-"
2669         elif after.startswith("+"):
2670             after = after[1:]
2671         before = format_float_or_int_string(before)
2672         text = f"{before}e{sign}{after}"
2673     elif text.endswith(("j", "l")):
2674         number = text[:-1]
2675         suffix = text[-1]
2676         # Capitalize in "2L" because "l" looks too similar to "1".
2677         if suffix == "l":
2678             suffix = "L"
2679         text = f"{format_float_or_int_string(number)}{suffix}"
2680     else:
2681         text = format_float_or_int_string(text)
2682     leaf.value = text
2683
2684
2685 def format_float_or_int_string(text: str) -> str:
2686     """Formats a float string like "1.0"."""
2687     if "." not in text:
2688         return text
2689
2690     before, after = text.split(".")
2691     return f"{before or 0}.{after or 0}"
2692
2693
2694 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2695     """Make existing optional parentheses invisible or create new ones.
2696
2697     `parens_after` is a set of string leaf values immeditely after which parens
2698     should be put.
2699
2700     Standardizes on visible parentheses for single-element tuples, and keeps
2701     existing visible parentheses for other tuples and generator expressions.
2702     """
2703     for pc in list_comments(node.prefix, is_endmarker=False):
2704         if pc.value in FMT_OFF:
2705             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2706             return
2707
2708     check_lpar = False
2709     for index, child in enumerate(list(node.children)):
2710         if check_lpar:
2711             if child.type == syms.atom:
2712                 if maybe_make_parens_invisible_in_atom(child):
2713                     lpar = Leaf(token.LPAR, "")
2714                     rpar = Leaf(token.RPAR, "")
2715                     index = child.remove() or 0
2716                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2717             elif is_one_tuple(child):
2718                 # wrap child in visible parentheses
2719                 lpar = Leaf(token.LPAR, "(")
2720                 rpar = Leaf(token.RPAR, ")")
2721                 child.remove()
2722                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2723             elif node.type == syms.import_from:
2724                 # "import from" nodes store parentheses directly as part of
2725                 # the statement
2726                 if child.type == token.LPAR:
2727                     # make parentheses invisible
2728                     child.value = ""  # type: ignore
2729                     node.children[-1].value = ""  # type: ignore
2730                 elif child.type != token.STAR:
2731                     # insert invisible parentheses
2732                     node.insert_child(index, Leaf(token.LPAR, ""))
2733                     node.append_child(Leaf(token.RPAR, ""))
2734                 break
2735
2736             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2737                 # wrap child in invisible parentheses
2738                 lpar = Leaf(token.LPAR, "")
2739                 rpar = Leaf(token.RPAR, "")
2740                 index = child.remove() or 0
2741                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2742
2743         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2744
2745
2746 def normalize_fmt_off(node: Node) -> None:
2747     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2748     try_again = True
2749     while try_again:
2750         try_again = convert_one_fmt_off_pair(node)
2751
2752
2753 def convert_one_fmt_off_pair(node: Node) -> bool:
2754     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2755
2756     Returns True if a pair was converted.
2757     """
2758     for leaf in node.leaves():
2759         previous_consumed = 0
2760         for comment in list_comments(leaf.prefix, is_endmarker=False):
2761             if comment.value in FMT_OFF:
2762                 # We only want standalone comments. If there's no previous leaf or
2763                 # the previous leaf is indentation, it's a standalone comment in
2764                 # disguise.
2765                 if comment.type != STANDALONE_COMMENT:
2766                     prev = preceding_leaf(leaf)
2767                     if prev and prev.type not in WHITESPACE:
2768                         continue
2769
2770                 ignored_nodes = list(generate_ignored_nodes(leaf))
2771                 if not ignored_nodes:
2772                     continue
2773
2774                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2775                 parent = first.parent
2776                 prefix = first.prefix
2777                 first.prefix = prefix[comment.consumed :]
2778                 hidden_value = (
2779                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2780                 )
2781                 if hidden_value.endswith("\n"):
2782                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2783                     # leaf (possibly followed by a DEDENT).
2784                     hidden_value = hidden_value[:-1]
2785                 first_idx = None
2786                 for ignored in ignored_nodes:
2787                     index = ignored.remove()
2788                     if first_idx is None:
2789                         first_idx = index
2790                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2791                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2792                 parent.insert_child(
2793                     first_idx,
2794                     Leaf(
2795                         STANDALONE_COMMENT,
2796                         hidden_value,
2797                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2798                     ),
2799                 )
2800                 return True
2801
2802             previous_consumed = comment.consumed
2803
2804     return False
2805
2806
2807 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2808     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2809
2810     Stops at the end of the block.
2811     """
2812     container: Optional[LN] = container_of(leaf)
2813     while container is not None and container.type != token.ENDMARKER:
2814         for comment in list_comments(container.prefix, is_endmarker=False):
2815             if comment.value in FMT_ON:
2816                 return
2817
2818         yield container
2819
2820         container = container.next_sibling
2821
2822
2823 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2824     """If it's safe, make the parens in the atom `node` invisible, recursively.
2825
2826     Returns whether the node should itself be wrapped in invisible parentheses.
2827
2828     """
2829     if (
2830         node.type != syms.atom
2831         or is_empty_tuple(node)
2832         or is_one_tuple(node)
2833         or is_yield(node)
2834         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2835     ):
2836         return False
2837
2838     first = node.children[0]
2839     last = node.children[-1]
2840     if first.type == token.LPAR and last.type == token.RPAR:
2841         # make parentheses invisible
2842         first.value = ""  # type: ignore
2843         last.value = ""  # type: ignore
2844         if len(node.children) > 1:
2845             maybe_make_parens_invisible_in_atom(node.children[1])
2846         return False
2847
2848     return True
2849
2850
2851 def is_empty_tuple(node: LN) -> bool:
2852     """Return True if `node` holds an empty tuple."""
2853     return (
2854         node.type == syms.atom
2855         and len(node.children) == 2
2856         and node.children[0].type == token.LPAR
2857         and node.children[1].type == token.RPAR
2858     )
2859
2860
2861 def is_one_tuple(node: LN) -> bool:
2862     """Return True if `node` holds a tuple with one element, with or without parens."""
2863     if node.type == syms.atom:
2864         if len(node.children) != 3:
2865             return False
2866
2867         lpar, gexp, rpar = node.children
2868         if not (
2869             lpar.type == token.LPAR
2870             and gexp.type == syms.testlist_gexp
2871             and rpar.type == token.RPAR
2872         ):
2873             return False
2874
2875         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2876
2877     return (
2878         node.type in IMPLICIT_TUPLE
2879         and len(node.children) == 2
2880         and node.children[1].type == token.COMMA
2881     )
2882
2883
2884 def is_yield(node: LN) -> bool:
2885     """Return True if `node` holds a `yield` or `yield from` expression."""
2886     if node.type == syms.yield_expr:
2887         return True
2888
2889     if node.type == token.NAME and node.value == "yield":  # type: ignore
2890         return True
2891
2892     if node.type != syms.atom:
2893         return False
2894
2895     if len(node.children) != 3:
2896         return False
2897
2898     lpar, expr, rpar = node.children
2899     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2900         return is_yield(expr)
2901
2902     return False
2903
2904
2905 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2906     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2907
2908     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2909     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2910     extended iterable unpacking (PEP 3132) and additional unpacking
2911     generalizations (PEP 448).
2912     """
2913     if leaf.type not in STARS or not leaf.parent:
2914         return False
2915
2916     p = leaf.parent
2917     if p.type == syms.star_expr:
2918         # Star expressions are also used as assignment targets in extended
2919         # iterable unpacking (PEP 3132).  See what its parent is instead.
2920         if not p.parent:
2921             return False
2922
2923         p = p.parent
2924
2925     return p.type in within
2926
2927
2928 def is_multiline_string(leaf: Leaf) -> bool:
2929     """Return True if `leaf` is a multiline string that actually spans many lines."""
2930     value = leaf.value.lstrip("furbFURB")
2931     return value[:3] in {'"""', "'''"} and "\n" in value
2932
2933
2934 def is_stub_suite(node: Node) -> bool:
2935     """Return True if `node` is a suite with a stub body."""
2936     if (
2937         len(node.children) != 4
2938         or node.children[0].type != token.NEWLINE
2939         or node.children[1].type != token.INDENT
2940         or node.children[3].type != token.DEDENT
2941     ):
2942         return False
2943
2944     return is_stub_body(node.children[2])
2945
2946
2947 def is_stub_body(node: LN) -> bool:
2948     """Return True if `node` is a simple statement containing an ellipsis."""
2949     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2950         return False
2951
2952     if len(node.children) != 2:
2953         return False
2954
2955     child = node.children[0]
2956     return (
2957         child.type == syms.atom
2958         and len(child.children) == 3
2959         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2960     )
2961
2962
2963 def max_delimiter_priority_in_atom(node: LN) -> int:
2964     """Return maximum delimiter priority inside `node`.
2965
2966     This is specific to atoms with contents contained in a pair of parentheses.
2967     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2968     """
2969     if node.type != syms.atom:
2970         return 0
2971
2972     first = node.children[0]
2973     last = node.children[-1]
2974     if not (first.type == token.LPAR and last.type == token.RPAR):
2975         return 0
2976
2977     bt = BracketTracker()
2978     for c in node.children[1:-1]:
2979         if isinstance(c, Leaf):
2980             bt.mark(c)
2981         else:
2982             for leaf in c.leaves():
2983                 bt.mark(leaf)
2984     try:
2985         return bt.max_delimiter_priority()
2986
2987     except ValueError:
2988         return 0
2989
2990
2991 def ensure_visible(leaf: Leaf) -> None:
2992     """Make sure parentheses are visible.
2993
2994     They could be invisible as part of some statements (see
2995     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2996     """
2997     if leaf.type == token.LPAR:
2998         leaf.value = "("
2999     elif leaf.type == token.RPAR:
3000         leaf.value = ")"
3001
3002
3003 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3004     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3005
3006     if not (
3007         opening_bracket.parent
3008         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3009         and opening_bracket.value in "[{("
3010     ):
3011         return False
3012
3013     try:
3014         last_leaf = line.leaves[-1]
3015         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3016         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3017     except (IndexError, ValueError):
3018         return False
3019
3020     return max_priority == COMMA_PRIORITY
3021
3022
3023 def get_features_used(node: Node) -> Set[Feature]:
3024     """Return a set of (relatively) new Python features used in this file.
3025
3026     Currently looking for:
3027     - f-strings;
3028     - underscores in numeric literals; and
3029     - trailing commas after * or ** in function signatures and calls.
3030     """
3031     features: Set[Feature] = set()
3032     for n in node.pre_order():
3033         if n.type == token.STRING:
3034             value_head = n.value[:2]  # type: ignore
3035             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3036                 features.add(Feature.F_STRINGS)
3037
3038         elif n.type == token.NUMBER:
3039             if "_" in n.value:  # type: ignore
3040                 features.add(Feature.NUMERIC_UNDERSCORES)
3041
3042         elif (
3043             n.type in {syms.typedargslist, syms.arglist}
3044             and n.children
3045             and n.children[-1].type == token.COMMA
3046         ):
3047             for ch in n.children:
3048                 if ch.type in STARS:
3049                     features.add(Feature.TRAILING_COMMA)
3050
3051                 if ch.type == syms.argument:
3052                     for argch in ch.children:
3053                         if argch.type in STARS:
3054                             features.add(Feature.TRAILING_COMMA)
3055
3056     return features
3057
3058
3059 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3060     """Detect the version to target based on the nodes used."""
3061     features = get_features_used(node)
3062     return {
3063         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3064     }
3065
3066
3067 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3068     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3069
3070     Brackets can be omitted if the entire trailer up to and including
3071     a preceding closing bracket fits in one line.
3072
3073     Yielded sets are cumulative (contain results of previous yields, too).  First
3074     set is empty.
3075     """
3076
3077     omit: Set[LeafID] = set()
3078     yield omit
3079
3080     length = 4 * line.depth
3081     opening_bracket = None
3082     closing_bracket = None
3083     inner_brackets: Set[LeafID] = set()
3084     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3085         length += leaf_length
3086         if length > line_length:
3087             break
3088
3089         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3090         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3091             break
3092
3093         if opening_bracket:
3094             if leaf is opening_bracket:
3095                 opening_bracket = None
3096             elif leaf.type in CLOSING_BRACKETS:
3097                 inner_brackets.add(id(leaf))
3098         elif leaf.type in CLOSING_BRACKETS:
3099             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3100                 # Empty brackets would fail a split so treat them as "inner"
3101                 # brackets (e.g. only add them to the `omit` set if another
3102                 # pair of brackets was good enough.
3103                 inner_brackets.add(id(leaf))
3104                 continue
3105
3106             if closing_bracket:
3107                 omit.add(id(closing_bracket))
3108                 omit.update(inner_brackets)
3109                 inner_brackets.clear()
3110                 yield omit
3111
3112             if leaf.value:
3113                 opening_bracket = leaf.opening_bracket
3114                 closing_bracket = leaf
3115
3116
3117 def get_future_imports(node: Node) -> Set[str]:
3118     """Return a set of __future__ imports in the file."""
3119     imports: Set[str] = set()
3120
3121     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3122         for child in children:
3123             if isinstance(child, Leaf):
3124                 if child.type == token.NAME:
3125                     yield child.value
3126             elif child.type == syms.import_as_name:
3127                 orig_name = child.children[0]
3128                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3129                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3130                 yield orig_name.value
3131             elif child.type == syms.import_as_names:
3132                 yield from get_imports_from_children(child.children)
3133             else:
3134                 assert False, "Invalid syntax parsing imports"
3135
3136     for child in node.children:
3137         if child.type != syms.simple_stmt:
3138             break
3139         first_child = child.children[0]
3140         if isinstance(first_child, Leaf):
3141             # Continue looking if we see a docstring; otherwise stop.
3142             if (
3143                 len(child.children) == 2
3144                 and first_child.type == token.STRING
3145                 and child.children[1].type == token.NEWLINE
3146             ):
3147                 continue
3148             else:
3149                 break
3150         elif first_child.type == syms.import_from:
3151             module_name = first_child.children[1]
3152             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3153                 break
3154             imports |= set(get_imports_from_children(first_child.children[3:]))
3155         else:
3156             break
3157     return imports
3158
3159
3160 def gen_python_files_in_dir(
3161     path: Path,
3162     root: Path,
3163     include: Pattern[str],
3164     exclude: Pattern[str],
3165     report: "Report",
3166 ) -> Iterator[Path]:
3167     """Generate all files under `path` whose paths are not excluded by the
3168     `exclude` regex, but are included by the `include` regex.
3169
3170     Symbolic links pointing outside of the `root` directory are ignored.
3171
3172     `report` is where output about exclusions goes.
3173     """
3174     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3175     for child in path.iterdir():
3176         try:
3177             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3178         except ValueError:
3179             if child.is_symlink():
3180                 report.path_ignored(
3181                     child, f"is a symbolic link that points outside {root}"
3182                 )
3183                 continue
3184
3185             raise
3186
3187         if child.is_dir():
3188             normalized_path += "/"
3189         exclude_match = exclude.search(normalized_path)
3190         if exclude_match and exclude_match.group(0):
3191             report.path_ignored(child, f"matches the --exclude regular expression")
3192             continue
3193
3194         if child.is_dir():
3195             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3196
3197         elif child.is_file():
3198             include_match = include.search(normalized_path)
3199             if include_match:
3200                 yield child
3201
3202
3203 @lru_cache()
3204 def find_project_root(srcs: Iterable[str]) -> Path:
3205     """Return a directory containing .git, .hg, or pyproject.toml.
3206
3207     That directory can be one of the directories passed in `srcs` or their
3208     common parent.
3209
3210     If no directory in the tree contains a marker that would specify it's the
3211     project root, the root of the file system is returned.
3212     """
3213     if not srcs:
3214         return Path("/").resolve()
3215
3216     common_base = min(Path(src).resolve() for src in srcs)
3217     if common_base.is_dir():
3218         # Append a fake file so `parents` below returns `common_base_dir`, too.
3219         common_base /= "fake-file"
3220     for directory in common_base.parents:
3221         if (directory / ".git").is_dir():
3222             return directory
3223
3224         if (directory / ".hg").is_dir():
3225             return directory
3226
3227         if (directory / "pyproject.toml").is_file():
3228             return directory
3229
3230     return directory
3231
3232
3233 @dataclass
3234 class Report:
3235     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3236
3237     check: bool = False
3238     quiet: bool = False
3239     verbose: bool = False
3240     change_count: int = 0
3241     same_count: int = 0
3242     failure_count: int = 0
3243
3244     def done(self, src: Path, changed: Changed) -> None:
3245         """Increment the counter for successful reformatting. Write out a message."""
3246         if changed is Changed.YES:
3247             reformatted = "would reformat" if self.check else "reformatted"
3248             if self.verbose or not self.quiet:
3249                 out(f"{reformatted} {src}")
3250             self.change_count += 1
3251         else:
3252             if self.verbose:
3253                 if changed is Changed.NO:
3254                     msg = f"{src} already well formatted, good job."
3255                 else:
3256                     msg = f"{src} wasn't modified on disk since last run."
3257                 out(msg, bold=False)
3258             self.same_count += 1
3259
3260     def failed(self, src: Path, message: str) -> None:
3261         """Increment the counter for failed reformatting. Write out a message."""
3262         err(f"error: cannot format {src}: {message}")
3263         self.failure_count += 1
3264
3265     def path_ignored(self, path: Path, message: str) -> None:
3266         if self.verbose:
3267             out(f"{path} ignored: {message}", bold=False)
3268
3269     @property
3270     def return_code(self) -> int:
3271         """Return the exit code that the app should use.
3272
3273         This considers the current state of changed files and failures:
3274         - if there were any failures, return 123;
3275         - if any files were changed and --check is being used, return 1;
3276         - otherwise return 0.
3277         """
3278         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3279         # 126 we have special return codes reserved by the shell.
3280         if self.failure_count:
3281             return 123
3282
3283         elif self.change_count and self.check:
3284             return 1
3285
3286         return 0
3287
3288     def __str__(self) -> str:
3289         """Render a color report of the current state.
3290
3291         Use `click.unstyle` to remove colors.
3292         """
3293         if self.check:
3294             reformatted = "would be reformatted"
3295             unchanged = "would be left unchanged"
3296             failed = "would fail to reformat"
3297         else:
3298             reformatted = "reformatted"
3299             unchanged = "left unchanged"
3300             failed = "failed to reformat"
3301         report = []
3302         if self.change_count:
3303             s = "s" if self.change_count > 1 else ""
3304             report.append(
3305                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3306             )
3307         if self.same_count:
3308             s = "s" if self.same_count > 1 else ""
3309             report.append(f"{self.same_count} file{s} {unchanged}")
3310         if self.failure_count:
3311             s = "s" if self.failure_count > 1 else ""
3312             report.append(
3313                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3314             )
3315         return ", ".join(report) + "."
3316
3317
3318 def assert_equivalent(src: str, dst: str) -> None:
3319     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3320
3321     import ast
3322     import traceback
3323
3324     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3325         """Simple visitor generating strings to compare ASTs by content."""
3326         yield f"{'  ' * depth}{node.__class__.__name__}("
3327
3328         for field in sorted(node._fields):
3329             try:
3330                 value = getattr(node, field)
3331             except AttributeError:
3332                 continue
3333
3334             yield f"{'  ' * (depth+1)}{field}="
3335
3336             if isinstance(value, list):
3337                 for item in value:
3338                     # Ignore nested tuples within del statements, because we may insert
3339                     # parentheses and they change the AST.
3340                     if (
3341                         field == "targets"
3342                         and isinstance(node, ast.Delete)
3343                         and isinstance(item, ast.Tuple)
3344                     ):
3345                         for item in item.elts:
3346                             yield from _v(item, depth + 2)
3347                     elif isinstance(item, ast.AST):
3348                         yield from _v(item, depth + 2)
3349
3350             elif isinstance(value, ast.AST):
3351                 yield from _v(value, depth + 2)
3352
3353             else:
3354                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3355
3356         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3357
3358     try:
3359         src_ast = ast.parse(src)
3360     except Exception as exc:
3361         major, minor = sys.version_info[:2]
3362         raise AssertionError(
3363             f"cannot use --safe with this file; failed to parse source file "
3364             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3365             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3366         )
3367
3368     try:
3369         dst_ast = ast.parse(dst)
3370     except Exception as exc:
3371         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3372         raise AssertionError(
3373             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3374             f"Please report a bug on https://github.com/ambv/black/issues.  "
3375             f"This invalid output might be helpful: {log}"
3376         ) from None
3377
3378     src_ast_str = "\n".join(_v(src_ast))
3379     dst_ast_str = "\n".join(_v(dst_ast))
3380     if src_ast_str != dst_ast_str:
3381         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3382         raise AssertionError(
3383             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3384             f"the source.  "
3385             f"Please report a bug on https://github.com/ambv/black/issues.  "
3386             f"This diff might be helpful: {log}"
3387         ) from None
3388
3389
3390 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3391     """Raise AssertionError if `dst` reformats differently the second time."""
3392     newdst = format_str(dst, mode=mode)
3393     if dst != newdst:
3394         log = dump_to_file(
3395             diff(src, dst, "source", "first pass"),
3396             diff(dst, newdst, "first pass", "second pass"),
3397         )
3398         raise AssertionError(
3399             f"INTERNAL ERROR: Black produced different code on the second pass "
3400             f"of the formatter.  "
3401             f"Please report a bug on https://github.com/ambv/black/issues.  "
3402             f"This diff might be helpful: {log}"
3403         ) from None
3404
3405
3406 def dump_to_file(*output: str) -> str:
3407     """Dump `output` to a temporary file. Return path to the file."""
3408     import tempfile
3409
3410     with tempfile.NamedTemporaryFile(
3411         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3412     ) as f:
3413         for lines in output:
3414             f.write(lines)
3415             if lines and lines[-1] != "\n":
3416                 f.write("\n")
3417     return f.name
3418
3419
3420 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3421     """Return a unified diff string between strings `a` and `b`."""
3422     import difflib
3423
3424     a_lines = [line + "\n" for line in a.split("\n")]
3425     b_lines = [line + "\n" for line in b.split("\n")]
3426     return "".join(
3427         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3428     )
3429
3430
3431 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3432     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3433     err("Aborted!")
3434     for task in tasks:
3435         task.cancel()
3436
3437
3438 def shutdown(loop: BaseEventLoop) -> None:
3439     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3440     try:
3441         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3442         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3443         if not to_cancel:
3444             return
3445
3446         for task in to_cancel:
3447             task.cancel()
3448         loop.run_until_complete(
3449             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3450         )
3451     finally:
3452         # `concurrent.futures.Future` objects cannot be cancelled once they
3453         # are already running. There might be some when the `shutdown()` happened.
3454         # Silence their logger's spew about the event loop being closed.
3455         cf_logger = logging.getLogger("concurrent.futures")
3456         cf_logger.setLevel(logging.CRITICAL)
3457         loop.close()
3458
3459
3460 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3461     """Replace `regex` with `replacement` twice on `original`.
3462
3463     This is used by string normalization to perform replaces on
3464     overlapping matches.
3465     """
3466     return regex.sub(replacement, regex.sub(replacement, original))
3467
3468
3469 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3470     """Compile a regular expression string in `regex`.
3471
3472     If it contains newlines, use verbose mode.
3473     """
3474     if "\n" in regex:
3475         regex = "(?x)" + regex
3476     return re.compile(regex)
3477
3478
3479 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3480     """Like `reversed(enumerate(sequence))` if that were possible."""
3481     index = len(sequence) - 1
3482     for element in reversed(sequence):
3483         yield (index, element)
3484         index -= 1
3485
3486
3487 def enumerate_with_length(
3488     line: Line, reversed: bool = False
3489 ) -> Iterator[Tuple[Index, Leaf, int]]:
3490     """Return an enumeration of leaves with their length.
3491
3492     Stops prematurely on multiline strings and standalone comments.
3493     """
3494     op = cast(
3495         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3496         enumerate_reversed if reversed else enumerate,
3497     )
3498     for index, leaf in op(line.leaves):
3499         length = len(leaf.prefix) + len(leaf.value)
3500         if "\n" in leaf.value:
3501             return  # Multiline strings, we can't continue.
3502
3503         comment: Optional[Leaf]
3504         for comment in line.comments_after(leaf):
3505             length += len(comment.value)
3506
3507         yield index, leaf, length
3508
3509
3510 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3511     """Return True if `line` is no longer than `line_length`.
3512
3513     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3514     """
3515     if not line_str:
3516         line_str = str(line).strip("\n")
3517     return (
3518         len(line_str) <= line_length
3519         and "\n" not in line_str  # multiline strings
3520         and not line.contains_standalone_comments()
3521     )
3522
3523
3524 def can_be_split(line: Line) -> bool:
3525     """Return False if the line cannot be split *for sure*.
3526
3527     This is not an exhaustive search but a cheap heuristic that we can use to
3528     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3529     in unnecessary parentheses).
3530     """
3531     leaves = line.leaves
3532     if len(leaves) < 2:
3533         return False
3534
3535     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3536         call_count = 0
3537         dot_count = 0
3538         next = leaves[-1]
3539         for leaf in leaves[-2::-1]:
3540             if leaf.type in OPENING_BRACKETS:
3541                 if next.type not in CLOSING_BRACKETS:
3542                     return False
3543
3544                 call_count += 1
3545             elif leaf.type == token.DOT:
3546                 dot_count += 1
3547             elif leaf.type == token.NAME:
3548                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3549                     return False
3550
3551             elif leaf.type not in CLOSING_BRACKETS:
3552                 return False
3553
3554             if dot_count > 1 and call_count > 1:
3555                 return False
3556
3557     return True
3558
3559
3560 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3561     """Does `line` have a shape safe to reformat without optional parens around it?
3562
3563     Returns True for only a subset of potentially nice looking formattings but
3564     the point is to not return false positives that end up producing lines that
3565     are too long.
3566     """
3567     bt = line.bracket_tracker
3568     if not bt.delimiters:
3569         # Without delimiters the optional parentheses are useless.
3570         return True
3571
3572     max_priority = bt.max_delimiter_priority()
3573     if bt.delimiter_count_with_priority(max_priority) > 1:
3574         # With more than one delimiter of a kind the optional parentheses read better.
3575         return False
3576
3577     if max_priority == DOT_PRIORITY:
3578         # A single stranded method call doesn't require optional parentheses.
3579         return True
3580
3581     assert len(line.leaves) >= 2, "Stranded delimiter"
3582
3583     first = line.leaves[0]
3584     second = line.leaves[1]
3585     penultimate = line.leaves[-2]
3586     last = line.leaves[-1]
3587
3588     # With a single delimiter, omit if the expression starts or ends with
3589     # a bracket.
3590     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3591         remainder = False
3592         length = 4 * line.depth
3593         for _index, leaf, leaf_length in enumerate_with_length(line):
3594             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3595                 remainder = True
3596             if remainder:
3597                 length += leaf_length
3598                 if length > line_length:
3599                     break
3600
3601                 if leaf.type in OPENING_BRACKETS:
3602                     # There are brackets we can further split on.
3603                     remainder = False
3604
3605         else:
3606             # checked the entire string and line length wasn't exceeded
3607             if len(line.leaves) == _index + 1:
3608                 return True
3609
3610         # Note: we are not returning False here because a line might have *both*
3611         # a leading opening bracket and a trailing closing bracket.  If the
3612         # opening bracket doesn't match our rule, maybe the closing will.
3613
3614     if (
3615         last.type == token.RPAR
3616         or last.type == token.RBRACE
3617         or (
3618             # don't use indexing for omitting optional parentheses;
3619             # it looks weird
3620             last.type == token.RSQB
3621             and last.parent
3622             and last.parent.type != syms.trailer
3623         )
3624     ):
3625         if penultimate.type in OPENING_BRACKETS:
3626             # Empty brackets don't help.
3627             return False
3628
3629         if is_multiline_string(first):
3630             # Additional wrapping of a multiline string in this situation is
3631             # unnecessary.
3632             return True
3633
3634         length = 4 * line.depth
3635         seen_other_brackets = False
3636         for _index, leaf, leaf_length in enumerate_with_length(line):
3637             length += leaf_length
3638             if leaf is last.opening_bracket:
3639                 if seen_other_brackets or length <= line_length:
3640                     return True
3641
3642             elif leaf.type in OPENING_BRACKETS:
3643                 # There are brackets we can further split on.
3644                 seen_other_brackets = True
3645
3646     return False
3647
3648
3649 def get_cache_file(mode: FileMode) -> Path:
3650     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3651
3652
3653 def read_cache(mode: FileMode) -> Cache:
3654     """Read the cache if it exists and is well formed.
3655
3656     If it is not well formed, the call to write_cache later should resolve the issue.
3657     """
3658     cache_file = get_cache_file(mode)
3659     if not cache_file.exists():
3660         return {}
3661
3662     with cache_file.open("rb") as fobj:
3663         try:
3664             cache: Cache = pickle.load(fobj)
3665         except pickle.UnpicklingError:
3666             return {}
3667
3668     return cache
3669
3670
3671 def get_cache_info(path: Path) -> CacheInfo:
3672     """Return the information used to check if a file is already formatted or not."""
3673     stat = path.stat()
3674     return stat.st_mtime, stat.st_size
3675
3676
3677 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3678     """Split an iterable of paths in `sources` into two sets.
3679
3680     The first contains paths of files that modified on disk or are not in the
3681     cache. The other contains paths to non-modified files.
3682     """
3683     todo, done = set(), set()
3684     for src in sources:
3685         src = src.resolve()
3686         if cache.get(src) != get_cache_info(src):
3687             todo.add(src)
3688         else:
3689             done.add(src)
3690     return todo, done
3691
3692
3693 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3694     """Update the cache file."""
3695     cache_file = get_cache_file(mode)
3696     try:
3697         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3698         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3699         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3700             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3701         os.replace(f.name, cache_file)
3702     except OSError:
3703         pass
3704
3705
3706 def patch_click() -> None:
3707     """Make Click not crash.
3708
3709     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3710     default which restricts paths that it can access during the lifetime of the
3711     application.  Click refuses to work in this scenario by raising a RuntimeError.
3712
3713     In case of Black the likelihood that non-ASCII characters are going to be used in
3714     file paths is minimal since it's Python source code.  Moreover, this crash was
3715     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3716     """
3717     try:
3718         from click import core
3719         from click import _unicodefun  # type: ignore
3720     except ModuleNotFoundError:
3721         return
3722
3723     for module in (core, _unicodefun):
3724         if hasattr(module, "_verify_python3_env"):
3725             module._verify_python3_env = lambda: None
3726
3727
3728 def patched_main() -> None:
3729     freeze_support()
3730     patch_click()
3731     main()
3732
3733
3734 if __name__ == "__main__":
3735     patched_main()