]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Wrap `loop.run_in_executor` up in `asyncio.ensure_future` for reliable cross-platform...
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "19.3b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA_IN_CALL = 4
137     TRAILING_COMMA_IN_DEF = 5
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.PY27: set(),
142     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
144     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
145     TargetVersion.PY36: {
146         Feature.UNICODE_LITERALS,
147         Feature.F_STRINGS,
148         Feature.NUMERIC_UNDERSCORES,
149         Feature.TRAILING_COMMA_IN_CALL,
150         Feature.TRAILING_COMMA_IN_DEF,
151     },
152     TargetVersion.PY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA_IN_CALL,
157         Feature.TRAILING_COMMA_IN_DEF,
158     },
159     TargetVersion.PY38: {
160         Feature.UNICODE_LITERALS,
161         Feature.F_STRINGS,
162         Feature.NUMERIC_UNDERSCORES,
163         Feature.TRAILING_COMMA_IN_CALL,
164         Feature.TRAILING_COMMA_IN_DEF,
165     },
166 }
167
168
169 @dataclass
170 class FileMode:
171     target_versions: Set[TargetVersion] = Factory(set)
172     line_length: int = DEFAULT_LINE_LENGTH
173     string_normalization: bool = True
174     is_pyi: bool = False
175
176     def get_cache_key(self) -> str:
177         if self.target_versions:
178             version_str = ",".join(
179                 str(version.value)
180                 for version in sorted(self.target_versions, key=lambda v: v.value)
181             )
182         else:
183             version_str = "-"
184         parts = [
185             version_str,
186             str(self.line_length),
187             str(int(self.string_normalization)),
188             str(int(self.is_pyi)),
189         ]
190         return ".".join(parts)
191
192
193 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
194     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
195
196
197 def read_pyproject_toml(
198     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
199 ) -> Optional[str]:
200     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
201
202     Returns the path to a successfully found and read configuration file, None
203     otherwise.
204     """
205     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
206     if not value:
207         root = find_project_root(ctx.params.get("src", ()))
208         path = root / "pyproject.toml"
209         if path.is_file():
210             value = str(path)
211         else:
212             return None
213
214     try:
215         pyproject_toml = toml.load(value)
216         config = pyproject_toml.get("tool", {}).get("black", {})
217     except (toml.TomlDecodeError, OSError) as e:
218         raise click.FileError(
219             filename=value, hint=f"Error reading configuration file: {e}"
220         )
221
222     if not config:
223         return None
224
225     if ctx.default_map is None:
226         ctx.default_map = {}
227     ctx.default_map.update(  # type: ignore  # bad types in .pyi
228         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
229     )
230     return value
231
232
233 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
234 @click.option(
235     "-l",
236     "--line-length",
237     type=int,
238     default=DEFAULT_LINE_LENGTH,
239     help="How many characters per line to allow.",
240     show_default=True,
241 )
242 @click.option(
243     "-t",
244     "--target-version",
245     type=click.Choice([v.name.lower() for v in TargetVersion]),
246     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
247     multiple=True,
248     help=(
249         "Python versions that should be supported by Black's output. [default: "
250         "per-file auto-detection]"
251     ),
252 )
253 @click.option(
254     "--py36",
255     is_flag=True,
256     help=(
257         "Allow using Python 3.6-only syntax on all input files.  This will put "
258         "trailing commas in function signatures and calls also after *args and "
259         "**kwargs. Deprecated; use --target-version instead. "
260         "[default: per-file auto-detection]"
261     ),
262 )
263 @click.option(
264     "--pyi",
265     is_flag=True,
266     help=(
267         "Format all input files like typing stubs regardless of file extension "
268         "(useful when piping source on standard input)."
269     ),
270 )
271 @click.option(
272     "-S",
273     "--skip-string-normalization",
274     is_flag=True,
275     help="Don't normalize string quotes or prefixes.",
276 )
277 @click.option(
278     "--check",
279     is_flag=True,
280     help=(
281         "Don't write the files back, just return the status.  Return code 0 "
282         "means nothing would change.  Return code 1 means some files would be "
283         "reformatted.  Return code 123 means there was an internal error."
284     ),
285 )
286 @click.option(
287     "--diff",
288     is_flag=True,
289     help="Don't write the files back, just output a diff for each file on stdout.",
290 )
291 @click.option(
292     "--fast/--safe",
293     is_flag=True,
294     help="If --fast given, skip temporary sanity checks. [default: --safe]",
295 )
296 @click.option(
297     "--include",
298     type=str,
299     default=DEFAULT_INCLUDES,
300     help=(
301         "A regular expression that matches files and directories that should be "
302         "included on recursive searches.  An empty value means all files are "
303         "included regardless of the name.  Use forward slashes for directories on "
304         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
305         "later."
306     ),
307     show_default=True,
308 )
309 @click.option(
310     "--exclude",
311     type=str,
312     default=DEFAULT_EXCLUDES,
313     help=(
314         "A regular expression that matches files and directories that should be "
315         "excluded on recursive searches.  An empty value means no paths are excluded. "
316         "Use forward slashes for directories on all platforms (Windows, too).  "
317         "Exclusions are calculated first, inclusions later."
318     ),
319     show_default=True,
320 )
321 @click.option(
322     "-q",
323     "--quiet",
324     is_flag=True,
325     help=(
326         "Don't emit non-error messages to stderr. Errors are still emitted, "
327         "silence those with 2>/dev/null."
328     ),
329 )
330 @click.option(
331     "-v",
332     "--verbose",
333     is_flag=True,
334     help=(
335         "Also emit messages to stderr about files that were not changed or were "
336         "ignored due to --exclude=."
337     ),
338 )
339 @click.version_option(version=__version__)
340 @click.argument(
341     "src",
342     nargs=-1,
343     type=click.Path(
344         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
345     ),
346     is_eager=True,
347 )
348 @click.option(
349     "--config",
350     type=click.Path(
351         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
352     ),
353     is_eager=True,
354     callback=read_pyproject_toml,
355     help="Read configuration from PATH.",
356 )
357 @click.pass_context
358 def main(
359     ctx: click.Context,
360     line_length: int,
361     target_version: List[TargetVersion],
362     check: bool,
363     diff: bool,
364     fast: bool,
365     pyi: bool,
366     py36: bool,
367     skip_string_normalization: bool,
368     quiet: bool,
369     verbose: bool,
370     include: str,
371     exclude: str,
372     src: Tuple[str],
373     config: Optional[str],
374 ) -> None:
375     """The uncompromising code formatter."""
376     write_back = WriteBack.from_configuration(check=check, diff=diff)
377     if target_version:
378         if py36:
379             err(f"Cannot use both --target-version and --py36")
380             ctx.exit(2)
381         else:
382             versions = set(target_version)
383     elif py36:
384         err(
385             "--py36 is deprecated and will be removed in a future version. "
386             "Use --target-version py36 instead."
387         )
388         versions = PY36_VERSIONS
389     else:
390         # We'll autodetect later.
391         versions = set()
392     mode = FileMode(
393         target_versions=versions,
394         line_length=line_length,
395         is_pyi=pyi,
396         string_normalization=not skip_string_normalization,
397     )
398     if config and verbose:
399         out(f"Using configuration from {config}.", bold=False, fg="blue")
400     try:
401         include_regex = re_compile_maybe_verbose(include)
402     except re.error:
403         err(f"Invalid regular expression for include given: {include!r}")
404         ctx.exit(2)
405     try:
406         exclude_regex = re_compile_maybe_verbose(exclude)
407     except re.error:
408         err(f"Invalid regular expression for exclude given: {exclude!r}")
409         ctx.exit(2)
410     report = Report(check=check, quiet=quiet, verbose=verbose)
411     root = find_project_root(src)
412     sources: Set[Path] = set()
413     for s in src:
414         p = Path(s)
415         if p.is_dir():
416             sources.update(
417                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
418             )
419         elif p.is_file() or s == "-":
420             # if a file was explicitly given, we don't care about its extension
421             sources.add(p)
422         else:
423             err(f"invalid path: {s}")
424     if len(sources) == 0:
425         if verbose or not quiet:
426             out("No paths given. Nothing to do 😴")
427         ctx.exit(0)
428
429     if len(sources) == 1:
430         reformat_one(
431             src=sources.pop(),
432             fast=fast,
433             write_back=write_back,
434             mode=mode,
435             report=report,
436         )
437     else:
438         loop = asyncio.get_event_loop()
439         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
440         try:
441             loop.run_until_complete(
442                 schedule_formatting(
443                     sources=sources,
444                     fast=fast,
445                     write_back=write_back,
446                     mode=mode,
447                     report=report,
448                     loop=loop,
449                     executor=executor,
450                 )
451             )
452         finally:
453             shutdown(loop)
454     if verbose or not quiet:
455         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
456         out(f"All done! {bang}")
457         click.secho(str(report), err=True)
458     ctx.exit(report.return_code)
459
460
461 def reformat_one(
462     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
463 ) -> None:
464     """Reformat a single file under `src` without spawning child processes.
465
466     If `quiet` is True, non-error messages are not output. `line_length`,
467     `write_back`, `fast` and `pyi` options are passed to
468     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
469     """
470     try:
471         changed = Changed.NO
472         if not src.is_file() and str(src) == "-":
473             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
474                 changed = Changed.YES
475         else:
476             cache: Cache = {}
477             if write_back != WriteBack.DIFF:
478                 cache = read_cache(mode)
479                 res_src = src.resolve()
480                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
481                     changed = Changed.CACHED
482             if changed is not Changed.CACHED and format_file_in_place(
483                 src, fast=fast, write_back=write_back, mode=mode
484             ):
485                 changed = Changed.YES
486             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
487                 write_back is WriteBack.CHECK and changed is Changed.NO
488             ):
489                 write_cache(cache, [src], mode)
490         report.done(src, changed)
491     except Exception as exc:
492         report.failed(src, str(exc))
493
494
495 async def schedule_formatting(
496     sources: Set[Path],
497     fast: bool,
498     write_back: WriteBack,
499     mode: FileMode,
500     report: "Report",
501     loop: BaseEventLoop,
502     executor: Executor,
503 ) -> None:
504     """Run formatting of `sources` in parallel using the provided `executor`.
505
506     (Use ProcessPoolExecutors for actual parallelism.)
507
508     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
509     :func:`format_file_in_place`.
510     """
511     cache: Cache = {}
512     if write_back != WriteBack.DIFF:
513         cache = read_cache(mode)
514         sources, cached = filter_cached(cache, sources)
515         for src in sorted(cached):
516             report.done(src, Changed.CACHED)
517     if not sources:
518         return
519
520     cancelled = []
521     sources_to_cache = []
522     lock = None
523     if write_back == WriteBack.DIFF:
524         # For diff output, we need locks to ensure we don't interleave output
525         # from different processes.
526         manager = Manager()
527         lock = manager.Lock()
528     tasks = {
529         asyncio.ensure_future(
530             loop.run_in_executor(
531                 executor, format_file_in_place, src, fast, mode, write_back, lock
532             )
533         ): src
534         for src in sorted(sources)
535     }
536     pending: Iterable[asyncio.Future] = tasks.keys()
537     try:
538         loop.add_signal_handler(signal.SIGINT, cancel, pending)
539         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
540     except NotImplementedError:
541         # There are no good alternatives for these on Windows.
542         pass
543     while pending:
544         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
545         for task in done:
546             src = tasks.pop(task)
547             if task.cancelled():
548                 cancelled.append(task)
549             elif task.exception():
550                 report.failed(src, str(task.exception()))
551             else:
552                 changed = Changed.YES if task.result() else Changed.NO
553                 # If the file was written back or was successfully checked as
554                 # well-formatted, store this information in the cache.
555                 if write_back is WriteBack.YES or (
556                     write_back is WriteBack.CHECK and changed is Changed.NO
557                 ):
558                     sources_to_cache.append(src)
559                 report.done(src, changed)
560     if cancelled:
561         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
562     if sources_to_cache:
563         write_cache(cache, sources_to_cache, mode)
564
565
566 def format_file_in_place(
567     src: Path,
568     fast: bool,
569     mode: FileMode,
570     write_back: WriteBack = WriteBack.NO,
571     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
572 ) -> bool:
573     """Format file under `src` path. Return True if changed.
574
575     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
576     code to the file.
577     `line_length` and `fast` options are passed to :func:`format_file_contents`.
578     """
579     if src.suffix == ".pyi":
580         mode = evolve(mode, is_pyi=True)
581
582     then = datetime.utcfromtimestamp(src.stat().st_mtime)
583     with open(src, "rb") as buf:
584         src_contents, encoding, newline = decode_bytes(buf.read())
585     try:
586         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
587     except NothingChanged:
588         return False
589
590     if write_back == write_back.YES:
591         with open(src, "w", encoding=encoding, newline=newline) as f:
592             f.write(dst_contents)
593     elif write_back == write_back.DIFF:
594         now = datetime.utcnow()
595         src_name = f"{src}\t{then} +0000"
596         dst_name = f"{src}\t{now} +0000"
597         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
598         if lock:
599             lock.acquire()
600         try:
601             f = io.TextIOWrapper(
602                 sys.stdout.buffer,
603                 encoding=encoding,
604                 newline=newline,
605                 write_through=True,
606             )
607             f.write(diff_contents)
608             f.detach()
609         finally:
610             if lock:
611                 lock.release()
612     return True
613
614
615 def format_stdin_to_stdout(
616     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
617 ) -> bool:
618     """Format file on stdin. Return True if changed.
619
620     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
621     write a diff to stdout. The `mode` argument is passed to
622     :func:`format_file_contents`.
623     """
624     then = datetime.utcnow()
625     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
626     dst = src
627     try:
628         dst = format_file_contents(src, fast=fast, mode=mode)
629         return True
630
631     except NothingChanged:
632         return False
633
634     finally:
635         f = io.TextIOWrapper(
636             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
637         )
638         if write_back == WriteBack.YES:
639             f.write(dst)
640         elif write_back == WriteBack.DIFF:
641             now = datetime.utcnow()
642             src_name = f"STDIN\t{then} +0000"
643             dst_name = f"STDOUT\t{now} +0000"
644             f.write(diff(src, dst, src_name, dst_name))
645         f.detach()
646
647
648 def format_file_contents(
649     src_contents: str, *, fast: bool, mode: FileMode
650 ) -> FileContent:
651     """Reformat contents a file and return new contents.
652
653     If `fast` is False, additionally confirm that the reformatted code is
654     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
655     `line_length` is passed to :func:`format_str`.
656     """
657     if src_contents.strip() == "":
658         raise NothingChanged
659
660     dst_contents = format_str(src_contents, mode=mode)
661     if src_contents == dst_contents:
662         raise NothingChanged
663
664     if not fast:
665         assert_equivalent(src_contents, dst_contents)
666         assert_stable(src_contents, dst_contents, mode=mode)
667     return dst_contents
668
669
670 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
671     """Reformat a string and return new contents.
672
673     `line_length` determines how many characters per line are allowed.
674     """
675     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
676     dst_contents = ""
677     future_imports = get_future_imports(src_node)
678     if mode.target_versions:
679         versions = mode.target_versions
680     else:
681         versions = detect_target_versions(src_node)
682     normalize_fmt_off(src_node)
683     lines = LineGenerator(
684         remove_u_prefix="unicode_literals" in future_imports
685         or supports_feature(versions, Feature.UNICODE_LITERALS),
686         is_pyi=mode.is_pyi,
687         normalize_strings=mode.string_normalization,
688     )
689     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
690     empty_line = Line()
691     after = 0
692     split_line_features = {
693         feature
694         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
695         if supports_feature(versions, feature)
696     }
697     for current_line in lines.visit(src_node):
698         for _ in range(after):
699             dst_contents += str(empty_line)
700         before, after = elt.maybe_empty_lines(current_line)
701         for _ in range(before):
702             dst_contents += str(empty_line)
703         for line in split_line(
704             current_line, line_length=mode.line_length, features=split_line_features
705         ):
706             dst_contents += str(line)
707     return dst_contents
708
709
710 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
711     """Return a tuple of (decoded_contents, encoding, newline).
712
713     `newline` is either CRLF or LF but `decoded_contents` is decoded with
714     universal newlines (i.e. only contains LF).
715     """
716     srcbuf = io.BytesIO(src)
717     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
718     if not lines:
719         return "", encoding, "\n"
720
721     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
722     srcbuf.seek(0)
723     with io.TextIOWrapper(srcbuf, encoding) as tiow:
724         return tiow.read(), encoding, newline
725
726
727 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
728     if not target_versions:
729         # No target_version specified, so try all grammars.
730         return [
731             pygram.python_grammar_no_print_statement_no_exec_statement,
732             pygram.python_grammar_no_print_statement,
733             pygram.python_grammar,
734         ]
735     elif all(version.is_python2() for version in target_versions):
736         # Python 2-only code, so try Python 2 grammars.
737         return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
738     else:
739         # Python 3-compatible code, so only try Python 3 grammar.
740         return [pygram.python_grammar_no_print_statement_no_exec_statement]
741
742
743 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
744     """Given a string with source, return the lib2to3 Node."""
745     if src_txt[-1:] != "\n":
746         src_txt += "\n"
747
748     for grammar in get_grammars(set(target_versions)):
749         drv = driver.Driver(grammar, pytree.convert)
750         try:
751             result = drv.parse_string(src_txt, True)
752             break
753
754         except ParseError as pe:
755             lineno, column = pe.context[1]
756             lines = src_txt.splitlines()
757             try:
758                 faulty_line = lines[lineno - 1]
759             except IndexError:
760                 faulty_line = "<line number missing in source>"
761             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
762     else:
763         raise exc from None
764
765     if isinstance(result, Leaf):
766         result = Node(syms.file_input, [result])
767     return result
768
769
770 def lib2to3_unparse(node: Node) -> str:
771     """Given a lib2to3 node, return its string representation."""
772     code = str(node)
773     return code
774
775
776 T = TypeVar("T")
777
778
779 class Visitor(Generic[T]):
780     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
781
782     def visit(self, node: LN) -> Iterator[T]:
783         """Main method to visit `node` and its children.
784
785         It tries to find a `visit_*()` method for the given `node.type`, like
786         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
787         If no dedicated `visit_*()` method is found, chooses `visit_default()`
788         instead.
789
790         Then yields objects of type `T` from the selected visitor.
791         """
792         if node.type < 256:
793             name = token.tok_name[node.type]
794         else:
795             name = type_repr(node.type)
796         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
797
798     def visit_default(self, node: LN) -> Iterator[T]:
799         """Default `visit_*()` implementation. Recurses to children of `node`."""
800         if isinstance(node, Node):
801             for child in node.children:
802                 yield from self.visit(child)
803
804
805 @dataclass
806 class DebugVisitor(Visitor[T]):
807     tree_depth: int = 0
808
809     def visit_default(self, node: LN) -> Iterator[T]:
810         indent = " " * (2 * self.tree_depth)
811         if isinstance(node, Node):
812             _type = type_repr(node.type)
813             out(f"{indent}{_type}", fg="yellow")
814             self.tree_depth += 1
815             for child in node.children:
816                 yield from self.visit(child)
817
818             self.tree_depth -= 1
819             out(f"{indent}/{_type}", fg="yellow", bold=False)
820         else:
821             _type = token.tok_name.get(node.type, str(node.type))
822             out(f"{indent}{_type}", fg="blue", nl=False)
823             if node.prefix:
824                 # We don't have to handle prefixes for `Node` objects since
825                 # that delegates to the first child anyway.
826                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
827             out(f" {node.value!r}", fg="blue", bold=False)
828
829     @classmethod
830     def show(cls, code: Union[str, Leaf, Node]) -> None:
831         """Pretty-print the lib2to3 AST of a given string of `code`.
832
833         Convenience method for debugging.
834         """
835         v: DebugVisitor[None] = DebugVisitor()
836         if isinstance(code, str):
837             code = lib2to3_parse(code)
838         list(v.visit(code))
839
840
841 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
842 STATEMENT = {
843     syms.if_stmt,
844     syms.while_stmt,
845     syms.for_stmt,
846     syms.try_stmt,
847     syms.except_clause,
848     syms.with_stmt,
849     syms.funcdef,
850     syms.classdef,
851 }
852 STANDALONE_COMMENT = 153
853 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
854 LOGIC_OPERATORS = {"and", "or"}
855 COMPARATORS = {
856     token.LESS,
857     token.GREATER,
858     token.EQEQUAL,
859     token.NOTEQUAL,
860     token.LESSEQUAL,
861     token.GREATEREQUAL,
862 }
863 MATH_OPERATORS = {
864     token.VBAR,
865     token.CIRCUMFLEX,
866     token.AMPER,
867     token.LEFTSHIFT,
868     token.RIGHTSHIFT,
869     token.PLUS,
870     token.MINUS,
871     token.STAR,
872     token.SLASH,
873     token.DOUBLESLASH,
874     token.PERCENT,
875     token.AT,
876     token.TILDE,
877     token.DOUBLESTAR,
878 }
879 STARS = {token.STAR, token.DOUBLESTAR}
880 VARARGS_PARENTS = {
881     syms.arglist,
882     syms.argument,  # double star in arglist
883     syms.trailer,  # single argument to call
884     syms.typedargslist,
885     syms.varargslist,  # lambdas
886 }
887 UNPACKING_PARENTS = {
888     syms.atom,  # single element of a list or set literal
889     syms.dictsetmaker,
890     syms.listmaker,
891     syms.testlist_gexp,
892     syms.testlist_star_expr,
893 }
894 TEST_DESCENDANTS = {
895     syms.test,
896     syms.lambdef,
897     syms.or_test,
898     syms.and_test,
899     syms.not_test,
900     syms.comparison,
901     syms.star_expr,
902     syms.expr,
903     syms.xor_expr,
904     syms.and_expr,
905     syms.shift_expr,
906     syms.arith_expr,
907     syms.trailer,
908     syms.term,
909     syms.power,
910 }
911 ASSIGNMENTS = {
912     "=",
913     "+=",
914     "-=",
915     "*=",
916     "@=",
917     "/=",
918     "%=",
919     "&=",
920     "|=",
921     "^=",
922     "<<=",
923     ">>=",
924     "**=",
925     "//=",
926 }
927 COMPREHENSION_PRIORITY = 20
928 COMMA_PRIORITY = 18
929 TERNARY_PRIORITY = 16
930 LOGIC_PRIORITY = 14
931 STRING_PRIORITY = 12
932 COMPARATOR_PRIORITY = 10
933 MATH_PRIORITIES = {
934     token.VBAR: 9,
935     token.CIRCUMFLEX: 8,
936     token.AMPER: 7,
937     token.LEFTSHIFT: 6,
938     token.RIGHTSHIFT: 6,
939     token.PLUS: 5,
940     token.MINUS: 5,
941     token.STAR: 4,
942     token.SLASH: 4,
943     token.DOUBLESLASH: 4,
944     token.PERCENT: 4,
945     token.AT: 4,
946     token.TILDE: 3,
947     token.DOUBLESTAR: 2,
948 }
949 DOT_PRIORITY = 1
950
951
952 @dataclass
953 class BracketTracker:
954     """Keeps track of brackets on a line."""
955
956     depth: int = 0
957     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
958     delimiters: Dict[LeafID, Priority] = Factory(dict)
959     previous: Optional[Leaf] = None
960     _for_loop_depths: List[int] = Factory(list)
961     _lambda_argument_depths: List[int] = Factory(list)
962
963     def mark(self, leaf: Leaf) -> None:
964         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
965
966         All leaves receive an int `bracket_depth` field that stores how deep
967         within brackets a given leaf is. 0 means there are no enclosing brackets
968         that started on this line.
969
970         If a leaf is itself a closing bracket, it receives an `opening_bracket`
971         field that it forms a pair with. This is a one-directional link to
972         avoid reference cycles.
973
974         If a leaf is a delimiter (a token on which Black can split the line if
975         needed) and it's on depth 0, its `id()` is stored in the tracker's
976         `delimiters` field.
977         """
978         if leaf.type == token.COMMENT:
979             return
980
981         self.maybe_decrement_after_for_loop_variable(leaf)
982         self.maybe_decrement_after_lambda_arguments(leaf)
983         if leaf.type in CLOSING_BRACKETS:
984             self.depth -= 1
985             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
986             leaf.opening_bracket = opening_bracket
987         leaf.bracket_depth = self.depth
988         if self.depth == 0:
989             delim = is_split_before_delimiter(leaf, self.previous)
990             if delim and self.previous is not None:
991                 self.delimiters[id(self.previous)] = delim
992             else:
993                 delim = is_split_after_delimiter(leaf, self.previous)
994                 if delim:
995                     self.delimiters[id(leaf)] = delim
996         if leaf.type in OPENING_BRACKETS:
997             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
998             self.depth += 1
999         self.previous = leaf
1000         self.maybe_increment_lambda_arguments(leaf)
1001         self.maybe_increment_for_loop_variable(leaf)
1002
1003     def any_open_brackets(self) -> bool:
1004         """Return True if there is an yet unmatched open bracket on the line."""
1005         return bool(self.bracket_match)
1006
1007     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1008         """Return the highest priority of a delimiter found on the line.
1009
1010         Values are consistent with what `is_split_*_delimiter()` return.
1011         Raises ValueError on no delimiters.
1012         """
1013         return max(v for k, v in self.delimiters.items() if k not in exclude)
1014
1015     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1016         """Return the number of delimiters with the given `priority`.
1017
1018         If no `priority` is passed, defaults to max priority on the line.
1019         """
1020         if not self.delimiters:
1021             return 0
1022
1023         priority = priority or self.max_delimiter_priority()
1024         return sum(1 for p in self.delimiters.values() if p == priority)
1025
1026     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1027         """In a for loop, or comprehension, the variables are often unpacks.
1028
1029         To avoid splitting on the comma in this situation, increase the depth of
1030         tokens between `for` and `in`.
1031         """
1032         if leaf.type == token.NAME and leaf.value == "for":
1033             self.depth += 1
1034             self._for_loop_depths.append(self.depth)
1035             return True
1036
1037         return False
1038
1039     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1040         """See `maybe_increment_for_loop_variable` above for explanation."""
1041         if (
1042             self._for_loop_depths
1043             and self._for_loop_depths[-1] == self.depth
1044             and leaf.type == token.NAME
1045             and leaf.value == "in"
1046         ):
1047             self.depth -= 1
1048             self._for_loop_depths.pop()
1049             return True
1050
1051         return False
1052
1053     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1054         """In a lambda expression, there might be more than one argument.
1055
1056         To avoid splitting on the comma in this situation, increase the depth of
1057         tokens between `lambda` and `:`.
1058         """
1059         if leaf.type == token.NAME and leaf.value == "lambda":
1060             self.depth += 1
1061             self._lambda_argument_depths.append(self.depth)
1062             return True
1063
1064         return False
1065
1066     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1067         """See `maybe_increment_lambda_arguments` above for explanation."""
1068         if (
1069             self._lambda_argument_depths
1070             and self._lambda_argument_depths[-1] == self.depth
1071             and leaf.type == token.COLON
1072         ):
1073             self.depth -= 1
1074             self._lambda_argument_depths.pop()
1075             return True
1076
1077         return False
1078
1079     def get_open_lsqb(self) -> Optional[Leaf]:
1080         """Return the most recent opening square bracket (if any)."""
1081         return self.bracket_match.get((self.depth - 1, token.RSQB))
1082
1083
1084 @dataclass
1085 class Line:
1086     """Holds leaves and comments. Can be printed with `str(line)`."""
1087
1088     depth: int = 0
1089     leaves: List[Leaf] = Factory(list)
1090     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1091     bracket_tracker: BracketTracker = Factory(BracketTracker)
1092     inside_brackets: bool = False
1093     should_explode: bool = False
1094
1095     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1096         """Add a new `leaf` to the end of the line.
1097
1098         Unless `preformatted` is True, the `leaf` will receive a new consistent
1099         whitespace prefix and metadata applied by :class:`BracketTracker`.
1100         Trailing commas are maybe removed, unpacked for loop variables are
1101         demoted from being delimiters.
1102
1103         Inline comments are put aside.
1104         """
1105         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1106         if not has_value:
1107             return
1108
1109         if token.COLON == leaf.type and self.is_class_paren_empty:
1110             del self.leaves[-2:]
1111         if self.leaves and not preformatted:
1112             # Note: at this point leaf.prefix should be empty except for
1113             # imports, for which we only preserve newlines.
1114             leaf.prefix += whitespace(
1115                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1116             )
1117         if self.inside_brackets or not preformatted:
1118             self.bracket_tracker.mark(leaf)
1119             self.maybe_remove_trailing_comma(leaf)
1120         if not self.append_comment(leaf):
1121             self.leaves.append(leaf)
1122
1123     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1124         """Like :func:`append()` but disallow invalid standalone comment structure.
1125
1126         Raises ValueError when any `leaf` is appended after a standalone comment
1127         or when a standalone comment is not the first leaf on the line.
1128         """
1129         if self.bracket_tracker.depth == 0:
1130             if self.is_comment:
1131                 raise ValueError("cannot append to standalone comments")
1132
1133             if self.leaves and leaf.type == STANDALONE_COMMENT:
1134                 raise ValueError(
1135                     "cannot append standalone comments to a populated line"
1136                 )
1137
1138         self.append(leaf, preformatted=preformatted)
1139
1140     @property
1141     def is_comment(self) -> bool:
1142         """Is this line a standalone comment?"""
1143         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1144
1145     @property
1146     def is_decorator(self) -> bool:
1147         """Is this line a decorator?"""
1148         return bool(self) and self.leaves[0].type == token.AT
1149
1150     @property
1151     def is_import(self) -> bool:
1152         """Is this an import line?"""
1153         return bool(self) and is_import(self.leaves[0])
1154
1155     @property
1156     def is_class(self) -> bool:
1157         """Is this line a class definition?"""
1158         return (
1159             bool(self)
1160             and self.leaves[0].type == token.NAME
1161             and self.leaves[0].value == "class"
1162         )
1163
1164     @property
1165     def is_stub_class(self) -> bool:
1166         """Is this line a class definition with a body consisting only of "..."?"""
1167         return self.is_class and self.leaves[-3:] == [
1168             Leaf(token.DOT, ".") for _ in range(3)
1169         ]
1170
1171     @property
1172     def is_def(self) -> bool:
1173         """Is this a function definition? (Also returns True for async defs.)"""
1174         try:
1175             first_leaf = self.leaves[0]
1176         except IndexError:
1177             return False
1178
1179         try:
1180             second_leaf: Optional[Leaf] = self.leaves[1]
1181         except IndexError:
1182             second_leaf = None
1183         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1184             first_leaf.type == token.ASYNC
1185             and second_leaf is not None
1186             and second_leaf.type == token.NAME
1187             and second_leaf.value == "def"
1188         )
1189
1190     @property
1191     def is_class_paren_empty(self) -> bool:
1192         """Is this a class with no base classes but using parentheses?
1193
1194         Those are unnecessary and should be removed.
1195         """
1196         return (
1197             bool(self)
1198             and len(self.leaves) == 4
1199             and self.is_class
1200             and self.leaves[2].type == token.LPAR
1201             and self.leaves[2].value == "("
1202             and self.leaves[3].type == token.RPAR
1203             and self.leaves[3].value == ")"
1204         )
1205
1206     @property
1207     def is_triple_quoted_string(self) -> bool:
1208         """Is the line a triple quoted string?"""
1209         return (
1210             bool(self)
1211             and self.leaves[0].type == token.STRING
1212             and self.leaves[0].value.startswith(('"""', "'''"))
1213         )
1214
1215     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1216         """If so, needs to be split before emitting."""
1217         for leaf in self.leaves:
1218             if leaf.type == STANDALONE_COMMENT:
1219                 if leaf.bracket_depth <= depth_limit:
1220                     return True
1221         return False
1222
1223     def contains_inner_type_comments(self) -> bool:
1224         ignored_ids = set()
1225         try:
1226             last_leaf = self.leaves[-1]
1227             ignored_ids.add(id(last_leaf))
1228             if last_leaf.type == token.COMMA:
1229                 # When trailing commas are inserted by Black for consistency, comments
1230                 # after the previous last element are not moved (they don't have to,
1231                 # rendering will still be correct).  So we ignore trailing commas.
1232                 last_leaf = self.leaves[-2]
1233                 ignored_ids.add(id(last_leaf))
1234         except IndexError:
1235             return False
1236
1237         for leaf_id, comments in self.comments.items():
1238             if leaf_id in ignored_ids:
1239                 continue
1240
1241             for comment in comments:
1242                 if is_type_comment(comment):
1243                     return True
1244
1245         return False
1246
1247     def contains_multiline_strings(self) -> bool:
1248         for leaf in self.leaves:
1249             if is_multiline_string(leaf):
1250                 return True
1251
1252         return False
1253
1254     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1255         """Remove trailing comma if there is one and it's safe."""
1256         if not (
1257             self.leaves
1258             and self.leaves[-1].type == token.COMMA
1259             and closing.type in CLOSING_BRACKETS
1260         ):
1261             return False
1262
1263         if closing.type == token.RBRACE:
1264             self.remove_trailing_comma()
1265             return True
1266
1267         if closing.type == token.RSQB:
1268             comma = self.leaves[-1]
1269             if comma.parent and comma.parent.type == syms.listmaker:
1270                 self.remove_trailing_comma()
1271                 return True
1272
1273         # For parens let's check if it's safe to remove the comma.
1274         # Imports are always safe.
1275         if self.is_import:
1276             self.remove_trailing_comma()
1277             return True
1278
1279         # Otherwise, if the trailing one is the only one, we might mistakenly
1280         # change a tuple into a different type by removing the comma.
1281         depth = closing.bracket_depth + 1
1282         commas = 0
1283         opening = closing.opening_bracket
1284         for _opening_index, leaf in enumerate(self.leaves):
1285             if leaf is opening:
1286                 break
1287
1288         else:
1289             return False
1290
1291         for leaf in self.leaves[_opening_index + 1 :]:
1292             if leaf is closing:
1293                 break
1294
1295             bracket_depth = leaf.bracket_depth
1296             if bracket_depth == depth and leaf.type == token.COMMA:
1297                 commas += 1
1298                 if leaf.parent and leaf.parent.type == syms.arglist:
1299                     commas += 1
1300                     break
1301
1302         if commas > 1:
1303             self.remove_trailing_comma()
1304             return True
1305
1306         return False
1307
1308     def append_comment(self, comment: Leaf) -> bool:
1309         """Add an inline or standalone comment to the line."""
1310         if (
1311             comment.type == STANDALONE_COMMENT
1312             and self.bracket_tracker.any_open_brackets()
1313         ):
1314             comment.prefix = ""
1315             return False
1316
1317         if comment.type != token.COMMENT:
1318             return False
1319
1320         if not self.leaves:
1321             comment.type = STANDALONE_COMMENT
1322             comment.prefix = ""
1323             return False
1324
1325         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1326         return True
1327
1328     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1329         """Generate comments that should appear directly after `leaf`."""
1330         return self.comments.get(id(leaf), [])
1331
1332     def remove_trailing_comma(self) -> None:
1333         """Remove the trailing comma and moves the comments attached to it."""
1334         trailing_comma = self.leaves.pop()
1335         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1336         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1337             trailing_comma_comments
1338         )
1339
1340     def is_complex_subscript(self, leaf: Leaf) -> bool:
1341         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1342         open_lsqb = self.bracket_tracker.get_open_lsqb()
1343         if open_lsqb is None:
1344             return False
1345
1346         subscript_start = open_lsqb.next_sibling
1347
1348         if isinstance(subscript_start, Node):
1349             if subscript_start.type == syms.listmaker:
1350                 return False
1351
1352             if subscript_start.type == syms.subscriptlist:
1353                 subscript_start = child_towards(subscript_start, leaf)
1354         return subscript_start is not None and any(
1355             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1356         )
1357
1358     def __str__(self) -> str:
1359         """Render the line."""
1360         if not self:
1361             return "\n"
1362
1363         indent = "    " * self.depth
1364         leaves = iter(self.leaves)
1365         first = next(leaves)
1366         res = f"{first.prefix}{indent}{first.value}"
1367         for leaf in leaves:
1368             res += str(leaf)
1369         for comment in itertools.chain.from_iterable(self.comments.values()):
1370             res += str(comment)
1371         return res + "\n"
1372
1373     def __bool__(self) -> bool:
1374         """Return True if the line has leaves or comments."""
1375         return bool(self.leaves or self.comments)
1376
1377
1378 @dataclass
1379 class EmptyLineTracker:
1380     """Provides a stateful method that returns the number of potential extra
1381     empty lines needed before and after the currently processed line.
1382
1383     Note: this tracker works on lines that haven't been split yet.  It assumes
1384     the prefix of the first leaf consists of optional newlines.  Those newlines
1385     are consumed by `maybe_empty_lines()` and included in the computation.
1386     """
1387
1388     is_pyi: bool = False
1389     previous_line: Optional[Line] = None
1390     previous_after: int = 0
1391     previous_defs: List[int] = Factory(list)
1392
1393     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1394         """Return the number of extra empty lines before and after the `current_line`.
1395
1396         This is for separating `def`, `async def` and `class` with extra empty
1397         lines (two on module-level).
1398         """
1399         before, after = self._maybe_empty_lines(current_line)
1400         before -= self.previous_after
1401         self.previous_after = after
1402         self.previous_line = current_line
1403         return before, after
1404
1405     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1406         max_allowed = 1
1407         if current_line.depth == 0:
1408             max_allowed = 1 if self.is_pyi else 2
1409         if current_line.leaves:
1410             # Consume the first leaf's extra newlines.
1411             first_leaf = current_line.leaves[0]
1412             before = first_leaf.prefix.count("\n")
1413             before = min(before, max_allowed)
1414             first_leaf.prefix = ""
1415         else:
1416             before = 0
1417         depth = current_line.depth
1418         while self.previous_defs and self.previous_defs[-1] >= depth:
1419             self.previous_defs.pop()
1420             if self.is_pyi:
1421                 before = 0 if depth else 1
1422             else:
1423                 before = 1 if depth else 2
1424         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1425             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1426
1427         if (
1428             self.previous_line
1429             and self.previous_line.is_import
1430             and not current_line.is_import
1431             and depth == self.previous_line.depth
1432         ):
1433             return (before or 1), 0
1434
1435         if (
1436             self.previous_line
1437             and self.previous_line.is_class
1438             and current_line.is_triple_quoted_string
1439         ):
1440             return before, 1
1441
1442         return before, 0
1443
1444     def _maybe_empty_lines_for_class_or_def(
1445         self, current_line: Line, before: int
1446     ) -> Tuple[int, int]:
1447         if not current_line.is_decorator:
1448             self.previous_defs.append(current_line.depth)
1449         if self.previous_line is None:
1450             # Don't insert empty lines before the first line in the file.
1451             return 0, 0
1452
1453         if self.previous_line.is_decorator:
1454             return 0, 0
1455
1456         if self.previous_line.depth < current_line.depth and (
1457             self.previous_line.is_class or self.previous_line.is_def
1458         ):
1459             return 0, 0
1460
1461         if (
1462             self.previous_line.is_comment
1463             and self.previous_line.depth == current_line.depth
1464             and before == 0
1465         ):
1466             return 0, 0
1467
1468         if self.is_pyi:
1469             if self.previous_line.depth > current_line.depth:
1470                 newlines = 1
1471             elif current_line.is_class or self.previous_line.is_class:
1472                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1473                     # No blank line between classes with an empty body
1474                     newlines = 0
1475                 else:
1476                     newlines = 1
1477             elif current_line.is_def and not self.previous_line.is_def:
1478                 # Blank line between a block of functions and a block of non-functions
1479                 newlines = 1
1480             else:
1481                 newlines = 0
1482         else:
1483             newlines = 2
1484         if current_line.depth and newlines:
1485             newlines -= 1
1486         return newlines, 0
1487
1488
1489 @dataclass
1490 class LineGenerator(Visitor[Line]):
1491     """Generates reformatted Line objects.  Empty lines are not emitted.
1492
1493     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1494     in ways that will no longer stringify to valid Python code on the tree.
1495     """
1496
1497     is_pyi: bool = False
1498     normalize_strings: bool = True
1499     current_line: Line = Factory(Line)
1500     remove_u_prefix: bool = False
1501
1502     def line(self, indent: int = 0) -> Iterator[Line]:
1503         """Generate a line.
1504
1505         If the line is empty, only emit if it makes sense.
1506         If the line is too long, split it first and then generate.
1507
1508         If any lines were generated, set up a new current_line.
1509         """
1510         if not self.current_line:
1511             self.current_line.depth += indent
1512             return  # Line is empty, don't emit. Creating a new one unnecessary.
1513
1514         complete_line = self.current_line
1515         self.current_line = Line(depth=complete_line.depth + indent)
1516         yield complete_line
1517
1518     def visit_default(self, node: LN) -> Iterator[Line]:
1519         """Default `visit_*()` implementation. Recurses to children of `node`."""
1520         if isinstance(node, Leaf):
1521             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1522             for comment in generate_comments(node):
1523                 if any_open_brackets:
1524                     # any comment within brackets is subject to splitting
1525                     self.current_line.append(comment)
1526                 elif comment.type == token.COMMENT:
1527                     # regular trailing comment
1528                     self.current_line.append(comment)
1529                     yield from self.line()
1530
1531                 else:
1532                     # regular standalone comment
1533                     yield from self.line()
1534
1535                     self.current_line.append(comment)
1536                     yield from self.line()
1537
1538             normalize_prefix(node, inside_brackets=any_open_brackets)
1539             if self.normalize_strings and node.type == token.STRING:
1540                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1541                 normalize_string_quotes(node)
1542             if node.type == token.NUMBER:
1543                 normalize_numeric_literal(node)
1544             if node.type not in WHITESPACE:
1545                 self.current_line.append(node)
1546         yield from super().visit_default(node)
1547
1548     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1549         """Increase indentation level, maybe yield a line."""
1550         # In blib2to3 INDENT never holds comments.
1551         yield from self.line(+1)
1552         yield from self.visit_default(node)
1553
1554     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1555         """Decrease indentation level, maybe yield a line."""
1556         # The current line might still wait for trailing comments.  At DEDENT time
1557         # there won't be any (they would be prefixes on the preceding NEWLINE).
1558         # Emit the line then.
1559         yield from self.line()
1560
1561         # While DEDENT has no value, its prefix may contain standalone comments
1562         # that belong to the current indentation level.  Get 'em.
1563         yield from self.visit_default(node)
1564
1565         # Finally, emit the dedent.
1566         yield from self.line(-1)
1567
1568     def visit_stmt(
1569         self, node: Node, keywords: Set[str], parens: Set[str]
1570     ) -> Iterator[Line]:
1571         """Visit a statement.
1572
1573         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1574         `def`, `with`, `class`, `assert` and assignments.
1575
1576         The relevant Python language `keywords` for a given statement will be
1577         NAME leaves within it. This methods puts those on a separate line.
1578
1579         `parens` holds a set of string leaf values immediately after which
1580         invisible parens should be put.
1581         """
1582         normalize_invisible_parens(node, parens_after=parens)
1583         for child in node.children:
1584             if child.type == token.NAME and child.value in keywords:  # type: ignore
1585                 yield from self.line()
1586
1587             yield from self.visit(child)
1588
1589     def visit_suite(self, node: Node) -> Iterator[Line]:
1590         """Visit a suite."""
1591         if self.is_pyi and is_stub_suite(node):
1592             yield from self.visit(node.children[2])
1593         else:
1594             yield from self.visit_default(node)
1595
1596     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1597         """Visit a statement without nested statements."""
1598         is_suite_like = node.parent and node.parent.type in STATEMENT
1599         if is_suite_like:
1600             if self.is_pyi and is_stub_body(node):
1601                 yield from self.visit_default(node)
1602             else:
1603                 yield from self.line(+1)
1604                 yield from self.visit_default(node)
1605                 yield from self.line(-1)
1606
1607         else:
1608             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1609                 yield from self.line()
1610             yield from self.visit_default(node)
1611
1612     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1613         """Visit `async def`, `async for`, `async with`."""
1614         yield from self.line()
1615
1616         children = iter(node.children)
1617         for child in children:
1618             yield from self.visit(child)
1619
1620             if child.type == token.ASYNC:
1621                 break
1622
1623         internal_stmt = next(children)
1624         for child in internal_stmt.children:
1625             yield from self.visit(child)
1626
1627     def visit_decorators(self, node: Node) -> Iterator[Line]:
1628         """Visit decorators."""
1629         for child in node.children:
1630             yield from self.line()
1631             yield from self.visit(child)
1632
1633     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1634         """Remove a semicolon and put the other statement on a separate line."""
1635         yield from self.line()
1636
1637     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1638         """End of file. Process outstanding comments and end with a newline."""
1639         yield from self.visit_default(leaf)
1640         yield from self.line()
1641
1642     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1643         if not self.current_line.bracket_tracker.any_open_brackets():
1644             yield from self.line()
1645         yield from self.visit_default(leaf)
1646
1647     def __attrs_post_init__(self) -> None:
1648         """You are in a twisty little maze of passages."""
1649         v = self.visit_stmt
1650         Ø: Set[str] = set()
1651         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1652         self.visit_if_stmt = partial(
1653             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1654         )
1655         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1656         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1657         self.visit_try_stmt = partial(
1658             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1659         )
1660         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1661         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1662         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1663         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1664         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1665         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1666         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1667         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1668         self.visit_async_funcdef = self.visit_async_stmt
1669         self.visit_decorated = self.visit_decorators
1670
1671
1672 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1673 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1674 OPENING_BRACKETS = set(BRACKET.keys())
1675 CLOSING_BRACKETS = set(BRACKET.values())
1676 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1677 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1678
1679
1680 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1681     """Return whitespace prefix if needed for the given `leaf`.
1682
1683     `complex_subscript` signals whether the given leaf is part of a subscription
1684     which has non-trivial arguments, like arithmetic expressions or function calls.
1685     """
1686     NO = ""
1687     SPACE = " "
1688     DOUBLESPACE = "  "
1689     t = leaf.type
1690     p = leaf.parent
1691     v = leaf.value
1692     if t in ALWAYS_NO_SPACE:
1693         return NO
1694
1695     if t == token.COMMENT:
1696         return DOUBLESPACE
1697
1698     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1699     if t == token.COLON and p.type not in {
1700         syms.subscript,
1701         syms.subscriptlist,
1702         syms.sliceop,
1703     }:
1704         return NO
1705
1706     prev = leaf.prev_sibling
1707     if not prev:
1708         prevp = preceding_leaf(p)
1709         if not prevp or prevp.type in OPENING_BRACKETS:
1710             return NO
1711
1712         if t == token.COLON:
1713             if prevp.type == token.COLON:
1714                 return NO
1715
1716             elif prevp.type != token.COMMA and not complex_subscript:
1717                 return NO
1718
1719             return SPACE
1720
1721         if prevp.type == token.EQUAL:
1722             if prevp.parent:
1723                 if prevp.parent.type in {
1724                     syms.arglist,
1725                     syms.argument,
1726                     syms.parameters,
1727                     syms.varargslist,
1728                 }:
1729                     return NO
1730
1731                 elif prevp.parent.type == syms.typedargslist:
1732                     # A bit hacky: if the equal sign has whitespace, it means we
1733                     # previously found it's a typed argument.  So, we're using
1734                     # that, too.
1735                     return prevp.prefix
1736
1737         elif prevp.type in STARS:
1738             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1739                 return NO
1740
1741         elif prevp.type == token.COLON:
1742             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1743                 return SPACE if complex_subscript else NO
1744
1745         elif (
1746             prevp.parent
1747             and prevp.parent.type == syms.factor
1748             and prevp.type in MATH_OPERATORS
1749         ):
1750             return NO
1751
1752         elif (
1753             prevp.type == token.RIGHTSHIFT
1754             and prevp.parent
1755             and prevp.parent.type == syms.shift_expr
1756             and prevp.prev_sibling
1757             and prevp.prev_sibling.type == token.NAME
1758             and prevp.prev_sibling.value == "print"  # type: ignore
1759         ):
1760             # Python 2 print chevron
1761             return NO
1762
1763     elif prev.type in OPENING_BRACKETS:
1764         return NO
1765
1766     if p.type in {syms.parameters, syms.arglist}:
1767         # untyped function signatures or calls
1768         if not prev or prev.type != token.COMMA:
1769             return NO
1770
1771     elif p.type == syms.varargslist:
1772         # lambdas
1773         if prev and prev.type != token.COMMA:
1774             return NO
1775
1776     elif p.type == syms.typedargslist:
1777         # typed function signatures
1778         if not prev:
1779             return NO
1780
1781         if t == token.EQUAL:
1782             if prev.type != syms.tname:
1783                 return NO
1784
1785         elif prev.type == token.EQUAL:
1786             # A bit hacky: if the equal sign has whitespace, it means we
1787             # previously found it's a typed argument.  So, we're using that, too.
1788             return prev.prefix
1789
1790         elif prev.type != token.COMMA:
1791             return NO
1792
1793     elif p.type == syms.tname:
1794         # type names
1795         if not prev:
1796             prevp = preceding_leaf(p)
1797             if not prevp or prevp.type != token.COMMA:
1798                 return NO
1799
1800     elif p.type == syms.trailer:
1801         # attributes and calls
1802         if t == token.LPAR or t == token.RPAR:
1803             return NO
1804
1805         if not prev:
1806             if t == token.DOT:
1807                 prevp = preceding_leaf(p)
1808                 if not prevp or prevp.type != token.NUMBER:
1809                     return NO
1810
1811             elif t == token.LSQB:
1812                 return NO
1813
1814         elif prev.type != token.COMMA:
1815             return NO
1816
1817     elif p.type == syms.argument:
1818         # single argument
1819         if t == token.EQUAL:
1820             return NO
1821
1822         if not prev:
1823             prevp = preceding_leaf(p)
1824             if not prevp or prevp.type == token.LPAR:
1825                 return NO
1826
1827         elif prev.type in {token.EQUAL} | STARS:
1828             return NO
1829
1830     elif p.type == syms.decorator:
1831         # decorators
1832         return NO
1833
1834     elif p.type == syms.dotted_name:
1835         if prev:
1836             return NO
1837
1838         prevp = preceding_leaf(p)
1839         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1840             return NO
1841
1842     elif p.type == syms.classdef:
1843         if t == token.LPAR:
1844             return NO
1845
1846         if prev and prev.type == token.LPAR:
1847             return NO
1848
1849     elif p.type in {syms.subscript, syms.sliceop}:
1850         # indexing
1851         if not prev:
1852             assert p.parent is not None, "subscripts are always parented"
1853             if p.parent.type == syms.subscriptlist:
1854                 return SPACE
1855
1856             return NO
1857
1858         elif not complex_subscript:
1859             return NO
1860
1861     elif p.type == syms.atom:
1862         if prev and t == token.DOT:
1863             # dots, but not the first one.
1864             return NO
1865
1866     elif p.type == syms.dictsetmaker:
1867         # dict unpacking
1868         if prev and prev.type == token.DOUBLESTAR:
1869             return NO
1870
1871     elif p.type in {syms.factor, syms.star_expr}:
1872         # unary ops
1873         if not prev:
1874             prevp = preceding_leaf(p)
1875             if not prevp or prevp.type in OPENING_BRACKETS:
1876                 return NO
1877
1878             prevp_parent = prevp.parent
1879             assert prevp_parent is not None
1880             if prevp.type == token.COLON and prevp_parent.type in {
1881                 syms.subscript,
1882                 syms.sliceop,
1883             }:
1884                 return NO
1885
1886             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1887                 return NO
1888
1889         elif t in {token.NAME, token.NUMBER, token.STRING}:
1890             return NO
1891
1892     elif p.type == syms.import_from:
1893         if t == token.DOT:
1894             if prev and prev.type == token.DOT:
1895                 return NO
1896
1897         elif t == token.NAME:
1898             if v == "import":
1899                 return SPACE
1900
1901             if prev and prev.type == token.DOT:
1902                 return NO
1903
1904     elif p.type == syms.sliceop:
1905         return NO
1906
1907     return SPACE
1908
1909
1910 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1911     """Return the first leaf that precedes `node`, if any."""
1912     while node:
1913         res = node.prev_sibling
1914         if res:
1915             if isinstance(res, Leaf):
1916                 return res
1917
1918             try:
1919                 return list(res.leaves())[-1]
1920
1921             except IndexError:
1922                 return None
1923
1924         node = node.parent
1925     return None
1926
1927
1928 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1929     """Return the child of `ancestor` that contains `descendant`."""
1930     node: Optional[LN] = descendant
1931     while node and node.parent != ancestor:
1932         node = node.parent
1933     return node
1934
1935
1936 def container_of(leaf: Leaf) -> LN:
1937     """Return `leaf` or one of its ancestors that is the topmost container of it.
1938
1939     By "container" we mean a node where `leaf` is the very first child.
1940     """
1941     same_prefix = leaf.prefix
1942     container: LN = leaf
1943     while container:
1944         parent = container.parent
1945         if parent is None:
1946             break
1947
1948         if parent.children[0].prefix != same_prefix:
1949             break
1950
1951         if parent.type == syms.file_input:
1952             break
1953
1954         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1955             break
1956
1957         container = parent
1958     return container
1959
1960
1961 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1962     """Return the priority of the `leaf` delimiter, given a line break after it.
1963
1964     The delimiter priorities returned here are from those delimiters that would
1965     cause a line break after themselves.
1966
1967     Higher numbers are higher priority.
1968     """
1969     if leaf.type == token.COMMA:
1970         return COMMA_PRIORITY
1971
1972     return 0
1973
1974
1975 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1976     """Return the priority of the `leaf` delimiter, given a line break before it.
1977
1978     The delimiter priorities returned here are from those delimiters that would
1979     cause a line break before themselves.
1980
1981     Higher numbers are higher priority.
1982     """
1983     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1984         # * and ** might also be MATH_OPERATORS but in this case they are not.
1985         # Don't treat them as a delimiter.
1986         return 0
1987
1988     if (
1989         leaf.type == token.DOT
1990         and leaf.parent
1991         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1992         and (previous is None or previous.type in CLOSING_BRACKETS)
1993     ):
1994         return DOT_PRIORITY
1995
1996     if (
1997         leaf.type in MATH_OPERATORS
1998         and leaf.parent
1999         and leaf.parent.type not in {syms.factor, syms.star_expr}
2000     ):
2001         return MATH_PRIORITIES[leaf.type]
2002
2003     if leaf.type in COMPARATORS:
2004         return COMPARATOR_PRIORITY
2005
2006     if (
2007         leaf.type == token.STRING
2008         and previous is not None
2009         and previous.type == token.STRING
2010     ):
2011         return STRING_PRIORITY
2012
2013     if leaf.type not in {token.NAME, token.ASYNC}:
2014         return 0
2015
2016     if (
2017         leaf.value == "for"
2018         and leaf.parent
2019         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2020         or leaf.type == token.ASYNC
2021     ):
2022         if (
2023             not isinstance(leaf.prev_sibling, Leaf)
2024             or leaf.prev_sibling.value != "async"
2025         ):
2026             return COMPREHENSION_PRIORITY
2027
2028     if (
2029         leaf.value == "if"
2030         and leaf.parent
2031         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2032     ):
2033         return COMPREHENSION_PRIORITY
2034
2035     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2036         return TERNARY_PRIORITY
2037
2038     if leaf.value == "is":
2039         return COMPARATOR_PRIORITY
2040
2041     if (
2042         leaf.value == "in"
2043         and leaf.parent
2044         and leaf.parent.type in {syms.comp_op, syms.comparison}
2045         and not (
2046             previous is not None
2047             and previous.type == token.NAME
2048             and previous.value == "not"
2049         )
2050     ):
2051         return COMPARATOR_PRIORITY
2052
2053     if (
2054         leaf.value == "not"
2055         and leaf.parent
2056         and leaf.parent.type == syms.comp_op
2057         and not (
2058             previous is not None
2059             and previous.type == token.NAME
2060             and previous.value == "is"
2061         )
2062     ):
2063         return COMPARATOR_PRIORITY
2064
2065     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2066         return LOGIC_PRIORITY
2067
2068     return 0
2069
2070
2071 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2072 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2073
2074
2075 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2076     """Clean the prefix of the `leaf` and generate comments from it, if any.
2077
2078     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2079     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2080     move because it does away with modifying the grammar to include all the
2081     possible places in which comments can be placed.
2082
2083     The sad consequence for us though is that comments don't "belong" anywhere.
2084     This is why this function generates simple parentless Leaf objects for
2085     comments.  We simply don't know what the correct parent should be.
2086
2087     No matter though, we can live without this.  We really only need to
2088     differentiate between inline and standalone comments.  The latter don't
2089     share the line with any code.
2090
2091     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2092     are emitted with a fake STANDALONE_COMMENT token identifier.
2093     """
2094     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2095         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2096
2097
2098 @dataclass
2099 class ProtoComment:
2100     """Describes a piece of syntax that is a comment.
2101
2102     It's not a :class:`blib2to3.pytree.Leaf` so that:
2103
2104     * it can be cached (`Leaf` objects should not be reused more than once as
2105       they store their lineno, column, prefix, and parent information);
2106     * `newlines` and `consumed` fields are kept separate from the `value`. This
2107       simplifies handling of special marker comments like ``# fmt: off/on``.
2108     """
2109
2110     type: int  # token.COMMENT or STANDALONE_COMMENT
2111     value: str  # content of the comment
2112     newlines: int  # how many newlines before the comment
2113     consumed: int  # how many characters of the original leaf's prefix did we consume
2114
2115
2116 @lru_cache(maxsize=4096)
2117 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2118     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2119     result: List[ProtoComment] = []
2120     if not prefix or "#" not in prefix:
2121         return result
2122
2123     consumed = 0
2124     nlines = 0
2125     for index, line in enumerate(prefix.split("\n")):
2126         consumed += len(line) + 1  # adding the length of the split '\n'
2127         line = line.lstrip()
2128         if not line:
2129             nlines += 1
2130         if not line.startswith("#"):
2131             continue
2132
2133         if index == 0 and not is_endmarker:
2134             comment_type = token.COMMENT  # simple trailing comment
2135         else:
2136             comment_type = STANDALONE_COMMENT
2137         comment = make_comment(line)
2138         result.append(
2139             ProtoComment(
2140                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2141             )
2142         )
2143         nlines = 0
2144     return result
2145
2146
2147 def make_comment(content: str) -> str:
2148     """Return a consistently formatted comment from the given `content` string.
2149
2150     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2151     space between the hash sign and the content.
2152
2153     If `content` didn't start with a hash sign, one is provided.
2154     """
2155     content = content.rstrip()
2156     if not content:
2157         return "#"
2158
2159     if content[0] == "#":
2160         content = content[1:]
2161     if content and content[0] not in " !:#'%":
2162         content = " " + content
2163     return "#" + content
2164
2165
2166 def split_line(
2167     line: Line,
2168     line_length: int,
2169     inner: bool = False,
2170     features: Collection[Feature] = (),
2171 ) -> Iterator[Line]:
2172     """Split a `line` into potentially many lines.
2173
2174     They should fit in the allotted `line_length` but might not be able to.
2175     `inner` signifies that there were a pair of brackets somewhere around the
2176     current `line`, possibly transitively. This means we can fallback to splitting
2177     by delimiters if the LHS/RHS don't yield any results.
2178
2179     `features` are syntactical features that may be used in the output.
2180     """
2181     if line.is_comment:
2182         yield line
2183         return
2184
2185     line_str = str(line).strip("\n")
2186
2187     if (
2188         not line.contains_inner_type_comments()
2189         and not line.should_explode
2190         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2191     ):
2192         yield line
2193         return
2194
2195     split_funcs: List[SplitFunc]
2196     if line.is_def:
2197         split_funcs = [left_hand_split]
2198     else:
2199
2200         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2201             for omit in generate_trailers_to_omit(line, line_length):
2202                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2203                 if is_line_short_enough(lines[0], line_length=line_length):
2204                     yield from lines
2205                     return
2206
2207             # All splits failed, best effort split with no omits.
2208             # This mostly happens to multiline strings that are by definition
2209             # reported as not fitting a single line.
2210             yield from right_hand_split(line, line_length, features=features)
2211
2212         if line.inside_brackets:
2213             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2214         else:
2215             split_funcs = [rhs]
2216     for split_func in split_funcs:
2217         # We are accumulating lines in `result` because we might want to abort
2218         # mission and return the original line in the end, or attempt a different
2219         # split altogether.
2220         result: List[Line] = []
2221         try:
2222             for l in split_func(line, features):
2223                 if str(l).strip("\n") == line_str:
2224                     raise CannotSplit("Split function returned an unchanged result")
2225
2226                 result.extend(
2227                     split_line(
2228                         l, line_length=line_length, inner=True, features=features
2229                     )
2230                 )
2231         except CannotSplit:
2232             continue
2233
2234         else:
2235             yield from result
2236             break
2237
2238     else:
2239         yield line
2240
2241
2242 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2243     """Split line into many lines, starting with the first matching bracket pair.
2244
2245     Note: this usually looks weird, only use this for function definitions.
2246     Prefer RHS otherwise.  This is why this function is not symmetrical with
2247     :func:`right_hand_split` which also handles optional parentheses.
2248     """
2249     tail_leaves: List[Leaf] = []
2250     body_leaves: List[Leaf] = []
2251     head_leaves: List[Leaf] = []
2252     current_leaves = head_leaves
2253     matching_bracket = None
2254     for leaf in line.leaves:
2255         if (
2256             current_leaves is body_leaves
2257             and leaf.type in CLOSING_BRACKETS
2258             and leaf.opening_bracket is matching_bracket
2259         ):
2260             current_leaves = tail_leaves if body_leaves else head_leaves
2261         current_leaves.append(leaf)
2262         if current_leaves is head_leaves:
2263             if leaf.type in OPENING_BRACKETS:
2264                 matching_bracket = leaf
2265                 current_leaves = body_leaves
2266     if not matching_bracket:
2267         raise CannotSplit("No brackets found")
2268
2269     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2270     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2271     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2272     bracket_split_succeeded_or_raise(head, body, tail)
2273     for result in (head, body, tail):
2274         if result:
2275             yield result
2276
2277
2278 def right_hand_split(
2279     line: Line,
2280     line_length: int,
2281     features: Collection[Feature] = (),
2282     omit: Collection[LeafID] = (),
2283 ) -> Iterator[Line]:
2284     """Split line into many lines, starting with the last matching bracket pair.
2285
2286     If the split was by optional parentheses, attempt splitting without them, too.
2287     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2288     this split.
2289
2290     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2291     """
2292     tail_leaves: List[Leaf] = []
2293     body_leaves: List[Leaf] = []
2294     head_leaves: List[Leaf] = []
2295     current_leaves = tail_leaves
2296     opening_bracket = None
2297     closing_bracket = None
2298     for leaf in reversed(line.leaves):
2299         if current_leaves is body_leaves:
2300             if leaf is opening_bracket:
2301                 current_leaves = head_leaves if body_leaves else tail_leaves
2302         current_leaves.append(leaf)
2303         if current_leaves is tail_leaves:
2304             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2305                 opening_bracket = leaf.opening_bracket
2306                 closing_bracket = leaf
2307                 current_leaves = body_leaves
2308     if not (opening_bracket and closing_bracket and head_leaves):
2309         # If there is no opening or closing_bracket that means the split failed and
2310         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2311         # the matching `opening_bracket` wasn't available on `line` anymore.
2312         raise CannotSplit("No brackets found")
2313
2314     tail_leaves.reverse()
2315     body_leaves.reverse()
2316     head_leaves.reverse()
2317     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2318     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2319     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2320     bracket_split_succeeded_or_raise(head, body, tail)
2321     if (
2322         # the body shouldn't be exploded
2323         not body.should_explode
2324         # the opening bracket is an optional paren
2325         and opening_bracket.type == token.LPAR
2326         and not opening_bracket.value
2327         # the closing bracket is an optional paren
2328         and closing_bracket.type == token.RPAR
2329         and not closing_bracket.value
2330         # it's not an import (optional parens are the only thing we can split on
2331         # in this case; attempting a split without them is a waste of time)
2332         and not line.is_import
2333         # there are no standalone comments in the body
2334         and not body.contains_standalone_comments(0)
2335         # and we can actually remove the parens
2336         and can_omit_invisible_parens(body, line_length)
2337     ):
2338         omit = {id(closing_bracket), *omit}
2339         try:
2340             yield from right_hand_split(line, line_length, features=features, omit=omit)
2341             return
2342
2343         except CannotSplit:
2344             if not (
2345                 can_be_split(body)
2346                 or is_line_short_enough(body, line_length=line_length)
2347             ):
2348                 raise CannotSplit(
2349                     "Splitting failed, body is still too long and can't be split."
2350                 )
2351
2352             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2353                 raise CannotSplit(
2354                     "The current optional pair of parentheses is bound to fail to "
2355                     "satisfy the splitting algorithm because the head or the tail "
2356                     "contains multiline strings which by definition never fit one "
2357                     "line."
2358                 )
2359
2360     ensure_visible(opening_bracket)
2361     ensure_visible(closing_bracket)
2362     for result in (head, body, tail):
2363         if result:
2364             yield result
2365
2366
2367 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2368     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2369
2370     Do nothing otherwise.
2371
2372     A left- or right-hand split is based on a pair of brackets. Content before
2373     (and including) the opening bracket is left on one line, content inside the
2374     brackets is put on a separate line, and finally content starting with and
2375     following the closing bracket is put on a separate line.
2376
2377     Those are called `head`, `body`, and `tail`, respectively. If the split
2378     produced the same line (all content in `head`) or ended up with an empty `body`
2379     and the `tail` is just the closing bracket, then it's considered failed.
2380     """
2381     tail_len = len(str(tail).strip())
2382     if not body:
2383         if tail_len == 0:
2384             raise CannotSplit("Splitting brackets produced the same line")
2385
2386         elif tail_len < 3:
2387             raise CannotSplit(
2388                 f"Splitting brackets on an empty body to save "
2389                 f"{tail_len} characters is not worth it"
2390             )
2391
2392
2393 def bracket_split_build_line(
2394     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2395 ) -> Line:
2396     """Return a new line with given `leaves` and respective comments from `original`.
2397
2398     If `is_body` is True, the result line is one-indented inside brackets and as such
2399     has its first leaf's prefix normalized and a trailing comma added when expected.
2400     """
2401     result = Line(depth=original.depth)
2402     if is_body:
2403         result.inside_brackets = True
2404         result.depth += 1
2405         if leaves:
2406             # Since body is a new indent level, remove spurious leading whitespace.
2407             normalize_prefix(leaves[0], inside_brackets=True)
2408             # Ensure a trailing comma when expected.
2409             if original.is_import:
2410                 if leaves[-1].type != token.COMMA:
2411                     leaves.append(Leaf(token.COMMA, ","))
2412     # Populate the line
2413     for leaf in leaves:
2414         result.append(leaf, preformatted=True)
2415         for comment_after in original.comments_after(leaf):
2416             result.append(comment_after, preformatted=True)
2417     if is_body:
2418         result.should_explode = should_explode(result, opening_bracket)
2419     return result
2420
2421
2422 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2423     """Normalize prefix of the first leaf in every line returned by `split_func`.
2424
2425     This is a decorator over relevant split functions.
2426     """
2427
2428     @wraps(split_func)
2429     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2430         for l in split_func(line, features):
2431             normalize_prefix(l.leaves[0], inside_brackets=True)
2432             yield l
2433
2434     return split_wrapper
2435
2436
2437 @dont_increase_indentation
2438 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2439     """Split according to delimiters of the highest priority.
2440
2441     If the appropriate Features are given, the split will add trailing commas
2442     also in function signatures and calls that contain `*` and `**`.
2443     """
2444     try:
2445         last_leaf = line.leaves[-1]
2446     except IndexError:
2447         raise CannotSplit("Line empty")
2448
2449     bt = line.bracket_tracker
2450     try:
2451         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2452     except ValueError:
2453         raise CannotSplit("No delimiters found")
2454
2455     if delimiter_priority == DOT_PRIORITY:
2456         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2457             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2458
2459     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2460     lowest_depth = sys.maxsize
2461     trailing_comma_safe = True
2462
2463     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2464         """Append `leaf` to current line or to new line if appending impossible."""
2465         nonlocal current_line
2466         try:
2467             current_line.append_safe(leaf, preformatted=True)
2468         except ValueError:
2469             yield current_line
2470
2471             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2472             current_line.append(leaf)
2473
2474     for leaf in line.leaves:
2475         yield from append_to_line(leaf)
2476
2477         for comment_after in line.comments_after(leaf):
2478             yield from append_to_line(comment_after)
2479
2480         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2481         if leaf.bracket_depth == lowest_depth:
2482             if is_vararg(leaf, within={syms.typedargslist}):
2483                 trailing_comma_safe = (
2484                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2485                 )
2486             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2487                 trailing_comma_safe = (
2488                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2489                 )
2490
2491         leaf_priority = bt.delimiters.get(id(leaf))
2492         if leaf_priority == delimiter_priority:
2493             yield current_line
2494
2495             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2496     if current_line:
2497         if (
2498             trailing_comma_safe
2499             and delimiter_priority == COMMA_PRIORITY
2500             and current_line.leaves[-1].type != token.COMMA
2501             and current_line.leaves[-1].type != STANDALONE_COMMENT
2502         ):
2503             current_line.append(Leaf(token.COMMA, ","))
2504         yield current_line
2505
2506
2507 @dont_increase_indentation
2508 def standalone_comment_split(
2509     line: Line, features: Collection[Feature] = ()
2510 ) -> Iterator[Line]:
2511     """Split standalone comments from the rest of the line."""
2512     if not line.contains_standalone_comments(0):
2513         raise CannotSplit("Line does not have any standalone comments")
2514
2515     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2516
2517     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2518         """Append `leaf` to current line or to new line if appending impossible."""
2519         nonlocal current_line
2520         try:
2521             current_line.append_safe(leaf, preformatted=True)
2522         except ValueError:
2523             yield current_line
2524
2525             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2526             current_line.append(leaf)
2527
2528     for leaf in line.leaves:
2529         yield from append_to_line(leaf)
2530
2531         for comment_after in line.comments_after(leaf):
2532             yield from append_to_line(comment_after)
2533
2534     if current_line:
2535         yield current_line
2536
2537
2538 def is_import(leaf: Leaf) -> bool:
2539     """Return True if the given leaf starts an import statement."""
2540     p = leaf.parent
2541     t = leaf.type
2542     v = leaf.value
2543     return bool(
2544         t == token.NAME
2545         and (
2546             (v == "import" and p and p.type == syms.import_name)
2547             or (v == "from" and p and p.type == syms.import_from)
2548         )
2549     )
2550
2551
2552 def is_type_comment(leaf: Leaf) -> bool:
2553     """Return True if the given leaf is a special comment.
2554     Only returns true for type comments for now."""
2555     t = leaf.type
2556     v = leaf.value
2557     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2558
2559
2560 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2561     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2562     else.
2563
2564     Note: don't use backslashes for formatting or you'll lose your voting rights.
2565     """
2566     if not inside_brackets:
2567         spl = leaf.prefix.split("#")
2568         if "\\" not in spl[0]:
2569             nl_count = spl[-1].count("\n")
2570             if len(spl) > 1:
2571                 nl_count -= 1
2572             leaf.prefix = "\n" * nl_count
2573             return
2574
2575     leaf.prefix = ""
2576
2577
2578 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2579     """Make all string prefixes lowercase.
2580
2581     If remove_u_prefix is given, also removes any u prefix from the string.
2582
2583     Note: Mutates its argument.
2584     """
2585     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2586     assert match is not None, f"failed to match string {leaf.value!r}"
2587     orig_prefix = match.group(1)
2588     new_prefix = orig_prefix.lower()
2589     if remove_u_prefix:
2590         new_prefix = new_prefix.replace("u", "")
2591     leaf.value = f"{new_prefix}{match.group(2)}"
2592
2593
2594 def normalize_string_quotes(leaf: Leaf) -> None:
2595     """Prefer double quotes but only if it doesn't cause more escaping.
2596
2597     Adds or removes backslashes as appropriate. Doesn't parse and fix
2598     strings nested in f-strings (yet).
2599
2600     Note: Mutates its argument.
2601     """
2602     value = leaf.value.lstrip("furbFURB")
2603     if value[:3] == '"""':
2604         return
2605
2606     elif value[:3] == "'''":
2607         orig_quote = "'''"
2608         new_quote = '"""'
2609     elif value[0] == '"':
2610         orig_quote = '"'
2611         new_quote = "'"
2612     else:
2613         orig_quote = "'"
2614         new_quote = '"'
2615     first_quote_pos = leaf.value.find(orig_quote)
2616     if first_quote_pos == -1:
2617         return  # There's an internal error
2618
2619     prefix = leaf.value[:first_quote_pos]
2620     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2621     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2622     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2623     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2624     if "r" in prefix.casefold():
2625         if unescaped_new_quote.search(body):
2626             # There's at least one unescaped new_quote in this raw string
2627             # so converting is impossible
2628             return
2629
2630         # Do not introduce or remove backslashes in raw strings
2631         new_body = body
2632     else:
2633         # remove unnecessary escapes
2634         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2635         if body != new_body:
2636             # Consider the string without unnecessary escapes as the original
2637             body = new_body
2638             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2639         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2640         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2641     if "f" in prefix.casefold():
2642         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2643         for m in matches:
2644             if "\\" in str(m):
2645                 # Do not introduce backslashes in interpolated expressions
2646                 return
2647     if new_quote == '"""' and new_body[-1:] == '"':
2648         # edge case:
2649         new_body = new_body[:-1] + '\\"'
2650     orig_escape_count = body.count("\\")
2651     new_escape_count = new_body.count("\\")
2652     if new_escape_count > orig_escape_count:
2653         return  # Do not introduce more escaping
2654
2655     if new_escape_count == orig_escape_count and orig_quote == '"':
2656         return  # Prefer double quotes
2657
2658     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2659
2660
2661 def normalize_numeric_literal(leaf: Leaf) -> None:
2662     """Normalizes numeric (float, int, and complex) literals.
2663
2664     All letters used in the representation are normalized to lowercase (except
2665     in Python 2 long literals).
2666     """
2667     text = leaf.value.lower()
2668     if text.startswith(("0o", "0b")):
2669         # Leave octal and binary literals alone.
2670         pass
2671     elif text.startswith("0x"):
2672         # Change hex literals to upper case.
2673         before, after = text[:2], text[2:]
2674         text = f"{before}{after.upper()}"
2675     elif "e" in text:
2676         before, after = text.split("e")
2677         sign = ""
2678         if after.startswith("-"):
2679             after = after[1:]
2680             sign = "-"
2681         elif after.startswith("+"):
2682             after = after[1:]
2683         before = format_float_or_int_string(before)
2684         text = f"{before}e{sign}{after}"
2685     elif text.endswith(("j", "l")):
2686         number = text[:-1]
2687         suffix = text[-1]
2688         # Capitalize in "2L" because "l" looks too similar to "1".
2689         if suffix == "l":
2690             suffix = "L"
2691         text = f"{format_float_or_int_string(number)}{suffix}"
2692     else:
2693         text = format_float_or_int_string(text)
2694     leaf.value = text
2695
2696
2697 def format_float_or_int_string(text: str) -> str:
2698     """Formats a float string like "1.0"."""
2699     if "." not in text:
2700         return text
2701
2702     before, after = text.split(".")
2703     return f"{before or 0}.{after or 0}"
2704
2705
2706 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2707     """Make existing optional parentheses invisible or create new ones.
2708
2709     `parens_after` is a set of string leaf values immeditely after which parens
2710     should be put.
2711
2712     Standardizes on visible parentheses for single-element tuples, and keeps
2713     existing visible parentheses for other tuples and generator expressions.
2714     """
2715     for pc in list_comments(node.prefix, is_endmarker=False):
2716         if pc.value in FMT_OFF:
2717             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2718             return
2719
2720     check_lpar = False
2721     for index, child in enumerate(list(node.children)):
2722         if check_lpar:
2723             if child.type == syms.atom:
2724                 if maybe_make_parens_invisible_in_atom(child):
2725                     lpar = Leaf(token.LPAR, "")
2726                     rpar = Leaf(token.RPAR, "")
2727                     index = child.remove() or 0
2728                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2729             elif is_one_tuple(child):
2730                 # wrap child in visible parentheses
2731                 lpar = Leaf(token.LPAR, "(")
2732                 rpar = Leaf(token.RPAR, ")")
2733                 child.remove()
2734                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2735             elif node.type == syms.import_from:
2736                 # "import from" nodes store parentheses directly as part of
2737                 # the statement
2738                 if child.type == token.LPAR:
2739                     # make parentheses invisible
2740                     child.value = ""  # type: ignore
2741                     node.children[-1].value = ""  # type: ignore
2742                 elif child.type != token.STAR:
2743                     # insert invisible parentheses
2744                     node.insert_child(index, Leaf(token.LPAR, ""))
2745                     node.append_child(Leaf(token.RPAR, ""))
2746                 break
2747
2748             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2749                 # wrap child in invisible parentheses
2750                 lpar = Leaf(token.LPAR, "")
2751                 rpar = Leaf(token.RPAR, "")
2752                 index = child.remove() or 0
2753                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2754
2755         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2756
2757
2758 def normalize_fmt_off(node: Node) -> None:
2759     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2760     try_again = True
2761     while try_again:
2762         try_again = convert_one_fmt_off_pair(node)
2763
2764
2765 def convert_one_fmt_off_pair(node: Node) -> bool:
2766     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2767
2768     Returns True if a pair was converted.
2769     """
2770     for leaf in node.leaves():
2771         previous_consumed = 0
2772         for comment in list_comments(leaf.prefix, is_endmarker=False):
2773             if comment.value in FMT_OFF:
2774                 # We only want standalone comments. If there's no previous leaf or
2775                 # the previous leaf is indentation, it's a standalone comment in
2776                 # disguise.
2777                 if comment.type != STANDALONE_COMMENT:
2778                     prev = preceding_leaf(leaf)
2779                     if prev and prev.type not in WHITESPACE:
2780                         continue
2781
2782                 ignored_nodes = list(generate_ignored_nodes(leaf))
2783                 if not ignored_nodes:
2784                     continue
2785
2786                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2787                 parent = first.parent
2788                 prefix = first.prefix
2789                 first.prefix = prefix[comment.consumed :]
2790                 hidden_value = (
2791                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2792                 )
2793                 if hidden_value.endswith("\n"):
2794                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2795                     # leaf (possibly followed by a DEDENT).
2796                     hidden_value = hidden_value[:-1]
2797                 first_idx = None
2798                 for ignored in ignored_nodes:
2799                     index = ignored.remove()
2800                     if first_idx is None:
2801                         first_idx = index
2802                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2803                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2804                 parent.insert_child(
2805                     first_idx,
2806                     Leaf(
2807                         STANDALONE_COMMENT,
2808                         hidden_value,
2809                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2810                     ),
2811                 )
2812                 return True
2813
2814             previous_consumed = comment.consumed
2815
2816     return False
2817
2818
2819 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2820     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2821
2822     Stops at the end of the block.
2823     """
2824     container: Optional[LN] = container_of(leaf)
2825     while container is not None and container.type != token.ENDMARKER:
2826         for comment in list_comments(container.prefix, is_endmarker=False):
2827             if comment.value in FMT_ON:
2828                 return
2829
2830         yield container
2831
2832         container = container.next_sibling
2833
2834
2835 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2836     """If it's safe, make the parens in the atom `node` invisible, recursively.
2837
2838     Returns whether the node should itself be wrapped in invisible parentheses.
2839
2840     """
2841     if (
2842         node.type != syms.atom
2843         or is_empty_tuple(node)
2844         or is_one_tuple(node)
2845         or is_yield(node)
2846         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2847     ):
2848         return False
2849
2850     first = node.children[0]
2851     last = node.children[-1]
2852     if first.type == token.LPAR and last.type == token.RPAR:
2853         # make parentheses invisible
2854         first.value = ""  # type: ignore
2855         last.value = ""  # type: ignore
2856         if len(node.children) > 1:
2857             maybe_make_parens_invisible_in_atom(node.children[1])
2858         return False
2859
2860     return True
2861
2862
2863 def is_empty_tuple(node: LN) -> bool:
2864     """Return True if `node` holds an empty tuple."""
2865     return (
2866         node.type == syms.atom
2867         and len(node.children) == 2
2868         and node.children[0].type == token.LPAR
2869         and node.children[1].type == token.RPAR
2870     )
2871
2872
2873 def is_one_tuple(node: LN) -> bool:
2874     """Return True if `node` holds a tuple with one element, with or without parens."""
2875     if node.type == syms.atom:
2876         if len(node.children) != 3:
2877             return False
2878
2879         lpar, gexp, rpar = node.children
2880         if not (
2881             lpar.type == token.LPAR
2882             and gexp.type == syms.testlist_gexp
2883             and rpar.type == token.RPAR
2884         ):
2885             return False
2886
2887         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2888
2889     return (
2890         node.type in IMPLICIT_TUPLE
2891         and len(node.children) == 2
2892         and node.children[1].type == token.COMMA
2893     )
2894
2895
2896 def is_yield(node: LN) -> bool:
2897     """Return True if `node` holds a `yield` or `yield from` expression."""
2898     if node.type == syms.yield_expr:
2899         return True
2900
2901     if node.type == token.NAME and node.value == "yield":  # type: ignore
2902         return True
2903
2904     if node.type != syms.atom:
2905         return False
2906
2907     if len(node.children) != 3:
2908         return False
2909
2910     lpar, expr, rpar = node.children
2911     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2912         return is_yield(expr)
2913
2914     return False
2915
2916
2917 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2918     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2919
2920     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2921     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2922     extended iterable unpacking (PEP 3132) and additional unpacking
2923     generalizations (PEP 448).
2924     """
2925     if leaf.type not in STARS or not leaf.parent:
2926         return False
2927
2928     p = leaf.parent
2929     if p.type == syms.star_expr:
2930         # Star expressions are also used as assignment targets in extended
2931         # iterable unpacking (PEP 3132).  See what its parent is instead.
2932         if not p.parent:
2933             return False
2934
2935         p = p.parent
2936
2937     return p.type in within
2938
2939
2940 def is_multiline_string(leaf: Leaf) -> bool:
2941     """Return True if `leaf` is a multiline string that actually spans many lines."""
2942     value = leaf.value.lstrip("furbFURB")
2943     return value[:3] in {'"""', "'''"} and "\n" in value
2944
2945
2946 def is_stub_suite(node: Node) -> bool:
2947     """Return True if `node` is a suite with a stub body."""
2948     if (
2949         len(node.children) != 4
2950         or node.children[0].type != token.NEWLINE
2951         or node.children[1].type != token.INDENT
2952         or node.children[3].type != token.DEDENT
2953     ):
2954         return False
2955
2956     return is_stub_body(node.children[2])
2957
2958
2959 def is_stub_body(node: LN) -> bool:
2960     """Return True if `node` is a simple statement containing an ellipsis."""
2961     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2962         return False
2963
2964     if len(node.children) != 2:
2965         return False
2966
2967     child = node.children[0]
2968     return (
2969         child.type == syms.atom
2970         and len(child.children) == 3
2971         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2972     )
2973
2974
2975 def max_delimiter_priority_in_atom(node: LN) -> int:
2976     """Return maximum delimiter priority inside `node`.
2977
2978     This is specific to atoms with contents contained in a pair of parentheses.
2979     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2980     """
2981     if node.type != syms.atom:
2982         return 0
2983
2984     first = node.children[0]
2985     last = node.children[-1]
2986     if not (first.type == token.LPAR and last.type == token.RPAR):
2987         return 0
2988
2989     bt = BracketTracker()
2990     for c in node.children[1:-1]:
2991         if isinstance(c, Leaf):
2992             bt.mark(c)
2993         else:
2994             for leaf in c.leaves():
2995                 bt.mark(leaf)
2996     try:
2997         return bt.max_delimiter_priority()
2998
2999     except ValueError:
3000         return 0
3001
3002
3003 def ensure_visible(leaf: Leaf) -> None:
3004     """Make sure parentheses are visible.
3005
3006     They could be invisible as part of some statements (see
3007     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3008     """
3009     if leaf.type == token.LPAR:
3010         leaf.value = "("
3011     elif leaf.type == token.RPAR:
3012         leaf.value = ")"
3013
3014
3015 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3016     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3017
3018     if not (
3019         opening_bracket.parent
3020         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3021         and opening_bracket.value in "[{("
3022     ):
3023         return False
3024
3025     try:
3026         last_leaf = line.leaves[-1]
3027         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3028         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3029     except (IndexError, ValueError):
3030         return False
3031
3032     return max_priority == COMMA_PRIORITY
3033
3034
3035 def get_features_used(node: Node) -> Set[Feature]:
3036     """Return a set of (relatively) new Python features used in this file.
3037
3038     Currently looking for:
3039     - f-strings;
3040     - underscores in numeric literals; and
3041     - trailing commas after * or ** in function signatures and calls.
3042     """
3043     features: Set[Feature] = set()
3044     for n in node.pre_order():
3045         if n.type == token.STRING:
3046             value_head = n.value[:2]  # type: ignore
3047             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3048                 features.add(Feature.F_STRINGS)
3049
3050         elif n.type == token.NUMBER:
3051             if "_" in n.value:  # type: ignore
3052                 features.add(Feature.NUMERIC_UNDERSCORES)
3053
3054         elif (
3055             n.type in {syms.typedargslist, syms.arglist}
3056             and n.children
3057             and n.children[-1].type == token.COMMA
3058         ):
3059             if n.type == syms.typedargslist:
3060                 feature = Feature.TRAILING_COMMA_IN_DEF
3061             else:
3062                 feature = Feature.TRAILING_COMMA_IN_CALL
3063
3064             for ch in n.children:
3065                 if ch.type in STARS:
3066                     features.add(feature)
3067
3068                 if ch.type == syms.argument:
3069                     for argch in ch.children:
3070                         if argch.type in STARS:
3071                             features.add(feature)
3072
3073     return features
3074
3075
3076 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3077     """Detect the version to target based on the nodes used."""
3078     features = get_features_used(node)
3079     return {
3080         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3081     }
3082
3083
3084 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3085     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3086
3087     Brackets can be omitted if the entire trailer up to and including
3088     a preceding closing bracket fits in one line.
3089
3090     Yielded sets are cumulative (contain results of previous yields, too).  First
3091     set is empty.
3092     """
3093
3094     omit: Set[LeafID] = set()
3095     yield omit
3096
3097     length = 4 * line.depth
3098     opening_bracket = None
3099     closing_bracket = None
3100     inner_brackets: Set[LeafID] = set()
3101     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3102         length += leaf_length
3103         if length > line_length:
3104             break
3105
3106         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3107         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3108             break
3109
3110         if opening_bracket:
3111             if leaf is opening_bracket:
3112                 opening_bracket = None
3113             elif leaf.type in CLOSING_BRACKETS:
3114                 inner_brackets.add(id(leaf))
3115         elif leaf.type in CLOSING_BRACKETS:
3116             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3117                 # Empty brackets would fail a split so treat them as "inner"
3118                 # brackets (e.g. only add them to the `omit` set if another
3119                 # pair of brackets was good enough.
3120                 inner_brackets.add(id(leaf))
3121                 continue
3122
3123             if closing_bracket:
3124                 omit.add(id(closing_bracket))
3125                 omit.update(inner_brackets)
3126                 inner_brackets.clear()
3127                 yield omit
3128
3129             if leaf.value:
3130                 opening_bracket = leaf.opening_bracket
3131                 closing_bracket = leaf
3132
3133
3134 def get_future_imports(node: Node) -> Set[str]:
3135     """Return a set of __future__ imports in the file."""
3136     imports: Set[str] = set()
3137
3138     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3139         for child in children:
3140             if isinstance(child, Leaf):
3141                 if child.type == token.NAME:
3142                     yield child.value
3143             elif child.type == syms.import_as_name:
3144                 orig_name = child.children[0]
3145                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3146                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3147                 yield orig_name.value
3148             elif child.type == syms.import_as_names:
3149                 yield from get_imports_from_children(child.children)
3150             else:
3151                 raise AssertionError("Invalid syntax parsing imports")
3152
3153     for child in node.children:
3154         if child.type != syms.simple_stmt:
3155             break
3156         first_child = child.children[0]
3157         if isinstance(first_child, Leaf):
3158             # Continue looking if we see a docstring; otherwise stop.
3159             if (
3160                 len(child.children) == 2
3161                 and first_child.type == token.STRING
3162                 and child.children[1].type == token.NEWLINE
3163             ):
3164                 continue
3165             else:
3166                 break
3167         elif first_child.type == syms.import_from:
3168             module_name = first_child.children[1]
3169             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3170                 break
3171             imports |= set(get_imports_from_children(first_child.children[3:]))
3172         else:
3173             break
3174     return imports
3175
3176
3177 def gen_python_files_in_dir(
3178     path: Path,
3179     root: Path,
3180     include: Pattern[str],
3181     exclude: Pattern[str],
3182     report: "Report",
3183 ) -> Iterator[Path]:
3184     """Generate all files under `path` whose paths are not excluded by the
3185     `exclude` regex, but are included by the `include` regex.
3186
3187     Symbolic links pointing outside of the `root` directory are ignored.
3188
3189     `report` is where output about exclusions goes.
3190     """
3191     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3192     for child in path.iterdir():
3193         try:
3194             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3195         except ValueError:
3196             if child.is_symlink():
3197                 report.path_ignored(
3198                     child, f"is a symbolic link that points outside {root}"
3199                 )
3200                 continue
3201
3202             raise
3203
3204         if child.is_dir():
3205             normalized_path += "/"
3206         exclude_match = exclude.search(normalized_path)
3207         if exclude_match and exclude_match.group(0):
3208             report.path_ignored(child, f"matches the --exclude regular expression")
3209             continue
3210
3211         if child.is_dir():
3212             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3213
3214         elif child.is_file():
3215             include_match = include.search(normalized_path)
3216             if include_match:
3217                 yield child
3218
3219
3220 @lru_cache()
3221 def find_project_root(srcs: Iterable[str]) -> Path:
3222     """Return a directory containing .git, .hg, or pyproject.toml.
3223
3224     That directory can be one of the directories passed in `srcs` or their
3225     common parent.
3226
3227     If no directory in the tree contains a marker that would specify it's the
3228     project root, the root of the file system is returned.
3229     """
3230     if not srcs:
3231         return Path("/").resolve()
3232
3233     common_base = min(Path(src).resolve() for src in srcs)
3234     if common_base.is_dir():
3235         # Append a fake file so `parents` below returns `common_base_dir`, too.
3236         common_base /= "fake-file"
3237     for directory in common_base.parents:
3238         if (directory / ".git").is_dir():
3239             return directory
3240
3241         if (directory / ".hg").is_dir():
3242             return directory
3243
3244         if (directory / "pyproject.toml").is_file():
3245             return directory
3246
3247     return directory
3248
3249
3250 @dataclass
3251 class Report:
3252     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3253
3254     check: bool = False
3255     quiet: bool = False
3256     verbose: bool = False
3257     change_count: int = 0
3258     same_count: int = 0
3259     failure_count: int = 0
3260
3261     def done(self, src: Path, changed: Changed) -> None:
3262         """Increment the counter for successful reformatting. Write out a message."""
3263         if changed is Changed.YES:
3264             reformatted = "would reformat" if self.check else "reformatted"
3265             if self.verbose or not self.quiet:
3266                 out(f"{reformatted} {src}")
3267             self.change_count += 1
3268         else:
3269             if self.verbose:
3270                 if changed is Changed.NO:
3271                     msg = f"{src} already well formatted, good job."
3272                 else:
3273                     msg = f"{src} wasn't modified on disk since last run."
3274                 out(msg, bold=False)
3275             self.same_count += 1
3276
3277     def failed(self, src: Path, message: str) -> None:
3278         """Increment the counter for failed reformatting. Write out a message."""
3279         err(f"error: cannot format {src}: {message}")
3280         self.failure_count += 1
3281
3282     def path_ignored(self, path: Path, message: str) -> None:
3283         if self.verbose:
3284             out(f"{path} ignored: {message}", bold=False)
3285
3286     @property
3287     def return_code(self) -> int:
3288         """Return the exit code that the app should use.
3289
3290         This considers the current state of changed files and failures:
3291         - if there were any failures, return 123;
3292         - if any files were changed and --check is being used, return 1;
3293         - otherwise return 0.
3294         """
3295         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3296         # 126 we have special return codes reserved by the shell.
3297         if self.failure_count:
3298             return 123
3299
3300         elif self.change_count and self.check:
3301             return 1
3302
3303         return 0
3304
3305     def __str__(self) -> str:
3306         """Render a color report of the current state.
3307
3308         Use `click.unstyle` to remove colors.
3309         """
3310         if self.check:
3311             reformatted = "would be reformatted"
3312             unchanged = "would be left unchanged"
3313             failed = "would fail to reformat"
3314         else:
3315             reformatted = "reformatted"
3316             unchanged = "left unchanged"
3317             failed = "failed to reformat"
3318         report = []
3319         if self.change_count:
3320             s = "s" if self.change_count > 1 else ""
3321             report.append(
3322                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3323             )
3324         if self.same_count:
3325             s = "s" if self.same_count > 1 else ""
3326             report.append(f"{self.same_count} file{s} {unchanged}")
3327         if self.failure_count:
3328             s = "s" if self.failure_count > 1 else ""
3329             report.append(
3330                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3331             )
3332         return ", ".join(report) + "."
3333
3334
3335 def assert_equivalent(src: str, dst: str) -> None:
3336     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3337
3338     import ast
3339     import traceback
3340
3341     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3342         """Simple visitor generating strings to compare ASTs by content."""
3343         yield f"{'  ' * depth}{node.__class__.__name__}("
3344
3345         for field in sorted(node._fields):
3346             try:
3347                 value = getattr(node, field)
3348             except AttributeError:
3349                 continue
3350
3351             yield f"{'  ' * (depth+1)}{field}="
3352
3353             if isinstance(value, list):
3354                 for item in value:
3355                     # Ignore nested tuples within del statements, because we may insert
3356                     # parentheses and they change the AST.
3357                     if (
3358                         field == "targets"
3359                         and isinstance(node, ast.Delete)
3360                         and isinstance(item, ast.Tuple)
3361                     ):
3362                         for item in item.elts:
3363                             yield from _v(item, depth + 2)
3364                     elif isinstance(item, ast.AST):
3365                         yield from _v(item, depth + 2)
3366
3367             elif isinstance(value, ast.AST):
3368                 yield from _v(value, depth + 2)
3369
3370             else:
3371                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3372
3373         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3374
3375     try:
3376         src_ast = ast.parse(src)
3377     except Exception as exc:
3378         major, minor = sys.version_info[:2]
3379         raise AssertionError(
3380             f"cannot use --safe with this file; failed to parse source file "
3381             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3382             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3383         )
3384
3385     try:
3386         dst_ast = ast.parse(dst)
3387     except Exception as exc:
3388         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3389         raise AssertionError(
3390             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3391             f"Please report a bug on https://github.com/python/black/issues.  "
3392             f"This invalid output might be helpful: {log}"
3393         ) from None
3394
3395     src_ast_str = "\n".join(_v(src_ast))
3396     dst_ast_str = "\n".join(_v(dst_ast))
3397     if src_ast_str != dst_ast_str:
3398         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3399         raise AssertionError(
3400             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3401             f"the source.  "
3402             f"Please report a bug on https://github.com/python/black/issues.  "
3403             f"This diff might be helpful: {log}"
3404         ) from None
3405
3406
3407 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3408     """Raise AssertionError if `dst` reformats differently the second time."""
3409     newdst = format_str(dst, mode=mode)
3410     if dst != newdst:
3411         log = dump_to_file(
3412             diff(src, dst, "source", "first pass"),
3413             diff(dst, newdst, "first pass", "second pass"),
3414         )
3415         raise AssertionError(
3416             f"INTERNAL ERROR: Black produced different code on the second pass "
3417             f"of the formatter.  "
3418             f"Please report a bug on https://github.com/python/black/issues.  "
3419             f"This diff might be helpful: {log}"
3420         ) from None
3421
3422
3423 def dump_to_file(*output: str) -> str:
3424     """Dump `output` to a temporary file. Return path to the file."""
3425     import tempfile
3426
3427     with tempfile.NamedTemporaryFile(
3428         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3429     ) as f:
3430         for lines in output:
3431             f.write(lines)
3432             if lines and lines[-1] != "\n":
3433                 f.write("\n")
3434     return f.name
3435
3436
3437 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3438     """Return a unified diff string between strings `a` and `b`."""
3439     import difflib
3440
3441     a_lines = [line + "\n" for line in a.split("\n")]
3442     b_lines = [line + "\n" for line in b.split("\n")]
3443     return "".join(
3444         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3445     )
3446
3447
3448 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3449     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3450     err("Aborted!")
3451     for task in tasks:
3452         task.cancel()
3453
3454
3455 def shutdown(loop: BaseEventLoop) -> None:
3456     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3457     try:
3458         if sys.version_info[:2] >= (3, 7):
3459             all_tasks = asyncio.all_tasks
3460         else:
3461             all_tasks = asyncio.Task.all_tasks
3462         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3463         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3464         if not to_cancel:
3465             return
3466
3467         for task in to_cancel:
3468             task.cancel()
3469         loop.run_until_complete(
3470             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3471         )
3472     finally:
3473         # `concurrent.futures.Future` objects cannot be cancelled once they
3474         # are already running. There might be some when the `shutdown()` happened.
3475         # Silence their logger's spew about the event loop being closed.
3476         cf_logger = logging.getLogger("concurrent.futures")
3477         cf_logger.setLevel(logging.CRITICAL)
3478         loop.close()
3479
3480
3481 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3482     """Replace `regex` with `replacement` twice on `original`.
3483
3484     This is used by string normalization to perform replaces on
3485     overlapping matches.
3486     """
3487     return regex.sub(replacement, regex.sub(replacement, original))
3488
3489
3490 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3491     """Compile a regular expression string in `regex`.
3492
3493     If it contains newlines, use verbose mode.
3494     """
3495     if "\n" in regex:
3496         regex = "(?x)" + regex
3497     return re.compile(regex)
3498
3499
3500 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3501     """Like `reversed(enumerate(sequence))` if that were possible."""
3502     index = len(sequence) - 1
3503     for element in reversed(sequence):
3504         yield (index, element)
3505         index -= 1
3506
3507
3508 def enumerate_with_length(
3509     line: Line, reversed: bool = False
3510 ) -> Iterator[Tuple[Index, Leaf, int]]:
3511     """Return an enumeration of leaves with their length.
3512
3513     Stops prematurely on multiline strings and standalone comments.
3514     """
3515     op = cast(
3516         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3517         enumerate_reversed if reversed else enumerate,
3518     )
3519     for index, leaf in op(line.leaves):
3520         length = len(leaf.prefix) + len(leaf.value)
3521         if "\n" in leaf.value:
3522             return  # Multiline strings, we can't continue.
3523
3524         comment: Optional[Leaf]
3525         for comment in line.comments_after(leaf):
3526             length += len(comment.value)
3527
3528         yield index, leaf, length
3529
3530
3531 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3532     """Return True if `line` is no longer than `line_length`.
3533
3534     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3535     """
3536     if not line_str:
3537         line_str = str(line).strip("\n")
3538     return (
3539         len(line_str) <= line_length
3540         and "\n" not in line_str  # multiline strings
3541         and not line.contains_standalone_comments()
3542     )
3543
3544
3545 def can_be_split(line: Line) -> bool:
3546     """Return False if the line cannot be split *for sure*.
3547
3548     This is not an exhaustive search but a cheap heuristic that we can use to
3549     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3550     in unnecessary parentheses).
3551     """
3552     leaves = line.leaves
3553     if len(leaves) < 2:
3554         return False
3555
3556     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3557         call_count = 0
3558         dot_count = 0
3559         next = leaves[-1]
3560         for leaf in leaves[-2::-1]:
3561             if leaf.type in OPENING_BRACKETS:
3562                 if next.type not in CLOSING_BRACKETS:
3563                     return False
3564
3565                 call_count += 1
3566             elif leaf.type == token.DOT:
3567                 dot_count += 1
3568             elif leaf.type == token.NAME:
3569                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3570                     return False
3571
3572             elif leaf.type not in CLOSING_BRACKETS:
3573                 return False
3574
3575             if dot_count > 1 and call_count > 1:
3576                 return False
3577
3578     return True
3579
3580
3581 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3582     """Does `line` have a shape safe to reformat without optional parens around it?
3583
3584     Returns True for only a subset of potentially nice looking formattings but
3585     the point is to not return false positives that end up producing lines that
3586     are too long.
3587     """
3588     bt = line.bracket_tracker
3589     if not bt.delimiters:
3590         # Without delimiters the optional parentheses are useless.
3591         return True
3592
3593     max_priority = bt.max_delimiter_priority()
3594     if bt.delimiter_count_with_priority(max_priority) > 1:
3595         # With more than one delimiter of a kind the optional parentheses read better.
3596         return False
3597
3598     if max_priority == DOT_PRIORITY:
3599         # A single stranded method call doesn't require optional parentheses.
3600         return True
3601
3602     assert len(line.leaves) >= 2, "Stranded delimiter"
3603
3604     first = line.leaves[0]
3605     second = line.leaves[1]
3606     penultimate = line.leaves[-2]
3607     last = line.leaves[-1]
3608
3609     # With a single delimiter, omit if the expression starts or ends with
3610     # a bracket.
3611     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3612         remainder = False
3613         length = 4 * line.depth
3614         for _index, leaf, leaf_length in enumerate_with_length(line):
3615             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3616                 remainder = True
3617             if remainder:
3618                 length += leaf_length
3619                 if length > line_length:
3620                     break
3621
3622                 if leaf.type in OPENING_BRACKETS:
3623                     # There are brackets we can further split on.
3624                     remainder = False
3625
3626         else:
3627             # checked the entire string and line length wasn't exceeded
3628             if len(line.leaves) == _index + 1:
3629                 return True
3630
3631         # Note: we are not returning False here because a line might have *both*
3632         # a leading opening bracket and a trailing closing bracket.  If the
3633         # opening bracket doesn't match our rule, maybe the closing will.
3634
3635     if (
3636         last.type == token.RPAR
3637         or last.type == token.RBRACE
3638         or (
3639             # don't use indexing for omitting optional parentheses;
3640             # it looks weird
3641             last.type == token.RSQB
3642             and last.parent
3643             and last.parent.type != syms.trailer
3644         )
3645     ):
3646         if penultimate.type in OPENING_BRACKETS:
3647             # Empty brackets don't help.
3648             return False
3649
3650         if is_multiline_string(first):
3651             # Additional wrapping of a multiline string in this situation is
3652             # unnecessary.
3653             return True
3654
3655         length = 4 * line.depth
3656         seen_other_brackets = False
3657         for _index, leaf, leaf_length in enumerate_with_length(line):
3658             length += leaf_length
3659             if leaf is last.opening_bracket:
3660                 if seen_other_brackets or length <= line_length:
3661                     return True
3662
3663             elif leaf.type in OPENING_BRACKETS:
3664                 # There are brackets we can further split on.
3665                 seen_other_brackets = True
3666
3667     return False
3668
3669
3670 def get_cache_file(mode: FileMode) -> Path:
3671     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3672
3673
3674 def read_cache(mode: FileMode) -> Cache:
3675     """Read the cache if it exists and is well formed.
3676
3677     If it is not well formed, the call to write_cache later should resolve the issue.
3678     """
3679     cache_file = get_cache_file(mode)
3680     if not cache_file.exists():
3681         return {}
3682
3683     with cache_file.open("rb") as fobj:
3684         try:
3685             cache: Cache = pickle.load(fobj)
3686         except pickle.UnpicklingError:
3687             return {}
3688
3689     return cache
3690
3691
3692 def get_cache_info(path: Path) -> CacheInfo:
3693     """Return the information used to check if a file is already formatted or not."""
3694     stat = path.stat()
3695     return stat.st_mtime, stat.st_size
3696
3697
3698 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3699     """Split an iterable of paths in `sources` into two sets.
3700
3701     The first contains paths of files that modified on disk or are not in the
3702     cache. The other contains paths to non-modified files.
3703     """
3704     todo, done = set(), set()
3705     for src in sources:
3706         src = src.resolve()
3707         if cache.get(src) != get_cache_info(src):
3708             todo.add(src)
3709         else:
3710             done.add(src)
3711     return todo, done
3712
3713
3714 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3715     """Update the cache file."""
3716     cache_file = get_cache_file(mode)
3717     try:
3718         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3719         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3720         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3721             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3722         os.replace(f.name, cache_file)
3723     except OSError:
3724         pass
3725
3726
3727 def patch_click() -> None:
3728     """Make Click not crash.
3729
3730     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3731     default which restricts paths that it can access during the lifetime of the
3732     application.  Click refuses to work in this scenario by raising a RuntimeError.
3733
3734     In case of Black the likelihood that non-ASCII characters are going to be used in
3735     file paths is minimal since it's Python source code.  Moreover, this crash was
3736     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3737     """
3738     try:
3739         from click import core
3740         from click import _unicodefun  # type: ignore
3741     except ModuleNotFoundError:
3742         return
3743
3744     for module in (core, _unicodefun):
3745         if hasattr(module, "_verify_python3_env"):
3746             module._verify_python3_env = lambda: None
3747
3748
3749 def patched_main() -> None:
3750     freeze_support()
3751     patch_click()
3752     main()
3753
3754
3755 if __name__ == "__main__":
3756     patched_main()