]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

1978fd52ec91969e226908ebb566d18f422a3ab7
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "19.3b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA_IN_CALL = 4
137     TRAILING_COMMA_IN_DEF = 5
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.PY27: set(),
142     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
144     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
145     TargetVersion.PY36: {
146         Feature.UNICODE_LITERALS,
147         Feature.F_STRINGS,
148         Feature.NUMERIC_UNDERSCORES,
149         Feature.TRAILING_COMMA_IN_CALL,
150         Feature.TRAILING_COMMA_IN_DEF,
151     },
152     TargetVersion.PY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA_IN_CALL,
157         Feature.TRAILING_COMMA_IN_DEF,
158     },
159     TargetVersion.PY38: {
160         Feature.UNICODE_LITERALS,
161         Feature.F_STRINGS,
162         Feature.NUMERIC_UNDERSCORES,
163         Feature.TRAILING_COMMA_IN_CALL,
164         Feature.TRAILING_COMMA_IN_DEF,
165     },
166 }
167
168
169 @dataclass
170 class FileMode:
171     target_versions: Set[TargetVersion] = Factory(set)
172     line_length: int = DEFAULT_LINE_LENGTH
173     string_normalization: bool = True
174     is_pyi: bool = False
175
176     def get_cache_key(self) -> str:
177         if self.target_versions:
178             version_str = ",".join(
179                 str(version.value)
180                 for version in sorted(self.target_versions, key=lambda v: v.value)
181             )
182         else:
183             version_str = "-"
184         parts = [
185             version_str,
186             str(self.line_length),
187             str(int(self.string_normalization)),
188             str(int(self.is_pyi)),
189         ]
190         return ".".join(parts)
191
192
193 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
194     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
195
196
197 def read_pyproject_toml(
198     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
199 ) -> Optional[str]:
200     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
201
202     Returns the path to a successfully found and read configuration file, None
203     otherwise.
204     """
205     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
206     if not value:
207         root = find_project_root(ctx.params.get("src", ()))
208         path = root / "pyproject.toml"
209         if path.is_file():
210             value = str(path)
211         else:
212             return None
213
214     try:
215         pyproject_toml = toml.load(value)
216         config = pyproject_toml.get("tool", {}).get("black", {})
217     except (toml.TomlDecodeError, OSError) as e:
218         raise click.FileError(
219             filename=value, hint=f"Error reading configuration file: {e}"
220         )
221
222     if not config:
223         return None
224
225     if ctx.default_map is None:
226         ctx.default_map = {}
227     ctx.default_map.update(  # type: ignore  # bad types in .pyi
228         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
229     )
230     return value
231
232
233 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
234 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
235 @click.option(
236     "-l",
237     "--line-length",
238     type=int,
239     default=DEFAULT_LINE_LENGTH,
240     help="How many characters per line to allow.",
241     show_default=True,
242 )
243 @click.option(
244     "-t",
245     "--target-version",
246     type=click.Choice([v.name.lower() for v in TargetVersion]),
247     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
248     multiple=True,
249     help=(
250         "Python versions that should be supported by Black's output. [default: "
251         "per-file auto-detection]"
252     ),
253 )
254 @click.option(
255     "--py36",
256     is_flag=True,
257     help=(
258         "Allow using Python 3.6-only syntax on all input files.  This will put "
259         "trailing commas in function signatures and calls also after *args and "
260         "**kwargs. Deprecated; use --target-version instead. "
261         "[default: per-file auto-detection]"
262     ),
263 )
264 @click.option(
265     "--pyi",
266     is_flag=True,
267     help=(
268         "Format all input files like typing stubs regardless of file extension "
269         "(useful when piping source on standard input)."
270     ),
271 )
272 @click.option(
273     "-S",
274     "--skip-string-normalization",
275     is_flag=True,
276     help="Don't normalize string quotes or prefixes.",
277 )
278 @click.option(
279     "--check",
280     is_flag=True,
281     help=(
282         "Don't write the files back, just return the status.  Return code 0 "
283         "means nothing would change.  Return code 1 means some files would be "
284         "reformatted.  Return code 123 means there was an internal error."
285     ),
286 )
287 @click.option(
288     "--diff",
289     is_flag=True,
290     help="Don't write the files back, just output a diff for each file on stdout.",
291 )
292 @click.option(
293     "--fast/--safe",
294     is_flag=True,
295     help="If --fast given, skip temporary sanity checks. [default: --safe]",
296 )
297 @click.option(
298     "--include",
299     type=str,
300     default=DEFAULT_INCLUDES,
301     help=(
302         "A regular expression that matches files and directories that should be "
303         "included on recursive searches.  An empty value means all files are "
304         "included regardless of the name.  Use forward slashes for directories on "
305         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
306         "later."
307     ),
308     show_default=True,
309 )
310 @click.option(
311     "--exclude",
312     type=str,
313     default=DEFAULT_EXCLUDES,
314     help=(
315         "A regular expression that matches files and directories that should be "
316         "excluded on recursive searches.  An empty value means no paths are excluded. "
317         "Use forward slashes for directories on all platforms (Windows, too).  "
318         "Exclusions are calculated first, inclusions later."
319     ),
320     show_default=True,
321 )
322 @click.option(
323     "-q",
324     "--quiet",
325     is_flag=True,
326     help=(
327         "Don't emit non-error messages to stderr. Errors are still emitted, "
328         "silence those with 2>/dev/null."
329     ),
330 )
331 @click.option(
332     "-v",
333     "--verbose",
334     is_flag=True,
335     help=(
336         "Also emit messages to stderr about files that were not changed or were "
337         "ignored due to --exclude=."
338     ),
339 )
340 @click.version_option(version=__version__)
341 @click.argument(
342     "src",
343     nargs=-1,
344     type=click.Path(
345         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
346     ),
347     is_eager=True,
348 )
349 @click.option(
350     "--config",
351     type=click.Path(
352         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
353     ),
354     is_eager=True,
355     callback=read_pyproject_toml,
356     help="Read configuration from PATH.",
357 )
358 @click.pass_context
359 def main(
360     ctx: click.Context,
361     code: Optional[str],
362     line_length: int,
363     target_version: List[TargetVersion],
364     check: bool,
365     diff: bool,
366     fast: bool,
367     pyi: bool,
368     py36: bool,
369     skip_string_normalization: bool,
370     quiet: bool,
371     verbose: bool,
372     include: str,
373     exclude: str,
374     src: Tuple[str],
375     config: Optional[str],
376 ) -> None:
377     """The uncompromising code formatter."""
378     write_back = WriteBack.from_configuration(check=check, diff=diff)
379     if target_version:
380         if py36:
381             err(f"Cannot use both --target-version and --py36")
382             ctx.exit(2)
383         else:
384             versions = set(target_version)
385     elif py36:
386         err(
387             "--py36 is deprecated and will be removed in a future version. "
388             "Use --target-version py36 instead."
389         )
390         versions = PY36_VERSIONS
391     else:
392         # We'll autodetect later.
393         versions = set()
394     mode = FileMode(
395         target_versions=versions,
396         line_length=line_length,
397         is_pyi=pyi,
398         string_normalization=not skip_string_normalization,
399     )
400     if config and verbose:
401         out(f"Using configuration from {config}.", bold=False, fg="blue")
402     if code is not None:
403         print(format_str(code, mode=mode))
404         ctx.exit(0)
405     try:
406         include_regex = re_compile_maybe_verbose(include)
407     except re.error:
408         err(f"Invalid regular expression for include given: {include!r}")
409         ctx.exit(2)
410     try:
411         exclude_regex = re_compile_maybe_verbose(exclude)
412     except re.error:
413         err(f"Invalid regular expression for exclude given: {exclude!r}")
414         ctx.exit(2)
415     report = Report(check=check, quiet=quiet, verbose=verbose)
416     root = find_project_root(src)
417     sources: Set[Path] = set()
418     for s in src:
419         p = Path(s)
420         if p.is_dir():
421             sources.update(
422                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
423             )
424         elif p.is_file() or s == "-":
425             # if a file was explicitly given, we don't care about its extension
426             sources.add(p)
427         else:
428             err(f"invalid path: {s}")
429     if len(sources) == 0:
430         if verbose or not quiet:
431             out("No paths given. Nothing to do 😴")
432         ctx.exit(0)
433
434     if len(sources) == 1:
435         reformat_one(
436             src=sources.pop(),
437             fast=fast,
438             write_back=write_back,
439             mode=mode,
440             report=report,
441         )
442     else:
443         reformat_many(
444             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
445         )
446
447     if verbose or not quiet:
448         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
449         out(f"All done! {bang}")
450         click.secho(str(report), err=True)
451     ctx.exit(report.return_code)
452
453
454 def reformat_one(
455     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
456 ) -> None:
457     """Reformat a single file under `src` without spawning child processes.
458
459     If `quiet` is True, non-error messages are not output. `line_length`,
460     `write_back`, `fast` and `pyi` options are passed to
461     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
462     """
463     try:
464         changed = Changed.NO
465         if not src.is_file() and str(src) == "-":
466             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
467                 changed = Changed.YES
468         else:
469             cache: Cache = {}
470             if write_back != WriteBack.DIFF:
471                 cache = read_cache(mode)
472                 res_src = src.resolve()
473                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
474                     changed = Changed.CACHED
475             if changed is not Changed.CACHED and format_file_in_place(
476                 src, fast=fast, write_back=write_back, mode=mode
477             ):
478                 changed = Changed.YES
479             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
480                 write_back is WriteBack.CHECK and changed is Changed.NO
481             ):
482                 write_cache(cache, [src], mode)
483         report.done(src, changed)
484     except Exception as exc:
485         report.failed(src, str(exc))
486
487
488 def reformat_many(
489     sources: Set[Path],
490     fast: bool,
491     write_back: WriteBack,
492     mode: FileMode,
493     report: "Report",
494 ) -> None:
495     """Reformat multiple files using a ProcessPoolExecutor."""
496     loop = asyncio.get_event_loop()
497     worker_count = os.cpu_count()
498     if sys.platform == "win32":
499         # Work around https://bugs.python.org/issue26903
500         worker_count = min(worker_count, 61)
501     executor = ProcessPoolExecutor(max_workers=worker_count)
502     try:
503         loop.run_until_complete(
504             schedule_formatting(
505                 sources=sources,
506                 fast=fast,
507                 write_back=write_back,
508                 mode=mode,
509                 report=report,
510                 loop=loop,
511                 executor=executor,
512             )
513         )
514     finally:
515         shutdown(loop)
516
517
518 async def schedule_formatting(
519     sources: Set[Path],
520     fast: bool,
521     write_back: WriteBack,
522     mode: FileMode,
523     report: "Report",
524     loop: BaseEventLoop,
525     executor: Executor,
526 ) -> None:
527     """Run formatting of `sources` in parallel using the provided `executor`.
528
529     (Use ProcessPoolExecutors for actual parallelism.)
530
531     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
532     :func:`format_file_in_place`.
533     """
534     cache: Cache = {}
535     if write_back != WriteBack.DIFF:
536         cache = read_cache(mode)
537         sources, cached = filter_cached(cache, sources)
538         for src in sorted(cached):
539             report.done(src, Changed.CACHED)
540     if not sources:
541         return
542
543     cancelled = []
544     sources_to_cache = []
545     lock = None
546     if write_back == WriteBack.DIFF:
547         # For diff output, we need locks to ensure we don't interleave output
548         # from different processes.
549         manager = Manager()
550         lock = manager.Lock()
551     tasks = {
552         asyncio.ensure_future(
553             loop.run_in_executor(
554                 executor, format_file_in_place, src, fast, mode, write_back, lock
555             )
556         ): src
557         for src in sorted(sources)
558     }
559     pending: Iterable[asyncio.Future] = tasks.keys()
560     try:
561         loop.add_signal_handler(signal.SIGINT, cancel, pending)
562         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
563     except NotImplementedError:
564         # There are no good alternatives for these on Windows.
565         pass
566     while pending:
567         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
568         for task in done:
569             src = tasks.pop(task)
570             if task.cancelled():
571                 cancelled.append(task)
572             elif task.exception():
573                 report.failed(src, str(task.exception()))
574             else:
575                 changed = Changed.YES if task.result() else Changed.NO
576                 # If the file was written back or was successfully checked as
577                 # well-formatted, store this information in the cache.
578                 if write_back is WriteBack.YES or (
579                     write_back is WriteBack.CHECK and changed is Changed.NO
580                 ):
581                     sources_to_cache.append(src)
582                 report.done(src, changed)
583     if cancelled:
584         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
585     if sources_to_cache:
586         write_cache(cache, sources_to_cache, mode)
587
588
589 def format_file_in_place(
590     src: Path,
591     fast: bool,
592     mode: FileMode,
593     write_back: WriteBack = WriteBack.NO,
594     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
595 ) -> bool:
596     """Format file under `src` path. Return True if changed.
597
598     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
599     code to the file.
600     `line_length` and `fast` options are passed to :func:`format_file_contents`.
601     """
602     if src.suffix == ".pyi":
603         mode = evolve(mode, is_pyi=True)
604
605     then = datetime.utcfromtimestamp(src.stat().st_mtime)
606     with open(src, "rb") as buf:
607         src_contents, encoding, newline = decode_bytes(buf.read())
608     try:
609         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
610     except NothingChanged:
611         return False
612
613     if write_back == write_back.YES:
614         with open(src, "w", encoding=encoding, newline=newline) as f:
615             f.write(dst_contents)
616     elif write_back == write_back.DIFF:
617         now = datetime.utcnow()
618         src_name = f"{src}\t{then} +0000"
619         dst_name = f"{src}\t{now} +0000"
620         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
621         if lock:
622             lock.acquire()
623         try:
624             f = io.TextIOWrapper(
625                 sys.stdout.buffer,
626                 encoding=encoding,
627                 newline=newline,
628                 write_through=True,
629             )
630             f.write(diff_contents)
631             f.detach()
632         finally:
633             if lock:
634                 lock.release()
635     return True
636
637
638 def format_stdin_to_stdout(
639     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
640 ) -> bool:
641     """Format file on stdin. Return True if changed.
642
643     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
644     write a diff to stdout. The `mode` argument is passed to
645     :func:`format_file_contents`.
646     """
647     then = datetime.utcnow()
648     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
649     dst = src
650     try:
651         dst = format_file_contents(src, fast=fast, mode=mode)
652         return True
653
654     except NothingChanged:
655         return False
656
657     finally:
658         f = io.TextIOWrapper(
659             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
660         )
661         if write_back == WriteBack.YES:
662             f.write(dst)
663         elif write_back == WriteBack.DIFF:
664             now = datetime.utcnow()
665             src_name = f"STDIN\t{then} +0000"
666             dst_name = f"STDOUT\t{now} +0000"
667             f.write(diff(src, dst, src_name, dst_name))
668         f.detach()
669
670
671 def format_file_contents(
672     src_contents: str, *, fast: bool, mode: FileMode
673 ) -> FileContent:
674     """Reformat contents a file and return new contents.
675
676     If `fast` is False, additionally confirm that the reformatted code is
677     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
678     `line_length` is passed to :func:`format_str`.
679     """
680     if src_contents.strip() == "":
681         raise NothingChanged
682
683     dst_contents = format_str(src_contents, mode=mode)
684     if src_contents == dst_contents:
685         raise NothingChanged
686
687     if not fast:
688         assert_equivalent(src_contents, dst_contents)
689         assert_stable(src_contents, dst_contents, mode=mode)
690     return dst_contents
691
692
693 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
694     """Reformat a string and return new contents.
695
696     `line_length` determines how many characters per line are allowed.
697     """
698     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
699     dst_contents = ""
700     future_imports = get_future_imports(src_node)
701     if mode.target_versions:
702         versions = mode.target_versions
703     else:
704         versions = detect_target_versions(src_node)
705     normalize_fmt_off(src_node)
706     lines = LineGenerator(
707         remove_u_prefix="unicode_literals" in future_imports
708         or supports_feature(versions, Feature.UNICODE_LITERALS),
709         is_pyi=mode.is_pyi,
710         normalize_strings=mode.string_normalization,
711     )
712     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
713     empty_line = Line()
714     after = 0
715     split_line_features = {
716         feature
717         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
718         if supports_feature(versions, feature)
719     }
720     for current_line in lines.visit(src_node):
721         for _ in range(after):
722             dst_contents += str(empty_line)
723         before, after = elt.maybe_empty_lines(current_line)
724         for _ in range(before):
725             dst_contents += str(empty_line)
726         for line in split_line(
727             current_line, line_length=mode.line_length, features=split_line_features
728         ):
729             dst_contents += str(line)
730     return dst_contents
731
732
733 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
734     """Return a tuple of (decoded_contents, encoding, newline).
735
736     `newline` is either CRLF or LF but `decoded_contents` is decoded with
737     universal newlines (i.e. only contains LF).
738     """
739     srcbuf = io.BytesIO(src)
740     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
741     if not lines:
742         return "", encoding, "\n"
743
744     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
745     srcbuf.seek(0)
746     with io.TextIOWrapper(srcbuf, encoding) as tiow:
747         return tiow.read(), encoding, newline
748
749
750 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
751     if not target_versions:
752         # No target_version specified, so try all grammars.
753         return [
754             pygram.python_grammar_no_print_statement_no_exec_statement,
755             pygram.python_grammar_no_print_statement,
756             pygram.python_grammar,
757         ]
758     elif all(version.is_python2() for version in target_versions):
759         # Python 2-only code, so try Python 2 grammars.
760         return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
761     else:
762         # Python 3-compatible code, so only try Python 3 grammar.
763         return [pygram.python_grammar_no_print_statement_no_exec_statement]
764
765
766 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
767     """Given a string with source, return the lib2to3 Node."""
768     if src_txt[-1:] != "\n":
769         src_txt += "\n"
770
771     for grammar in get_grammars(set(target_versions)):
772         drv = driver.Driver(grammar, pytree.convert)
773         try:
774             result = drv.parse_string(src_txt, True)
775             break
776
777         except ParseError as pe:
778             lineno, column = pe.context[1]
779             lines = src_txt.splitlines()
780             try:
781                 faulty_line = lines[lineno - 1]
782             except IndexError:
783                 faulty_line = "<line number missing in source>"
784             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
785     else:
786         raise exc from None
787
788     if isinstance(result, Leaf):
789         result = Node(syms.file_input, [result])
790     return result
791
792
793 def lib2to3_unparse(node: Node) -> str:
794     """Given a lib2to3 node, return its string representation."""
795     code = str(node)
796     return code
797
798
799 T = TypeVar("T")
800
801
802 class Visitor(Generic[T]):
803     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
804
805     def visit(self, node: LN) -> Iterator[T]:
806         """Main method to visit `node` and its children.
807
808         It tries to find a `visit_*()` method for the given `node.type`, like
809         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
810         If no dedicated `visit_*()` method is found, chooses `visit_default()`
811         instead.
812
813         Then yields objects of type `T` from the selected visitor.
814         """
815         if node.type < 256:
816             name = token.tok_name[node.type]
817         else:
818             name = type_repr(node.type)
819         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
820
821     def visit_default(self, node: LN) -> Iterator[T]:
822         """Default `visit_*()` implementation. Recurses to children of `node`."""
823         if isinstance(node, Node):
824             for child in node.children:
825                 yield from self.visit(child)
826
827
828 @dataclass
829 class DebugVisitor(Visitor[T]):
830     tree_depth: int = 0
831
832     def visit_default(self, node: LN) -> Iterator[T]:
833         indent = " " * (2 * self.tree_depth)
834         if isinstance(node, Node):
835             _type = type_repr(node.type)
836             out(f"{indent}{_type}", fg="yellow")
837             self.tree_depth += 1
838             for child in node.children:
839                 yield from self.visit(child)
840
841             self.tree_depth -= 1
842             out(f"{indent}/{_type}", fg="yellow", bold=False)
843         else:
844             _type = token.tok_name.get(node.type, str(node.type))
845             out(f"{indent}{_type}", fg="blue", nl=False)
846             if node.prefix:
847                 # We don't have to handle prefixes for `Node` objects since
848                 # that delegates to the first child anyway.
849                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
850             out(f" {node.value!r}", fg="blue", bold=False)
851
852     @classmethod
853     def show(cls, code: Union[str, Leaf, Node]) -> None:
854         """Pretty-print the lib2to3 AST of a given string of `code`.
855
856         Convenience method for debugging.
857         """
858         v: DebugVisitor[None] = DebugVisitor()
859         if isinstance(code, str):
860             code = lib2to3_parse(code)
861         list(v.visit(code))
862
863
864 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
865 STATEMENT = {
866     syms.if_stmt,
867     syms.while_stmt,
868     syms.for_stmt,
869     syms.try_stmt,
870     syms.except_clause,
871     syms.with_stmt,
872     syms.funcdef,
873     syms.classdef,
874 }
875 STANDALONE_COMMENT = 153
876 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
877 LOGIC_OPERATORS = {"and", "or"}
878 COMPARATORS = {
879     token.LESS,
880     token.GREATER,
881     token.EQEQUAL,
882     token.NOTEQUAL,
883     token.LESSEQUAL,
884     token.GREATEREQUAL,
885 }
886 MATH_OPERATORS = {
887     token.VBAR,
888     token.CIRCUMFLEX,
889     token.AMPER,
890     token.LEFTSHIFT,
891     token.RIGHTSHIFT,
892     token.PLUS,
893     token.MINUS,
894     token.STAR,
895     token.SLASH,
896     token.DOUBLESLASH,
897     token.PERCENT,
898     token.AT,
899     token.TILDE,
900     token.DOUBLESTAR,
901 }
902 STARS = {token.STAR, token.DOUBLESTAR}
903 VARARGS_PARENTS = {
904     syms.arglist,
905     syms.argument,  # double star in arglist
906     syms.trailer,  # single argument to call
907     syms.typedargslist,
908     syms.varargslist,  # lambdas
909 }
910 UNPACKING_PARENTS = {
911     syms.atom,  # single element of a list or set literal
912     syms.dictsetmaker,
913     syms.listmaker,
914     syms.testlist_gexp,
915     syms.testlist_star_expr,
916 }
917 TEST_DESCENDANTS = {
918     syms.test,
919     syms.lambdef,
920     syms.or_test,
921     syms.and_test,
922     syms.not_test,
923     syms.comparison,
924     syms.star_expr,
925     syms.expr,
926     syms.xor_expr,
927     syms.and_expr,
928     syms.shift_expr,
929     syms.arith_expr,
930     syms.trailer,
931     syms.term,
932     syms.power,
933 }
934 ASSIGNMENTS = {
935     "=",
936     "+=",
937     "-=",
938     "*=",
939     "@=",
940     "/=",
941     "%=",
942     "&=",
943     "|=",
944     "^=",
945     "<<=",
946     ">>=",
947     "**=",
948     "//=",
949 }
950 COMPREHENSION_PRIORITY = 20
951 COMMA_PRIORITY = 18
952 TERNARY_PRIORITY = 16
953 LOGIC_PRIORITY = 14
954 STRING_PRIORITY = 12
955 COMPARATOR_PRIORITY = 10
956 MATH_PRIORITIES = {
957     token.VBAR: 9,
958     token.CIRCUMFLEX: 8,
959     token.AMPER: 7,
960     token.LEFTSHIFT: 6,
961     token.RIGHTSHIFT: 6,
962     token.PLUS: 5,
963     token.MINUS: 5,
964     token.STAR: 4,
965     token.SLASH: 4,
966     token.DOUBLESLASH: 4,
967     token.PERCENT: 4,
968     token.AT: 4,
969     token.TILDE: 3,
970     token.DOUBLESTAR: 2,
971 }
972 DOT_PRIORITY = 1
973
974
975 @dataclass
976 class BracketTracker:
977     """Keeps track of brackets on a line."""
978
979     depth: int = 0
980     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
981     delimiters: Dict[LeafID, Priority] = Factory(dict)
982     previous: Optional[Leaf] = None
983     _for_loop_depths: List[int] = Factory(list)
984     _lambda_argument_depths: List[int] = Factory(list)
985
986     def mark(self, leaf: Leaf) -> None:
987         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
988
989         All leaves receive an int `bracket_depth` field that stores how deep
990         within brackets a given leaf is. 0 means there are no enclosing brackets
991         that started on this line.
992
993         If a leaf is itself a closing bracket, it receives an `opening_bracket`
994         field that it forms a pair with. This is a one-directional link to
995         avoid reference cycles.
996
997         If a leaf is a delimiter (a token on which Black can split the line if
998         needed) and it's on depth 0, its `id()` is stored in the tracker's
999         `delimiters` field.
1000         """
1001         if leaf.type == token.COMMENT:
1002             return
1003
1004         self.maybe_decrement_after_for_loop_variable(leaf)
1005         self.maybe_decrement_after_lambda_arguments(leaf)
1006         if leaf.type in CLOSING_BRACKETS:
1007             self.depth -= 1
1008             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1009             leaf.opening_bracket = opening_bracket
1010         leaf.bracket_depth = self.depth
1011         if self.depth == 0:
1012             delim = is_split_before_delimiter(leaf, self.previous)
1013             if delim and self.previous is not None:
1014                 self.delimiters[id(self.previous)] = delim
1015             else:
1016                 delim = is_split_after_delimiter(leaf, self.previous)
1017                 if delim:
1018                     self.delimiters[id(leaf)] = delim
1019         if leaf.type in OPENING_BRACKETS:
1020             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1021             self.depth += 1
1022         self.previous = leaf
1023         self.maybe_increment_lambda_arguments(leaf)
1024         self.maybe_increment_for_loop_variable(leaf)
1025
1026     def any_open_brackets(self) -> bool:
1027         """Return True if there is an yet unmatched open bracket on the line."""
1028         return bool(self.bracket_match)
1029
1030     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1031         """Return the highest priority of a delimiter found on the line.
1032
1033         Values are consistent with what `is_split_*_delimiter()` return.
1034         Raises ValueError on no delimiters.
1035         """
1036         return max(v for k, v in self.delimiters.items() if k not in exclude)
1037
1038     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1039         """Return the number of delimiters with the given `priority`.
1040
1041         If no `priority` is passed, defaults to max priority on the line.
1042         """
1043         if not self.delimiters:
1044             return 0
1045
1046         priority = priority or self.max_delimiter_priority()
1047         return sum(1 for p in self.delimiters.values() if p == priority)
1048
1049     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1050         """In a for loop, or comprehension, the variables are often unpacks.
1051
1052         To avoid splitting on the comma in this situation, increase the depth of
1053         tokens between `for` and `in`.
1054         """
1055         if leaf.type == token.NAME and leaf.value == "for":
1056             self.depth += 1
1057             self._for_loop_depths.append(self.depth)
1058             return True
1059
1060         return False
1061
1062     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1063         """See `maybe_increment_for_loop_variable` above for explanation."""
1064         if (
1065             self._for_loop_depths
1066             and self._for_loop_depths[-1] == self.depth
1067             and leaf.type == token.NAME
1068             and leaf.value == "in"
1069         ):
1070             self.depth -= 1
1071             self._for_loop_depths.pop()
1072             return True
1073
1074         return False
1075
1076     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1077         """In a lambda expression, there might be more than one argument.
1078
1079         To avoid splitting on the comma in this situation, increase the depth of
1080         tokens between `lambda` and `:`.
1081         """
1082         if leaf.type == token.NAME and leaf.value == "lambda":
1083             self.depth += 1
1084             self._lambda_argument_depths.append(self.depth)
1085             return True
1086
1087         return False
1088
1089     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1090         """See `maybe_increment_lambda_arguments` above for explanation."""
1091         if (
1092             self._lambda_argument_depths
1093             and self._lambda_argument_depths[-1] == self.depth
1094             and leaf.type == token.COLON
1095         ):
1096             self.depth -= 1
1097             self._lambda_argument_depths.pop()
1098             return True
1099
1100         return False
1101
1102     def get_open_lsqb(self) -> Optional[Leaf]:
1103         """Return the most recent opening square bracket (if any)."""
1104         return self.bracket_match.get((self.depth - 1, token.RSQB))
1105
1106
1107 @dataclass
1108 class Line:
1109     """Holds leaves and comments. Can be printed with `str(line)`."""
1110
1111     depth: int = 0
1112     leaves: List[Leaf] = Factory(list)
1113     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1114     bracket_tracker: BracketTracker = Factory(BracketTracker)
1115     inside_brackets: bool = False
1116     should_explode: bool = False
1117
1118     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1119         """Add a new `leaf` to the end of the line.
1120
1121         Unless `preformatted` is True, the `leaf` will receive a new consistent
1122         whitespace prefix and metadata applied by :class:`BracketTracker`.
1123         Trailing commas are maybe removed, unpacked for loop variables are
1124         demoted from being delimiters.
1125
1126         Inline comments are put aside.
1127         """
1128         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1129         if not has_value:
1130             return
1131
1132         if token.COLON == leaf.type and self.is_class_paren_empty:
1133             del self.leaves[-2:]
1134         if self.leaves and not preformatted:
1135             # Note: at this point leaf.prefix should be empty except for
1136             # imports, for which we only preserve newlines.
1137             leaf.prefix += whitespace(
1138                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1139             )
1140         if self.inside_brackets or not preformatted:
1141             self.bracket_tracker.mark(leaf)
1142             self.maybe_remove_trailing_comma(leaf)
1143         if not self.append_comment(leaf):
1144             self.leaves.append(leaf)
1145
1146     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1147         """Like :func:`append()` but disallow invalid standalone comment structure.
1148
1149         Raises ValueError when any `leaf` is appended after a standalone comment
1150         or when a standalone comment is not the first leaf on the line.
1151         """
1152         if self.bracket_tracker.depth == 0:
1153             if self.is_comment:
1154                 raise ValueError("cannot append to standalone comments")
1155
1156             if self.leaves and leaf.type == STANDALONE_COMMENT:
1157                 raise ValueError(
1158                     "cannot append standalone comments to a populated line"
1159                 )
1160
1161         self.append(leaf, preformatted=preformatted)
1162
1163     @property
1164     def is_comment(self) -> bool:
1165         """Is this line a standalone comment?"""
1166         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1167
1168     @property
1169     def is_decorator(self) -> bool:
1170         """Is this line a decorator?"""
1171         return bool(self) and self.leaves[0].type == token.AT
1172
1173     @property
1174     def is_import(self) -> bool:
1175         """Is this an import line?"""
1176         return bool(self) and is_import(self.leaves[0])
1177
1178     @property
1179     def is_class(self) -> bool:
1180         """Is this line a class definition?"""
1181         return (
1182             bool(self)
1183             and self.leaves[0].type == token.NAME
1184             and self.leaves[0].value == "class"
1185         )
1186
1187     @property
1188     def is_stub_class(self) -> bool:
1189         """Is this line a class definition with a body consisting only of "..."?"""
1190         return self.is_class and self.leaves[-3:] == [
1191             Leaf(token.DOT, ".") for _ in range(3)
1192         ]
1193
1194     @property
1195     def is_def(self) -> bool:
1196         """Is this a function definition? (Also returns True for async defs.)"""
1197         try:
1198             first_leaf = self.leaves[0]
1199         except IndexError:
1200             return False
1201
1202         try:
1203             second_leaf: Optional[Leaf] = self.leaves[1]
1204         except IndexError:
1205             second_leaf = None
1206         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1207             first_leaf.type == token.ASYNC
1208             and second_leaf is not None
1209             and second_leaf.type == token.NAME
1210             and second_leaf.value == "def"
1211         )
1212
1213     @property
1214     def is_class_paren_empty(self) -> bool:
1215         """Is this a class with no base classes but using parentheses?
1216
1217         Those are unnecessary and should be removed.
1218         """
1219         return (
1220             bool(self)
1221             and len(self.leaves) == 4
1222             and self.is_class
1223             and self.leaves[2].type == token.LPAR
1224             and self.leaves[2].value == "("
1225             and self.leaves[3].type == token.RPAR
1226             and self.leaves[3].value == ")"
1227         )
1228
1229     @property
1230     def is_triple_quoted_string(self) -> bool:
1231         """Is the line a triple quoted string?"""
1232         return (
1233             bool(self)
1234             and self.leaves[0].type == token.STRING
1235             and self.leaves[0].value.startswith(('"""', "'''"))
1236         )
1237
1238     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1239         """If so, needs to be split before emitting."""
1240         for leaf in self.leaves:
1241             if leaf.type == STANDALONE_COMMENT:
1242                 if leaf.bracket_depth <= depth_limit:
1243                     return True
1244         return False
1245
1246     def contains_inner_type_comments(self) -> bool:
1247         ignored_ids = set()
1248         try:
1249             last_leaf = self.leaves[-1]
1250             ignored_ids.add(id(last_leaf))
1251             if last_leaf.type == token.COMMA:
1252                 # When trailing commas are inserted by Black for consistency, comments
1253                 # after the previous last element are not moved (they don't have to,
1254                 # rendering will still be correct).  So we ignore trailing commas.
1255                 last_leaf = self.leaves[-2]
1256                 ignored_ids.add(id(last_leaf))
1257         except IndexError:
1258             return False
1259
1260         for leaf_id, comments in self.comments.items():
1261             if leaf_id in ignored_ids:
1262                 continue
1263
1264             for comment in comments:
1265                 if is_type_comment(comment):
1266                     return True
1267
1268         return False
1269
1270     def contains_multiline_strings(self) -> bool:
1271         for leaf in self.leaves:
1272             if is_multiline_string(leaf):
1273                 return True
1274
1275         return False
1276
1277     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1278         """Remove trailing comma if there is one and it's safe."""
1279         if not (
1280             self.leaves
1281             and self.leaves[-1].type == token.COMMA
1282             and closing.type in CLOSING_BRACKETS
1283         ):
1284             return False
1285
1286         if closing.type == token.RBRACE:
1287             self.remove_trailing_comma()
1288             return True
1289
1290         if closing.type == token.RSQB:
1291             comma = self.leaves[-1]
1292             if comma.parent and comma.parent.type == syms.listmaker:
1293                 self.remove_trailing_comma()
1294                 return True
1295
1296         # For parens let's check if it's safe to remove the comma.
1297         # Imports are always safe.
1298         if self.is_import:
1299             self.remove_trailing_comma()
1300             return True
1301
1302         # Otherwise, if the trailing one is the only one, we might mistakenly
1303         # change a tuple into a different type by removing the comma.
1304         depth = closing.bracket_depth + 1
1305         commas = 0
1306         opening = closing.opening_bracket
1307         for _opening_index, leaf in enumerate(self.leaves):
1308             if leaf is opening:
1309                 break
1310
1311         else:
1312             return False
1313
1314         for leaf in self.leaves[_opening_index + 1 :]:
1315             if leaf is closing:
1316                 break
1317
1318             bracket_depth = leaf.bracket_depth
1319             if bracket_depth == depth and leaf.type == token.COMMA:
1320                 commas += 1
1321                 if leaf.parent and leaf.parent.type == syms.arglist:
1322                     commas += 1
1323                     break
1324
1325         if commas > 1:
1326             self.remove_trailing_comma()
1327             return True
1328
1329         return False
1330
1331     def append_comment(self, comment: Leaf) -> bool:
1332         """Add an inline or standalone comment to the line."""
1333         if (
1334             comment.type == STANDALONE_COMMENT
1335             and self.bracket_tracker.any_open_brackets()
1336         ):
1337             comment.prefix = ""
1338             return False
1339
1340         if comment.type != token.COMMENT:
1341             return False
1342
1343         if not self.leaves:
1344             comment.type = STANDALONE_COMMENT
1345             comment.prefix = ""
1346             return False
1347
1348         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1349         return True
1350
1351     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1352         """Generate comments that should appear directly after `leaf`."""
1353         return self.comments.get(id(leaf), [])
1354
1355     def remove_trailing_comma(self) -> None:
1356         """Remove the trailing comma and moves the comments attached to it."""
1357         trailing_comma = self.leaves.pop()
1358         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1359         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1360             trailing_comma_comments
1361         )
1362
1363     def is_complex_subscript(self, leaf: Leaf) -> bool:
1364         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1365         open_lsqb = self.bracket_tracker.get_open_lsqb()
1366         if open_lsqb is None:
1367             return False
1368
1369         subscript_start = open_lsqb.next_sibling
1370
1371         if isinstance(subscript_start, Node):
1372             if subscript_start.type == syms.listmaker:
1373                 return False
1374
1375             if subscript_start.type == syms.subscriptlist:
1376                 subscript_start = child_towards(subscript_start, leaf)
1377         return subscript_start is not None and any(
1378             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1379         )
1380
1381     def __str__(self) -> str:
1382         """Render the line."""
1383         if not self:
1384             return "\n"
1385
1386         indent = "    " * self.depth
1387         leaves = iter(self.leaves)
1388         first = next(leaves)
1389         res = f"{first.prefix}{indent}{first.value}"
1390         for leaf in leaves:
1391             res += str(leaf)
1392         for comment in itertools.chain.from_iterable(self.comments.values()):
1393             res += str(comment)
1394         return res + "\n"
1395
1396     def __bool__(self) -> bool:
1397         """Return True if the line has leaves or comments."""
1398         return bool(self.leaves or self.comments)
1399
1400
1401 @dataclass
1402 class EmptyLineTracker:
1403     """Provides a stateful method that returns the number of potential extra
1404     empty lines needed before and after the currently processed line.
1405
1406     Note: this tracker works on lines that haven't been split yet.  It assumes
1407     the prefix of the first leaf consists of optional newlines.  Those newlines
1408     are consumed by `maybe_empty_lines()` and included in the computation.
1409     """
1410
1411     is_pyi: bool = False
1412     previous_line: Optional[Line] = None
1413     previous_after: int = 0
1414     previous_defs: List[int] = Factory(list)
1415
1416     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1417         """Return the number of extra empty lines before and after the `current_line`.
1418
1419         This is for separating `def`, `async def` and `class` with extra empty
1420         lines (two on module-level).
1421         """
1422         before, after = self._maybe_empty_lines(current_line)
1423         before -= self.previous_after
1424         self.previous_after = after
1425         self.previous_line = current_line
1426         return before, after
1427
1428     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1429         max_allowed = 1
1430         if current_line.depth == 0:
1431             max_allowed = 1 if self.is_pyi else 2
1432         if current_line.leaves:
1433             # Consume the first leaf's extra newlines.
1434             first_leaf = current_line.leaves[0]
1435             before = first_leaf.prefix.count("\n")
1436             before = min(before, max_allowed)
1437             first_leaf.prefix = ""
1438         else:
1439             before = 0
1440         depth = current_line.depth
1441         while self.previous_defs and self.previous_defs[-1] >= depth:
1442             self.previous_defs.pop()
1443             if self.is_pyi:
1444                 before = 0 if depth else 1
1445             else:
1446                 before = 1 if depth else 2
1447         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1448             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1449
1450         if (
1451             self.previous_line
1452             and self.previous_line.is_import
1453             and not current_line.is_import
1454             and depth == self.previous_line.depth
1455         ):
1456             return (before or 1), 0
1457
1458         if (
1459             self.previous_line
1460             and self.previous_line.is_class
1461             and current_line.is_triple_quoted_string
1462         ):
1463             return before, 1
1464
1465         return before, 0
1466
1467     def _maybe_empty_lines_for_class_or_def(
1468         self, current_line: Line, before: int
1469     ) -> Tuple[int, int]:
1470         if not current_line.is_decorator:
1471             self.previous_defs.append(current_line.depth)
1472         if self.previous_line is None:
1473             # Don't insert empty lines before the first line in the file.
1474             return 0, 0
1475
1476         if self.previous_line.is_decorator:
1477             return 0, 0
1478
1479         if self.previous_line.depth < current_line.depth and (
1480             self.previous_line.is_class or self.previous_line.is_def
1481         ):
1482             return 0, 0
1483
1484         if (
1485             self.previous_line.is_comment
1486             and self.previous_line.depth == current_line.depth
1487             and before == 0
1488         ):
1489             return 0, 0
1490
1491         if self.is_pyi:
1492             if self.previous_line.depth > current_line.depth:
1493                 newlines = 1
1494             elif current_line.is_class or self.previous_line.is_class:
1495                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1496                     # No blank line between classes with an empty body
1497                     newlines = 0
1498                 else:
1499                     newlines = 1
1500             elif current_line.is_def and not self.previous_line.is_def:
1501                 # Blank line between a block of functions and a block of non-functions
1502                 newlines = 1
1503             else:
1504                 newlines = 0
1505         else:
1506             newlines = 2
1507         if current_line.depth and newlines:
1508             newlines -= 1
1509         return newlines, 0
1510
1511
1512 @dataclass
1513 class LineGenerator(Visitor[Line]):
1514     """Generates reformatted Line objects.  Empty lines are not emitted.
1515
1516     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1517     in ways that will no longer stringify to valid Python code on the tree.
1518     """
1519
1520     is_pyi: bool = False
1521     normalize_strings: bool = True
1522     current_line: Line = Factory(Line)
1523     remove_u_prefix: bool = False
1524
1525     def line(self, indent: int = 0) -> Iterator[Line]:
1526         """Generate a line.
1527
1528         If the line is empty, only emit if it makes sense.
1529         If the line is too long, split it first and then generate.
1530
1531         If any lines were generated, set up a new current_line.
1532         """
1533         if not self.current_line:
1534             self.current_line.depth += indent
1535             return  # Line is empty, don't emit. Creating a new one unnecessary.
1536
1537         complete_line = self.current_line
1538         self.current_line = Line(depth=complete_line.depth + indent)
1539         yield complete_line
1540
1541     def visit_default(self, node: LN) -> Iterator[Line]:
1542         """Default `visit_*()` implementation. Recurses to children of `node`."""
1543         if isinstance(node, Leaf):
1544             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1545             for comment in generate_comments(node):
1546                 if any_open_brackets:
1547                     # any comment within brackets is subject to splitting
1548                     self.current_line.append(comment)
1549                 elif comment.type == token.COMMENT:
1550                     # regular trailing comment
1551                     self.current_line.append(comment)
1552                     yield from self.line()
1553
1554                 else:
1555                     # regular standalone comment
1556                     yield from self.line()
1557
1558                     self.current_line.append(comment)
1559                     yield from self.line()
1560
1561             normalize_prefix(node, inside_brackets=any_open_brackets)
1562             if self.normalize_strings and node.type == token.STRING:
1563                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1564                 normalize_string_quotes(node)
1565             if node.type == token.NUMBER:
1566                 normalize_numeric_literal(node)
1567             if node.type not in WHITESPACE:
1568                 self.current_line.append(node)
1569         yield from super().visit_default(node)
1570
1571     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1572         """Increase indentation level, maybe yield a line."""
1573         # In blib2to3 INDENT never holds comments.
1574         yield from self.line(+1)
1575         yield from self.visit_default(node)
1576
1577     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1578         """Decrease indentation level, maybe yield a line."""
1579         # The current line might still wait for trailing comments.  At DEDENT time
1580         # there won't be any (they would be prefixes on the preceding NEWLINE).
1581         # Emit the line then.
1582         yield from self.line()
1583
1584         # While DEDENT has no value, its prefix may contain standalone comments
1585         # that belong to the current indentation level.  Get 'em.
1586         yield from self.visit_default(node)
1587
1588         # Finally, emit the dedent.
1589         yield from self.line(-1)
1590
1591     def visit_stmt(
1592         self, node: Node, keywords: Set[str], parens: Set[str]
1593     ) -> Iterator[Line]:
1594         """Visit a statement.
1595
1596         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1597         `def`, `with`, `class`, `assert` and assignments.
1598
1599         The relevant Python language `keywords` for a given statement will be
1600         NAME leaves within it. This methods puts those on a separate line.
1601
1602         `parens` holds a set of string leaf values immediately after which
1603         invisible parens should be put.
1604         """
1605         normalize_invisible_parens(node, parens_after=parens)
1606         for child in node.children:
1607             if child.type == token.NAME and child.value in keywords:  # type: ignore
1608                 yield from self.line()
1609
1610             yield from self.visit(child)
1611
1612     def visit_suite(self, node: Node) -> Iterator[Line]:
1613         """Visit a suite."""
1614         if self.is_pyi and is_stub_suite(node):
1615             yield from self.visit(node.children[2])
1616         else:
1617             yield from self.visit_default(node)
1618
1619     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1620         """Visit a statement without nested statements."""
1621         is_suite_like = node.parent and node.parent.type in STATEMENT
1622         if is_suite_like:
1623             if self.is_pyi and is_stub_body(node):
1624                 yield from self.visit_default(node)
1625             else:
1626                 yield from self.line(+1)
1627                 yield from self.visit_default(node)
1628                 yield from self.line(-1)
1629
1630         else:
1631             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1632                 yield from self.line()
1633             yield from self.visit_default(node)
1634
1635     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1636         """Visit `async def`, `async for`, `async with`."""
1637         yield from self.line()
1638
1639         children = iter(node.children)
1640         for child in children:
1641             yield from self.visit(child)
1642
1643             if child.type == token.ASYNC:
1644                 break
1645
1646         internal_stmt = next(children)
1647         for child in internal_stmt.children:
1648             yield from self.visit(child)
1649
1650     def visit_decorators(self, node: Node) -> Iterator[Line]:
1651         """Visit decorators."""
1652         for child in node.children:
1653             yield from self.line()
1654             yield from self.visit(child)
1655
1656     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1657         """Remove a semicolon and put the other statement on a separate line."""
1658         yield from self.line()
1659
1660     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1661         """End of file. Process outstanding comments and end with a newline."""
1662         yield from self.visit_default(leaf)
1663         yield from self.line()
1664
1665     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1666         if not self.current_line.bracket_tracker.any_open_brackets():
1667             yield from self.line()
1668         yield from self.visit_default(leaf)
1669
1670     def __attrs_post_init__(self) -> None:
1671         """You are in a twisty little maze of passages."""
1672         v = self.visit_stmt
1673         Ø: Set[str] = set()
1674         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1675         self.visit_if_stmt = partial(
1676             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1677         )
1678         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1679         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1680         self.visit_try_stmt = partial(
1681             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1682         )
1683         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1684         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1685         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1686         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1687         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1688         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1689         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1690         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1691         self.visit_async_funcdef = self.visit_async_stmt
1692         self.visit_decorated = self.visit_decorators
1693
1694
1695 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1696 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1697 OPENING_BRACKETS = set(BRACKET.keys())
1698 CLOSING_BRACKETS = set(BRACKET.values())
1699 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1700 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1701
1702
1703 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1704     """Return whitespace prefix if needed for the given `leaf`.
1705
1706     `complex_subscript` signals whether the given leaf is part of a subscription
1707     which has non-trivial arguments, like arithmetic expressions or function calls.
1708     """
1709     NO = ""
1710     SPACE = " "
1711     DOUBLESPACE = "  "
1712     t = leaf.type
1713     p = leaf.parent
1714     v = leaf.value
1715     if t in ALWAYS_NO_SPACE:
1716         return NO
1717
1718     if t == token.COMMENT:
1719         return DOUBLESPACE
1720
1721     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1722     if t == token.COLON and p.type not in {
1723         syms.subscript,
1724         syms.subscriptlist,
1725         syms.sliceop,
1726     }:
1727         return NO
1728
1729     prev = leaf.prev_sibling
1730     if not prev:
1731         prevp = preceding_leaf(p)
1732         if not prevp or prevp.type in OPENING_BRACKETS:
1733             return NO
1734
1735         if t == token.COLON:
1736             if prevp.type == token.COLON:
1737                 return NO
1738
1739             elif prevp.type != token.COMMA and not complex_subscript:
1740                 return NO
1741
1742             return SPACE
1743
1744         if prevp.type == token.EQUAL:
1745             if prevp.parent:
1746                 if prevp.parent.type in {
1747                     syms.arglist,
1748                     syms.argument,
1749                     syms.parameters,
1750                     syms.varargslist,
1751                 }:
1752                     return NO
1753
1754                 elif prevp.parent.type == syms.typedargslist:
1755                     # A bit hacky: if the equal sign has whitespace, it means we
1756                     # previously found it's a typed argument.  So, we're using
1757                     # that, too.
1758                     return prevp.prefix
1759
1760         elif prevp.type in STARS:
1761             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1762                 return NO
1763
1764         elif prevp.type == token.COLON:
1765             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1766                 return SPACE if complex_subscript else NO
1767
1768         elif (
1769             prevp.parent
1770             and prevp.parent.type == syms.factor
1771             and prevp.type in MATH_OPERATORS
1772         ):
1773             return NO
1774
1775         elif (
1776             prevp.type == token.RIGHTSHIFT
1777             and prevp.parent
1778             and prevp.parent.type == syms.shift_expr
1779             and prevp.prev_sibling
1780             and prevp.prev_sibling.type == token.NAME
1781             and prevp.prev_sibling.value == "print"  # type: ignore
1782         ):
1783             # Python 2 print chevron
1784             return NO
1785
1786     elif prev.type in OPENING_BRACKETS:
1787         return NO
1788
1789     if p.type in {syms.parameters, syms.arglist}:
1790         # untyped function signatures or calls
1791         if not prev or prev.type != token.COMMA:
1792             return NO
1793
1794     elif p.type == syms.varargslist:
1795         # lambdas
1796         if prev and prev.type != token.COMMA:
1797             return NO
1798
1799     elif p.type == syms.typedargslist:
1800         # typed function signatures
1801         if not prev:
1802             return NO
1803
1804         if t == token.EQUAL:
1805             if prev.type != syms.tname:
1806                 return NO
1807
1808         elif prev.type == token.EQUAL:
1809             # A bit hacky: if the equal sign has whitespace, it means we
1810             # previously found it's a typed argument.  So, we're using that, too.
1811             return prev.prefix
1812
1813         elif prev.type != token.COMMA:
1814             return NO
1815
1816     elif p.type == syms.tname:
1817         # type names
1818         if not prev:
1819             prevp = preceding_leaf(p)
1820             if not prevp or prevp.type != token.COMMA:
1821                 return NO
1822
1823     elif p.type == syms.trailer:
1824         # attributes and calls
1825         if t == token.LPAR or t == token.RPAR:
1826             return NO
1827
1828         if not prev:
1829             if t == token.DOT:
1830                 prevp = preceding_leaf(p)
1831                 if not prevp or prevp.type != token.NUMBER:
1832                     return NO
1833
1834             elif t == token.LSQB:
1835                 return NO
1836
1837         elif prev.type != token.COMMA:
1838             return NO
1839
1840     elif p.type == syms.argument:
1841         # single argument
1842         if t == token.EQUAL:
1843             return NO
1844
1845         if not prev:
1846             prevp = preceding_leaf(p)
1847             if not prevp or prevp.type == token.LPAR:
1848                 return NO
1849
1850         elif prev.type in {token.EQUAL} | STARS:
1851             return NO
1852
1853     elif p.type == syms.decorator:
1854         # decorators
1855         return NO
1856
1857     elif p.type == syms.dotted_name:
1858         if prev:
1859             return NO
1860
1861         prevp = preceding_leaf(p)
1862         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1863             return NO
1864
1865     elif p.type == syms.classdef:
1866         if t == token.LPAR:
1867             return NO
1868
1869         if prev and prev.type == token.LPAR:
1870             return NO
1871
1872     elif p.type in {syms.subscript, syms.sliceop}:
1873         # indexing
1874         if not prev:
1875             assert p.parent is not None, "subscripts are always parented"
1876             if p.parent.type == syms.subscriptlist:
1877                 return SPACE
1878
1879             return NO
1880
1881         elif not complex_subscript:
1882             return NO
1883
1884     elif p.type == syms.atom:
1885         if prev and t == token.DOT:
1886             # dots, but not the first one.
1887             return NO
1888
1889     elif p.type == syms.dictsetmaker:
1890         # dict unpacking
1891         if prev and prev.type == token.DOUBLESTAR:
1892             return NO
1893
1894     elif p.type in {syms.factor, syms.star_expr}:
1895         # unary ops
1896         if not prev:
1897             prevp = preceding_leaf(p)
1898             if not prevp or prevp.type in OPENING_BRACKETS:
1899                 return NO
1900
1901             prevp_parent = prevp.parent
1902             assert prevp_parent is not None
1903             if prevp.type == token.COLON and prevp_parent.type in {
1904                 syms.subscript,
1905                 syms.sliceop,
1906             }:
1907                 return NO
1908
1909             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1910                 return NO
1911
1912         elif t in {token.NAME, token.NUMBER, token.STRING}:
1913             return NO
1914
1915     elif p.type == syms.import_from:
1916         if t == token.DOT:
1917             if prev and prev.type == token.DOT:
1918                 return NO
1919
1920         elif t == token.NAME:
1921             if v == "import":
1922                 return SPACE
1923
1924             if prev and prev.type == token.DOT:
1925                 return NO
1926
1927     elif p.type == syms.sliceop:
1928         return NO
1929
1930     return SPACE
1931
1932
1933 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1934     """Return the first leaf that precedes `node`, if any."""
1935     while node:
1936         res = node.prev_sibling
1937         if res:
1938             if isinstance(res, Leaf):
1939                 return res
1940
1941             try:
1942                 return list(res.leaves())[-1]
1943
1944             except IndexError:
1945                 return None
1946
1947         node = node.parent
1948     return None
1949
1950
1951 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1952     """Return the child of `ancestor` that contains `descendant`."""
1953     node: Optional[LN] = descendant
1954     while node and node.parent != ancestor:
1955         node = node.parent
1956     return node
1957
1958
1959 def container_of(leaf: Leaf) -> LN:
1960     """Return `leaf` or one of its ancestors that is the topmost container of it.
1961
1962     By "container" we mean a node where `leaf` is the very first child.
1963     """
1964     same_prefix = leaf.prefix
1965     container: LN = leaf
1966     while container:
1967         parent = container.parent
1968         if parent is None:
1969             break
1970
1971         if parent.children[0].prefix != same_prefix:
1972             break
1973
1974         if parent.type == syms.file_input:
1975             break
1976
1977         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1978             break
1979
1980         container = parent
1981     return container
1982
1983
1984 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1985     """Return the priority of the `leaf` delimiter, given a line break after it.
1986
1987     The delimiter priorities returned here are from those delimiters that would
1988     cause a line break after themselves.
1989
1990     Higher numbers are higher priority.
1991     """
1992     if leaf.type == token.COMMA:
1993         return COMMA_PRIORITY
1994
1995     return 0
1996
1997
1998 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1999     """Return the priority of the `leaf` delimiter, given a line break before it.
2000
2001     The delimiter priorities returned here are from those delimiters that would
2002     cause a line break before themselves.
2003
2004     Higher numbers are higher priority.
2005     """
2006     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2007         # * and ** might also be MATH_OPERATORS but in this case they are not.
2008         # Don't treat them as a delimiter.
2009         return 0
2010
2011     if (
2012         leaf.type == token.DOT
2013         and leaf.parent
2014         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2015         and (previous is None or previous.type in CLOSING_BRACKETS)
2016     ):
2017         return DOT_PRIORITY
2018
2019     if (
2020         leaf.type in MATH_OPERATORS
2021         and leaf.parent
2022         and leaf.parent.type not in {syms.factor, syms.star_expr}
2023     ):
2024         return MATH_PRIORITIES[leaf.type]
2025
2026     if leaf.type in COMPARATORS:
2027         return COMPARATOR_PRIORITY
2028
2029     if (
2030         leaf.type == token.STRING
2031         and previous is not None
2032         and previous.type == token.STRING
2033     ):
2034         return STRING_PRIORITY
2035
2036     if leaf.type not in {token.NAME, token.ASYNC}:
2037         return 0
2038
2039     if (
2040         leaf.value == "for"
2041         and leaf.parent
2042         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2043         or leaf.type == token.ASYNC
2044     ):
2045         if (
2046             not isinstance(leaf.prev_sibling, Leaf)
2047             or leaf.prev_sibling.value != "async"
2048         ):
2049             return COMPREHENSION_PRIORITY
2050
2051     if (
2052         leaf.value == "if"
2053         and leaf.parent
2054         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2055     ):
2056         return COMPREHENSION_PRIORITY
2057
2058     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2059         return TERNARY_PRIORITY
2060
2061     if leaf.value == "is":
2062         return COMPARATOR_PRIORITY
2063
2064     if (
2065         leaf.value == "in"
2066         and leaf.parent
2067         and leaf.parent.type in {syms.comp_op, syms.comparison}
2068         and not (
2069             previous is not None
2070             and previous.type == token.NAME
2071             and previous.value == "not"
2072         )
2073     ):
2074         return COMPARATOR_PRIORITY
2075
2076     if (
2077         leaf.value == "not"
2078         and leaf.parent
2079         and leaf.parent.type == syms.comp_op
2080         and not (
2081             previous is not None
2082             and previous.type == token.NAME
2083             and previous.value == "is"
2084         )
2085     ):
2086         return COMPARATOR_PRIORITY
2087
2088     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2089         return LOGIC_PRIORITY
2090
2091     return 0
2092
2093
2094 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2095 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2096
2097
2098 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2099     """Clean the prefix of the `leaf` and generate comments from it, if any.
2100
2101     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2102     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2103     move because it does away with modifying the grammar to include all the
2104     possible places in which comments can be placed.
2105
2106     The sad consequence for us though is that comments don't "belong" anywhere.
2107     This is why this function generates simple parentless Leaf objects for
2108     comments.  We simply don't know what the correct parent should be.
2109
2110     No matter though, we can live without this.  We really only need to
2111     differentiate between inline and standalone comments.  The latter don't
2112     share the line with any code.
2113
2114     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2115     are emitted with a fake STANDALONE_COMMENT token identifier.
2116     """
2117     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2118         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2119
2120
2121 @dataclass
2122 class ProtoComment:
2123     """Describes a piece of syntax that is a comment.
2124
2125     It's not a :class:`blib2to3.pytree.Leaf` so that:
2126
2127     * it can be cached (`Leaf` objects should not be reused more than once as
2128       they store their lineno, column, prefix, and parent information);
2129     * `newlines` and `consumed` fields are kept separate from the `value`. This
2130       simplifies handling of special marker comments like ``# fmt: off/on``.
2131     """
2132
2133     type: int  # token.COMMENT or STANDALONE_COMMENT
2134     value: str  # content of the comment
2135     newlines: int  # how many newlines before the comment
2136     consumed: int  # how many characters of the original leaf's prefix did we consume
2137
2138
2139 @lru_cache(maxsize=4096)
2140 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2141     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2142     result: List[ProtoComment] = []
2143     if not prefix or "#" not in prefix:
2144         return result
2145
2146     consumed = 0
2147     nlines = 0
2148     ignored_lines = 0
2149     for index, line in enumerate(prefix.split("\n")):
2150         consumed += len(line) + 1  # adding the length of the split '\n'
2151         line = line.lstrip()
2152         if not line:
2153             nlines += 1
2154         if not line.startswith("#"):
2155             # Escaped newlines outside of a comment are not really newlines at
2156             # all. We treat a single-line comment following an escaped newline
2157             # as a simple trailing comment.
2158             if line.endswith("\\"):
2159                 ignored_lines += 1
2160             continue
2161
2162         if index == ignored_lines and not is_endmarker:
2163             comment_type = token.COMMENT  # simple trailing comment
2164         else:
2165             comment_type = STANDALONE_COMMENT
2166         comment = make_comment(line)
2167         result.append(
2168             ProtoComment(
2169                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2170             )
2171         )
2172         nlines = 0
2173     return result
2174
2175
2176 def make_comment(content: str) -> str:
2177     """Return a consistently formatted comment from the given `content` string.
2178
2179     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2180     space between the hash sign and the content.
2181
2182     If `content` didn't start with a hash sign, one is provided.
2183     """
2184     content = content.rstrip()
2185     if not content:
2186         return "#"
2187
2188     if content[0] == "#":
2189         content = content[1:]
2190     if content and content[0] not in " !:#'%":
2191         content = " " + content
2192     return "#" + content
2193
2194
2195 def split_line(
2196     line: Line,
2197     line_length: int,
2198     inner: bool = False,
2199     features: Collection[Feature] = (),
2200 ) -> Iterator[Line]:
2201     """Split a `line` into potentially many lines.
2202
2203     They should fit in the allotted `line_length` but might not be able to.
2204     `inner` signifies that there were a pair of brackets somewhere around the
2205     current `line`, possibly transitively. This means we can fallback to splitting
2206     by delimiters if the LHS/RHS don't yield any results.
2207
2208     `features` are syntactical features that may be used in the output.
2209     """
2210     if line.is_comment:
2211         yield line
2212         return
2213
2214     line_str = str(line).strip("\n")
2215
2216     if (
2217         not line.contains_inner_type_comments()
2218         and not line.should_explode
2219         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2220     ):
2221         yield line
2222         return
2223
2224     split_funcs: List[SplitFunc]
2225     if line.is_def:
2226         split_funcs = [left_hand_split]
2227     else:
2228
2229         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2230             for omit in generate_trailers_to_omit(line, line_length):
2231                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2232                 if is_line_short_enough(lines[0], line_length=line_length):
2233                     yield from lines
2234                     return
2235
2236             # All splits failed, best effort split with no omits.
2237             # This mostly happens to multiline strings that are by definition
2238             # reported as not fitting a single line.
2239             yield from right_hand_split(line, line_length, features=features)
2240
2241         if line.inside_brackets:
2242             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2243         else:
2244             split_funcs = [rhs]
2245     for split_func in split_funcs:
2246         # We are accumulating lines in `result` because we might want to abort
2247         # mission and return the original line in the end, or attempt a different
2248         # split altogether.
2249         result: List[Line] = []
2250         try:
2251             for l in split_func(line, features):
2252                 if str(l).strip("\n") == line_str:
2253                     raise CannotSplit("Split function returned an unchanged result")
2254
2255                 result.extend(
2256                     split_line(
2257                         l, line_length=line_length, inner=True, features=features
2258                     )
2259                 )
2260         except CannotSplit:
2261             continue
2262
2263         else:
2264             yield from result
2265             break
2266
2267     else:
2268         yield line
2269
2270
2271 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2272     """Split line into many lines, starting with the first matching bracket pair.
2273
2274     Note: this usually looks weird, only use this for function definitions.
2275     Prefer RHS otherwise.  This is why this function is not symmetrical with
2276     :func:`right_hand_split` which also handles optional parentheses.
2277     """
2278     tail_leaves: List[Leaf] = []
2279     body_leaves: List[Leaf] = []
2280     head_leaves: List[Leaf] = []
2281     current_leaves = head_leaves
2282     matching_bracket = None
2283     for leaf in line.leaves:
2284         if (
2285             current_leaves is body_leaves
2286             and leaf.type in CLOSING_BRACKETS
2287             and leaf.opening_bracket is matching_bracket
2288         ):
2289             current_leaves = tail_leaves if body_leaves else head_leaves
2290         current_leaves.append(leaf)
2291         if current_leaves is head_leaves:
2292             if leaf.type in OPENING_BRACKETS:
2293                 matching_bracket = leaf
2294                 current_leaves = body_leaves
2295     if not matching_bracket:
2296         raise CannotSplit("No brackets found")
2297
2298     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2299     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2300     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2301     bracket_split_succeeded_or_raise(head, body, tail)
2302     for result in (head, body, tail):
2303         if result:
2304             yield result
2305
2306
2307 def right_hand_split(
2308     line: Line,
2309     line_length: int,
2310     features: Collection[Feature] = (),
2311     omit: Collection[LeafID] = (),
2312 ) -> Iterator[Line]:
2313     """Split line into many lines, starting with the last matching bracket pair.
2314
2315     If the split was by optional parentheses, attempt splitting without them, too.
2316     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2317     this split.
2318
2319     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2320     """
2321     tail_leaves: List[Leaf] = []
2322     body_leaves: List[Leaf] = []
2323     head_leaves: List[Leaf] = []
2324     current_leaves = tail_leaves
2325     opening_bracket = None
2326     closing_bracket = None
2327     for leaf in reversed(line.leaves):
2328         if current_leaves is body_leaves:
2329             if leaf is opening_bracket:
2330                 current_leaves = head_leaves if body_leaves else tail_leaves
2331         current_leaves.append(leaf)
2332         if current_leaves is tail_leaves:
2333             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2334                 opening_bracket = leaf.opening_bracket
2335                 closing_bracket = leaf
2336                 current_leaves = body_leaves
2337     if not (opening_bracket and closing_bracket and head_leaves):
2338         # If there is no opening or closing_bracket that means the split failed and
2339         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2340         # the matching `opening_bracket` wasn't available on `line` anymore.
2341         raise CannotSplit("No brackets found")
2342
2343     tail_leaves.reverse()
2344     body_leaves.reverse()
2345     head_leaves.reverse()
2346     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2347     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2348     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2349     bracket_split_succeeded_or_raise(head, body, tail)
2350     if (
2351         # the body shouldn't be exploded
2352         not body.should_explode
2353         # the opening bracket is an optional paren
2354         and opening_bracket.type == token.LPAR
2355         and not opening_bracket.value
2356         # the closing bracket is an optional paren
2357         and closing_bracket.type == token.RPAR
2358         and not closing_bracket.value
2359         # it's not an import (optional parens are the only thing we can split on
2360         # in this case; attempting a split without them is a waste of time)
2361         and not line.is_import
2362         # there are no standalone comments in the body
2363         and not body.contains_standalone_comments(0)
2364         # and we can actually remove the parens
2365         and can_omit_invisible_parens(body, line_length)
2366     ):
2367         omit = {id(closing_bracket), *omit}
2368         try:
2369             yield from right_hand_split(line, line_length, features=features, omit=omit)
2370             return
2371
2372         except CannotSplit:
2373             if not (
2374                 can_be_split(body)
2375                 or is_line_short_enough(body, line_length=line_length)
2376             ):
2377                 raise CannotSplit(
2378                     "Splitting failed, body is still too long and can't be split."
2379                 )
2380
2381             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2382                 raise CannotSplit(
2383                     "The current optional pair of parentheses is bound to fail to "
2384                     "satisfy the splitting algorithm because the head or the tail "
2385                     "contains multiline strings which by definition never fit one "
2386                     "line."
2387                 )
2388
2389     ensure_visible(opening_bracket)
2390     ensure_visible(closing_bracket)
2391     for result in (head, body, tail):
2392         if result:
2393             yield result
2394
2395
2396 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2397     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2398
2399     Do nothing otherwise.
2400
2401     A left- or right-hand split is based on a pair of brackets. Content before
2402     (and including) the opening bracket is left on one line, content inside the
2403     brackets is put on a separate line, and finally content starting with and
2404     following the closing bracket is put on a separate line.
2405
2406     Those are called `head`, `body`, and `tail`, respectively. If the split
2407     produced the same line (all content in `head`) or ended up with an empty `body`
2408     and the `tail` is just the closing bracket, then it's considered failed.
2409     """
2410     tail_len = len(str(tail).strip())
2411     if not body:
2412         if tail_len == 0:
2413             raise CannotSplit("Splitting brackets produced the same line")
2414
2415         elif tail_len < 3:
2416             raise CannotSplit(
2417                 f"Splitting brackets on an empty body to save "
2418                 f"{tail_len} characters is not worth it"
2419             )
2420
2421
2422 def bracket_split_build_line(
2423     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2424 ) -> Line:
2425     """Return a new line with given `leaves` and respective comments from `original`.
2426
2427     If `is_body` is True, the result line is one-indented inside brackets and as such
2428     has its first leaf's prefix normalized and a trailing comma added when expected.
2429     """
2430     result = Line(depth=original.depth)
2431     if is_body:
2432         result.inside_brackets = True
2433         result.depth += 1
2434         if leaves:
2435             # Since body is a new indent level, remove spurious leading whitespace.
2436             normalize_prefix(leaves[0], inside_brackets=True)
2437             # Ensure a trailing comma for imports, but be careful not to add one after
2438             # any comments.
2439             if original.is_import:
2440                 for i in range(len(leaves) - 1, -1, -1):
2441                     if leaves[i].type == STANDALONE_COMMENT:
2442                         continue
2443                     elif leaves[i].type == token.COMMA:
2444                         break
2445                     else:
2446                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2447                         break
2448     # Populate the line
2449     for leaf in leaves:
2450         result.append(leaf, preformatted=True)
2451         for comment_after in original.comments_after(leaf):
2452             result.append(comment_after, preformatted=True)
2453     if is_body:
2454         result.should_explode = should_explode(result, opening_bracket)
2455     return result
2456
2457
2458 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2459     """Normalize prefix of the first leaf in every line returned by `split_func`.
2460
2461     This is a decorator over relevant split functions.
2462     """
2463
2464     @wraps(split_func)
2465     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2466         for l in split_func(line, features):
2467             normalize_prefix(l.leaves[0], inside_brackets=True)
2468             yield l
2469
2470     return split_wrapper
2471
2472
2473 @dont_increase_indentation
2474 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2475     """Split according to delimiters of the highest priority.
2476
2477     If the appropriate Features are given, the split will add trailing commas
2478     also in function signatures and calls that contain `*` and `**`.
2479     """
2480     try:
2481         last_leaf = line.leaves[-1]
2482     except IndexError:
2483         raise CannotSplit("Line empty")
2484
2485     bt = line.bracket_tracker
2486     try:
2487         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2488     except ValueError:
2489         raise CannotSplit("No delimiters found")
2490
2491     if delimiter_priority == DOT_PRIORITY:
2492         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2493             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2494
2495     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2496     lowest_depth = sys.maxsize
2497     trailing_comma_safe = True
2498
2499     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2500         """Append `leaf` to current line or to new line if appending impossible."""
2501         nonlocal current_line
2502         try:
2503             current_line.append_safe(leaf, preformatted=True)
2504         except ValueError:
2505             yield current_line
2506
2507             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2508             current_line.append(leaf)
2509
2510     for leaf in line.leaves:
2511         yield from append_to_line(leaf)
2512
2513         for comment_after in line.comments_after(leaf):
2514             yield from append_to_line(comment_after)
2515
2516         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2517         if leaf.bracket_depth == lowest_depth:
2518             if is_vararg(leaf, within={syms.typedargslist}):
2519                 trailing_comma_safe = (
2520                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2521                 )
2522             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2523                 trailing_comma_safe = (
2524                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2525                 )
2526
2527         leaf_priority = bt.delimiters.get(id(leaf))
2528         if leaf_priority == delimiter_priority:
2529             yield current_line
2530
2531             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2532     if current_line:
2533         if (
2534             trailing_comma_safe
2535             and delimiter_priority == COMMA_PRIORITY
2536             and current_line.leaves[-1].type != token.COMMA
2537             and current_line.leaves[-1].type != STANDALONE_COMMENT
2538         ):
2539             current_line.append(Leaf(token.COMMA, ","))
2540         yield current_line
2541
2542
2543 @dont_increase_indentation
2544 def standalone_comment_split(
2545     line: Line, features: Collection[Feature] = ()
2546 ) -> Iterator[Line]:
2547     """Split standalone comments from the rest of the line."""
2548     if not line.contains_standalone_comments(0):
2549         raise CannotSplit("Line does not have any standalone comments")
2550
2551     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2552
2553     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2554         """Append `leaf` to current line or to new line if appending impossible."""
2555         nonlocal current_line
2556         try:
2557             current_line.append_safe(leaf, preformatted=True)
2558         except ValueError:
2559             yield current_line
2560
2561             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2562             current_line.append(leaf)
2563
2564     for leaf in line.leaves:
2565         yield from append_to_line(leaf)
2566
2567         for comment_after in line.comments_after(leaf):
2568             yield from append_to_line(comment_after)
2569
2570     if current_line:
2571         yield current_line
2572
2573
2574 def is_import(leaf: Leaf) -> bool:
2575     """Return True if the given leaf starts an import statement."""
2576     p = leaf.parent
2577     t = leaf.type
2578     v = leaf.value
2579     return bool(
2580         t == token.NAME
2581         and (
2582             (v == "import" and p and p.type == syms.import_name)
2583             or (v == "from" and p and p.type == syms.import_from)
2584         )
2585     )
2586
2587
2588 def is_type_comment(leaf: Leaf) -> bool:
2589     """Return True if the given leaf is a special comment.
2590     Only returns true for type comments for now."""
2591     t = leaf.type
2592     v = leaf.value
2593     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2594
2595
2596 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2597     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2598     else.
2599
2600     Note: don't use backslashes for formatting or you'll lose your voting rights.
2601     """
2602     if not inside_brackets:
2603         spl = leaf.prefix.split("#")
2604         if "\\" not in spl[0]:
2605             nl_count = spl[-1].count("\n")
2606             if len(spl) > 1:
2607                 nl_count -= 1
2608             leaf.prefix = "\n" * nl_count
2609             return
2610
2611     leaf.prefix = ""
2612
2613
2614 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2615     """Make all string prefixes lowercase.
2616
2617     If remove_u_prefix is given, also removes any u prefix from the string.
2618
2619     Note: Mutates its argument.
2620     """
2621     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2622     assert match is not None, f"failed to match string {leaf.value!r}"
2623     orig_prefix = match.group(1)
2624     new_prefix = orig_prefix.lower()
2625     if remove_u_prefix:
2626         new_prefix = new_prefix.replace("u", "")
2627     leaf.value = f"{new_prefix}{match.group(2)}"
2628
2629
2630 def normalize_string_quotes(leaf: Leaf) -> None:
2631     """Prefer double quotes but only if it doesn't cause more escaping.
2632
2633     Adds or removes backslashes as appropriate. Doesn't parse and fix
2634     strings nested in f-strings (yet).
2635
2636     Note: Mutates its argument.
2637     """
2638     value = leaf.value.lstrip("furbFURB")
2639     if value[:3] == '"""':
2640         return
2641
2642     elif value[:3] == "'''":
2643         orig_quote = "'''"
2644         new_quote = '"""'
2645     elif value[0] == '"':
2646         orig_quote = '"'
2647         new_quote = "'"
2648     else:
2649         orig_quote = "'"
2650         new_quote = '"'
2651     first_quote_pos = leaf.value.find(orig_quote)
2652     if first_quote_pos == -1:
2653         return  # There's an internal error
2654
2655     prefix = leaf.value[:first_quote_pos]
2656     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2657     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2658     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2659     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2660     if "r" in prefix.casefold():
2661         if unescaped_new_quote.search(body):
2662             # There's at least one unescaped new_quote in this raw string
2663             # so converting is impossible
2664             return
2665
2666         # Do not introduce or remove backslashes in raw strings
2667         new_body = body
2668     else:
2669         # remove unnecessary escapes
2670         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2671         if body != new_body:
2672             # Consider the string without unnecessary escapes as the original
2673             body = new_body
2674             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2675         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2676         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2677     if "f" in prefix.casefold():
2678         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2679         for m in matches:
2680             if "\\" in str(m):
2681                 # Do not introduce backslashes in interpolated expressions
2682                 return
2683     if new_quote == '"""' and new_body[-1:] == '"':
2684         # edge case:
2685         new_body = new_body[:-1] + '\\"'
2686     orig_escape_count = body.count("\\")
2687     new_escape_count = new_body.count("\\")
2688     if new_escape_count > orig_escape_count:
2689         return  # Do not introduce more escaping
2690
2691     if new_escape_count == orig_escape_count and orig_quote == '"':
2692         return  # Prefer double quotes
2693
2694     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2695
2696
2697 def normalize_numeric_literal(leaf: Leaf) -> None:
2698     """Normalizes numeric (float, int, and complex) literals.
2699
2700     All letters used in the representation are normalized to lowercase (except
2701     in Python 2 long literals).
2702     """
2703     text = leaf.value.lower()
2704     if text.startswith(("0o", "0b")):
2705         # Leave octal and binary literals alone.
2706         pass
2707     elif text.startswith("0x"):
2708         # Change hex literals to upper case.
2709         before, after = text[:2], text[2:]
2710         text = f"{before}{after.upper()}"
2711     elif "e" in text:
2712         before, after = text.split("e")
2713         sign = ""
2714         if after.startswith("-"):
2715             after = after[1:]
2716             sign = "-"
2717         elif after.startswith("+"):
2718             after = after[1:]
2719         before = format_float_or_int_string(before)
2720         text = f"{before}e{sign}{after}"
2721     elif text.endswith(("j", "l")):
2722         number = text[:-1]
2723         suffix = text[-1]
2724         # Capitalize in "2L" because "l" looks too similar to "1".
2725         if suffix == "l":
2726             suffix = "L"
2727         text = f"{format_float_or_int_string(number)}{suffix}"
2728     else:
2729         text = format_float_or_int_string(text)
2730     leaf.value = text
2731
2732
2733 def format_float_or_int_string(text: str) -> str:
2734     """Formats a float string like "1.0"."""
2735     if "." not in text:
2736         return text
2737
2738     before, after = text.split(".")
2739     return f"{before or 0}.{after or 0}"
2740
2741
2742 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2743     """Make existing optional parentheses invisible or create new ones.
2744
2745     `parens_after` is a set of string leaf values immeditely after which parens
2746     should be put.
2747
2748     Standardizes on visible parentheses for single-element tuples, and keeps
2749     existing visible parentheses for other tuples and generator expressions.
2750     """
2751     for pc in list_comments(node.prefix, is_endmarker=False):
2752         if pc.value in FMT_OFF:
2753             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2754             return
2755
2756     check_lpar = False
2757     for index, child in enumerate(list(node.children)):
2758         # Add parentheses around long tuple unpacking in assignments.
2759         if (
2760             index == 0
2761             and isinstance(child, Node)
2762             and child.type == syms.testlist_star_expr
2763         ):
2764             check_lpar = True
2765
2766         if check_lpar:
2767             if child.type == syms.atom:
2768                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2769                     lpar = Leaf(token.LPAR, "")
2770                     rpar = Leaf(token.RPAR, "")
2771                     index = child.remove() or 0
2772                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2773             elif is_one_tuple(child):
2774                 # wrap child in visible parentheses
2775                 lpar = Leaf(token.LPAR, "(")
2776                 rpar = Leaf(token.RPAR, ")")
2777                 child.remove()
2778                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2779             elif node.type == syms.import_from:
2780                 # "import from" nodes store parentheses directly as part of
2781                 # the statement
2782                 if child.type == token.LPAR:
2783                     # make parentheses invisible
2784                     child.value = ""  # type: ignore
2785                     node.children[-1].value = ""  # type: ignore
2786                 elif child.type != token.STAR:
2787                     # insert invisible parentheses
2788                     node.insert_child(index, Leaf(token.LPAR, ""))
2789                     node.append_child(Leaf(token.RPAR, ""))
2790                 break
2791
2792             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2793                 # wrap child in invisible parentheses
2794                 lpar = Leaf(token.LPAR, "")
2795                 rpar = Leaf(token.RPAR, "")
2796                 index = child.remove() or 0
2797                 prefix = child.prefix
2798                 child.prefix = ""
2799                 new_child = Node(syms.atom, [lpar, child, rpar])
2800                 new_child.prefix = prefix
2801                 node.insert_child(index, new_child)
2802
2803         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2804
2805
2806 def normalize_fmt_off(node: Node) -> None:
2807     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2808     try_again = True
2809     while try_again:
2810         try_again = convert_one_fmt_off_pair(node)
2811
2812
2813 def convert_one_fmt_off_pair(node: Node) -> bool:
2814     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2815
2816     Returns True if a pair was converted.
2817     """
2818     for leaf in node.leaves():
2819         previous_consumed = 0
2820         for comment in list_comments(leaf.prefix, is_endmarker=False):
2821             if comment.value in FMT_OFF:
2822                 # We only want standalone comments. If there's no previous leaf or
2823                 # the previous leaf is indentation, it's a standalone comment in
2824                 # disguise.
2825                 if comment.type != STANDALONE_COMMENT:
2826                     prev = preceding_leaf(leaf)
2827                     if prev and prev.type not in WHITESPACE:
2828                         continue
2829
2830                 ignored_nodes = list(generate_ignored_nodes(leaf))
2831                 if not ignored_nodes:
2832                     continue
2833
2834                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2835                 parent = first.parent
2836                 prefix = first.prefix
2837                 first.prefix = prefix[comment.consumed :]
2838                 hidden_value = (
2839                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2840                 )
2841                 if hidden_value.endswith("\n"):
2842                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2843                     # leaf (possibly followed by a DEDENT).
2844                     hidden_value = hidden_value[:-1]
2845                 first_idx = None
2846                 for ignored in ignored_nodes:
2847                     index = ignored.remove()
2848                     if first_idx is None:
2849                         first_idx = index
2850                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2851                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2852                 parent.insert_child(
2853                     first_idx,
2854                     Leaf(
2855                         STANDALONE_COMMENT,
2856                         hidden_value,
2857                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2858                     ),
2859                 )
2860                 return True
2861
2862             previous_consumed = comment.consumed
2863
2864     return False
2865
2866
2867 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2868     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2869
2870     Stops at the end of the block.
2871     """
2872     container: Optional[LN] = container_of(leaf)
2873     while container is not None and container.type != token.ENDMARKER:
2874         for comment in list_comments(container.prefix, is_endmarker=False):
2875             if comment.value in FMT_ON:
2876                 return
2877
2878         yield container
2879
2880         container = container.next_sibling
2881
2882
2883 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
2884     """If it's safe, make the parens in the atom `node` invisible, recursively.
2885
2886     Returns whether the node should itself be wrapped in invisible parentheses.
2887
2888     """
2889     if (
2890         node.type != syms.atom
2891         or is_empty_tuple(node)
2892         or is_one_tuple(node)
2893         or (is_yield(node) and parent.type != syms.expr_stmt)
2894         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2895     ):
2896         return False
2897
2898     first = node.children[0]
2899     last = node.children[-1]
2900     if first.type == token.LPAR and last.type == token.RPAR:
2901         # make parentheses invisible
2902         first.value = ""  # type: ignore
2903         last.value = ""  # type: ignore
2904         if len(node.children) > 1:
2905             maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
2906         return False
2907
2908     return True
2909
2910
2911 def is_empty_tuple(node: LN) -> bool:
2912     """Return True if `node` holds an empty tuple."""
2913     return (
2914         node.type == syms.atom
2915         and len(node.children) == 2
2916         and node.children[0].type == token.LPAR
2917         and node.children[1].type == token.RPAR
2918     )
2919
2920
2921 def is_one_tuple(node: LN) -> bool:
2922     """Return True if `node` holds a tuple with one element, with or without parens."""
2923     if node.type == syms.atom:
2924         if len(node.children) != 3:
2925             return False
2926
2927         lpar, gexp, rpar = node.children
2928         if not (
2929             lpar.type == token.LPAR
2930             and gexp.type == syms.testlist_gexp
2931             and rpar.type == token.RPAR
2932         ):
2933             return False
2934
2935         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2936
2937     return (
2938         node.type in IMPLICIT_TUPLE
2939         and len(node.children) == 2
2940         and node.children[1].type == token.COMMA
2941     )
2942
2943
2944 def is_yield(node: LN) -> bool:
2945     """Return True if `node` holds a `yield` or `yield from` expression."""
2946     if node.type == syms.yield_expr:
2947         return True
2948
2949     if node.type == token.NAME and node.value == "yield":  # type: ignore
2950         return True
2951
2952     if node.type != syms.atom:
2953         return False
2954
2955     if len(node.children) != 3:
2956         return False
2957
2958     lpar, expr, rpar = node.children
2959     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2960         return is_yield(expr)
2961
2962     return False
2963
2964
2965 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2966     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2967
2968     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2969     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2970     extended iterable unpacking (PEP 3132) and additional unpacking
2971     generalizations (PEP 448).
2972     """
2973     if leaf.type not in STARS or not leaf.parent:
2974         return False
2975
2976     p = leaf.parent
2977     if p.type == syms.star_expr:
2978         # Star expressions are also used as assignment targets in extended
2979         # iterable unpacking (PEP 3132).  See what its parent is instead.
2980         if not p.parent:
2981             return False
2982
2983         p = p.parent
2984
2985     return p.type in within
2986
2987
2988 def is_multiline_string(leaf: Leaf) -> bool:
2989     """Return True if `leaf` is a multiline string that actually spans many lines."""
2990     value = leaf.value.lstrip("furbFURB")
2991     return value[:3] in {'"""', "'''"} and "\n" in value
2992
2993
2994 def is_stub_suite(node: Node) -> bool:
2995     """Return True if `node` is a suite with a stub body."""
2996     if (
2997         len(node.children) != 4
2998         or node.children[0].type != token.NEWLINE
2999         or node.children[1].type != token.INDENT
3000         or node.children[3].type != token.DEDENT
3001     ):
3002         return False
3003
3004     return is_stub_body(node.children[2])
3005
3006
3007 def is_stub_body(node: LN) -> bool:
3008     """Return True if `node` is a simple statement containing an ellipsis."""
3009     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3010         return False
3011
3012     if len(node.children) != 2:
3013         return False
3014
3015     child = node.children[0]
3016     return (
3017         child.type == syms.atom
3018         and len(child.children) == 3
3019         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3020     )
3021
3022
3023 def max_delimiter_priority_in_atom(node: LN) -> int:
3024     """Return maximum delimiter priority inside `node`.
3025
3026     This is specific to atoms with contents contained in a pair of parentheses.
3027     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3028     """
3029     if node.type != syms.atom:
3030         return 0
3031
3032     first = node.children[0]
3033     last = node.children[-1]
3034     if not (first.type == token.LPAR and last.type == token.RPAR):
3035         return 0
3036
3037     bt = BracketTracker()
3038     for c in node.children[1:-1]:
3039         if isinstance(c, Leaf):
3040             bt.mark(c)
3041         else:
3042             for leaf in c.leaves():
3043                 bt.mark(leaf)
3044     try:
3045         return bt.max_delimiter_priority()
3046
3047     except ValueError:
3048         return 0
3049
3050
3051 def ensure_visible(leaf: Leaf) -> None:
3052     """Make sure parentheses are visible.
3053
3054     They could be invisible as part of some statements (see
3055     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3056     """
3057     if leaf.type == token.LPAR:
3058         leaf.value = "("
3059     elif leaf.type == token.RPAR:
3060         leaf.value = ")"
3061
3062
3063 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3064     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3065
3066     if not (
3067         opening_bracket.parent
3068         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3069         and opening_bracket.value in "[{("
3070     ):
3071         return False
3072
3073     try:
3074         last_leaf = line.leaves[-1]
3075         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3076         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3077     except (IndexError, ValueError):
3078         return False
3079
3080     return max_priority == COMMA_PRIORITY
3081
3082
3083 def get_features_used(node: Node) -> Set[Feature]:
3084     """Return a set of (relatively) new Python features used in this file.
3085
3086     Currently looking for:
3087     - f-strings;
3088     - underscores in numeric literals; and
3089     - trailing commas after * or ** in function signatures and calls.
3090     """
3091     features: Set[Feature] = set()
3092     for n in node.pre_order():
3093         if n.type == token.STRING:
3094             value_head = n.value[:2]  # type: ignore
3095             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3096                 features.add(Feature.F_STRINGS)
3097
3098         elif n.type == token.NUMBER:
3099             if "_" in n.value:  # type: ignore
3100                 features.add(Feature.NUMERIC_UNDERSCORES)
3101
3102         elif (
3103             n.type in {syms.typedargslist, syms.arglist}
3104             and n.children
3105             and n.children[-1].type == token.COMMA
3106         ):
3107             if n.type == syms.typedargslist:
3108                 feature = Feature.TRAILING_COMMA_IN_DEF
3109             else:
3110                 feature = Feature.TRAILING_COMMA_IN_CALL
3111
3112             for ch in n.children:
3113                 if ch.type in STARS:
3114                     features.add(feature)
3115
3116                 if ch.type == syms.argument:
3117                     for argch in ch.children:
3118                         if argch.type in STARS:
3119                             features.add(feature)
3120
3121     return features
3122
3123
3124 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3125     """Detect the version to target based on the nodes used."""
3126     features = get_features_used(node)
3127     return {
3128         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3129     }
3130
3131
3132 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3133     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3134
3135     Brackets can be omitted if the entire trailer up to and including
3136     a preceding closing bracket fits in one line.
3137
3138     Yielded sets are cumulative (contain results of previous yields, too).  First
3139     set is empty.
3140     """
3141
3142     omit: Set[LeafID] = set()
3143     yield omit
3144
3145     length = 4 * line.depth
3146     opening_bracket = None
3147     closing_bracket = None
3148     inner_brackets: Set[LeafID] = set()
3149     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3150         length += leaf_length
3151         if length > line_length:
3152             break
3153
3154         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3155         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3156             break
3157
3158         if opening_bracket:
3159             if leaf is opening_bracket:
3160                 opening_bracket = None
3161             elif leaf.type in CLOSING_BRACKETS:
3162                 inner_brackets.add(id(leaf))
3163         elif leaf.type in CLOSING_BRACKETS:
3164             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3165                 # Empty brackets would fail a split so treat them as "inner"
3166                 # brackets (e.g. only add them to the `omit` set if another
3167                 # pair of brackets was good enough.
3168                 inner_brackets.add(id(leaf))
3169                 continue
3170
3171             if closing_bracket:
3172                 omit.add(id(closing_bracket))
3173                 omit.update(inner_brackets)
3174                 inner_brackets.clear()
3175                 yield omit
3176
3177             if leaf.value:
3178                 opening_bracket = leaf.opening_bracket
3179                 closing_bracket = leaf
3180
3181
3182 def get_future_imports(node: Node) -> Set[str]:
3183     """Return a set of __future__ imports in the file."""
3184     imports: Set[str] = set()
3185
3186     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3187         for child in children:
3188             if isinstance(child, Leaf):
3189                 if child.type == token.NAME:
3190                     yield child.value
3191             elif child.type == syms.import_as_name:
3192                 orig_name = child.children[0]
3193                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3194                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3195                 yield orig_name.value
3196             elif child.type == syms.import_as_names:
3197                 yield from get_imports_from_children(child.children)
3198             else:
3199                 raise AssertionError("Invalid syntax parsing imports")
3200
3201     for child in node.children:
3202         if child.type != syms.simple_stmt:
3203             break
3204         first_child = child.children[0]
3205         if isinstance(first_child, Leaf):
3206             # Continue looking if we see a docstring; otherwise stop.
3207             if (
3208                 len(child.children) == 2
3209                 and first_child.type == token.STRING
3210                 and child.children[1].type == token.NEWLINE
3211             ):
3212                 continue
3213             else:
3214                 break
3215         elif first_child.type == syms.import_from:
3216             module_name = first_child.children[1]
3217             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3218                 break
3219             imports |= set(get_imports_from_children(first_child.children[3:]))
3220         else:
3221             break
3222     return imports
3223
3224
3225 def gen_python_files_in_dir(
3226     path: Path,
3227     root: Path,
3228     include: Pattern[str],
3229     exclude: Pattern[str],
3230     report: "Report",
3231 ) -> Iterator[Path]:
3232     """Generate all files under `path` whose paths are not excluded by the
3233     `exclude` regex, but are included by the `include` regex.
3234
3235     Symbolic links pointing outside of the `root` directory are ignored.
3236
3237     `report` is where output about exclusions goes.
3238     """
3239     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3240     for child in path.iterdir():
3241         try:
3242             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3243         except ValueError:
3244             if child.is_symlink():
3245                 report.path_ignored(
3246                     child, f"is a symbolic link that points outside {root}"
3247                 )
3248                 continue
3249
3250             raise
3251
3252         if child.is_dir():
3253             normalized_path += "/"
3254         exclude_match = exclude.search(normalized_path)
3255         if exclude_match and exclude_match.group(0):
3256             report.path_ignored(child, f"matches the --exclude regular expression")
3257             continue
3258
3259         if child.is_dir():
3260             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3261
3262         elif child.is_file():
3263             include_match = include.search(normalized_path)
3264             if include_match:
3265                 yield child
3266
3267
3268 @lru_cache()
3269 def find_project_root(srcs: Iterable[str]) -> Path:
3270     """Return a directory containing .git, .hg, or pyproject.toml.
3271
3272     That directory can be one of the directories passed in `srcs` or their
3273     common parent.
3274
3275     If no directory in the tree contains a marker that would specify it's the
3276     project root, the root of the file system is returned.
3277     """
3278     if not srcs:
3279         return Path("/").resolve()
3280
3281     common_base = min(Path(src).resolve() for src in srcs)
3282     if common_base.is_dir():
3283         # Append a fake file so `parents` below returns `common_base_dir`, too.
3284         common_base /= "fake-file"
3285     for directory in common_base.parents:
3286         if (directory / ".git").is_dir():
3287             return directory
3288
3289         if (directory / ".hg").is_dir():
3290             return directory
3291
3292         if (directory / "pyproject.toml").is_file():
3293             return directory
3294
3295     return directory
3296
3297
3298 @dataclass
3299 class Report:
3300     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3301
3302     check: bool = False
3303     quiet: bool = False
3304     verbose: bool = False
3305     change_count: int = 0
3306     same_count: int = 0
3307     failure_count: int = 0
3308
3309     def done(self, src: Path, changed: Changed) -> None:
3310         """Increment the counter for successful reformatting. Write out a message."""
3311         if changed is Changed.YES:
3312             reformatted = "would reformat" if self.check else "reformatted"
3313             if self.verbose or not self.quiet:
3314                 out(f"{reformatted} {src}")
3315             self.change_count += 1
3316         else:
3317             if self.verbose:
3318                 if changed is Changed.NO:
3319                     msg = f"{src} already well formatted, good job."
3320                 else:
3321                     msg = f"{src} wasn't modified on disk since last run."
3322                 out(msg, bold=False)
3323             self.same_count += 1
3324
3325     def failed(self, src: Path, message: str) -> None:
3326         """Increment the counter for failed reformatting. Write out a message."""
3327         err(f"error: cannot format {src}: {message}")
3328         self.failure_count += 1
3329
3330     def path_ignored(self, path: Path, message: str) -> None:
3331         if self.verbose:
3332             out(f"{path} ignored: {message}", bold=False)
3333
3334     @property
3335     def return_code(self) -> int:
3336         """Return the exit code that the app should use.
3337
3338         This considers the current state of changed files and failures:
3339         - if there were any failures, return 123;
3340         - if any files were changed and --check is being used, return 1;
3341         - otherwise return 0.
3342         """
3343         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3344         # 126 we have special return codes reserved by the shell.
3345         if self.failure_count:
3346             return 123
3347
3348         elif self.change_count and self.check:
3349             return 1
3350
3351         return 0
3352
3353     def __str__(self) -> str:
3354         """Render a color report of the current state.
3355
3356         Use `click.unstyle` to remove colors.
3357         """
3358         if self.check:
3359             reformatted = "would be reformatted"
3360             unchanged = "would be left unchanged"
3361             failed = "would fail to reformat"
3362         else:
3363             reformatted = "reformatted"
3364             unchanged = "left unchanged"
3365             failed = "failed to reformat"
3366         report = []
3367         if self.change_count:
3368             s = "s" if self.change_count > 1 else ""
3369             report.append(
3370                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3371             )
3372         if self.same_count:
3373             s = "s" if self.same_count > 1 else ""
3374             report.append(f"{self.same_count} file{s} {unchanged}")
3375         if self.failure_count:
3376             s = "s" if self.failure_count > 1 else ""
3377             report.append(
3378                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3379             )
3380         return ", ".join(report) + "."
3381
3382
3383 def assert_equivalent(src: str, dst: str) -> None:
3384     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3385
3386     import ast
3387     import traceback
3388
3389     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3390         """Simple visitor generating strings to compare ASTs by content."""
3391         yield f"{'  ' * depth}{node.__class__.__name__}("
3392
3393         for field in sorted(node._fields):
3394             try:
3395                 value = getattr(node, field)
3396             except AttributeError:
3397                 continue
3398
3399             yield f"{'  ' * (depth+1)}{field}="
3400
3401             if isinstance(value, list):
3402                 for item in value:
3403                     # Ignore nested tuples within del statements, because we may insert
3404                     # parentheses and they change the AST.
3405                     if (
3406                         field == "targets"
3407                         and isinstance(node, ast.Delete)
3408                         and isinstance(item, ast.Tuple)
3409                     ):
3410                         for item in item.elts:
3411                             yield from _v(item, depth + 2)
3412                     elif isinstance(item, ast.AST):
3413                         yield from _v(item, depth + 2)
3414
3415             elif isinstance(value, ast.AST):
3416                 yield from _v(value, depth + 2)
3417
3418             else:
3419                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3420
3421         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3422
3423     try:
3424         src_ast = ast.parse(src)
3425     except Exception as exc:
3426         major, minor = sys.version_info[:2]
3427         raise AssertionError(
3428             f"cannot use --safe with this file; failed to parse source file "
3429             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3430             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3431         )
3432
3433     try:
3434         dst_ast = ast.parse(dst)
3435     except Exception as exc:
3436         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3437         raise AssertionError(
3438             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3439             f"Please report a bug on https://github.com/python/black/issues.  "
3440             f"This invalid output might be helpful: {log}"
3441         ) from None
3442
3443     src_ast_str = "\n".join(_v(src_ast))
3444     dst_ast_str = "\n".join(_v(dst_ast))
3445     if src_ast_str != dst_ast_str:
3446         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3447         raise AssertionError(
3448             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3449             f"the source.  "
3450             f"Please report a bug on https://github.com/python/black/issues.  "
3451             f"This diff might be helpful: {log}"
3452         ) from None
3453
3454
3455 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3456     """Raise AssertionError if `dst` reformats differently the second time."""
3457     newdst = format_str(dst, mode=mode)
3458     if dst != newdst:
3459         log = dump_to_file(
3460             diff(src, dst, "source", "first pass"),
3461             diff(dst, newdst, "first pass", "second pass"),
3462         )
3463         raise AssertionError(
3464             f"INTERNAL ERROR: Black produced different code on the second pass "
3465             f"of the formatter.  "
3466             f"Please report a bug on https://github.com/python/black/issues.  "
3467             f"This diff might be helpful: {log}"
3468         ) from None
3469
3470
3471 def dump_to_file(*output: str) -> str:
3472     """Dump `output` to a temporary file. Return path to the file."""
3473     import tempfile
3474
3475     with tempfile.NamedTemporaryFile(
3476         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3477     ) as f:
3478         for lines in output:
3479             f.write(lines)
3480             if lines and lines[-1] != "\n":
3481                 f.write("\n")
3482     return f.name
3483
3484
3485 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3486     """Return a unified diff string between strings `a` and `b`."""
3487     import difflib
3488
3489     a_lines = [line + "\n" for line in a.split("\n")]
3490     b_lines = [line + "\n" for line in b.split("\n")]
3491     return "".join(
3492         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3493     )
3494
3495
3496 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3497     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3498     err("Aborted!")
3499     for task in tasks:
3500         task.cancel()
3501
3502
3503 def shutdown(loop: BaseEventLoop) -> None:
3504     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3505     try:
3506         if sys.version_info[:2] >= (3, 7):
3507             all_tasks = asyncio.all_tasks
3508         else:
3509             all_tasks = asyncio.Task.all_tasks
3510         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3511         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3512         if not to_cancel:
3513             return
3514
3515         for task in to_cancel:
3516             task.cancel()
3517         loop.run_until_complete(
3518             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3519         )
3520     finally:
3521         # `concurrent.futures.Future` objects cannot be cancelled once they
3522         # are already running. There might be some when the `shutdown()` happened.
3523         # Silence their logger's spew about the event loop being closed.
3524         cf_logger = logging.getLogger("concurrent.futures")
3525         cf_logger.setLevel(logging.CRITICAL)
3526         loop.close()
3527
3528
3529 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3530     """Replace `regex` with `replacement` twice on `original`.
3531
3532     This is used by string normalization to perform replaces on
3533     overlapping matches.
3534     """
3535     return regex.sub(replacement, regex.sub(replacement, original))
3536
3537
3538 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3539     """Compile a regular expression string in `regex`.
3540
3541     If it contains newlines, use verbose mode.
3542     """
3543     if "\n" in regex:
3544         regex = "(?x)" + regex
3545     return re.compile(regex)
3546
3547
3548 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3549     """Like `reversed(enumerate(sequence))` if that were possible."""
3550     index = len(sequence) - 1
3551     for element in reversed(sequence):
3552         yield (index, element)
3553         index -= 1
3554
3555
3556 def enumerate_with_length(
3557     line: Line, reversed: bool = False
3558 ) -> Iterator[Tuple[Index, Leaf, int]]:
3559     """Return an enumeration of leaves with their length.
3560
3561     Stops prematurely on multiline strings and standalone comments.
3562     """
3563     op = cast(
3564         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3565         enumerate_reversed if reversed else enumerate,
3566     )
3567     for index, leaf in op(line.leaves):
3568         length = len(leaf.prefix) + len(leaf.value)
3569         if "\n" in leaf.value:
3570             return  # Multiline strings, we can't continue.
3571
3572         comment: Optional[Leaf]
3573         for comment in line.comments_after(leaf):
3574             length += len(comment.value)
3575
3576         yield index, leaf, length
3577
3578
3579 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3580     """Return True if `line` is no longer than `line_length`.
3581
3582     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3583     """
3584     if not line_str:
3585         line_str = str(line).strip("\n")
3586     return (
3587         len(line_str) <= line_length
3588         and "\n" not in line_str  # multiline strings
3589         and not line.contains_standalone_comments()
3590     )
3591
3592
3593 def can_be_split(line: Line) -> bool:
3594     """Return False if the line cannot be split *for sure*.
3595
3596     This is not an exhaustive search but a cheap heuristic that we can use to
3597     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3598     in unnecessary parentheses).
3599     """
3600     leaves = line.leaves
3601     if len(leaves) < 2:
3602         return False
3603
3604     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3605         call_count = 0
3606         dot_count = 0
3607         next = leaves[-1]
3608         for leaf in leaves[-2::-1]:
3609             if leaf.type in OPENING_BRACKETS:
3610                 if next.type not in CLOSING_BRACKETS:
3611                     return False
3612
3613                 call_count += 1
3614             elif leaf.type == token.DOT:
3615                 dot_count += 1
3616             elif leaf.type == token.NAME:
3617                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3618                     return False
3619
3620             elif leaf.type not in CLOSING_BRACKETS:
3621                 return False
3622
3623             if dot_count > 1 and call_count > 1:
3624                 return False
3625
3626     return True
3627
3628
3629 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3630     """Does `line` have a shape safe to reformat without optional parens around it?
3631
3632     Returns True for only a subset of potentially nice looking formattings but
3633     the point is to not return false positives that end up producing lines that
3634     are too long.
3635     """
3636     bt = line.bracket_tracker
3637     if not bt.delimiters:
3638         # Without delimiters the optional parentheses are useless.
3639         return True
3640
3641     max_priority = bt.max_delimiter_priority()
3642     if bt.delimiter_count_with_priority(max_priority) > 1:
3643         # With more than one delimiter of a kind the optional parentheses read better.
3644         return False
3645
3646     if max_priority == DOT_PRIORITY:
3647         # A single stranded method call doesn't require optional parentheses.
3648         return True
3649
3650     assert len(line.leaves) >= 2, "Stranded delimiter"
3651
3652     first = line.leaves[0]
3653     second = line.leaves[1]
3654     penultimate = line.leaves[-2]
3655     last = line.leaves[-1]
3656
3657     # With a single delimiter, omit if the expression starts or ends with
3658     # a bracket.
3659     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3660         remainder = False
3661         length = 4 * line.depth
3662         for _index, leaf, leaf_length in enumerate_with_length(line):
3663             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3664                 remainder = True
3665             if remainder:
3666                 length += leaf_length
3667                 if length > line_length:
3668                     break
3669
3670                 if leaf.type in OPENING_BRACKETS:
3671                     # There are brackets we can further split on.
3672                     remainder = False
3673
3674         else:
3675             # checked the entire string and line length wasn't exceeded
3676             if len(line.leaves) == _index + 1:
3677                 return True
3678
3679         # Note: we are not returning False here because a line might have *both*
3680         # a leading opening bracket and a trailing closing bracket.  If the
3681         # opening bracket doesn't match our rule, maybe the closing will.
3682
3683     if (
3684         last.type == token.RPAR
3685         or last.type == token.RBRACE
3686         or (
3687             # don't use indexing for omitting optional parentheses;
3688             # it looks weird
3689             last.type == token.RSQB
3690             and last.parent
3691             and last.parent.type != syms.trailer
3692         )
3693     ):
3694         if penultimate.type in OPENING_BRACKETS:
3695             # Empty brackets don't help.
3696             return False
3697
3698         if is_multiline_string(first):
3699             # Additional wrapping of a multiline string in this situation is
3700             # unnecessary.
3701             return True
3702
3703         length = 4 * line.depth
3704         seen_other_brackets = False
3705         for _index, leaf, leaf_length in enumerate_with_length(line):
3706             length += leaf_length
3707             if leaf is last.opening_bracket:
3708                 if seen_other_brackets or length <= line_length:
3709                     return True
3710
3711             elif leaf.type in OPENING_BRACKETS:
3712                 # There are brackets we can further split on.
3713                 seen_other_brackets = True
3714
3715     return False
3716
3717
3718 def get_cache_file(mode: FileMode) -> Path:
3719     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3720
3721
3722 def read_cache(mode: FileMode) -> Cache:
3723     """Read the cache if it exists and is well formed.
3724
3725     If it is not well formed, the call to write_cache later should resolve the issue.
3726     """
3727     cache_file = get_cache_file(mode)
3728     if not cache_file.exists():
3729         return {}
3730
3731     with cache_file.open("rb") as fobj:
3732         try:
3733             cache: Cache = pickle.load(fobj)
3734         except pickle.UnpicklingError:
3735             return {}
3736
3737     return cache
3738
3739
3740 def get_cache_info(path: Path) -> CacheInfo:
3741     """Return the information used to check if a file is already formatted or not."""
3742     stat = path.stat()
3743     return stat.st_mtime, stat.st_size
3744
3745
3746 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3747     """Split an iterable of paths in `sources` into two sets.
3748
3749     The first contains paths of files that modified on disk or are not in the
3750     cache. The other contains paths to non-modified files.
3751     """
3752     todo, done = set(), set()
3753     for src in sources:
3754         src = src.resolve()
3755         if cache.get(src) != get_cache_info(src):
3756             todo.add(src)
3757         else:
3758             done.add(src)
3759     return todo, done
3760
3761
3762 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3763     """Update the cache file."""
3764     cache_file = get_cache_file(mode)
3765     try:
3766         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3767         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3768         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3769             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3770         os.replace(f.name, cache_file)
3771     except OSError:
3772         pass
3773
3774
3775 def patch_click() -> None:
3776     """Make Click not crash.
3777
3778     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3779     default which restricts paths that it can access during the lifetime of the
3780     application.  Click refuses to work in this scenario by raising a RuntimeError.
3781
3782     In case of Black the likelihood that non-ASCII characters are going to be used in
3783     file paths is minimal since it's Python source code.  Moreover, this crash was
3784     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3785     """
3786     try:
3787         from click import core
3788         from click import _unicodefun  # type: ignore
3789     except ModuleNotFoundError:
3790         return
3791
3792     for module in (core, _unicodefun):
3793         if hasattr(module, "_verify_python3_env"):
3794             module._verify_python3_env = lambda: None
3795
3796
3797 def patched_main() -> None:
3798     freeze_support()
3799     patch_click()
3800     main()
3801
3802
3803 if __name__ == "__main__":
3804     patched_main()