]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

e7dce5bf2abc4a3d1a685f653d21a8ed5f712895
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "19.3b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA_IN_CALL = 4
137     TRAILING_COMMA_IN_DEF = 5
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.PY27: set(),
142     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
144     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
145     TargetVersion.PY36: {
146         Feature.UNICODE_LITERALS,
147         Feature.F_STRINGS,
148         Feature.NUMERIC_UNDERSCORES,
149         Feature.TRAILING_COMMA_IN_CALL,
150         Feature.TRAILING_COMMA_IN_DEF,
151     },
152     TargetVersion.PY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA_IN_CALL,
157         Feature.TRAILING_COMMA_IN_DEF,
158     },
159     TargetVersion.PY38: {
160         Feature.UNICODE_LITERALS,
161         Feature.F_STRINGS,
162         Feature.NUMERIC_UNDERSCORES,
163         Feature.TRAILING_COMMA_IN_CALL,
164         Feature.TRAILING_COMMA_IN_DEF,
165     },
166 }
167
168
169 @dataclass
170 class FileMode:
171     target_versions: Set[TargetVersion] = Factory(set)
172     line_length: int = DEFAULT_LINE_LENGTH
173     string_normalization: bool = True
174     is_pyi: bool = False
175
176     def get_cache_key(self) -> str:
177         if self.target_versions:
178             version_str = ",".join(
179                 str(version.value)
180                 for version in sorted(self.target_versions, key=lambda v: v.value)
181             )
182         else:
183             version_str = "-"
184         parts = [
185             version_str,
186             str(self.line_length),
187             str(int(self.string_normalization)),
188             str(int(self.is_pyi)),
189         ]
190         return ".".join(parts)
191
192
193 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
194     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
195
196
197 def read_pyproject_toml(
198     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
199 ) -> Optional[str]:
200     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
201
202     Returns the path to a successfully found and read configuration file, None
203     otherwise.
204     """
205     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
206     if not value:
207         root = find_project_root(ctx.params.get("src", ()))
208         path = root / "pyproject.toml"
209         if path.is_file():
210             value = str(path)
211         else:
212             return None
213
214     try:
215         pyproject_toml = toml.load(value)
216         config = pyproject_toml.get("tool", {}).get("black", {})
217     except (toml.TomlDecodeError, OSError) as e:
218         raise click.FileError(
219             filename=value, hint=f"Error reading configuration file: {e}"
220         )
221
222     if not config:
223         return None
224
225     if ctx.default_map is None:
226         ctx.default_map = {}
227     ctx.default_map.update(  # type: ignore  # bad types in .pyi
228         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
229     )
230     return value
231
232
233 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
234 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
235 @click.option(
236     "-l",
237     "--line-length",
238     type=int,
239     default=DEFAULT_LINE_LENGTH,
240     help="How many characters per line to allow.",
241     show_default=True,
242 )
243 @click.option(
244     "-t",
245     "--target-version",
246     type=click.Choice([v.name.lower() for v in TargetVersion]),
247     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
248     multiple=True,
249     help=(
250         "Python versions that should be supported by Black's output. [default: "
251         "per-file auto-detection]"
252     ),
253 )
254 @click.option(
255     "--py36",
256     is_flag=True,
257     help=(
258         "Allow using Python 3.6-only syntax on all input files.  This will put "
259         "trailing commas in function signatures and calls also after *args and "
260         "**kwargs. Deprecated; use --target-version instead. "
261         "[default: per-file auto-detection]"
262     ),
263 )
264 @click.option(
265     "--pyi",
266     is_flag=True,
267     help=(
268         "Format all input files like typing stubs regardless of file extension "
269         "(useful when piping source on standard input)."
270     ),
271 )
272 @click.option(
273     "-S",
274     "--skip-string-normalization",
275     is_flag=True,
276     help="Don't normalize string quotes or prefixes.",
277 )
278 @click.option(
279     "--check",
280     is_flag=True,
281     help=(
282         "Don't write the files back, just return the status.  Return code 0 "
283         "means nothing would change.  Return code 1 means some files would be "
284         "reformatted.  Return code 123 means there was an internal error."
285     ),
286 )
287 @click.option(
288     "--diff",
289     is_flag=True,
290     help="Don't write the files back, just output a diff for each file on stdout.",
291 )
292 @click.option(
293     "--fast/--safe",
294     is_flag=True,
295     help="If --fast given, skip temporary sanity checks. [default: --safe]",
296 )
297 @click.option(
298     "--include",
299     type=str,
300     default=DEFAULT_INCLUDES,
301     help=(
302         "A regular expression that matches files and directories that should be "
303         "included on recursive searches.  An empty value means all files are "
304         "included regardless of the name.  Use forward slashes for directories on "
305         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
306         "later."
307     ),
308     show_default=True,
309 )
310 @click.option(
311     "--exclude",
312     type=str,
313     default=DEFAULT_EXCLUDES,
314     help=(
315         "A regular expression that matches files and directories that should be "
316         "excluded on recursive searches.  An empty value means no paths are excluded. "
317         "Use forward slashes for directories on all platforms (Windows, too).  "
318         "Exclusions are calculated first, inclusions later."
319     ),
320     show_default=True,
321 )
322 @click.option(
323     "-q",
324     "--quiet",
325     is_flag=True,
326     help=(
327         "Don't emit non-error messages to stderr. Errors are still emitted, "
328         "silence those with 2>/dev/null."
329     ),
330 )
331 @click.option(
332     "-v",
333     "--verbose",
334     is_flag=True,
335     help=(
336         "Also emit messages to stderr about files that were not changed or were "
337         "ignored due to --exclude=."
338     ),
339 )
340 @click.version_option(version=__version__)
341 @click.argument(
342     "src",
343     nargs=-1,
344     type=click.Path(
345         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
346     ),
347     is_eager=True,
348 )
349 @click.option(
350     "--config",
351     type=click.Path(
352         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
353     ),
354     is_eager=True,
355     callback=read_pyproject_toml,
356     help="Read configuration from PATH.",
357 )
358 @click.pass_context
359 def main(
360     ctx: click.Context,
361     code: Optional[str],
362     line_length: int,
363     target_version: List[TargetVersion],
364     check: bool,
365     diff: bool,
366     fast: bool,
367     pyi: bool,
368     py36: bool,
369     skip_string_normalization: bool,
370     quiet: bool,
371     verbose: bool,
372     include: str,
373     exclude: str,
374     src: Tuple[str],
375     config: Optional[str],
376 ) -> None:
377     """The uncompromising code formatter."""
378     write_back = WriteBack.from_configuration(check=check, diff=diff)
379     if target_version:
380         if py36:
381             err(f"Cannot use both --target-version and --py36")
382             ctx.exit(2)
383         else:
384             versions = set(target_version)
385     elif py36:
386         err(
387             "--py36 is deprecated and will be removed in a future version. "
388             "Use --target-version py36 instead."
389         )
390         versions = PY36_VERSIONS
391     else:
392         # We'll autodetect later.
393         versions = set()
394     mode = FileMode(
395         target_versions=versions,
396         line_length=line_length,
397         is_pyi=pyi,
398         string_normalization=not skip_string_normalization,
399     )
400     if config and verbose:
401         out(f"Using configuration from {config}.", bold=False, fg="blue")
402     if code is not None:
403         print(format_str(code, mode=mode))
404         ctx.exit(0)
405     try:
406         include_regex = re_compile_maybe_verbose(include)
407     except re.error:
408         err(f"Invalid regular expression for include given: {include!r}")
409         ctx.exit(2)
410     try:
411         exclude_regex = re_compile_maybe_verbose(exclude)
412     except re.error:
413         err(f"Invalid regular expression for exclude given: {exclude!r}")
414         ctx.exit(2)
415     report = Report(check=check, quiet=quiet, verbose=verbose)
416     root = find_project_root(src)
417     sources: Set[Path] = set()
418     for s in src:
419         p = Path(s)
420         if p.is_dir():
421             sources.update(
422                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
423             )
424         elif p.is_file() or s == "-":
425             # if a file was explicitly given, we don't care about its extension
426             sources.add(p)
427         else:
428             err(f"invalid path: {s}")
429     if len(sources) == 0:
430         if verbose or not quiet:
431             out("No paths given. Nothing to do 😴")
432         ctx.exit(0)
433
434     if len(sources) == 1:
435         reformat_one(
436             src=sources.pop(),
437             fast=fast,
438             write_back=write_back,
439             mode=mode,
440             report=report,
441         )
442     else:
443         loop = asyncio.get_event_loop()
444         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
445         try:
446             loop.run_until_complete(
447                 schedule_formatting(
448                     sources=sources,
449                     fast=fast,
450                     write_back=write_back,
451                     mode=mode,
452                     report=report,
453                     loop=loop,
454                     executor=executor,
455                 )
456             )
457         finally:
458             shutdown(loop)
459     if verbose or not quiet:
460         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
461         out(f"All done! {bang}")
462         click.secho(str(report), err=True)
463     ctx.exit(report.return_code)
464
465
466 def reformat_one(
467     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
468 ) -> None:
469     """Reformat a single file under `src` without spawning child processes.
470
471     If `quiet` is True, non-error messages are not output. `line_length`,
472     `write_back`, `fast` and `pyi` options are passed to
473     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
474     """
475     try:
476         changed = Changed.NO
477         if not src.is_file() and str(src) == "-":
478             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
479                 changed = Changed.YES
480         else:
481             cache: Cache = {}
482             if write_back != WriteBack.DIFF:
483                 cache = read_cache(mode)
484                 res_src = src.resolve()
485                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
486                     changed = Changed.CACHED
487             if changed is not Changed.CACHED and format_file_in_place(
488                 src, fast=fast, write_back=write_back, mode=mode
489             ):
490                 changed = Changed.YES
491             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
492                 write_back is WriteBack.CHECK and changed is Changed.NO
493             ):
494                 write_cache(cache, [src], mode)
495         report.done(src, changed)
496     except Exception as exc:
497         report.failed(src, str(exc))
498
499
500 async def schedule_formatting(
501     sources: Set[Path],
502     fast: bool,
503     write_back: WriteBack,
504     mode: FileMode,
505     report: "Report",
506     loop: BaseEventLoop,
507     executor: Executor,
508 ) -> None:
509     """Run formatting of `sources` in parallel using the provided `executor`.
510
511     (Use ProcessPoolExecutors for actual parallelism.)
512
513     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
514     :func:`format_file_in_place`.
515     """
516     cache: Cache = {}
517     if write_back != WriteBack.DIFF:
518         cache = read_cache(mode)
519         sources, cached = filter_cached(cache, sources)
520         for src in sorted(cached):
521             report.done(src, Changed.CACHED)
522     if not sources:
523         return
524
525     cancelled = []
526     sources_to_cache = []
527     lock = None
528     if write_back == WriteBack.DIFF:
529         # For diff output, we need locks to ensure we don't interleave output
530         # from different processes.
531         manager = Manager()
532         lock = manager.Lock()
533     tasks = {
534         asyncio.ensure_future(
535             loop.run_in_executor(
536                 executor, format_file_in_place, src, fast, mode, write_back, lock
537             )
538         ): src
539         for src in sorted(sources)
540     }
541     pending: Iterable[asyncio.Future] = tasks.keys()
542     try:
543         loop.add_signal_handler(signal.SIGINT, cancel, pending)
544         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
545     except NotImplementedError:
546         # There are no good alternatives for these on Windows.
547         pass
548     while pending:
549         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
550         for task in done:
551             src = tasks.pop(task)
552             if task.cancelled():
553                 cancelled.append(task)
554             elif task.exception():
555                 report.failed(src, str(task.exception()))
556             else:
557                 changed = Changed.YES if task.result() else Changed.NO
558                 # If the file was written back or was successfully checked as
559                 # well-formatted, store this information in the cache.
560                 if write_back is WriteBack.YES or (
561                     write_back is WriteBack.CHECK and changed is Changed.NO
562                 ):
563                     sources_to_cache.append(src)
564                 report.done(src, changed)
565     if cancelled:
566         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
567     if sources_to_cache:
568         write_cache(cache, sources_to_cache, mode)
569
570
571 def format_file_in_place(
572     src: Path,
573     fast: bool,
574     mode: FileMode,
575     write_back: WriteBack = WriteBack.NO,
576     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
577 ) -> bool:
578     """Format file under `src` path. Return True if changed.
579
580     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
581     code to the file.
582     `line_length` and `fast` options are passed to :func:`format_file_contents`.
583     """
584     if src.suffix == ".pyi":
585         mode = evolve(mode, is_pyi=True)
586
587     then = datetime.utcfromtimestamp(src.stat().st_mtime)
588     with open(src, "rb") as buf:
589         src_contents, encoding, newline = decode_bytes(buf.read())
590     try:
591         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
592     except NothingChanged:
593         return False
594
595     if write_back == write_back.YES:
596         with open(src, "w", encoding=encoding, newline=newline) as f:
597             f.write(dst_contents)
598     elif write_back == write_back.DIFF:
599         now = datetime.utcnow()
600         src_name = f"{src}\t{then} +0000"
601         dst_name = f"{src}\t{now} +0000"
602         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
603         if lock:
604             lock.acquire()
605         try:
606             f = io.TextIOWrapper(
607                 sys.stdout.buffer,
608                 encoding=encoding,
609                 newline=newline,
610                 write_through=True,
611             )
612             f.write(diff_contents)
613             f.detach()
614         finally:
615             if lock:
616                 lock.release()
617     return True
618
619
620 def format_stdin_to_stdout(
621     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
622 ) -> bool:
623     """Format file on stdin. Return True if changed.
624
625     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
626     write a diff to stdout. The `mode` argument is passed to
627     :func:`format_file_contents`.
628     """
629     then = datetime.utcnow()
630     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
631     dst = src
632     try:
633         dst = format_file_contents(src, fast=fast, mode=mode)
634         return True
635
636     except NothingChanged:
637         return False
638
639     finally:
640         f = io.TextIOWrapper(
641             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
642         )
643         if write_back == WriteBack.YES:
644             f.write(dst)
645         elif write_back == WriteBack.DIFF:
646             now = datetime.utcnow()
647             src_name = f"STDIN\t{then} +0000"
648             dst_name = f"STDOUT\t{now} +0000"
649             f.write(diff(src, dst, src_name, dst_name))
650         f.detach()
651
652
653 def format_file_contents(
654     src_contents: str, *, fast: bool, mode: FileMode
655 ) -> FileContent:
656     """Reformat contents a file and return new contents.
657
658     If `fast` is False, additionally confirm that the reformatted code is
659     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
660     `line_length` is passed to :func:`format_str`.
661     """
662     if src_contents.strip() == "":
663         raise NothingChanged
664
665     dst_contents = format_str(src_contents, mode=mode)
666     if src_contents == dst_contents:
667         raise NothingChanged
668
669     if not fast:
670         assert_equivalent(src_contents, dst_contents)
671         assert_stable(src_contents, dst_contents, mode=mode)
672     return dst_contents
673
674
675 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
676     """Reformat a string and return new contents.
677
678     `line_length` determines how many characters per line are allowed.
679     """
680     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
681     dst_contents = ""
682     future_imports = get_future_imports(src_node)
683     if mode.target_versions:
684         versions = mode.target_versions
685     else:
686         versions = detect_target_versions(src_node)
687     normalize_fmt_off(src_node)
688     lines = LineGenerator(
689         remove_u_prefix="unicode_literals" in future_imports
690         or supports_feature(versions, Feature.UNICODE_LITERALS),
691         is_pyi=mode.is_pyi,
692         normalize_strings=mode.string_normalization,
693     )
694     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
695     empty_line = Line()
696     after = 0
697     split_line_features = {
698         feature
699         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
700         if supports_feature(versions, feature)
701     }
702     for current_line in lines.visit(src_node):
703         for _ in range(after):
704             dst_contents += str(empty_line)
705         before, after = elt.maybe_empty_lines(current_line)
706         for _ in range(before):
707             dst_contents += str(empty_line)
708         for line in split_line(
709             current_line, line_length=mode.line_length, features=split_line_features
710         ):
711             dst_contents += str(line)
712     return dst_contents
713
714
715 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
716     """Return a tuple of (decoded_contents, encoding, newline).
717
718     `newline` is either CRLF or LF but `decoded_contents` is decoded with
719     universal newlines (i.e. only contains LF).
720     """
721     srcbuf = io.BytesIO(src)
722     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
723     if not lines:
724         return "", encoding, "\n"
725
726     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
727     srcbuf.seek(0)
728     with io.TextIOWrapper(srcbuf, encoding) as tiow:
729         return tiow.read(), encoding, newline
730
731
732 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
733     if not target_versions:
734         # No target_version specified, so try all grammars.
735         return [
736             pygram.python_grammar_no_print_statement_no_exec_statement,
737             pygram.python_grammar_no_print_statement,
738             pygram.python_grammar,
739         ]
740     elif all(version.is_python2() for version in target_versions):
741         # Python 2-only code, so try Python 2 grammars.
742         return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
743     else:
744         # Python 3-compatible code, so only try Python 3 grammar.
745         return [pygram.python_grammar_no_print_statement_no_exec_statement]
746
747
748 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
749     """Given a string with source, return the lib2to3 Node."""
750     if src_txt[-1:] != "\n":
751         src_txt += "\n"
752
753     for grammar in get_grammars(set(target_versions)):
754         drv = driver.Driver(grammar, pytree.convert)
755         try:
756             result = drv.parse_string(src_txt, True)
757             break
758
759         except ParseError as pe:
760             lineno, column = pe.context[1]
761             lines = src_txt.splitlines()
762             try:
763                 faulty_line = lines[lineno - 1]
764             except IndexError:
765                 faulty_line = "<line number missing in source>"
766             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
767     else:
768         raise exc from None
769
770     if isinstance(result, Leaf):
771         result = Node(syms.file_input, [result])
772     return result
773
774
775 def lib2to3_unparse(node: Node) -> str:
776     """Given a lib2to3 node, return its string representation."""
777     code = str(node)
778     return code
779
780
781 T = TypeVar("T")
782
783
784 class Visitor(Generic[T]):
785     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
786
787     def visit(self, node: LN) -> Iterator[T]:
788         """Main method to visit `node` and its children.
789
790         It tries to find a `visit_*()` method for the given `node.type`, like
791         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
792         If no dedicated `visit_*()` method is found, chooses `visit_default()`
793         instead.
794
795         Then yields objects of type `T` from the selected visitor.
796         """
797         if node.type < 256:
798             name = token.tok_name[node.type]
799         else:
800             name = type_repr(node.type)
801         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
802
803     def visit_default(self, node: LN) -> Iterator[T]:
804         """Default `visit_*()` implementation. Recurses to children of `node`."""
805         if isinstance(node, Node):
806             for child in node.children:
807                 yield from self.visit(child)
808
809
810 @dataclass
811 class DebugVisitor(Visitor[T]):
812     tree_depth: int = 0
813
814     def visit_default(self, node: LN) -> Iterator[T]:
815         indent = " " * (2 * self.tree_depth)
816         if isinstance(node, Node):
817             _type = type_repr(node.type)
818             out(f"{indent}{_type}", fg="yellow")
819             self.tree_depth += 1
820             for child in node.children:
821                 yield from self.visit(child)
822
823             self.tree_depth -= 1
824             out(f"{indent}/{_type}", fg="yellow", bold=False)
825         else:
826             _type = token.tok_name.get(node.type, str(node.type))
827             out(f"{indent}{_type}", fg="blue", nl=False)
828             if node.prefix:
829                 # We don't have to handle prefixes for `Node` objects since
830                 # that delegates to the first child anyway.
831                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
832             out(f" {node.value!r}", fg="blue", bold=False)
833
834     @classmethod
835     def show(cls, code: Union[str, Leaf, Node]) -> None:
836         """Pretty-print the lib2to3 AST of a given string of `code`.
837
838         Convenience method for debugging.
839         """
840         v: DebugVisitor[None] = DebugVisitor()
841         if isinstance(code, str):
842             code = lib2to3_parse(code)
843         list(v.visit(code))
844
845
846 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
847 STATEMENT = {
848     syms.if_stmt,
849     syms.while_stmt,
850     syms.for_stmt,
851     syms.try_stmt,
852     syms.except_clause,
853     syms.with_stmt,
854     syms.funcdef,
855     syms.classdef,
856 }
857 STANDALONE_COMMENT = 153
858 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
859 LOGIC_OPERATORS = {"and", "or"}
860 COMPARATORS = {
861     token.LESS,
862     token.GREATER,
863     token.EQEQUAL,
864     token.NOTEQUAL,
865     token.LESSEQUAL,
866     token.GREATEREQUAL,
867 }
868 MATH_OPERATORS = {
869     token.VBAR,
870     token.CIRCUMFLEX,
871     token.AMPER,
872     token.LEFTSHIFT,
873     token.RIGHTSHIFT,
874     token.PLUS,
875     token.MINUS,
876     token.STAR,
877     token.SLASH,
878     token.DOUBLESLASH,
879     token.PERCENT,
880     token.AT,
881     token.TILDE,
882     token.DOUBLESTAR,
883 }
884 STARS = {token.STAR, token.DOUBLESTAR}
885 VARARGS_PARENTS = {
886     syms.arglist,
887     syms.argument,  # double star in arglist
888     syms.trailer,  # single argument to call
889     syms.typedargslist,
890     syms.varargslist,  # lambdas
891 }
892 UNPACKING_PARENTS = {
893     syms.atom,  # single element of a list or set literal
894     syms.dictsetmaker,
895     syms.listmaker,
896     syms.testlist_gexp,
897     syms.testlist_star_expr,
898 }
899 TEST_DESCENDANTS = {
900     syms.test,
901     syms.lambdef,
902     syms.or_test,
903     syms.and_test,
904     syms.not_test,
905     syms.comparison,
906     syms.star_expr,
907     syms.expr,
908     syms.xor_expr,
909     syms.and_expr,
910     syms.shift_expr,
911     syms.arith_expr,
912     syms.trailer,
913     syms.term,
914     syms.power,
915 }
916 ASSIGNMENTS = {
917     "=",
918     "+=",
919     "-=",
920     "*=",
921     "@=",
922     "/=",
923     "%=",
924     "&=",
925     "|=",
926     "^=",
927     "<<=",
928     ">>=",
929     "**=",
930     "//=",
931 }
932 COMPREHENSION_PRIORITY = 20
933 COMMA_PRIORITY = 18
934 TERNARY_PRIORITY = 16
935 LOGIC_PRIORITY = 14
936 STRING_PRIORITY = 12
937 COMPARATOR_PRIORITY = 10
938 MATH_PRIORITIES = {
939     token.VBAR: 9,
940     token.CIRCUMFLEX: 8,
941     token.AMPER: 7,
942     token.LEFTSHIFT: 6,
943     token.RIGHTSHIFT: 6,
944     token.PLUS: 5,
945     token.MINUS: 5,
946     token.STAR: 4,
947     token.SLASH: 4,
948     token.DOUBLESLASH: 4,
949     token.PERCENT: 4,
950     token.AT: 4,
951     token.TILDE: 3,
952     token.DOUBLESTAR: 2,
953 }
954 DOT_PRIORITY = 1
955
956
957 @dataclass
958 class BracketTracker:
959     """Keeps track of brackets on a line."""
960
961     depth: int = 0
962     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
963     delimiters: Dict[LeafID, Priority] = Factory(dict)
964     previous: Optional[Leaf] = None
965     _for_loop_depths: List[int] = Factory(list)
966     _lambda_argument_depths: List[int] = Factory(list)
967
968     def mark(self, leaf: Leaf) -> None:
969         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
970
971         All leaves receive an int `bracket_depth` field that stores how deep
972         within brackets a given leaf is. 0 means there are no enclosing brackets
973         that started on this line.
974
975         If a leaf is itself a closing bracket, it receives an `opening_bracket`
976         field that it forms a pair with. This is a one-directional link to
977         avoid reference cycles.
978
979         If a leaf is a delimiter (a token on which Black can split the line if
980         needed) and it's on depth 0, its `id()` is stored in the tracker's
981         `delimiters` field.
982         """
983         if leaf.type == token.COMMENT:
984             return
985
986         self.maybe_decrement_after_for_loop_variable(leaf)
987         self.maybe_decrement_after_lambda_arguments(leaf)
988         if leaf.type in CLOSING_BRACKETS:
989             self.depth -= 1
990             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
991             leaf.opening_bracket = opening_bracket
992         leaf.bracket_depth = self.depth
993         if self.depth == 0:
994             delim = is_split_before_delimiter(leaf, self.previous)
995             if delim and self.previous is not None:
996                 self.delimiters[id(self.previous)] = delim
997             else:
998                 delim = is_split_after_delimiter(leaf, self.previous)
999                 if delim:
1000                     self.delimiters[id(leaf)] = delim
1001         if leaf.type in OPENING_BRACKETS:
1002             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1003             self.depth += 1
1004         self.previous = leaf
1005         self.maybe_increment_lambda_arguments(leaf)
1006         self.maybe_increment_for_loop_variable(leaf)
1007
1008     def any_open_brackets(self) -> bool:
1009         """Return True if there is an yet unmatched open bracket on the line."""
1010         return bool(self.bracket_match)
1011
1012     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1013         """Return the highest priority of a delimiter found on the line.
1014
1015         Values are consistent with what `is_split_*_delimiter()` return.
1016         Raises ValueError on no delimiters.
1017         """
1018         return max(v for k, v in self.delimiters.items() if k not in exclude)
1019
1020     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1021         """Return the number of delimiters with the given `priority`.
1022
1023         If no `priority` is passed, defaults to max priority on the line.
1024         """
1025         if not self.delimiters:
1026             return 0
1027
1028         priority = priority or self.max_delimiter_priority()
1029         return sum(1 for p in self.delimiters.values() if p == priority)
1030
1031     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1032         """In a for loop, or comprehension, the variables are often unpacks.
1033
1034         To avoid splitting on the comma in this situation, increase the depth of
1035         tokens between `for` and `in`.
1036         """
1037         if leaf.type == token.NAME and leaf.value == "for":
1038             self.depth += 1
1039             self._for_loop_depths.append(self.depth)
1040             return True
1041
1042         return False
1043
1044     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1045         """See `maybe_increment_for_loop_variable` above for explanation."""
1046         if (
1047             self._for_loop_depths
1048             and self._for_loop_depths[-1] == self.depth
1049             and leaf.type == token.NAME
1050             and leaf.value == "in"
1051         ):
1052             self.depth -= 1
1053             self._for_loop_depths.pop()
1054             return True
1055
1056         return False
1057
1058     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1059         """In a lambda expression, there might be more than one argument.
1060
1061         To avoid splitting on the comma in this situation, increase the depth of
1062         tokens between `lambda` and `:`.
1063         """
1064         if leaf.type == token.NAME and leaf.value == "lambda":
1065             self.depth += 1
1066             self._lambda_argument_depths.append(self.depth)
1067             return True
1068
1069         return False
1070
1071     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1072         """See `maybe_increment_lambda_arguments` above for explanation."""
1073         if (
1074             self._lambda_argument_depths
1075             and self._lambda_argument_depths[-1] == self.depth
1076             and leaf.type == token.COLON
1077         ):
1078             self.depth -= 1
1079             self._lambda_argument_depths.pop()
1080             return True
1081
1082         return False
1083
1084     def get_open_lsqb(self) -> Optional[Leaf]:
1085         """Return the most recent opening square bracket (if any)."""
1086         return self.bracket_match.get((self.depth - 1, token.RSQB))
1087
1088
1089 @dataclass
1090 class Line:
1091     """Holds leaves and comments. Can be printed with `str(line)`."""
1092
1093     depth: int = 0
1094     leaves: List[Leaf] = Factory(list)
1095     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1096     bracket_tracker: BracketTracker = Factory(BracketTracker)
1097     inside_brackets: bool = False
1098     should_explode: bool = False
1099
1100     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1101         """Add a new `leaf` to the end of the line.
1102
1103         Unless `preformatted` is True, the `leaf` will receive a new consistent
1104         whitespace prefix and metadata applied by :class:`BracketTracker`.
1105         Trailing commas are maybe removed, unpacked for loop variables are
1106         demoted from being delimiters.
1107
1108         Inline comments are put aside.
1109         """
1110         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1111         if not has_value:
1112             return
1113
1114         if token.COLON == leaf.type and self.is_class_paren_empty:
1115             del self.leaves[-2:]
1116         if self.leaves and not preformatted:
1117             # Note: at this point leaf.prefix should be empty except for
1118             # imports, for which we only preserve newlines.
1119             leaf.prefix += whitespace(
1120                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1121             )
1122         if self.inside_brackets or not preformatted:
1123             self.bracket_tracker.mark(leaf)
1124             self.maybe_remove_trailing_comma(leaf)
1125         if not self.append_comment(leaf):
1126             self.leaves.append(leaf)
1127
1128     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1129         """Like :func:`append()` but disallow invalid standalone comment structure.
1130
1131         Raises ValueError when any `leaf` is appended after a standalone comment
1132         or when a standalone comment is not the first leaf on the line.
1133         """
1134         if self.bracket_tracker.depth == 0:
1135             if self.is_comment:
1136                 raise ValueError("cannot append to standalone comments")
1137
1138             if self.leaves and leaf.type == STANDALONE_COMMENT:
1139                 raise ValueError(
1140                     "cannot append standalone comments to a populated line"
1141                 )
1142
1143         self.append(leaf, preformatted=preformatted)
1144
1145     @property
1146     def is_comment(self) -> bool:
1147         """Is this line a standalone comment?"""
1148         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1149
1150     @property
1151     def is_decorator(self) -> bool:
1152         """Is this line a decorator?"""
1153         return bool(self) and self.leaves[0].type == token.AT
1154
1155     @property
1156     def is_import(self) -> bool:
1157         """Is this an import line?"""
1158         return bool(self) and is_import(self.leaves[0])
1159
1160     @property
1161     def is_class(self) -> bool:
1162         """Is this line a class definition?"""
1163         return (
1164             bool(self)
1165             and self.leaves[0].type == token.NAME
1166             and self.leaves[0].value == "class"
1167         )
1168
1169     @property
1170     def is_stub_class(self) -> bool:
1171         """Is this line a class definition with a body consisting only of "..."?"""
1172         return self.is_class and self.leaves[-3:] == [
1173             Leaf(token.DOT, ".") for _ in range(3)
1174         ]
1175
1176     @property
1177     def is_def(self) -> bool:
1178         """Is this a function definition? (Also returns True for async defs.)"""
1179         try:
1180             first_leaf = self.leaves[0]
1181         except IndexError:
1182             return False
1183
1184         try:
1185             second_leaf: Optional[Leaf] = self.leaves[1]
1186         except IndexError:
1187             second_leaf = None
1188         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1189             first_leaf.type == token.ASYNC
1190             and second_leaf is not None
1191             and second_leaf.type == token.NAME
1192             and second_leaf.value == "def"
1193         )
1194
1195     @property
1196     def is_class_paren_empty(self) -> bool:
1197         """Is this a class with no base classes but using parentheses?
1198
1199         Those are unnecessary and should be removed.
1200         """
1201         return (
1202             bool(self)
1203             and len(self.leaves) == 4
1204             and self.is_class
1205             and self.leaves[2].type == token.LPAR
1206             and self.leaves[2].value == "("
1207             and self.leaves[3].type == token.RPAR
1208             and self.leaves[3].value == ")"
1209         )
1210
1211     @property
1212     def is_triple_quoted_string(self) -> bool:
1213         """Is the line a triple quoted string?"""
1214         return (
1215             bool(self)
1216             and self.leaves[0].type == token.STRING
1217             and self.leaves[0].value.startswith(('"""', "'''"))
1218         )
1219
1220     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1221         """If so, needs to be split before emitting."""
1222         for leaf in self.leaves:
1223             if leaf.type == STANDALONE_COMMENT:
1224                 if leaf.bracket_depth <= depth_limit:
1225                     return True
1226         return False
1227
1228     def contains_inner_type_comments(self) -> bool:
1229         ignored_ids = set()
1230         try:
1231             last_leaf = self.leaves[-1]
1232             ignored_ids.add(id(last_leaf))
1233             if last_leaf.type == token.COMMA:
1234                 # When trailing commas are inserted by Black for consistency, comments
1235                 # after the previous last element are not moved (they don't have to,
1236                 # rendering will still be correct).  So we ignore trailing commas.
1237                 last_leaf = self.leaves[-2]
1238                 ignored_ids.add(id(last_leaf))
1239         except IndexError:
1240             return False
1241
1242         for leaf_id, comments in self.comments.items():
1243             if leaf_id in ignored_ids:
1244                 continue
1245
1246             for comment in comments:
1247                 if is_type_comment(comment):
1248                     return True
1249
1250         return False
1251
1252     def contains_multiline_strings(self) -> bool:
1253         for leaf in self.leaves:
1254             if is_multiline_string(leaf):
1255                 return True
1256
1257         return False
1258
1259     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1260         """Remove trailing comma if there is one and it's safe."""
1261         if not (
1262             self.leaves
1263             and self.leaves[-1].type == token.COMMA
1264             and closing.type in CLOSING_BRACKETS
1265         ):
1266             return False
1267
1268         if closing.type == token.RBRACE:
1269             self.remove_trailing_comma()
1270             return True
1271
1272         if closing.type == token.RSQB:
1273             comma = self.leaves[-1]
1274             if comma.parent and comma.parent.type == syms.listmaker:
1275                 self.remove_trailing_comma()
1276                 return True
1277
1278         # For parens let's check if it's safe to remove the comma.
1279         # Imports are always safe.
1280         if self.is_import:
1281             self.remove_trailing_comma()
1282             return True
1283
1284         # Otherwise, if the trailing one is the only one, we might mistakenly
1285         # change a tuple into a different type by removing the comma.
1286         depth = closing.bracket_depth + 1
1287         commas = 0
1288         opening = closing.opening_bracket
1289         for _opening_index, leaf in enumerate(self.leaves):
1290             if leaf is opening:
1291                 break
1292
1293         else:
1294             return False
1295
1296         for leaf in self.leaves[_opening_index + 1 :]:
1297             if leaf is closing:
1298                 break
1299
1300             bracket_depth = leaf.bracket_depth
1301             if bracket_depth == depth and leaf.type == token.COMMA:
1302                 commas += 1
1303                 if leaf.parent and leaf.parent.type == syms.arglist:
1304                     commas += 1
1305                     break
1306
1307         if commas > 1:
1308             self.remove_trailing_comma()
1309             return True
1310
1311         return False
1312
1313     def append_comment(self, comment: Leaf) -> bool:
1314         """Add an inline or standalone comment to the line."""
1315         if (
1316             comment.type == STANDALONE_COMMENT
1317             and self.bracket_tracker.any_open_brackets()
1318         ):
1319             comment.prefix = ""
1320             return False
1321
1322         if comment.type != token.COMMENT:
1323             return False
1324
1325         if not self.leaves:
1326             comment.type = STANDALONE_COMMENT
1327             comment.prefix = ""
1328             return False
1329
1330         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1331         return True
1332
1333     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1334         """Generate comments that should appear directly after `leaf`."""
1335         return self.comments.get(id(leaf), [])
1336
1337     def remove_trailing_comma(self) -> None:
1338         """Remove the trailing comma and moves the comments attached to it."""
1339         trailing_comma = self.leaves.pop()
1340         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1341         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1342             trailing_comma_comments
1343         )
1344
1345     def is_complex_subscript(self, leaf: Leaf) -> bool:
1346         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1347         open_lsqb = self.bracket_tracker.get_open_lsqb()
1348         if open_lsqb is None:
1349             return False
1350
1351         subscript_start = open_lsqb.next_sibling
1352
1353         if isinstance(subscript_start, Node):
1354             if subscript_start.type == syms.listmaker:
1355                 return False
1356
1357             if subscript_start.type == syms.subscriptlist:
1358                 subscript_start = child_towards(subscript_start, leaf)
1359         return subscript_start is not None and any(
1360             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1361         )
1362
1363     def __str__(self) -> str:
1364         """Render the line."""
1365         if not self:
1366             return "\n"
1367
1368         indent = "    " * self.depth
1369         leaves = iter(self.leaves)
1370         first = next(leaves)
1371         res = f"{first.prefix}{indent}{first.value}"
1372         for leaf in leaves:
1373             res += str(leaf)
1374         for comment in itertools.chain.from_iterable(self.comments.values()):
1375             res += str(comment)
1376         return res + "\n"
1377
1378     def __bool__(self) -> bool:
1379         """Return True if the line has leaves or comments."""
1380         return bool(self.leaves or self.comments)
1381
1382
1383 @dataclass
1384 class EmptyLineTracker:
1385     """Provides a stateful method that returns the number of potential extra
1386     empty lines needed before and after the currently processed line.
1387
1388     Note: this tracker works on lines that haven't been split yet.  It assumes
1389     the prefix of the first leaf consists of optional newlines.  Those newlines
1390     are consumed by `maybe_empty_lines()` and included in the computation.
1391     """
1392
1393     is_pyi: bool = False
1394     previous_line: Optional[Line] = None
1395     previous_after: int = 0
1396     previous_defs: List[int] = Factory(list)
1397
1398     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1399         """Return the number of extra empty lines before and after the `current_line`.
1400
1401         This is for separating `def`, `async def` and `class` with extra empty
1402         lines (two on module-level).
1403         """
1404         before, after = self._maybe_empty_lines(current_line)
1405         before -= self.previous_after
1406         self.previous_after = after
1407         self.previous_line = current_line
1408         return before, after
1409
1410     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1411         max_allowed = 1
1412         if current_line.depth == 0:
1413             max_allowed = 1 if self.is_pyi else 2
1414         if current_line.leaves:
1415             # Consume the first leaf's extra newlines.
1416             first_leaf = current_line.leaves[0]
1417             before = first_leaf.prefix.count("\n")
1418             before = min(before, max_allowed)
1419             first_leaf.prefix = ""
1420         else:
1421             before = 0
1422         depth = current_line.depth
1423         while self.previous_defs and self.previous_defs[-1] >= depth:
1424             self.previous_defs.pop()
1425             if self.is_pyi:
1426                 before = 0 if depth else 1
1427             else:
1428                 before = 1 if depth else 2
1429         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1430             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1431
1432         if (
1433             self.previous_line
1434             and self.previous_line.is_import
1435             and not current_line.is_import
1436             and depth == self.previous_line.depth
1437         ):
1438             return (before or 1), 0
1439
1440         if (
1441             self.previous_line
1442             and self.previous_line.is_class
1443             and current_line.is_triple_quoted_string
1444         ):
1445             return before, 1
1446
1447         return before, 0
1448
1449     def _maybe_empty_lines_for_class_or_def(
1450         self, current_line: Line, before: int
1451     ) -> Tuple[int, int]:
1452         if not current_line.is_decorator:
1453             self.previous_defs.append(current_line.depth)
1454         if self.previous_line is None:
1455             # Don't insert empty lines before the first line in the file.
1456             return 0, 0
1457
1458         if self.previous_line.is_decorator:
1459             return 0, 0
1460
1461         if self.previous_line.depth < current_line.depth and (
1462             self.previous_line.is_class or self.previous_line.is_def
1463         ):
1464             return 0, 0
1465
1466         if (
1467             self.previous_line.is_comment
1468             and self.previous_line.depth == current_line.depth
1469             and before == 0
1470         ):
1471             return 0, 0
1472
1473         if self.is_pyi:
1474             if self.previous_line.depth > current_line.depth:
1475                 newlines = 1
1476             elif current_line.is_class or self.previous_line.is_class:
1477                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1478                     # No blank line between classes with an empty body
1479                     newlines = 0
1480                 else:
1481                     newlines = 1
1482             elif current_line.is_def and not self.previous_line.is_def:
1483                 # Blank line between a block of functions and a block of non-functions
1484                 newlines = 1
1485             else:
1486                 newlines = 0
1487         else:
1488             newlines = 2
1489         if current_line.depth and newlines:
1490             newlines -= 1
1491         return newlines, 0
1492
1493
1494 @dataclass
1495 class LineGenerator(Visitor[Line]):
1496     """Generates reformatted Line objects.  Empty lines are not emitted.
1497
1498     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1499     in ways that will no longer stringify to valid Python code on the tree.
1500     """
1501
1502     is_pyi: bool = False
1503     normalize_strings: bool = True
1504     current_line: Line = Factory(Line)
1505     remove_u_prefix: bool = False
1506
1507     def line(self, indent: int = 0) -> Iterator[Line]:
1508         """Generate a line.
1509
1510         If the line is empty, only emit if it makes sense.
1511         If the line is too long, split it first and then generate.
1512
1513         If any lines were generated, set up a new current_line.
1514         """
1515         if not self.current_line:
1516             self.current_line.depth += indent
1517             return  # Line is empty, don't emit. Creating a new one unnecessary.
1518
1519         complete_line = self.current_line
1520         self.current_line = Line(depth=complete_line.depth + indent)
1521         yield complete_line
1522
1523     def visit_default(self, node: LN) -> Iterator[Line]:
1524         """Default `visit_*()` implementation. Recurses to children of `node`."""
1525         if isinstance(node, Leaf):
1526             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1527             for comment in generate_comments(node):
1528                 if any_open_brackets:
1529                     # any comment within brackets is subject to splitting
1530                     self.current_line.append(comment)
1531                 elif comment.type == token.COMMENT:
1532                     # regular trailing comment
1533                     self.current_line.append(comment)
1534                     yield from self.line()
1535
1536                 else:
1537                     # regular standalone comment
1538                     yield from self.line()
1539
1540                     self.current_line.append(comment)
1541                     yield from self.line()
1542
1543             normalize_prefix(node, inside_brackets=any_open_brackets)
1544             if self.normalize_strings and node.type == token.STRING:
1545                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1546                 normalize_string_quotes(node)
1547             if node.type == token.NUMBER:
1548                 normalize_numeric_literal(node)
1549             if node.type not in WHITESPACE:
1550                 self.current_line.append(node)
1551         yield from super().visit_default(node)
1552
1553     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1554         """Increase indentation level, maybe yield a line."""
1555         # In blib2to3 INDENT never holds comments.
1556         yield from self.line(+1)
1557         yield from self.visit_default(node)
1558
1559     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1560         """Decrease indentation level, maybe yield a line."""
1561         # The current line might still wait for trailing comments.  At DEDENT time
1562         # there won't be any (they would be prefixes on the preceding NEWLINE).
1563         # Emit the line then.
1564         yield from self.line()
1565
1566         # While DEDENT has no value, its prefix may contain standalone comments
1567         # that belong to the current indentation level.  Get 'em.
1568         yield from self.visit_default(node)
1569
1570         # Finally, emit the dedent.
1571         yield from self.line(-1)
1572
1573     def visit_stmt(
1574         self, node: Node, keywords: Set[str], parens: Set[str]
1575     ) -> Iterator[Line]:
1576         """Visit a statement.
1577
1578         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1579         `def`, `with`, `class`, `assert` and assignments.
1580
1581         The relevant Python language `keywords` for a given statement will be
1582         NAME leaves within it. This methods puts those on a separate line.
1583
1584         `parens` holds a set of string leaf values immediately after which
1585         invisible parens should be put.
1586         """
1587         normalize_invisible_parens(node, parens_after=parens)
1588         for child in node.children:
1589             if child.type == token.NAME and child.value in keywords:  # type: ignore
1590                 yield from self.line()
1591
1592             yield from self.visit(child)
1593
1594     def visit_suite(self, node: Node) -> Iterator[Line]:
1595         """Visit a suite."""
1596         if self.is_pyi and is_stub_suite(node):
1597             yield from self.visit(node.children[2])
1598         else:
1599             yield from self.visit_default(node)
1600
1601     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1602         """Visit a statement without nested statements."""
1603         is_suite_like = node.parent and node.parent.type in STATEMENT
1604         if is_suite_like:
1605             if self.is_pyi and is_stub_body(node):
1606                 yield from self.visit_default(node)
1607             else:
1608                 yield from self.line(+1)
1609                 yield from self.visit_default(node)
1610                 yield from self.line(-1)
1611
1612         else:
1613             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1614                 yield from self.line()
1615             yield from self.visit_default(node)
1616
1617     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1618         """Visit `async def`, `async for`, `async with`."""
1619         yield from self.line()
1620
1621         children = iter(node.children)
1622         for child in children:
1623             yield from self.visit(child)
1624
1625             if child.type == token.ASYNC:
1626                 break
1627
1628         internal_stmt = next(children)
1629         for child in internal_stmt.children:
1630             yield from self.visit(child)
1631
1632     def visit_decorators(self, node: Node) -> Iterator[Line]:
1633         """Visit decorators."""
1634         for child in node.children:
1635             yield from self.line()
1636             yield from self.visit(child)
1637
1638     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1639         """Remove a semicolon and put the other statement on a separate line."""
1640         yield from self.line()
1641
1642     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1643         """End of file. Process outstanding comments and end with a newline."""
1644         yield from self.visit_default(leaf)
1645         yield from self.line()
1646
1647     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1648         if not self.current_line.bracket_tracker.any_open_brackets():
1649             yield from self.line()
1650         yield from self.visit_default(leaf)
1651
1652     def __attrs_post_init__(self) -> None:
1653         """You are in a twisty little maze of passages."""
1654         v = self.visit_stmt
1655         Ø: Set[str] = set()
1656         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1657         self.visit_if_stmt = partial(
1658             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1659         )
1660         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1661         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1662         self.visit_try_stmt = partial(
1663             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1664         )
1665         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1666         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1667         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1668         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1669         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1670         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1671         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1672         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1673         self.visit_async_funcdef = self.visit_async_stmt
1674         self.visit_decorated = self.visit_decorators
1675
1676
1677 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1678 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1679 OPENING_BRACKETS = set(BRACKET.keys())
1680 CLOSING_BRACKETS = set(BRACKET.values())
1681 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1682 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1683
1684
1685 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1686     """Return whitespace prefix if needed for the given `leaf`.
1687
1688     `complex_subscript` signals whether the given leaf is part of a subscription
1689     which has non-trivial arguments, like arithmetic expressions or function calls.
1690     """
1691     NO = ""
1692     SPACE = " "
1693     DOUBLESPACE = "  "
1694     t = leaf.type
1695     p = leaf.parent
1696     v = leaf.value
1697     if t in ALWAYS_NO_SPACE:
1698         return NO
1699
1700     if t == token.COMMENT:
1701         return DOUBLESPACE
1702
1703     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1704     if t == token.COLON and p.type not in {
1705         syms.subscript,
1706         syms.subscriptlist,
1707         syms.sliceop,
1708     }:
1709         return NO
1710
1711     prev = leaf.prev_sibling
1712     if not prev:
1713         prevp = preceding_leaf(p)
1714         if not prevp or prevp.type in OPENING_BRACKETS:
1715             return NO
1716
1717         if t == token.COLON:
1718             if prevp.type == token.COLON:
1719                 return NO
1720
1721             elif prevp.type != token.COMMA and not complex_subscript:
1722                 return NO
1723
1724             return SPACE
1725
1726         if prevp.type == token.EQUAL:
1727             if prevp.parent:
1728                 if prevp.parent.type in {
1729                     syms.arglist,
1730                     syms.argument,
1731                     syms.parameters,
1732                     syms.varargslist,
1733                 }:
1734                     return NO
1735
1736                 elif prevp.parent.type == syms.typedargslist:
1737                     # A bit hacky: if the equal sign has whitespace, it means we
1738                     # previously found it's a typed argument.  So, we're using
1739                     # that, too.
1740                     return prevp.prefix
1741
1742         elif prevp.type in STARS:
1743             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1744                 return NO
1745
1746         elif prevp.type == token.COLON:
1747             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1748                 return SPACE if complex_subscript else NO
1749
1750         elif (
1751             prevp.parent
1752             and prevp.parent.type == syms.factor
1753             and prevp.type in MATH_OPERATORS
1754         ):
1755             return NO
1756
1757         elif (
1758             prevp.type == token.RIGHTSHIFT
1759             and prevp.parent
1760             and prevp.parent.type == syms.shift_expr
1761             and prevp.prev_sibling
1762             and prevp.prev_sibling.type == token.NAME
1763             and prevp.prev_sibling.value == "print"  # type: ignore
1764         ):
1765             # Python 2 print chevron
1766             return NO
1767
1768     elif prev.type in OPENING_BRACKETS:
1769         return NO
1770
1771     if p.type in {syms.parameters, syms.arglist}:
1772         # untyped function signatures or calls
1773         if not prev or prev.type != token.COMMA:
1774             return NO
1775
1776     elif p.type == syms.varargslist:
1777         # lambdas
1778         if prev and prev.type != token.COMMA:
1779             return NO
1780
1781     elif p.type == syms.typedargslist:
1782         # typed function signatures
1783         if not prev:
1784             return NO
1785
1786         if t == token.EQUAL:
1787             if prev.type != syms.tname:
1788                 return NO
1789
1790         elif prev.type == token.EQUAL:
1791             # A bit hacky: if the equal sign has whitespace, it means we
1792             # previously found it's a typed argument.  So, we're using that, too.
1793             return prev.prefix
1794
1795         elif prev.type != token.COMMA:
1796             return NO
1797
1798     elif p.type == syms.tname:
1799         # type names
1800         if not prev:
1801             prevp = preceding_leaf(p)
1802             if not prevp or prevp.type != token.COMMA:
1803                 return NO
1804
1805     elif p.type == syms.trailer:
1806         # attributes and calls
1807         if t == token.LPAR or t == token.RPAR:
1808             return NO
1809
1810         if not prev:
1811             if t == token.DOT:
1812                 prevp = preceding_leaf(p)
1813                 if not prevp or prevp.type != token.NUMBER:
1814                     return NO
1815
1816             elif t == token.LSQB:
1817                 return NO
1818
1819         elif prev.type != token.COMMA:
1820             return NO
1821
1822     elif p.type == syms.argument:
1823         # single argument
1824         if t == token.EQUAL:
1825             return NO
1826
1827         if not prev:
1828             prevp = preceding_leaf(p)
1829             if not prevp or prevp.type == token.LPAR:
1830                 return NO
1831
1832         elif prev.type in {token.EQUAL} | STARS:
1833             return NO
1834
1835     elif p.type == syms.decorator:
1836         # decorators
1837         return NO
1838
1839     elif p.type == syms.dotted_name:
1840         if prev:
1841             return NO
1842
1843         prevp = preceding_leaf(p)
1844         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1845             return NO
1846
1847     elif p.type == syms.classdef:
1848         if t == token.LPAR:
1849             return NO
1850
1851         if prev and prev.type == token.LPAR:
1852             return NO
1853
1854     elif p.type in {syms.subscript, syms.sliceop}:
1855         # indexing
1856         if not prev:
1857             assert p.parent is not None, "subscripts are always parented"
1858             if p.parent.type == syms.subscriptlist:
1859                 return SPACE
1860
1861             return NO
1862
1863         elif not complex_subscript:
1864             return NO
1865
1866     elif p.type == syms.atom:
1867         if prev and t == token.DOT:
1868             # dots, but not the first one.
1869             return NO
1870
1871     elif p.type == syms.dictsetmaker:
1872         # dict unpacking
1873         if prev and prev.type == token.DOUBLESTAR:
1874             return NO
1875
1876     elif p.type in {syms.factor, syms.star_expr}:
1877         # unary ops
1878         if not prev:
1879             prevp = preceding_leaf(p)
1880             if not prevp or prevp.type in OPENING_BRACKETS:
1881                 return NO
1882
1883             prevp_parent = prevp.parent
1884             assert prevp_parent is not None
1885             if prevp.type == token.COLON and prevp_parent.type in {
1886                 syms.subscript,
1887                 syms.sliceop,
1888             }:
1889                 return NO
1890
1891             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1892                 return NO
1893
1894         elif t in {token.NAME, token.NUMBER, token.STRING}:
1895             return NO
1896
1897     elif p.type == syms.import_from:
1898         if t == token.DOT:
1899             if prev and prev.type == token.DOT:
1900                 return NO
1901
1902         elif t == token.NAME:
1903             if v == "import":
1904                 return SPACE
1905
1906             if prev and prev.type == token.DOT:
1907                 return NO
1908
1909     elif p.type == syms.sliceop:
1910         return NO
1911
1912     return SPACE
1913
1914
1915 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1916     """Return the first leaf that precedes `node`, if any."""
1917     while node:
1918         res = node.prev_sibling
1919         if res:
1920             if isinstance(res, Leaf):
1921                 return res
1922
1923             try:
1924                 return list(res.leaves())[-1]
1925
1926             except IndexError:
1927                 return None
1928
1929         node = node.parent
1930     return None
1931
1932
1933 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1934     """Return the child of `ancestor` that contains `descendant`."""
1935     node: Optional[LN] = descendant
1936     while node and node.parent != ancestor:
1937         node = node.parent
1938     return node
1939
1940
1941 def container_of(leaf: Leaf) -> LN:
1942     """Return `leaf` or one of its ancestors that is the topmost container of it.
1943
1944     By "container" we mean a node where `leaf` is the very first child.
1945     """
1946     same_prefix = leaf.prefix
1947     container: LN = leaf
1948     while container:
1949         parent = container.parent
1950         if parent is None:
1951             break
1952
1953         if parent.children[0].prefix != same_prefix:
1954             break
1955
1956         if parent.type == syms.file_input:
1957             break
1958
1959         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1960             break
1961
1962         container = parent
1963     return container
1964
1965
1966 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1967     """Return the priority of the `leaf` delimiter, given a line break after it.
1968
1969     The delimiter priorities returned here are from those delimiters that would
1970     cause a line break after themselves.
1971
1972     Higher numbers are higher priority.
1973     """
1974     if leaf.type == token.COMMA:
1975         return COMMA_PRIORITY
1976
1977     return 0
1978
1979
1980 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1981     """Return the priority of the `leaf` delimiter, given a line break before it.
1982
1983     The delimiter priorities returned here are from those delimiters that would
1984     cause a line break before themselves.
1985
1986     Higher numbers are higher priority.
1987     """
1988     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1989         # * and ** might also be MATH_OPERATORS but in this case they are not.
1990         # Don't treat them as a delimiter.
1991         return 0
1992
1993     if (
1994         leaf.type == token.DOT
1995         and leaf.parent
1996         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1997         and (previous is None or previous.type in CLOSING_BRACKETS)
1998     ):
1999         return DOT_PRIORITY
2000
2001     if (
2002         leaf.type in MATH_OPERATORS
2003         and leaf.parent
2004         and leaf.parent.type not in {syms.factor, syms.star_expr}
2005     ):
2006         return MATH_PRIORITIES[leaf.type]
2007
2008     if leaf.type in COMPARATORS:
2009         return COMPARATOR_PRIORITY
2010
2011     if (
2012         leaf.type == token.STRING
2013         and previous is not None
2014         and previous.type == token.STRING
2015     ):
2016         return STRING_PRIORITY
2017
2018     if leaf.type not in {token.NAME, token.ASYNC}:
2019         return 0
2020
2021     if (
2022         leaf.value == "for"
2023         and leaf.parent
2024         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2025         or leaf.type == token.ASYNC
2026     ):
2027         if (
2028             not isinstance(leaf.prev_sibling, Leaf)
2029             or leaf.prev_sibling.value != "async"
2030         ):
2031             return COMPREHENSION_PRIORITY
2032
2033     if (
2034         leaf.value == "if"
2035         and leaf.parent
2036         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2037     ):
2038         return COMPREHENSION_PRIORITY
2039
2040     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2041         return TERNARY_PRIORITY
2042
2043     if leaf.value == "is":
2044         return COMPARATOR_PRIORITY
2045
2046     if (
2047         leaf.value == "in"
2048         and leaf.parent
2049         and leaf.parent.type in {syms.comp_op, syms.comparison}
2050         and not (
2051             previous is not None
2052             and previous.type == token.NAME
2053             and previous.value == "not"
2054         )
2055     ):
2056         return COMPARATOR_PRIORITY
2057
2058     if (
2059         leaf.value == "not"
2060         and leaf.parent
2061         and leaf.parent.type == syms.comp_op
2062         and not (
2063             previous is not None
2064             and previous.type == token.NAME
2065             and previous.value == "is"
2066         )
2067     ):
2068         return COMPARATOR_PRIORITY
2069
2070     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2071         return LOGIC_PRIORITY
2072
2073     return 0
2074
2075
2076 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2077 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2078
2079
2080 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2081     """Clean the prefix of the `leaf` and generate comments from it, if any.
2082
2083     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2084     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2085     move because it does away with modifying the grammar to include all the
2086     possible places in which comments can be placed.
2087
2088     The sad consequence for us though is that comments don't "belong" anywhere.
2089     This is why this function generates simple parentless Leaf objects for
2090     comments.  We simply don't know what the correct parent should be.
2091
2092     No matter though, we can live without this.  We really only need to
2093     differentiate between inline and standalone comments.  The latter don't
2094     share the line with any code.
2095
2096     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2097     are emitted with a fake STANDALONE_COMMENT token identifier.
2098     """
2099     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2100         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2101
2102
2103 @dataclass
2104 class ProtoComment:
2105     """Describes a piece of syntax that is a comment.
2106
2107     It's not a :class:`blib2to3.pytree.Leaf` so that:
2108
2109     * it can be cached (`Leaf` objects should not be reused more than once as
2110       they store their lineno, column, prefix, and parent information);
2111     * `newlines` and `consumed` fields are kept separate from the `value`. This
2112       simplifies handling of special marker comments like ``# fmt: off/on``.
2113     """
2114
2115     type: int  # token.COMMENT or STANDALONE_COMMENT
2116     value: str  # content of the comment
2117     newlines: int  # how many newlines before the comment
2118     consumed: int  # how many characters of the original leaf's prefix did we consume
2119
2120
2121 @lru_cache(maxsize=4096)
2122 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2123     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2124     result: List[ProtoComment] = []
2125     if not prefix or "#" not in prefix:
2126         return result
2127
2128     consumed = 0
2129     nlines = 0
2130     for index, line in enumerate(prefix.split("\n")):
2131         consumed += len(line) + 1  # adding the length of the split '\n'
2132         line = line.lstrip()
2133         if not line:
2134             nlines += 1
2135         if not line.startswith("#"):
2136             continue
2137
2138         if index == 0 and not is_endmarker:
2139             comment_type = token.COMMENT  # simple trailing comment
2140         else:
2141             comment_type = STANDALONE_COMMENT
2142         comment = make_comment(line)
2143         result.append(
2144             ProtoComment(
2145                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2146             )
2147         )
2148         nlines = 0
2149     return result
2150
2151
2152 def make_comment(content: str) -> str:
2153     """Return a consistently formatted comment from the given `content` string.
2154
2155     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2156     space between the hash sign and the content.
2157
2158     If `content` didn't start with a hash sign, one is provided.
2159     """
2160     content = content.rstrip()
2161     if not content:
2162         return "#"
2163
2164     if content[0] == "#":
2165         content = content[1:]
2166     if content and content[0] not in " !:#'%":
2167         content = " " + content
2168     return "#" + content
2169
2170
2171 def split_line(
2172     line: Line,
2173     line_length: int,
2174     inner: bool = False,
2175     features: Collection[Feature] = (),
2176 ) -> Iterator[Line]:
2177     """Split a `line` into potentially many lines.
2178
2179     They should fit in the allotted `line_length` but might not be able to.
2180     `inner` signifies that there were a pair of brackets somewhere around the
2181     current `line`, possibly transitively. This means we can fallback to splitting
2182     by delimiters if the LHS/RHS don't yield any results.
2183
2184     `features` are syntactical features that may be used in the output.
2185     """
2186     if line.is_comment:
2187         yield line
2188         return
2189
2190     line_str = str(line).strip("\n")
2191
2192     if (
2193         not line.contains_inner_type_comments()
2194         and not line.should_explode
2195         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2196     ):
2197         yield line
2198         return
2199
2200     split_funcs: List[SplitFunc]
2201     if line.is_def:
2202         split_funcs = [left_hand_split]
2203     else:
2204
2205         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2206             for omit in generate_trailers_to_omit(line, line_length):
2207                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2208                 if is_line_short_enough(lines[0], line_length=line_length):
2209                     yield from lines
2210                     return
2211
2212             # All splits failed, best effort split with no omits.
2213             # This mostly happens to multiline strings that are by definition
2214             # reported as not fitting a single line.
2215             yield from right_hand_split(line, line_length, features=features)
2216
2217         if line.inside_brackets:
2218             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2219         else:
2220             split_funcs = [rhs]
2221     for split_func in split_funcs:
2222         # We are accumulating lines in `result` because we might want to abort
2223         # mission and return the original line in the end, or attempt a different
2224         # split altogether.
2225         result: List[Line] = []
2226         try:
2227             for l in split_func(line, features):
2228                 if str(l).strip("\n") == line_str:
2229                     raise CannotSplit("Split function returned an unchanged result")
2230
2231                 result.extend(
2232                     split_line(
2233                         l, line_length=line_length, inner=True, features=features
2234                     )
2235                 )
2236         except CannotSplit:
2237             continue
2238
2239         else:
2240             yield from result
2241             break
2242
2243     else:
2244         yield line
2245
2246
2247 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2248     """Split line into many lines, starting with the first matching bracket pair.
2249
2250     Note: this usually looks weird, only use this for function definitions.
2251     Prefer RHS otherwise.  This is why this function is not symmetrical with
2252     :func:`right_hand_split` which also handles optional parentheses.
2253     """
2254     tail_leaves: List[Leaf] = []
2255     body_leaves: List[Leaf] = []
2256     head_leaves: List[Leaf] = []
2257     current_leaves = head_leaves
2258     matching_bracket = None
2259     for leaf in line.leaves:
2260         if (
2261             current_leaves is body_leaves
2262             and leaf.type in CLOSING_BRACKETS
2263             and leaf.opening_bracket is matching_bracket
2264         ):
2265             current_leaves = tail_leaves if body_leaves else head_leaves
2266         current_leaves.append(leaf)
2267         if current_leaves is head_leaves:
2268             if leaf.type in OPENING_BRACKETS:
2269                 matching_bracket = leaf
2270                 current_leaves = body_leaves
2271     if not matching_bracket:
2272         raise CannotSplit("No brackets found")
2273
2274     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2275     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2276     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2277     bracket_split_succeeded_or_raise(head, body, tail)
2278     for result in (head, body, tail):
2279         if result:
2280             yield result
2281
2282
2283 def right_hand_split(
2284     line: Line,
2285     line_length: int,
2286     features: Collection[Feature] = (),
2287     omit: Collection[LeafID] = (),
2288 ) -> Iterator[Line]:
2289     """Split line into many lines, starting with the last matching bracket pair.
2290
2291     If the split was by optional parentheses, attempt splitting without them, too.
2292     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2293     this split.
2294
2295     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2296     """
2297     tail_leaves: List[Leaf] = []
2298     body_leaves: List[Leaf] = []
2299     head_leaves: List[Leaf] = []
2300     current_leaves = tail_leaves
2301     opening_bracket = None
2302     closing_bracket = None
2303     for leaf in reversed(line.leaves):
2304         if current_leaves is body_leaves:
2305             if leaf is opening_bracket:
2306                 current_leaves = head_leaves if body_leaves else tail_leaves
2307         current_leaves.append(leaf)
2308         if current_leaves is tail_leaves:
2309             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2310                 opening_bracket = leaf.opening_bracket
2311                 closing_bracket = leaf
2312                 current_leaves = body_leaves
2313     if not (opening_bracket and closing_bracket and head_leaves):
2314         # If there is no opening or closing_bracket that means the split failed and
2315         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2316         # the matching `opening_bracket` wasn't available on `line` anymore.
2317         raise CannotSplit("No brackets found")
2318
2319     tail_leaves.reverse()
2320     body_leaves.reverse()
2321     head_leaves.reverse()
2322     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2323     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2324     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2325     bracket_split_succeeded_or_raise(head, body, tail)
2326     if (
2327         # the body shouldn't be exploded
2328         not body.should_explode
2329         # the opening bracket is an optional paren
2330         and opening_bracket.type == token.LPAR
2331         and not opening_bracket.value
2332         # the closing bracket is an optional paren
2333         and closing_bracket.type == token.RPAR
2334         and not closing_bracket.value
2335         # it's not an import (optional parens are the only thing we can split on
2336         # in this case; attempting a split without them is a waste of time)
2337         and not line.is_import
2338         # there are no standalone comments in the body
2339         and not body.contains_standalone_comments(0)
2340         # and we can actually remove the parens
2341         and can_omit_invisible_parens(body, line_length)
2342     ):
2343         omit = {id(closing_bracket), *omit}
2344         try:
2345             yield from right_hand_split(line, line_length, features=features, omit=omit)
2346             return
2347
2348         except CannotSplit:
2349             if not (
2350                 can_be_split(body)
2351                 or is_line_short_enough(body, line_length=line_length)
2352             ):
2353                 raise CannotSplit(
2354                     "Splitting failed, body is still too long and can't be split."
2355                 )
2356
2357             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2358                 raise CannotSplit(
2359                     "The current optional pair of parentheses is bound to fail to "
2360                     "satisfy the splitting algorithm because the head or the tail "
2361                     "contains multiline strings which by definition never fit one "
2362                     "line."
2363                 )
2364
2365     ensure_visible(opening_bracket)
2366     ensure_visible(closing_bracket)
2367     for result in (head, body, tail):
2368         if result:
2369             yield result
2370
2371
2372 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2373     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2374
2375     Do nothing otherwise.
2376
2377     A left- or right-hand split is based on a pair of brackets. Content before
2378     (and including) the opening bracket is left on one line, content inside the
2379     brackets is put on a separate line, and finally content starting with and
2380     following the closing bracket is put on a separate line.
2381
2382     Those are called `head`, `body`, and `tail`, respectively. If the split
2383     produced the same line (all content in `head`) or ended up with an empty `body`
2384     and the `tail` is just the closing bracket, then it's considered failed.
2385     """
2386     tail_len = len(str(tail).strip())
2387     if not body:
2388         if tail_len == 0:
2389             raise CannotSplit("Splitting brackets produced the same line")
2390
2391         elif tail_len < 3:
2392             raise CannotSplit(
2393                 f"Splitting brackets on an empty body to save "
2394                 f"{tail_len} characters is not worth it"
2395             )
2396
2397
2398 def bracket_split_build_line(
2399     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2400 ) -> Line:
2401     """Return a new line with given `leaves` and respective comments from `original`.
2402
2403     If `is_body` is True, the result line is one-indented inside brackets and as such
2404     has its first leaf's prefix normalized and a trailing comma added when expected.
2405     """
2406     result = Line(depth=original.depth)
2407     if is_body:
2408         result.inside_brackets = True
2409         result.depth += 1
2410         if leaves:
2411             # Since body is a new indent level, remove spurious leading whitespace.
2412             normalize_prefix(leaves[0], inside_brackets=True)
2413             # Ensure a trailing comma for imports, but be careful not to add one after
2414             # any comments.
2415             if original.is_import:
2416                 for i in range(len(leaves) - 1, -1, -1):
2417                     if leaves[i].type == STANDALONE_COMMENT:
2418                         continue
2419                     elif leaves[i].type == token.COMMA:
2420                         break
2421                     else:
2422                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2423                         break
2424     # Populate the line
2425     for leaf in leaves:
2426         result.append(leaf, preformatted=True)
2427         for comment_after in original.comments_after(leaf):
2428             result.append(comment_after, preformatted=True)
2429     if is_body:
2430         result.should_explode = should_explode(result, opening_bracket)
2431     return result
2432
2433
2434 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2435     """Normalize prefix of the first leaf in every line returned by `split_func`.
2436
2437     This is a decorator over relevant split functions.
2438     """
2439
2440     @wraps(split_func)
2441     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2442         for l in split_func(line, features):
2443             normalize_prefix(l.leaves[0], inside_brackets=True)
2444             yield l
2445
2446     return split_wrapper
2447
2448
2449 @dont_increase_indentation
2450 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2451     """Split according to delimiters of the highest priority.
2452
2453     If the appropriate Features are given, the split will add trailing commas
2454     also in function signatures and calls that contain `*` and `**`.
2455     """
2456     try:
2457         last_leaf = line.leaves[-1]
2458     except IndexError:
2459         raise CannotSplit("Line empty")
2460
2461     bt = line.bracket_tracker
2462     try:
2463         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2464     except ValueError:
2465         raise CannotSplit("No delimiters found")
2466
2467     if delimiter_priority == DOT_PRIORITY:
2468         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2469             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2470
2471     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2472     lowest_depth = sys.maxsize
2473     trailing_comma_safe = True
2474
2475     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2476         """Append `leaf` to current line or to new line if appending impossible."""
2477         nonlocal current_line
2478         try:
2479             current_line.append_safe(leaf, preformatted=True)
2480         except ValueError:
2481             yield current_line
2482
2483             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2484             current_line.append(leaf)
2485
2486     for leaf in line.leaves:
2487         yield from append_to_line(leaf)
2488
2489         for comment_after in line.comments_after(leaf):
2490             yield from append_to_line(comment_after)
2491
2492         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2493         if leaf.bracket_depth == lowest_depth:
2494             if is_vararg(leaf, within={syms.typedargslist}):
2495                 trailing_comma_safe = (
2496                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2497                 )
2498             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2499                 trailing_comma_safe = (
2500                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2501                 )
2502
2503         leaf_priority = bt.delimiters.get(id(leaf))
2504         if leaf_priority == delimiter_priority:
2505             yield current_line
2506
2507             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2508     if current_line:
2509         if (
2510             trailing_comma_safe
2511             and delimiter_priority == COMMA_PRIORITY
2512             and current_line.leaves[-1].type != token.COMMA
2513             and current_line.leaves[-1].type != STANDALONE_COMMENT
2514         ):
2515             current_line.append(Leaf(token.COMMA, ","))
2516         yield current_line
2517
2518
2519 @dont_increase_indentation
2520 def standalone_comment_split(
2521     line: Line, features: Collection[Feature] = ()
2522 ) -> Iterator[Line]:
2523     """Split standalone comments from the rest of the line."""
2524     if not line.contains_standalone_comments(0):
2525         raise CannotSplit("Line does not have any standalone comments")
2526
2527     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2528
2529     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2530         """Append `leaf` to current line or to new line if appending impossible."""
2531         nonlocal current_line
2532         try:
2533             current_line.append_safe(leaf, preformatted=True)
2534         except ValueError:
2535             yield current_line
2536
2537             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2538             current_line.append(leaf)
2539
2540     for leaf in line.leaves:
2541         yield from append_to_line(leaf)
2542
2543         for comment_after in line.comments_after(leaf):
2544             yield from append_to_line(comment_after)
2545
2546     if current_line:
2547         yield current_line
2548
2549
2550 def is_import(leaf: Leaf) -> bool:
2551     """Return True if the given leaf starts an import statement."""
2552     p = leaf.parent
2553     t = leaf.type
2554     v = leaf.value
2555     return bool(
2556         t == token.NAME
2557         and (
2558             (v == "import" and p and p.type == syms.import_name)
2559             or (v == "from" and p and p.type == syms.import_from)
2560         )
2561     )
2562
2563
2564 def is_type_comment(leaf: Leaf) -> bool:
2565     """Return True if the given leaf is a special comment.
2566     Only returns true for type comments for now."""
2567     t = leaf.type
2568     v = leaf.value
2569     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2570
2571
2572 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2573     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2574     else.
2575
2576     Note: don't use backslashes for formatting or you'll lose your voting rights.
2577     """
2578     if not inside_brackets:
2579         spl = leaf.prefix.split("#")
2580         if "\\" not in spl[0]:
2581             nl_count = spl[-1].count("\n")
2582             if len(spl) > 1:
2583                 nl_count -= 1
2584             leaf.prefix = "\n" * nl_count
2585             return
2586
2587     leaf.prefix = ""
2588
2589
2590 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2591     """Make all string prefixes lowercase.
2592
2593     If remove_u_prefix is given, also removes any u prefix from the string.
2594
2595     Note: Mutates its argument.
2596     """
2597     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2598     assert match is not None, f"failed to match string {leaf.value!r}"
2599     orig_prefix = match.group(1)
2600     new_prefix = orig_prefix.lower()
2601     if remove_u_prefix:
2602         new_prefix = new_prefix.replace("u", "")
2603     leaf.value = f"{new_prefix}{match.group(2)}"
2604
2605
2606 def normalize_string_quotes(leaf: Leaf) -> None:
2607     """Prefer double quotes but only if it doesn't cause more escaping.
2608
2609     Adds or removes backslashes as appropriate. Doesn't parse and fix
2610     strings nested in f-strings (yet).
2611
2612     Note: Mutates its argument.
2613     """
2614     value = leaf.value.lstrip("furbFURB")
2615     if value[:3] == '"""':
2616         return
2617
2618     elif value[:3] == "'''":
2619         orig_quote = "'''"
2620         new_quote = '"""'
2621     elif value[0] == '"':
2622         orig_quote = '"'
2623         new_quote = "'"
2624     else:
2625         orig_quote = "'"
2626         new_quote = '"'
2627     first_quote_pos = leaf.value.find(orig_quote)
2628     if first_quote_pos == -1:
2629         return  # There's an internal error
2630
2631     prefix = leaf.value[:first_quote_pos]
2632     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2633     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2634     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2635     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2636     if "r" in prefix.casefold():
2637         if unescaped_new_quote.search(body):
2638             # There's at least one unescaped new_quote in this raw string
2639             # so converting is impossible
2640             return
2641
2642         # Do not introduce or remove backslashes in raw strings
2643         new_body = body
2644     else:
2645         # remove unnecessary escapes
2646         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2647         if body != new_body:
2648             # Consider the string without unnecessary escapes as the original
2649             body = new_body
2650             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2651         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2652         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2653     if "f" in prefix.casefold():
2654         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2655         for m in matches:
2656             if "\\" in str(m):
2657                 # Do not introduce backslashes in interpolated expressions
2658                 return
2659     if new_quote == '"""' and new_body[-1:] == '"':
2660         # edge case:
2661         new_body = new_body[:-1] + '\\"'
2662     orig_escape_count = body.count("\\")
2663     new_escape_count = new_body.count("\\")
2664     if new_escape_count > orig_escape_count:
2665         return  # Do not introduce more escaping
2666
2667     if new_escape_count == orig_escape_count and orig_quote == '"':
2668         return  # Prefer double quotes
2669
2670     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2671
2672
2673 def normalize_numeric_literal(leaf: Leaf) -> None:
2674     """Normalizes numeric (float, int, and complex) literals.
2675
2676     All letters used in the representation are normalized to lowercase (except
2677     in Python 2 long literals).
2678     """
2679     text = leaf.value.lower()
2680     if text.startswith(("0o", "0b")):
2681         # Leave octal and binary literals alone.
2682         pass
2683     elif text.startswith("0x"):
2684         # Change hex literals to upper case.
2685         before, after = text[:2], text[2:]
2686         text = f"{before}{after.upper()}"
2687     elif "e" in text:
2688         before, after = text.split("e")
2689         sign = ""
2690         if after.startswith("-"):
2691             after = after[1:]
2692             sign = "-"
2693         elif after.startswith("+"):
2694             after = after[1:]
2695         before = format_float_or_int_string(before)
2696         text = f"{before}e{sign}{after}"
2697     elif text.endswith(("j", "l")):
2698         number = text[:-1]
2699         suffix = text[-1]
2700         # Capitalize in "2L" because "l" looks too similar to "1".
2701         if suffix == "l":
2702             suffix = "L"
2703         text = f"{format_float_or_int_string(number)}{suffix}"
2704     else:
2705         text = format_float_or_int_string(text)
2706     leaf.value = text
2707
2708
2709 def format_float_or_int_string(text: str) -> str:
2710     """Formats a float string like "1.0"."""
2711     if "." not in text:
2712         return text
2713
2714     before, after = text.split(".")
2715     return f"{before or 0}.{after or 0}"
2716
2717
2718 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2719     """Make existing optional parentheses invisible or create new ones.
2720
2721     `parens_after` is a set of string leaf values immeditely after which parens
2722     should be put.
2723
2724     Standardizes on visible parentheses for single-element tuples, and keeps
2725     existing visible parentheses for other tuples and generator expressions.
2726     """
2727     for pc in list_comments(node.prefix, is_endmarker=False):
2728         if pc.value in FMT_OFF:
2729             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2730             return
2731
2732     check_lpar = False
2733     for index, child in enumerate(list(node.children)):
2734         # Add parentheses around long tuple unpacking in assignments.
2735         if (
2736             index == 0
2737             and isinstance(child, Node)
2738             and child.type == syms.testlist_star_expr
2739         ):
2740             check_lpar = True
2741
2742         if check_lpar:
2743             if child.type == syms.atom:
2744                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2745                     lpar = Leaf(token.LPAR, "")
2746                     rpar = Leaf(token.RPAR, "")
2747                     index = child.remove() or 0
2748                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2749             elif is_one_tuple(child):
2750                 # wrap child in visible parentheses
2751                 lpar = Leaf(token.LPAR, "(")
2752                 rpar = Leaf(token.RPAR, ")")
2753                 child.remove()
2754                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2755             elif node.type == syms.import_from:
2756                 # "import from" nodes store parentheses directly as part of
2757                 # the statement
2758                 if child.type == token.LPAR:
2759                     # make parentheses invisible
2760                     child.value = ""  # type: ignore
2761                     node.children[-1].value = ""  # type: ignore
2762                 elif child.type != token.STAR:
2763                     # insert invisible parentheses
2764                     node.insert_child(index, Leaf(token.LPAR, ""))
2765                     node.append_child(Leaf(token.RPAR, ""))
2766                 break
2767
2768             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2769                 # wrap child in invisible parentheses
2770                 lpar = Leaf(token.LPAR, "")
2771                 rpar = Leaf(token.RPAR, "")
2772                 index = child.remove() or 0
2773                 prefix = child.prefix
2774                 child.prefix = ""
2775                 new_child = Node(syms.atom, [lpar, child, rpar])
2776                 new_child.prefix = prefix
2777                 node.insert_child(index, new_child)
2778
2779         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2780
2781
2782 def normalize_fmt_off(node: Node) -> None:
2783     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2784     try_again = True
2785     while try_again:
2786         try_again = convert_one_fmt_off_pair(node)
2787
2788
2789 def convert_one_fmt_off_pair(node: Node) -> bool:
2790     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2791
2792     Returns True if a pair was converted.
2793     """
2794     for leaf in node.leaves():
2795         previous_consumed = 0
2796         for comment in list_comments(leaf.prefix, is_endmarker=False):
2797             if comment.value in FMT_OFF:
2798                 # We only want standalone comments. If there's no previous leaf or
2799                 # the previous leaf is indentation, it's a standalone comment in
2800                 # disguise.
2801                 if comment.type != STANDALONE_COMMENT:
2802                     prev = preceding_leaf(leaf)
2803                     if prev and prev.type not in WHITESPACE:
2804                         continue
2805
2806                 ignored_nodes = list(generate_ignored_nodes(leaf))
2807                 if not ignored_nodes:
2808                     continue
2809
2810                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2811                 parent = first.parent
2812                 prefix = first.prefix
2813                 first.prefix = prefix[comment.consumed :]
2814                 hidden_value = (
2815                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2816                 )
2817                 if hidden_value.endswith("\n"):
2818                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2819                     # leaf (possibly followed by a DEDENT).
2820                     hidden_value = hidden_value[:-1]
2821                 first_idx = None
2822                 for ignored in ignored_nodes:
2823                     index = ignored.remove()
2824                     if first_idx is None:
2825                         first_idx = index
2826                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2827                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2828                 parent.insert_child(
2829                     first_idx,
2830                     Leaf(
2831                         STANDALONE_COMMENT,
2832                         hidden_value,
2833                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2834                     ),
2835                 )
2836                 return True
2837
2838             previous_consumed = comment.consumed
2839
2840     return False
2841
2842
2843 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2844     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2845
2846     Stops at the end of the block.
2847     """
2848     container: Optional[LN] = container_of(leaf)
2849     while container is not None and container.type != token.ENDMARKER:
2850         for comment in list_comments(container.prefix, is_endmarker=False):
2851             if comment.value in FMT_ON:
2852                 return
2853
2854         yield container
2855
2856         container = container.next_sibling
2857
2858
2859 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
2860     """If it's safe, make the parens in the atom `node` invisible, recursively.
2861
2862     Returns whether the node should itself be wrapped in invisible parentheses.
2863
2864     """
2865     if (
2866         node.type != syms.atom
2867         or is_empty_tuple(node)
2868         or is_one_tuple(node)
2869         or (is_yield(node) and parent.type != syms.expr_stmt)
2870         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2871     ):
2872         return False
2873
2874     first = node.children[0]
2875     last = node.children[-1]
2876     if first.type == token.LPAR and last.type == token.RPAR:
2877         # make parentheses invisible
2878         first.value = ""  # type: ignore
2879         last.value = ""  # type: ignore
2880         if len(node.children) > 1:
2881             maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
2882         return False
2883
2884     return True
2885
2886
2887 def is_empty_tuple(node: LN) -> bool:
2888     """Return True if `node` holds an empty tuple."""
2889     return (
2890         node.type == syms.atom
2891         and len(node.children) == 2
2892         and node.children[0].type == token.LPAR
2893         and node.children[1].type == token.RPAR
2894     )
2895
2896
2897 def is_one_tuple(node: LN) -> bool:
2898     """Return True if `node` holds a tuple with one element, with or without parens."""
2899     if node.type == syms.atom:
2900         if len(node.children) != 3:
2901             return False
2902
2903         lpar, gexp, rpar = node.children
2904         if not (
2905             lpar.type == token.LPAR
2906             and gexp.type == syms.testlist_gexp
2907             and rpar.type == token.RPAR
2908         ):
2909             return False
2910
2911         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2912
2913     return (
2914         node.type in IMPLICIT_TUPLE
2915         and len(node.children) == 2
2916         and node.children[1].type == token.COMMA
2917     )
2918
2919
2920 def is_yield(node: LN) -> bool:
2921     """Return True if `node` holds a `yield` or `yield from` expression."""
2922     if node.type == syms.yield_expr:
2923         return True
2924
2925     if node.type == token.NAME and node.value == "yield":  # type: ignore
2926         return True
2927
2928     if node.type != syms.atom:
2929         return False
2930
2931     if len(node.children) != 3:
2932         return False
2933
2934     lpar, expr, rpar = node.children
2935     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2936         return is_yield(expr)
2937
2938     return False
2939
2940
2941 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2942     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2943
2944     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2945     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2946     extended iterable unpacking (PEP 3132) and additional unpacking
2947     generalizations (PEP 448).
2948     """
2949     if leaf.type not in STARS or not leaf.parent:
2950         return False
2951
2952     p = leaf.parent
2953     if p.type == syms.star_expr:
2954         # Star expressions are also used as assignment targets in extended
2955         # iterable unpacking (PEP 3132).  See what its parent is instead.
2956         if not p.parent:
2957             return False
2958
2959         p = p.parent
2960
2961     return p.type in within
2962
2963
2964 def is_multiline_string(leaf: Leaf) -> bool:
2965     """Return True if `leaf` is a multiline string that actually spans many lines."""
2966     value = leaf.value.lstrip("furbFURB")
2967     return value[:3] in {'"""', "'''"} and "\n" in value
2968
2969
2970 def is_stub_suite(node: Node) -> bool:
2971     """Return True if `node` is a suite with a stub body."""
2972     if (
2973         len(node.children) != 4
2974         or node.children[0].type != token.NEWLINE
2975         or node.children[1].type != token.INDENT
2976         or node.children[3].type != token.DEDENT
2977     ):
2978         return False
2979
2980     return is_stub_body(node.children[2])
2981
2982
2983 def is_stub_body(node: LN) -> bool:
2984     """Return True if `node` is a simple statement containing an ellipsis."""
2985     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2986         return False
2987
2988     if len(node.children) != 2:
2989         return False
2990
2991     child = node.children[0]
2992     return (
2993         child.type == syms.atom
2994         and len(child.children) == 3
2995         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2996     )
2997
2998
2999 def max_delimiter_priority_in_atom(node: LN) -> int:
3000     """Return maximum delimiter priority inside `node`.
3001
3002     This is specific to atoms with contents contained in a pair of parentheses.
3003     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3004     """
3005     if node.type != syms.atom:
3006         return 0
3007
3008     first = node.children[0]
3009     last = node.children[-1]
3010     if not (first.type == token.LPAR and last.type == token.RPAR):
3011         return 0
3012
3013     bt = BracketTracker()
3014     for c in node.children[1:-1]:
3015         if isinstance(c, Leaf):
3016             bt.mark(c)
3017         else:
3018             for leaf in c.leaves():
3019                 bt.mark(leaf)
3020     try:
3021         return bt.max_delimiter_priority()
3022
3023     except ValueError:
3024         return 0
3025
3026
3027 def ensure_visible(leaf: Leaf) -> None:
3028     """Make sure parentheses are visible.
3029
3030     They could be invisible as part of some statements (see
3031     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3032     """
3033     if leaf.type == token.LPAR:
3034         leaf.value = "("
3035     elif leaf.type == token.RPAR:
3036         leaf.value = ")"
3037
3038
3039 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3040     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3041
3042     if not (
3043         opening_bracket.parent
3044         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3045         and opening_bracket.value in "[{("
3046     ):
3047         return False
3048
3049     try:
3050         last_leaf = line.leaves[-1]
3051         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3052         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3053     except (IndexError, ValueError):
3054         return False
3055
3056     return max_priority == COMMA_PRIORITY
3057
3058
3059 def get_features_used(node: Node) -> Set[Feature]:
3060     """Return a set of (relatively) new Python features used in this file.
3061
3062     Currently looking for:
3063     - f-strings;
3064     - underscores in numeric literals; and
3065     - trailing commas after * or ** in function signatures and calls.
3066     """
3067     features: Set[Feature] = set()
3068     for n in node.pre_order():
3069         if n.type == token.STRING:
3070             value_head = n.value[:2]  # type: ignore
3071             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3072                 features.add(Feature.F_STRINGS)
3073
3074         elif n.type == token.NUMBER:
3075             if "_" in n.value:  # type: ignore
3076                 features.add(Feature.NUMERIC_UNDERSCORES)
3077
3078         elif (
3079             n.type in {syms.typedargslist, syms.arglist}
3080             and n.children
3081             and n.children[-1].type == token.COMMA
3082         ):
3083             if n.type == syms.typedargslist:
3084                 feature = Feature.TRAILING_COMMA_IN_DEF
3085             else:
3086                 feature = Feature.TRAILING_COMMA_IN_CALL
3087
3088             for ch in n.children:
3089                 if ch.type in STARS:
3090                     features.add(feature)
3091
3092                 if ch.type == syms.argument:
3093                     for argch in ch.children:
3094                         if argch.type in STARS:
3095                             features.add(feature)
3096
3097     return features
3098
3099
3100 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3101     """Detect the version to target based on the nodes used."""
3102     features = get_features_used(node)
3103     return {
3104         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3105     }
3106
3107
3108 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3109     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3110
3111     Brackets can be omitted if the entire trailer up to and including
3112     a preceding closing bracket fits in one line.
3113
3114     Yielded sets are cumulative (contain results of previous yields, too).  First
3115     set is empty.
3116     """
3117
3118     omit: Set[LeafID] = set()
3119     yield omit
3120
3121     length = 4 * line.depth
3122     opening_bracket = None
3123     closing_bracket = None
3124     inner_brackets: Set[LeafID] = set()
3125     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3126         length += leaf_length
3127         if length > line_length:
3128             break
3129
3130         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3131         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3132             break
3133
3134         if opening_bracket:
3135             if leaf is opening_bracket:
3136                 opening_bracket = None
3137             elif leaf.type in CLOSING_BRACKETS:
3138                 inner_brackets.add(id(leaf))
3139         elif leaf.type in CLOSING_BRACKETS:
3140             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3141                 # Empty brackets would fail a split so treat them as "inner"
3142                 # brackets (e.g. only add them to the `omit` set if another
3143                 # pair of brackets was good enough.
3144                 inner_brackets.add(id(leaf))
3145                 continue
3146
3147             if closing_bracket:
3148                 omit.add(id(closing_bracket))
3149                 omit.update(inner_brackets)
3150                 inner_brackets.clear()
3151                 yield omit
3152
3153             if leaf.value:
3154                 opening_bracket = leaf.opening_bracket
3155                 closing_bracket = leaf
3156
3157
3158 def get_future_imports(node: Node) -> Set[str]:
3159     """Return a set of __future__ imports in the file."""
3160     imports: Set[str] = set()
3161
3162     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3163         for child in children:
3164             if isinstance(child, Leaf):
3165                 if child.type == token.NAME:
3166                     yield child.value
3167             elif child.type == syms.import_as_name:
3168                 orig_name = child.children[0]
3169                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3170                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3171                 yield orig_name.value
3172             elif child.type == syms.import_as_names:
3173                 yield from get_imports_from_children(child.children)
3174             else:
3175                 raise AssertionError("Invalid syntax parsing imports")
3176
3177     for child in node.children:
3178         if child.type != syms.simple_stmt:
3179             break
3180         first_child = child.children[0]
3181         if isinstance(first_child, Leaf):
3182             # Continue looking if we see a docstring; otherwise stop.
3183             if (
3184                 len(child.children) == 2
3185                 and first_child.type == token.STRING
3186                 and child.children[1].type == token.NEWLINE
3187             ):
3188                 continue
3189             else:
3190                 break
3191         elif first_child.type == syms.import_from:
3192             module_name = first_child.children[1]
3193             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3194                 break
3195             imports |= set(get_imports_from_children(first_child.children[3:]))
3196         else:
3197             break
3198     return imports
3199
3200
3201 def gen_python_files_in_dir(
3202     path: Path,
3203     root: Path,
3204     include: Pattern[str],
3205     exclude: Pattern[str],
3206     report: "Report",
3207 ) -> Iterator[Path]:
3208     """Generate all files under `path` whose paths are not excluded by the
3209     `exclude` regex, but are included by the `include` regex.
3210
3211     Symbolic links pointing outside of the `root` directory are ignored.
3212
3213     `report` is where output about exclusions goes.
3214     """
3215     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3216     for child in path.iterdir():
3217         try:
3218             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3219         except ValueError:
3220             if child.is_symlink():
3221                 report.path_ignored(
3222                     child, f"is a symbolic link that points outside {root}"
3223                 )
3224                 continue
3225
3226             raise
3227
3228         if child.is_dir():
3229             normalized_path += "/"
3230         exclude_match = exclude.search(normalized_path)
3231         if exclude_match and exclude_match.group(0):
3232             report.path_ignored(child, f"matches the --exclude regular expression")
3233             continue
3234
3235         if child.is_dir():
3236             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3237
3238         elif child.is_file():
3239             include_match = include.search(normalized_path)
3240             if include_match:
3241                 yield child
3242
3243
3244 @lru_cache()
3245 def find_project_root(srcs: Iterable[str]) -> Path:
3246     """Return a directory containing .git, .hg, or pyproject.toml.
3247
3248     That directory can be one of the directories passed in `srcs` or their
3249     common parent.
3250
3251     If no directory in the tree contains a marker that would specify it's the
3252     project root, the root of the file system is returned.
3253     """
3254     if not srcs:
3255         return Path("/").resolve()
3256
3257     common_base = min(Path(src).resolve() for src in srcs)
3258     if common_base.is_dir():
3259         # Append a fake file so `parents` below returns `common_base_dir`, too.
3260         common_base /= "fake-file"
3261     for directory in common_base.parents:
3262         if (directory / ".git").is_dir():
3263             return directory
3264
3265         if (directory / ".hg").is_dir():
3266             return directory
3267
3268         if (directory / "pyproject.toml").is_file():
3269             return directory
3270
3271     return directory
3272
3273
3274 @dataclass
3275 class Report:
3276     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3277
3278     check: bool = False
3279     quiet: bool = False
3280     verbose: bool = False
3281     change_count: int = 0
3282     same_count: int = 0
3283     failure_count: int = 0
3284
3285     def done(self, src: Path, changed: Changed) -> None:
3286         """Increment the counter for successful reformatting. Write out a message."""
3287         if changed is Changed.YES:
3288             reformatted = "would reformat" if self.check else "reformatted"
3289             if self.verbose or not self.quiet:
3290                 out(f"{reformatted} {src}")
3291             self.change_count += 1
3292         else:
3293             if self.verbose:
3294                 if changed is Changed.NO:
3295                     msg = f"{src} already well formatted, good job."
3296                 else:
3297                     msg = f"{src} wasn't modified on disk since last run."
3298                 out(msg, bold=False)
3299             self.same_count += 1
3300
3301     def failed(self, src: Path, message: str) -> None:
3302         """Increment the counter for failed reformatting. Write out a message."""
3303         err(f"error: cannot format {src}: {message}")
3304         self.failure_count += 1
3305
3306     def path_ignored(self, path: Path, message: str) -> None:
3307         if self.verbose:
3308             out(f"{path} ignored: {message}", bold=False)
3309
3310     @property
3311     def return_code(self) -> int:
3312         """Return the exit code that the app should use.
3313
3314         This considers the current state of changed files and failures:
3315         - if there were any failures, return 123;
3316         - if any files were changed and --check is being used, return 1;
3317         - otherwise return 0.
3318         """
3319         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3320         # 126 we have special return codes reserved by the shell.
3321         if self.failure_count:
3322             return 123
3323
3324         elif self.change_count and self.check:
3325             return 1
3326
3327         return 0
3328
3329     def __str__(self) -> str:
3330         """Render a color report of the current state.
3331
3332         Use `click.unstyle` to remove colors.
3333         """
3334         if self.check:
3335             reformatted = "would be reformatted"
3336             unchanged = "would be left unchanged"
3337             failed = "would fail to reformat"
3338         else:
3339             reformatted = "reformatted"
3340             unchanged = "left unchanged"
3341             failed = "failed to reformat"
3342         report = []
3343         if self.change_count:
3344             s = "s" if self.change_count > 1 else ""
3345             report.append(
3346                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3347             )
3348         if self.same_count:
3349             s = "s" if self.same_count > 1 else ""
3350             report.append(f"{self.same_count} file{s} {unchanged}")
3351         if self.failure_count:
3352             s = "s" if self.failure_count > 1 else ""
3353             report.append(
3354                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3355             )
3356         return ", ".join(report) + "."
3357
3358
3359 def assert_equivalent(src: str, dst: str) -> None:
3360     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3361
3362     import ast
3363     import traceback
3364
3365     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3366         """Simple visitor generating strings to compare ASTs by content."""
3367         yield f"{'  ' * depth}{node.__class__.__name__}("
3368
3369         for field in sorted(node._fields):
3370             try:
3371                 value = getattr(node, field)
3372             except AttributeError:
3373                 continue
3374
3375             yield f"{'  ' * (depth+1)}{field}="
3376
3377             if isinstance(value, list):
3378                 for item in value:
3379                     # Ignore nested tuples within del statements, because we may insert
3380                     # parentheses and they change the AST.
3381                     if (
3382                         field == "targets"
3383                         and isinstance(node, ast.Delete)
3384                         and isinstance(item, ast.Tuple)
3385                     ):
3386                         for item in item.elts:
3387                             yield from _v(item, depth + 2)
3388                     elif isinstance(item, ast.AST):
3389                         yield from _v(item, depth + 2)
3390
3391             elif isinstance(value, ast.AST):
3392                 yield from _v(value, depth + 2)
3393
3394             else:
3395                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3396
3397         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3398
3399     try:
3400         src_ast = ast.parse(src)
3401     except Exception as exc:
3402         major, minor = sys.version_info[:2]
3403         raise AssertionError(
3404             f"cannot use --safe with this file; failed to parse source file "
3405             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3406             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3407         )
3408
3409     try:
3410         dst_ast = ast.parse(dst)
3411     except Exception as exc:
3412         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3413         raise AssertionError(
3414             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3415             f"Please report a bug on https://github.com/python/black/issues.  "
3416             f"This invalid output might be helpful: {log}"
3417         ) from None
3418
3419     src_ast_str = "\n".join(_v(src_ast))
3420     dst_ast_str = "\n".join(_v(dst_ast))
3421     if src_ast_str != dst_ast_str:
3422         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3423         raise AssertionError(
3424             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3425             f"the source.  "
3426             f"Please report a bug on https://github.com/python/black/issues.  "
3427             f"This diff might be helpful: {log}"
3428         ) from None
3429
3430
3431 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3432     """Raise AssertionError if `dst` reformats differently the second time."""
3433     newdst = format_str(dst, mode=mode)
3434     if dst != newdst:
3435         log = dump_to_file(
3436             diff(src, dst, "source", "first pass"),
3437             diff(dst, newdst, "first pass", "second pass"),
3438         )
3439         raise AssertionError(
3440             f"INTERNAL ERROR: Black produced different code on the second pass "
3441             f"of the formatter.  "
3442             f"Please report a bug on https://github.com/python/black/issues.  "
3443             f"This diff might be helpful: {log}"
3444         ) from None
3445
3446
3447 def dump_to_file(*output: str) -> str:
3448     """Dump `output` to a temporary file. Return path to the file."""
3449     import tempfile
3450
3451     with tempfile.NamedTemporaryFile(
3452         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3453     ) as f:
3454         for lines in output:
3455             f.write(lines)
3456             if lines and lines[-1] != "\n":
3457                 f.write("\n")
3458     return f.name
3459
3460
3461 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3462     """Return a unified diff string between strings `a` and `b`."""
3463     import difflib
3464
3465     a_lines = [line + "\n" for line in a.split("\n")]
3466     b_lines = [line + "\n" for line in b.split("\n")]
3467     return "".join(
3468         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3469     )
3470
3471
3472 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3473     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3474     err("Aborted!")
3475     for task in tasks:
3476         task.cancel()
3477
3478
3479 def shutdown(loop: BaseEventLoop) -> None:
3480     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3481     try:
3482         if sys.version_info[:2] >= (3, 7):
3483             all_tasks = asyncio.all_tasks
3484         else:
3485             all_tasks = asyncio.Task.all_tasks
3486         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3487         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3488         if not to_cancel:
3489             return
3490
3491         for task in to_cancel:
3492             task.cancel()
3493         loop.run_until_complete(
3494             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3495         )
3496     finally:
3497         # `concurrent.futures.Future` objects cannot be cancelled once they
3498         # are already running. There might be some when the `shutdown()` happened.
3499         # Silence their logger's spew about the event loop being closed.
3500         cf_logger = logging.getLogger("concurrent.futures")
3501         cf_logger.setLevel(logging.CRITICAL)
3502         loop.close()
3503
3504
3505 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3506     """Replace `regex` with `replacement` twice on `original`.
3507
3508     This is used by string normalization to perform replaces on
3509     overlapping matches.
3510     """
3511     return regex.sub(replacement, regex.sub(replacement, original))
3512
3513
3514 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3515     """Compile a regular expression string in `regex`.
3516
3517     If it contains newlines, use verbose mode.
3518     """
3519     if "\n" in regex:
3520         regex = "(?x)" + regex
3521     return re.compile(regex)
3522
3523
3524 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3525     """Like `reversed(enumerate(sequence))` if that were possible."""
3526     index = len(sequence) - 1
3527     for element in reversed(sequence):
3528         yield (index, element)
3529         index -= 1
3530
3531
3532 def enumerate_with_length(
3533     line: Line, reversed: bool = False
3534 ) -> Iterator[Tuple[Index, Leaf, int]]:
3535     """Return an enumeration of leaves with their length.
3536
3537     Stops prematurely on multiline strings and standalone comments.
3538     """
3539     op = cast(
3540         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3541         enumerate_reversed if reversed else enumerate,
3542     )
3543     for index, leaf in op(line.leaves):
3544         length = len(leaf.prefix) + len(leaf.value)
3545         if "\n" in leaf.value:
3546             return  # Multiline strings, we can't continue.
3547
3548         comment: Optional[Leaf]
3549         for comment in line.comments_after(leaf):
3550             length += len(comment.value)
3551
3552         yield index, leaf, length
3553
3554
3555 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3556     """Return True if `line` is no longer than `line_length`.
3557
3558     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3559     """
3560     if not line_str:
3561         line_str = str(line).strip("\n")
3562     return (
3563         len(line_str) <= line_length
3564         and "\n" not in line_str  # multiline strings
3565         and not line.contains_standalone_comments()
3566     )
3567
3568
3569 def can_be_split(line: Line) -> bool:
3570     """Return False if the line cannot be split *for sure*.
3571
3572     This is not an exhaustive search but a cheap heuristic that we can use to
3573     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3574     in unnecessary parentheses).
3575     """
3576     leaves = line.leaves
3577     if len(leaves) < 2:
3578         return False
3579
3580     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3581         call_count = 0
3582         dot_count = 0
3583         next = leaves[-1]
3584         for leaf in leaves[-2::-1]:
3585             if leaf.type in OPENING_BRACKETS:
3586                 if next.type not in CLOSING_BRACKETS:
3587                     return False
3588
3589                 call_count += 1
3590             elif leaf.type == token.DOT:
3591                 dot_count += 1
3592             elif leaf.type == token.NAME:
3593                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3594                     return False
3595
3596             elif leaf.type not in CLOSING_BRACKETS:
3597                 return False
3598
3599             if dot_count > 1 and call_count > 1:
3600                 return False
3601
3602     return True
3603
3604
3605 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3606     """Does `line` have a shape safe to reformat without optional parens around it?
3607
3608     Returns True for only a subset of potentially nice looking formattings but
3609     the point is to not return false positives that end up producing lines that
3610     are too long.
3611     """
3612     bt = line.bracket_tracker
3613     if not bt.delimiters:
3614         # Without delimiters the optional parentheses are useless.
3615         return True
3616
3617     max_priority = bt.max_delimiter_priority()
3618     if bt.delimiter_count_with_priority(max_priority) > 1:
3619         # With more than one delimiter of a kind the optional parentheses read better.
3620         return False
3621
3622     if max_priority == DOT_PRIORITY:
3623         # A single stranded method call doesn't require optional parentheses.
3624         return True
3625
3626     assert len(line.leaves) >= 2, "Stranded delimiter"
3627
3628     first = line.leaves[0]
3629     second = line.leaves[1]
3630     penultimate = line.leaves[-2]
3631     last = line.leaves[-1]
3632
3633     # With a single delimiter, omit if the expression starts or ends with
3634     # a bracket.
3635     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3636         remainder = False
3637         length = 4 * line.depth
3638         for _index, leaf, leaf_length in enumerate_with_length(line):
3639             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3640                 remainder = True
3641             if remainder:
3642                 length += leaf_length
3643                 if length > line_length:
3644                     break
3645
3646                 if leaf.type in OPENING_BRACKETS:
3647                     # There are brackets we can further split on.
3648                     remainder = False
3649
3650         else:
3651             # checked the entire string and line length wasn't exceeded
3652             if len(line.leaves) == _index + 1:
3653                 return True
3654
3655         # Note: we are not returning False here because a line might have *both*
3656         # a leading opening bracket and a trailing closing bracket.  If the
3657         # opening bracket doesn't match our rule, maybe the closing will.
3658
3659     if (
3660         last.type == token.RPAR
3661         or last.type == token.RBRACE
3662         or (
3663             # don't use indexing for omitting optional parentheses;
3664             # it looks weird
3665             last.type == token.RSQB
3666             and last.parent
3667             and last.parent.type != syms.trailer
3668         )
3669     ):
3670         if penultimate.type in OPENING_BRACKETS:
3671             # Empty brackets don't help.
3672             return False
3673
3674         if is_multiline_string(first):
3675             # Additional wrapping of a multiline string in this situation is
3676             # unnecessary.
3677             return True
3678
3679         length = 4 * line.depth
3680         seen_other_brackets = False
3681         for _index, leaf, leaf_length in enumerate_with_length(line):
3682             length += leaf_length
3683             if leaf is last.opening_bracket:
3684                 if seen_other_brackets or length <= line_length:
3685                     return True
3686
3687             elif leaf.type in OPENING_BRACKETS:
3688                 # There are brackets we can further split on.
3689                 seen_other_brackets = True
3690
3691     return False
3692
3693
3694 def get_cache_file(mode: FileMode) -> Path:
3695     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3696
3697
3698 def read_cache(mode: FileMode) -> Cache:
3699     """Read the cache if it exists and is well formed.
3700
3701     If it is not well formed, the call to write_cache later should resolve the issue.
3702     """
3703     cache_file = get_cache_file(mode)
3704     if not cache_file.exists():
3705         return {}
3706
3707     with cache_file.open("rb") as fobj:
3708         try:
3709             cache: Cache = pickle.load(fobj)
3710         except pickle.UnpicklingError:
3711             return {}
3712
3713     return cache
3714
3715
3716 def get_cache_info(path: Path) -> CacheInfo:
3717     """Return the information used to check if a file is already formatted or not."""
3718     stat = path.stat()
3719     return stat.st_mtime, stat.st_size
3720
3721
3722 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3723     """Split an iterable of paths in `sources` into two sets.
3724
3725     The first contains paths of files that modified on disk or are not in the
3726     cache. The other contains paths to non-modified files.
3727     """
3728     todo, done = set(), set()
3729     for src in sources:
3730         src = src.resolve()
3731         if cache.get(src) != get_cache_info(src):
3732             todo.add(src)
3733         else:
3734             done.add(src)
3735     return todo, done
3736
3737
3738 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3739     """Update the cache file."""
3740     cache_file = get_cache_file(mode)
3741     try:
3742         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3743         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3744         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3745             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3746         os.replace(f.name, cache_file)
3747     except OSError:
3748         pass
3749
3750
3751 def patch_click() -> None:
3752     """Make Click not crash.
3753
3754     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3755     default which restricts paths that it can access during the lifetime of the
3756     application.  Click refuses to work in this scenario by raising a RuntimeError.
3757
3758     In case of Black the likelihood that non-ASCII characters are going to be used in
3759     file paths is minimal since it's Python source code.  Moreover, this crash was
3760     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3761     """
3762     try:
3763         from click import core
3764         from click import _unicodefun  # type: ignore
3765     except ModuleNotFoundError:
3766         return
3767
3768     for module in (core, _unicodefun):
3769         if hasattr(module, "_verify_python3_env"):
3770             module._verify_python3_env = lambda: None
3771
3772
3773 def patched_main() -> None:
3774     freeze_support()
3775     patch_click()
3776     main()
3777
3778
3779 if __name__ == "__main__":
3780     patched_main()