]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

minor: remove wrong comment in .flake8 (#788)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "19.3b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA_IN_CALL = 4
137     TRAILING_COMMA_IN_DEF = 5
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.PY27: set(),
142     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
144     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
145     TargetVersion.PY36: {
146         Feature.UNICODE_LITERALS,
147         Feature.F_STRINGS,
148         Feature.NUMERIC_UNDERSCORES,
149         Feature.TRAILING_COMMA_IN_CALL,
150         Feature.TRAILING_COMMA_IN_DEF,
151     },
152     TargetVersion.PY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA_IN_CALL,
157         Feature.TRAILING_COMMA_IN_DEF,
158     },
159     TargetVersion.PY38: {
160         Feature.UNICODE_LITERALS,
161         Feature.F_STRINGS,
162         Feature.NUMERIC_UNDERSCORES,
163         Feature.TRAILING_COMMA_IN_CALL,
164         Feature.TRAILING_COMMA_IN_DEF,
165     },
166 }
167
168
169 @dataclass
170 class FileMode:
171     target_versions: Set[TargetVersion] = Factory(set)
172     line_length: int = DEFAULT_LINE_LENGTH
173     string_normalization: bool = True
174     is_pyi: bool = False
175
176     def get_cache_key(self) -> str:
177         if self.target_versions:
178             version_str = ",".join(
179                 str(version.value)
180                 for version in sorted(self.target_versions, key=lambda v: v.value)
181             )
182         else:
183             version_str = "-"
184         parts = [
185             version_str,
186             str(self.line_length),
187             str(int(self.string_normalization)),
188             str(int(self.is_pyi)),
189         ]
190         return ".".join(parts)
191
192
193 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
194     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
195
196
197 def read_pyproject_toml(
198     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
199 ) -> Optional[str]:
200     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
201
202     Returns the path to a successfully found and read configuration file, None
203     otherwise.
204     """
205     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
206     if not value:
207         root = find_project_root(ctx.params.get("src", ()))
208         path = root / "pyproject.toml"
209         if path.is_file():
210             value = str(path)
211         else:
212             return None
213
214     try:
215         pyproject_toml = toml.load(value)
216         config = pyproject_toml.get("tool", {}).get("black", {})
217     except (toml.TomlDecodeError, OSError) as e:
218         raise click.FileError(
219             filename=value, hint=f"Error reading configuration file: {e}"
220         )
221
222     if not config:
223         return None
224
225     if ctx.default_map is None:
226         ctx.default_map = {}
227     ctx.default_map.update(  # type: ignore  # bad types in .pyi
228         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
229     )
230     return value
231
232
233 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
234 @click.option(
235     "-l",
236     "--line-length",
237     type=int,
238     default=DEFAULT_LINE_LENGTH,
239     help="How many characters per line to allow.",
240     show_default=True,
241 )
242 @click.option(
243     "-t",
244     "--target-version",
245     type=click.Choice([v.name.lower() for v in TargetVersion]),
246     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
247     multiple=True,
248     help=(
249         "Python versions that should be supported by Black's output. [default: "
250         "per-file auto-detection]"
251     ),
252 )
253 @click.option(
254     "--py36",
255     is_flag=True,
256     help=(
257         "Allow using Python 3.6-only syntax on all input files.  This will put "
258         "trailing commas in function signatures and calls also after *args and "
259         "**kwargs. Deprecated; use --target-version instead. "
260         "[default: per-file auto-detection]"
261     ),
262 )
263 @click.option(
264     "--pyi",
265     is_flag=True,
266     help=(
267         "Format all input files like typing stubs regardless of file extension "
268         "(useful when piping source on standard input)."
269     ),
270 )
271 @click.option(
272     "-S",
273     "--skip-string-normalization",
274     is_flag=True,
275     help="Don't normalize string quotes or prefixes.",
276 )
277 @click.option(
278     "--check",
279     is_flag=True,
280     help=(
281         "Don't write the files back, just return the status.  Return code 0 "
282         "means nothing would change.  Return code 1 means some files would be "
283         "reformatted.  Return code 123 means there was an internal error."
284     ),
285 )
286 @click.option(
287     "--diff",
288     is_flag=True,
289     help="Don't write the files back, just output a diff for each file on stdout.",
290 )
291 @click.option(
292     "--fast/--safe",
293     is_flag=True,
294     help="If --fast given, skip temporary sanity checks. [default: --safe]",
295 )
296 @click.option(
297     "--include",
298     type=str,
299     default=DEFAULT_INCLUDES,
300     help=(
301         "A regular expression that matches files and directories that should be "
302         "included on recursive searches.  An empty value means all files are "
303         "included regardless of the name.  Use forward slashes for directories on "
304         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
305         "later."
306     ),
307     show_default=True,
308 )
309 @click.option(
310     "--exclude",
311     type=str,
312     default=DEFAULT_EXCLUDES,
313     help=(
314         "A regular expression that matches files and directories that should be "
315         "excluded on recursive searches.  An empty value means no paths are excluded. "
316         "Use forward slashes for directories on all platforms (Windows, too).  "
317         "Exclusions are calculated first, inclusions later."
318     ),
319     show_default=True,
320 )
321 @click.option(
322     "-q",
323     "--quiet",
324     is_flag=True,
325     help=(
326         "Don't emit non-error messages to stderr. Errors are still emitted, "
327         "silence those with 2>/dev/null."
328     ),
329 )
330 @click.option(
331     "-v",
332     "--verbose",
333     is_flag=True,
334     help=(
335         "Also emit messages to stderr about files that were not changed or were "
336         "ignored due to --exclude=."
337     ),
338 )
339 @click.version_option(version=__version__)
340 @click.argument(
341     "src",
342     nargs=-1,
343     type=click.Path(
344         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
345     ),
346     is_eager=True,
347 )
348 @click.option(
349     "--config",
350     type=click.Path(
351         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
352     ),
353     is_eager=True,
354     callback=read_pyproject_toml,
355     help="Read configuration from PATH.",
356 )
357 @click.pass_context
358 def main(
359     ctx: click.Context,
360     line_length: int,
361     target_version: List[TargetVersion],
362     check: bool,
363     diff: bool,
364     fast: bool,
365     pyi: bool,
366     py36: bool,
367     skip_string_normalization: bool,
368     quiet: bool,
369     verbose: bool,
370     include: str,
371     exclude: str,
372     src: Tuple[str],
373     config: Optional[str],
374 ) -> None:
375     """The uncompromising code formatter."""
376     write_back = WriteBack.from_configuration(check=check, diff=diff)
377     if target_version:
378         if py36:
379             err(f"Cannot use both --target-version and --py36")
380             ctx.exit(2)
381         else:
382             versions = set(target_version)
383     elif py36:
384         err(
385             "--py36 is deprecated and will be removed in a future version. "
386             "Use --target-version py36 instead."
387         )
388         versions = PY36_VERSIONS
389     else:
390         # We'll autodetect later.
391         versions = set()
392     mode = FileMode(
393         target_versions=versions,
394         line_length=line_length,
395         is_pyi=pyi,
396         string_normalization=not skip_string_normalization,
397     )
398     if config and verbose:
399         out(f"Using configuration from {config}.", bold=False, fg="blue")
400     try:
401         include_regex = re_compile_maybe_verbose(include)
402     except re.error:
403         err(f"Invalid regular expression for include given: {include!r}")
404         ctx.exit(2)
405     try:
406         exclude_regex = re_compile_maybe_verbose(exclude)
407     except re.error:
408         err(f"Invalid regular expression for exclude given: {exclude!r}")
409         ctx.exit(2)
410     report = Report(check=check, quiet=quiet, verbose=verbose)
411     root = find_project_root(src)
412     sources: Set[Path] = set()
413     for s in src:
414         p = Path(s)
415         if p.is_dir():
416             sources.update(
417                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
418             )
419         elif p.is_file() or s == "-":
420             # if a file was explicitly given, we don't care about its extension
421             sources.add(p)
422         else:
423             err(f"invalid path: {s}")
424     if len(sources) == 0:
425         if verbose or not quiet:
426             out("No paths given. Nothing to do 😴")
427         ctx.exit(0)
428
429     if len(sources) == 1:
430         reformat_one(
431             src=sources.pop(),
432             fast=fast,
433             write_back=write_back,
434             mode=mode,
435             report=report,
436         )
437     else:
438         loop = asyncio.get_event_loop()
439         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
440         try:
441             loop.run_until_complete(
442                 schedule_formatting(
443                     sources=sources,
444                     fast=fast,
445                     write_back=write_back,
446                     mode=mode,
447                     report=report,
448                     loop=loop,
449                     executor=executor,
450                 )
451             )
452         finally:
453             shutdown(loop)
454     if verbose or not quiet:
455         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
456         out(f"All done! {bang}")
457         click.secho(str(report), err=True)
458     ctx.exit(report.return_code)
459
460
461 def reformat_one(
462     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
463 ) -> None:
464     """Reformat a single file under `src` without spawning child processes.
465
466     If `quiet` is True, non-error messages are not output. `line_length`,
467     `write_back`, `fast` and `pyi` options are passed to
468     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
469     """
470     try:
471         changed = Changed.NO
472         if not src.is_file() and str(src) == "-":
473             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
474                 changed = Changed.YES
475         else:
476             cache: Cache = {}
477             if write_back != WriteBack.DIFF:
478                 cache = read_cache(mode)
479                 res_src = src.resolve()
480                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
481                     changed = Changed.CACHED
482             if changed is not Changed.CACHED and format_file_in_place(
483                 src, fast=fast, write_back=write_back, mode=mode
484             ):
485                 changed = Changed.YES
486             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
487                 write_back is WriteBack.CHECK and changed is Changed.NO
488             ):
489                 write_cache(cache, [src], mode)
490         report.done(src, changed)
491     except Exception as exc:
492         report.failed(src, str(exc))
493
494
495 async def schedule_formatting(
496     sources: Set[Path],
497     fast: bool,
498     write_back: WriteBack,
499     mode: FileMode,
500     report: "Report",
501     loop: BaseEventLoop,
502     executor: Executor,
503 ) -> None:
504     """Run formatting of `sources` in parallel using the provided `executor`.
505
506     (Use ProcessPoolExecutors for actual parallelism.)
507
508     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
509     :func:`format_file_in_place`.
510     """
511     cache: Cache = {}
512     if write_back != WriteBack.DIFF:
513         cache = read_cache(mode)
514         sources, cached = filter_cached(cache, sources)
515         for src in sorted(cached):
516             report.done(src, Changed.CACHED)
517     if not sources:
518         return
519
520     cancelled = []
521     sources_to_cache = []
522     lock = None
523     if write_back == WriteBack.DIFF:
524         # For diff output, we need locks to ensure we don't interleave output
525         # from different processes.
526         manager = Manager()
527         lock = manager.Lock()
528     tasks = {
529         loop.run_in_executor(
530             executor, format_file_in_place, src, fast, mode, write_back, lock
531         ): src
532         for src in sorted(sources)
533     }
534     pending: Iterable[asyncio.Task] = tasks.keys()
535     try:
536         loop.add_signal_handler(signal.SIGINT, cancel, pending)
537         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
538     except NotImplementedError:
539         # There are no good alternatives for these on Windows.
540         pass
541     while pending:
542         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
543         for task in done:
544             src = tasks.pop(task)
545             if task.cancelled():
546                 cancelled.append(task)
547             elif task.exception():
548                 report.failed(src, str(task.exception()))
549             else:
550                 changed = Changed.YES if task.result() else Changed.NO
551                 # If the file was written back or was successfully checked as
552                 # well-formatted, store this information in the cache.
553                 if write_back is WriteBack.YES or (
554                     write_back is WriteBack.CHECK and changed is Changed.NO
555                 ):
556                     sources_to_cache.append(src)
557                 report.done(src, changed)
558     if cancelled:
559         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
560     if sources_to_cache:
561         write_cache(cache, sources_to_cache, mode)
562
563
564 def format_file_in_place(
565     src: Path,
566     fast: bool,
567     mode: FileMode,
568     write_back: WriteBack = WriteBack.NO,
569     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
570 ) -> bool:
571     """Format file under `src` path. Return True if changed.
572
573     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
574     code to the file.
575     `line_length` and `fast` options are passed to :func:`format_file_contents`.
576     """
577     if src.suffix == ".pyi":
578         mode = evolve(mode, is_pyi=True)
579
580     then = datetime.utcfromtimestamp(src.stat().st_mtime)
581     with open(src, "rb") as buf:
582         src_contents, encoding, newline = decode_bytes(buf.read())
583     try:
584         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
585     except NothingChanged:
586         return False
587
588     if write_back == write_back.YES:
589         with open(src, "w", encoding=encoding, newline=newline) as f:
590             f.write(dst_contents)
591     elif write_back == write_back.DIFF:
592         now = datetime.utcnow()
593         src_name = f"{src}\t{then} +0000"
594         dst_name = f"{src}\t{now} +0000"
595         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
596         if lock:
597             lock.acquire()
598         try:
599             f = io.TextIOWrapper(
600                 sys.stdout.buffer,
601                 encoding=encoding,
602                 newline=newline,
603                 write_through=True,
604             )
605             f.write(diff_contents)
606             f.detach()
607         finally:
608             if lock:
609                 lock.release()
610     return True
611
612
613 def format_stdin_to_stdout(
614     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
615 ) -> bool:
616     """Format file on stdin. Return True if changed.
617
618     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
619     write a diff to stdout. The `mode` argument is passed to
620     :func:`format_file_contents`.
621     """
622     then = datetime.utcnow()
623     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
624     dst = src
625     try:
626         dst = format_file_contents(src, fast=fast, mode=mode)
627         return True
628
629     except NothingChanged:
630         return False
631
632     finally:
633         f = io.TextIOWrapper(
634             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
635         )
636         if write_back == WriteBack.YES:
637             f.write(dst)
638         elif write_back == WriteBack.DIFF:
639             now = datetime.utcnow()
640             src_name = f"STDIN\t{then} +0000"
641             dst_name = f"STDOUT\t{now} +0000"
642             f.write(diff(src, dst, src_name, dst_name))
643         f.detach()
644
645
646 def format_file_contents(
647     src_contents: str, *, fast: bool, mode: FileMode
648 ) -> FileContent:
649     """Reformat contents a file and return new contents.
650
651     If `fast` is False, additionally confirm that the reformatted code is
652     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
653     `line_length` is passed to :func:`format_str`.
654     """
655     if src_contents.strip() == "":
656         raise NothingChanged
657
658     dst_contents = format_str(src_contents, mode=mode)
659     if src_contents == dst_contents:
660         raise NothingChanged
661
662     if not fast:
663         assert_equivalent(src_contents, dst_contents)
664         assert_stable(src_contents, dst_contents, mode=mode)
665     return dst_contents
666
667
668 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
669     """Reformat a string and return new contents.
670
671     `line_length` determines how many characters per line are allowed.
672     """
673     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
674     dst_contents = ""
675     future_imports = get_future_imports(src_node)
676     if mode.target_versions:
677         versions = mode.target_versions
678     else:
679         versions = detect_target_versions(src_node)
680     normalize_fmt_off(src_node)
681     lines = LineGenerator(
682         remove_u_prefix="unicode_literals" in future_imports
683         or supports_feature(versions, Feature.UNICODE_LITERALS),
684         is_pyi=mode.is_pyi,
685         normalize_strings=mode.string_normalization,
686     )
687     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
688     empty_line = Line()
689     after = 0
690     split_line_features = {
691         feature
692         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
693         if supports_feature(versions, feature)
694     }
695     for current_line in lines.visit(src_node):
696         for _ in range(after):
697             dst_contents += str(empty_line)
698         before, after = elt.maybe_empty_lines(current_line)
699         for _ in range(before):
700             dst_contents += str(empty_line)
701         for line in split_line(
702             current_line, line_length=mode.line_length, features=split_line_features
703         ):
704             dst_contents += str(line)
705     return dst_contents
706
707
708 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
709     """Return a tuple of (decoded_contents, encoding, newline).
710
711     `newline` is either CRLF or LF but `decoded_contents` is decoded with
712     universal newlines (i.e. only contains LF).
713     """
714     srcbuf = io.BytesIO(src)
715     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
716     if not lines:
717         return "", encoding, "\n"
718
719     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
720     srcbuf.seek(0)
721     with io.TextIOWrapper(srcbuf, encoding) as tiow:
722         return tiow.read(), encoding, newline
723
724
725 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
726     if not target_versions:
727         # No target_version specified, so try all grammars.
728         return [
729             pygram.python_grammar_no_print_statement_no_exec_statement,
730             pygram.python_grammar_no_print_statement,
731             pygram.python_grammar,
732         ]
733     elif all(version.is_python2() for version in target_versions):
734         # Python 2-only code, so try Python 2 grammars.
735         return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
736     else:
737         # Python 3-compatible code, so only try Python 3 grammar.
738         return [pygram.python_grammar_no_print_statement_no_exec_statement]
739
740
741 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
742     """Given a string with source, return the lib2to3 Node."""
743     if src_txt[-1:] != "\n":
744         src_txt += "\n"
745
746     for grammar in get_grammars(set(target_versions)):
747         drv = driver.Driver(grammar, pytree.convert)
748         try:
749             result = drv.parse_string(src_txt, True)
750             break
751
752         except ParseError as pe:
753             lineno, column = pe.context[1]
754             lines = src_txt.splitlines()
755             try:
756                 faulty_line = lines[lineno - 1]
757             except IndexError:
758                 faulty_line = "<line number missing in source>"
759             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
760     else:
761         raise exc from None
762
763     if isinstance(result, Leaf):
764         result = Node(syms.file_input, [result])
765     return result
766
767
768 def lib2to3_unparse(node: Node) -> str:
769     """Given a lib2to3 node, return its string representation."""
770     code = str(node)
771     return code
772
773
774 T = TypeVar("T")
775
776
777 class Visitor(Generic[T]):
778     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
779
780     def visit(self, node: LN) -> Iterator[T]:
781         """Main method to visit `node` and its children.
782
783         It tries to find a `visit_*()` method for the given `node.type`, like
784         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
785         If no dedicated `visit_*()` method is found, chooses `visit_default()`
786         instead.
787
788         Then yields objects of type `T` from the selected visitor.
789         """
790         if node.type < 256:
791             name = token.tok_name[node.type]
792         else:
793             name = type_repr(node.type)
794         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
795
796     def visit_default(self, node: LN) -> Iterator[T]:
797         """Default `visit_*()` implementation. Recurses to children of `node`."""
798         if isinstance(node, Node):
799             for child in node.children:
800                 yield from self.visit(child)
801
802
803 @dataclass
804 class DebugVisitor(Visitor[T]):
805     tree_depth: int = 0
806
807     def visit_default(self, node: LN) -> Iterator[T]:
808         indent = " " * (2 * self.tree_depth)
809         if isinstance(node, Node):
810             _type = type_repr(node.type)
811             out(f"{indent}{_type}", fg="yellow")
812             self.tree_depth += 1
813             for child in node.children:
814                 yield from self.visit(child)
815
816             self.tree_depth -= 1
817             out(f"{indent}/{_type}", fg="yellow", bold=False)
818         else:
819             _type = token.tok_name.get(node.type, str(node.type))
820             out(f"{indent}{_type}", fg="blue", nl=False)
821             if node.prefix:
822                 # We don't have to handle prefixes for `Node` objects since
823                 # that delegates to the first child anyway.
824                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
825             out(f" {node.value!r}", fg="blue", bold=False)
826
827     @classmethod
828     def show(cls, code: Union[str, Leaf, Node]) -> None:
829         """Pretty-print the lib2to3 AST of a given string of `code`.
830
831         Convenience method for debugging.
832         """
833         v: DebugVisitor[None] = DebugVisitor()
834         if isinstance(code, str):
835             code = lib2to3_parse(code)
836         list(v.visit(code))
837
838
839 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
840 STATEMENT = {
841     syms.if_stmt,
842     syms.while_stmt,
843     syms.for_stmt,
844     syms.try_stmt,
845     syms.except_clause,
846     syms.with_stmt,
847     syms.funcdef,
848     syms.classdef,
849 }
850 STANDALONE_COMMENT = 153
851 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
852 LOGIC_OPERATORS = {"and", "or"}
853 COMPARATORS = {
854     token.LESS,
855     token.GREATER,
856     token.EQEQUAL,
857     token.NOTEQUAL,
858     token.LESSEQUAL,
859     token.GREATEREQUAL,
860 }
861 MATH_OPERATORS = {
862     token.VBAR,
863     token.CIRCUMFLEX,
864     token.AMPER,
865     token.LEFTSHIFT,
866     token.RIGHTSHIFT,
867     token.PLUS,
868     token.MINUS,
869     token.STAR,
870     token.SLASH,
871     token.DOUBLESLASH,
872     token.PERCENT,
873     token.AT,
874     token.TILDE,
875     token.DOUBLESTAR,
876 }
877 STARS = {token.STAR, token.DOUBLESTAR}
878 VARARGS_PARENTS = {
879     syms.arglist,
880     syms.argument,  # double star in arglist
881     syms.trailer,  # single argument to call
882     syms.typedargslist,
883     syms.varargslist,  # lambdas
884 }
885 UNPACKING_PARENTS = {
886     syms.atom,  # single element of a list or set literal
887     syms.dictsetmaker,
888     syms.listmaker,
889     syms.testlist_gexp,
890     syms.testlist_star_expr,
891 }
892 TEST_DESCENDANTS = {
893     syms.test,
894     syms.lambdef,
895     syms.or_test,
896     syms.and_test,
897     syms.not_test,
898     syms.comparison,
899     syms.star_expr,
900     syms.expr,
901     syms.xor_expr,
902     syms.and_expr,
903     syms.shift_expr,
904     syms.arith_expr,
905     syms.trailer,
906     syms.term,
907     syms.power,
908 }
909 ASSIGNMENTS = {
910     "=",
911     "+=",
912     "-=",
913     "*=",
914     "@=",
915     "/=",
916     "%=",
917     "&=",
918     "|=",
919     "^=",
920     "<<=",
921     ">>=",
922     "**=",
923     "//=",
924 }
925 COMPREHENSION_PRIORITY = 20
926 COMMA_PRIORITY = 18
927 TERNARY_PRIORITY = 16
928 LOGIC_PRIORITY = 14
929 STRING_PRIORITY = 12
930 COMPARATOR_PRIORITY = 10
931 MATH_PRIORITIES = {
932     token.VBAR: 9,
933     token.CIRCUMFLEX: 8,
934     token.AMPER: 7,
935     token.LEFTSHIFT: 6,
936     token.RIGHTSHIFT: 6,
937     token.PLUS: 5,
938     token.MINUS: 5,
939     token.STAR: 4,
940     token.SLASH: 4,
941     token.DOUBLESLASH: 4,
942     token.PERCENT: 4,
943     token.AT: 4,
944     token.TILDE: 3,
945     token.DOUBLESTAR: 2,
946 }
947 DOT_PRIORITY = 1
948
949
950 @dataclass
951 class BracketTracker:
952     """Keeps track of brackets on a line."""
953
954     depth: int = 0
955     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
956     delimiters: Dict[LeafID, Priority] = Factory(dict)
957     previous: Optional[Leaf] = None
958     _for_loop_depths: List[int] = Factory(list)
959     _lambda_argument_depths: List[int] = Factory(list)
960
961     def mark(self, leaf: Leaf) -> None:
962         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
963
964         All leaves receive an int `bracket_depth` field that stores how deep
965         within brackets a given leaf is. 0 means there are no enclosing brackets
966         that started on this line.
967
968         If a leaf is itself a closing bracket, it receives an `opening_bracket`
969         field that it forms a pair with. This is a one-directional link to
970         avoid reference cycles.
971
972         If a leaf is a delimiter (a token on which Black can split the line if
973         needed) and it's on depth 0, its `id()` is stored in the tracker's
974         `delimiters` field.
975         """
976         if leaf.type == token.COMMENT:
977             return
978
979         self.maybe_decrement_after_for_loop_variable(leaf)
980         self.maybe_decrement_after_lambda_arguments(leaf)
981         if leaf.type in CLOSING_BRACKETS:
982             self.depth -= 1
983             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
984             leaf.opening_bracket = opening_bracket
985         leaf.bracket_depth = self.depth
986         if self.depth == 0:
987             delim = is_split_before_delimiter(leaf, self.previous)
988             if delim and self.previous is not None:
989                 self.delimiters[id(self.previous)] = delim
990             else:
991                 delim = is_split_after_delimiter(leaf, self.previous)
992                 if delim:
993                     self.delimiters[id(leaf)] = delim
994         if leaf.type in OPENING_BRACKETS:
995             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
996             self.depth += 1
997         self.previous = leaf
998         self.maybe_increment_lambda_arguments(leaf)
999         self.maybe_increment_for_loop_variable(leaf)
1000
1001     def any_open_brackets(self) -> bool:
1002         """Return True if there is an yet unmatched open bracket on the line."""
1003         return bool(self.bracket_match)
1004
1005     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1006         """Return the highest priority of a delimiter found on the line.
1007
1008         Values are consistent with what `is_split_*_delimiter()` return.
1009         Raises ValueError on no delimiters.
1010         """
1011         return max(v for k, v in self.delimiters.items() if k not in exclude)
1012
1013     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1014         """Return the number of delimiters with the given `priority`.
1015
1016         If no `priority` is passed, defaults to max priority on the line.
1017         """
1018         if not self.delimiters:
1019             return 0
1020
1021         priority = priority or self.max_delimiter_priority()
1022         return sum(1 for p in self.delimiters.values() if p == priority)
1023
1024     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1025         """In a for loop, or comprehension, the variables are often unpacks.
1026
1027         To avoid splitting on the comma in this situation, increase the depth of
1028         tokens between `for` and `in`.
1029         """
1030         if leaf.type == token.NAME and leaf.value == "for":
1031             self.depth += 1
1032             self._for_loop_depths.append(self.depth)
1033             return True
1034
1035         return False
1036
1037     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1038         """See `maybe_increment_for_loop_variable` above for explanation."""
1039         if (
1040             self._for_loop_depths
1041             and self._for_loop_depths[-1] == self.depth
1042             and leaf.type == token.NAME
1043             and leaf.value == "in"
1044         ):
1045             self.depth -= 1
1046             self._for_loop_depths.pop()
1047             return True
1048
1049         return False
1050
1051     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1052         """In a lambda expression, there might be more than one argument.
1053
1054         To avoid splitting on the comma in this situation, increase the depth of
1055         tokens between `lambda` and `:`.
1056         """
1057         if leaf.type == token.NAME and leaf.value == "lambda":
1058             self.depth += 1
1059             self._lambda_argument_depths.append(self.depth)
1060             return True
1061
1062         return False
1063
1064     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1065         """See `maybe_increment_lambda_arguments` above for explanation."""
1066         if (
1067             self._lambda_argument_depths
1068             and self._lambda_argument_depths[-1] == self.depth
1069             and leaf.type == token.COLON
1070         ):
1071             self.depth -= 1
1072             self._lambda_argument_depths.pop()
1073             return True
1074
1075         return False
1076
1077     def get_open_lsqb(self) -> Optional[Leaf]:
1078         """Return the most recent opening square bracket (if any)."""
1079         return self.bracket_match.get((self.depth - 1, token.RSQB))
1080
1081
1082 @dataclass
1083 class Line:
1084     """Holds leaves and comments. Can be printed with `str(line)`."""
1085
1086     depth: int = 0
1087     leaves: List[Leaf] = Factory(list)
1088     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1089     bracket_tracker: BracketTracker = Factory(BracketTracker)
1090     inside_brackets: bool = False
1091     should_explode: bool = False
1092
1093     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1094         """Add a new `leaf` to the end of the line.
1095
1096         Unless `preformatted` is True, the `leaf` will receive a new consistent
1097         whitespace prefix and metadata applied by :class:`BracketTracker`.
1098         Trailing commas are maybe removed, unpacked for loop variables are
1099         demoted from being delimiters.
1100
1101         Inline comments are put aside.
1102         """
1103         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1104         if not has_value:
1105             return
1106
1107         if token.COLON == leaf.type and self.is_class_paren_empty:
1108             del self.leaves[-2:]
1109         if self.leaves and not preformatted:
1110             # Note: at this point leaf.prefix should be empty except for
1111             # imports, for which we only preserve newlines.
1112             leaf.prefix += whitespace(
1113                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1114             )
1115         if self.inside_brackets or not preformatted:
1116             self.bracket_tracker.mark(leaf)
1117             self.maybe_remove_trailing_comma(leaf)
1118         if not self.append_comment(leaf):
1119             self.leaves.append(leaf)
1120
1121     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1122         """Like :func:`append()` but disallow invalid standalone comment structure.
1123
1124         Raises ValueError when any `leaf` is appended after a standalone comment
1125         or when a standalone comment is not the first leaf on the line.
1126         """
1127         if self.bracket_tracker.depth == 0:
1128             if self.is_comment:
1129                 raise ValueError("cannot append to standalone comments")
1130
1131             if self.leaves and leaf.type == STANDALONE_COMMENT:
1132                 raise ValueError(
1133                     "cannot append standalone comments to a populated line"
1134                 )
1135
1136         self.append(leaf, preformatted=preformatted)
1137
1138     @property
1139     def is_comment(self) -> bool:
1140         """Is this line a standalone comment?"""
1141         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1142
1143     @property
1144     def is_decorator(self) -> bool:
1145         """Is this line a decorator?"""
1146         return bool(self) and self.leaves[0].type == token.AT
1147
1148     @property
1149     def is_import(self) -> bool:
1150         """Is this an import line?"""
1151         return bool(self) and is_import(self.leaves[0])
1152
1153     @property
1154     def is_class(self) -> bool:
1155         """Is this line a class definition?"""
1156         return (
1157             bool(self)
1158             and self.leaves[0].type == token.NAME
1159             and self.leaves[0].value == "class"
1160         )
1161
1162     @property
1163     def is_stub_class(self) -> bool:
1164         """Is this line a class definition with a body consisting only of "..."?"""
1165         return self.is_class and self.leaves[-3:] == [
1166             Leaf(token.DOT, ".") for _ in range(3)
1167         ]
1168
1169     @property
1170     def is_def(self) -> bool:
1171         """Is this a function definition? (Also returns True for async defs.)"""
1172         try:
1173             first_leaf = self.leaves[0]
1174         except IndexError:
1175             return False
1176
1177         try:
1178             second_leaf: Optional[Leaf] = self.leaves[1]
1179         except IndexError:
1180             second_leaf = None
1181         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1182             first_leaf.type == token.ASYNC
1183             and second_leaf is not None
1184             and second_leaf.type == token.NAME
1185             and second_leaf.value == "def"
1186         )
1187
1188     @property
1189     def is_class_paren_empty(self) -> bool:
1190         """Is this a class with no base classes but using parentheses?
1191
1192         Those are unnecessary and should be removed.
1193         """
1194         return (
1195             bool(self)
1196             and len(self.leaves) == 4
1197             and self.is_class
1198             and self.leaves[2].type == token.LPAR
1199             and self.leaves[2].value == "("
1200             and self.leaves[3].type == token.RPAR
1201             and self.leaves[3].value == ")"
1202         )
1203
1204     @property
1205     def is_triple_quoted_string(self) -> bool:
1206         """Is the line a triple quoted string?"""
1207         return (
1208             bool(self)
1209             and self.leaves[0].type == token.STRING
1210             and self.leaves[0].value.startswith(('"""', "'''"))
1211         )
1212
1213     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1214         """If so, needs to be split before emitting."""
1215         for leaf in self.leaves:
1216             if leaf.type == STANDALONE_COMMENT:
1217                 if leaf.bracket_depth <= depth_limit:
1218                     return True
1219         return False
1220
1221     def contains_inner_type_comments(self) -> bool:
1222         ignored_ids = set()
1223         try:
1224             last_leaf = self.leaves[-1]
1225             ignored_ids.add(id(last_leaf))
1226             if last_leaf.type == token.COMMA:
1227                 # When trailing commas are inserted by Black for consistency, comments
1228                 # after the previous last element are not moved (they don't have to,
1229                 # rendering will still be correct).  So we ignore trailing commas.
1230                 last_leaf = self.leaves[-2]
1231                 ignored_ids.add(id(last_leaf))
1232         except IndexError:
1233             return False
1234
1235         for leaf_id, comments in self.comments.items():
1236             if leaf_id in ignored_ids:
1237                 continue
1238
1239             for comment in comments:
1240                 if is_type_comment(comment):
1241                     return True
1242
1243         return False
1244
1245     def contains_multiline_strings(self) -> bool:
1246         for leaf in self.leaves:
1247             if is_multiline_string(leaf):
1248                 return True
1249
1250         return False
1251
1252     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1253         """Remove trailing comma if there is one and it's safe."""
1254         if not (
1255             self.leaves
1256             and self.leaves[-1].type == token.COMMA
1257             and closing.type in CLOSING_BRACKETS
1258         ):
1259             return False
1260
1261         if closing.type == token.RBRACE:
1262             self.remove_trailing_comma()
1263             return True
1264
1265         if closing.type == token.RSQB:
1266             comma = self.leaves[-1]
1267             if comma.parent and comma.parent.type == syms.listmaker:
1268                 self.remove_trailing_comma()
1269                 return True
1270
1271         # For parens let's check if it's safe to remove the comma.
1272         # Imports are always safe.
1273         if self.is_import:
1274             self.remove_trailing_comma()
1275             return True
1276
1277         # Otherwise, if the trailing one is the only one, we might mistakenly
1278         # change a tuple into a different type by removing the comma.
1279         depth = closing.bracket_depth + 1
1280         commas = 0
1281         opening = closing.opening_bracket
1282         for _opening_index, leaf in enumerate(self.leaves):
1283             if leaf is opening:
1284                 break
1285
1286         else:
1287             return False
1288
1289         for leaf in self.leaves[_opening_index + 1 :]:
1290             if leaf is closing:
1291                 break
1292
1293             bracket_depth = leaf.bracket_depth
1294             if bracket_depth == depth and leaf.type == token.COMMA:
1295                 commas += 1
1296                 if leaf.parent and leaf.parent.type == syms.arglist:
1297                     commas += 1
1298                     break
1299
1300         if commas > 1:
1301             self.remove_trailing_comma()
1302             return True
1303
1304         return False
1305
1306     def append_comment(self, comment: Leaf) -> bool:
1307         """Add an inline or standalone comment to the line."""
1308         if (
1309             comment.type == STANDALONE_COMMENT
1310             and self.bracket_tracker.any_open_brackets()
1311         ):
1312             comment.prefix = ""
1313             return False
1314
1315         if comment.type != token.COMMENT:
1316             return False
1317
1318         if not self.leaves:
1319             comment.type = STANDALONE_COMMENT
1320             comment.prefix = ""
1321             return False
1322
1323         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1324         return True
1325
1326     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1327         """Generate comments that should appear directly after `leaf`."""
1328         return self.comments.get(id(leaf), [])
1329
1330     def remove_trailing_comma(self) -> None:
1331         """Remove the trailing comma and moves the comments attached to it."""
1332         trailing_comma = self.leaves.pop()
1333         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1334         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1335             trailing_comma_comments
1336         )
1337
1338     def is_complex_subscript(self, leaf: Leaf) -> bool:
1339         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1340         open_lsqb = self.bracket_tracker.get_open_lsqb()
1341         if open_lsqb is None:
1342             return False
1343
1344         subscript_start = open_lsqb.next_sibling
1345
1346         if isinstance(subscript_start, Node):
1347             if subscript_start.type == syms.listmaker:
1348                 return False
1349
1350             if subscript_start.type == syms.subscriptlist:
1351                 subscript_start = child_towards(subscript_start, leaf)
1352         return subscript_start is not None and any(
1353             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1354         )
1355
1356     def __str__(self) -> str:
1357         """Render the line."""
1358         if not self:
1359             return "\n"
1360
1361         indent = "    " * self.depth
1362         leaves = iter(self.leaves)
1363         first = next(leaves)
1364         res = f"{first.prefix}{indent}{first.value}"
1365         for leaf in leaves:
1366             res += str(leaf)
1367         for comment in itertools.chain.from_iterable(self.comments.values()):
1368             res += str(comment)
1369         return res + "\n"
1370
1371     def __bool__(self) -> bool:
1372         """Return True if the line has leaves or comments."""
1373         return bool(self.leaves or self.comments)
1374
1375
1376 @dataclass
1377 class EmptyLineTracker:
1378     """Provides a stateful method that returns the number of potential extra
1379     empty lines needed before and after the currently processed line.
1380
1381     Note: this tracker works on lines that haven't been split yet.  It assumes
1382     the prefix of the first leaf consists of optional newlines.  Those newlines
1383     are consumed by `maybe_empty_lines()` and included in the computation.
1384     """
1385
1386     is_pyi: bool = False
1387     previous_line: Optional[Line] = None
1388     previous_after: int = 0
1389     previous_defs: List[int] = Factory(list)
1390
1391     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1392         """Return the number of extra empty lines before and after the `current_line`.
1393
1394         This is for separating `def`, `async def` and `class` with extra empty
1395         lines (two on module-level).
1396         """
1397         before, after = self._maybe_empty_lines(current_line)
1398         before -= self.previous_after
1399         self.previous_after = after
1400         self.previous_line = current_line
1401         return before, after
1402
1403     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1404         max_allowed = 1
1405         if current_line.depth == 0:
1406             max_allowed = 1 if self.is_pyi else 2
1407         if current_line.leaves:
1408             # Consume the first leaf's extra newlines.
1409             first_leaf = current_line.leaves[0]
1410             before = first_leaf.prefix.count("\n")
1411             before = min(before, max_allowed)
1412             first_leaf.prefix = ""
1413         else:
1414             before = 0
1415         depth = current_line.depth
1416         while self.previous_defs and self.previous_defs[-1] >= depth:
1417             self.previous_defs.pop()
1418             if self.is_pyi:
1419                 before = 0 if depth else 1
1420             else:
1421                 before = 1 if depth else 2
1422         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1423             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1424
1425         if (
1426             self.previous_line
1427             and self.previous_line.is_import
1428             and not current_line.is_import
1429             and depth == self.previous_line.depth
1430         ):
1431             return (before or 1), 0
1432
1433         if (
1434             self.previous_line
1435             and self.previous_line.is_class
1436             and current_line.is_triple_quoted_string
1437         ):
1438             return before, 1
1439
1440         return before, 0
1441
1442     def _maybe_empty_lines_for_class_or_def(
1443         self, current_line: Line, before: int
1444     ) -> Tuple[int, int]:
1445         if not current_line.is_decorator:
1446             self.previous_defs.append(current_line.depth)
1447         if self.previous_line is None:
1448             # Don't insert empty lines before the first line in the file.
1449             return 0, 0
1450
1451         if self.previous_line.is_decorator:
1452             return 0, 0
1453
1454         if self.previous_line.depth < current_line.depth and (
1455             self.previous_line.is_class or self.previous_line.is_def
1456         ):
1457             return 0, 0
1458
1459         if (
1460             self.previous_line.is_comment
1461             and self.previous_line.depth == current_line.depth
1462             and before == 0
1463         ):
1464             return 0, 0
1465
1466         if self.is_pyi:
1467             if self.previous_line.depth > current_line.depth:
1468                 newlines = 1
1469             elif current_line.is_class or self.previous_line.is_class:
1470                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1471                     # No blank line between classes with an empty body
1472                     newlines = 0
1473                 else:
1474                     newlines = 1
1475             elif current_line.is_def and not self.previous_line.is_def:
1476                 # Blank line between a block of functions and a block of non-functions
1477                 newlines = 1
1478             else:
1479                 newlines = 0
1480         else:
1481             newlines = 2
1482         if current_line.depth and newlines:
1483             newlines -= 1
1484         return newlines, 0
1485
1486
1487 @dataclass
1488 class LineGenerator(Visitor[Line]):
1489     """Generates reformatted Line objects.  Empty lines are not emitted.
1490
1491     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1492     in ways that will no longer stringify to valid Python code on the tree.
1493     """
1494
1495     is_pyi: bool = False
1496     normalize_strings: bool = True
1497     current_line: Line = Factory(Line)
1498     remove_u_prefix: bool = False
1499
1500     def line(self, indent: int = 0) -> Iterator[Line]:
1501         """Generate a line.
1502
1503         If the line is empty, only emit if it makes sense.
1504         If the line is too long, split it first and then generate.
1505
1506         If any lines were generated, set up a new current_line.
1507         """
1508         if not self.current_line:
1509             self.current_line.depth += indent
1510             return  # Line is empty, don't emit. Creating a new one unnecessary.
1511
1512         complete_line = self.current_line
1513         self.current_line = Line(depth=complete_line.depth + indent)
1514         yield complete_line
1515
1516     def visit_default(self, node: LN) -> Iterator[Line]:
1517         """Default `visit_*()` implementation. Recurses to children of `node`."""
1518         if isinstance(node, Leaf):
1519             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1520             for comment in generate_comments(node):
1521                 if any_open_brackets:
1522                     # any comment within brackets is subject to splitting
1523                     self.current_line.append(comment)
1524                 elif comment.type == token.COMMENT:
1525                     # regular trailing comment
1526                     self.current_line.append(comment)
1527                     yield from self.line()
1528
1529                 else:
1530                     # regular standalone comment
1531                     yield from self.line()
1532
1533                     self.current_line.append(comment)
1534                     yield from self.line()
1535
1536             normalize_prefix(node, inside_brackets=any_open_brackets)
1537             if self.normalize_strings and node.type == token.STRING:
1538                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1539                 normalize_string_quotes(node)
1540             if node.type == token.NUMBER:
1541                 normalize_numeric_literal(node)
1542             if node.type not in WHITESPACE:
1543                 self.current_line.append(node)
1544         yield from super().visit_default(node)
1545
1546     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1547         """Increase indentation level, maybe yield a line."""
1548         # In blib2to3 INDENT never holds comments.
1549         yield from self.line(+1)
1550         yield from self.visit_default(node)
1551
1552     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1553         """Decrease indentation level, maybe yield a line."""
1554         # The current line might still wait for trailing comments.  At DEDENT time
1555         # there won't be any (they would be prefixes on the preceding NEWLINE).
1556         # Emit the line then.
1557         yield from self.line()
1558
1559         # While DEDENT has no value, its prefix may contain standalone comments
1560         # that belong to the current indentation level.  Get 'em.
1561         yield from self.visit_default(node)
1562
1563         # Finally, emit the dedent.
1564         yield from self.line(-1)
1565
1566     def visit_stmt(
1567         self, node: Node, keywords: Set[str], parens: Set[str]
1568     ) -> Iterator[Line]:
1569         """Visit a statement.
1570
1571         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1572         `def`, `with`, `class`, `assert` and assignments.
1573
1574         The relevant Python language `keywords` for a given statement will be
1575         NAME leaves within it. This methods puts those on a separate line.
1576
1577         `parens` holds a set of string leaf values immediately after which
1578         invisible parens should be put.
1579         """
1580         normalize_invisible_parens(node, parens_after=parens)
1581         for child in node.children:
1582             if child.type == token.NAME and child.value in keywords:  # type: ignore
1583                 yield from self.line()
1584
1585             yield from self.visit(child)
1586
1587     def visit_suite(self, node: Node) -> Iterator[Line]:
1588         """Visit a suite."""
1589         if self.is_pyi and is_stub_suite(node):
1590             yield from self.visit(node.children[2])
1591         else:
1592             yield from self.visit_default(node)
1593
1594     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1595         """Visit a statement without nested statements."""
1596         is_suite_like = node.parent and node.parent.type in STATEMENT
1597         if is_suite_like:
1598             if self.is_pyi and is_stub_body(node):
1599                 yield from self.visit_default(node)
1600             else:
1601                 yield from self.line(+1)
1602                 yield from self.visit_default(node)
1603                 yield from self.line(-1)
1604
1605         else:
1606             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1607                 yield from self.line()
1608             yield from self.visit_default(node)
1609
1610     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1611         """Visit `async def`, `async for`, `async with`."""
1612         yield from self.line()
1613
1614         children = iter(node.children)
1615         for child in children:
1616             yield from self.visit(child)
1617
1618             if child.type == token.ASYNC:
1619                 break
1620
1621         internal_stmt = next(children)
1622         for child in internal_stmt.children:
1623             yield from self.visit(child)
1624
1625     def visit_decorators(self, node: Node) -> Iterator[Line]:
1626         """Visit decorators."""
1627         for child in node.children:
1628             yield from self.line()
1629             yield from self.visit(child)
1630
1631     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1632         """Remove a semicolon and put the other statement on a separate line."""
1633         yield from self.line()
1634
1635     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1636         """End of file. Process outstanding comments and end with a newline."""
1637         yield from self.visit_default(leaf)
1638         yield from self.line()
1639
1640     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1641         if not self.current_line.bracket_tracker.any_open_brackets():
1642             yield from self.line()
1643         yield from self.visit_default(leaf)
1644
1645     def __attrs_post_init__(self) -> None:
1646         """You are in a twisty little maze of passages."""
1647         v = self.visit_stmt
1648         Ø: Set[str] = set()
1649         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1650         self.visit_if_stmt = partial(
1651             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1652         )
1653         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1654         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1655         self.visit_try_stmt = partial(
1656             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1657         )
1658         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1659         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1660         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1661         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1662         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1663         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1664         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1665         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1666         self.visit_async_funcdef = self.visit_async_stmt
1667         self.visit_decorated = self.visit_decorators
1668
1669
1670 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1671 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1672 OPENING_BRACKETS = set(BRACKET.keys())
1673 CLOSING_BRACKETS = set(BRACKET.values())
1674 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1675 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1676
1677
1678 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1679     """Return whitespace prefix if needed for the given `leaf`.
1680
1681     `complex_subscript` signals whether the given leaf is part of a subscription
1682     which has non-trivial arguments, like arithmetic expressions or function calls.
1683     """
1684     NO = ""
1685     SPACE = " "
1686     DOUBLESPACE = "  "
1687     t = leaf.type
1688     p = leaf.parent
1689     v = leaf.value
1690     if t in ALWAYS_NO_SPACE:
1691         return NO
1692
1693     if t == token.COMMENT:
1694         return DOUBLESPACE
1695
1696     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1697     if t == token.COLON and p.type not in {
1698         syms.subscript,
1699         syms.subscriptlist,
1700         syms.sliceop,
1701     }:
1702         return NO
1703
1704     prev = leaf.prev_sibling
1705     if not prev:
1706         prevp = preceding_leaf(p)
1707         if not prevp or prevp.type in OPENING_BRACKETS:
1708             return NO
1709
1710         if t == token.COLON:
1711             if prevp.type == token.COLON:
1712                 return NO
1713
1714             elif prevp.type != token.COMMA and not complex_subscript:
1715                 return NO
1716
1717             return SPACE
1718
1719         if prevp.type == token.EQUAL:
1720             if prevp.parent:
1721                 if prevp.parent.type in {
1722                     syms.arglist,
1723                     syms.argument,
1724                     syms.parameters,
1725                     syms.varargslist,
1726                 }:
1727                     return NO
1728
1729                 elif prevp.parent.type == syms.typedargslist:
1730                     # A bit hacky: if the equal sign has whitespace, it means we
1731                     # previously found it's a typed argument.  So, we're using
1732                     # that, too.
1733                     return prevp.prefix
1734
1735         elif prevp.type in STARS:
1736             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1737                 return NO
1738
1739         elif prevp.type == token.COLON:
1740             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1741                 return SPACE if complex_subscript else NO
1742
1743         elif (
1744             prevp.parent
1745             and prevp.parent.type == syms.factor
1746             and prevp.type in MATH_OPERATORS
1747         ):
1748             return NO
1749
1750         elif (
1751             prevp.type == token.RIGHTSHIFT
1752             and prevp.parent
1753             and prevp.parent.type == syms.shift_expr
1754             and prevp.prev_sibling
1755             and prevp.prev_sibling.type == token.NAME
1756             and prevp.prev_sibling.value == "print"  # type: ignore
1757         ):
1758             # Python 2 print chevron
1759             return NO
1760
1761     elif prev.type in OPENING_BRACKETS:
1762         return NO
1763
1764     if p.type in {syms.parameters, syms.arglist}:
1765         # untyped function signatures or calls
1766         if not prev or prev.type != token.COMMA:
1767             return NO
1768
1769     elif p.type == syms.varargslist:
1770         # lambdas
1771         if prev and prev.type != token.COMMA:
1772             return NO
1773
1774     elif p.type == syms.typedargslist:
1775         # typed function signatures
1776         if not prev:
1777             return NO
1778
1779         if t == token.EQUAL:
1780             if prev.type != syms.tname:
1781                 return NO
1782
1783         elif prev.type == token.EQUAL:
1784             # A bit hacky: if the equal sign has whitespace, it means we
1785             # previously found it's a typed argument.  So, we're using that, too.
1786             return prev.prefix
1787
1788         elif prev.type != token.COMMA:
1789             return NO
1790
1791     elif p.type == syms.tname:
1792         # type names
1793         if not prev:
1794             prevp = preceding_leaf(p)
1795             if not prevp or prevp.type != token.COMMA:
1796                 return NO
1797
1798     elif p.type == syms.trailer:
1799         # attributes and calls
1800         if t == token.LPAR or t == token.RPAR:
1801             return NO
1802
1803         if not prev:
1804             if t == token.DOT:
1805                 prevp = preceding_leaf(p)
1806                 if not prevp or prevp.type != token.NUMBER:
1807                     return NO
1808
1809             elif t == token.LSQB:
1810                 return NO
1811
1812         elif prev.type != token.COMMA:
1813             return NO
1814
1815     elif p.type == syms.argument:
1816         # single argument
1817         if t == token.EQUAL:
1818             return NO
1819
1820         if not prev:
1821             prevp = preceding_leaf(p)
1822             if not prevp or prevp.type == token.LPAR:
1823                 return NO
1824
1825         elif prev.type in {token.EQUAL} | STARS:
1826             return NO
1827
1828     elif p.type == syms.decorator:
1829         # decorators
1830         return NO
1831
1832     elif p.type == syms.dotted_name:
1833         if prev:
1834             return NO
1835
1836         prevp = preceding_leaf(p)
1837         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1838             return NO
1839
1840     elif p.type == syms.classdef:
1841         if t == token.LPAR:
1842             return NO
1843
1844         if prev and prev.type == token.LPAR:
1845             return NO
1846
1847     elif p.type in {syms.subscript, syms.sliceop}:
1848         # indexing
1849         if not prev:
1850             assert p.parent is not None, "subscripts are always parented"
1851             if p.parent.type == syms.subscriptlist:
1852                 return SPACE
1853
1854             return NO
1855
1856         elif not complex_subscript:
1857             return NO
1858
1859     elif p.type == syms.atom:
1860         if prev and t == token.DOT:
1861             # dots, but not the first one.
1862             return NO
1863
1864     elif p.type == syms.dictsetmaker:
1865         # dict unpacking
1866         if prev and prev.type == token.DOUBLESTAR:
1867             return NO
1868
1869     elif p.type in {syms.factor, syms.star_expr}:
1870         # unary ops
1871         if not prev:
1872             prevp = preceding_leaf(p)
1873             if not prevp or prevp.type in OPENING_BRACKETS:
1874                 return NO
1875
1876             prevp_parent = prevp.parent
1877             assert prevp_parent is not None
1878             if prevp.type == token.COLON and prevp_parent.type in {
1879                 syms.subscript,
1880                 syms.sliceop,
1881             }:
1882                 return NO
1883
1884             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1885                 return NO
1886
1887         elif t in {token.NAME, token.NUMBER, token.STRING}:
1888             return NO
1889
1890     elif p.type == syms.import_from:
1891         if t == token.DOT:
1892             if prev and prev.type == token.DOT:
1893                 return NO
1894
1895         elif t == token.NAME:
1896             if v == "import":
1897                 return SPACE
1898
1899             if prev and prev.type == token.DOT:
1900                 return NO
1901
1902     elif p.type == syms.sliceop:
1903         return NO
1904
1905     return SPACE
1906
1907
1908 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1909     """Return the first leaf that precedes `node`, if any."""
1910     while node:
1911         res = node.prev_sibling
1912         if res:
1913             if isinstance(res, Leaf):
1914                 return res
1915
1916             try:
1917                 return list(res.leaves())[-1]
1918
1919             except IndexError:
1920                 return None
1921
1922         node = node.parent
1923     return None
1924
1925
1926 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1927     """Return the child of `ancestor` that contains `descendant`."""
1928     node: Optional[LN] = descendant
1929     while node and node.parent != ancestor:
1930         node = node.parent
1931     return node
1932
1933
1934 def container_of(leaf: Leaf) -> LN:
1935     """Return `leaf` or one of its ancestors that is the topmost container of it.
1936
1937     By "container" we mean a node where `leaf` is the very first child.
1938     """
1939     same_prefix = leaf.prefix
1940     container: LN = leaf
1941     while container:
1942         parent = container.parent
1943         if parent is None:
1944             break
1945
1946         if parent.children[0].prefix != same_prefix:
1947             break
1948
1949         if parent.type == syms.file_input:
1950             break
1951
1952         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1953             break
1954
1955         container = parent
1956     return container
1957
1958
1959 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1960     """Return the priority of the `leaf` delimiter, given a line break after it.
1961
1962     The delimiter priorities returned here are from those delimiters that would
1963     cause a line break after themselves.
1964
1965     Higher numbers are higher priority.
1966     """
1967     if leaf.type == token.COMMA:
1968         return COMMA_PRIORITY
1969
1970     return 0
1971
1972
1973 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1974     """Return the priority of the `leaf` delimiter, given a line break before it.
1975
1976     The delimiter priorities returned here are from those delimiters that would
1977     cause a line break before themselves.
1978
1979     Higher numbers are higher priority.
1980     """
1981     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1982         # * and ** might also be MATH_OPERATORS but in this case they are not.
1983         # Don't treat them as a delimiter.
1984         return 0
1985
1986     if (
1987         leaf.type == token.DOT
1988         and leaf.parent
1989         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1990         and (previous is None or previous.type in CLOSING_BRACKETS)
1991     ):
1992         return DOT_PRIORITY
1993
1994     if (
1995         leaf.type in MATH_OPERATORS
1996         and leaf.parent
1997         and leaf.parent.type not in {syms.factor, syms.star_expr}
1998     ):
1999         return MATH_PRIORITIES[leaf.type]
2000
2001     if leaf.type in COMPARATORS:
2002         return COMPARATOR_PRIORITY
2003
2004     if (
2005         leaf.type == token.STRING
2006         and previous is not None
2007         and previous.type == token.STRING
2008     ):
2009         return STRING_PRIORITY
2010
2011     if leaf.type not in {token.NAME, token.ASYNC}:
2012         return 0
2013
2014     if (
2015         leaf.value == "for"
2016         and leaf.parent
2017         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2018         or leaf.type == token.ASYNC
2019     ):
2020         if (
2021             not isinstance(leaf.prev_sibling, Leaf)
2022             or leaf.prev_sibling.value != "async"
2023         ):
2024             return COMPREHENSION_PRIORITY
2025
2026     if (
2027         leaf.value == "if"
2028         and leaf.parent
2029         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2030     ):
2031         return COMPREHENSION_PRIORITY
2032
2033     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2034         return TERNARY_PRIORITY
2035
2036     if leaf.value == "is":
2037         return COMPARATOR_PRIORITY
2038
2039     if (
2040         leaf.value == "in"
2041         and leaf.parent
2042         and leaf.parent.type in {syms.comp_op, syms.comparison}
2043         and not (
2044             previous is not None
2045             and previous.type == token.NAME
2046             and previous.value == "not"
2047         )
2048     ):
2049         return COMPARATOR_PRIORITY
2050
2051     if (
2052         leaf.value == "not"
2053         and leaf.parent
2054         and leaf.parent.type == syms.comp_op
2055         and not (
2056             previous is not None
2057             and previous.type == token.NAME
2058             and previous.value == "is"
2059         )
2060     ):
2061         return COMPARATOR_PRIORITY
2062
2063     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2064         return LOGIC_PRIORITY
2065
2066     return 0
2067
2068
2069 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2070 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2071
2072
2073 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2074     """Clean the prefix of the `leaf` and generate comments from it, if any.
2075
2076     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2077     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2078     move because it does away with modifying the grammar to include all the
2079     possible places in which comments can be placed.
2080
2081     The sad consequence for us though is that comments don't "belong" anywhere.
2082     This is why this function generates simple parentless Leaf objects for
2083     comments.  We simply don't know what the correct parent should be.
2084
2085     No matter though, we can live without this.  We really only need to
2086     differentiate between inline and standalone comments.  The latter don't
2087     share the line with any code.
2088
2089     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2090     are emitted with a fake STANDALONE_COMMENT token identifier.
2091     """
2092     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2093         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2094
2095
2096 @dataclass
2097 class ProtoComment:
2098     """Describes a piece of syntax that is a comment.
2099
2100     It's not a :class:`blib2to3.pytree.Leaf` so that:
2101
2102     * it can be cached (`Leaf` objects should not be reused more than once as
2103       they store their lineno, column, prefix, and parent information);
2104     * `newlines` and `consumed` fields are kept separate from the `value`. This
2105       simplifies handling of special marker comments like ``# fmt: off/on``.
2106     """
2107
2108     type: int  # token.COMMENT or STANDALONE_COMMENT
2109     value: str  # content of the comment
2110     newlines: int  # how many newlines before the comment
2111     consumed: int  # how many characters of the original leaf's prefix did we consume
2112
2113
2114 @lru_cache(maxsize=4096)
2115 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2116     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2117     result: List[ProtoComment] = []
2118     if not prefix or "#" not in prefix:
2119         return result
2120
2121     consumed = 0
2122     nlines = 0
2123     for index, line in enumerate(prefix.split("\n")):
2124         consumed += len(line) + 1  # adding the length of the split '\n'
2125         line = line.lstrip()
2126         if not line:
2127             nlines += 1
2128         if not line.startswith("#"):
2129             continue
2130
2131         if index == 0 and not is_endmarker:
2132             comment_type = token.COMMENT  # simple trailing comment
2133         else:
2134             comment_type = STANDALONE_COMMENT
2135         comment = make_comment(line)
2136         result.append(
2137             ProtoComment(
2138                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2139             )
2140         )
2141         nlines = 0
2142     return result
2143
2144
2145 def make_comment(content: str) -> str:
2146     """Return a consistently formatted comment from the given `content` string.
2147
2148     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2149     space between the hash sign and the content.
2150
2151     If `content` didn't start with a hash sign, one is provided.
2152     """
2153     content = content.rstrip()
2154     if not content:
2155         return "#"
2156
2157     if content[0] == "#":
2158         content = content[1:]
2159     if content and content[0] not in " !:#'%":
2160         content = " " + content
2161     return "#" + content
2162
2163
2164 def split_line(
2165     line: Line,
2166     line_length: int,
2167     inner: bool = False,
2168     features: Collection[Feature] = (),
2169 ) -> Iterator[Line]:
2170     """Split a `line` into potentially many lines.
2171
2172     They should fit in the allotted `line_length` but might not be able to.
2173     `inner` signifies that there were a pair of brackets somewhere around the
2174     current `line`, possibly transitively. This means we can fallback to splitting
2175     by delimiters if the LHS/RHS don't yield any results.
2176
2177     `features` are syntactical features that may be used in the output.
2178     """
2179     if line.is_comment:
2180         yield line
2181         return
2182
2183     line_str = str(line).strip("\n")
2184
2185     if (
2186         not line.contains_inner_type_comments()
2187         and not line.should_explode
2188         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2189     ):
2190         yield line
2191         return
2192
2193     split_funcs: List[SplitFunc]
2194     if line.is_def:
2195         split_funcs = [left_hand_split]
2196     else:
2197
2198         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2199             for omit in generate_trailers_to_omit(line, line_length):
2200                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2201                 if is_line_short_enough(lines[0], line_length=line_length):
2202                     yield from lines
2203                     return
2204
2205             # All splits failed, best effort split with no omits.
2206             # This mostly happens to multiline strings that are by definition
2207             # reported as not fitting a single line.
2208             yield from right_hand_split(line, line_length, features=features)
2209
2210         if line.inside_brackets:
2211             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2212         else:
2213             split_funcs = [rhs]
2214     for split_func in split_funcs:
2215         # We are accumulating lines in `result` because we might want to abort
2216         # mission and return the original line in the end, or attempt a different
2217         # split altogether.
2218         result: List[Line] = []
2219         try:
2220             for l in split_func(line, features):
2221                 if str(l).strip("\n") == line_str:
2222                     raise CannotSplit("Split function returned an unchanged result")
2223
2224                 result.extend(
2225                     split_line(
2226                         l, line_length=line_length, inner=True, features=features
2227                     )
2228                 )
2229         except CannotSplit:
2230             continue
2231
2232         else:
2233             yield from result
2234             break
2235
2236     else:
2237         yield line
2238
2239
2240 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2241     """Split line into many lines, starting with the first matching bracket pair.
2242
2243     Note: this usually looks weird, only use this for function definitions.
2244     Prefer RHS otherwise.  This is why this function is not symmetrical with
2245     :func:`right_hand_split` which also handles optional parentheses.
2246     """
2247     tail_leaves: List[Leaf] = []
2248     body_leaves: List[Leaf] = []
2249     head_leaves: List[Leaf] = []
2250     current_leaves = head_leaves
2251     matching_bracket = None
2252     for leaf in line.leaves:
2253         if (
2254             current_leaves is body_leaves
2255             and leaf.type in CLOSING_BRACKETS
2256             and leaf.opening_bracket is matching_bracket
2257         ):
2258             current_leaves = tail_leaves if body_leaves else head_leaves
2259         current_leaves.append(leaf)
2260         if current_leaves is head_leaves:
2261             if leaf.type in OPENING_BRACKETS:
2262                 matching_bracket = leaf
2263                 current_leaves = body_leaves
2264     if not matching_bracket:
2265         raise CannotSplit("No brackets found")
2266
2267     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2268     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2269     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2270     bracket_split_succeeded_or_raise(head, body, tail)
2271     for result in (head, body, tail):
2272         if result:
2273             yield result
2274
2275
2276 def right_hand_split(
2277     line: Line,
2278     line_length: int,
2279     features: Collection[Feature] = (),
2280     omit: Collection[LeafID] = (),
2281 ) -> Iterator[Line]:
2282     """Split line into many lines, starting with the last matching bracket pair.
2283
2284     If the split was by optional parentheses, attempt splitting without them, too.
2285     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2286     this split.
2287
2288     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2289     """
2290     tail_leaves: List[Leaf] = []
2291     body_leaves: List[Leaf] = []
2292     head_leaves: List[Leaf] = []
2293     current_leaves = tail_leaves
2294     opening_bracket = None
2295     closing_bracket = None
2296     for leaf in reversed(line.leaves):
2297         if current_leaves is body_leaves:
2298             if leaf is opening_bracket:
2299                 current_leaves = head_leaves if body_leaves else tail_leaves
2300         current_leaves.append(leaf)
2301         if current_leaves is tail_leaves:
2302             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2303                 opening_bracket = leaf.opening_bracket
2304                 closing_bracket = leaf
2305                 current_leaves = body_leaves
2306     if not (opening_bracket and closing_bracket and head_leaves):
2307         # If there is no opening or closing_bracket that means the split failed and
2308         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2309         # the matching `opening_bracket` wasn't available on `line` anymore.
2310         raise CannotSplit("No brackets found")
2311
2312     tail_leaves.reverse()
2313     body_leaves.reverse()
2314     head_leaves.reverse()
2315     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2316     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2317     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2318     bracket_split_succeeded_or_raise(head, body, tail)
2319     if (
2320         # the body shouldn't be exploded
2321         not body.should_explode
2322         # the opening bracket is an optional paren
2323         and opening_bracket.type == token.LPAR
2324         and not opening_bracket.value
2325         # the closing bracket is an optional paren
2326         and closing_bracket.type == token.RPAR
2327         and not closing_bracket.value
2328         # it's not an import (optional parens are the only thing we can split on
2329         # in this case; attempting a split without them is a waste of time)
2330         and not line.is_import
2331         # there are no standalone comments in the body
2332         and not body.contains_standalone_comments(0)
2333         # and we can actually remove the parens
2334         and can_omit_invisible_parens(body, line_length)
2335     ):
2336         omit = {id(closing_bracket), *omit}
2337         try:
2338             yield from right_hand_split(line, line_length, features=features, omit=omit)
2339             return
2340
2341         except CannotSplit:
2342             if not (
2343                 can_be_split(body)
2344                 or is_line_short_enough(body, line_length=line_length)
2345             ):
2346                 raise CannotSplit(
2347                     "Splitting failed, body is still too long and can't be split."
2348                 )
2349
2350             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2351                 raise CannotSplit(
2352                     "The current optional pair of parentheses is bound to fail to "
2353                     "satisfy the splitting algorithm because the head or the tail "
2354                     "contains multiline strings which by definition never fit one "
2355                     "line."
2356                 )
2357
2358     ensure_visible(opening_bracket)
2359     ensure_visible(closing_bracket)
2360     for result in (head, body, tail):
2361         if result:
2362             yield result
2363
2364
2365 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2366     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2367
2368     Do nothing otherwise.
2369
2370     A left- or right-hand split is based on a pair of brackets. Content before
2371     (and including) the opening bracket is left on one line, content inside the
2372     brackets is put on a separate line, and finally content starting with and
2373     following the closing bracket is put on a separate line.
2374
2375     Those are called `head`, `body`, and `tail`, respectively. If the split
2376     produced the same line (all content in `head`) or ended up with an empty `body`
2377     and the `tail` is just the closing bracket, then it's considered failed.
2378     """
2379     tail_len = len(str(tail).strip())
2380     if not body:
2381         if tail_len == 0:
2382             raise CannotSplit("Splitting brackets produced the same line")
2383
2384         elif tail_len < 3:
2385             raise CannotSplit(
2386                 f"Splitting brackets on an empty body to save "
2387                 f"{tail_len} characters is not worth it"
2388             )
2389
2390
2391 def bracket_split_build_line(
2392     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2393 ) -> Line:
2394     """Return a new line with given `leaves` and respective comments from `original`.
2395
2396     If `is_body` is True, the result line is one-indented inside brackets and as such
2397     has its first leaf's prefix normalized and a trailing comma added when expected.
2398     """
2399     result = Line(depth=original.depth)
2400     if is_body:
2401         result.inside_brackets = True
2402         result.depth += 1
2403         if leaves:
2404             # Since body is a new indent level, remove spurious leading whitespace.
2405             normalize_prefix(leaves[0], inside_brackets=True)
2406             # Ensure a trailing comma when expected.
2407             if original.is_import:
2408                 if leaves[-1].type != token.COMMA:
2409                     leaves.append(Leaf(token.COMMA, ","))
2410     # Populate the line
2411     for leaf in leaves:
2412         result.append(leaf, preformatted=True)
2413         for comment_after in original.comments_after(leaf):
2414             result.append(comment_after, preformatted=True)
2415     if is_body:
2416         result.should_explode = should_explode(result, opening_bracket)
2417     return result
2418
2419
2420 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2421     """Normalize prefix of the first leaf in every line returned by `split_func`.
2422
2423     This is a decorator over relevant split functions.
2424     """
2425
2426     @wraps(split_func)
2427     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2428         for l in split_func(line, features):
2429             normalize_prefix(l.leaves[0], inside_brackets=True)
2430             yield l
2431
2432     return split_wrapper
2433
2434
2435 @dont_increase_indentation
2436 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2437     """Split according to delimiters of the highest priority.
2438
2439     If the appropriate Features are given, the split will add trailing commas
2440     also in function signatures and calls that contain `*` and `**`.
2441     """
2442     try:
2443         last_leaf = line.leaves[-1]
2444     except IndexError:
2445         raise CannotSplit("Line empty")
2446
2447     bt = line.bracket_tracker
2448     try:
2449         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2450     except ValueError:
2451         raise CannotSplit("No delimiters found")
2452
2453     if delimiter_priority == DOT_PRIORITY:
2454         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2455             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2456
2457     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2458     lowest_depth = sys.maxsize
2459     trailing_comma_safe = True
2460
2461     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2462         """Append `leaf` to current line or to new line if appending impossible."""
2463         nonlocal current_line
2464         try:
2465             current_line.append_safe(leaf, preformatted=True)
2466         except ValueError:
2467             yield current_line
2468
2469             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2470             current_line.append(leaf)
2471
2472     for leaf in line.leaves:
2473         yield from append_to_line(leaf)
2474
2475         for comment_after in line.comments_after(leaf):
2476             yield from append_to_line(comment_after)
2477
2478         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2479         if leaf.bracket_depth == lowest_depth:
2480             if is_vararg(leaf, within={syms.typedargslist}):
2481                 trailing_comma_safe = (
2482                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2483                 )
2484             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2485                 trailing_comma_safe = (
2486                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2487                 )
2488
2489         leaf_priority = bt.delimiters.get(id(leaf))
2490         if leaf_priority == delimiter_priority:
2491             yield current_line
2492
2493             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2494     if current_line:
2495         if (
2496             trailing_comma_safe
2497             and delimiter_priority == COMMA_PRIORITY
2498             and current_line.leaves[-1].type != token.COMMA
2499             and current_line.leaves[-1].type != STANDALONE_COMMENT
2500         ):
2501             current_line.append(Leaf(token.COMMA, ","))
2502         yield current_line
2503
2504
2505 @dont_increase_indentation
2506 def standalone_comment_split(
2507     line: Line, features: Collection[Feature] = ()
2508 ) -> Iterator[Line]:
2509     """Split standalone comments from the rest of the line."""
2510     if not line.contains_standalone_comments(0):
2511         raise CannotSplit("Line does not have any standalone comments")
2512
2513     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2514
2515     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2516         """Append `leaf` to current line or to new line if appending impossible."""
2517         nonlocal current_line
2518         try:
2519             current_line.append_safe(leaf, preformatted=True)
2520         except ValueError:
2521             yield current_line
2522
2523             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2524             current_line.append(leaf)
2525
2526     for leaf in line.leaves:
2527         yield from append_to_line(leaf)
2528
2529         for comment_after in line.comments_after(leaf):
2530             yield from append_to_line(comment_after)
2531
2532     if current_line:
2533         yield current_line
2534
2535
2536 def is_import(leaf: Leaf) -> bool:
2537     """Return True if the given leaf starts an import statement."""
2538     p = leaf.parent
2539     t = leaf.type
2540     v = leaf.value
2541     return bool(
2542         t == token.NAME
2543         and (
2544             (v == "import" and p and p.type == syms.import_name)
2545             or (v == "from" and p and p.type == syms.import_from)
2546         )
2547     )
2548
2549
2550 def is_type_comment(leaf: Leaf) -> bool:
2551     """Return True if the given leaf is a special comment.
2552     Only returns true for type comments for now."""
2553     t = leaf.type
2554     v = leaf.value
2555     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2556
2557
2558 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2559     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2560     else.
2561
2562     Note: don't use backslashes for formatting or you'll lose your voting rights.
2563     """
2564     if not inside_brackets:
2565         spl = leaf.prefix.split("#")
2566         if "\\" not in spl[0]:
2567             nl_count = spl[-1].count("\n")
2568             if len(spl) > 1:
2569                 nl_count -= 1
2570             leaf.prefix = "\n" * nl_count
2571             return
2572
2573     leaf.prefix = ""
2574
2575
2576 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2577     """Make all string prefixes lowercase.
2578
2579     If remove_u_prefix is given, also removes any u prefix from the string.
2580
2581     Note: Mutates its argument.
2582     """
2583     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2584     assert match is not None, f"failed to match string {leaf.value!r}"
2585     orig_prefix = match.group(1)
2586     new_prefix = orig_prefix.lower()
2587     if remove_u_prefix:
2588         new_prefix = new_prefix.replace("u", "")
2589     leaf.value = f"{new_prefix}{match.group(2)}"
2590
2591
2592 def normalize_string_quotes(leaf: Leaf) -> None:
2593     """Prefer double quotes but only if it doesn't cause more escaping.
2594
2595     Adds or removes backslashes as appropriate. Doesn't parse and fix
2596     strings nested in f-strings (yet).
2597
2598     Note: Mutates its argument.
2599     """
2600     value = leaf.value.lstrip("furbFURB")
2601     if value[:3] == '"""':
2602         return
2603
2604     elif value[:3] == "'''":
2605         orig_quote = "'''"
2606         new_quote = '"""'
2607     elif value[0] == '"':
2608         orig_quote = '"'
2609         new_quote = "'"
2610     else:
2611         orig_quote = "'"
2612         new_quote = '"'
2613     first_quote_pos = leaf.value.find(orig_quote)
2614     if first_quote_pos == -1:
2615         return  # There's an internal error
2616
2617     prefix = leaf.value[:first_quote_pos]
2618     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2619     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2620     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2621     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2622     if "r" in prefix.casefold():
2623         if unescaped_new_quote.search(body):
2624             # There's at least one unescaped new_quote in this raw string
2625             # so converting is impossible
2626             return
2627
2628         # Do not introduce or remove backslashes in raw strings
2629         new_body = body
2630     else:
2631         # remove unnecessary escapes
2632         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2633         if body != new_body:
2634             # Consider the string without unnecessary escapes as the original
2635             body = new_body
2636             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2637         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2638         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2639     if "f" in prefix.casefold():
2640         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2641         for m in matches:
2642             if "\\" in str(m):
2643                 # Do not introduce backslashes in interpolated expressions
2644                 return
2645     if new_quote == '"""' and new_body[-1:] == '"':
2646         # edge case:
2647         new_body = new_body[:-1] + '\\"'
2648     orig_escape_count = body.count("\\")
2649     new_escape_count = new_body.count("\\")
2650     if new_escape_count > orig_escape_count:
2651         return  # Do not introduce more escaping
2652
2653     if new_escape_count == orig_escape_count and orig_quote == '"':
2654         return  # Prefer double quotes
2655
2656     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2657
2658
2659 def normalize_numeric_literal(leaf: Leaf) -> None:
2660     """Normalizes numeric (float, int, and complex) literals.
2661
2662     All letters used in the representation are normalized to lowercase (except
2663     in Python 2 long literals).
2664     """
2665     text = leaf.value.lower()
2666     if text.startswith(("0o", "0b")):
2667         # Leave octal and binary literals alone.
2668         pass
2669     elif text.startswith("0x"):
2670         # Change hex literals to upper case.
2671         before, after = text[:2], text[2:]
2672         text = f"{before}{after.upper()}"
2673     elif "e" in text:
2674         before, after = text.split("e")
2675         sign = ""
2676         if after.startswith("-"):
2677             after = after[1:]
2678             sign = "-"
2679         elif after.startswith("+"):
2680             after = after[1:]
2681         before = format_float_or_int_string(before)
2682         text = f"{before}e{sign}{after}"
2683     elif text.endswith(("j", "l")):
2684         number = text[:-1]
2685         suffix = text[-1]
2686         # Capitalize in "2L" because "l" looks too similar to "1".
2687         if suffix == "l":
2688             suffix = "L"
2689         text = f"{format_float_or_int_string(number)}{suffix}"
2690     else:
2691         text = format_float_or_int_string(text)
2692     leaf.value = text
2693
2694
2695 def format_float_or_int_string(text: str) -> str:
2696     """Formats a float string like "1.0"."""
2697     if "." not in text:
2698         return text
2699
2700     before, after = text.split(".")
2701     return f"{before or 0}.{after or 0}"
2702
2703
2704 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2705     """Make existing optional parentheses invisible or create new ones.
2706
2707     `parens_after` is a set of string leaf values immeditely after which parens
2708     should be put.
2709
2710     Standardizes on visible parentheses for single-element tuples, and keeps
2711     existing visible parentheses for other tuples and generator expressions.
2712     """
2713     for pc in list_comments(node.prefix, is_endmarker=False):
2714         if pc.value in FMT_OFF:
2715             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2716             return
2717
2718     check_lpar = False
2719     for index, child in enumerate(list(node.children)):
2720         if check_lpar:
2721             if child.type == syms.atom:
2722                 if maybe_make_parens_invisible_in_atom(child):
2723                     lpar = Leaf(token.LPAR, "")
2724                     rpar = Leaf(token.RPAR, "")
2725                     index = child.remove() or 0
2726                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2727             elif is_one_tuple(child):
2728                 # wrap child in visible parentheses
2729                 lpar = Leaf(token.LPAR, "(")
2730                 rpar = Leaf(token.RPAR, ")")
2731                 child.remove()
2732                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2733             elif node.type == syms.import_from:
2734                 # "import from" nodes store parentheses directly as part of
2735                 # the statement
2736                 if child.type == token.LPAR:
2737                     # make parentheses invisible
2738                     child.value = ""  # type: ignore
2739                     node.children[-1].value = ""  # type: ignore
2740                 elif child.type != token.STAR:
2741                     # insert invisible parentheses
2742                     node.insert_child(index, Leaf(token.LPAR, ""))
2743                     node.append_child(Leaf(token.RPAR, ""))
2744                 break
2745
2746             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2747                 # wrap child in invisible parentheses
2748                 lpar = Leaf(token.LPAR, "")
2749                 rpar = Leaf(token.RPAR, "")
2750                 index = child.remove() or 0
2751                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2752
2753         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2754
2755
2756 def normalize_fmt_off(node: Node) -> None:
2757     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2758     try_again = True
2759     while try_again:
2760         try_again = convert_one_fmt_off_pair(node)
2761
2762
2763 def convert_one_fmt_off_pair(node: Node) -> bool:
2764     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2765
2766     Returns True if a pair was converted.
2767     """
2768     for leaf in node.leaves():
2769         previous_consumed = 0
2770         for comment in list_comments(leaf.prefix, is_endmarker=False):
2771             if comment.value in FMT_OFF:
2772                 # We only want standalone comments. If there's no previous leaf or
2773                 # the previous leaf is indentation, it's a standalone comment in
2774                 # disguise.
2775                 if comment.type != STANDALONE_COMMENT:
2776                     prev = preceding_leaf(leaf)
2777                     if prev and prev.type not in WHITESPACE:
2778                         continue
2779
2780                 ignored_nodes = list(generate_ignored_nodes(leaf))
2781                 if not ignored_nodes:
2782                     continue
2783
2784                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2785                 parent = first.parent
2786                 prefix = first.prefix
2787                 first.prefix = prefix[comment.consumed :]
2788                 hidden_value = (
2789                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2790                 )
2791                 if hidden_value.endswith("\n"):
2792                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2793                     # leaf (possibly followed by a DEDENT).
2794                     hidden_value = hidden_value[:-1]
2795                 first_idx = None
2796                 for ignored in ignored_nodes:
2797                     index = ignored.remove()
2798                     if first_idx is None:
2799                         first_idx = index
2800                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2801                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2802                 parent.insert_child(
2803                     first_idx,
2804                     Leaf(
2805                         STANDALONE_COMMENT,
2806                         hidden_value,
2807                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2808                     ),
2809                 )
2810                 return True
2811
2812             previous_consumed = comment.consumed
2813
2814     return False
2815
2816
2817 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2818     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2819
2820     Stops at the end of the block.
2821     """
2822     container: Optional[LN] = container_of(leaf)
2823     while container is not None and container.type != token.ENDMARKER:
2824         for comment in list_comments(container.prefix, is_endmarker=False):
2825             if comment.value in FMT_ON:
2826                 return
2827
2828         yield container
2829
2830         container = container.next_sibling
2831
2832
2833 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2834     """If it's safe, make the parens in the atom `node` invisible, recursively.
2835
2836     Returns whether the node should itself be wrapped in invisible parentheses.
2837
2838     """
2839     if (
2840         node.type != syms.atom
2841         or is_empty_tuple(node)
2842         or is_one_tuple(node)
2843         or is_yield(node)
2844         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2845     ):
2846         return False
2847
2848     first = node.children[0]
2849     last = node.children[-1]
2850     if first.type == token.LPAR and last.type == token.RPAR:
2851         # make parentheses invisible
2852         first.value = ""  # type: ignore
2853         last.value = ""  # type: ignore
2854         if len(node.children) > 1:
2855             maybe_make_parens_invisible_in_atom(node.children[1])
2856         return False
2857
2858     return True
2859
2860
2861 def is_empty_tuple(node: LN) -> bool:
2862     """Return True if `node` holds an empty tuple."""
2863     return (
2864         node.type == syms.atom
2865         and len(node.children) == 2
2866         and node.children[0].type == token.LPAR
2867         and node.children[1].type == token.RPAR
2868     )
2869
2870
2871 def is_one_tuple(node: LN) -> bool:
2872     """Return True if `node` holds a tuple with one element, with or without parens."""
2873     if node.type == syms.atom:
2874         if len(node.children) != 3:
2875             return False
2876
2877         lpar, gexp, rpar = node.children
2878         if not (
2879             lpar.type == token.LPAR
2880             and gexp.type == syms.testlist_gexp
2881             and rpar.type == token.RPAR
2882         ):
2883             return False
2884
2885         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2886
2887     return (
2888         node.type in IMPLICIT_TUPLE
2889         and len(node.children) == 2
2890         and node.children[1].type == token.COMMA
2891     )
2892
2893
2894 def is_yield(node: LN) -> bool:
2895     """Return True if `node` holds a `yield` or `yield from` expression."""
2896     if node.type == syms.yield_expr:
2897         return True
2898
2899     if node.type == token.NAME and node.value == "yield":  # type: ignore
2900         return True
2901
2902     if node.type != syms.atom:
2903         return False
2904
2905     if len(node.children) != 3:
2906         return False
2907
2908     lpar, expr, rpar = node.children
2909     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2910         return is_yield(expr)
2911
2912     return False
2913
2914
2915 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2916     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2917
2918     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2919     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2920     extended iterable unpacking (PEP 3132) and additional unpacking
2921     generalizations (PEP 448).
2922     """
2923     if leaf.type not in STARS or not leaf.parent:
2924         return False
2925
2926     p = leaf.parent
2927     if p.type == syms.star_expr:
2928         # Star expressions are also used as assignment targets in extended
2929         # iterable unpacking (PEP 3132).  See what its parent is instead.
2930         if not p.parent:
2931             return False
2932
2933         p = p.parent
2934
2935     return p.type in within
2936
2937
2938 def is_multiline_string(leaf: Leaf) -> bool:
2939     """Return True if `leaf` is a multiline string that actually spans many lines."""
2940     value = leaf.value.lstrip("furbFURB")
2941     return value[:3] in {'"""', "'''"} and "\n" in value
2942
2943
2944 def is_stub_suite(node: Node) -> bool:
2945     """Return True if `node` is a suite with a stub body."""
2946     if (
2947         len(node.children) != 4
2948         or node.children[0].type != token.NEWLINE
2949         or node.children[1].type != token.INDENT
2950         or node.children[3].type != token.DEDENT
2951     ):
2952         return False
2953
2954     return is_stub_body(node.children[2])
2955
2956
2957 def is_stub_body(node: LN) -> bool:
2958     """Return True if `node` is a simple statement containing an ellipsis."""
2959     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2960         return False
2961
2962     if len(node.children) != 2:
2963         return False
2964
2965     child = node.children[0]
2966     return (
2967         child.type == syms.atom
2968         and len(child.children) == 3
2969         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2970     )
2971
2972
2973 def max_delimiter_priority_in_atom(node: LN) -> int:
2974     """Return maximum delimiter priority inside `node`.
2975
2976     This is specific to atoms with contents contained in a pair of parentheses.
2977     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2978     """
2979     if node.type != syms.atom:
2980         return 0
2981
2982     first = node.children[0]
2983     last = node.children[-1]
2984     if not (first.type == token.LPAR and last.type == token.RPAR):
2985         return 0
2986
2987     bt = BracketTracker()
2988     for c in node.children[1:-1]:
2989         if isinstance(c, Leaf):
2990             bt.mark(c)
2991         else:
2992             for leaf in c.leaves():
2993                 bt.mark(leaf)
2994     try:
2995         return bt.max_delimiter_priority()
2996
2997     except ValueError:
2998         return 0
2999
3000
3001 def ensure_visible(leaf: Leaf) -> None:
3002     """Make sure parentheses are visible.
3003
3004     They could be invisible as part of some statements (see
3005     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3006     """
3007     if leaf.type == token.LPAR:
3008         leaf.value = "("
3009     elif leaf.type == token.RPAR:
3010         leaf.value = ")"
3011
3012
3013 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3014     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3015
3016     if not (
3017         opening_bracket.parent
3018         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3019         and opening_bracket.value in "[{("
3020     ):
3021         return False
3022
3023     try:
3024         last_leaf = line.leaves[-1]
3025         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3026         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3027     except (IndexError, ValueError):
3028         return False
3029
3030     return max_priority == COMMA_PRIORITY
3031
3032
3033 def get_features_used(node: Node) -> Set[Feature]:
3034     """Return a set of (relatively) new Python features used in this file.
3035
3036     Currently looking for:
3037     - f-strings;
3038     - underscores in numeric literals; and
3039     - trailing commas after * or ** in function signatures and calls.
3040     """
3041     features: Set[Feature] = set()
3042     for n in node.pre_order():
3043         if n.type == token.STRING:
3044             value_head = n.value[:2]  # type: ignore
3045             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3046                 features.add(Feature.F_STRINGS)
3047
3048         elif n.type == token.NUMBER:
3049             if "_" in n.value:  # type: ignore
3050                 features.add(Feature.NUMERIC_UNDERSCORES)
3051
3052         elif (
3053             n.type in {syms.typedargslist, syms.arglist}
3054             and n.children
3055             and n.children[-1].type == token.COMMA
3056         ):
3057             if n.type == syms.typedargslist:
3058                 feature = Feature.TRAILING_COMMA_IN_DEF
3059             else:
3060                 feature = Feature.TRAILING_COMMA_IN_CALL
3061
3062             for ch in n.children:
3063                 if ch.type in STARS:
3064                     features.add(feature)
3065
3066                 if ch.type == syms.argument:
3067                     for argch in ch.children:
3068                         if argch.type in STARS:
3069                             features.add(feature)
3070
3071     return features
3072
3073
3074 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3075     """Detect the version to target based on the nodes used."""
3076     features = get_features_used(node)
3077     return {
3078         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3079     }
3080
3081
3082 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3083     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3084
3085     Brackets can be omitted if the entire trailer up to and including
3086     a preceding closing bracket fits in one line.
3087
3088     Yielded sets are cumulative (contain results of previous yields, too).  First
3089     set is empty.
3090     """
3091
3092     omit: Set[LeafID] = set()
3093     yield omit
3094
3095     length = 4 * line.depth
3096     opening_bracket = None
3097     closing_bracket = None
3098     inner_brackets: Set[LeafID] = set()
3099     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3100         length += leaf_length
3101         if length > line_length:
3102             break
3103
3104         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3105         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3106             break
3107
3108         if opening_bracket:
3109             if leaf is opening_bracket:
3110                 opening_bracket = None
3111             elif leaf.type in CLOSING_BRACKETS:
3112                 inner_brackets.add(id(leaf))
3113         elif leaf.type in CLOSING_BRACKETS:
3114             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3115                 # Empty brackets would fail a split so treat them as "inner"
3116                 # brackets (e.g. only add them to the `omit` set if another
3117                 # pair of brackets was good enough.
3118                 inner_brackets.add(id(leaf))
3119                 continue
3120
3121             if closing_bracket:
3122                 omit.add(id(closing_bracket))
3123                 omit.update(inner_brackets)
3124                 inner_brackets.clear()
3125                 yield omit
3126
3127             if leaf.value:
3128                 opening_bracket = leaf.opening_bracket
3129                 closing_bracket = leaf
3130
3131
3132 def get_future_imports(node: Node) -> Set[str]:
3133     """Return a set of __future__ imports in the file."""
3134     imports: Set[str] = set()
3135
3136     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3137         for child in children:
3138             if isinstance(child, Leaf):
3139                 if child.type == token.NAME:
3140                     yield child.value
3141             elif child.type == syms.import_as_name:
3142                 orig_name = child.children[0]
3143                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3144                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3145                 yield orig_name.value
3146             elif child.type == syms.import_as_names:
3147                 yield from get_imports_from_children(child.children)
3148             else:
3149                 assert False, "Invalid syntax parsing imports"
3150
3151     for child in node.children:
3152         if child.type != syms.simple_stmt:
3153             break
3154         first_child = child.children[0]
3155         if isinstance(first_child, Leaf):
3156             # Continue looking if we see a docstring; otherwise stop.
3157             if (
3158                 len(child.children) == 2
3159                 and first_child.type == token.STRING
3160                 and child.children[1].type == token.NEWLINE
3161             ):
3162                 continue
3163             else:
3164                 break
3165         elif first_child.type == syms.import_from:
3166             module_name = first_child.children[1]
3167             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3168                 break
3169             imports |= set(get_imports_from_children(first_child.children[3:]))
3170         else:
3171             break
3172     return imports
3173
3174
3175 def gen_python_files_in_dir(
3176     path: Path,
3177     root: Path,
3178     include: Pattern[str],
3179     exclude: Pattern[str],
3180     report: "Report",
3181 ) -> Iterator[Path]:
3182     """Generate all files under `path` whose paths are not excluded by the
3183     `exclude` regex, but are included by the `include` regex.
3184
3185     Symbolic links pointing outside of the `root` directory are ignored.
3186
3187     `report` is where output about exclusions goes.
3188     """
3189     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3190     for child in path.iterdir():
3191         try:
3192             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3193         except ValueError:
3194             if child.is_symlink():
3195                 report.path_ignored(
3196                     child, f"is a symbolic link that points outside {root}"
3197                 )
3198                 continue
3199
3200             raise
3201
3202         if child.is_dir():
3203             normalized_path += "/"
3204         exclude_match = exclude.search(normalized_path)
3205         if exclude_match and exclude_match.group(0):
3206             report.path_ignored(child, f"matches the --exclude regular expression")
3207             continue
3208
3209         if child.is_dir():
3210             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3211
3212         elif child.is_file():
3213             include_match = include.search(normalized_path)
3214             if include_match:
3215                 yield child
3216
3217
3218 @lru_cache()
3219 def find_project_root(srcs: Iterable[str]) -> Path:
3220     """Return a directory containing .git, .hg, or pyproject.toml.
3221
3222     That directory can be one of the directories passed in `srcs` or their
3223     common parent.
3224
3225     If no directory in the tree contains a marker that would specify it's the
3226     project root, the root of the file system is returned.
3227     """
3228     if not srcs:
3229         return Path("/").resolve()
3230
3231     common_base = min(Path(src).resolve() for src in srcs)
3232     if common_base.is_dir():
3233         # Append a fake file so `parents` below returns `common_base_dir`, too.
3234         common_base /= "fake-file"
3235     for directory in common_base.parents:
3236         if (directory / ".git").is_dir():
3237             return directory
3238
3239         if (directory / ".hg").is_dir():
3240             return directory
3241
3242         if (directory / "pyproject.toml").is_file():
3243             return directory
3244
3245     return directory
3246
3247
3248 @dataclass
3249 class Report:
3250     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3251
3252     check: bool = False
3253     quiet: bool = False
3254     verbose: bool = False
3255     change_count: int = 0
3256     same_count: int = 0
3257     failure_count: int = 0
3258
3259     def done(self, src: Path, changed: Changed) -> None:
3260         """Increment the counter for successful reformatting. Write out a message."""
3261         if changed is Changed.YES:
3262             reformatted = "would reformat" if self.check else "reformatted"
3263             if self.verbose or not self.quiet:
3264                 out(f"{reformatted} {src}")
3265             self.change_count += 1
3266         else:
3267             if self.verbose:
3268                 if changed is Changed.NO:
3269                     msg = f"{src} already well formatted, good job."
3270                 else:
3271                     msg = f"{src} wasn't modified on disk since last run."
3272                 out(msg, bold=False)
3273             self.same_count += 1
3274
3275     def failed(self, src: Path, message: str) -> None:
3276         """Increment the counter for failed reformatting. Write out a message."""
3277         err(f"error: cannot format {src}: {message}")
3278         self.failure_count += 1
3279
3280     def path_ignored(self, path: Path, message: str) -> None:
3281         if self.verbose:
3282             out(f"{path} ignored: {message}", bold=False)
3283
3284     @property
3285     def return_code(self) -> int:
3286         """Return the exit code that the app should use.
3287
3288         This considers the current state of changed files and failures:
3289         - if there were any failures, return 123;
3290         - if any files were changed and --check is being used, return 1;
3291         - otherwise return 0.
3292         """
3293         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3294         # 126 we have special return codes reserved by the shell.
3295         if self.failure_count:
3296             return 123
3297
3298         elif self.change_count and self.check:
3299             return 1
3300
3301         return 0
3302
3303     def __str__(self) -> str:
3304         """Render a color report of the current state.
3305
3306         Use `click.unstyle` to remove colors.
3307         """
3308         if self.check:
3309             reformatted = "would be reformatted"
3310             unchanged = "would be left unchanged"
3311             failed = "would fail to reformat"
3312         else:
3313             reformatted = "reformatted"
3314             unchanged = "left unchanged"
3315             failed = "failed to reformat"
3316         report = []
3317         if self.change_count:
3318             s = "s" if self.change_count > 1 else ""
3319             report.append(
3320                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3321             )
3322         if self.same_count:
3323             s = "s" if self.same_count > 1 else ""
3324             report.append(f"{self.same_count} file{s} {unchanged}")
3325         if self.failure_count:
3326             s = "s" if self.failure_count > 1 else ""
3327             report.append(
3328                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3329             )
3330         return ", ".join(report) + "."
3331
3332
3333 def assert_equivalent(src: str, dst: str) -> None:
3334     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3335
3336     import ast
3337     import traceback
3338
3339     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3340         """Simple visitor generating strings to compare ASTs by content."""
3341         yield f"{'  ' * depth}{node.__class__.__name__}("
3342
3343         for field in sorted(node._fields):
3344             try:
3345                 value = getattr(node, field)
3346             except AttributeError:
3347                 continue
3348
3349             yield f"{'  ' * (depth+1)}{field}="
3350
3351             if isinstance(value, list):
3352                 for item in value:
3353                     # Ignore nested tuples within del statements, because we may insert
3354                     # parentheses and they change the AST.
3355                     if (
3356                         field == "targets"
3357                         and isinstance(node, ast.Delete)
3358                         and isinstance(item, ast.Tuple)
3359                     ):
3360                         for item in item.elts:
3361                             yield from _v(item, depth + 2)
3362                     elif isinstance(item, ast.AST):
3363                         yield from _v(item, depth + 2)
3364
3365             elif isinstance(value, ast.AST):
3366                 yield from _v(value, depth + 2)
3367
3368             else:
3369                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3370
3371         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3372
3373     try:
3374         src_ast = ast.parse(src)
3375     except Exception as exc:
3376         major, minor = sys.version_info[:2]
3377         raise AssertionError(
3378             f"cannot use --safe with this file; failed to parse source file "
3379             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3380             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3381         )
3382
3383     try:
3384         dst_ast = ast.parse(dst)
3385     except Exception as exc:
3386         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3387         raise AssertionError(
3388             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3389             f"Please report a bug on https://github.com/ambv/black/issues.  "
3390             f"This invalid output might be helpful: {log}"
3391         ) from None
3392
3393     src_ast_str = "\n".join(_v(src_ast))
3394     dst_ast_str = "\n".join(_v(dst_ast))
3395     if src_ast_str != dst_ast_str:
3396         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3397         raise AssertionError(
3398             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3399             f"the source.  "
3400             f"Please report a bug on https://github.com/ambv/black/issues.  "
3401             f"This diff might be helpful: {log}"
3402         ) from None
3403
3404
3405 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3406     """Raise AssertionError if `dst` reformats differently the second time."""
3407     newdst = format_str(dst, mode=mode)
3408     if dst != newdst:
3409         log = dump_to_file(
3410             diff(src, dst, "source", "first pass"),
3411             diff(dst, newdst, "first pass", "second pass"),
3412         )
3413         raise AssertionError(
3414             f"INTERNAL ERROR: Black produced different code on the second pass "
3415             f"of the formatter.  "
3416             f"Please report a bug on https://github.com/ambv/black/issues.  "
3417             f"This diff might be helpful: {log}"
3418         ) from None
3419
3420
3421 def dump_to_file(*output: str) -> str:
3422     """Dump `output` to a temporary file. Return path to the file."""
3423     import tempfile
3424
3425     with tempfile.NamedTemporaryFile(
3426         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3427     ) as f:
3428         for lines in output:
3429             f.write(lines)
3430             if lines and lines[-1] != "\n":
3431                 f.write("\n")
3432     return f.name
3433
3434
3435 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3436     """Return a unified diff string between strings `a` and `b`."""
3437     import difflib
3438
3439     a_lines = [line + "\n" for line in a.split("\n")]
3440     b_lines = [line + "\n" for line in b.split("\n")]
3441     return "".join(
3442         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3443     )
3444
3445
3446 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3447     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3448     err("Aborted!")
3449     for task in tasks:
3450         task.cancel()
3451
3452
3453 def shutdown(loop: BaseEventLoop) -> None:
3454     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3455     try:
3456         if sys.version_info[:2] >= (3, 7):
3457             all_tasks = asyncio.all_tasks
3458         else:
3459             all_tasks = asyncio.Task.all_tasks
3460         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3461         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3462         if not to_cancel:
3463             return
3464
3465         for task in to_cancel:
3466             task.cancel()
3467         loop.run_until_complete(
3468             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3469         )
3470     finally:
3471         # `concurrent.futures.Future` objects cannot be cancelled once they
3472         # are already running. There might be some when the `shutdown()` happened.
3473         # Silence their logger's spew about the event loop being closed.
3474         cf_logger = logging.getLogger("concurrent.futures")
3475         cf_logger.setLevel(logging.CRITICAL)
3476         loop.close()
3477
3478
3479 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3480     """Replace `regex` with `replacement` twice on `original`.
3481
3482     This is used by string normalization to perform replaces on
3483     overlapping matches.
3484     """
3485     return regex.sub(replacement, regex.sub(replacement, original))
3486
3487
3488 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3489     """Compile a regular expression string in `regex`.
3490
3491     If it contains newlines, use verbose mode.
3492     """
3493     if "\n" in regex:
3494         regex = "(?x)" + regex
3495     return re.compile(regex)
3496
3497
3498 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3499     """Like `reversed(enumerate(sequence))` if that were possible."""
3500     index = len(sequence) - 1
3501     for element in reversed(sequence):
3502         yield (index, element)
3503         index -= 1
3504
3505
3506 def enumerate_with_length(
3507     line: Line, reversed: bool = False
3508 ) -> Iterator[Tuple[Index, Leaf, int]]:
3509     """Return an enumeration of leaves with their length.
3510
3511     Stops prematurely on multiline strings and standalone comments.
3512     """
3513     op = cast(
3514         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3515         enumerate_reversed if reversed else enumerate,
3516     )
3517     for index, leaf in op(line.leaves):
3518         length = len(leaf.prefix) + len(leaf.value)
3519         if "\n" in leaf.value:
3520             return  # Multiline strings, we can't continue.
3521
3522         comment: Optional[Leaf]
3523         for comment in line.comments_after(leaf):
3524             length += len(comment.value)
3525
3526         yield index, leaf, length
3527
3528
3529 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3530     """Return True if `line` is no longer than `line_length`.
3531
3532     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3533     """
3534     if not line_str:
3535         line_str = str(line).strip("\n")
3536     return (
3537         len(line_str) <= line_length
3538         and "\n" not in line_str  # multiline strings
3539         and not line.contains_standalone_comments()
3540     )
3541
3542
3543 def can_be_split(line: Line) -> bool:
3544     """Return False if the line cannot be split *for sure*.
3545
3546     This is not an exhaustive search but a cheap heuristic that we can use to
3547     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3548     in unnecessary parentheses).
3549     """
3550     leaves = line.leaves
3551     if len(leaves) < 2:
3552         return False
3553
3554     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3555         call_count = 0
3556         dot_count = 0
3557         next = leaves[-1]
3558         for leaf in leaves[-2::-1]:
3559             if leaf.type in OPENING_BRACKETS:
3560                 if next.type not in CLOSING_BRACKETS:
3561                     return False
3562
3563                 call_count += 1
3564             elif leaf.type == token.DOT:
3565                 dot_count += 1
3566             elif leaf.type == token.NAME:
3567                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3568                     return False
3569
3570             elif leaf.type not in CLOSING_BRACKETS:
3571                 return False
3572
3573             if dot_count > 1 and call_count > 1:
3574                 return False
3575
3576     return True
3577
3578
3579 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3580     """Does `line` have a shape safe to reformat without optional parens around it?
3581
3582     Returns True for only a subset of potentially nice looking formattings but
3583     the point is to not return false positives that end up producing lines that
3584     are too long.
3585     """
3586     bt = line.bracket_tracker
3587     if not bt.delimiters:
3588         # Without delimiters the optional parentheses are useless.
3589         return True
3590
3591     max_priority = bt.max_delimiter_priority()
3592     if bt.delimiter_count_with_priority(max_priority) > 1:
3593         # With more than one delimiter of a kind the optional parentheses read better.
3594         return False
3595
3596     if max_priority == DOT_PRIORITY:
3597         # A single stranded method call doesn't require optional parentheses.
3598         return True
3599
3600     assert len(line.leaves) >= 2, "Stranded delimiter"
3601
3602     first = line.leaves[0]
3603     second = line.leaves[1]
3604     penultimate = line.leaves[-2]
3605     last = line.leaves[-1]
3606
3607     # With a single delimiter, omit if the expression starts or ends with
3608     # a bracket.
3609     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3610         remainder = False
3611         length = 4 * line.depth
3612         for _index, leaf, leaf_length in enumerate_with_length(line):
3613             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3614                 remainder = True
3615             if remainder:
3616                 length += leaf_length
3617                 if length > line_length:
3618                     break
3619
3620                 if leaf.type in OPENING_BRACKETS:
3621                     # There are brackets we can further split on.
3622                     remainder = False
3623
3624         else:
3625             # checked the entire string and line length wasn't exceeded
3626             if len(line.leaves) == _index + 1:
3627                 return True
3628
3629         # Note: we are not returning False here because a line might have *both*
3630         # a leading opening bracket and a trailing closing bracket.  If the
3631         # opening bracket doesn't match our rule, maybe the closing will.
3632
3633     if (
3634         last.type == token.RPAR
3635         or last.type == token.RBRACE
3636         or (
3637             # don't use indexing for omitting optional parentheses;
3638             # it looks weird
3639             last.type == token.RSQB
3640             and last.parent
3641             and last.parent.type != syms.trailer
3642         )
3643     ):
3644         if penultimate.type in OPENING_BRACKETS:
3645             # Empty brackets don't help.
3646             return False
3647
3648         if is_multiline_string(first):
3649             # Additional wrapping of a multiline string in this situation is
3650             # unnecessary.
3651             return True
3652
3653         length = 4 * line.depth
3654         seen_other_brackets = False
3655         for _index, leaf, leaf_length in enumerate_with_length(line):
3656             length += leaf_length
3657             if leaf is last.opening_bracket:
3658                 if seen_other_brackets or length <= line_length:
3659                     return True
3660
3661             elif leaf.type in OPENING_BRACKETS:
3662                 # There are brackets we can further split on.
3663                 seen_other_brackets = True
3664
3665     return False
3666
3667
3668 def get_cache_file(mode: FileMode) -> Path:
3669     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3670
3671
3672 def read_cache(mode: FileMode) -> Cache:
3673     """Read the cache if it exists and is well formed.
3674
3675     If it is not well formed, the call to write_cache later should resolve the issue.
3676     """
3677     cache_file = get_cache_file(mode)
3678     if not cache_file.exists():
3679         return {}
3680
3681     with cache_file.open("rb") as fobj:
3682         try:
3683             cache: Cache = pickle.load(fobj)
3684         except pickle.UnpicklingError:
3685             return {}
3686
3687     return cache
3688
3689
3690 def get_cache_info(path: Path) -> CacheInfo:
3691     """Return the information used to check if a file is already formatted or not."""
3692     stat = path.stat()
3693     return stat.st_mtime, stat.st_size
3694
3695
3696 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3697     """Split an iterable of paths in `sources` into two sets.
3698
3699     The first contains paths of files that modified on disk or are not in the
3700     cache. The other contains paths to non-modified files.
3701     """
3702     todo, done = set(), set()
3703     for src in sources:
3704         src = src.resolve()
3705         if cache.get(src) != get_cache_info(src):
3706             todo.add(src)
3707         else:
3708             done.add(src)
3709     return todo, done
3710
3711
3712 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3713     """Update the cache file."""
3714     cache_file = get_cache_file(mode)
3715     try:
3716         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3717         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3718         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3719             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3720         os.replace(f.name, cache_file)
3721     except OSError:
3722         pass
3723
3724
3725 def patch_click() -> None:
3726     """Make Click not crash.
3727
3728     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3729     default which restricts paths that it can access during the lifetime of the
3730     application.  Click refuses to work in this scenario by raising a RuntimeError.
3731
3732     In case of Black the likelihood that non-ASCII characters are going to be used in
3733     file paths is minimal since it's Python source code.  Moreover, this crash was
3734     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3735     """
3736     try:
3737         from click import core
3738         from click import _unicodefun  # type: ignore
3739     except ModuleNotFoundError:
3740         return
3741
3742     for module in (core, _unicodefun):
3743         if hasattr(module, "_verify_python3_env"):
3744             module._verify_python3_env = lambda: None
3745
3746
3747 def patched_main() -> None:
3748     freeze_support()
3749     patch_click()
3750     main()
3751
3752
3753 if __name__ == "__main__":
3754     patched_main()