]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add `--target-version` option to allow users to choose targeted Python versions ...
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "18.9b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PYPY35 = 1
117     CPY27 = 2
118     CPY33 = 3
119     CPY34 = 4
120     CPY35 = 5
121     CPY36 = 6
122     CPY37 = 7
123     CPY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.CPY27
127
128
129 PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA = 4
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.CPY27: set(),
142     TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
143     TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
144     TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
145     TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
146     TargetVersion.CPY36: {
147         Feature.UNICODE_LITERALS,
148         Feature.F_STRINGS,
149         Feature.NUMERIC_UNDERSCORES,
150         Feature.TRAILING_COMMA,
151     },
152     TargetVersion.CPY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA,
157     },
158     TargetVersion.CPY38: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA,
163     },
164 }
165
166
167 @dataclass
168 class FileMode:
169     target_versions: Set[TargetVersion] = Factory(set)
170     line_length: int = DEFAULT_LINE_LENGTH
171     numeric_underscore_normalization: bool = True
172     string_normalization: bool = True
173     is_pyi: bool = False
174
175     def get_cache_key(self) -> str:
176         if self.target_versions:
177             version_str = ",".join(
178                 str(version.value)
179                 for version in sorted(self.target_versions, key=lambda v: v.value)
180             )
181         else:
182             version_str = "-"
183         parts = [
184             version_str,
185             str(self.line_length),
186             str(int(self.numeric_underscore_normalization)),
187             str(int(self.string_normalization)),
188             str(int(self.is_pyi)),
189         ]
190         return ".".join(parts)
191
192
193 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
194     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
195
196
197 def read_pyproject_toml(
198     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
199 ) -> Optional[str]:
200     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
201
202     Returns the path to a successfully found and read configuration file, None
203     otherwise.
204     """
205     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
206     if not value:
207         root = find_project_root(ctx.params.get("src", ()))
208         path = root / "pyproject.toml"
209         if path.is_file():
210             value = str(path)
211         else:
212             return None
213
214     try:
215         pyproject_toml = toml.load(value)
216         config = pyproject_toml.get("tool", {}).get("black", {})
217     except (toml.TomlDecodeError, OSError) as e:
218         raise click.FileError(
219             filename=value, hint=f"Error reading configuration file: {e}"
220         )
221
222     if not config:
223         return None
224
225     if ctx.default_map is None:
226         ctx.default_map = {}
227     ctx.default_map.update(  # type: ignore  # bad types in .pyi
228         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
229     )
230     return value
231
232
233 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
234 @click.option(
235     "-l",
236     "--line-length",
237     type=int,
238     default=DEFAULT_LINE_LENGTH,
239     help="How many characters per line to allow.",
240     show_default=True,
241 )
242 @click.option(
243     "-t",
244     "--target-version",
245     type=click.Choice([v.name.lower() for v in TargetVersion]),
246     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
247     multiple=True,
248     help=(
249         "Python versions that should be supported by Black's output. [default: "
250         "per-file auto-detection]"
251     ),
252 )
253 @click.option(
254     "--py36",
255     is_flag=True,
256     help=(
257         "Allow using Python 3.6-only syntax on all input files.  This will put "
258         "trailing commas in function signatures and calls also after *args and "
259         "**kwargs.  [default: per-file auto-detection]"
260     ),
261 )
262 @click.option(
263     "--pyi",
264     is_flag=True,
265     help=(
266         "Format all input files like typing stubs regardless of file extension "
267         "(useful when piping source on standard input)."
268     ),
269 )
270 @click.option(
271     "-S",
272     "--skip-string-normalization",
273     is_flag=True,
274     help="Don't normalize string quotes or prefixes.",
275 )
276 @click.option(
277     "-N",
278     "--skip-numeric-underscore-normalization",
279     is_flag=True,
280     help="Don't normalize underscores in numeric literals.",
281 )
282 @click.option(
283     "--check",
284     is_flag=True,
285     help=(
286         "Don't write the files back, just return the status.  Return code 0 "
287         "means nothing would change.  Return code 1 means some files would be "
288         "reformatted.  Return code 123 means there was an internal error."
289     ),
290 )
291 @click.option(
292     "--diff",
293     is_flag=True,
294     help="Don't write the files back, just output a diff for each file on stdout.",
295 )
296 @click.option(
297     "--fast/--safe",
298     is_flag=True,
299     help="If --fast given, skip temporary sanity checks. [default: --safe]",
300 )
301 @click.option(
302     "--include",
303     type=str,
304     default=DEFAULT_INCLUDES,
305     help=(
306         "A regular expression that matches files and directories that should be "
307         "included on recursive searches.  An empty value means all files are "
308         "included regardless of the name.  Use forward slashes for directories on "
309         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
310         "later."
311     ),
312     show_default=True,
313 )
314 @click.option(
315     "--exclude",
316     type=str,
317     default=DEFAULT_EXCLUDES,
318     help=(
319         "A regular expression that matches files and directories that should be "
320         "excluded on recursive searches.  An empty value means no paths are excluded. "
321         "Use forward slashes for directories on all platforms (Windows, too).  "
322         "Exclusions are calculated first, inclusions later."
323     ),
324     show_default=True,
325 )
326 @click.option(
327     "-q",
328     "--quiet",
329     is_flag=True,
330     help=(
331         "Don't emit non-error messages to stderr. Errors are still emitted, "
332         "silence those with 2>/dev/null."
333     ),
334 )
335 @click.option(
336     "-v",
337     "--verbose",
338     is_flag=True,
339     help=(
340         "Also emit messages to stderr about files that were not changed or were "
341         "ignored due to --exclude=."
342     ),
343 )
344 @click.version_option(version=__version__)
345 @click.argument(
346     "src",
347     nargs=-1,
348     type=click.Path(
349         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
350     ),
351     is_eager=True,
352 )
353 @click.option(
354     "--config",
355     type=click.Path(
356         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
357     ),
358     is_eager=True,
359     callback=read_pyproject_toml,
360     help="Read configuration from PATH.",
361 )
362 @click.pass_context
363 def main(
364     ctx: click.Context,
365     line_length: int,
366     target_version: List[TargetVersion],
367     check: bool,
368     diff: bool,
369     fast: bool,
370     pyi: bool,
371     py36: bool,
372     skip_string_normalization: bool,
373     skip_numeric_underscore_normalization: bool,
374     quiet: bool,
375     verbose: bool,
376     include: str,
377     exclude: str,
378     src: Tuple[str],
379     config: Optional[str],
380 ) -> None:
381     """The uncompromising code formatter."""
382     write_back = WriteBack.from_configuration(check=check, diff=diff)
383     if target_version:
384         if py36:
385             err(f"Cannot use both --target-version and --py36")
386             ctx.exit(2)
387         else:
388             versions = set(target_version)
389     elif py36:
390         versions = PY36_VERSIONS
391     else:
392         # We'll autodetect later.
393         versions = set()
394     mode = FileMode(
395         target_versions=versions,
396         line_length=line_length,
397         is_pyi=pyi,
398         string_normalization=not skip_string_normalization,
399         numeric_underscore_normalization=not skip_numeric_underscore_normalization,
400     )
401     if config and verbose:
402         out(f"Using configuration from {config}.", bold=False, fg="blue")
403     try:
404         include_regex = re_compile_maybe_verbose(include)
405     except re.error:
406         err(f"Invalid regular expression for include given: {include!r}")
407         ctx.exit(2)
408     try:
409         exclude_regex = re_compile_maybe_verbose(exclude)
410     except re.error:
411         err(f"Invalid regular expression for exclude given: {exclude!r}")
412         ctx.exit(2)
413     report = Report(check=check, quiet=quiet, verbose=verbose)
414     root = find_project_root(src)
415     sources: Set[Path] = set()
416     for s in src:
417         p = Path(s)
418         if p.is_dir():
419             sources.update(
420                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
421             )
422         elif p.is_file() or s == "-":
423             # if a file was explicitly given, we don't care about its extension
424             sources.add(p)
425         else:
426             err(f"invalid path: {s}")
427     if len(sources) == 0:
428         if verbose or not quiet:
429             out("No paths given. Nothing to do 😴")
430         ctx.exit(0)
431
432     if len(sources) == 1:
433         reformat_one(
434             src=sources.pop(),
435             fast=fast,
436             write_back=write_back,
437             mode=mode,
438             report=report,
439         )
440     else:
441         loop = asyncio.get_event_loop()
442         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
443         try:
444             loop.run_until_complete(
445                 schedule_formatting(
446                     sources=sources,
447                     fast=fast,
448                     write_back=write_back,
449                     mode=mode,
450                     report=report,
451                     loop=loop,
452                     executor=executor,
453                 )
454             )
455         finally:
456             shutdown(loop)
457     if verbose or not quiet:
458         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
459         out(f"All done! {bang}")
460         click.secho(str(report), err=True)
461     ctx.exit(report.return_code)
462
463
464 def reformat_one(
465     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
466 ) -> None:
467     """Reformat a single file under `src` without spawning child processes.
468
469     If `quiet` is True, non-error messages are not output. `line_length`,
470     `write_back`, `fast` and `pyi` options are passed to
471     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
472     """
473     try:
474         changed = Changed.NO
475         if not src.is_file() and str(src) == "-":
476             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
477                 changed = Changed.YES
478         else:
479             cache: Cache = {}
480             if write_back != WriteBack.DIFF:
481                 cache = read_cache(mode)
482                 res_src = src.resolve()
483                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
484                     changed = Changed.CACHED
485             if changed is not Changed.CACHED and format_file_in_place(
486                 src, fast=fast, write_back=write_back, mode=mode
487             ):
488                 changed = Changed.YES
489             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
490                 write_back is WriteBack.CHECK and changed is Changed.NO
491             ):
492                 write_cache(cache, [src], mode)
493         report.done(src, changed)
494     except Exception as exc:
495         report.failed(src, str(exc))
496
497
498 async def schedule_formatting(
499     sources: Set[Path],
500     fast: bool,
501     write_back: WriteBack,
502     mode: FileMode,
503     report: "Report",
504     loop: BaseEventLoop,
505     executor: Executor,
506 ) -> None:
507     """Run formatting of `sources` in parallel using the provided `executor`.
508
509     (Use ProcessPoolExecutors for actual parallelism.)
510
511     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
512     :func:`format_file_in_place`.
513     """
514     cache: Cache = {}
515     if write_back != WriteBack.DIFF:
516         cache = read_cache(mode)
517         sources, cached = filter_cached(cache, sources)
518         for src in sorted(cached):
519             report.done(src, Changed.CACHED)
520     if not sources:
521         return
522
523     cancelled = []
524     sources_to_cache = []
525     lock = None
526     if write_back == WriteBack.DIFF:
527         # For diff output, we need locks to ensure we don't interleave output
528         # from different processes.
529         manager = Manager()
530         lock = manager.Lock()
531     tasks = {
532         loop.run_in_executor(
533             executor, format_file_in_place, src, fast, mode, write_back, lock
534         ): src
535         for src in sorted(sources)
536     }
537     pending: Iterable[asyncio.Task] = tasks.keys()
538     try:
539         loop.add_signal_handler(signal.SIGINT, cancel, pending)
540         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
541     except NotImplementedError:
542         # There are no good alternatives for these on Windows.
543         pass
544     while pending:
545         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
546         for task in done:
547             src = tasks.pop(task)
548             if task.cancelled():
549                 cancelled.append(task)
550             elif task.exception():
551                 report.failed(src, str(task.exception()))
552             else:
553                 changed = Changed.YES if task.result() else Changed.NO
554                 # If the file was written back or was successfully checked as
555                 # well-formatted, store this information in the cache.
556                 if write_back is WriteBack.YES or (
557                     write_back is WriteBack.CHECK and changed is Changed.NO
558                 ):
559                     sources_to_cache.append(src)
560                 report.done(src, changed)
561     if cancelled:
562         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
563     if sources_to_cache:
564         write_cache(cache, sources_to_cache, mode)
565
566
567 def format_file_in_place(
568     src: Path,
569     fast: bool,
570     mode: FileMode,
571     write_back: WriteBack = WriteBack.NO,
572     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
573 ) -> bool:
574     """Format file under `src` path. Return True if changed.
575
576     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
577     code to the file.
578     `line_length` and `fast` options are passed to :func:`format_file_contents`.
579     """
580     if src.suffix == ".pyi":
581         mode = evolve(mode, is_pyi=True)
582
583     then = datetime.utcfromtimestamp(src.stat().st_mtime)
584     with open(src, "rb") as buf:
585         src_contents, encoding, newline = decode_bytes(buf.read())
586     try:
587         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
588     except NothingChanged:
589         return False
590
591     if write_back == write_back.YES:
592         with open(src, "w", encoding=encoding, newline=newline) as f:
593             f.write(dst_contents)
594     elif write_back == write_back.DIFF:
595         now = datetime.utcnow()
596         src_name = f"{src}\t{then} +0000"
597         dst_name = f"{src}\t{now} +0000"
598         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
599         if lock:
600             lock.acquire()
601         try:
602             f = io.TextIOWrapper(
603                 sys.stdout.buffer,
604                 encoding=encoding,
605                 newline=newline,
606                 write_through=True,
607             )
608             f.write(diff_contents)
609             f.detach()
610         finally:
611             if lock:
612                 lock.release()
613     return True
614
615
616 def format_stdin_to_stdout(
617     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
618 ) -> bool:
619     """Format file on stdin. Return True if changed.
620
621     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
622     write a diff to stdout. The `mode` argument is passed to
623     :func:`format_file_contents`.
624     """
625     then = datetime.utcnow()
626     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
627     dst = src
628     try:
629         dst = format_file_contents(src, fast=fast, mode=mode)
630         return True
631
632     except NothingChanged:
633         return False
634
635     finally:
636         f = io.TextIOWrapper(
637             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
638         )
639         if write_back == WriteBack.YES:
640             f.write(dst)
641         elif write_back == WriteBack.DIFF:
642             now = datetime.utcnow()
643             src_name = f"STDIN\t{then} +0000"
644             dst_name = f"STDOUT\t{now} +0000"
645             f.write(diff(src, dst, src_name, dst_name))
646         f.detach()
647
648
649 def format_file_contents(
650     src_contents: str, *, fast: bool, mode: FileMode
651 ) -> FileContent:
652     """Reformat contents a file and return new contents.
653
654     If `fast` is False, additionally confirm that the reformatted code is
655     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
656     `line_length` is passed to :func:`format_str`.
657     """
658     if src_contents.strip() == "":
659         raise NothingChanged
660
661     dst_contents = format_str(src_contents, mode=mode)
662     if src_contents == dst_contents:
663         raise NothingChanged
664
665     if not fast:
666         assert_equivalent(src_contents, dst_contents)
667         assert_stable(src_contents, dst_contents, mode=mode)
668     return dst_contents
669
670
671 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
672     """Reformat a string and return new contents.
673
674     `line_length` determines how many characters per line are allowed.
675     """
676     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
677     dst_contents = ""
678     future_imports = get_future_imports(src_node)
679     if mode.target_versions:
680         versions = mode.target_versions
681     else:
682         versions = detect_target_versions(src_node)
683     normalize_fmt_off(src_node)
684     lines = LineGenerator(
685         remove_u_prefix="unicode_literals" in future_imports
686         or supports_feature(versions, Feature.UNICODE_LITERALS),
687         is_pyi=mode.is_pyi,
688         normalize_strings=mode.string_normalization,
689         allow_underscores=mode.numeric_underscore_normalization
690         and supports_feature(versions, Feature.NUMERIC_UNDERSCORES),
691     )
692     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
693     empty_line = Line()
694     after = 0
695     for current_line in lines.visit(src_node):
696         for _ in range(after):
697             dst_contents += str(empty_line)
698         before, after = elt.maybe_empty_lines(current_line)
699         for _ in range(before):
700             dst_contents += str(empty_line)
701         for line in split_line(
702             current_line,
703             line_length=mode.line_length,
704             supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
705         ):
706             dst_contents += str(line)
707     return dst_contents
708
709
710 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
711     """Return a tuple of (decoded_contents, encoding, newline).
712
713     `newline` is either CRLF or LF but `decoded_contents` is decoded with
714     universal newlines (i.e. only contains LF).
715     """
716     srcbuf = io.BytesIO(src)
717     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
718     if not lines:
719         return "", encoding, "\n"
720
721     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
722     srcbuf.seek(0)
723     with io.TextIOWrapper(srcbuf, encoding) as tiow:
724         return tiow.read(), encoding, newline
725
726
727 GRAMMARS = [
728     pygram.python_grammar_no_print_statement_no_exec_statement,
729     pygram.python_grammar_no_print_statement,
730     pygram.python_grammar,
731 ]
732
733
734 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
735     if not target_versions:
736         return GRAMMARS
737     elif all(not version.is_python2() for version in target_versions):
738         # Python 2-compatible code, so don't try Python 3 grammar.
739         return [
740             pygram.python_grammar_no_print_statement_no_exec_statement,
741             pygram.python_grammar_no_print_statement,
742         ]
743     else:
744         return [pygram.python_grammar]
745
746
747 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
748     """Given a string with source, return the lib2to3 Node."""
749     if src_txt[-1:] != "\n":
750         src_txt += "\n"
751
752     for grammar in get_grammars(set(target_versions)):
753         drv = driver.Driver(grammar, pytree.convert)
754         try:
755             result = drv.parse_string(src_txt, True)
756             break
757
758         except ParseError as pe:
759             lineno, column = pe.context[1]
760             lines = src_txt.splitlines()
761             try:
762                 faulty_line = lines[lineno - 1]
763             except IndexError:
764                 faulty_line = "<line number missing in source>"
765             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
766     else:
767         raise exc from None
768
769     if isinstance(result, Leaf):
770         result = Node(syms.file_input, [result])
771     return result
772
773
774 def lib2to3_unparse(node: Node) -> str:
775     """Given a lib2to3 node, return its string representation."""
776     code = str(node)
777     return code
778
779
780 T = TypeVar("T")
781
782
783 class Visitor(Generic[T]):
784     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
785
786     def visit(self, node: LN) -> Iterator[T]:
787         """Main method to visit `node` and its children.
788
789         It tries to find a `visit_*()` method for the given `node.type`, like
790         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
791         If no dedicated `visit_*()` method is found, chooses `visit_default()`
792         instead.
793
794         Then yields objects of type `T` from the selected visitor.
795         """
796         if node.type < 256:
797             name = token.tok_name[node.type]
798         else:
799             name = type_repr(node.type)
800         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
801
802     def visit_default(self, node: LN) -> Iterator[T]:
803         """Default `visit_*()` implementation. Recurses to children of `node`."""
804         if isinstance(node, Node):
805             for child in node.children:
806                 yield from self.visit(child)
807
808
809 @dataclass
810 class DebugVisitor(Visitor[T]):
811     tree_depth: int = 0
812
813     def visit_default(self, node: LN) -> Iterator[T]:
814         indent = " " * (2 * self.tree_depth)
815         if isinstance(node, Node):
816             _type = type_repr(node.type)
817             out(f"{indent}{_type}", fg="yellow")
818             self.tree_depth += 1
819             for child in node.children:
820                 yield from self.visit(child)
821
822             self.tree_depth -= 1
823             out(f"{indent}/{_type}", fg="yellow", bold=False)
824         else:
825             _type = token.tok_name.get(node.type, str(node.type))
826             out(f"{indent}{_type}", fg="blue", nl=False)
827             if node.prefix:
828                 # We don't have to handle prefixes for `Node` objects since
829                 # that delegates to the first child anyway.
830                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
831             out(f" {node.value!r}", fg="blue", bold=False)
832
833     @classmethod
834     def show(cls, code: Union[str, Leaf, Node]) -> None:
835         """Pretty-print the lib2to3 AST of a given string of `code`.
836
837         Convenience method for debugging.
838         """
839         v: DebugVisitor[None] = DebugVisitor()
840         if isinstance(code, str):
841             code = lib2to3_parse(code)
842         list(v.visit(code))
843
844
845 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
846 STATEMENT = {
847     syms.if_stmt,
848     syms.while_stmt,
849     syms.for_stmt,
850     syms.try_stmt,
851     syms.except_clause,
852     syms.with_stmt,
853     syms.funcdef,
854     syms.classdef,
855 }
856 STANDALONE_COMMENT = 153
857 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
858 LOGIC_OPERATORS = {"and", "or"}
859 COMPARATORS = {
860     token.LESS,
861     token.GREATER,
862     token.EQEQUAL,
863     token.NOTEQUAL,
864     token.LESSEQUAL,
865     token.GREATEREQUAL,
866 }
867 MATH_OPERATORS = {
868     token.VBAR,
869     token.CIRCUMFLEX,
870     token.AMPER,
871     token.LEFTSHIFT,
872     token.RIGHTSHIFT,
873     token.PLUS,
874     token.MINUS,
875     token.STAR,
876     token.SLASH,
877     token.DOUBLESLASH,
878     token.PERCENT,
879     token.AT,
880     token.TILDE,
881     token.DOUBLESTAR,
882 }
883 STARS = {token.STAR, token.DOUBLESTAR}
884 VARARGS_PARENTS = {
885     syms.arglist,
886     syms.argument,  # double star in arglist
887     syms.trailer,  # single argument to call
888     syms.typedargslist,
889     syms.varargslist,  # lambdas
890 }
891 UNPACKING_PARENTS = {
892     syms.atom,  # single element of a list or set literal
893     syms.dictsetmaker,
894     syms.listmaker,
895     syms.testlist_gexp,
896     syms.testlist_star_expr,
897 }
898 TEST_DESCENDANTS = {
899     syms.test,
900     syms.lambdef,
901     syms.or_test,
902     syms.and_test,
903     syms.not_test,
904     syms.comparison,
905     syms.star_expr,
906     syms.expr,
907     syms.xor_expr,
908     syms.and_expr,
909     syms.shift_expr,
910     syms.arith_expr,
911     syms.trailer,
912     syms.term,
913     syms.power,
914 }
915 ASSIGNMENTS = {
916     "=",
917     "+=",
918     "-=",
919     "*=",
920     "@=",
921     "/=",
922     "%=",
923     "&=",
924     "|=",
925     "^=",
926     "<<=",
927     ">>=",
928     "**=",
929     "//=",
930 }
931 COMPREHENSION_PRIORITY = 20
932 COMMA_PRIORITY = 18
933 TERNARY_PRIORITY = 16
934 LOGIC_PRIORITY = 14
935 STRING_PRIORITY = 12
936 COMPARATOR_PRIORITY = 10
937 MATH_PRIORITIES = {
938     token.VBAR: 9,
939     token.CIRCUMFLEX: 8,
940     token.AMPER: 7,
941     token.LEFTSHIFT: 6,
942     token.RIGHTSHIFT: 6,
943     token.PLUS: 5,
944     token.MINUS: 5,
945     token.STAR: 4,
946     token.SLASH: 4,
947     token.DOUBLESLASH: 4,
948     token.PERCENT: 4,
949     token.AT: 4,
950     token.TILDE: 3,
951     token.DOUBLESTAR: 2,
952 }
953 DOT_PRIORITY = 1
954
955
956 @dataclass
957 class BracketTracker:
958     """Keeps track of brackets on a line."""
959
960     depth: int = 0
961     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
962     delimiters: Dict[LeafID, Priority] = Factory(dict)
963     previous: Optional[Leaf] = None
964     _for_loop_depths: List[int] = Factory(list)
965     _lambda_argument_depths: List[int] = Factory(list)
966
967     def mark(self, leaf: Leaf) -> None:
968         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
969
970         All leaves receive an int `bracket_depth` field that stores how deep
971         within brackets a given leaf is. 0 means there are no enclosing brackets
972         that started on this line.
973
974         If a leaf is itself a closing bracket, it receives an `opening_bracket`
975         field that it forms a pair with. This is a one-directional link to
976         avoid reference cycles.
977
978         If a leaf is a delimiter (a token on which Black can split the line if
979         needed) and it's on depth 0, its `id()` is stored in the tracker's
980         `delimiters` field.
981         """
982         if leaf.type == token.COMMENT:
983             return
984
985         self.maybe_decrement_after_for_loop_variable(leaf)
986         self.maybe_decrement_after_lambda_arguments(leaf)
987         if leaf.type in CLOSING_BRACKETS:
988             self.depth -= 1
989             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
990             leaf.opening_bracket = opening_bracket
991         leaf.bracket_depth = self.depth
992         if self.depth == 0:
993             delim = is_split_before_delimiter(leaf, self.previous)
994             if delim and self.previous is not None:
995                 self.delimiters[id(self.previous)] = delim
996             else:
997                 delim = is_split_after_delimiter(leaf, self.previous)
998                 if delim:
999                     self.delimiters[id(leaf)] = delim
1000         if leaf.type in OPENING_BRACKETS:
1001             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1002             self.depth += 1
1003         self.previous = leaf
1004         self.maybe_increment_lambda_arguments(leaf)
1005         self.maybe_increment_for_loop_variable(leaf)
1006
1007     def any_open_brackets(self) -> bool:
1008         """Return True if there is an yet unmatched open bracket on the line."""
1009         return bool(self.bracket_match)
1010
1011     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1012         """Return the highest priority of a delimiter found on the line.
1013
1014         Values are consistent with what `is_split_*_delimiter()` return.
1015         Raises ValueError on no delimiters.
1016         """
1017         return max(v for k, v in self.delimiters.items() if k not in exclude)
1018
1019     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1020         """Return the number of delimiters with the given `priority`.
1021
1022         If no `priority` is passed, defaults to max priority on the line.
1023         """
1024         if not self.delimiters:
1025             return 0
1026
1027         priority = priority or self.max_delimiter_priority()
1028         return sum(1 for p in self.delimiters.values() if p == priority)
1029
1030     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1031         """In a for loop, or comprehension, the variables are often unpacks.
1032
1033         To avoid splitting on the comma in this situation, increase the depth of
1034         tokens between `for` and `in`.
1035         """
1036         if leaf.type == token.NAME and leaf.value == "for":
1037             self.depth += 1
1038             self._for_loop_depths.append(self.depth)
1039             return True
1040
1041         return False
1042
1043     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1044         """See `maybe_increment_for_loop_variable` above for explanation."""
1045         if (
1046             self._for_loop_depths
1047             and self._for_loop_depths[-1] == self.depth
1048             and leaf.type == token.NAME
1049             and leaf.value == "in"
1050         ):
1051             self.depth -= 1
1052             self._for_loop_depths.pop()
1053             return True
1054
1055         return False
1056
1057     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1058         """In a lambda expression, there might be more than one argument.
1059
1060         To avoid splitting on the comma in this situation, increase the depth of
1061         tokens between `lambda` and `:`.
1062         """
1063         if leaf.type == token.NAME and leaf.value == "lambda":
1064             self.depth += 1
1065             self._lambda_argument_depths.append(self.depth)
1066             return True
1067
1068         return False
1069
1070     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1071         """See `maybe_increment_lambda_arguments` above for explanation."""
1072         if (
1073             self._lambda_argument_depths
1074             and self._lambda_argument_depths[-1] == self.depth
1075             and leaf.type == token.COLON
1076         ):
1077             self.depth -= 1
1078             self._lambda_argument_depths.pop()
1079             return True
1080
1081         return False
1082
1083     def get_open_lsqb(self) -> Optional[Leaf]:
1084         """Return the most recent opening square bracket (if any)."""
1085         return self.bracket_match.get((self.depth - 1, token.RSQB))
1086
1087
1088 @dataclass
1089 class Line:
1090     """Holds leaves and comments. Can be printed with `str(line)`."""
1091
1092     depth: int = 0
1093     leaves: List[Leaf] = Factory(list)
1094     # The LeafID keys of comments must remain ordered by the corresponding leaf's index
1095     # in leaves
1096     comments: Dict[LeafID, List[Leaf]] = Factory(dict)
1097     bracket_tracker: BracketTracker = Factory(BracketTracker)
1098     inside_brackets: bool = False
1099     should_explode: bool = False
1100
1101     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1102         """Add a new `leaf` to the end of the line.
1103
1104         Unless `preformatted` is True, the `leaf` will receive a new consistent
1105         whitespace prefix and metadata applied by :class:`BracketTracker`.
1106         Trailing commas are maybe removed, unpacked for loop variables are
1107         demoted from being delimiters.
1108
1109         Inline comments are put aside.
1110         """
1111         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1112         if not has_value:
1113             return
1114
1115         if token.COLON == leaf.type and self.is_class_paren_empty:
1116             del self.leaves[-2:]
1117         if self.leaves and not preformatted:
1118             # Note: at this point leaf.prefix should be empty except for
1119             # imports, for which we only preserve newlines.
1120             leaf.prefix += whitespace(
1121                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1122             )
1123         if self.inside_brackets or not preformatted:
1124             self.bracket_tracker.mark(leaf)
1125             self.maybe_remove_trailing_comma(leaf)
1126         if not self.append_comment(leaf):
1127             self.leaves.append(leaf)
1128
1129     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1130         """Like :func:`append()` but disallow invalid standalone comment structure.
1131
1132         Raises ValueError when any `leaf` is appended after a standalone comment
1133         or when a standalone comment is not the first leaf on the line.
1134         """
1135         if self.bracket_tracker.depth == 0:
1136             if self.is_comment:
1137                 raise ValueError("cannot append to standalone comments")
1138
1139             if self.leaves and leaf.type == STANDALONE_COMMENT:
1140                 raise ValueError(
1141                     "cannot append standalone comments to a populated line"
1142                 )
1143
1144         self.append(leaf, preformatted=preformatted)
1145
1146     @property
1147     def is_comment(self) -> bool:
1148         """Is this line a standalone comment?"""
1149         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1150
1151     @property
1152     def is_decorator(self) -> bool:
1153         """Is this line a decorator?"""
1154         return bool(self) and self.leaves[0].type == token.AT
1155
1156     @property
1157     def is_import(self) -> bool:
1158         """Is this an import line?"""
1159         return bool(self) and is_import(self.leaves[0])
1160
1161     @property
1162     def is_class(self) -> bool:
1163         """Is this line a class definition?"""
1164         return (
1165             bool(self)
1166             and self.leaves[0].type == token.NAME
1167             and self.leaves[0].value == "class"
1168         )
1169
1170     @property
1171     def is_stub_class(self) -> bool:
1172         """Is this line a class definition with a body consisting only of "..."?"""
1173         return self.is_class and self.leaves[-3:] == [
1174             Leaf(token.DOT, ".") for _ in range(3)
1175         ]
1176
1177     @property
1178     def is_def(self) -> bool:
1179         """Is this a function definition? (Also returns True for async defs.)"""
1180         try:
1181             first_leaf = self.leaves[0]
1182         except IndexError:
1183             return False
1184
1185         try:
1186             second_leaf: Optional[Leaf] = self.leaves[1]
1187         except IndexError:
1188             second_leaf = None
1189         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1190             first_leaf.type == token.ASYNC
1191             and second_leaf is not None
1192             and second_leaf.type == token.NAME
1193             and second_leaf.value == "def"
1194         )
1195
1196     @property
1197     def is_class_paren_empty(self) -> bool:
1198         """Is this a class with no base classes but using parentheses?
1199
1200         Those are unnecessary and should be removed.
1201         """
1202         return (
1203             bool(self)
1204             and len(self.leaves) == 4
1205             and self.is_class
1206             and self.leaves[2].type == token.LPAR
1207             and self.leaves[2].value == "("
1208             and self.leaves[3].type == token.RPAR
1209             and self.leaves[3].value == ")"
1210         )
1211
1212     @property
1213     def is_triple_quoted_string(self) -> bool:
1214         """Is the line a triple quoted string?"""
1215         return (
1216             bool(self)
1217             and self.leaves[0].type == token.STRING
1218             and self.leaves[0].value.startswith(('"""', "'''"))
1219         )
1220
1221     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1222         """If so, needs to be split before emitting."""
1223         for leaf in self.leaves:
1224             if leaf.type == STANDALONE_COMMENT:
1225                 if leaf.bracket_depth <= depth_limit:
1226                     return True
1227
1228         return False
1229
1230     def contains_multiline_strings(self) -> bool:
1231         for leaf in self.leaves:
1232             if is_multiline_string(leaf):
1233                 return True
1234
1235         return False
1236
1237     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1238         """Remove trailing comma if there is one and it's safe."""
1239         if not (
1240             self.leaves
1241             and self.leaves[-1].type == token.COMMA
1242             and closing.type in CLOSING_BRACKETS
1243         ):
1244             return False
1245
1246         if closing.type == token.RBRACE:
1247             self.remove_trailing_comma()
1248             return True
1249
1250         if closing.type == token.RSQB:
1251             comma = self.leaves[-1]
1252             if comma.parent and comma.parent.type == syms.listmaker:
1253                 self.remove_trailing_comma()
1254                 return True
1255
1256         # For parens let's check if it's safe to remove the comma.
1257         # Imports are always safe.
1258         if self.is_import:
1259             self.remove_trailing_comma()
1260             return True
1261
1262         # Otherwise, if the trailing one is the only one, we might mistakenly
1263         # change a tuple into a different type by removing the comma.
1264         depth = closing.bracket_depth + 1
1265         commas = 0
1266         opening = closing.opening_bracket
1267         for _opening_index, leaf in enumerate(self.leaves):
1268             if leaf is opening:
1269                 break
1270
1271         else:
1272             return False
1273
1274         for leaf in self.leaves[_opening_index + 1 :]:
1275             if leaf is closing:
1276                 break
1277
1278             bracket_depth = leaf.bracket_depth
1279             if bracket_depth == depth and leaf.type == token.COMMA:
1280                 commas += 1
1281                 if leaf.parent and leaf.parent.type == syms.arglist:
1282                     commas += 1
1283                     break
1284
1285         if commas > 1:
1286             self.remove_trailing_comma()
1287             return True
1288
1289         return False
1290
1291     def append_comment(self, comment: Leaf) -> bool:
1292         """Add an inline or standalone comment to the line."""
1293         if (
1294             comment.type == STANDALONE_COMMENT
1295             and self.bracket_tracker.any_open_brackets()
1296         ):
1297             comment.prefix = ""
1298             return False
1299
1300         if comment.type != token.COMMENT:
1301             return False
1302
1303         if not self.leaves:
1304             comment.type = STANDALONE_COMMENT
1305             comment.prefix = ""
1306             return False
1307
1308         else:
1309             leaf_id = id(self.leaves[-1])
1310             if leaf_id not in self.comments:
1311                 self.comments[leaf_id] = [comment]
1312             else:
1313                 self.comments[leaf_id].append(comment)
1314             return True
1315
1316     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1317         """Generate comments that should appear directly after `leaf`."""
1318         return self.comments.get(id(leaf), [])
1319
1320     def remove_trailing_comma(self) -> None:
1321         """Remove the trailing comma and moves the comments attached to it."""
1322         # Remember, the LeafID keys of self.comments are ordered by the
1323         # corresponding leaf's index in self.leaves
1324         # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
1325         # Otherwise, we insert it into self.comments, and it becomes the last entry.
1326         # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
1327         # is maintained
1328         self.comments.setdefault(id(self.leaves[-2]), []).extend(
1329             self.comments.get(id(self.leaves[-1]), [])
1330         )
1331         self.comments.pop(id(self.leaves[-1]), None)
1332         self.leaves.pop()
1333
1334     def is_complex_subscript(self, leaf: Leaf) -> bool:
1335         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1336         open_lsqb = self.bracket_tracker.get_open_lsqb()
1337         if open_lsqb is None:
1338             return False
1339
1340         subscript_start = open_lsqb.next_sibling
1341
1342         if isinstance(subscript_start, Node):
1343             if subscript_start.type == syms.listmaker:
1344                 return False
1345
1346             if subscript_start.type == syms.subscriptlist:
1347                 subscript_start = child_towards(subscript_start, leaf)
1348         return subscript_start is not None and any(
1349             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1350         )
1351
1352     def __str__(self) -> str:
1353         """Render the line."""
1354         if not self:
1355             return "\n"
1356
1357         indent = "    " * self.depth
1358         leaves = iter(self.leaves)
1359         first = next(leaves)
1360         res = f"{first.prefix}{indent}{first.value}"
1361         for leaf in leaves:
1362             res += str(leaf)
1363         for comment in itertools.chain.from_iterable(self.comments.values()):
1364             res += str(comment)
1365         return res + "\n"
1366
1367     def __bool__(self) -> bool:
1368         """Return True if the line has leaves or comments."""
1369         return bool(self.leaves or self.comments)
1370
1371
1372 @dataclass
1373 class EmptyLineTracker:
1374     """Provides a stateful method that returns the number of potential extra
1375     empty lines needed before and after the currently processed line.
1376
1377     Note: this tracker works on lines that haven't been split yet.  It assumes
1378     the prefix of the first leaf consists of optional newlines.  Those newlines
1379     are consumed by `maybe_empty_lines()` and included in the computation.
1380     """
1381
1382     is_pyi: bool = False
1383     previous_line: Optional[Line] = None
1384     previous_after: int = 0
1385     previous_defs: List[int] = Factory(list)
1386
1387     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1388         """Return the number of extra empty lines before and after the `current_line`.
1389
1390         This is for separating `def`, `async def` and `class` with extra empty
1391         lines (two on module-level).
1392         """
1393         before, after = self._maybe_empty_lines(current_line)
1394         before -= self.previous_after
1395         self.previous_after = after
1396         self.previous_line = current_line
1397         return before, after
1398
1399     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1400         max_allowed = 1
1401         if current_line.depth == 0:
1402             max_allowed = 1 if self.is_pyi else 2
1403         if current_line.leaves:
1404             # Consume the first leaf's extra newlines.
1405             first_leaf = current_line.leaves[0]
1406             before = first_leaf.prefix.count("\n")
1407             before = min(before, max_allowed)
1408             first_leaf.prefix = ""
1409         else:
1410             before = 0
1411         depth = current_line.depth
1412         while self.previous_defs and self.previous_defs[-1] >= depth:
1413             self.previous_defs.pop()
1414             if self.is_pyi:
1415                 before = 0 if depth else 1
1416             else:
1417                 before = 1 if depth else 2
1418         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1419             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1420
1421         if (
1422             self.previous_line
1423             and self.previous_line.is_import
1424             and not current_line.is_import
1425             and depth == self.previous_line.depth
1426         ):
1427             return (before or 1), 0
1428
1429         if (
1430             self.previous_line
1431             and self.previous_line.is_class
1432             and current_line.is_triple_quoted_string
1433         ):
1434             return before, 1
1435
1436         return before, 0
1437
1438     def _maybe_empty_lines_for_class_or_def(
1439         self, current_line: Line, before: int
1440     ) -> Tuple[int, int]:
1441         if not current_line.is_decorator:
1442             self.previous_defs.append(current_line.depth)
1443         if self.previous_line is None:
1444             # Don't insert empty lines before the first line in the file.
1445             return 0, 0
1446
1447         if self.previous_line.is_decorator:
1448             return 0, 0
1449
1450         if self.previous_line.depth < current_line.depth and (
1451             self.previous_line.is_class or self.previous_line.is_def
1452         ):
1453             return 0, 0
1454
1455         if (
1456             self.previous_line.is_comment
1457             and self.previous_line.depth == current_line.depth
1458             and before == 0
1459         ):
1460             return 0, 0
1461
1462         if self.is_pyi:
1463             if self.previous_line.depth > current_line.depth:
1464                 newlines = 1
1465             elif current_line.is_class or self.previous_line.is_class:
1466                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1467                     # No blank line between classes with an empty body
1468                     newlines = 0
1469                 else:
1470                     newlines = 1
1471             elif current_line.is_def and not self.previous_line.is_def:
1472                 # Blank line between a block of functions and a block of non-functions
1473                 newlines = 1
1474             else:
1475                 newlines = 0
1476         else:
1477             newlines = 2
1478         if current_line.depth and newlines:
1479             newlines -= 1
1480         return newlines, 0
1481
1482
1483 @dataclass
1484 class LineGenerator(Visitor[Line]):
1485     """Generates reformatted Line objects.  Empty lines are not emitted.
1486
1487     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1488     in ways that will no longer stringify to valid Python code on the tree.
1489     """
1490
1491     is_pyi: bool = False
1492     normalize_strings: bool = True
1493     current_line: Line = Factory(Line)
1494     remove_u_prefix: bool = False
1495     allow_underscores: bool = False
1496
1497     def line(self, indent: int = 0) -> Iterator[Line]:
1498         """Generate a line.
1499
1500         If the line is empty, only emit if it makes sense.
1501         If the line is too long, split it first and then generate.
1502
1503         If any lines were generated, set up a new current_line.
1504         """
1505         if not self.current_line:
1506             self.current_line.depth += indent
1507             return  # Line is empty, don't emit. Creating a new one unnecessary.
1508
1509         complete_line = self.current_line
1510         self.current_line = Line(depth=complete_line.depth + indent)
1511         yield complete_line
1512
1513     def visit_default(self, node: LN) -> Iterator[Line]:
1514         """Default `visit_*()` implementation. Recurses to children of `node`."""
1515         if isinstance(node, Leaf):
1516             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1517             for comment in generate_comments(node):
1518                 if any_open_brackets:
1519                     # any comment within brackets is subject to splitting
1520                     self.current_line.append(comment)
1521                 elif comment.type == token.COMMENT:
1522                     # regular trailing comment
1523                     self.current_line.append(comment)
1524                     yield from self.line()
1525
1526                 else:
1527                     # regular standalone comment
1528                     yield from self.line()
1529
1530                     self.current_line.append(comment)
1531                     yield from self.line()
1532
1533             normalize_prefix(node, inside_brackets=any_open_brackets)
1534             if self.normalize_strings and node.type == token.STRING:
1535                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1536                 normalize_string_quotes(node)
1537             if node.type == token.NUMBER:
1538                 normalize_numeric_literal(node, self.allow_underscores)
1539             if node.type not in WHITESPACE:
1540                 self.current_line.append(node)
1541         yield from super().visit_default(node)
1542
1543     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1544         """Increase indentation level, maybe yield a line."""
1545         # In blib2to3 INDENT never holds comments.
1546         yield from self.line(+1)
1547         yield from self.visit_default(node)
1548
1549     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1550         """Decrease indentation level, maybe yield a line."""
1551         # The current line might still wait for trailing comments.  At DEDENT time
1552         # there won't be any (they would be prefixes on the preceding NEWLINE).
1553         # Emit the line then.
1554         yield from self.line()
1555
1556         # While DEDENT has no value, its prefix may contain standalone comments
1557         # that belong to the current indentation level.  Get 'em.
1558         yield from self.visit_default(node)
1559
1560         # Finally, emit the dedent.
1561         yield from self.line(-1)
1562
1563     def visit_stmt(
1564         self, node: Node, keywords: Set[str], parens: Set[str]
1565     ) -> Iterator[Line]:
1566         """Visit a statement.
1567
1568         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1569         `def`, `with`, `class`, `assert` and assignments.
1570
1571         The relevant Python language `keywords` for a given statement will be
1572         NAME leaves within it. This methods puts those on a separate line.
1573
1574         `parens` holds a set of string leaf values immediately after which
1575         invisible parens should be put.
1576         """
1577         normalize_invisible_parens(node, parens_after=parens)
1578         for child in node.children:
1579             if child.type == token.NAME and child.value in keywords:  # type: ignore
1580                 yield from self.line()
1581
1582             yield from self.visit(child)
1583
1584     def visit_suite(self, node: Node) -> Iterator[Line]:
1585         """Visit a suite."""
1586         if self.is_pyi and is_stub_suite(node):
1587             yield from self.visit(node.children[2])
1588         else:
1589             yield from self.visit_default(node)
1590
1591     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1592         """Visit a statement without nested statements."""
1593         is_suite_like = node.parent and node.parent.type in STATEMENT
1594         if is_suite_like:
1595             if self.is_pyi and is_stub_body(node):
1596                 yield from self.visit_default(node)
1597             else:
1598                 yield from self.line(+1)
1599                 yield from self.visit_default(node)
1600                 yield from self.line(-1)
1601
1602         else:
1603             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1604                 yield from self.line()
1605             yield from self.visit_default(node)
1606
1607     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1608         """Visit `async def`, `async for`, `async with`."""
1609         yield from self.line()
1610
1611         children = iter(node.children)
1612         for child in children:
1613             yield from self.visit(child)
1614
1615             if child.type == token.ASYNC:
1616                 break
1617
1618         internal_stmt = next(children)
1619         for child in internal_stmt.children:
1620             yield from self.visit(child)
1621
1622     def visit_decorators(self, node: Node) -> Iterator[Line]:
1623         """Visit decorators."""
1624         for child in node.children:
1625             yield from self.line()
1626             yield from self.visit(child)
1627
1628     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1629         """Remove a semicolon and put the other statement on a separate line."""
1630         yield from self.line()
1631
1632     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1633         """End of file. Process outstanding comments and end with a newline."""
1634         yield from self.visit_default(leaf)
1635         yield from self.line()
1636
1637     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1638         if not self.current_line.bracket_tracker.any_open_brackets():
1639             yield from self.line()
1640         yield from self.visit_default(leaf)
1641
1642     def __attrs_post_init__(self) -> None:
1643         """You are in a twisty little maze of passages."""
1644         v = self.visit_stmt
1645         Ø: Set[str] = set()
1646         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1647         self.visit_if_stmt = partial(
1648             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1649         )
1650         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1651         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1652         self.visit_try_stmt = partial(
1653             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1654         )
1655         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1656         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1657         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1658         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1659         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1660         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1661         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1662         self.visit_async_funcdef = self.visit_async_stmt
1663         self.visit_decorated = self.visit_decorators
1664
1665
1666 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1667 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1668 OPENING_BRACKETS = set(BRACKET.keys())
1669 CLOSING_BRACKETS = set(BRACKET.values())
1670 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1671 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1672
1673
1674 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1675     """Return whitespace prefix if needed for the given `leaf`.
1676
1677     `complex_subscript` signals whether the given leaf is part of a subscription
1678     which has non-trivial arguments, like arithmetic expressions or function calls.
1679     """
1680     NO = ""
1681     SPACE = " "
1682     DOUBLESPACE = "  "
1683     t = leaf.type
1684     p = leaf.parent
1685     v = leaf.value
1686     if t in ALWAYS_NO_SPACE:
1687         return NO
1688
1689     if t == token.COMMENT:
1690         return DOUBLESPACE
1691
1692     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1693     if t == token.COLON and p.type not in {
1694         syms.subscript,
1695         syms.subscriptlist,
1696         syms.sliceop,
1697     }:
1698         return NO
1699
1700     prev = leaf.prev_sibling
1701     if not prev:
1702         prevp = preceding_leaf(p)
1703         if not prevp or prevp.type in OPENING_BRACKETS:
1704             return NO
1705
1706         if t == token.COLON:
1707             if prevp.type == token.COLON:
1708                 return NO
1709
1710             elif prevp.type != token.COMMA and not complex_subscript:
1711                 return NO
1712
1713             return SPACE
1714
1715         if prevp.type == token.EQUAL:
1716             if prevp.parent:
1717                 if prevp.parent.type in {
1718                     syms.arglist,
1719                     syms.argument,
1720                     syms.parameters,
1721                     syms.varargslist,
1722                 }:
1723                     return NO
1724
1725                 elif prevp.parent.type == syms.typedargslist:
1726                     # A bit hacky: if the equal sign has whitespace, it means we
1727                     # previously found it's a typed argument.  So, we're using
1728                     # that, too.
1729                     return prevp.prefix
1730
1731         elif prevp.type in STARS:
1732             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1733                 return NO
1734
1735         elif prevp.type == token.COLON:
1736             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1737                 return SPACE if complex_subscript else NO
1738
1739         elif (
1740             prevp.parent
1741             and prevp.parent.type == syms.factor
1742             and prevp.type in MATH_OPERATORS
1743         ):
1744             return NO
1745
1746         elif (
1747             prevp.type == token.RIGHTSHIFT
1748             and prevp.parent
1749             and prevp.parent.type == syms.shift_expr
1750             and prevp.prev_sibling
1751             and prevp.prev_sibling.type == token.NAME
1752             and prevp.prev_sibling.value == "print"  # type: ignore
1753         ):
1754             # Python 2 print chevron
1755             return NO
1756
1757     elif prev.type in OPENING_BRACKETS:
1758         return NO
1759
1760     if p.type in {syms.parameters, syms.arglist}:
1761         # untyped function signatures or calls
1762         if not prev or prev.type != token.COMMA:
1763             return NO
1764
1765     elif p.type == syms.varargslist:
1766         # lambdas
1767         if prev and prev.type != token.COMMA:
1768             return NO
1769
1770     elif p.type == syms.typedargslist:
1771         # typed function signatures
1772         if not prev:
1773             return NO
1774
1775         if t == token.EQUAL:
1776             if prev.type != syms.tname:
1777                 return NO
1778
1779         elif prev.type == token.EQUAL:
1780             # A bit hacky: if the equal sign has whitespace, it means we
1781             # previously found it's a typed argument.  So, we're using that, too.
1782             return prev.prefix
1783
1784         elif prev.type != token.COMMA:
1785             return NO
1786
1787     elif p.type == syms.tname:
1788         # type names
1789         if not prev:
1790             prevp = preceding_leaf(p)
1791             if not prevp or prevp.type != token.COMMA:
1792                 return NO
1793
1794     elif p.type == syms.trailer:
1795         # attributes and calls
1796         if t == token.LPAR or t == token.RPAR:
1797             return NO
1798
1799         if not prev:
1800             if t == token.DOT:
1801                 prevp = preceding_leaf(p)
1802                 if not prevp or prevp.type != token.NUMBER:
1803                     return NO
1804
1805             elif t == token.LSQB:
1806                 return NO
1807
1808         elif prev.type != token.COMMA:
1809             return NO
1810
1811     elif p.type == syms.argument:
1812         # single argument
1813         if t == token.EQUAL:
1814             return NO
1815
1816         if not prev:
1817             prevp = preceding_leaf(p)
1818             if not prevp or prevp.type == token.LPAR:
1819                 return NO
1820
1821         elif prev.type in {token.EQUAL} | STARS:
1822             return NO
1823
1824     elif p.type == syms.decorator:
1825         # decorators
1826         return NO
1827
1828     elif p.type == syms.dotted_name:
1829         if prev:
1830             return NO
1831
1832         prevp = preceding_leaf(p)
1833         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1834             return NO
1835
1836     elif p.type == syms.classdef:
1837         if t == token.LPAR:
1838             return NO
1839
1840         if prev and prev.type == token.LPAR:
1841             return NO
1842
1843     elif p.type in {syms.subscript, syms.sliceop}:
1844         # indexing
1845         if not prev:
1846             assert p.parent is not None, "subscripts are always parented"
1847             if p.parent.type == syms.subscriptlist:
1848                 return SPACE
1849
1850             return NO
1851
1852         elif not complex_subscript:
1853             return NO
1854
1855     elif p.type == syms.atom:
1856         if prev and t == token.DOT:
1857             # dots, but not the first one.
1858             return NO
1859
1860     elif p.type == syms.dictsetmaker:
1861         # dict unpacking
1862         if prev and prev.type == token.DOUBLESTAR:
1863             return NO
1864
1865     elif p.type in {syms.factor, syms.star_expr}:
1866         # unary ops
1867         if not prev:
1868             prevp = preceding_leaf(p)
1869             if not prevp or prevp.type in OPENING_BRACKETS:
1870                 return NO
1871
1872             prevp_parent = prevp.parent
1873             assert prevp_parent is not None
1874             if prevp.type == token.COLON and prevp_parent.type in {
1875                 syms.subscript,
1876                 syms.sliceop,
1877             }:
1878                 return NO
1879
1880             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1881                 return NO
1882
1883         elif t in {token.NAME, token.NUMBER, token.STRING}:
1884             return NO
1885
1886     elif p.type == syms.import_from:
1887         if t == token.DOT:
1888             if prev and prev.type == token.DOT:
1889                 return NO
1890
1891         elif t == token.NAME:
1892             if v == "import":
1893                 return SPACE
1894
1895             if prev and prev.type == token.DOT:
1896                 return NO
1897
1898     elif p.type == syms.sliceop:
1899         return NO
1900
1901     return SPACE
1902
1903
1904 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1905     """Return the first leaf that precedes `node`, if any."""
1906     while node:
1907         res = node.prev_sibling
1908         if res:
1909             if isinstance(res, Leaf):
1910                 return res
1911
1912             try:
1913                 return list(res.leaves())[-1]
1914
1915             except IndexError:
1916                 return None
1917
1918         node = node.parent
1919     return None
1920
1921
1922 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1923     """Return the child of `ancestor` that contains `descendant`."""
1924     node: Optional[LN] = descendant
1925     while node and node.parent != ancestor:
1926         node = node.parent
1927     return node
1928
1929
1930 def container_of(leaf: Leaf) -> LN:
1931     """Return `leaf` or one of its ancestors that is the topmost container of it.
1932
1933     By "container" we mean a node where `leaf` is the very first child.
1934     """
1935     same_prefix = leaf.prefix
1936     container: LN = leaf
1937     while container:
1938         parent = container.parent
1939         if parent is None:
1940             break
1941
1942         if parent.children[0].prefix != same_prefix:
1943             break
1944
1945         if parent.type == syms.file_input:
1946             break
1947
1948         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1949             break
1950
1951         container = parent
1952     return container
1953
1954
1955 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1956     """Return the priority of the `leaf` delimiter, given a line break after it.
1957
1958     The delimiter priorities returned here are from those delimiters that would
1959     cause a line break after themselves.
1960
1961     Higher numbers are higher priority.
1962     """
1963     if leaf.type == token.COMMA:
1964         return COMMA_PRIORITY
1965
1966     return 0
1967
1968
1969 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1970     """Return the priority of the `leaf` delimiter, given a line break before it.
1971
1972     The delimiter priorities returned here are from those delimiters that would
1973     cause a line break before themselves.
1974
1975     Higher numbers are higher priority.
1976     """
1977     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1978         # * and ** might also be MATH_OPERATORS but in this case they are not.
1979         # Don't treat them as a delimiter.
1980         return 0
1981
1982     if (
1983         leaf.type == token.DOT
1984         and leaf.parent
1985         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1986         and (previous is None or previous.type in CLOSING_BRACKETS)
1987     ):
1988         return DOT_PRIORITY
1989
1990     if (
1991         leaf.type in MATH_OPERATORS
1992         and leaf.parent
1993         and leaf.parent.type not in {syms.factor, syms.star_expr}
1994     ):
1995         return MATH_PRIORITIES[leaf.type]
1996
1997     if leaf.type in COMPARATORS:
1998         return COMPARATOR_PRIORITY
1999
2000     if (
2001         leaf.type == token.STRING
2002         and previous is not None
2003         and previous.type == token.STRING
2004     ):
2005         return STRING_PRIORITY
2006
2007     if leaf.type not in {token.NAME, token.ASYNC}:
2008         return 0
2009
2010     if (
2011         leaf.value == "for"
2012         and leaf.parent
2013         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2014         or leaf.type == token.ASYNC
2015     ):
2016         if (
2017             not isinstance(leaf.prev_sibling, Leaf)
2018             or leaf.prev_sibling.value != "async"
2019         ):
2020             return COMPREHENSION_PRIORITY
2021
2022     if (
2023         leaf.value == "if"
2024         and leaf.parent
2025         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2026     ):
2027         return COMPREHENSION_PRIORITY
2028
2029     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2030         return TERNARY_PRIORITY
2031
2032     if leaf.value == "is":
2033         return COMPARATOR_PRIORITY
2034
2035     if (
2036         leaf.value == "in"
2037         and leaf.parent
2038         and leaf.parent.type in {syms.comp_op, syms.comparison}
2039         and not (
2040             previous is not None
2041             and previous.type == token.NAME
2042             and previous.value == "not"
2043         )
2044     ):
2045         return COMPARATOR_PRIORITY
2046
2047     if (
2048         leaf.value == "not"
2049         and leaf.parent
2050         and leaf.parent.type == syms.comp_op
2051         and not (
2052             previous is not None
2053             and previous.type == token.NAME
2054             and previous.value == "is"
2055         )
2056     ):
2057         return COMPARATOR_PRIORITY
2058
2059     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2060         return LOGIC_PRIORITY
2061
2062     return 0
2063
2064
2065 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2066 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2067
2068
2069 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2070     """Clean the prefix of the `leaf` and generate comments from it, if any.
2071
2072     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2073     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2074     move because it does away with modifying the grammar to include all the
2075     possible places in which comments can be placed.
2076
2077     The sad consequence for us though is that comments don't "belong" anywhere.
2078     This is why this function generates simple parentless Leaf objects for
2079     comments.  We simply don't know what the correct parent should be.
2080
2081     No matter though, we can live without this.  We really only need to
2082     differentiate between inline and standalone comments.  The latter don't
2083     share the line with any code.
2084
2085     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2086     are emitted with a fake STANDALONE_COMMENT token identifier.
2087     """
2088     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2089         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2090
2091
2092 @dataclass
2093 class ProtoComment:
2094     """Describes a piece of syntax that is a comment.
2095
2096     It's not a :class:`blib2to3.pytree.Leaf` so that:
2097
2098     * it can be cached (`Leaf` objects should not be reused more than once as
2099       they store their lineno, column, prefix, and parent information);
2100     * `newlines` and `consumed` fields are kept separate from the `value`. This
2101       simplifies handling of special marker comments like ``# fmt: off/on``.
2102     """
2103
2104     type: int  # token.COMMENT or STANDALONE_COMMENT
2105     value: str  # content of the comment
2106     newlines: int  # how many newlines before the comment
2107     consumed: int  # how many characters of the original leaf's prefix did we consume
2108
2109
2110 @lru_cache(maxsize=4096)
2111 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2112     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2113     result: List[ProtoComment] = []
2114     if not prefix or "#" not in prefix:
2115         return result
2116
2117     consumed = 0
2118     nlines = 0
2119     for index, line in enumerate(prefix.split("\n")):
2120         consumed += len(line) + 1  # adding the length of the split '\n'
2121         line = line.lstrip()
2122         if not line:
2123             nlines += 1
2124         if not line.startswith("#"):
2125             continue
2126
2127         if index == 0 and not is_endmarker:
2128             comment_type = token.COMMENT  # simple trailing comment
2129         else:
2130             comment_type = STANDALONE_COMMENT
2131         comment = make_comment(line)
2132         result.append(
2133             ProtoComment(
2134                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2135             )
2136         )
2137         nlines = 0
2138     return result
2139
2140
2141 def make_comment(content: str) -> str:
2142     """Return a consistently formatted comment from the given `content` string.
2143
2144     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2145     space between the hash sign and the content.
2146
2147     If `content` didn't start with a hash sign, one is provided.
2148     """
2149     content = content.rstrip()
2150     if not content:
2151         return "#"
2152
2153     if content[0] == "#":
2154         content = content[1:]
2155     if content and content[0] not in " !:#'%":
2156         content = " " + content
2157     return "#" + content
2158
2159
2160 def split_line(
2161     line: Line,
2162     line_length: int,
2163     inner: bool = False,
2164     supports_trailing_commas: bool = False,
2165 ) -> Iterator[Line]:
2166     """Split a `line` into potentially many lines.
2167
2168     They should fit in the allotted `line_length` but might not be able to.
2169     `inner` signifies that there were a pair of brackets somewhere around the
2170     current `line`, possibly transitively. This means we can fallback to splitting
2171     by delimiters if the LHS/RHS don't yield any results.
2172
2173     If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
2174     """
2175     if line.is_comment:
2176         yield line
2177         return
2178
2179     line_str = str(line).strip("\n")
2180
2181     # we don't want to split special comments like type annotations
2182     # https://github.com/python/typing/issues/186
2183     has_special_comment = False
2184     for leaf in line.leaves:
2185         for comment in line.comments_after(leaf):
2186             if leaf.type == token.COMMA and is_special_comment(comment):
2187                 has_special_comment = True
2188
2189     if (
2190         not has_special_comment
2191         and not line.should_explode
2192         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2193     ):
2194         yield line
2195         return
2196
2197     split_funcs: List[SplitFunc]
2198     if line.is_def:
2199         split_funcs = [left_hand_split]
2200     else:
2201
2202         def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
2203             for omit in generate_trailers_to_omit(line, line_length):
2204                 lines = list(
2205                     right_hand_split(
2206                         line, line_length, supports_trailing_commas, omit=omit
2207                     )
2208                 )
2209                 if is_line_short_enough(lines[0], line_length=line_length):
2210                     yield from lines
2211                     return
2212
2213             # All splits failed, best effort split with no omits.
2214             # This mostly happens to multiline strings that are by definition
2215             # reported as not fitting a single line.
2216             yield from right_hand_split(line, supports_trailing_commas)
2217
2218         if line.inside_brackets:
2219             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2220         else:
2221             split_funcs = [rhs]
2222     for split_func in split_funcs:
2223         # We are accumulating lines in `result` because we might want to abort
2224         # mission and return the original line in the end, or attempt a different
2225         # split altogether.
2226         result: List[Line] = []
2227         try:
2228             for l in split_func(line, supports_trailing_commas):
2229                 if str(l).strip("\n") == line_str:
2230                     raise CannotSplit("Split function returned an unchanged result")
2231
2232                 result.extend(
2233                     split_line(
2234                         l,
2235                         line_length=line_length,
2236                         inner=True,
2237                         supports_trailing_commas=supports_trailing_commas,
2238                     )
2239                 )
2240         except CannotSplit:
2241             continue
2242
2243         else:
2244             yield from result
2245             break
2246
2247     else:
2248         yield line
2249
2250
2251 def left_hand_split(
2252     line: Line, supports_trailing_commas: bool = False
2253 ) -> Iterator[Line]:
2254     """Split line into many lines, starting with the first matching bracket pair.
2255
2256     Note: this usually looks weird, only use this for function definitions.
2257     Prefer RHS otherwise.  This is why this function is not symmetrical with
2258     :func:`right_hand_split` which also handles optional parentheses.
2259     """
2260     tail_leaves: List[Leaf] = []
2261     body_leaves: List[Leaf] = []
2262     head_leaves: List[Leaf] = []
2263     current_leaves = head_leaves
2264     matching_bracket = None
2265     for leaf in line.leaves:
2266         if (
2267             current_leaves is body_leaves
2268             and leaf.type in CLOSING_BRACKETS
2269             and leaf.opening_bracket is matching_bracket
2270         ):
2271             current_leaves = tail_leaves if body_leaves else head_leaves
2272         current_leaves.append(leaf)
2273         if current_leaves is head_leaves:
2274             if leaf.type in OPENING_BRACKETS:
2275                 matching_bracket = leaf
2276                 current_leaves = body_leaves
2277     if not matching_bracket:
2278         raise CannotSplit("No brackets found")
2279
2280     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2281     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2282     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2283     bracket_split_succeeded_or_raise(head, body, tail)
2284     for result in (head, body, tail):
2285         if result:
2286             yield result
2287
2288
2289 def right_hand_split(
2290     line: Line,
2291     line_length: int,
2292     supports_trailing_commas: bool = False,
2293     omit: Collection[LeafID] = (),
2294 ) -> Iterator[Line]:
2295     """Split line into many lines, starting with the last matching bracket pair.
2296
2297     If the split was by optional parentheses, attempt splitting without them, too.
2298     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2299     this split.
2300
2301     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2302     """
2303     tail_leaves: List[Leaf] = []
2304     body_leaves: List[Leaf] = []
2305     head_leaves: List[Leaf] = []
2306     current_leaves = tail_leaves
2307     opening_bracket = None
2308     closing_bracket = None
2309     for leaf in reversed(line.leaves):
2310         if current_leaves is body_leaves:
2311             if leaf is opening_bracket:
2312                 current_leaves = head_leaves if body_leaves else tail_leaves
2313         current_leaves.append(leaf)
2314         if current_leaves is tail_leaves:
2315             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2316                 opening_bracket = leaf.opening_bracket
2317                 closing_bracket = leaf
2318                 current_leaves = body_leaves
2319     if not (opening_bracket and closing_bracket and head_leaves):
2320         # If there is no opening or closing_bracket that means the split failed and
2321         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2322         # the matching `opening_bracket` wasn't available on `line` anymore.
2323         raise CannotSplit("No brackets found")
2324
2325     tail_leaves.reverse()
2326     body_leaves.reverse()
2327     head_leaves.reverse()
2328     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2329     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2330     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2331     bracket_split_succeeded_or_raise(head, body, tail)
2332     if (
2333         # the body shouldn't be exploded
2334         not body.should_explode
2335         # the opening bracket is an optional paren
2336         and opening_bracket.type == token.LPAR
2337         and not opening_bracket.value
2338         # the closing bracket is an optional paren
2339         and closing_bracket.type == token.RPAR
2340         and not closing_bracket.value
2341         # it's not an import (optional parens are the only thing we can split on
2342         # in this case; attempting a split without them is a waste of time)
2343         and not line.is_import
2344         # there are no standalone comments in the body
2345         and not body.contains_standalone_comments(0)
2346         # and we can actually remove the parens
2347         and can_omit_invisible_parens(body, line_length)
2348     ):
2349         omit = {id(closing_bracket), *omit}
2350         try:
2351             yield from right_hand_split(
2352                 line,
2353                 line_length,
2354                 supports_trailing_commas=supports_trailing_commas,
2355                 omit=omit,
2356             )
2357             return
2358
2359         except CannotSplit:
2360             if not (
2361                 can_be_split(body)
2362                 or is_line_short_enough(body, line_length=line_length)
2363             ):
2364                 raise CannotSplit(
2365                     "Splitting failed, body is still too long and can't be split."
2366                 )
2367
2368             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2369                 raise CannotSplit(
2370                     "The current optional pair of parentheses is bound to fail to "
2371                     "satisfy the splitting algorithm because the head or the tail "
2372                     "contains multiline strings which by definition never fit one "
2373                     "line."
2374                 )
2375
2376     ensure_visible(opening_bracket)
2377     ensure_visible(closing_bracket)
2378     for result in (head, body, tail):
2379         if result:
2380             yield result
2381
2382
2383 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2384     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2385
2386     Do nothing otherwise.
2387
2388     A left- or right-hand split is based on a pair of brackets. Content before
2389     (and including) the opening bracket is left on one line, content inside the
2390     brackets is put on a separate line, and finally content starting with and
2391     following the closing bracket is put on a separate line.
2392
2393     Those are called `head`, `body`, and `tail`, respectively. If the split
2394     produced the same line (all content in `head`) or ended up with an empty `body`
2395     and the `tail` is just the closing bracket, then it's considered failed.
2396     """
2397     tail_len = len(str(tail).strip())
2398     if not body:
2399         if tail_len == 0:
2400             raise CannotSplit("Splitting brackets produced the same line")
2401
2402         elif tail_len < 3:
2403             raise CannotSplit(
2404                 f"Splitting brackets on an empty body to save "
2405                 f"{tail_len} characters is not worth it"
2406             )
2407
2408
2409 def bracket_split_build_line(
2410     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2411 ) -> Line:
2412     """Return a new line with given `leaves` and respective comments from `original`.
2413
2414     If `is_body` is True, the result line is one-indented inside brackets and as such
2415     has its first leaf's prefix normalized and a trailing comma added when expected.
2416     """
2417     result = Line(depth=original.depth)
2418     if is_body:
2419         result.inside_brackets = True
2420         result.depth += 1
2421         if leaves:
2422             # Since body is a new indent level, remove spurious leading whitespace.
2423             normalize_prefix(leaves[0], inside_brackets=True)
2424             # Ensure a trailing comma when expected.
2425             if original.is_import:
2426                 if leaves[-1].type != token.COMMA:
2427                     leaves.append(Leaf(token.COMMA, ","))
2428     # Populate the line
2429     for leaf in leaves:
2430         result.append(leaf, preformatted=True)
2431         for comment_after in original.comments_after(leaf):
2432             result.append(comment_after, preformatted=True)
2433     if is_body:
2434         result.should_explode = should_explode(result, opening_bracket)
2435     return result
2436
2437
2438 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2439     """Normalize prefix of the first leaf in every line returned by `split_func`.
2440
2441     This is a decorator over relevant split functions.
2442     """
2443
2444     @wraps(split_func)
2445     def split_wrapper(
2446         line: Line, supports_trailing_commas: bool = False
2447     ) -> Iterator[Line]:
2448         for l in split_func(line, supports_trailing_commas):
2449             normalize_prefix(l.leaves[0], inside_brackets=True)
2450             yield l
2451
2452     return split_wrapper
2453
2454
2455 @dont_increase_indentation
2456 def delimiter_split(
2457     line: Line, supports_trailing_commas: bool = False
2458 ) -> Iterator[Line]:
2459     """Split according to delimiters of the highest priority.
2460
2461     If `py36` is True, the split will add trailing commas also in function
2462     signatures that contain `*` and `**`.
2463     """
2464     try:
2465         last_leaf = line.leaves[-1]
2466     except IndexError:
2467         raise CannotSplit("Line empty")
2468
2469     bt = line.bracket_tracker
2470     try:
2471         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2472     except ValueError:
2473         raise CannotSplit("No delimiters found")
2474
2475     if delimiter_priority == DOT_PRIORITY:
2476         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2477             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2478
2479     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2480     lowest_depth = sys.maxsize
2481     trailing_comma_safe = True
2482
2483     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2484         """Append `leaf` to current line or to new line if appending impossible."""
2485         nonlocal current_line
2486         try:
2487             current_line.append_safe(leaf, preformatted=True)
2488         except ValueError:
2489             yield current_line
2490
2491             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2492             current_line.append(leaf)
2493
2494     for leaf in line.leaves:
2495         yield from append_to_line(leaf)
2496
2497         for comment_after in line.comments_after(leaf):
2498             yield from append_to_line(comment_after)
2499
2500         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2501         if leaf.bracket_depth == lowest_depth and is_vararg(
2502             leaf, within=VARARGS_PARENTS
2503         ):
2504             trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
2505         leaf_priority = bt.delimiters.get(id(leaf))
2506         if leaf_priority == delimiter_priority:
2507             yield current_line
2508
2509             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2510     if current_line:
2511         if (
2512             trailing_comma_safe
2513             and delimiter_priority == COMMA_PRIORITY
2514             and current_line.leaves[-1].type != token.COMMA
2515             and current_line.leaves[-1].type != STANDALONE_COMMENT
2516         ):
2517             current_line.append(Leaf(token.COMMA, ","))
2518         yield current_line
2519
2520
2521 @dont_increase_indentation
2522 def standalone_comment_split(
2523     line: Line, supports_trailing_commas: bool = False
2524 ) -> Iterator[Line]:
2525     """Split standalone comments from the rest of the line."""
2526     if not line.contains_standalone_comments(0):
2527         raise CannotSplit("Line does not have any standalone comments")
2528
2529     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2530
2531     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2532         """Append `leaf` to current line or to new line if appending impossible."""
2533         nonlocal current_line
2534         try:
2535             current_line.append_safe(leaf, preformatted=True)
2536         except ValueError:
2537             yield current_line
2538
2539             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2540             current_line.append(leaf)
2541
2542     for leaf in line.leaves:
2543         yield from append_to_line(leaf)
2544
2545         for comment_after in line.comments_after(leaf):
2546             yield from append_to_line(comment_after)
2547
2548     if current_line:
2549         yield current_line
2550
2551
2552 def is_import(leaf: Leaf) -> bool:
2553     """Return True if the given leaf starts an import statement."""
2554     p = leaf.parent
2555     t = leaf.type
2556     v = leaf.value
2557     return bool(
2558         t == token.NAME
2559         and (
2560             (v == "import" and p and p.type == syms.import_name)
2561             or (v == "from" and p and p.type == syms.import_from)
2562         )
2563     )
2564
2565
2566 def is_special_comment(leaf: Leaf) -> bool:
2567     """Return True if the given leaf is a special comment.
2568     Only returns true for type comments for now."""
2569     t = leaf.type
2570     v = leaf.value
2571     return bool(
2572         (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
2573     )
2574
2575
2576 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2577     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2578     else.
2579
2580     Note: don't use backslashes for formatting or you'll lose your voting rights.
2581     """
2582     if not inside_brackets:
2583         spl = leaf.prefix.split("#")
2584         if "\\" not in spl[0]:
2585             nl_count = spl[-1].count("\n")
2586             if len(spl) > 1:
2587                 nl_count -= 1
2588             leaf.prefix = "\n" * nl_count
2589             return
2590
2591     leaf.prefix = ""
2592
2593
2594 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2595     """Make all string prefixes lowercase.
2596
2597     If remove_u_prefix is given, also removes any u prefix from the string.
2598
2599     Note: Mutates its argument.
2600     """
2601     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2602     assert match is not None, f"failed to match string {leaf.value!r}"
2603     orig_prefix = match.group(1)
2604     new_prefix = orig_prefix.lower()
2605     if remove_u_prefix:
2606         new_prefix = new_prefix.replace("u", "")
2607     leaf.value = f"{new_prefix}{match.group(2)}"
2608
2609
2610 def normalize_string_quotes(leaf: Leaf) -> None:
2611     """Prefer double quotes but only if it doesn't cause more escaping.
2612
2613     Adds or removes backslashes as appropriate. Doesn't parse and fix
2614     strings nested in f-strings (yet).
2615
2616     Note: Mutates its argument.
2617     """
2618     value = leaf.value.lstrip("furbFURB")
2619     if value[:3] == '"""':
2620         return
2621
2622     elif value[:3] == "'''":
2623         orig_quote = "'''"
2624         new_quote = '"""'
2625     elif value[0] == '"':
2626         orig_quote = '"'
2627         new_quote = "'"
2628     else:
2629         orig_quote = "'"
2630         new_quote = '"'
2631     first_quote_pos = leaf.value.find(orig_quote)
2632     if first_quote_pos == -1:
2633         return  # There's an internal error
2634
2635     prefix = leaf.value[:first_quote_pos]
2636     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2637     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2638     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2639     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2640     if "r" in prefix.casefold():
2641         if unescaped_new_quote.search(body):
2642             # There's at least one unescaped new_quote in this raw string
2643             # so converting is impossible
2644             return
2645
2646         # Do not introduce or remove backslashes in raw strings
2647         new_body = body
2648     else:
2649         # remove unnecessary escapes
2650         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2651         if body != new_body:
2652             # Consider the string without unnecessary escapes as the original
2653             body = new_body
2654             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2655         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2656         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2657     if "f" in prefix.casefold():
2658         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2659         for m in matches:
2660             if "\\" in str(m):
2661                 # Do not introduce backslashes in interpolated expressions
2662                 return
2663     if new_quote == '"""' and new_body[-1:] == '"':
2664         # edge case:
2665         new_body = new_body[:-1] + '\\"'
2666     orig_escape_count = body.count("\\")
2667     new_escape_count = new_body.count("\\")
2668     if new_escape_count > orig_escape_count:
2669         return  # Do not introduce more escaping
2670
2671     if new_escape_count == orig_escape_count and orig_quote == '"':
2672         return  # Prefer double quotes
2673
2674     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2675
2676
2677 def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
2678     """Normalizes numeric (float, int, and complex) literals.
2679
2680     All letters used in the representation are normalized to lowercase (except
2681     in Python 2 long literals), and long number literals are split using underscores.
2682     """
2683     text = leaf.value.lower()
2684     if text.startswith(("0o", "0b")):
2685         # Leave octal and binary literals alone.
2686         pass
2687     elif text.startswith("0x"):
2688         # Change hex literals to upper case.
2689         before, after = text[:2], text[2:]
2690         text = f"{before}{after.upper()}"
2691     elif "e" in text:
2692         before, after = text.split("e")
2693         sign = ""
2694         if after.startswith("-"):
2695             after = after[1:]
2696             sign = "-"
2697         elif after.startswith("+"):
2698             after = after[1:]
2699         before = format_float_or_int_string(before, allow_underscores)
2700         after = format_int_string(after, allow_underscores)
2701         text = f"{before}e{sign}{after}"
2702     elif text.endswith(("j", "l")):
2703         number = text[:-1]
2704         suffix = text[-1]
2705         # Capitalize in "2L" because "l" looks too similar to "1".
2706         if suffix == "l":
2707             suffix = "L"
2708         text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
2709     else:
2710         text = format_float_or_int_string(text, allow_underscores)
2711     leaf.value = text
2712
2713
2714 def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
2715     """Formats a float string like "1.0"."""
2716     if "." not in text:
2717         return format_int_string(text, allow_underscores)
2718
2719     before, after = text.split(".")
2720     before = format_int_string(before, allow_underscores) if before else "0"
2721     if after:
2722         after = format_int_string(after, allow_underscores, count_from_end=False)
2723     else:
2724         after = "0"
2725     return f"{before}.{after}"
2726
2727
2728 def format_int_string(
2729     text: str, allow_underscores: bool, count_from_end: bool = True
2730 ) -> str:
2731     """Normalizes underscores in a string to e.g. 1_000_000.
2732
2733     Input must be a string of digits and optional underscores.
2734     If count_from_end is False, we add underscores after groups of three digits
2735     counting from the beginning instead of the end of the strings. This is used
2736     for the fractional part of float literals.
2737     """
2738     if not allow_underscores:
2739         return text
2740
2741     text = text.replace("_", "")
2742     if len(text) <= 5:
2743         # No underscores for numbers <= 5 digits long.
2744         return text
2745
2746     if count_from_end:
2747         # Avoid removing leading zeros, which are important if we're formatting
2748         # part of a number like "0.001".
2749         return format(int("1" + text), "3_")[1:].lstrip("_")
2750     else:
2751         return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
2752
2753
2754 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2755     """Make existing optional parentheses invisible or create new ones.
2756
2757     `parens_after` is a set of string leaf values immeditely after which parens
2758     should be put.
2759
2760     Standardizes on visible parentheses for single-element tuples, and keeps
2761     existing visible parentheses for other tuples and generator expressions.
2762     """
2763     for pc in list_comments(node.prefix, is_endmarker=False):
2764         if pc.value in FMT_OFF:
2765             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2766             return
2767
2768     check_lpar = False
2769     for index, child in enumerate(list(node.children)):
2770         if check_lpar:
2771             if child.type == syms.atom:
2772                 if maybe_make_parens_invisible_in_atom(child):
2773                     lpar = Leaf(token.LPAR, "")
2774                     rpar = Leaf(token.RPAR, "")
2775                     index = child.remove() or 0
2776                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2777             elif is_one_tuple(child):
2778                 # wrap child in visible parentheses
2779                 lpar = Leaf(token.LPAR, "(")
2780                 rpar = Leaf(token.RPAR, ")")
2781                 child.remove()
2782                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2783             elif node.type == syms.import_from:
2784                 # "import from" nodes store parentheses directly as part of
2785                 # the statement
2786                 if child.type == token.LPAR:
2787                     # make parentheses invisible
2788                     child.value = ""  # type: ignore
2789                     node.children[-1].value = ""  # type: ignore
2790                 elif child.type != token.STAR:
2791                     # insert invisible parentheses
2792                     node.insert_child(index, Leaf(token.LPAR, ""))
2793                     node.append_child(Leaf(token.RPAR, ""))
2794                 break
2795
2796             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2797                 # wrap child in invisible parentheses
2798                 lpar = Leaf(token.LPAR, "")
2799                 rpar = Leaf(token.RPAR, "")
2800                 index = child.remove() or 0
2801                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2802
2803         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2804
2805
2806 def normalize_fmt_off(node: Node) -> None:
2807     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2808     try_again = True
2809     while try_again:
2810         try_again = convert_one_fmt_off_pair(node)
2811
2812
2813 def convert_one_fmt_off_pair(node: Node) -> bool:
2814     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2815
2816     Returns True if a pair was converted.
2817     """
2818     for leaf in node.leaves():
2819         previous_consumed = 0
2820         for comment in list_comments(leaf.prefix, is_endmarker=False):
2821             if comment.value in FMT_OFF:
2822                 # We only want standalone comments. If there's no previous leaf or
2823                 # the previous leaf is indentation, it's a standalone comment in
2824                 # disguise.
2825                 if comment.type != STANDALONE_COMMENT:
2826                     prev = preceding_leaf(leaf)
2827                     if prev and prev.type not in WHITESPACE:
2828                         continue
2829
2830                 ignored_nodes = list(generate_ignored_nodes(leaf))
2831                 if not ignored_nodes:
2832                     continue
2833
2834                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2835                 parent = first.parent
2836                 prefix = first.prefix
2837                 first.prefix = prefix[comment.consumed :]
2838                 hidden_value = (
2839                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2840                 )
2841                 if hidden_value.endswith("\n"):
2842                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2843                     # leaf (possibly followed by a DEDENT).
2844                     hidden_value = hidden_value[:-1]
2845                 first_idx = None
2846                 for ignored in ignored_nodes:
2847                     index = ignored.remove()
2848                     if first_idx is None:
2849                         first_idx = index
2850                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2851                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2852                 parent.insert_child(
2853                     first_idx,
2854                     Leaf(
2855                         STANDALONE_COMMENT,
2856                         hidden_value,
2857                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2858                     ),
2859                 )
2860                 return True
2861
2862             previous_consumed = comment.consumed
2863
2864     return False
2865
2866
2867 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2868     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2869
2870     Stops at the end of the block.
2871     """
2872     container: Optional[LN] = container_of(leaf)
2873     while container is not None and container.type != token.ENDMARKER:
2874         for comment in list_comments(container.prefix, is_endmarker=False):
2875             if comment.value in FMT_ON:
2876                 return
2877
2878         yield container
2879
2880         container = container.next_sibling
2881
2882
2883 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2884     """If it's safe, make the parens in the atom `node` invisible, recursively.
2885
2886     Returns whether the node should itself be wrapped in invisible parentheses.
2887
2888     """
2889     if (
2890         node.type != syms.atom
2891         or is_empty_tuple(node)
2892         or is_one_tuple(node)
2893         or is_yield(node)
2894         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2895     ):
2896         return False
2897
2898     first = node.children[0]
2899     last = node.children[-1]
2900     if first.type == token.LPAR and last.type == token.RPAR:
2901         # make parentheses invisible
2902         first.value = ""  # type: ignore
2903         last.value = ""  # type: ignore
2904         if len(node.children) > 1:
2905             maybe_make_parens_invisible_in_atom(node.children[1])
2906         return False
2907
2908     return True
2909
2910
2911 def is_empty_tuple(node: LN) -> bool:
2912     """Return True if `node` holds an empty tuple."""
2913     return (
2914         node.type == syms.atom
2915         and len(node.children) == 2
2916         and node.children[0].type == token.LPAR
2917         and node.children[1].type == token.RPAR
2918     )
2919
2920
2921 def is_one_tuple(node: LN) -> bool:
2922     """Return True if `node` holds a tuple with one element, with or without parens."""
2923     if node.type == syms.atom:
2924         if len(node.children) != 3:
2925             return False
2926
2927         lpar, gexp, rpar = node.children
2928         if not (
2929             lpar.type == token.LPAR
2930             and gexp.type == syms.testlist_gexp
2931             and rpar.type == token.RPAR
2932         ):
2933             return False
2934
2935         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2936
2937     return (
2938         node.type in IMPLICIT_TUPLE
2939         and len(node.children) == 2
2940         and node.children[1].type == token.COMMA
2941     )
2942
2943
2944 def is_yield(node: LN) -> bool:
2945     """Return True if `node` holds a `yield` or `yield from` expression."""
2946     if node.type == syms.yield_expr:
2947         return True
2948
2949     if node.type == token.NAME and node.value == "yield":  # type: ignore
2950         return True
2951
2952     if node.type != syms.atom:
2953         return False
2954
2955     if len(node.children) != 3:
2956         return False
2957
2958     lpar, expr, rpar = node.children
2959     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2960         return is_yield(expr)
2961
2962     return False
2963
2964
2965 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2966     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2967
2968     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2969     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2970     extended iterable unpacking (PEP 3132) and additional unpacking
2971     generalizations (PEP 448).
2972     """
2973     if leaf.type not in STARS or not leaf.parent:
2974         return False
2975
2976     p = leaf.parent
2977     if p.type == syms.star_expr:
2978         # Star expressions are also used as assignment targets in extended
2979         # iterable unpacking (PEP 3132).  See what its parent is instead.
2980         if not p.parent:
2981             return False
2982
2983         p = p.parent
2984
2985     return p.type in within
2986
2987
2988 def is_multiline_string(leaf: Leaf) -> bool:
2989     """Return True if `leaf` is a multiline string that actually spans many lines."""
2990     value = leaf.value.lstrip("furbFURB")
2991     return value[:3] in {'"""', "'''"} and "\n" in value
2992
2993
2994 def is_stub_suite(node: Node) -> bool:
2995     """Return True if `node` is a suite with a stub body."""
2996     if (
2997         len(node.children) != 4
2998         or node.children[0].type != token.NEWLINE
2999         or node.children[1].type != token.INDENT
3000         or node.children[3].type != token.DEDENT
3001     ):
3002         return False
3003
3004     return is_stub_body(node.children[2])
3005
3006
3007 def is_stub_body(node: LN) -> bool:
3008     """Return True if `node` is a simple statement containing an ellipsis."""
3009     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3010         return False
3011
3012     if len(node.children) != 2:
3013         return False
3014
3015     child = node.children[0]
3016     return (
3017         child.type == syms.atom
3018         and len(child.children) == 3
3019         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3020     )
3021
3022
3023 def max_delimiter_priority_in_atom(node: LN) -> int:
3024     """Return maximum delimiter priority inside `node`.
3025
3026     This is specific to atoms with contents contained in a pair of parentheses.
3027     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3028     """
3029     if node.type != syms.atom:
3030         return 0
3031
3032     first = node.children[0]
3033     last = node.children[-1]
3034     if not (first.type == token.LPAR and last.type == token.RPAR):
3035         return 0
3036
3037     bt = BracketTracker()
3038     for c in node.children[1:-1]:
3039         if isinstance(c, Leaf):
3040             bt.mark(c)
3041         else:
3042             for leaf in c.leaves():
3043                 bt.mark(leaf)
3044     try:
3045         return bt.max_delimiter_priority()
3046
3047     except ValueError:
3048         return 0
3049
3050
3051 def ensure_visible(leaf: Leaf) -> None:
3052     """Make sure parentheses are visible.
3053
3054     They could be invisible as part of some statements (see
3055     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3056     """
3057     if leaf.type == token.LPAR:
3058         leaf.value = "("
3059     elif leaf.type == token.RPAR:
3060         leaf.value = ")"
3061
3062
3063 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3064     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3065
3066     if not (
3067         opening_bracket.parent
3068         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3069         and opening_bracket.value in "[{("
3070     ):
3071         return False
3072
3073     try:
3074         last_leaf = line.leaves[-1]
3075         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3076         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3077     except (IndexError, ValueError):
3078         return False
3079
3080     return max_priority == COMMA_PRIORITY
3081
3082
3083 def get_features_used(node: Node) -> Set[Feature]:
3084     """Return a set of (relatively) new Python features used in this file.
3085
3086     Currently looking for:
3087     - f-strings;
3088     - underscores in numeric literals; and
3089     - trailing commas after * or ** in function signatures and calls.
3090     """
3091     features: Set[Feature] = set()
3092     for n in node.pre_order():
3093         if n.type == token.STRING:
3094             value_head = n.value[:2]  # type: ignore
3095             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3096                 features.add(Feature.F_STRINGS)
3097
3098         elif n.type == token.NUMBER:
3099             if "_" in n.value:  # type: ignore
3100                 features.add(Feature.NUMERIC_UNDERSCORES)
3101
3102         elif (
3103             n.type in {syms.typedargslist, syms.arglist}
3104             and n.children
3105             and n.children[-1].type == token.COMMA
3106         ):
3107             for ch in n.children:
3108                 if ch.type in STARS:
3109                     features.add(Feature.TRAILING_COMMA)
3110
3111                 if ch.type == syms.argument:
3112                     for argch in ch.children:
3113                         if argch.type in STARS:
3114                             features.add(Feature.TRAILING_COMMA)
3115
3116     return features
3117
3118
3119 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3120     """Detect the version to target based on the nodes used."""
3121     features = get_features_used(node)
3122     return {
3123         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3124     }
3125
3126
3127 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3128     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3129
3130     Brackets can be omitted if the entire trailer up to and including
3131     a preceding closing bracket fits in one line.
3132
3133     Yielded sets are cumulative (contain results of previous yields, too).  First
3134     set is empty.
3135     """
3136
3137     omit: Set[LeafID] = set()
3138     yield omit
3139
3140     length = 4 * line.depth
3141     opening_bracket = None
3142     closing_bracket = None
3143     inner_brackets: Set[LeafID] = set()
3144     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3145         length += leaf_length
3146         if length > line_length:
3147             break
3148
3149         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3150         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3151             break
3152
3153         if opening_bracket:
3154             if leaf is opening_bracket:
3155                 opening_bracket = None
3156             elif leaf.type in CLOSING_BRACKETS:
3157                 inner_brackets.add(id(leaf))
3158         elif leaf.type in CLOSING_BRACKETS:
3159             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3160                 # Empty brackets would fail a split so treat them as "inner"
3161                 # brackets (e.g. only add them to the `omit` set if another
3162                 # pair of brackets was good enough.
3163                 inner_brackets.add(id(leaf))
3164                 continue
3165
3166             if closing_bracket:
3167                 omit.add(id(closing_bracket))
3168                 omit.update(inner_brackets)
3169                 inner_brackets.clear()
3170                 yield omit
3171
3172             if leaf.value:
3173                 opening_bracket = leaf.opening_bracket
3174                 closing_bracket = leaf
3175
3176
3177 def get_future_imports(node: Node) -> Set[str]:
3178     """Return a set of __future__ imports in the file."""
3179     imports: Set[str] = set()
3180
3181     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3182         for child in children:
3183             if isinstance(child, Leaf):
3184                 if child.type == token.NAME:
3185                     yield child.value
3186             elif child.type == syms.import_as_name:
3187                 orig_name = child.children[0]
3188                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3189                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3190                 yield orig_name.value
3191             elif child.type == syms.import_as_names:
3192                 yield from get_imports_from_children(child.children)
3193             else:
3194                 assert False, "Invalid syntax parsing imports"
3195
3196     for child in node.children:
3197         if child.type != syms.simple_stmt:
3198             break
3199         first_child = child.children[0]
3200         if isinstance(first_child, Leaf):
3201             # Continue looking if we see a docstring; otherwise stop.
3202             if (
3203                 len(child.children) == 2
3204                 and first_child.type == token.STRING
3205                 and child.children[1].type == token.NEWLINE
3206             ):
3207                 continue
3208             else:
3209                 break
3210         elif first_child.type == syms.import_from:
3211             module_name = first_child.children[1]
3212             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3213                 break
3214             imports |= set(get_imports_from_children(first_child.children[3:]))
3215         else:
3216             break
3217     return imports
3218
3219
3220 def gen_python_files_in_dir(
3221     path: Path,
3222     root: Path,
3223     include: Pattern[str],
3224     exclude: Pattern[str],
3225     report: "Report",
3226 ) -> Iterator[Path]:
3227     """Generate all files under `path` whose paths are not excluded by the
3228     `exclude` regex, but are included by the `include` regex.
3229
3230     Symbolic links pointing outside of the `root` directory are ignored.
3231
3232     `report` is where output about exclusions goes.
3233     """
3234     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3235     for child in path.iterdir():
3236         try:
3237             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3238         except ValueError:
3239             if child.is_symlink():
3240                 report.path_ignored(
3241                     child, f"is a symbolic link that points outside {root}"
3242                 )
3243                 continue
3244
3245             raise
3246
3247         if child.is_dir():
3248             normalized_path += "/"
3249         exclude_match = exclude.search(normalized_path)
3250         if exclude_match and exclude_match.group(0):
3251             report.path_ignored(child, f"matches the --exclude regular expression")
3252             continue
3253
3254         if child.is_dir():
3255             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3256
3257         elif child.is_file():
3258             include_match = include.search(normalized_path)
3259             if include_match:
3260                 yield child
3261
3262
3263 @lru_cache()
3264 def find_project_root(srcs: Iterable[str]) -> Path:
3265     """Return a directory containing .git, .hg, or pyproject.toml.
3266
3267     That directory can be one of the directories passed in `srcs` or their
3268     common parent.
3269
3270     If no directory in the tree contains a marker that would specify it's the
3271     project root, the root of the file system is returned.
3272     """
3273     if not srcs:
3274         return Path("/").resolve()
3275
3276     common_base = min(Path(src).resolve() for src in srcs)
3277     if common_base.is_dir():
3278         # Append a fake file so `parents` below returns `common_base_dir`, too.
3279         common_base /= "fake-file"
3280     for directory in common_base.parents:
3281         if (directory / ".git").is_dir():
3282             return directory
3283
3284         if (directory / ".hg").is_dir():
3285             return directory
3286
3287         if (directory / "pyproject.toml").is_file():
3288             return directory
3289
3290     return directory
3291
3292
3293 @dataclass
3294 class Report:
3295     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3296
3297     check: bool = False
3298     quiet: bool = False
3299     verbose: bool = False
3300     change_count: int = 0
3301     same_count: int = 0
3302     failure_count: int = 0
3303
3304     def done(self, src: Path, changed: Changed) -> None:
3305         """Increment the counter for successful reformatting. Write out a message."""
3306         if changed is Changed.YES:
3307             reformatted = "would reformat" if self.check else "reformatted"
3308             if self.verbose or not self.quiet:
3309                 out(f"{reformatted} {src}")
3310             self.change_count += 1
3311         else:
3312             if self.verbose:
3313                 if changed is Changed.NO:
3314                     msg = f"{src} already well formatted, good job."
3315                 else:
3316                     msg = f"{src} wasn't modified on disk since last run."
3317                 out(msg, bold=False)
3318             self.same_count += 1
3319
3320     def failed(self, src: Path, message: str) -> None:
3321         """Increment the counter for failed reformatting. Write out a message."""
3322         err(f"error: cannot format {src}: {message}")
3323         self.failure_count += 1
3324
3325     def path_ignored(self, path: Path, message: str) -> None:
3326         if self.verbose:
3327             out(f"{path} ignored: {message}", bold=False)
3328
3329     @property
3330     def return_code(self) -> int:
3331         """Return the exit code that the app should use.
3332
3333         This considers the current state of changed files and failures:
3334         - if there were any failures, return 123;
3335         - if any files were changed and --check is being used, return 1;
3336         - otherwise return 0.
3337         """
3338         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3339         # 126 we have special return codes reserved by the shell.
3340         if self.failure_count:
3341             return 123
3342
3343         elif self.change_count and self.check:
3344             return 1
3345
3346         return 0
3347
3348     def __str__(self) -> str:
3349         """Render a color report of the current state.
3350
3351         Use `click.unstyle` to remove colors.
3352         """
3353         if self.check:
3354             reformatted = "would be reformatted"
3355             unchanged = "would be left unchanged"
3356             failed = "would fail to reformat"
3357         else:
3358             reformatted = "reformatted"
3359             unchanged = "left unchanged"
3360             failed = "failed to reformat"
3361         report = []
3362         if self.change_count:
3363             s = "s" if self.change_count > 1 else ""
3364             report.append(
3365                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3366             )
3367         if self.same_count:
3368             s = "s" if self.same_count > 1 else ""
3369             report.append(f"{self.same_count} file{s} {unchanged}")
3370         if self.failure_count:
3371             s = "s" if self.failure_count > 1 else ""
3372             report.append(
3373                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3374             )
3375         return ", ".join(report) + "."
3376
3377
3378 def assert_equivalent(src: str, dst: str) -> None:
3379     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3380
3381     import ast
3382     import traceback
3383
3384     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3385         """Simple visitor generating strings to compare ASTs by content."""
3386         yield f"{'  ' * depth}{node.__class__.__name__}("
3387
3388         for field in sorted(node._fields):
3389             try:
3390                 value = getattr(node, field)
3391             except AttributeError:
3392                 continue
3393
3394             yield f"{'  ' * (depth+1)}{field}="
3395
3396             if isinstance(value, list):
3397                 for item in value:
3398                     if isinstance(item, ast.AST):
3399                         yield from _v(item, depth + 2)
3400
3401             elif isinstance(value, ast.AST):
3402                 yield from _v(value, depth + 2)
3403
3404             else:
3405                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3406
3407         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3408
3409     try:
3410         src_ast = ast.parse(src)
3411     except Exception as exc:
3412         major, minor = sys.version_info[:2]
3413         raise AssertionError(
3414             f"cannot use --safe with this file; failed to parse source file "
3415             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3416             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3417         )
3418
3419     try:
3420         dst_ast = ast.parse(dst)
3421     except Exception as exc:
3422         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3423         raise AssertionError(
3424             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3425             f"Please report a bug on https://github.com/ambv/black/issues.  "
3426             f"This invalid output might be helpful: {log}"
3427         ) from None
3428
3429     src_ast_str = "\n".join(_v(src_ast))
3430     dst_ast_str = "\n".join(_v(dst_ast))
3431     if src_ast_str != dst_ast_str:
3432         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3433         raise AssertionError(
3434             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3435             f"the source.  "
3436             f"Please report a bug on https://github.com/ambv/black/issues.  "
3437             f"This diff might be helpful: {log}"
3438         ) from None
3439
3440
3441 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3442     """Raise AssertionError if `dst` reformats differently the second time."""
3443     newdst = format_str(dst, mode=mode)
3444     if dst != newdst:
3445         log = dump_to_file(
3446             diff(src, dst, "source", "first pass"),
3447             diff(dst, newdst, "first pass", "second pass"),
3448         )
3449         raise AssertionError(
3450             f"INTERNAL ERROR: Black produced different code on the second pass "
3451             f"of the formatter.  "
3452             f"Please report a bug on https://github.com/ambv/black/issues.  "
3453             f"This diff might be helpful: {log}"
3454         ) from None
3455
3456
3457 def dump_to_file(*output: str) -> str:
3458     """Dump `output` to a temporary file. Return path to the file."""
3459     import tempfile
3460
3461     with tempfile.NamedTemporaryFile(
3462         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3463     ) as f:
3464         for lines in output:
3465             f.write(lines)
3466             if lines and lines[-1] != "\n":
3467                 f.write("\n")
3468     return f.name
3469
3470
3471 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3472     """Return a unified diff string between strings `a` and `b`."""
3473     import difflib
3474
3475     a_lines = [line + "\n" for line in a.split("\n")]
3476     b_lines = [line + "\n" for line in b.split("\n")]
3477     return "".join(
3478         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3479     )
3480
3481
3482 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3483     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3484     err("Aborted!")
3485     for task in tasks:
3486         task.cancel()
3487
3488
3489 def shutdown(loop: BaseEventLoop) -> None:
3490     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3491     try:
3492         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3493         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3494         if not to_cancel:
3495             return
3496
3497         for task in to_cancel:
3498             task.cancel()
3499         loop.run_until_complete(
3500             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3501         )
3502     finally:
3503         # `concurrent.futures.Future` objects cannot be cancelled once they
3504         # are already running. There might be some when the `shutdown()` happened.
3505         # Silence their logger's spew about the event loop being closed.
3506         cf_logger = logging.getLogger("concurrent.futures")
3507         cf_logger.setLevel(logging.CRITICAL)
3508         loop.close()
3509
3510
3511 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3512     """Replace `regex` with `replacement` twice on `original`.
3513
3514     This is used by string normalization to perform replaces on
3515     overlapping matches.
3516     """
3517     return regex.sub(replacement, regex.sub(replacement, original))
3518
3519
3520 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3521     """Compile a regular expression string in `regex`.
3522
3523     If it contains newlines, use verbose mode.
3524     """
3525     if "\n" in regex:
3526         regex = "(?x)" + regex
3527     return re.compile(regex)
3528
3529
3530 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3531     """Like `reversed(enumerate(sequence))` if that were possible."""
3532     index = len(sequence) - 1
3533     for element in reversed(sequence):
3534         yield (index, element)
3535         index -= 1
3536
3537
3538 def enumerate_with_length(
3539     line: Line, reversed: bool = False
3540 ) -> Iterator[Tuple[Index, Leaf, int]]:
3541     """Return an enumeration of leaves with their length.
3542
3543     Stops prematurely on multiline strings and standalone comments.
3544     """
3545     op = cast(
3546         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3547         enumerate_reversed if reversed else enumerate,
3548     )
3549     for index, leaf in op(line.leaves):
3550         length = len(leaf.prefix) + len(leaf.value)
3551         if "\n" in leaf.value:
3552             return  # Multiline strings, we can't continue.
3553
3554         comment: Optional[Leaf]
3555         for comment in line.comments_after(leaf):
3556             length += len(comment.value)
3557
3558         yield index, leaf, length
3559
3560
3561 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3562     """Return True if `line` is no longer than `line_length`.
3563
3564     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3565     """
3566     if not line_str:
3567         line_str = str(line).strip("\n")
3568     return (
3569         len(line_str) <= line_length
3570         and "\n" not in line_str  # multiline strings
3571         and not line.contains_standalone_comments()
3572     )
3573
3574
3575 def can_be_split(line: Line) -> bool:
3576     """Return False if the line cannot be split *for sure*.
3577
3578     This is not an exhaustive search but a cheap heuristic that we can use to
3579     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3580     in unnecessary parentheses).
3581     """
3582     leaves = line.leaves
3583     if len(leaves) < 2:
3584         return False
3585
3586     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3587         call_count = 0
3588         dot_count = 0
3589         next = leaves[-1]
3590         for leaf in leaves[-2::-1]:
3591             if leaf.type in OPENING_BRACKETS:
3592                 if next.type not in CLOSING_BRACKETS:
3593                     return False
3594
3595                 call_count += 1
3596             elif leaf.type == token.DOT:
3597                 dot_count += 1
3598             elif leaf.type == token.NAME:
3599                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3600                     return False
3601
3602             elif leaf.type not in CLOSING_BRACKETS:
3603                 return False
3604
3605             if dot_count > 1 and call_count > 1:
3606                 return False
3607
3608     return True
3609
3610
3611 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3612     """Does `line` have a shape safe to reformat without optional parens around it?
3613
3614     Returns True for only a subset of potentially nice looking formattings but
3615     the point is to not return false positives that end up producing lines that
3616     are too long.
3617     """
3618     bt = line.bracket_tracker
3619     if not bt.delimiters:
3620         # Without delimiters the optional parentheses are useless.
3621         return True
3622
3623     max_priority = bt.max_delimiter_priority()
3624     if bt.delimiter_count_with_priority(max_priority) > 1:
3625         # With more than one delimiter of a kind the optional parentheses read better.
3626         return False
3627
3628     if max_priority == DOT_PRIORITY:
3629         # A single stranded method call doesn't require optional parentheses.
3630         return True
3631
3632     assert len(line.leaves) >= 2, "Stranded delimiter"
3633
3634     first = line.leaves[0]
3635     second = line.leaves[1]
3636     penultimate = line.leaves[-2]
3637     last = line.leaves[-1]
3638
3639     # With a single delimiter, omit if the expression starts or ends with
3640     # a bracket.
3641     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3642         remainder = False
3643         length = 4 * line.depth
3644         for _index, leaf, leaf_length in enumerate_with_length(line):
3645             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3646                 remainder = True
3647             if remainder:
3648                 length += leaf_length
3649                 if length > line_length:
3650                     break
3651
3652                 if leaf.type in OPENING_BRACKETS:
3653                     # There are brackets we can further split on.
3654                     remainder = False
3655
3656         else:
3657             # checked the entire string and line length wasn't exceeded
3658             if len(line.leaves) == _index + 1:
3659                 return True
3660
3661         # Note: we are not returning False here because a line might have *both*
3662         # a leading opening bracket and a trailing closing bracket.  If the
3663         # opening bracket doesn't match our rule, maybe the closing will.
3664
3665     if (
3666         last.type == token.RPAR
3667         or last.type == token.RBRACE
3668         or (
3669             # don't use indexing for omitting optional parentheses;
3670             # it looks weird
3671             last.type == token.RSQB
3672             and last.parent
3673             and last.parent.type != syms.trailer
3674         )
3675     ):
3676         if penultimate.type in OPENING_BRACKETS:
3677             # Empty brackets don't help.
3678             return False
3679
3680         if is_multiline_string(first):
3681             # Additional wrapping of a multiline string in this situation is
3682             # unnecessary.
3683             return True
3684
3685         length = 4 * line.depth
3686         seen_other_brackets = False
3687         for _index, leaf, leaf_length in enumerate_with_length(line):
3688             length += leaf_length
3689             if leaf is last.opening_bracket:
3690                 if seen_other_brackets or length <= line_length:
3691                     return True
3692
3693             elif leaf.type in OPENING_BRACKETS:
3694                 # There are brackets we can further split on.
3695                 seen_other_brackets = True
3696
3697     return False
3698
3699
3700 def get_cache_file(mode: FileMode) -> Path:
3701     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3702
3703
3704 def read_cache(mode: FileMode) -> Cache:
3705     """Read the cache if it exists and is well formed.
3706
3707     If it is not well formed, the call to write_cache later should resolve the issue.
3708     """
3709     cache_file = get_cache_file(mode)
3710     if not cache_file.exists():
3711         return {}
3712
3713     with cache_file.open("rb") as fobj:
3714         try:
3715             cache: Cache = pickle.load(fobj)
3716         except pickle.UnpicklingError:
3717             return {}
3718
3719     return cache
3720
3721
3722 def get_cache_info(path: Path) -> CacheInfo:
3723     """Return the information used to check if a file is already formatted or not."""
3724     stat = path.stat()
3725     return stat.st_mtime, stat.st_size
3726
3727
3728 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3729     """Split an iterable of paths in `sources` into two sets.
3730
3731     The first contains paths of files that modified on disk or are not in the
3732     cache. The other contains paths to non-modified files.
3733     """
3734     todo, done = set(), set()
3735     for src in sources:
3736         src = src.resolve()
3737         if cache.get(src) != get_cache_info(src):
3738             todo.add(src)
3739         else:
3740             done.add(src)
3741     return todo, done
3742
3743
3744 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3745     """Update the cache file."""
3746     cache_file = get_cache_file(mode)
3747     try:
3748         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3749         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3750         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3751             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3752         os.replace(f.name, cache_file)
3753     except OSError:
3754         pass
3755
3756
3757 def patch_click() -> None:
3758     """Make Click not crash.
3759
3760     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3761     default which restricts paths that it can access during the lifetime of the
3762     application.  Click refuses to work in this scenario by raising a RuntimeError.
3763
3764     In case of Black the likelihood that non-ASCII characters are going to be used in
3765     file paths is minimal since it's Python source code.  Moreover, this crash was
3766     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3767     """
3768     try:
3769         from click import core
3770         from click import _unicodefun  # type: ignore
3771     except ModuleNotFoundError:
3772         return
3773
3774     for module in (core, _unicodefun):
3775         if hasattr(module, "_verify_python3_env"):
3776             module._verify_python3_env = lambda: None
3777
3778
3779 def patched_main() -> None:
3780     freeze_support()
3781     patch_click()
3782     main()
3783
3784
3785 if __name__ == "__main__":
3786     patched_main()