]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

don't run more than 61 workers on Windows (#838)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "19.3b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA_IN_CALL = 4
137     TRAILING_COMMA_IN_DEF = 5
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.PY27: set(),
142     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
144     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
145     TargetVersion.PY36: {
146         Feature.UNICODE_LITERALS,
147         Feature.F_STRINGS,
148         Feature.NUMERIC_UNDERSCORES,
149         Feature.TRAILING_COMMA_IN_CALL,
150         Feature.TRAILING_COMMA_IN_DEF,
151     },
152     TargetVersion.PY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA_IN_CALL,
157         Feature.TRAILING_COMMA_IN_DEF,
158     },
159     TargetVersion.PY38: {
160         Feature.UNICODE_LITERALS,
161         Feature.F_STRINGS,
162         Feature.NUMERIC_UNDERSCORES,
163         Feature.TRAILING_COMMA_IN_CALL,
164         Feature.TRAILING_COMMA_IN_DEF,
165     },
166 }
167
168
169 @dataclass
170 class FileMode:
171     target_versions: Set[TargetVersion] = Factory(set)
172     line_length: int = DEFAULT_LINE_LENGTH
173     string_normalization: bool = True
174     is_pyi: bool = False
175
176     def get_cache_key(self) -> str:
177         if self.target_versions:
178             version_str = ",".join(
179                 str(version.value)
180                 for version in sorted(self.target_versions, key=lambda v: v.value)
181             )
182         else:
183             version_str = "-"
184         parts = [
185             version_str,
186             str(self.line_length),
187             str(int(self.string_normalization)),
188             str(int(self.is_pyi)),
189         ]
190         return ".".join(parts)
191
192
193 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
194     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
195
196
197 def read_pyproject_toml(
198     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
199 ) -> Optional[str]:
200     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
201
202     Returns the path to a successfully found and read configuration file, None
203     otherwise.
204     """
205     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
206     if not value:
207         root = find_project_root(ctx.params.get("src", ()))
208         path = root / "pyproject.toml"
209         if path.is_file():
210             value = str(path)
211         else:
212             return None
213
214     try:
215         pyproject_toml = toml.load(value)
216         config = pyproject_toml.get("tool", {}).get("black", {})
217     except (toml.TomlDecodeError, OSError) as e:
218         raise click.FileError(
219             filename=value, hint=f"Error reading configuration file: {e}"
220         )
221
222     if not config:
223         return None
224
225     if ctx.default_map is None:
226         ctx.default_map = {}
227     ctx.default_map.update(  # type: ignore  # bad types in .pyi
228         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
229     )
230     return value
231
232
233 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
234 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
235 @click.option(
236     "-l",
237     "--line-length",
238     type=int,
239     default=DEFAULT_LINE_LENGTH,
240     help="How many characters per line to allow.",
241     show_default=True,
242 )
243 @click.option(
244     "-t",
245     "--target-version",
246     type=click.Choice([v.name.lower() for v in TargetVersion]),
247     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
248     multiple=True,
249     help=(
250         "Python versions that should be supported by Black's output. [default: "
251         "per-file auto-detection]"
252     ),
253 )
254 @click.option(
255     "--py36",
256     is_flag=True,
257     help=(
258         "Allow using Python 3.6-only syntax on all input files.  This will put "
259         "trailing commas in function signatures and calls also after *args and "
260         "**kwargs. Deprecated; use --target-version instead. "
261         "[default: per-file auto-detection]"
262     ),
263 )
264 @click.option(
265     "--pyi",
266     is_flag=True,
267     help=(
268         "Format all input files like typing stubs regardless of file extension "
269         "(useful when piping source on standard input)."
270     ),
271 )
272 @click.option(
273     "-S",
274     "--skip-string-normalization",
275     is_flag=True,
276     help="Don't normalize string quotes or prefixes.",
277 )
278 @click.option(
279     "--check",
280     is_flag=True,
281     help=(
282         "Don't write the files back, just return the status.  Return code 0 "
283         "means nothing would change.  Return code 1 means some files would be "
284         "reformatted.  Return code 123 means there was an internal error."
285     ),
286 )
287 @click.option(
288     "--diff",
289     is_flag=True,
290     help="Don't write the files back, just output a diff for each file on stdout.",
291 )
292 @click.option(
293     "--fast/--safe",
294     is_flag=True,
295     help="If --fast given, skip temporary sanity checks. [default: --safe]",
296 )
297 @click.option(
298     "--include",
299     type=str,
300     default=DEFAULT_INCLUDES,
301     help=(
302         "A regular expression that matches files and directories that should be "
303         "included on recursive searches.  An empty value means all files are "
304         "included regardless of the name.  Use forward slashes for directories on "
305         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
306         "later."
307     ),
308     show_default=True,
309 )
310 @click.option(
311     "--exclude",
312     type=str,
313     default=DEFAULT_EXCLUDES,
314     help=(
315         "A regular expression that matches files and directories that should be "
316         "excluded on recursive searches.  An empty value means no paths are excluded. "
317         "Use forward slashes for directories on all platforms (Windows, too).  "
318         "Exclusions are calculated first, inclusions later."
319     ),
320     show_default=True,
321 )
322 @click.option(
323     "-q",
324     "--quiet",
325     is_flag=True,
326     help=(
327         "Don't emit non-error messages to stderr. Errors are still emitted, "
328         "silence those with 2>/dev/null."
329     ),
330 )
331 @click.option(
332     "-v",
333     "--verbose",
334     is_flag=True,
335     help=(
336         "Also emit messages to stderr about files that were not changed or were "
337         "ignored due to --exclude=."
338     ),
339 )
340 @click.version_option(version=__version__)
341 @click.argument(
342     "src",
343     nargs=-1,
344     type=click.Path(
345         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
346     ),
347     is_eager=True,
348 )
349 @click.option(
350     "--config",
351     type=click.Path(
352         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
353     ),
354     is_eager=True,
355     callback=read_pyproject_toml,
356     help="Read configuration from PATH.",
357 )
358 @click.pass_context
359 def main(
360     ctx: click.Context,
361     code: Optional[str],
362     line_length: int,
363     target_version: List[TargetVersion],
364     check: bool,
365     diff: bool,
366     fast: bool,
367     pyi: bool,
368     py36: bool,
369     skip_string_normalization: bool,
370     quiet: bool,
371     verbose: bool,
372     include: str,
373     exclude: str,
374     src: Tuple[str],
375     config: Optional[str],
376 ) -> None:
377     """The uncompromising code formatter."""
378     write_back = WriteBack.from_configuration(check=check, diff=diff)
379     if target_version:
380         if py36:
381             err(f"Cannot use both --target-version and --py36")
382             ctx.exit(2)
383         else:
384             versions = set(target_version)
385     elif py36:
386         err(
387             "--py36 is deprecated and will be removed in a future version. "
388             "Use --target-version py36 instead."
389         )
390         versions = PY36_VERSIONS
391     else:
392         # We'll autodetect later.
393         versions = set()
394     mode = FileMode(
395         target_versions=versions,
396         line_length=line_length,
397         is_pyi=pyi,
398         string_normalization=not skip_string_normalization,
399     )
400     if config and verbose:
401         out(f"Using configuration from {config}.", bold=False, fg="blue")
402     if code is not None:
403         print(format_str(code, mode=mode))
404         ctx.exit(0)
405     try:
406         include_regex = re_compile_maybe_verbose(include)
407     except re.error:
408         err(f"Invalid regular expression for include given: {include!r}")
409         ctx.exit(2)
410     try:
411         exclude_regex = re_compile_maybe_verbose(exclude)
412     except re.error:
413         err(f"Invalid regular expression for exclude given: {exclude!r}")
414         ctx.exit(2)
415     report = Report(check=check, quiet=quiet, verbose=verbose)
416     root = find_project_root(src)
417     sources: Set[Path] = set()
418     for s in src:
419         p = Path(s)
420         if p.is_dir():
421             sources.update(
422                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
423             )
424         elif p.is_file() or s == "-":
425             # if a file was explicitly given, we don't care about its extension
426             sources.add(p)
427         else:
428             err(f"invalid path: {s}")
429     if len(sources) == 0:
430         if verbose or not quiet:
431             out("No paths given. Nothing to do 😴")
432         ctx.exit(0)
433
434     if len(sources) == 1:
435         reformat_one(
436             src=sources.pop(),
437             fast=fast,
438             write_back=write_back,
439             mode=mode,
440             report=report,
441         )
442     else:
443         reformat_many(
444             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
445         )
446
447     if verbose or not quiet:
448         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
449         out(f"All done! {bang}")
450         click.secho(str(report), err=True)
451     ctx.exit(report.return_code)
452
453
454 def reformat_one(
455     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
456 ) -> None:
457     """Reformat a single file under `src` without spawning child processes.
458
459     If `quiet` is True, non-error messages are not output. `line_length`,
460     `write_back`, `fast` and `pyi` options are passed to
461     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
462     """
463     try:
464         changed = Changed.NO
465         if not src.is_file() and str(src) == "-":
466             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
467                 changed = Changed.YES
468         else:
469             cache: Cache = {}
470             if write_back != WriteBack.DIFF:
471                 cache = read_cache(mode)
472                 res_src = src.resolve()
473                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
474                     changed = Changed.CACHED
475             if changed is not Changed.CACHED and format_file_in_place(
476                 src, fast=fast, write_back=write_back, mode=mode
477             ):
478                 changed = Changed.YES
479             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
480                 write_back is WriteBack.CHECK and changed is Changed.NO
481             ):
482                 write_cache(cache, [src], mode)
483         report.done(src, changed)
484     except Exception as exc:
485         report.failed(src, str(exc))
486
487
488 def reformat_many(
489     sources: Set[Path],
490     fast: bool,
491     write_back: WriteBack,
492     mode: FileMode,
493     report: "Report",
494 ) -> None:
495     """Reformat multiple files using a ProcessPoolExecutor."""
496     loop = asyncio.get_event_loop()
497     worker_count = os.cpu_count()
498     if sys.platform == "win32":
499         # Work around https://bugs.python.org/issue26903
500         worker_count = min(worker_count, 61)
501     executor = ProcessPoolExecutor(max_workers=worker_count)
502     try:
503         loop.run_until_complete(
504             schedule_formatting(
505                 sources=sources,
506                 fast=fast,
507                 write_back=write_back,
508                 mode=mode,
509                 report=report,
510                 loop=loop,
511                 executor=executor,
512             )
513         )
514     finally:
515         shutdown(loop)
516
517
518 async def schedule_formatting(
519     sources: Set[Path],
520     fast: bool,
521     write_back: WriteBack,
522     mode: FileMode,
523     report: "Report",
524     loop: BaseEventLoop,
525     executor: Executor,
526 ) -> None:
527     """Run formatting of `sources` in parallel using the provided `executor`.
528
529     (Use ProcessPoolExecutors for actual parallelism.)
530
531     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
532     :func:`format_file_in_place`.
533     """
534     cache: Cache = {}
535     if write_back != WriteBack.DIFF:
536         cache = read_cache(mode)
537         sources, cached = filter_cached(cache, sources)
538         for src in sorted(cached):
539             report.done(src, Changed.CACHED)
540     if not sources:
541         return
542
543     cancelled = []
544     sources_to_cache = []
545     lock = None
546     if write_back == WriteBack.DIFF:
547         # For diff output, we need locks to ensure we don't interleave output
548         # from different processes.
549         manager = Manager()
550         lock = manager.Lock()
551     tasks = {
552         asyncio.ensure_future(
553             loop.run_in_executor(
554                 executor, format_file_in_place, src, fast, mode, write_back, lock
555             )
556         ): src
557         for src in sorted(sources)
558     }
559     pending: Iterable[asyncio.Future] = tasks.keys()
560     try:
561         loop.add_signal_handler(signal.SIGINT, cancel, pending)
562         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
563     except NotImplementedError:
564         # There are no good alternatives for these on Windows.
565         pass
566     while pending:
567         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
568         for task in done:
569             src = tasks.pop(task)
570             if task.cancelled():
571                 cancelled.append(task)
572             elif task.exception():
573                 report.failed(src, str(task.exception()))
574             else:
575                 changed = Changed.YES if task.result() else Changed.NO
576                 # If the file was written back or was successfully checked as
577                 # well-formatted, store this information in the cache.
578                 if write_back is WriteBack.YES or (
579                     write_back is WriteBack.CHECK and changed is Changed.NO
580                 ):
581                     sources_to_cache.append(src)
582                 report.done(src, changed)
583     if cancelled:
584         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
585     if sources_to_cache:
586         write_cache(cache, sources_to_cache, mode)
587
588
589 def format_file_in_place(
590     src: Path,
591     fast: bool,
592     mode: FileMode,
593     write_back: WriteBack = WriteBack.NO,
594     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
595 ) -> bool:
596     """Format file under `src` path. Return True if changed.
597
598     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
599     code to the file.
600     `line_length` and `fast` options are passed to :func:`format_file_contents`.
601     """
602     if src.suffix == ".pyi":
603         mode = evolve(mode, is_pyi=True)
604
605     then = datetime.utcfromtimestamp(src.stat().st_mtime)
606     with open(src, "rb") as buf:
607         src_contents, encoding, newline = decode_bytes(buf.read())
608     try:
609         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
610     except NothingChanged:
611         return False
612
613     if write_back == write_back.YES:
614         with open(src, "w", encoding=encoding, newline=newline) as f:
615             f.write(dst_contents)
616     elif write_back == write_back.DIFF:
617         now = datetime.utcnow()
618         src_name = f"{src}\t{then} +0000"
619         dst_name = f"{src}\t{now} +0000"
620         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
621         if lock:
622             lock.acquire()
623         try:
624             f = io.TextIOWrapper(
625                 sys.stdout.buffer,
626                 encoding=encoding,
627                 newline=newline,
628                 write_through=True,
629             )
630             f.write(diff_contents)
631             f.detach()
632         finally:
633             if lock:
634                 lock.release()
635     return True
636
637
638 def format_stdin_to_stdout(
639     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
640 ) -> bool:
641     """Format file on stdin. Return True if changed.
642
643     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
644     write a diff to stdout. The `mode` argument is passed to
645     :func:`format_file_contents`.
646     """
647     then = datetime.utcnow()
648     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
649     dst = src
650     try:
651         dst = format_file_contents(src, fast=fast, mode=mode)
652         return True
653
654     except NothingChanged:
655         return False
656
657     finally:
658         f = io.TextIOWrapper(
659             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
660         )
661         if write_back == WriteBack.YES:
662             f.write(dst)
663         elif write_back == WriteBack.DIFF:
664             now = datetime.utcnow()
665             src_name = f"STDIN\t{then} +0000"
666             dst_name = f"STDOUT\t{now} +0000"
667             f.write(diff(src, dst, src_name, dst_name))
668         f.detach()
669
670
671 def format_file_contents(
672     src_contents: str, *, fast: bool, mode: FileMode
673 ) -> FileContent:
674     """Reformat contents a file and return new contents.
675
676     If `fast` is False, additionally confirm that the reformatted code is
677     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
678     `line_length` is passed to :func:`format_str`.
679     """
680     if src_contents.strip() == "":
681         raise NothingChanged
682
683     dst_contents = format_str(src_contents, mode=mode)
684     if src_contents == dst_contents:
685         raise NothingChanged
686
687     if not fast:
688         assert_equivalent(src_contents, dst_contents)
689         assert_stable(src_contents, dst_contents, mode=mode)
690     return dst_contents
691
692
693 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
694     """Reformat a string and return new contents.
695
696     `line_length` determines how many characters per line are allowed.
697     """
698     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
699     dst_contents = ""
700     future_imports = get_future_imports(src_node)
701     if mode.target_versions:
702         versions = mode.target_versions
703     else:
704         versions = detect_target_versions(src_node)
705     normalize_fmt_off(src_node)
706     lines = LineGenerator(
707         remove_u_prefix="unicode_literals" in future_imports
708         or supports_feature(versions, Feature.UNICODE_LITERALS),
709         is_pyi=mode.is_pyi,
710         normalize_strings=mode.string_normalization,
711     )
712     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
713     empty_line = Line()
714     after = 0
715     split_line_features = {
716         feature
717         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
718         if supports_feature(versions, feature)
719     }
720     for current_line in lines.visit(src_node):
721         for _ in range(after):
722             dst_contents += str(empty_line)
723         before, after = elt.maybe_empty_lines(current_line)
724         for _ in range(before):
725             dst_contents += str(empty_line)
726         for line in split_line(
727             current_line, line_length=mode.line_length, features=split_line_features
728         ):
729             dst_contents += str(line)
730     return dst_contents
731
732
733 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
734     """Return a tuple of (decoded_contents, encoding, newline).
735
736     `newline` is either CRLF or LF but `decoded_contents` is decoded with
737     universal newlines (i.e. only contains LF).
738     """
739     srcbuf = io.BytesIO(src)
740     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
741     if not lines:
742         return "", encoding, "\n"
743
744     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
745     srcbuf.seek(0)
746     with io.TextIOWrapper(srcbuf, encoding) as tiow:
747         return tiow.read(), encoding, newline
748
749
750 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
751     if not target_versions:
752         # No target_version specified, so try all grammars.
753         return [
754             pygram.python_grammar_no_print_statement_no_exec_statement,
755             pygram.python_grammar_no_print_statement,
756             pygram.python_grammar,
757         ]
758     elif all(version.is_python2() for version in target_versions):
759         # Python 2-only code, so try Python 2 grammars.
760         return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
761     else:
762         # Python 3-compatible code, so only try Python 3 grammar.
763         return [pygram.python_grammar_no_print_statement_no_exec_statement]
764
765
766 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
767     """Given a string with source, return the lib2to3 Node."""
768     if src_txt[-1:] != "\n":
769         src_txt += "\n"
770
771     for grammar in get_grammars(set(target_versions)):
772         drv = driver.Driver(grammar, pytree.convert)
773         try:
774             result = drv.parse_string(src_txt, True)
775             break
776
777         except ParseError as pe:
778             lineno, column = pe.context[1]
779             lines = src_txt.splitlines()
780             try:
781                 faulty_line = lines[lineno - 1]
782             except IndexError:
783                 faulty_line = "<line number missing in source>"
784             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
785     else:
786         raise exc from None
787
788     if isinstance(result, Leaf):
789         result = Node(syms.file_input, [result])
790     return result
791
792
793 def lib2to3_unparse(node: Node) -> str:
794     """Given a lib2to3 node, return its string representation."""
795     code = str(node)
796     return code
797
798
799 T = TypeVar("T")
800
801
802 class Visitor(Generic[T]):
803     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
804
805     def visit(self, node: LN) -> Iterator[T]:
806         """Main method to visit `node` and its children.
807
808         It tries to find a `visit_*()` method for the given `node.type`, like
809         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
810         If no dedicated `visit_*()` method is found, chooses `visit_default()`
811         instead.
812
813         Then yields objects of type `T` from the selected visitor.
814         """
815         if node.type < 256:
816             name = token.tok_name[node.type]
817         else:
818             name = type_repr(node.type)
819         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
820
821     def visit_default(self, node: LN) -> Iterator[T]:
822         """Default `visit_*()` implementation. Recurses to children of `node`."""
823         if isinstance(node, Node):
824             for child in node.children:
825                 yield from self.visit(child)
826
827
828 @dataclass
829 class DebugVisitor(Visitor[T]):
830     tree_depth: int = 0
831
832     def visit_default(self, node: LN) -> Iterator[T]:
833         indent = " " * (2 * self.tree_depth)
834         if isinstance(node, Node):
835             _type = type_repr(node.type)
836             out(f"{indent}{_type}", fg="yellow")
837             self.tree_depth += 1
838             for child in node.children:
839                 yield from self.visit(child)
840
841             self.tree_depth -= 1
842             out(f"{indent}/{_type}", fg="yellow", bold=False)
843         else:
844             _type = token.tok_name.get(node.type, str(node.type))
845             out(f"{indent}{_type}", fg="blue", nl=False)
846             if node.prefix:
847                 # We don't have to handle prefixes for `Node` objects since
848                 # that delegates to the first child anyway.
849                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
850             out(f" {node.value!r}", fg="blue", bold=False)
851
852     @classmethod
853     def show(cls, code: Union[str, Leaf, Node]) -> None:
854         """Pretty-print the lib2to3 AST of a given string of `code`.
855
856         Convenience method for debugging.
857         """
858         v: DebugVisitor[None] = DebugVisitor()
859         if isinstance(code, str):
860             code = lib2to3_parse(code)
861         list(v.visit(code))
862
863
864 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
865 STATEMENT = {
866     syms.if_stmt,
867     syms.while_stmt,
868     syms.for_stmt,
869     syms.try_stmt,
870     syms.except_clause,
871     syms.with_stmt,
872     syms.funcdef,
873     syms.classdef,
874 }
875 STANDALONE_COMMENT = 153
876 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
877 LOGIC_OPERATORS = {"and", "or"}
878 COMPARATORS = {
879     token.LESS,
880     token.GREATER,
881     token.EQEQUAL,
882     token.NOTEQUAL,
883     token.LESSEQUAL,
884     token.GREATEREQUAL,
885 }
886 MATH_OPERATORS = {
887     token.VBAR,
888     token.CIRCUMFLEX,
889     token.AMPER,
890     token.LEFTSHIFT,
891     token.RIGHTSHIFT,
892     token.PLUS,
893     token.MINUS,
894     token.STAR,
895     token.SLASH,
896     token.DOUBLESLASH,
897     token.PERCENT,
898     token.AT,
899     token.TILDE,
900     token.DOUBLESTAR,
901 }
902 STARS = {token.STAR, token.DOUBLESTAR}
903 VARARGS_PARENTS = {
904     syms.arglist,
905     syms.argument,  # double star in arglist
906     syms.trailer,  # single argument to call
907     syms.typedargslist,
908     syms.varargslist,  # lambdas
909 }
910 UNPACKING_PARENTS = {
911     syms.atom,  # single element of a list or set literal
912     syms.dictsetmaker,
913     syms.listmaker,
914     syms.testlist_gexp,
915     syms.testlist_star_expr,
916 }
917 TEST_DESCENDANTS = {
918     syms.test,
919     syms.lambdef,
920     syms.or_test,
921     syms.and_test,
922     syms.not_test,
923     syms.comparison,
924     syms.star_expr,
925     syms.expr,
926     syms.xor_expr,
927     syms.and_expr,
928     syms.shift_expr,
929     syms.arith_expr,
930     syms.trailer,
931     syms.term,
932     syms.power,
933 }
934 ASSIGNMENTS = {
935     "=",
936     "+=",
937     "-=",
938     "*=",
939     "@=",
940     "/=",
941     "%=",
942     "&=",
943     "|=",
944     "^=",
945     "<<=",
946     ">>=",
947     "**=",
948     "//=",
949 }
950 COMPREHENSION_PRIORITY = 20
951 COMMA_PRIORITY = 18
952 TERNARY_PRIORITY = 16
953 LOGIC_PRIORITY = 14
954 STRING_PRIORITY = 12
955 COMPARATOR_PRIORITY = 10
956 MATH_PRIORITIES = {
957     token.VBAR: 9,
958     token.CIRCUMFLEX: 8,
959     token.AMPER: 7,
960     token.LEFTSHIFT: 6,
961     token.RIGHTSHIFT: 6,
962     token.PLUS: 5,
963     token.MINUS: 5,
964     token.STAR: 4,
965     token.SLASH: 4,
966     token.DOUBLESLASH: 4,
967     token.PERCENT: 4,
968     token.AT: 4,
969     token.TILDE: 3,
970     token.DOUBLESTAR: 2,
971 }
972 DOT_PRIORITY = 1
973
974
975 @dataclass
976 class BracketTracker:
977     """Keeps track of brackets on a line."""
978
979     depth: int = 0
980     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
981     delimiters: Dict[LeafID, Priority] = Factory(dict)
982     previous: Optional[Leaf] = None
983     _for_loop_depths: List[int] = Factory(list)
984     _lambda_argument_depths: List[int] = Factory(list)
985
986     def mark(self, leaf: Leaf) -> None:
987         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
988
989         All leaves receive an int `bracket_depth` field that stores how deep
990         within brackets a given leaf is. 0 means there are no enclosing brackets
991         that started on this line.
992
993         If a leaf is itself a closing bracket, it receives an `opening_bracket`
994         field that it forms a pair with. This is a one-directional link to
995         avoid reference cycles.
996
997         If a leaf is a delimiter (a token on which Black can split the line if
998         needed) and it's on depth 0, its `id()` is stored in the tracker's
999         `delimiters` field.
1000         """
1001         if leaf.type == token.COMMENT:
1002             return
1003
1004         self.maybe_decrement_after_for_loop_variable(leaf)
1005         self.maybe_decrement_after_lambda_arguments(leaf)
1006         if leaf.type in CLOSING_BRACKETS:
1007             self.depth -= 1
1008             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1009             leaf.opening_bracket = opening_bracket
1010         leaf.bracket_depth = self.depth
1011         if self.depth == 0:
1012             delim = is_split_before_delimiter(leaf, self.previous)
1013             if delim and self.previous is not None:
1014                 self.delimiters[id(self.previous)] = delim
1015             else:
1016                 delim = is_split_after_delimiter(leaf, self.previous)
1017                 if delim:
1018                     self.delimiters[id(leaf)] = delim
1019         if leaf.type in OPENING_BRACKETS:
1020             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1021             self.depth += 1
1022         self.previous = leaf
1023         self.maybe_increment_lambda_arguments(leaf)
1024         self.maybe_increment_for_loop_variable(leaf)
1025
1026     def any_open_brackets(self) -> bool:
1027         """Return True if there is an yet unmatched open bracket on the line."""
1028         return bool(self.bracket_match)
1029
1030     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1031         """Return the highest priority of a delimiter found on the line.
1032
1033         Values are consistent with what `is_split_*_delimiter()` return.
1034         Raises ValueError on no delimiters.
1035         """
1036         return max(v for k, v in self.delimiters.items() if k not in exclude)
1037
1038     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1039         """Return the number of delimiters with the given `priority`.
1040
1041         If no `priority` is passed, defaults to max priority on the line.
1042         """
1043         if not self.delimiters:
1044             return 0
1045
1046         priority = priority or self.max_delimiter_priority()
1047         return sum(1 for p in self.delimiters.values() if p == priority)
1048
1049     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1050         """In a for loop, or comprehension, the variables are often unpacks.
1051
1052         To avoid splitting on the comma in this situation, increase the depth of
1053         tokens between `for` and `in`.
1054         """
1055         if leaf.type == token.NAME and leaf.value == "for":
1056             self.depth += 1
1057             self._for_loop_depths.append(self.depth)
1058             return True
1059
1060         return False
1061
1062     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1063         """See `maybe_increment_for_loop_variable` above for explanation."""
1064         if (
1065             self._for_loop_depths
1066             and self._for_loop_depths[-1] == self.depth
1067             and leaf.type == token.NAME
1068             and leaf.value == "in"
1069         ):
1070             self.depth -= 1
1071             self._for_loop_depths.pop()
1072             return True
1073
1074         return False
1075
1076     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1077         """In a lambda expression, there might be more than one argument.
1078
1079         To avoid splitting on the comma in this situation, increase the depth of
1080         tokens between `lambda` and `:`.
1081         """
1082         if leaf.type == token.NAME and leaf.value == "lambda":
1083             self.depth += 1
1084             self._lambda_argument_depths.append(self.depth)
1085             return True
1086
1087         return False
1088
1089     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1090         """See `maybe_increment_lambda_arguments` above for explanation."""
1091         if (
1092             self._lambda_argument_depths
1093             and self._lambda_argument_depths[-1] == self.depth
1094             and leaf.type == token.COLON
1095         ):
1096             self.depth -= 1
1097             self._lambda_argument_depths.pop()
1098             return True
1099
1100         return False
1101
1102     def get_open_lsqb(self) -> Optional[Leaf]:
1103         """Return the most recent opening square bracket (if any)."""
1104         return self.bracket_match.get((self.depth - 1, token.RSQB))
1105
1106
1107 @dataclass
1108 class Line:
1109     """Holds leaves and comments. Can be printed with `str(line)`."""
1110
1111     depth: int = 0
1112     leaves: List[Leaf] = Factory(list)
1113     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1114     bracket_tracker: BracketTracker = Factory(BracketTracker)
1115     inside_brackets: bool = False
1116     should_explode: bool = False
1117
1118     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1119         """Add a new `leaf` to the end of the line.
1120
1121         Unless `preformatted` is True, the `leaf` will receive a new consistent
1122         whitespace prefix and metadata applied by :class:`BracketTracker`.
1123         Trailing commas are maybe removed, unpacked for loop variables are
1124         demoted from being delimiters.
1125
1126         Inline comments are put aside.
1127         """
1128         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1129         if not has_value:
1130             return
1131
1132         if token.COLON == leaf.type and self.is_class_paren_empty:
1133             del self.leaves[-2:]
1134         if self.leaves and not preformatted:
1135             # Note: at this point leaf.prefix should be empty except for
1136             # imports, for which we only preserve newlines.
1137             leaf.prefix += whitespace(
1138                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1139             )
1140         if self.inside_brackets or not preformatted:
1141             self.bracket_tracker.mark(leaf)
1142             self.maybe_remove_trailing_comma(leaf)
1143         if not self.append_comment(leaf):
1144             self.leaves.append(leaf)
1145
1146     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1147         """Like :func:`append()` but disallow invalid standalone comment structure.
1148
1149         Raises ValueError when any `leaf` is appended after a standalone comment
1150         or when a standalone comment is not the first leaf on the line.
1151         """
1152         if self.bracket_tracker.depth == 0:
1153             if self.is_comment:
1154                 raise ValueError("cannot append to standalone comments")
1155
1156             if self.leaves and leaf.type == STANDALONE_COMMENT:
1157                 raise ValueError(
1158                     "cannot append standalone comments to a populated line"
1159                 )
1160
1161         self.append(leaf, preformatted=preformatted)
1162
1163     @property
1164     def is_comment(self) -> bool:
1165         """Is this line a standalone comment?"""
1166         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1167
1168     @property
1169     def is_decorator(self) -> bool:
1170         """Is this line a decorator?"""
1171         return bool(self) and self.leaves[0].type == token.AT
1172
1173     @property
1174     def is_import(self) -> bool:
1175         """Is this an import line?"""
1176         return bool(self) and is_import(self.leaves[0])
1177
1178     @property
1179     def is_class(self) -> bool:
1180         """Is this line a class definition?"""
1181         return (
1182             bool(self)
1183             and self.leaves[0].type == token.NAME
1184             and self.leaves[0].value == "class"
1185         )
1186
1187     @property
1188     def is_stub_class(self) -> bool:
1189         """Is this line a class definition with a body consisting only of "..."?"""
1190         return self.is_class and self.leaves[-3:] == [
1191             Leaf(token.DOT, ".") for _ in range(3)
1192         ]
1193
1194     @property
1195     def is_def(self) -> bool:
1196         """Is this a function definition? (Also returns True for async defs.)"""
1197         try:
1198             first_leaf = self.leaves[0]
1199         except IndexError:
1200             return False
1201
1202         try:
1203             second_leaf: Optional[Leaf] = self.leaves[1]
1204         except IndexError:
1205             second_leaf = None
1206         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1207             first_leaf.type == token.ASYNC
1208             and second_leaf is not None
1209             and second_leaf.type == token.NAME
1210             and second_leaf.value == "def"
1211         )
1212
1213     @property
1214     def is_class_paren_empty(self) -> bool:
1215         """Is this a class with no base classes but using parentheses?
1216
1217         Those are unnecessary and should be removed.
1218         """
1219         return (
1220             bool(self)
1221             and len(self.leaves) == 4
1222             and self.is_class
1223             and self.leaves[2].type == token.LPAR
1224             and self.leaves[2].value == "("
1225             and self.leaves[3].type == token.RPAR
1226             and self.leaves[3].value == ")"
1227         )
1228
1229     @property
1230     def is_triple_quoted_string(self) -> bool:
1231         """Is the line a triple quoted string?"""
1232         return (
1233             bool(self)
1234             and self.leaves[0].type == token.STRING
1235             and self.leaves[0].value.startswith(('"""', "'''"))
1236         )
1237
1238     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1239         """If so, needs to be split before emitting."""
1240         for leaf in self.leaves:
1241             if leaf.type == STANDALONE_COMMENT:
1242                 if leaf.bracket_depth <= depth_limit:
1243                     return True
1244         return False
1245
1246     def contains_inner_type_comments(self) -> bool:
1247         ignored_ids = set()
1248         try:
1249             last_leaf = self.leaves[-1]
1250             ignored_ids.add(id(last_leaf))
1251             if last_leaf.type == token.COMMA:
1252                 # When trailing commas are inserted by Black for consistency, comments
1253                 # after the previous last element are not moved (they don't have to,
1254                 # rendering will still be correct).  So we ignore trailing commas.
1255                 last_leaf = self.leaves[-2]
1256                 ignored_ids.add(id(last_leaf))
1257         except IndexError:
1258             return False
1259
1260         for leaf_id, comments in self.comments.items():
1261             if leaf_id in ignored_ids:
1262                 continue
1263
1264             for comment in comments:
1265                 if is_type_comment(comment):
1266                     return True
1267
1268         return False
1269
1270     def contains_multiline_strings(self) -> bool:
1271         for leaf in self.leaves:
1272             if is_multiline_string(leaf):
1273                 return True
1274
1275         return False
1276
1277     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1278         """Remove trailing comma if there is one and it's safe."""
1279         if not (
1280             self.leaves
1281             and self.leaves[-1].type == token.COMMA
1282             and closing.type in CLOSING_BRACKETS
1283         ):
1284             return False
1285
1286         if closing.type == token.RBRACE:
1287             self.remove_trailing_comma()
1288             return True
1289
1290         if closing.type == token.RSQB:
1291             comma = self.leaves[-1]
1292             if comma.parent and comma.parent.type == syms.listmaker:
1293                 self.remove_trailing_comma()
1294                 return True
1295
1296         # For parens let's check if it's safe to remove the comma.
1297         # Imports are always safe.
1298         if self.is_import:
1299             self.remove_trailing_comma()
1300             return True
1301
1302         # Otherwise, if the trailing one is the only one, we might mistakenly
1303         # change a tuple into a different type by removing the comma.
1304         depth = closing.bracket_depth + 1
1305         commas = 0
1306         opening = closing.opening_bracket
1307         for _opening_index, leaf in enumerate(self.leaves):
1308             if leaf is opening:
1309                 break
1310
1311         else:
1312             return False
1313
1314         for leaf in self.leaves[_opening_index + 1 :]:
1315             if leaf is closing:
1316                 break
1317
1318             bracket_depth = leaf.bracket_depth
1319             if bracket_depth == depth and leaf.type == token.COMMA:
1320                 commas += 1
1321                 if leaf.parent and leaf.parent.type == syms.arglist:
1322                     commas += 1
1323                     break
1324
1325         if commas > 1:
1326             self.remove_trailing_comma()
1327             return True
1328
1329         return False
1330
1331     def append_comment(self, comment: Leaf) -> bool:
1332         """Add an inline or standalone comment to the line."""
1333         if (
1334             comment.type == STANDALONE_COMMENT
1335             and self.bracket_tracker.any_open_brackets()
1336         ):
1337             comment.prefix = ""
1338             return False
1339
1340         if comment.type != token.COMMENT:
1341             return False
1342
1343         if not self.leaves:
1344             comment.type = STANDALONE_COMMENT
1345             comment.prefix = ""
1346             return False
1347
1348         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1349         return True
1350
1351     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1352         """Generate comments that should appear directly after `leaf`."""
1353         return self.comments.get(id(leaf), [])
1354
1355     def remove_trailing_comma(self) -> None:
1356         """Remove the trailing comma and moves the comments attached to it."""
1357         trailing_comma = self.leaves.pop()
1358         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1359         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1360             trailing_comma_comments
1361         )
1362
1363     def is_complex_subscript(self, leaf: Leaf) -> bool:
1364         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1365         open_lsqb = self.bracket_tracker.get_open_lsqb()
1366         if open_lsqb is None:
1367             return False
1368
1369         subscript_start = open_lsqb.next_sibling
1370
1371         if isinstance(subscript_start, Node):
1372             if subscript_start.type == syms.listmaker:
1373                 return False
1374
1375             if subscript_start.type == syms.subscriptlist:
1376                 subscript_start = child_towards(subscript_start, leaf)
1377         return subscript_start is not None and any(
1378             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1379         )
1380
1381     def __str__(self) -> str:
1382         """Render the line."""
1383         if not self:
1384             return "\n"
1385
1386         indent = "    " * self.depth
1387         leaves = iter(self.leaves)
1388         first = next(leaves)
1389         res = f"{first.prefix}{indent}{first.value}"
1390         for leaf in leaves:
1391             res += str(leaf)
1392         for comment in itertools.chain.from_iterable(self.comments.values()):
1393             res += str(comment)
1394         return res + "\n"
1395
1396     def __bool__(self) -> bool:
1397         """Return True if the line has leaves or comments."""
1398         return bool(self.leaves or self.comments)
1399
1400
1401 @dataclass
1402 class EmptyLineTracker:
1403     """Provides a stateful method that returns the number of potential extra
1404     empty lines needed before and after the currently processed line.
1405
1406     Note: this tracker works on lines that haven't been split yet.  It assumes
1407     the prefix of the first leaf consists of optional newlines.  Those newlines
1408     are consumed by `maybe_empty_lines()` and included in the computation.
1409     """
1410
1411     is_pyi: bool = False
1412     previous_line: Optional[Line] = None
1413     previous_after: int = 0
1414     previous_defs: List[int] = Factory(list)
1415
1416     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1417         """Return the number of extra empty lines before and after the `current_line`.
1418
1419         This is for separating `def`, `async def` and `class` with extra empty
1420         lines (two on module-level).
1421         """
1422         before, after = self._maybe_empty_lines(current_line)
1423         before -= self.previous_after
1424         self.previous_after = after
1425         self.previous_line = current_line
1426         return before, after
1427
1428     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1429         max_allowed = 1
1430         if current_line.depth == 0:
1431             max_allowed = 1 if self.is_pyi else 2
1432         if current_line.leaves:
1433             # Consume the first leaf's extra newlines.
1434             first_leaf = current_line.leaves[0]
1435             before = first_leaf.prefix.count("\n")
1436             before = min(before, max_allowed)
1437             first_leaf.prefix = ""
1438         else:
1439             before = 0
1440         depth = current_line.depth
1441         while self.previous_defs and self.previous_defs[-1] >= depth:
1442             self.previous_defs.pop()
1443             if self.is_pyi:
1444                 before = 0 if depth else 1
1445             else:
1446                 before = 1 if depth else 2
1447         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1448             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1449
1450         if (
1451             self.previous_line
1452             and self.previous_line.is_import
1453             and not current_line.is_import
1454             and depth == self.previous_line.depth
1455         ):
1456             return (before or 1), 0
1457
1458         if (
1459             self.previous_line
1460             and self.previous_line.is_class
1461             and current_line.is_triple_quoted_string
1462         ):
1463             return before, 1
1464
1465         return before, 0
1466
1467     def _maybe_empty_lines_for_class_or_def(
1468         self, current_line: Line, before: int
1469     ) -> Tuple[int, int]:
1470         if not current_line.is_decorator:
1471             self.previous_defs.append(current_line.depth)
1472         if self.previous_line is None:
1473             # Don't insert empty lines before the first line in the file.
1474             return 0, 0
1475
1476         if self.previous_line.is_decorator:
1477             return 0, 0
1478
1479         if self.previous_line.depth < current_line.depth and (
1480             self.previous_line.is_class or self.previous_line.is_def
1481         ):
1482             return 0, 0
1483
1484         if (
1485             self.previous_line.is_comment
1486             and self.previous_line.depth == current_line.depth
1487             and before == 0
1488         ):
1489             return 0, 0
1490
1491         if self.is_pyi:
1492             if self.previous_line.depth > current_line.depth:
1493                 newlines = 1
1494             elif current_line.is_class or self.previous_line.is_class:
1495                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1496                     # No blank line between classes with an empty body
1497                     newlines = 0
1498                 else:
1499                     newlines = 1
1500             elif current_line.is_def and not self.previous_line.is_def:
1501                 # Blank line between a block of functions and a block of non-functions
1502                 newlines = 1
1503             else:
1504                 newlines = 0
1505         else:
1506             newlines = 2
1507         if current_line.depth and newlines:
1508             newlines -= 1
1509         return newlines, 0
1510
1511
1512 @dataclass
1513 class LineGenerator(Visitor[Line]):
1514     """Generates reformatted Line objects.  Empty lines are not emitted.
1515
1516     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1517     in ways that will no longer stringify to valid Python code on the tree.
1518     """
1519
1520     is_pyi: bool = False
1521     normalize_strings: bool = True
1522     current_line: Line = Factory(Line)
1523     remove_u_prefix: bool = False
1524
1525     def line(self, indent: int = 0) -> Iterator[Line]:
1526         """Generate a line.
1527
1528         If the line is empty, only emit if it makes sense.
1529         If the line is too long, split it first and then generate.
1530
1531         If any lines were generated, set up a new current_line.
1532         """
1533         if not self.current_line:
1534             self.current_line.depth += indent
1535             return  # Line is empty, don't emit. Creating a new one unnecessary.
1536
1537         complete_line = self.current_line
1538         self.current_line = Line(depth=complete_line.depth + indent)
1539         yield complete_line
1540
1541     def visit_default(self, node: LN) -> Iterator[Line]:
1542         """Default `visit_*()` implementation. Recurses to children of `node`."""
1543         if isinstance(node, Leaf):
1544             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1545             for comment in generate_comments(node):
1546                 if any_open_brackets:
1547                     # any comment within brackets is subject to splitting
1548                     self.current_line.append(comment)
1549                 elif comment.type == token.COMMENT:
1550                     # regular trailing comment
1551                     self.current_line.append(comment)
1552                     yield from self.line()
1553
1554                 else:
1555                     # regular standalone comment
1556                     yield from self.line()
1557
1558                     self.current_line.append(comment)
1559                     yield from self.line()
1560
1561             normalize_prefix(node, inside_brackets=any_open_brackets)
1562             if self.normalize_strings and node.type == token.STRING:
1563                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1564                 normalize_string_quotes(node)
1565             if node.type == token.NUMBER:
1566                 normalize_numeric_literal(node)
1567             if node.type not in WHITESPACE:
1568                 self.current_line.append(node)
1569         yield from super().visit_default(node)
1570
1571     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1572         """Increase indentation level, maybe yield a line."""
1573         # In blib2to3 INDENT never holds comments.
1574         yield from self.line(+1)
1575         yield from self.visit_default(node)
1576
1577     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1578         """Decrease indentation level, maybe yield a line."""
1579         # The current line might still wait for trailing comments.  At DEDENT time
1580         # there won't be any (they would be prefixes on the preceding NEWLINE).
1581         # Emit the line then.
1582         yield from self.line()
1583
1584         # While DEDENT has no value, its prefix may contain standalone comments
1585         # that belong to the current indentation level.  Get 'em.
1586         yield from self.visit_default(node)
1587
1588         # Finally, emit the dedent.
1589         yield from self.line(-1)
1590
1591     def visit_stmt(
1592         self, node: Node, keywords: Set[str], parens: Set[str]
1593     ) -> Iterator[Line]:
1594         """Visit a statement.
1595
1596         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1597         `def`, `with`, `class`, `assert` and assignments.
1598
1599         The relevant Python language `keywords` for a given statement will be
1600         NAME leaves within it. This methods puts those on a separate line.
1601
1602         `parens` holds a set of string leaf values immediately after which
1603         invisible parens should be put.
1604         """
1605         normalize_invisible_parens(node, parens_after=parens)
1606         for child in node.children:
1607             if child.type == token.NAME and child.value in keywords:  # type: ignore
1608                 yield from self.line()
1609
1610             yield from self.visit(child)
1611
1612     def visit_suite(self, node: Node) -> Iterator[Line]:
1613         """Visit a suite."""
1614         if self.is_pyi and is_stub_suite(node):
1615             yield from self.visit(node.children[2])
1616         else:
1617             yield from self.visit_default(node)
1618
1619     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1620         """Visit a statement without nested statements."""
1621         is_suite_like = node.parent and node.parent.type in STATEMENT
1622         if is_suite_like:
1623             if self.is_pyi and is_stub_body(node):
1624                 yield from self.visit_default(node)
1625             else:
1626                 yield from self.line(+1)
1627                 yield from self.visit_default(node)
1628                 yield from self.line(-1)
1629
1630         else:
1631             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1632                 yield from self.line()
1633             yield from self.visit_default(node)
1634
1635     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1636         """Visit `async def`, `async for`, `async with`."""
1637         yield from self.line()
1638
1639         children = iter(node.children)
1640         for child in children:
1641             yield from self.visit(child)
1642
1643             if child.type == token.ASYNC:
1644                 break
1645
1646         internal_stmt = next(children)
1647         for child in internal_stmt.children:
1648             yield from self.visit(child)
1649
1650     def visit_decorators(self, node: Node) -> Iterator[Line]:
1651         """Visit decorators."""
1652         for child in node.children:
1653             yield from self.line()
1654             yield from self.visit(child)
1655
1656     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1657         """Remove a semicolon and put the other statement on a separate line."""
1658         yield from self.line()
1659
1660     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1661         """End of file. Process outstanding comments and end with a newline."""
1662         yield from self.visit_default(leaf)
1663         yield from self.line()
1664
1665     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1666         if not self.current_line.bracket_tracker.any_open_brackets():
1667             yield from self.line()
1668         yield from self.visit_default(leaf)
1669
1670     def __attrs_post_init__(self) -> None:
1671         """You are in a twisty little maze of passages."""
1672         v = self.visit_stmt
1673         Ø: Set[str] = set()
1674         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1675         self.visit_if_stmt = partial(
1676             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1677         )
1678         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1679         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1680         self.visit_try_stmt = partial(
1681             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1682         )
1683         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1684         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1685         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1686         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1687         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1688         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1689         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1690         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1691         self.visit_async_funcdef = self.visit_async_stmt
1692         self.visit_decorated = self.visit_decorators
1693
1694
1695 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1696 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1697 OPENING_BRACKETS = set(BRACKET.keys())
1698 CLOSING_BRACKETS = set(BRACKET.values())
1699 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1700 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1701
1702
1703 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1704     """Return whitespace prefix if needed for the given `leaf`.
1705
1706     `complex_subscript` signals whether the given leaf is part of a subscription
1707     which has non-trivial arguments, like arithmetic expressions or function calls.
1708     """
1709     NO = ""
1710     SPACE = " "
1711     DOUBLESPACE = "  "
1712     t = leaf.type
1713     p = leaf.parent
1714     v = leaf.value
1715     if t in ALWAYS_NO_SPACE:
1716         return NO
1717
1718     if t == token.COMMENT:
1719         return DOUBLESPACE
1720
1721     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1722     if t == token.COLON and p.type not in {
1723         syms.subscript,
1724         syms.subscriptlist,
1725         syms.sliceop,
1726     }:
1727         return NO
1728
1729     prev = leaf.prev_sibling
1730     if not prev:
1731         prevp = preceding_leaf(p)
1732         if not prevp or prevp.type in OPENING_BRACKETS:
1733             return NO
1734
1735         if t == token.COLON:
1736             if prevp.type == token.COLON:
1737                 return NO
1738
1739             elif prevp.type != token.COMMA and not complex_subscript:
1740                 return NO
1741
1742             return SPACE
1743
1744         if prevp.type == token.EQUAL:
1745             if prevp.parent:
1746                 if prevp.parent.type in {
1747                     syms.arglist,
1748                     syms.argument,
1749                     syms.parameters,
1750                     syms.varargslist,
1751                 }:
1752                     return NO
1753
1754                 elif prevp.parent.type == syms.typedargslist:
1755                     # A bit hacky: if the equal sign has whitespace, it means we
1756                     # previously found it's a typed argument.  So, we're using
1757                     # that, too.
1758                     return prevp.prefix
1759
1760         elif prevp.type in STARS:
1761             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1762                 return NO
1763
1764         elif prevp.type == token.COLON:
1765             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1766                 return SPACE if complex_subscript else NO
1767
1768         elif (
1769             prevp.parent
1770             and prevp.parent.type == syms.factor
1771             and prevp.type in MATH_OPERATORS
1772         ):
1773             return NO
1774
1775         elif (
1776             prevp.type == token.RIGHTSHIFT
1777             and prevp.parent
1778             and prevp.parent.type == syms.shift_expr
1779             and prevp.prev_sibling
1780             and prevp.prev_sibling.type == token.NAME
1781             and prevp.prev_sibling.value == "print"  # type: ignore
1782         ):
1783             # Python 2 print chevron
1784             return NO
1785
1786     elif prev.type in OPENING_BRACKETS:
1787         return NO
1788
1789     if p.type in {syms.parameters, syms.arglist}:
1790         # untyped function signatures or calls
1791         if not prev or prev.type != token.COMMA:
1792             return NO
1793
1794     elif p.type == syms.varargslist:
1795         # lambdas
1796         if prev and prev.type != token.COMMA:
1797             return NO
1798
1799     elif p.type == syms.typedargslist:
1800         # typed function signatures
1801         if not prev:
1802             return NO
1803
1804         if t == token.EQUAL:
1805             if prev.type != syms.tname:
1806                 return NO
1807
1808         elif prev.type == token.EQUAL:
1809             # A bit hacky: if the equal sign has whitespace, it means we
1810             # previously found it's a typed argument.  So, we're using that, too.
1811             return prev.prefix
1812
1813         elif prev.type != token.COMMA:
1814             return NO
1815
1816     elif p.type == syms.tname:
1817         # type names
1818         if not prev:
1819             prevp = preceding_leaf(p)
1820             if not prevp or prevp.type != token.COMMA:
1821                 return NO
1822
1823     elif p.type == syms.trailer:
1824         # attributes and calls
1825         if t == token.LPAR or t == token.RPAR:
1826             return NO
1827
1828         if not prev:
1829             if t == token.DOT:
1830                 prevp = preceding_leaf(p)
1831                 if not prevp or prevp.type != token.NUMBER:
1832                     return NO
1833
1834             elif t == token.LSQB:
1835                 return NO
1836
1837         elif prev.type != token.COMMA:
1838             return NO
1839
1840     elif p.type == syms.argument:
1841         # single argument
1842         if t == token.EQUAL:
1843             return NO
1844
1845         if not prev:
1846             prevp = preceding_leaf(p)
1847             if not prevp or prevp.type == token.LPAR:
1848                 return NO
1849
1850         elif prev.type in {token.EQUAL} | STARS:
1851             return NO
1852
1853     elif p.type == syms.decorator:
1854         # decorators
1855         return NO
1856
1857     elif p.type == syms.dotted_name:
1858         if prev:
1859             return NO
1860
1861         prevp = preceding_leaf(p)
1862         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1863             return NO
1864
1865     elif p.type == syms.classdef:
1866         if t == token.LPAR:
1867             return NO
1868
1869         if prev and prev.type == token.LPAR:
1870             return NO
1871
1872     elif p.type in {syms.subscript, syms.sliceop}:
1873         # indexing
1874         if not prev:
1875             assert p.parent is not None, "subscripts are always parented"
1876             if p.parent.type == syms.subscriptlist:
1877                 return SPACE
1878
1879             return NO
1880
1881         elif not complex_subscript:
1882             return NO
1883
1884     elif p.type == syms.atom:
1885         if prev and t == token.DOT:
1886             # dots, but not the first one.
1887             return NO
1888
1889     elif p.type == syms.dictsetmaker:
1890         # dict unpacking
1891         if prev and prev.type == token.DOUBLESTAR:
1892             return NO
1893
1894     elif p.type in {syms.factor, syms.star_expr}:
1895         # unary ops
1896         if not prev:
1897             prevp = preceding_leaf(p)
1898             if not prevp or prevp.type in OPENING_BRACKETS:
1899                 return NO
1900
1901             prevp_parent = prevp.parent
1902             assert prevp_parent is not None
1903             if prevp.type == token.COLON and prevp_parent.type in {
1904                 syms.subscript,
1905                 syms.sliceop,
1906             }:
1907                 return NO
1908
1909             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1910                 return NO
1911
1912         elif t in {token.NAME, token.NUMBER, token.STRING}:
1913             return NO
1914
1915     elif p.type == syms.import_from:
1916         if t == token.DOT:
1917             if prev and prev.type == token.DOT:
1918                 return NO
1919
1920         elif t == token.NAME:
1921             if v == "import":
1922                 return SPACE
1923
1924             if prev and prev.type == token.DOT:
1925                 return NO
1926
1927     elif p.type == syms.sliceop:
1928         return NO
1929
1930     return SPACE
1931
1932
1933 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1934     """Return the first leaf that precedes `node`, if any."""
1935     while node:
1936         res = node.prev_sibling
1937         if res:
1938             if isinstance(res, Leaf):
1939                 return res
1940
1941             try:
1942                 return list(res.leaves())[-1]
1943
1944             except IndexError:
1945                 return None
1946
1947         node = node.parent
1948     return None
1949
1950
1951 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1952     """Return the child of `ancestor` that contains `descendant`."""
1953     node: Optional[LN] = descendant
1954     while node and node.parent != ancestor:
1955         node = node.parent
1956     return node
1957
1958
1959 def container_of(leaf: Leaf) -> LN:
1960     """Return `leaf` or one of its ancestors that is the topmost container of it.
1961
1962     By "container" we mean a node where `leaf` is the very first child.
1963     """
1964     same_prefix = leaf.prefix
1965     container: LN = leaf
1966     while container:
1967         parent = container.parent
1968         if parent is None:
1969             break
1970
1971         if parent.children[0].prefix != same_prefix:
1972             break
1973
1974         if parent.type == syms.file_input:
1975             break
1976
1977         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1978             break
1979
1980         container = parent
1981     return container
1982
1983
1984 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1985     """Return the priority of the `leaf` delimiter, given a line break after it.
1986
1987     The delimiter priorities returned here are from those delimiters that would
1988     cause a line break after themselves.
1989
1990     Higher numbers are higher priority.
1991     """
1992     if leaf.type == token.COMMA:
1993         return COMMA_PRIORITY
1994
1995     return 0
1996
1997
1998 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1999     """Return the priority of the `leaf` delimiter, given a line break before it.
2000
2001     The delimiter priorities returned here are from those delimiters that would
2002     cause a line break before themselves.
2003
2004     Higher numbers are higher priority.
2005     """
2006     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2007         # * and ** might also be MATH_OPERATORS but in this case they are not.
2008         # Don't treat them as a delimiter.
2009         return 0
2010
2011     if (
2012         leaf.type == token.DOT
2013         and leaf.parent
2014         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2015         and (previous is None or previous.type in CLOSING_BRACKETS)
2016     ):
2017         return DOT_PRIORITY
2018
2019     if (
2020         leaf.type in MATH_OPERATORS
2021         and leaf.parent
2022         and leaf.parent.type not in {syms.factor, syms.star_expr}
2023     ):
2024         return MATH_PRIORITIES[leaf.type]
2025
2026     if leaf.type in COMPARATORS:
2027         return COMPARATOR_PRIORITY
2028
2029     if (
2030         leaf.type == token.STRING
2031         and previous is not None
2032         and previous.type == token.STRING
2033     ):
2034         return STRING_PRIORITY
2035
2036     if leaf.type not in {token.NAME, token.ASYNC}:
2037         return 0
2038
2039     if (
2040         leaf.value == "for"
2041         and leaf.parent
2042         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2043         or leaf.type == token.ASYNC
2044     ):
2045         if (
2046             not isinstance(leaf.prev_sibling, Leaf)
2047             or leaf.prev_sibling.value != "async"
2048         ):
2049             return COMPREHENSION_PRIORITY
2050
2051     if (
2052         leaf.value == "if"
2053         and leaf.parent
2054         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2055     ):
2056         return COMPREHENSION_PRIORITY
2057
2058     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2059         return TERNARY_PRIORITY
2060
2061     if leaf.value == "is":
2062         return COMPARATOR_PRIORITY
2063
2064     if (
2065         leaf.value == "in"
2066         and leaf.parent
2067         and leaf.parent.type in {syms.comp_op, syms.comparison}
2068         and not (
2069             previous is not None
2070             and previous.type == token.NAME
2071             and previous.value == "not"
2072         )
2073     ):
2074         return COMPARATOR_PRIORITY
2075
2076     if (
2077         leaf.value == "not"
2078         and leaf.parent
2079         and leaf.parent.type == syms.comp_op
2080         and not (
2081             previous is not None
2082             and previous.type == token.NAME
2083             and previous.value == "is"
2084         )
2085     ):
2086         return COMPARATOR_PRIORITY
2087
2088     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2089         return LOGIC_PRIORITY
2090
2091     return 0
2092
2093
2094 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2095 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2096
2097
2098 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2099     """Clean the prefix of the `leaf` and generate comments from it, if any.
2100
2101     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2102     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2103     move because it does away with modifying the grammar to include all the
2104     possible places in which comments can be placed.
2105
2106     The sad consequence for us though is that comments don't "belong" anywhere.
2107     This is why this function generates simple parentless Leaf objects for
2108     comments.  We simply don't know what the correct parent should be.
2109
2110     No matter though, we can live without this.  We really only need to
2111     differentiate between inline and standalone comments.  The latter don't
2112     share the line with any code.
2113
2114     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2115     are emitted with a fake STANDALONE_COMMENT token identifier.
2116     """
2117     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2118         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2119
2120
2121 @dataclass
2122 class ProtoComment:
2123     """Describes a piece of syntax that is a comment.
2124
2125     It's not a :class:`blib2to3.pytree.Leaf` so that:
2126
2127     * it can be cached (`Leaf` objects should not be reused more than once as
2128       they store their lineno, column, prefix, and parent information);
2129     * `newlines` and `consumed` fields are kept separate from the `value`. This
2130       simplifies handling of special marker comments like ``# fmt: off/on``.
2131     """
2132
2133     type: int  # token.COMMENT or STANDALONE_COMMENT
2134     value: str  # content of the comment
2135     newlines: int  # how many newlines before the comment
2136     consumed: int  # how many characters of the original leaf's prefix did we consume
2137
2138
2139 @lru_cache(maxsize=4096)
2140 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2141     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2142     result: List[ProtoComment] = []
2143     if not prefix or "#" not in prefix:
2144         return result
2145
2146     consumed = 0
2147     nlines = 0
2148     for index, line in enumerate(prefix.split("\n")):
2149         consumed += len(line) + 1  # adding the length of the split '\n'
2150         line = line.lstrip()
2151         if not line:
2152             nlines += 1
2153         if not line.startswith("#"):
2154             continue
2155
2156         if index == 0 and not is_endmarker:
2157             comment_type = token.COMMENT  # simple trailing comment
2158         else:
2159             comment_type = STANDALONE_COMMENT
2160         comment = make_comment(line)
2161         result.append(
2162             ProtoComment(
2163                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2164             )
2165         )
2166         nlines = 0
2167     return result
2168
2169
2170 def make_comment(content: str) -> str:
2171     """Return a consistently formatted comment from the given `content` string.
2172
2173     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2174     space between the hash sign and the content.
2175
2176     If `content` didn't start with a hash sign, one is provided.
2177     """
2178     content = content.rstrip()
2179     if not content:
2180         return "#"
2181
2182     if content[0] == "#":
2183         content = content[1:]
2184     if content and content[0] not in " !:#'%":
2185         content = " " + content
2186     return "#" + content
2187
2188
2189 def split_line(
2190     line: Line,
2191     line_length: int,
2192     inner: bool = False,
2193     features: Collection[Feature] = (),
2194 ) -> Iterator[Line]:
2195     """Split a `line` into potentially many lines.
2196
2197     They should fit in the allotted `line_length` but might not be able to.
2198     `inner` signifies that there were a pair of brackets somewhere around the
2199     current `line`, possibly transitively. This means we can fallback to splitting
2200     by delimiters if the LHS/RHS don't yield any results.
2201
2202     `features` are syntactical features that may be used in the output.
2203     """
2204     if line.is_comment:
2205         yield line
2206         return
2207
2208     line_str = str(line).strip("\n")
2209
2210     if (
2211         not line.contains_inner_type_comments()
2212         and not line.should_explode
2213         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2214     ):
2215         yield line
2216         return
2217
2218     split_funcs: List[SplitFunc]
2219     if line.is_def:
2220         split_funcs = [left_hand_split]
2221     else:
2222
2223         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2224             for omit in generate_trailers_to_omit(line, line_length):
2225                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2226                 if is_line_short_enough(lines[0], line_length=line_length):
2227                     yield from lines
2228                     return
2229
2230             # All splits failed, best effort split with no omits.
2231             # This mostly happens to multiline strings that are by definition
2232             # reported as not fitting a single line.
2233             yield from right_hand_split(line, line_length, features=features)
2234
2235         if line.inside_brackets:
2236             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2237         else:
2238             split_funcs = [rhs]
2239     for split_func in split_funcs:
2240         # We are accumulating lines in `result` because we might want to abort
2241         # mission and return the original line in the end, or attempt a different
2242         # split altogether.
2243         result: List[Line] = []
2244         try:
2245             for l in split_func(line, features):
2246                 if str(l).strip("\n") == line_str:
2247                     raise CannotSplit("Split function returned an unchanged result")
2248
2249                 result.extend(
2250                     split_line(
2251                         l, line_length=line_length, inner=True, features=features
2252                     )
2253                 )
2254         except CannotSplit:
2255             continue
2256
2257         else:
2258             yield from result
2259             break
2260
2261     else:
2262         yield line
2263
2264
2265 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2266     """Split line into many lines, starting with the first matching bracket pair.
2267
2268     Note: this usually looks weird, only use this for function definitions.
2269     Prefer RHS otherwise.  This is why this function is not symmetrical with
2270     :func:`right_hand_split` which also handles optional parentheses.
2271     """
2272     tail_leaves: List[Leaf] = []
2273     body_leaves: List[Leaf] = []
2274     head_leaves: List[Leaf] = []
2275     current_leaves = head_leaves
2276     matching_bracket = None
2277     for leaf in line.leaves:
2278         if (
2279             current_leaves is body_leaves
2280             and leaf.type in CLOSING_BRACKETS
2281             and leaf.opening_bracket is matching_bracket
2282         ):
2283             current_leaves = tail_leaves if body_leaves else head_leaves
2284         current_leaves.append(leaf)
2285         if current_leaves is head_leaves:
2286             if leaf.type in OPENING_BRACKETS:
2287                 matching_bracket = leaf
2288                 current_leaves = body_leaves
2289     if not matching_bracket:
2290         raise CannotSplit("No brackets found")
2291
2292     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2293     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2294     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2295     bracket_split_succeeded_or_raise(head, body, tail)
2296     for result in (head, body, tail):
2297         if result:
2298             yield result
2299
2300
2301 def right_hand_split(
2302     line: Line,
2303     line_length: int,
2304     features: Collection[Feature] = (),
2305     omit: Collection[LeafID] = (),
2306 ) -> Iterator[Line]:
2307     """Split line into many lines, starting with the last matching bracket pair.
2308
2309     If the split was by optional parentheses, attempt splitting without them, too.
2310     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2311     this split.
2312
2313     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2314     """
2315     tail_leaves: List[Leaf] = []
2316     body_leaves: List[Leaf] = []
2317     head_leaves: List[Leaf] = []
2318     current_leaves = tail_leaves
2319     opening_bracket = None
2320     closing_bracket = None
2321     for leaf in reversed(line.leaves):
2322         if current_leaves is body_leaves:
2323             if leaf is opening_bracket:
2324                 current_leaves = head_leaves if body_leaves else tail_leaves
2325         current_leaves.append(leaf)
2326         if current_leaves is tail_leaves:
2327             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2328                 opening_bracket = leaf.opening_bracket
2329                 closing_bracket = leaf
2330                 current_leaves = body_leaves
2331     if not (opening_bracket and closing_bracket and head_leaves):
2332         # If there is no opening or closing_bracket that means the split failed and
2333         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2334         # the matching `opening_bracket` wasn't available on `line` anymore.
2335         raise CannotSplit("No brackets found")
2336
2337     tail_leaves.reverse()
2338     body_leaves.reverse()
2339     head_leaves.reverse()
2340     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2341     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2342     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2343     bracket_split_succeeded_or_raise(head, body, tail)
2344     if (
2345         # the body shouldn't be exploded
2346         not body.should_explode
2347         # the opening bracket is an optional paren
2348         and opening_bracket.type == token.LPAR
2349         and not opening_bracket.value
2350         # the closing bracket is an optional paren
2351         and closing_bracket.type == token.RPAR
2352         and not closing_bracket.value
2353         # it's not an import (optional parens are the only thing we can split on
2354         # in this case; attempting a split without them is a waste of time)
2355         and not line.is_import
2356         # there are no standalone comments in the body
2357         and not body.contains_standalone_comments(0)
2358         # and we can actually remove the parens
2359         and can_omit_invisible_parens(body, line_length)
2360     ):
2361         omit = {id(closing_bracket), *omit}
2362         try:
2363             yield from right_hand_split(line, line_length, features=features, omit=omit)
2364             return
2365
2366         except CannotSplit:
2367             if not (
2368                 can_be_split(body)
2369                 or is_line_short_enough(body, line_length=line_length)
2370             ):
2371                 raise CannotSplit(
2372                     "Splitting failed, body is still too long and can't be split."
2373                 )
2374
2375             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2376                 raise CannotSplit(
2377                     "The current optional pair of parentheses is bound to fail to "
2378                     "satisfy the splitting algorithm because the head or the tail "
2379                     "contains multiline strings which by definition never fit one "
2380                     "line."
2381                 )
2382
2383     ensure_visible(opening_bracket)
2384     ensure_visible(closing_bracket)
2385     for result in (head, body, tail):
2386         if result:
2387             yield result
2388
2389
2390 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2391     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2392
2393     Do nothing otherwise.
2394
2395     A left- or right-hand split is based on a pair of brackets. Content before
2396     (and including) the opening bracket is left on one line, content inside the
2397     brackets is put on a separate line, and finally content starting with and
2398     following the closing bracket is put on a separate line.
2399
2400     Those are called `head`, `body`, and `tail`, respectively. If the split
2401     produced the same line (all content in `head`) or ended up with an empty `body`
2402     and the `tail` is just the closing bracket, then it's considered failed.
2403     """
2404     tail_len = len(str(tail).strip())
2405     if not body:
2406         if tail_len == 0:
2407             raise CannotSplit("Splitting brackets produced the same line")
2408
2409         elif tail_len < 3:
2410             raise CannotSplit(
2411                 f"Splitting brackets on an empty body to save "
2412                 f"{tail_len} characters is not worth it"
2413             )
2414
2415
2416 def bracket_split_build_line(
2417     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2418 ) -> Line:
2419     """Return a new line with given `leaves` and respective comments from `original`.
2420
2421     If `is_body` is True, the result line is one-indented inside brackets and as such
2422     has its first leaf's prefix normalized and a trailing comma added when expected.
2423     """
2424     result = Line(depth=original.depth)
2425     if is_body:
2426         result.inside_brackets = True
2427         result.depth += 1
2428         if leaves:
2429             # Since body is a new indent level, remove spurious leading whitespace.
2430             normalize_prefix(leaves[0], inside_brackets=True)
2431             # Ensure a trailing comma for imports, but be careful not to add one after
2432             # any comments.
2433             if original.is_import:
2434                 for i in range(len(leaves) - 1, -1, -1):
2435                     if leaves[i].type == STANDALONE_COMMENT:
2436                         continue
2437                     elif leaves[i].type == token.COMMA:
2438                         break
2439                     else:
2440                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2441                         break
2442     # Populate the line
2443     for leaf in leaves:
2444         result.append(leaf, preformatted=True)
2445         for comment_after in original.comments_after(leaf):
2446             result.append(comment_after, preformatted=True)
2447     if is_body:
2448         result.should_explode = should_explode(result, opening_bracket)
2449     return result
2450
2451
2452 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2453     """Normalize prefix of the first leaf in every line returned by `split_func`.
2454
2455     This is a decorator over relevant split functions.
2456     """
2457
2458     @wraps(split_func)
2459     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2460         for l in split_func(line, features):
2461             normalize_prefix(l.leaves[0], inside_brackets=True)
2462             yield l
2463
2464     return split_wrapper
2465
2466
2467 @dont_increase_indentation
2468 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2469     """Split according to delimiters of the highest priority.
2470
2471     If the appropriate Features are given, the split will add trailing commas
2472     also in function signatures and calls that contain `*` and `**`.
2473     """
2474     try:
2475         last_leaf = line.leaves[-1]
2476     except IndexError:
2477         raise CannotSplit("Line empty")
2478
2479     bt = line.bracket_tracker
2480     try:
2481         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2482     except ValueError:
2483         raise CannotSplit("No delimiters found")
2484
2485     if delimiter_priority == DOT_PRIORITY:
2486         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2487             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2488
2489     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2490     lowest_depth = sys.maxsize
2491     trailing_comma_safe = True
2492
2493     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2494         """Append `leaf` to current line or to new line if appending impossible."""
2495         nonlocal current_line
2496         try:
2497             current_line.append_safe(leaf, preformatted=True)
2498         except ValueError:
2499             yield current_line
2500
2501             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2502             current_line.append(leaf)
2503
2504     for leaf in line.leaves:
2505         yield from append_to_line(leaf)
2506
2507         for comment_after in line.comments_after(leaf):
2508             yield from append_to_line(comment_after)
2509
2510         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2511         if leaf.bracket_depth == lowest_depth:
2512             if is_vararg(leaf, within={syms.typedargslist}):
2513                 trailing_comma_safe = (
2514                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2515                 )
2516             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2517                 trailing_comma_safe = (
2518                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2519                 )
2520
2521         leaf_priority = bt.delimiters.get(id(leaf))
2522         if leaf_priority == delimiter_priority:
2523             yield current_line
2524
2525             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2526     if current_line:
2527         if (
2528             trailing_comma_safe
2529             and delimiter_priority == COMMA_PRIORITY
2530             and current_line.leaves[-1].type != token.COMMA
2531             and current_line.leaves[-1].type != STANDALONE_COMMENT
2532         ):
2533             current_line.append(Leaf(token.COMMA, ","))
2534         yield current_line
2535
2536
2537 @dont_increase_indentation
2538 def standalone_comment_split(
2539     line: Line, features: Collection[Feature] = ()
2540 ) -> Iterator[Line]:
2541     """Split standalone comments from the rest of the line."""
2542     if not line.contains_standalone_comments(0):
2543         raise CannotSplit("Line does not have any standalone comments")
2544
2545     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2546
2547     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2548         """Append `leaf` to current line or to new line if appending impossible."""
2549         nonlocal current_line
2550         try:
2551             current_line.append_safe(leaf, preformatted=True)
2552         except ValueError:
2553             yield current_line
2554
2555             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2556             current_line.append(leaf)
2557
2558     for leaf in line.leaves:
2559         yield from append_to_line(leaf)
2560
2561         for comment_after in line.comments_after(leaf):
2562             yield from append_to_line(comment_after)
2563
2564     if current_line:
2565         yield current_line
2566
2567
2568 def is_import(leaf: Leaf) -> bool:
2569     """Return True if the given leaf starts an import statement."""
2570     p = leaf.parent
2571     t = leaf.type
2572     v = leaf.value
2573     return bool(
2574         t == token.NAME
2575         and (
2576             (v == "import" and p and p.type == syms.import_name)
2577             or (v == "from" and p and p.type == syms.import_from)
2578         )
2579     )
2580
2581
2582 def is_type_comment(leaf: Leaf) -> bool:
2583     """Return True if the given leaf is a special comment.
2584     Only returns true for type comments for now."""
2585     t = leaf.type
2586     v = leaf.value
2587     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2588
2589
2590 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2591     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2592     else.
2593
2594     Note: don't use backslashes for formatting or you'll lose your voting rights.
2595     """
2596     if not inside_brackets:
2597         spl = leaf.prefix.split("#")
2598         if "\\" not in spl[0]:
2599             nl_count = spl[-1].count("\n")
2600             if len(spl) > 1:
2601                 nl_count -= 1
2602             leaf.prefix = "\n" * nl_count
2603             return
2604
2605     leaf.prefix = ""
2606
2607
2608 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2609     """Make all string prefixes lowercase.
2610
2611     If remove_u_prefix is given, also removes any u prefix from the string.
2612
2613     Note: Mutates its argument.
2614     """
2615     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2616     assert match is not None, f"failed to match string {leaf.value!r}"
2617     orig_prefix = match.group(1)
2618     new_prefix = orig_prefix.lower()
2619     if remove_u_prefix:
2620         new_prefix = new_prefix.replace("u", "")
2621     leaf.value = f"{new_prefix}{match.group(2)}"
2622
2623
2624 def normalize_string_quotes(leaf: Leaf) -> None:
2625     """Prefer double quotes but only if it doesn't cause more escaping.
2626
2627     Adds or removes backslashes as appropriate. Doesn't parse and fix
2628     strings nested in f-strings (yet).
2629
2630     Note: Mutates its argument.
2631     """
2632     value = leaf.value.lstrip("furbFURB")
2633     if value[:3] == '"""':
2634         return
2635
2636     elif value[:3] == "'''":
2637         orig_quote = "'''"
2638         new_quote = '"""'
2639     elif value[0] == '"':
2640         orig_quote = '"'
2641         new_quote = "'"
2642     else:
2643         orig_quote = "'"
2644         new_quote = '"'
2645     first_quote_pos = leaf.value.find(orig_quote)
2646     if first_quote_pos == -1:
2647         return  # There's an internal error
2648
2649     prefix = leaf.value[:first_quote_pos]
2650     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2651     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2652     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2653     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2654     if "r" in prefix.casefold():
2655         if unescaped_new_quote.search(body):
2656             # There's at least one unescaped new_quote in this raw string
2657             # so converting is impossible
2658             return
2659
2660         # Do not introduce or remove backslashes in raw strings
2661         new_body = body
2662     else:
2663         # remove unnecessary escapes
2664         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2665         if body != new_body:
2666             # Consider the string without unnecessary escapes as the original
2667             body = new_body
2668             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2669         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2670         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2671     if "f" in prefix.casefold():
2672         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2673         for m in matches:
2674             if "\\" in str(m):
2675                 # Do not introduce backslashes in interpolated expressions
2676                 return
2677     if new_quote == '"""' and new_body[-1:] == '"':
2678         # edge case:
2679         new_body = new_body[:-1] + '\\"'
2680     orig_escape_count = body.count("\\")
2681     new_escape_count = new_body.count("\\")
2682     if new_escape_count > orig_escape_count:
2683         return  # Do not introduce more escaping
2684
2685     if new_escape_count == orig_escape_count and orig_quote == '"':
2686         return  # Prefer double quotes
2687
2688     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2689
2690
2691 def normalize_numeric_literal(leaf: Leaf) -> None:
2692     """Normalizes numeric (float, int, and complex) literals.
2693
2694     All letters used in the representation are normalized to lowercase (except
2695     in Python 2 long literals).
2696     """
2697     text = leaf.value.lower()
2698     if text.startswith(("0o", "0b")):
2699         # Leave octal and binary literals alone.
2700         pass
2701     elif text.startswith("0x"):
2702         # Change hex literals to upper case.
2703         before, after = text[:2], text[2:]
2704         text = f"{before}{after.upper()}"
2705     elif "e" in text:
2706         before, after = text.split("e")
2707         sign = ""
2708         if after.startswith("-"):
2709             after = after[1:]
2710             sign = "-"
2711         elif after.startswith("+"):
2712             after = after[1:]
2713         before = format_float_or_int_string(before)
2714         text = f"{before}e{sign}{after}"
2715     elif text.endswith(("j", "l")):
2716         number = text[:-1]
2717         suffix = text[-1]
2718         # Capitalize in "2L" because "l" looks too similar to "1".
2719         if suffix == "l":
2720             suffix = "L"
2721         text = f"{format_float_or_int_string(number)}{suffix}"
2722     else:
2723         text = format_float_or_int_string(text)
2724     leaf.value = text
2725
2726
2727 def format_float_or_int_string(text: str) -> str:
2728     """Formats a float string like "1.0"."""
2729     if "." not in text:
2730         return text
2731
2732     before, after = text.split(".")
2733     return f"{before or 0}.{after or 0}"
2734
2735
2736 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2737     """Make existing optional parentheses invisible or create new ones.
2738
2739     `parens_after` is a set of string leaf values immeditely after which parens
2740     should be put.
2741
2742     Standardizes on visible parentheses for single-element tuples, and keeps
2743     existing visible parentheses for other tuples and generator expressions.
2744     """
2745     for pc in list_comments(node.prefix, is_endmarker=False):
2746         if pc.value in FMT_OFF:
2747             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2748             return
2749
2750     check_lpar = False
2751     for index, child in enumerate(list(node.children)):
2752         # Add parentheses around long tuple unpacking in assignments.
2753         if (
2754             index == 0
2755             and isinstance(child, Node)
2756             and child.type == syms.testlist_star_expr
2757         ):
2758             check_lpar = True
2759
2760         if check_lpar:
2761             if child.type == syms.atom:
2762                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2763                     lpar = Leaf(token.LPAR, "")
2764                     rpar = Leaf(token.RPAR, "")
2765                     index = child.remove() or 0
2766                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2767             elif is_one_tuple(child):
2768                 # wrap child in visible parentheses
2769                 lpar = Leaf(token.LPAR, "(")
2770                 rpar = Leaf(token.RPAR, ")")
2771                 child.remove()
2772                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2773             elif node.type == syms.import_from:
2774                 # "import from" nodes store parentheses directly as part of
2775                 # the statement
2776                 if child.type == token.LPAR:
2777                     # make parentheses invisible
2778                     child.value = ""  # type: ignore
2779                     node.children[-1].value = ""  # type: ignore
2780                 elif child.type != token.STAR:
2781                     # insert invisible parentheses
2782                     node.insert_child(index, Leaf(token.LPAR, ""))
2783                     node.append_child(Leaf(token.RPAR, ""))
2784                 break
2785
2786             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2787                 # wrap child in invisible parentheses
2788                 lpar = Leaf(token.LPAR, "")
2789                 rpar = Leaf(token.RPAR, "")
2790                 index = child.remove() or 0
2791                 prefix = child.prefix
2792                 child.prefix = ""
2793                 new_child = Node(syms.atom, [lpar, child, rpar])
2794                 new_child.prefix = prefix
2795                 node.insert_child(index, new_child)
2796
2797         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2798
2799
2800 def normalize_fmt_off(node: Node) -> None:
2801     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2802     try_again = True
2803     while try_again:
2804         try_again = convert_one_fmt_off_pair(node)
2805
2806
2807 def convert_one_fmt_off_pair(node: Node) -> bool:
2808     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2809
2810     Returns True if a pair was converted.
2811     """
2812     for leaf in node.leaves():
2813         previous_consumed = 0
2814         for comment in list_comments(leaf.prefix, is_endmarker=False):
2815             if comment.value in FMT_OFF:
2816                 # We only want standalone comments. If there's no previous leaf or
2817                 # the previous leaf is indentation, it's a standalone comment in
2818                 # disguise.
2819                 if comment.type != STANDALONE_COMMENT:
2820                     prev = preceding_leaf(leaf)
2821                     if prev and prev.type not in WHITESPACE:
2822                         continue
2823
2824                 ignored_nodes = list(generate_ignored_nodes(leaf))
2825                 if not ignored_nodes:
2826                     continue
2827
2828                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2829                 parent = first.parent
2830                 prefix = first.prefix
2831                 first.prefix = prefix[comment.consumed :]
2832                 hidden_value = (
2833                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2834                 )
2835                 if hidden_value.endswith("\n"):
2836                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2837                     # leaf (possibly followed by a DEDENT).
2838                     hidden_value = hidden_value[:-1]
2839                 first_idx = None
2840                 for ignored in ignored_nodes:
2841                     index = ignored.remove()
2842                     if first_idx is None:
2843                         first_idx = index
2844                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2845                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2846                 parent.insert_child(
2847                     first_idx,
2848                     Leaf(
2849                         STANDALONE_COMMENT,
2850                         hidden_value,
2851                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2852                     ),
2853                 )
2854                 return True
2855
2856             previous_consumed = comment.consumed
2857
2858     return False
2859
2860
2861 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2862     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2863
2864     Stops at the end of the block.
2865     """
2866     container: Optional[LN] = container_of(leaf)
2867     while container is not None and container.type != token.ENDMARKER:
2868         for comment in list_comments(container.prefix, is_endmarker=False):
2869             if comment.value in FMT_ON:
2870                 return
2871
2872         yield container
2873
2874         container = container.next_sibling
2875
2876
2877 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
2878     """If it's safe, make the parens in the atom `node` invisible, recursively.
2879
2880     Returns whether the node should itself be wrapped in invisible parentheses.
2881
2882     """
2883     if (
2884         node.type != syms.atom
2885         or is_empty_tuple(node)
2886         or is_one_tuple(node)
2887         or (is_yield(node) and parent.type != syms.expr_stmt)
2888         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2889     ):
2890         return False
2891
2892     first = node.children[0]
2893     last = node.children[-1]
2894     if first.type == token.LPAR and last.type == token.RPAR:
2895         # make parentheses invisible
2896         first.value = ""  # type: ignore
2897         last.value = ""  # type: ignore
2898         if len(node.children) > 1:
2899             maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
2900         return False
2901
2902     return True
2903
2904
2905 def is_empty_tuple(node: LN) -> bool:
2906     """Return True if `node` holds an empty tuple."""
2907     return (
2908         node.type == syms.atom
2909         and len(node.children) == 2
2910         and node.children[0].type == token.LPAR
2911         and node.children[1].type == token.RPAR
2912     )
2913
2914
2915 def is_one_tuple(node: LN) -> bool:
2916     """Return True if `node` holds a tuple with one element, with or without parens."""
2917     if node.type == syms.atom:
2918         if len(node.children) != 3:
2919             return False
2920
2921         lpar, gexp, rpar = node.children
2922         if not (
2923             lpar.type == token.LPAR
2924             and gexp.type == syms.testlist_gexp
2925             and rpar.type == token.RPAR
2926         ):
2927             return False
2928
2929         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2930
2931     return (
2932         node.type in IMPLICIT_TUPLE
2933         and len(node.children) == 2
2934         and node.children[1].type == token.COMMA
2935     )
2936
2937
2938 def is_yield(node: LN) -> bool:
2939     """Return True if `node` holds a `yield` or `yield from` expression."""
2940     if node.type == syms.yield_expr:
2941         return True
2942
2943     if node.type == token.NAME and node.value == "yield":  # type: ignore
2944         return True
2945
2946     if node.type != syms.atom:
2947         return False
2948
2949     if len(node.children) != 3:
2950         return False
2951
2952     lpar, expr, rpar = node.children
2953     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2954         return is_yield(expr)
2955
2956     return False
2957
2958
2959 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2960     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2961
2962     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2963     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2964     extended iterable unpacking (PEP 3132) and additional unpacking
2965     generalizations (PEP 448).
2966     """
2967     if leaf.type not in STARS or not leaf.parent:
2968         return False
2969
2970     p = leaf.parent
2971     if p.type == syms.star_expr:
2972         # Star expressions are also used as assignment targets in extended
2973         # iterable unpacking (PEP 3132).  See what its parent is instead.
2974         if not p.parent:
2975             return False
2976
2977         p = p.parent
2978
2979     return p.type in within
2980
2981
2982 def is_multiline_string(leaf: Leaf) -> bool:
2983     """Return True if `leaf` is a multiline string that actually spans many lines."""
2984     value = leaf.value.lstrip("furbFURB")
2985     return value[:3] in {'"""', "'''"} and "\n" in value
2986
2987
2988 def is_stub_suite(node: Node) -> bool:
2989     """Return True if `node` is a suite with a stub body."""
2990     if (
2991         len(node.children) != 4
2992         or node.children[0].type != token.NEWLINE
2993         or node.children[1].type != token.INDENT
2994         or node.children[3].type != token.DEDENT
2995     ):
2996         return False
2997
2998     return is_stub_body(node.children[2])
2999
3000
3001 def is_stub_body(node: LN) -> bool:
3002     """Return True if `node` is a simple statement containing an ellipsis."""
3003     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3004         return False
3005
3006     if len(node.children) != 2:
3007         return False
3008
3009     child = node.children[0]
3010     return (
3011         child.type == syms.atom
3012         and len(child.children) == 3
3013         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3014     )
3015
3016
3017 def max_delimiter_priority_in_atom(node: LN) -> int:
3018     """Return maximum delimiter priority inside `node`.
3019
3020     This is specific to atoms with contents contained in a pair of parentheses.
3021     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3022     """
3023     if node.type != syms.atom:
3024         return 0
3025
3026     first = node.children[0]
3027     last = node.children[-1]
3028     if not (first.type == token.LPAR and last.type == token.RPAR):
3029         return 0
3030
3031     bt = BracketTracker()
3032     for c in node.children[1:-1]:
3033         if isinstance(c, Leaf):
3034             bt.mark(c)
3035         else:
3036             for leaf in c.leaves():
3037                 bt.mark(leaf)
3038     try:
3039         return bt.max_delimiter_priority()
3040
3041     except ValueError:
3042         return 0
3043
3044
3045 def ensure_visible(leaf: Leaf) -> None:
3046     """Make sure parentheses are visible.
3047
3048     They could be invisible as part of some statements (see
3049     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3050     """
3051     if leaf.type == token.LPAR:
3052         leaf.value = "("
3053     elif leaf.type == token.RPAR:
3054         leaf.value = ")"
3055
3056
3057 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3058     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3059
3060     if not (
3061         opening_bracket.parent
3062         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3063         and opening_bracket.value in "[{("
3064     ):
3065         return False
3066
3067     try:
3068         last_leaf = line.leaves[-1]
3069         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3070         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3071     except (IndexError, ValueError):
3072         return False
3073
3074     return max_priority == COMMA_PRIORITY
3075
3076
3077 def get_features_used(node: Node) -> Set[Feature]:
3078     """Return a set of (relatively) new Python features used in this file.
3079
3080     Currently looking for:
3081     - f-strings;
3082     - underscores in numeric literals; and
3083     - trailing commas after * or ** in function signatures and calls.
3084     """
3085     features: Set[Feature] = set()
3086     for n in node.pre_order():
3087         if n.type == token.STRING:
3088             value_head = n.value[:2]  # type: ignore
3089             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3090                 features.add(Feature.F_STRINGS)
3091
3092         elif n.type == token.NUMBER:
3093             if "_" in n.value:  # type: ignore
3094                 features.add(Feature.NUMERIC_UNDERSCORES)
3095
3096         elif (
3097             n.type in {syms.typedargslist, syms.arglist}
3098             and n.children
3099             and n.children[-1].type == token.COMMA
3100         ):
3101             if n.type == syms.typedargslist:
3102                 feature = Feature.TRAILING_COMMA_IN_DEF
3103             else:
3104                 feature = Feature.TRAILING_COMMA_IN_CALL
3105
3106             for ch in n.children:
3107                 if ch.type in STARS:
3108                     features.add(feature)
3109
3110                 if ch.type == syms.argument:
3111                     for argch in ch.children:
3112                         if argch.type in STARS:
3113                             features.add(feature)
3114
3115     return features
3116
3117
3118 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3119     """Detect the version to target based on the nodes used."""
3120     features = get_features_used(node)
3121     return {
3122         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3123     }
3124
3125
3126 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3127     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3128
3129     Brackets can be omitted if the entire trailer up to and including
3130     a preceding closing bracket fits in one line.
3131
3132     Yielded sets are cumulative (contain results of previous yields, too).  First
3133     set is empty.
3134     """
3135
3136     omit: Set[LeafID] = set()
3137     yield omit
3138
3139     length = 4 * line.depth
3140     opening_bracket = None
3141     closing_bracket = None
3142     inner_brackets: Set[LeafID] = set()
3143     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3144         length += leaf_length
3145         if length > line_length:
3146             break
3147
3148         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3149         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3150             break
3151
3152         if opening_bracket:
3153             if leaf is opening_bracket:
3154                 opening_bracket = None
3155             elif leaf.type in CLOSING_BRACKETS:
3156                 inner_brackets.add(id(leaf))
3157         elif leaf.type in CLOSING_BRACKETS:
3158             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3159                 # Empty brackets would fail a split so treat them as "inner"
3160                 # brackets (e.g. only add them to the `omit` set if another
3161                 # pair of brackets was good enough.
3162                 inner_brackets.add(id(leaf))
3163                 continue
3164
3165             if closing_bracket:
3166                 omit.add(id(closing_bracket))
3167                 omit.update(inner_brackets)
3168                 inner_brackets.clear()
3169                 yield omit
3170
3171             if leaf.value:
3172                 opening_bracket = leaf.opening_bracket
3173                 closing_bracket = leaf
3174
3175
3176 def get_future_imports(node: Node) -> Set[str]:
3177     """Return a set of __future__ imports in the file."""
3178     imports: Set[str] = set()
3179
3180     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3181         for child in children:
3182             if isinstance(child, Leaf):
3183                 if child.type == token.NAME:
3184                     yield child.value
3185             elif child.type == syms.import_as_name:
3186                 orig_name = child.children[0]
3187                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3188                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3189                 yield orig_name.value
3190             elif child.type == syms.import_as_names:
3191                 yield from get_imports_from_children(child.children)
3192             else:
3193                 raise AssertionError("Invalid syntax parsing imports")
3194
3195     for child in node.children:
3196         if child.type != syms.simple_stmt:
3197             break
3198         first_child = child.children[0]
3199         if isinstance(first_child, Leaf):
3200             # Continue looking if we see a docstring; otherwise stop.
3201             if (
3202                 len(child.children) == 2
3203                 and first_child.type == token.STRING
3204                 and child.children[1].type == token.NEWLINE
3205             ):
3206                 continue
3207             else:
3208                 break
3209         elif first_child.type == syms.import_from:
3210             module_name = first_child.children[1]
3211             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3212                 break
3213             imports |= set(get_imports_from_children(first_child.children[3:]))
3214         else:
3215             break
3216     return imports
3217
3218
3219 def gen_python_files_in_dir(
3220     path: Path,
3221     root: Path,
3222     include: Pattern[str],
3223     exclude: Pattern[str],
3224     report: "Report",
3225 ) -> Iterator[Path]:
3226     """Generate all files under `path` whose paths are not excluded by the
3227     `exclude` regex, but are included by the `include` regex.
3228
3229     Symbolic links pointing outside of the `root` directory are ignored.
3230
3231     `report` is where output about exclusions goes.
3232     """
3233     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3234     for child in path.iterdir():
3235         try:
3236             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3237         except ValueError:
3238             if child.is_symlink():
3239                 report.path_ignored(
3240                     child, f"is a symbolic link that points outside {root}"
3241                 )
3242                 continue
3243
3244             raise
3245
3246         if child.is_dir():
3247             normalized_path += "/"
3248         exclude_match = exclude.search(normalized_path)
3249         if exclude_match and exclude_match.group(0):
3250             report.path_ignored(child, f"matches the --exclude regular expression")
3251             continue
3252
3253         if child.is_dir():
3254             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3255
3256         elif child.is_file():
3257             include_match = include.search(normalized_path)
3258             if include_match:
3259                 yield child
3260
3261
3262 @lru_cache()
3263 def find_project_root(srcs: Iterable[str]) -> Path:
3264     """Return a directory containing .git, .hg, or pyproject.toml.
3265
3266     That directory can be one of the directories passed in `srcs` or their
3267     common parent.
3268
3269     If no directory in the tree contains a marker that would specify it's the
3270     project root, the root of the file system is returned.
3271     """
3272     if not srcs:
3273         return Path("/").resolve()
3274
3275     common_base = min(Path(src).resolve() for src in srcs)
3276     if common_base.is_dir():
3277         # Append a fake file so `parents` below returns `common_base_dir`, too.
3278         common_base /= "fake-file"
3279     for directory in common_base.parents:
3280         if (directory / ".git").is_dir():
3281             return directory
3282
3283         if (directory / ".hg").is_dir():
3284             return directory
3285
3286         if (directory / "pyproject.toml").is_file():
3287             return directory
3288
3289     return directory
3290
3291
3292 @dataclass
3293 class Report:
3294     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3295
3296     check: bool = False
3297     quiet: bool = False
3298     verbose: bool = False
3299     change_count: int = 0
3300     same_count: int = 0
3301     failure_count: int = 0
3302
3303     def done(self, src: Path, changed: Changed) -> None:
3304         """Increment the counter for successful reformatting. Write out a message."""
3305         if changed is Changed.YES:
3306             reformatted = "would reformat" if self.check else "reformatted"
3307             if self.verbose or not self.quiet:
3308                 out(f"{reformatted} {src}")
3309             self.change_count += 1
3310         else:
3311             if self.verbose:
3312                 if changed is Changed.NO:
3313                     msg = f"{src} already well formatted, good job."
3314                 else:
3315                     msg = f"{src} wasn't modified on disk since last run."
3316                 out(msg, bold=False)
3317             self.same_count += 1
3318
3319     def failed(self, src: Path, message: str) -> None:
3320         """Increment the counter for failed reformatting. Write out a message."""
3321         err(f"error: cannot format {src}: {message}")
3322         self.failure_count += 1
3323
3324     def path_ignored(self, path: Path, message: str) -> None:
3325         if self.verbose:
3326             out(f"{path} ignored: {message}", bold=False)
3327
3328     @property
3329     def return_code(self) -> int:
3330         """Return the exit code that the app should use.
3331
3332         This considers the current state of changed files and failures:
3333         - if there were any failures, return 123;
3334         - if any files were changed and --check is being used, return 1;
3335         - otherwise return 0.
3336         """
3337         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3338         # 126 we have special return codes reserved by the shell.
3339         if self.failure_count:
3340             return 123
3341
3342         elif self.change_count and self.check:
3343             return 1
3344
3345         return 0
3346
3347     def __str__(self) -> str:
3348         """Render a color report of the current state.
3349
3350         Use `click.unstyle` to remove colors.
3351         """
3352         if self.check:
3353             reformatted = "would be reformatted"
3354             unchanged = "would be left unchanged"
3355             failed = "would fail to reformat"
3356         else:
3357             reformatted = "reformatted"
3358             unchanged = "left unchanged"
3359             failed = "failed to reformat"
3360         report = []
3361         if self.change_count:
3362             s = "s" if self.change_count > 1 else ""
3363             report.append(
3364                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3365             )
3366         if self.same_count:
3367             s = "s" if self.same_count > 1 else ""
3368             report.append(f"{self.same_count} file{s} {unchanged}")
3369         if self.failure_count:
3370             s = "s" if self.failure_count > 1 else ""
3371             report.append(
3372                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3373             )
3374         return ", ".join(report) + "."
3375
3376
3377 def assert_equivalent(src: str, dst: str) -> None:
3378     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3379
3380     import ast
3381     import traceback
3382
3383     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3384         """Simple visitor generating strings to compare ASTs by content."""
3385         yield f"{'  ' * depth}{node.__class__.__name__}("
3386
3387         for field in sorted(node._fields):
3388             try:
3389                 value = getattr(node, field)
3390             except AttributeError:
3391                 continue
3392
3393             yield f"{'  ' * (depth+1)}{field}="
3394
3395             if isinstance(value, list):
3396                 for item in value:
3397                     # Ignore nested tuples within del statements, because we may insert
3398                     # parentheses and they change the AST.
3399                     if (
3400                         field == "targets"
3401                         and isinstance(node, ast.Delete)
3402                         and isinstance(item, ast.Tuple)
3403                     ):
3404                         for item in item.elts:
3405                             yield from _v(item, depth + 2)
3406                     elif isinstance(item, ast.AST):
3407                         yield from _v(item, depth + 2)
3408
3409             elif isinstance(value, ast.AST):
3410                 yield from _v(value, depth + 2)
3411
3412             else:
3413                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3414
3415         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3416
3417     try:
3418         src_ast = ast.parse(src)
3419     except Exception as exc:
3420         major, minor = sys.version_info[:2]
3421         raise AssertionError(
3422             f"cannot use --safe with this file; failed to parse source file "
3423             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3424             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3425         )
3426
3427     try:
3428         dst_ast = ast.parse(dst)
3429     except Exception as exc:
3430         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3431         raise AssertionError(
3432             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3433             f"Please report a bug on https://github.com/python/black/issues.  "
3434             f"This invalid output might be helpful: {log}"
3435         ) from None
3436
3437     src_ast_str = "\n".join(_v(src_ast))
3438     dst_ast_str = "\n".join(_v(dst_ast))
3439     if src_ast_str != dst_ast_str:
3440         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3441         raise AssertionError(
3442             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3443             f"the source.  "
3444             f"Please report a bug on https://github.com/python/black/issues.  "
3445             f"This diff might be helpful: {log}"
3446         ) from None
3447
3448
3449 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3450     """Raise AssertionError if `dst` reformats differently the second time."""
3451     newdst = format_str(dst, mode=mode)
3452     if dst != newdst:
3453         log = dump_to_file(
3454             diff(src, dst, "source", "first pass"),
3455             diff(dst, newdst, "first pass", "second pass"),
3456         )
3457         raise AssertionError(
3458             f"INTERNAL ERROR: Black produced different code on the second pass "
3459             f"of the formatter.  "
3460             f"Please report a bug on https://github.com/python/black/issues.  "
3461             f"This diff might be helpful: {log}"
3462         ) from None
3463
3464
3465 def dump_to_file(*output: str) -> str:
3466     """Dump `output` to a temporary file. Return path to the file."""
3467     import tempfile
3468
3469     with tempfile.NamedTemporaryFile(
3470         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3471     ) as f:
3472         for lines in output:
3473             f.write(lines)
3474             if lines and lines[-1] != "\n":
3475                 f.write("\n")
3476     return f.name
3477
3478
3479 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3480     """Return a unified diff string between strings `a` and `b`."""
3481     import difflib
3482
3483     a_lines = [line + "\n" for line in a.split("\n")]
3484     b_lines = [line + "\n" for line in b.split("\n")]
3485     return "".join(
3486         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3487     )
3488
3489
3490 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3491     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3492     err("Aborted!")
3493     for task in tasks:
3494         task.cancel()
3495
3496
3497 def shutdown(loop: BaseEventLoop) -> None:
3498     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3499     try:
3500         if sys.version_info[:2] >= (3, 7):
3501             all_tasks = asyncio.all_tasks
3502         else:
3503             all_tasks = asyncio.Task.all_tasks
3504         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3505         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3506         if not to_cancel:
3507             return
3508
3509         for task in to_cancel:
3510             task.cancel()
3511         loop.run_until_complete(
3512             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3513         )
3514     finally:
3515         # `concurrent.futures.Future` objects cannot be cancelled once they
3516         # are already running. There might be some when the `shutdown()` happened.
3517         # Silence their logger's spew about the event loop being closed.
3518         cf_logger = logging.getLogger("concurrent.futures")
3519         cf_logger.setLevel(logging.CRITICAL)
3520         loop.close()
3521
3522
3523 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3524     """Replace `regex` with `replacement` twice on `original`.
3525
3526     This is used by string normalization to perform replaces on
3527     overlapping matches.
3528     """
3529     return regex.sub(replacement, regex.sub(replacement, original))
3530
3531
3532 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3533     """Compile a regular expression string in `regex`.
3534
3535     If it contains newlines, use verbose mode.
3536     """
3537     if "\n" in regex:
3538         regex = "(?x)" + regex
3539     return re.compile(regex)
3540
3541
3542 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3543     """Like `reversed(enumerate(sequence))` if that were possible."""
3544     index = len(sequence) - 1
3545     for element in reversed(sequence):
3546         yield (index, element)
3547         index -= 1
3548
3549
3550 def enumerate_with_length(
3551     line: Line, reversed: bool = False
3552 ) -> Iterator[Tuple[Index, Leaf, int]]:
3553     """Return an enumeration of leaves with their length.
3554
3555     Stops prematurely on multiline strings and standalone comments.
3556     """
3557     op = cast(
3558         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3559         enumerate_reversed if reversed else enumerate,
3560     )
3561     for index, leaf in op(line.leaves):
3562         length = len(leaf.prefix) + len(leaf.value)
3563         if "\n" in leaf.value:
3564             return  # Multiline strings, we can't continue.
3565
3566         comment: Optional[Leaf]
3567         for comment in line.comments_after(leaf):
3568             length += len(comment.value)
3569
3570         yield index, leaf, length
3571
3572
3573 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3574     """Return True if `line` is no longer than `line_length`.
3575
3576     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3577     """
3578     if not line_str:
3579         line_str = str(line).strip("\n")
3580     return (
3581         len(line_str) <= line_length
3582         and "\n" not in line_str  # multiline strings
3583         and not line.contains_standalone_comments()
3584     )
3585
3586
3587 def can_be_split(line: Line) -> bool:
3588     """Return False if the line cannot be split *for sure*.
3589
3590     This is not an exhaustive search but a cheap heuristic that we can use to
3591     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3592     in unnecessary parentheses).
3593     """
3594     leaves = line.leaves
3595     if len(leaves) < 2:
3596         return False
3597
3598     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3599         call_count = 0
3600         dot_count = 0
3601         next = leaves[-1]
3602         for leaf in leaves[-2::-1]:
3603             if leaf.type in OPENING_BRACKETS:
3604                 if next.type not in CLOSING_BRACKETS:
3605                     return False
3606
3607                 call_count += 1
3608             elif leaf.type == token.DOT:
3609                 dot_count += 1
3610             elif leaf.type == token.NAME:
3611                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3612                     return False
3613
3614             elif leaf.type not in CLOSING_BRACKETS:
3615                 return False
3616
3617             if dot_count > 1 and call_count > 1:
3618                 return False
3619
3620     return True
3621
3622
3623 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3624     """Does `line` have a shape safe to reformat without optional parens around it?
3625
3626     Returns True for only a subset of potentially nice looking formattings but
3627     the point is to not return false positives that end up producing lines that
3628     are too long.
3629     """
3630     bt = line.bracket_tracker
3631     if not bt.delimiters:
3632         # Without delimiters the optional parentheses are useless.
3633         return True
3634
3635     max_priority = bt.max_delimiter_priority()
3636     if bt.delimiter_count_with_priority(max_priority) > 1:
3637         # With more than one delimiter of a kind the optional parentheses read better.
3638         return False
3639
3640     if max_priority == DOT_PRIORITY:
3641         # A single stranded method call doesn't require optional parentheses.
3642         return True
3643
3644     assert len(line.leaves) >= 2, "Stranded delimiter"
3645
3646     first = line.leaves[0]
3647     second = line.leaves[1]
3648     penultimate = line.leaves[-2]
3649     last = line.leaves[-1]
3650
3651     # With a single delimiter, omit if the expression starts or ends with
3652     # a bracket.
3653     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3654         remainder = False
3655         length = 4 * line.depth
3656         for _index, leaf, leaf_length in enumerate_with_length(line):
3657             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3658                 remainder = True
3659             if remainder:
3660                 length += leaf_length
3661                 if length > line_length:
3662                     break
3663
3664                 if leaf.type in OPENING_BRACKETS:
3665                     # There are brackets we can further split on.
3666                     remainder = False
3667
3668         else:
3669             # checked the entire string and line length wasn't exceeded
3670             if len(line.leaves) == _index + 1:
3671                 return True
3672
3673         # Note: we are not returning False here because a line might have *both*
3674         # a leading opening bracket and a trailing closing bracket.  If the
3675         # opening bracket doesn't match our rule, maybe the closing will.
3676
3677     if (
3678         last.type == token.RPAR
3679         or last.type == token.RBRACE
3680         or (
3681             # don't use indexing for omitting optional parentheses;
3682             # it looks weird
3683             last.type == token.RSQB
3684             and last.parent
3685             and last.parent.type != syms.trailer
3686         )
3687     ):
3688         if penultimate.type in OPENING_BRACKETS:
3689             # Empty brackets don't help.
3690             return False
3691
3692         if is_multiline_string(first):
3693             # Additional wrapping of a multiline string in this situation is
3694             # unnecessary.
3695             return True
3696
3697         length = 4 * line.depth
3698         seen_other_brackets = False
3699         for _index, leaf, leaf_length in enumerate_with_length(line):
3700             length += leaf_length
3701             if leaf is last.opening_bracket:
3702                 if seen_other_brackets or length <= line_length:
3703                     return True
3704
3705             elif leaf.type in OPENING_BRACKETS:
3706                 # There are brackets we can further split on.
3707                 seen_other_brackets = True
3708
3709     return False
3710
3711
3712 def get_cache_file(mode: FileMode) -> Path:
3713     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3714
3715
3716 def read_cache(mode: FileMode) -> Cache:
3717     """Read the cache if it exists and is well formed.
3718
3719     If it is not well formed, the call to write_cache later should resolve the issue.
3720     """
3721     cache_file = get_cache_file(mode)
3722     if not cache_file.exists():
3723         return {}
3724
3725     with cache_file.open("rb") as fobj:
3726         try:
3727             cache: Cache = pickle.load(fobj)
3728         except pickle.UnpicklingError:
3729             return {}
3730
3731     return cache
3732
3733
3734 def get_cache_info(path: Path) -> CacheInfo:
3735     """Return the information used to check if a file is already formatted or not."""
3736     stat = path.stat()
3737     return stat.st_mtime, stat.st_size
3738
3739
3740 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3741     """Split an iterable of paths in `sources` into two sets.
3742
3743     The first contains paths of files that modified on disk or are not in the
3744     cache. The other contains paths to non-modified files.
3745     """
3746     todo, done = set(), set()
3747     for src in sources:
3748         src = src.resolve()
3749         if cache.get(src) != get_cache_info(src):
3750             todo.add(src)
3751         else:
3752             done.add(src)
3753     return todo, done
3754
3755
3756 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3757     """Update the cache file."""
3758     cache_file = get_cache_file(mode)
3759     try:
3760         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3761         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3762         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3763             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3764         os.replace(f.name, cache_file)
3765     except OSError:
3766         pass
3767
3768
3769 def patch_click() -> None:
3770     """Make Click not crash.
3771
3772     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3773     default which restricts paths that it can access during the lifetime of the
3774     application.  Click refuses to work in this scenario by raising a RuntimeError.
3775
3776     In case of Black the likelihood that non-ASCII characters are going to be used in
3777     file paths is minimal since it's Python source code.  Moreover, this crash was
3778     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3779     """
3780     try:
3781         from click import core
3782         from click import _unicodefun  # type: ignore
3783     except ModuleNotFoundError:
3784         return
3785
3786     for module in (core, _unicodefun):
3787         if hasattr(module, "_verify_python3_env"):
3788             module._verify_python3_env = lambda: None
3789
3790
3791 def patched_main() -> None:
3792     freeze_support()
3793     patch_click()
3794     main()
3795
3796
3797 if __name__ == "__main__":
3798     patched_main()