]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

24e9d4edaaaeb5fa8792e9f0c21f0e9e863d1047
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 class BracketMatchError(KeyError):
116     """Raised when an opening bracket is unable to be matched to a closing bracket."""
117
118
119 T = TypeVar("T")
120 E = TypeVar("E", bound=Exception)
121
122
123 class Ok(Generic[T]):
124     def __init__(self, value: T) -> None:
125         self._value = value
126
127     def ok(self) -> T:
128         return self._value
129
130
131 class Err(Generic[E]):
132     def __init__(self, e: E) -> None:
133         self._e = e
134
135     def err(self) -> E:
136         return self._e
137
138
139 # The 'Result' return type is used to implement an error-handling model heavily
140 # influenced by that used by the Rust programming language
141 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
142 Result = Union[Ok[T], Err[E]]
143 TResult = Result[T, CannotTransform]  # (T)ransform Result
144 TMatchResult = TResult[Index]
145
146
147 class WriteBack(Enum):
148     NO = 0
149     YES = 1
150     DIFF = 2
151     CHECK = 3
152     COLOR_DIFF = 4
153
154     @classmethod
155     def from_configuration(
156         cls, *, check: bool, diff: bool, color: bool = False
157     ) -> "WriteBack":
158         if check and not diff:
159             return cls.CHECK
160
161         if diff and color:
162             return cls.COLOR_DIFF
163
164         return cls.DIFF if diff else cls.YES
165
166
167 class Changed(Enum):
168     NO = 0
169     CACHED = 1
170     YES = 2
171
172
173 class TargetVersion(Enum):
174     PY27 = 2
175     PY33 = 3
176     PY34 = 4
177     PY35 = 5
178     PY36 = 6
179     PY37 = 7
180     PY38 = 8
181     PY39 = 9
182
183     def is_python2(self) -> bool:
184         return self is TargetVersion.PY27
185
186
187 class Feature(Enum):
188     # All string literals are unicode
189     UNICODE_LITERALS = 1
190     F_STRINGS = 2
191     NUMERIC_UNDERSCORES = 3
192     TRAILING_COMMA_IN_CALL = 4
193     TRAILING_COMMA_IN_DEF = 5
194     # The following two feature-flags are mutually exclusive, and exactly one should be
195     # set for every version of python.
196     ASYNC_IDENTIFIERS = 6
197     ASYNC_KEYWORDS = 7
198     ASSIGNMENT_EXPRESSIONS = 8
199     POS_ONLY_ARGUMENTS = 9
200     RELAXED_DECORATORS = 10
201     FORCE_OPTIONAL_PARENTHESES = 50
202
203
204 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
205     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
206     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
207     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
208     TargetVersion.PY35: {
209         Feature.UNICODE_LITERALS,
210         Feature.TRAILING_COMMA_IN_CALL,
211         Feature.ASYNC_IDENTIFIERS,
212     },
213     TargetVersion.PY36: {
214         Feature.UNICODE_LITERALS,
215         Feature.F_STRINGS,
216         Feature.NUMERIC_UNDERSCORES,
217         Feature.TRAILING_COMMA_IN_CALL,
218         Feature.TRAILING_COMMA_IN_DEF,
219         Feature.ASYNC_IDENTIFIERS,
220     },
221     TargetVersion.PY37: {
222         Feature.UNICODE_LITERALS,
223         Feature.F_STRINGS,
224         Feature.NUMERIC_UNDERSCORES,
225         Feature.TRAILING_COMMA_IN_CALL,
226         Feature.TRAILING_COMMA_IN_DEF,
227         Feature.ASYNC_KEYWORDS,
228     },
229     TargetVersion.PY38: {
230         Feature.UNICODE_LITERALS,
231         Feature.F_STRINGS,
232         Feature.NUMERIC_UNDERSCORES,
233         Feature.TRAILING_COMMA_IN_CALL,
234         Feature.TRAILING_COMMA_IN_DEF,
235         Feature.ASYNC_KEYWORDS,
236         Feature.ASSIGNMENT_EXPRESSIONS,
237         Feature.POS_ONLY_ARGUMENTS,
238     },
239     TargetVersion.PY39: {
240         Feature.UNICODE_LITERALS,
241         Feature.F_STRINGS,
242         Feature.NUMERIC_UNDERSCORES,
243         Feature.TRAILING_COMMA_IN_CALL,
244         Feature.TRAILING_COMMA_IN_DEF,
245         Feature.ASYNC_KEYWORDS,
246         Feature.ASSIGNMENT_EXPRESSIONS,
247         Feature.RELAXED_DECORATORS,
248         Feature.POS_ONLY_ARGUMENTS,
249     },
250 }
251
252
253 @dataclass
254 class Mode:
255     target_versions: Set[TargetVersion] = field(default_factory=set)
256     line_length: int = DEFAULT_LINE_LENGTH
257     string_normalization: bool = True
258     experimental_string_processing: bool = False
259     is_pyi: bool = False
260
261     def get_cache_key(self) -> str:
262         if self.target_versions:
263             version_str = ",".join(
264                 str(version.value)
265                 for version in sorted(self.target_versions, key=lambda v: v.value)
266             )
267         else:
268             version_str = "-"
269         parts = [
270             version_str,
271             str(self.line_length),
272             str(int(self.string_normalization)),
273             str(int(self.is_pyi)),
274         ]
275         return ".".join(parts)
276
277
278 # Legacy name, left for integrations.
279 FileMode = Mode
280
281
282 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
283     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
284
285
286 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
287     """Find the absolute filepath to a pyproject.toml if it exists"""
288     path_project_root = find_project_root(path_search_start)
289     path_pyproject_toml = path_project_root / "pyproject.toml"
290     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
291
292
293 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
294     """Parse a pyproject toml file, pulling out relevant parts for Black
295
296     If parsing fails, will raise a toml.TomlDecodeError
297     """
298     pyproject_toml = toml.load(path_config)
299     config = pyproject_toml.get("tool", {}).get("black", {})
300     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
301
302
303 def read_pyproject_toml(
304     ctx: click.Context, param: click.Parameter, value: Optional[str]
305 ) -> Optional[str]:
306     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
307
308     Returns the path to a successfully found and read configuration file, None
309     otherwise.
310     """
311     if not value:
312         value = find_pyproject_toml(ctx.params.get("src", ()))
313         if value is None:
314             return None
315
316     try:
317         config = parse_pyproject_toml(value)
318     except (toml.TomlDecodeError, OSError) as e:
319         raise click.FileError(
320             filename=value, hint=f"Error reading configuration file: {e}"
321         )
322
323     if not config:
324         return None
325     else:
326         # Sanitize the values to be Click friendly. For more information please see:
327         # https://github.com/psf/black/issues/1458
328         # https://github.com/pallets/click/issues/1567
329         config = {
330             k: str(v) if not isinstance(v, (list, dict)) else v
331             for k, v in config.items()
332         }
333
334     target_version = config.get("target_version")
335     if target_version is not None and not isinstance(target_version, list):
336         raise click.BadOptionUsage(
337             "target-version", "Config key target-version must be a list"
338         )
339
340     default_map: Dict[str, Any] = {}
341     if ctx.default_map:
342         default_map.update(ctx.default_map)
343     default_map.update(config)
344
345     ctx.default_map = default_map
346     return value
347
348
349 def target_version_option_callback(
350     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
351 ) -> List[TargetVersion]:
352     """Compute the target versions from a --target-version flag.
353
354     This is its own function because mypy couldn't infer the type correctly
355     when it was a lambda, causing mypyc trouble.
356     """
357     return [TargetVersion[val.upper()] for val in v]
358
359
360 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
361 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
362 @click.option(
363     "-l",
364     "--line-length",
365     type=int,
366     default=DEFAULT_LINE_LENGTH,
367     help="How many characters per line to allow.",
368     show_default=True,
369 )
370 @click.option(
371     "-t",
372     "--target-version",
373     type=click.Choice([v.name.lower() for v in TargetVersion]),
374     callback=target_version_option_callback,
375     multiple=True,
376     help=(
377         "Python versions that should be supported by Black's output. [default: per-file"
378         " auto-detection]"
379     ),
380 )
381 @click.option(
382     "--pyi",
383     is_flag=True,
384     help=(
385         "Format all input files like typing stubs regardless of file extension (useful"
386         " when piping source on standard input)."
387     ),
388 )
389 @click.option(
390     "-S",
391     "--skip-string-normalization",
392     is_flag=True,
393     help="Don't normalize string quotes or prefixes.",
394 )
395 @click.option(
396     "--experimental-string-processing",
397     is_flag=True,
398     hidden=True,
399     help=(
400         "Experimental option that performs more normalization on string literals."
401         " Currently disabled because it leads to some crashes."
402     ),
403 )
404 @click.option(
405     "--check",
406     is_flag=True,
407     help=(
408         "Don't write the files back, just return the status.  Return code 0 means"
409         " nothing would change.  Return code 1 means some files would be reformatted."
410         " Return code 123 means there was an internal error."
411     ),
412 )
413 @click.option(
414     "--diff",
415     is_flag=True,
416     help="Don't write the files back, just output a diff for each file on stdout.",
417 )
418 @click.option(
419     "--color/--no-color",
420     is_flag=True,
421     help="Show colored diff. Only applies when `--diff` is given.",
422 )
423 @click.option(
424     "--fast/--safe",
425     is_flag=True,
426     help="If --fast given, skip temporary sanity checks. [default: --safe]",
427 )
428 @click.option(
429     "--include",
430     type=str,
431     default=DEFAULT_INCLUDES,
432     help=(
433         "A regular expression that matches files and directories that should be"
434         " included on recursive searches.  An empty value means all files are included"
435         " regardless of the name.  Use forward slashes for directories on all platforms"
436         " (Windows, too).  Exclusions are calculated first, inclusions later."
437     ),
438     show_default=True,
439 )
440 @click.option(
441     "--exclude",
442     type=str,
443     default=DEFAULT_EXCLUDES,
444     help=(
445         "A regular expression that matches files and directories that should be"
446         " excluded on recursive searches.  An empty value means no paths are excluded."
447         " Use forward slashes for directories on all platforms (Windows, too). "
448         " Exclusions are calculated first, inclusions later."
449     ),
450     show_default=True,
451 )
452 @click.option(
453     "--force-exclude",
454     type=str,
455     help=(
456         "Like --exclude, but files and directories matching this regex will be "
457         "excluded even when they are passed explicitly as arguments."
458     ),
459 )
460 @click.option(
461     "-q",
462     "--quiet",
463     is_flag=True,
464     help=(
465         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
466         " those with 2>/dev/null."
467     ),
468 )
469 @click.option(
470     "-v",
471     "--verbose",
472     is_flag=True,
473     help=(
474         "Also emit messages to stderr about files that were not changed or were ignored"
475         " due to --exclude=."
476     ),
477 )
478 @click.version_option(version=__version__)
479 @click.argument(
480     "src",
481     nargs=-1,
482     type=click.Path(
483         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
484     ),
485     is_eager=True,
486 )
487 @click.option(
488     "--config",
489     type=click.Path(
490         exists=True,
491         file_okay=True,
492         dir_okay=False,
493         readable=True,
494         allow_dash=False,
495         path_type=str,
496     ),
497     is_eager=True,
498     callback=read_pyproject_toml,
499     help="Read configuration from FILE path.",
500 )
501 @click.pass_context
502 def main(
503     ctx: click.Context,
504     code: Optional[str],
505     line_length: int,
506     target_version: List[TargetVersion],
507     check: bool,
508     diff: bool,
509     color: bool,
510     fast: bool,
511     pyi: bool,
512     skip_string_normalization: bool,
513     experimental_string_processing: bool,
514     quiet: bool,
515     verbose: bool,
516     include: str,
517     exclude: str,
518     force_exclude: Optional[str],
519     src: Tuple[str, ...],
520     config: Optional[str],
521 ) -> None:
522     """The uncompromising code formatter."""
523     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
524     if target_version:
525         versions = set(target_version)
526     else:
527         # We'll autodetect later.
528         versions = set()
529     mode = Mode(
530         target_versions=versions,
531         line_length=line_length,
532         is_pyi=pyi,
533         string_normalization=not skip_string_normalization,
534         experimental_string_processing=experimental_string_processing,
535     )
536     if config and verbose:
537         out(f"Using configuration from {config}.", bold=False, fg="blue")
538     if code is not None:
539         print(format_str(code, mode=mode))
540         ctx.exit(0)
541     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
542     sources = get_sources(
543         ctx=ctx,
544         src=src,
545         quiet=quiet,
546         verbose=verbose,
547         include=include,
548         exclude=exclude,
549         force_exclude=force_exclude,
550         report=report,
551     )
552
553     path_empty(
554         sources,
555         "No Python files are present to be formatted. Nothing to do 😴",
556         quiet,
557         verbose,
558         ctx,
559     )
560
561     if len(sources) == 1:
562         reformat_one(
563             src=sources.pop(),
564             fast=fast,
565             write_back=write_back,
566             mode=mode,
567             report=report,
568         )
569     else:
570         reformat_many(
571             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
572         )
573
574     if verbose or not quiet:
575         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
576         click.secho(str(report), err=True)
577     ctx.exit(report.return_code)
578
579
580 def get_sources(
581     *,
582     ctx: click.Context,
583     src: Tuple[str, ...],
584     quiet: bool,
585     verbose: bool,
586     include: str,
587     exclude: str,
588     force_exclude: Optional[str],
589     report: "Report",
590 ) -> Set[Path]:
591     """Compute the set of files to be formatted."""
592     try:
593         include_regex = re_compile_maybe_verbose(include)
594     except re.error:
595         err(f"Invalid regular expression for include given: {include!r}")
596         ctx.exit(2)
597     try:
598         exclude_regex = re_compile_maybe_verbose(exclude)
599     except re.error:
600         err(f"Invalid regular expression for exclude given: {exclude!r}")
601         ctx.exit(2)
602     try:
603         force_exclude_regex = (
604             re_compile_maybe_verbose(force_exclude) if force_exclude else None
605         )
606     except re.error:
607         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
608         ctx.exit(2)
609
610     root = find_project_root(src)
611     sources: Set[Path] = set()
612     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
613     gitignore = get_gitignore(root)
614
615     for s in src:
616         p = Path(s)
617         if p.is_dir():
618             sources.update(
619                 gen_python_files(
620                     p.iterdir(),
621                     root,
622                     include_regex,
623                     exclude_regex,
624                     force_exclude_regex,
625                     report,
626                     gitignore,
627                 )
628             )
629         elif s == "-":
630             sources.add(p)
631         elif p.is_file():
632             normalized_path = normalize_path_maybe_ignore(p, root, report)
633             if normalized_path is None:
634                 continue
635
636             normalized_path = "/" + normalized_path
637             # Hard-exclude any files that matches the `--force-exclude` regex.
638             if force_exclude_regex:
639                 force_exclude_match = force_exclude_regex.search(normalized_path)
640             else:
641                 force_exclude_match = None
642             if force_exclude_match and force_exclude_match.group(0):
643                 report.path_ignored(p, "matches the --force-exclude regular expression")
644                 continue
645
646             sources.add(p)
647         else:
648             err(f"invalid path: {s}")
649     return sources
650
651
652 def path_empty(
653     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
654 ) -> None:
655     """
656     Exit if there is no `src` provided for formatting
657     """
658     if not src and (verbose or not quiet):
659         out(msg)
660         ctx.exit(0)
661
662
663 def reformat_one(
664     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
665 ) -> None:
666     """Reformat a single file under `src` without spawning child processes.
667
668     `fast`, `write_back`, and `mode` options are passed to
669     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
670     """
671     try:
672         changed = Changed.NO
673         if not src.is_file() and str(src) == "-":
674             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
675                 changed = Changed.YES
676         else:
677             cache: Cache = {}
678             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
679                 cache = read_cache(mode)
680                 res_src = src.resolve()
681                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
682                     changed = Changed.CACHED
683             if changed is not Changed.CACHED and format_file_in_place(
684                 src, fast=fast, write_back=write_back, mode=mode
685             ):
686                 changed = Changed.YES
687             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
688                 write_back is WriteBack.CHECK and changed is Changed.NO
689             ):
690                 write_cache(cache, [src], mode)
691         report.done(src, changed)
692     except Exception as exc:
693         if report.verbose:
694             traceback.print_exc()
695         report.failed(src, str(exc))
696
697
698 def reformat_many(
699     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
700 ) -> None:
701     """Reformat multiple files using a ProcessPoolExecutor."""
702     executor: Executor
703     loop = asyncio.get_event_loop()
704     worker_count = os.cpu_count()
705     if sys.platform == "win32":
706         # Work around https://bugs.python.org/issue26903
707         worker_count = min(worker_count, 61)
708     try:
709         executor = ProcessPoolExecutor(max_workers=worker_count)
710     except (ImportError, OSError):
711         # we arrive here if the underlying system does not support multi-processing
712         # like in AWS Lambda or Termux, in which case we gracefully fallback to
713         # a ThreadPollExecutor with just a single worker (more workers would not do us
714         # any good due to the Global Interpreter Lock)
715         executor = ThreadPoolExecutor(max_workers=1)
716
717     try:
718         loop.run_until_complete(
719             schedule_formatting(
720                 sources=sources,
721                 fast=fast,
722                 write_back=write_back,
723                 mode=mode,
724                 report=report,
725                 loop=loop,
726                 executor=executor,
727             )
728         )
729     finally:
730         shutdown(loop)
731         if executor is not None:
732             executor.shutdown()
733
734
735 async def schedule_formatting(
736     sources: Set[Path],
737     fast: bool,
738     write_back: WriteBack,
739     mode: Mode,
740     report: "Report",
741     loop: asyncio.AbstractEventLoop,
742     executor: Executor,
743 ) -> None:
744     """Run formatting of `sources` in parallel using the provided `executor`.
745
746     (Use ProcessPoolExecutors for actual parallelism.)
747
748     `write_back`, `fast`, and `mode` options are passed to
749     :func:`format_file_in_place`.
750     """
751     cache: Cache = {}
752     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
753         cache = read_cache(mode)
754         sources, cached = filter_cached(cache, sources)
755         for src in sorted(cached):
756             report.done(src, Changed.CACHED)
757     if not sources:
758         return
759
760     cancelled = []
761     sources_to_cache = []
762     lock = None
763     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
764         # For diff output, we need locks to ensure we don't interleave output
765         # from different processes.
766         manager = Manager()
767         lock = manager.Lock()
768     tasks = {
769         asyncio.ensure_future(
770             loop.run_in_executor(
771                 executor, format_file_in_place, src, fast, mode, write_back, lock
772             )
773         ): src
774         for src in sorted(sources)
775     }
776     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
777     try:
778         loop.add_signal_handler(signal.SIGINT, cancel, pending)
779         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
780     except NotImplementedError:
781         # There are no good alternatives for these on Windows.
782         pass
783     while pending:
784         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
785         for task in done:
786             src = tasks.pop(task)
787             if task.cancelled():
788                 cancelled.append(task)
789             elif task.exception():
790                 report.failed(src, str(task.exception()))
791             else:
792                 changed = Changed.YES if task.result() else Changed.NO
793                 # If the file was written back or was successfully checked as
794                 # well-formatted, store this information in the cache.
795                 if write_back is WriteBack.YES or (
796                     write_back is WriteBack.CHECK and changed is Changed.NO
797                 ):
798                     sources_to_cache.append(src)
799                 report.done(src, changed)
800     if cancelled:
801         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
802     if sources_to_cache:
803         write_cache(cache, sources_to_cache, mode)
804
805
806 def format_file_in_place(
807     src: Path,
808     fast: bool,
809     mode: Mode,
810     write_back: WriteBack = WriteBack.NO,
811     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
812 ) -> bool:
813     """Format file under `src` path. Return True if changed.
814
815     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
816     code to the file.
817     `mode` and `fast` options are passed to :func:`format_file_contents`.
818     """
819     if src.suffix == ".pyi":
820         mode = replace(mode, is_pyi=True)
821
822     then = datetime.utcfromtimestamp(src.stat().st_mtime)
823     with open(src, "rb") as buf:
824         src_contents, encoding, newline = decode_bytes(buf.read())
825     try:
826         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
827     except NothingChanged:
828         return False
829
830     if write_back == WriteBack.YES:
831         with open(src, "w", encoding=encoding, newline=newline) as f:
832             f.write(dst_contents)
833     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
834         now = datetime.utcnow()
835         src_name = f"{src}\t{then} +0000"
836         dst_name = f"{src}\t{now} +0000"
837         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
838
839         if write_back == write_back.COLOR_DIFF:
840             diff_contents = color_diff(diff_contents)
841
842         with lock or nullcontext():
843             f = io.TextIOWrapper(
844                 sys.stdout.buffer,
845                 encoding=encoding,
846                 newline=newline,
847                 write_through=True,
848             )
849             f = wrap_stream_for_windows(f)
850             f.write(diff_contents)
851             f.detach()
852
853     return True
854
855
856 def color_diff(contents: str) -> str:
857     """Inject the ANSI color codes to the diff."""
858     lines = contents.split("\n")
859     for i, line in enumerate(lines):
860         if line.startswith("+++") or line.startswith("---"):
861             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
862         elif line.startswith("@@"):
863             line = "\033[36m" + line + "\033[0m"  # cyan, reset
864         elif line.startswith("+"):
865             line = "\033[32m" + line + "\033[0m"  # green, reset
866         elif line.startswith("-"):
867             line = "\033[31m" + line + "\033[0m"  # red, reset
868         lines[i] = line
869     return "\n".join(lines)
870
871
872 def wrap_stream_for_windows(
873     f: io.TextIOWrapper,
874 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
875     """
876     Wrap stream with colorama's wrap_stream so colors are shown on Windows.
877
878     If `colorama` is unavailable, the original stream is returned unmodified.
879     Otherwise, the `wrap_stream()` function determines whether the stream needs
880     to be wrapped for a Windows environment and will accordingly either return
881     an `AnsiToWin32` wrapper or the original stream.
882     """
883     try:
884         from colorama.initialise import wrap_stream
885     except ImportError:
886         return f
887     else:
888         # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
889         return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
890
891
892 def format_stdin_to_stdout(
893     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
894 ) -> bool:
895     """Format file on stdin. Return True if changed.
896
897     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
898     write a diff to stdout. The `mode` argument is passed to
899     :func:`format_file_contents`.
900     """
901     then = datetime.utcnow()
902     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
903     dst = src
904     try:
905         dst = format_file_contents(src, fast=fast, mode=mode)
906         return True
907
908     except NothingChanged:
909         return False
910
911     finally:
912         f = io.TextIOWrapper(
913             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
914         )
915         if write_back == WriteBack.YES:
916             f.write(dst)
917         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
918             now = datetime.utcnow()
919             src_name = f"STDIN\t{then} +0000"
920             dst_name = f"STDOUT\t{now} +0000"
921             d = diff(src, dst, src_name, dst_name)
922             if write_back == WriteBack.COLOR_DIFF:
923                 d = color_diff(d)
924                 f = wrap_stream_for_windows(f)
925             f.write(d)
926         f.detach()
927
928
929 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
930     """Reformat contents of a file and return new contents.
931
932     If `fast` is False, additionally confirm that the reformatted code is
933     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
934     `mode` is passed to :func:`format_str`.
935     """
936     if not src_contents.strip():
937         raise NothingChanged
938
939     dst_contents = format_str(src_contents, mode=mode)
940     if src_contents == dst_contents:
941         raise NothingChanged
942
943     if not fast:
944         assert_equivalent(src_contents, dst_contents)
945         assert_stable(src_contents, dst_contents, mode=mode)
946     return dst_contents
947
948
949 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
950     """Reformat a string and return new contents.
951
952     `mode` determines formatting options, such as how many characters per line are
953     allowed.  Example:
954
955     >>> import black
956     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
957     def f(arg: str = "") -> None:
958         ...
959
960     A more complex example:
961
962     >>> print(
963     ...   black.format_str(
964     ...     "def f(arg:str='')->None: hey",
965     ...     mode=black.Mode(
966     ...       target_versions={black.TargetVersion.PY36},
967     ...       line_length=10,
968     ...       string_normalization=False,
969     ...       is_pyi=False,
970     ...     ),
971     ...   ),
972     ... )
973     def f(
974         arg: str = '',
975     ) -> None:
976         hey
977
978     """
979     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
980     dst_contents = []
981     future_imports = get_future_imports(src_node)
982     if mode.target_versions:
983         versions = mode.target_versions
984     else:
985         versions = detect_target_versions(src_node)
986     normalize_fmt_off(src_node)
987     lines = LineGenerator(
988         remove_u_prefix="unicode_literals" in future_imports
989         or supports_feature(versions, Feature.UNICODE_LITERALS),
990         is_pyi=mode.is_pyi,
991         normalize_strings=mode.string_normalization,
992     )
993     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
994     empty_line = Line()
995     after = 0
996     split_line_features = {
997         feature
998         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
999         if supports_feature(versions, feature)
1000     }
1001     for current_line in lines.visit(src_node):
1002         dst_contents.append(str(empty_line) * after)
1003         before, after = elt.maybe_empty_lines(current_line)
1004         dst_contents.append(str(empty_line) * before)
1005         for line in transform_line(
1006             current_line, mode=mode, features=split_line_features
1007         ):
1008             dst_contents.append(str(line))
1009     return "".join(dst_contents)
1010
1011
1012 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1013     """Return a tuple of (decoded_contents, encoding, newline).
1014
1015     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1016     universal newlines (i.e. only contains LF).
1017     """
1018     srcbuf = io.BytesIO(src)
1019     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1020     if not lines:
1021         return "", encoding, "\n"
1022
1023     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1024     srcbuf.seek(0)
1025     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1026         return tiow.read(), encoding, newline
1027
1028
1029 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1030     if not target_versions:
1031         # No target_version specified, so try all grammars.
1032         return [
1033             # Python 3.7+
1034             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1035             # Python 3.0-3.6
1036             pygram.python_grammar_no_print_statement_no_exec_statement,
1037             # Python 2.7 with future print_function import
1038             pygram.python_grammar_no_print_statement,
1039             # Python 2.7
1040             pygram.python_grammar,
1041         ]
1042
1043     if all(version.is_python2() for version in target_versions):
1044         # Python 2-only code, so try Python 2 grammars.
1045         return [
1046             # Python 2.7 with future print_function import
1047             pygram.python_grammar_no_print_statement,
1048             # Python 2.7
1049             pygram.python_grammar,
1050         ]
1051
1052     # Python 3-compatible code, so only try Python 3 grammar.
1053     grammars = []
1054     # If we have to parse both, try to parse async as a keyword first
1055     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1056         # Python 3.7+
1057         grammars.append(
1058             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1059         )
1060     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1061         # Python 3.0-3.6
1062         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1063     # At least one of the above branches must have been taken, because every Python
1064     # version has exactly one of the two 'ASYNC_*' flags
1065     return grammars
1066
1067
1068 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1069     """Given a string with source, return the lib2to3 Node."""
1070     if not src_txt.endswith("\n"):
1071         src_txt += "\n"
1072
1073     for grammar in get_grammars(set(target_versions)):
1074         drv = driver.Driver(grammar, pytree.convert)
1075         try:
1076             result = drv.parse_string(src_txt, True)
1077             break
1078
1079         except ParseError as pe:
1080             lineno, column = pe.context[1]
1081             lines = src_txt.splitlines()
1082             try:
1083                 faulty_line = lines[lineno - 1]
1084             except IndexError:
1085                 faulty_line = "<line number missing in source>"
1086             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1087     else:
1088         raise exc from None
1089
1090     if isinstance(result, Leaf):
1091         result = Node(syms.file_input, [result])
1092     return result
1093
1094
1095 def lib2to3_unparse(node: Node) -> str:
1096     """Given a lib2to3 node, return its string representation."""
1097     code = str(node)
1098     return code
1099
1100
1101 class Visitor(Generic[T]):
1102     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1103
1104     def visit(self, node: LN) -> Iterator[T]:
1105         """Main method to visit `node` and its children.
1106
1107         It tries to find a `visit_*()` method for the given `node.type`, like
1108         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1109         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1110         instead.
1111
1112         Then yields objects of type `T` from the selected visitor.
1113         """
1114         if node.type < 256:
1115             name = token.tok_name[node.type]
1116         else:
1117             name = str(type_repr(node.type))
1118         # We explicitly branch on whether a visitor exists (instead of
1119         # using self.visit_default as the default arg to getattr) in order
1120         # to save needing to create a bound method object and so mypyc can
1121         # generate a native call to visit_default.
1122         visitf = getattr(self, f"visit_{name}", None)
1123         if visitf:
1124             yield from visitf(node)
1125         else:
1126             yield from self.visit_default(node)
1127
1128     def visit_default(self, node: LN) -> Iterator[T]:
1129         """Default `visit_*()` implementation. Recurses to children of `node`."""
1130         if isinstance(node, Node):
1131             for child in node.children:
1132                 yield from self.visit(child)
1133
1134
1135 @dataclass
1136 class DebugVisitor(Visitor[T]):
1137     tree_depth: int = 0
1138
1139     def visit_default(self, node: LN) -> Iterator[T]:
1140         indent = " " * (2 * self.tree_depth)
1141         if isinstance(node, Node):
1142             _type = type_repr(node.type)
1143             out(f"{indent}{_type}", fg="yellow")
1144             self.tree_depth += 1
1145             for child in node.children:
1146                 yield from self.visit(child)
1147
1148             self.tree_depth -= 1
1149             out(f"{indent}/{_type}", fg="yellow", bold=False)
1150         else:
1151             _type = token.tok_name.get(node.type, str(node.type))
1152             out(f"{indent}{_type}", fg="blue", nl=False)
1153             if node.prefix:
1154                 # We don't have to handle prefixes for `Node` objects since
1155                 # that delegates to the first child anyway.
1156                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1157             out(f" {node.value!r}", fg="blue", bold=False)
1158
1159     @classmethod
1160     def show(cls, code: Union[str, Leaf, Node]) -> None:
1161         """Pretty-print the lib2to3 AST of a given string of `code`.
1162
1163         Convenience method for debugging.
1164         """
1165         v: DebugVisitor[None] = DebugVisitor()
1166         if isinstance(code, str):
1167             code = lib2to3_parse(code)
1168         list(v.visit(code))
1169
1170
1171 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1172 STATEMENT: Final = {
1173     syms.if_stmt,
1174     syms.while_stmt,
1175     syms.for_stmt,
1176     syms.try_stmt,
1177     syms.except_clause,
1178     syms.with_stmt,
1179     syms.funcdef,
1180     syms.classdef,
1181 }
1182 STANDALONE_COMMENT: Final = 153
1183 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1184 LOGIC_OPERATORS: Final = {"and", "or"}
1185 COMPARATORS: Final = {
1186     token.LESS,
1187     token.GREATER,
1188     token.EQEQUAL,
1189     token.NOTEQUAL,
1190     token.LESSEQUAL,
1191     token.GREATEREQUAL,
1192 }
1193 MATH_OPERATORS: Final = {
1194     token.VBAR,
1195     token.CIRCUMFLEX,
1196     token.AMPER,
1197     token.LEFTSHIFT,
1198     token.RIGHTSHIFT,
1199     token.PLUS,
1200     token.MINUS,
1201     token.STAR,
1202     token.SLASH,
1203     token.DOUBLESLASH,
1204     token.PERCENT,
1205     token.AT,
1206     token.TILDE,
1207     token.DOUBLESTAR,
1208 }
1209 STARS: Final = {token.STAR, token.DOUBLESTAR}
1210 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1211 VARARGS_PARENTS: Final = {
1212     syms.arglist,
1213     syms.argument,  # double star in arglist
1214     syms.trailer,  # single argument to call
1215     syms.typedargslist,
1216     syms.varargslist,  # lambdas
1217 }
1218 UNPACKING_PARENTS: Final = {
1219     syms.atom,  # single element of a list or set literal
1220     syms.dictsetmaker,
1221     syms.listmaker,
1222     syms.testlist_gexp,
1223     syms.testlist_star_expr,
1224 }
1225 TEST_DESCENDANTS: Final = {
1226     syms.test,
1227     syms.lambdef,
1228     syms.or_test,
1229     syms.and_test,
1230     syms.not_test,
1231     syms.comparison,
1232     syms.star_expr,
1233     syms.expr,
1234     syms.xor_expr,
1235     syms.and_expr,
1236     syms.shift_expr,
1237     syms.arith_expr,
1238     syms.trailer,
1239     syms.term,
1240     syms.power,
1241 }
1242 ASSIGNMENTS: Final = {
1243     "=",
1244     "+=",
1245     "-=",
1246     "*=",
1247     "@=",
1248     "/=",
1249     "%=",
1250     "&=",
1251     "|=",
1252     "^=",
1253     "<<=",
1254     ">>=",
1255     "**=",
1256     "//=",
1257 }
1258 COMPREHENSION_PRIORITY: Final = 20
1259 COMMA_PRIORITY: Final = 18
1260 TERNARY_PRIORITY: Final = 16
1261 LOGIC_PRIORITY: Final = 14
1262 STRING_PRIORITY: Final = 12
1263 COMPARATOR_PRIORITY: Final = 10
1264 MATH_PRIORITIES: Final = {
1265     token.VBAR: 9,
1266     token.CIRCUMFLEX: 8,
1267     token.AMPER: 7,
1268     token.LEFTSHIFT: 6,
1269     token.RIGHTSHIFT: 6,
1270     token.PLUS: 5,
1271     token.MINUS: 5,
1272     token.STAR: 4,
1273     token.SLASH: 4,
1274     token.DOUBLESLASH: 4,
1275     token.PERCENT: 4,
1276     token.AT: 4,
1277     token.TILDE: 3,
1278     token.DOUBLESTAR: 2,
1279 }
1280 DOT_PRIORITY: Final = 1
1281
1282
1283 @dataclass
1284 class BracketTracker:
1285     """Keeps track of brackets on a line."""
1286
1287     depth: int = 0
1288     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1289     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1290     previous: Optional[Leaf] = None
1291     _for_loop_depths: List[int] = field(default_factory=list)
1292     _lambda_argument_depths: List[int] = field(default_factory=list)
1293     invisible: List[Leaf] = field(default_factory=list)
1294
1295     def mark(self, leaf: Leaf) -> None:
1296         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1297
1298         All leaves receive an int `bracket_depth` field that stores how deep
1299         within brackets a given leaf is. 0 means there are no enclosing brackets
1300         that started on this line.
1301
1302         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1303         field that it forms a pair with. This is a one-directional link to
1304         avoid reference cycles.
1305
1306         If a leaf is a delimiter (a token on which Black can split the line if
1307         needed) and it's on depth 0, its `id()` is stored in the tracker's
1308         `delimiters` field.
1309         """
1310         if leaf.type == token.COMMENT:
1311             return
1312
1313         self.maybe_decrement_after_for_loop_variable(leaf)
1314         self.maybe_decrement_after_lambda_arguments(leaf)
1315         if leaf.type in CLOSING_BRACKETS:
1316             self.depth -= 1
1317             try:
1318                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1319             except KeyError as e:
1320                 raise BracketMatchError(
1321                     "Unable to match a closing bracket to the following opening"
1322                     f" bracket: {leaf}"
1323                 ) from e
1324             leaf.opening_bracket = opening_bracket
1325             if not leaf.value:
1326                 self.invisible.append(leaf)
1327         leaf.bracket_depth = self.depth
1328         if self.depth == 0:
1329             delim = is_split_before_delimiter(leaf, self.previous)
1330             if delim and self.previous is not None:
1331                 self.delimiters[id(self.previous)] = delim
1332             else:
1333                 delim = is_split_after_delimiter(leaf, self.previous)
1334                 if delim:
1335                     self.delimiters[id(leaf)] = delim
1336         if leaf.type in OPENING_BRACKETS:
1337             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1338             self.depth += 1
1339             if not leaf.value:
1340                 self.invisible.append(leaf)
1341         self.previous = leaf
1342         self.maybe_increment_lambda_arguments(leaf)
1343         self.maybe_increment_for_loop_variable(leaf)
1344
1345     def any_open_brackets(self) -> bool:
1346         """Return True if there is an yet unmatched open bracket on the line."""
1347         return bool(self.bracket_match)
1348
1349     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1350         """Return the highest priority of a delimiter found on the line.
1351
1352         Values are consistent with what `is_split_*_delimiter()` return.
1353         Raises ValueError on no delimiters.
1354         """
1355         return max(v for k, v in self.delimiters.items() if k not in exclude)
1356
1357     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1358         """Return the number of delimiters with the given `priority`.
1359
1360         If no `priority` is passed, defaults to max priority on the line.
1361         """
1362         if not self.delimiters:
1363             return 0
1364
1365         priority = priority or self.max_delimiter_priority()
1366         return sum(1 for p in self.delimiters.values() if p == priority)
1367
1368     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1369         """In a for loop, or comprehension, the variables are often unpacks.
1370
1371         To avoid splitting on the comma in this situation, increase the depth of
1372         tokens between `for` and `in`.
1373         """
1374         if leaf.type == token.NAME and leaf.value == "for":
1375             self.depth += 1
1376             self._for_loop_depths.append(self.depth)
1377             return True
1378
1379         return False
1380
1381     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1382         """See `maybe_increment_for_loop_variable` above for explanation."""
1383         if (
1384             self._for_loop_depths
1385             and self._for_loop_depths[-1] == self.depth
1386             and leaf.type == token.NAME
1387             and leaf.value == "in"
1388         ):
1389             self.depth -= 1
1390             self._for_loop_depths.pop()
1391             return True
1392
1393         return False
1394
1395     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1396         """In a lambda expression, there might be more than one argument.
1397
1398         To avoid splitting on the comma in this situation, increase the depth of
1399         tokens between `lambda` and `:`.
1400         """
1401         if leaf.type == token.NAME and leaf.value == "lambda":
1402             self.depth += 1
1403             self._lambda_argument_depths.append(self.depth)
1404             return True
1405
1406         return False
1407
1408     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1409         """See `maybe_increment_lambda_arguments` above for explanation."""
1410         if (
1411             self._lambda_argument_depths
1412             and self._lambda_argument_depths[-1] == self.depth
1413             and leaf.type == token.COLON
1414         ):
1415             self.depth -= 1
1416             self._lambda_argument_depths.pop()
1417             return True
1418
1419         return False
1420
1421     def get_open_lsqb(self) -> Optional[Leaf]:
1422         """Return the most recent opening square bracket (if any)."""
1423         return self.bracket_match.get((self.depth - 1, token.RSQB))
1424
1425
1426 @dataclass
1427 class Line:
1428     """Holds leaves and comments. Can be printed with `str(line)`."""
1429
1430     depth: int = 0
1431     leaves: List[Leaf] = field(default_factory=list)
1432     # keys ordered like `leaves`
1433     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1434     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1435     inside_brackets: bool = False
1436     should_explode: bool = False
1437
1438     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1439         """Add a new `leaf` to the end of the line.
1440
1441         Unless `preformatted` is True, the `leaf` will receive a new consistent
1442         whitespace prefix and metadata applied by :class:`BracketTracker`.
1443         Trailing commas are maybe removed, unpacked for loop variables are
1444         demoted from being delimiters.
1445
1446         Inline comments are put aside.
1447         """
1448         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1449         if not has_value:
1450             return
1451
1452         if token.COLON == leaf.type and self.is_class_paren_empty:
1453             del self.leaves[-2:]
1454         if self.leaves and not preformatted:
1455             # Note: at this point leaf.prefix should be empty except for
1456             # imports, for which we only preserve newlines.
1457             leaf.prefix += whitespace(
1458                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1459             )
1460         if self.inside_brackets or not preformatted:
1461             self.bracket_tracker.mark(leaf)
1462             if self.maybe_should_explode(leaf):
1463                 self.should_explode = True
1464         if not self.append_comment(leaf):
1465             self.leaves.append(leaf)
1466
1467     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1468         """Like :func:`append()` but disallow invalid standalone comment structure.
1469
1470         Raises ValueError when any `leaf` is appended after a standalone comment
1471         or when a standalone comment is not the first leaf on the line.
1472         """
1473         if self.bracket_tracker.depth == 0:
1474             if self.is_comment:
1475                 raise ValueError("cannot append to standalone comments")
1476
1477             if self.leaves and leaf.type == STANDALONE_COMMENT:
1478                 raise ValueError(
1479                     "cannot append standalone comments to a populated line"
1480                 )
1481
1482         self.append(leaf, preformatted=preformatted)
1483
1484     @property
1485     def is_comment(self) -> bool:
1486         """Is this line a standalone comment?"""
1487         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1488
1489     @property
1490     def is_decorator(self) -> bool:
1491         """Is this line a decorator?"""
1492         return bool(self) and self.leaves[0].type == token.AT
1493
1494     @property
1495     def is_import(self) -> bool:
1496         """Is this an import line?"""
1497         return bool(self) and is_import(self.leaves[0])
1498
1499     @property
1500     def is_class(self) -> bool:
1501         """Is this line a class definition?"""
1502         return (
1503             bool(self)
1504             and self.leaves[0].type == token.NAME
1505             and self.leaves[0].value == "class"
1506         )
1507
1508     @property
1509     def is_stub_class(self) -> bool:
1510         """Is this line a class definition with a body consisting only of "..."?"""
1511         return self.is_class and self.leaves[-3:] == [
1512             Leaf(token.DOT, ".") for _ in range(3)
1513         ]
1514
1515     @property
1516     def is_def(self) -> bool:
1517         """Is this a function definition? (Also returns True for async defs.)"""
1518         try:
1519             first_leaf = self.leaves[0]
1520         except IndexError:
1521             return False
1522
1523         try:
1524             second_leaf: Optional[Leaf] = self.leaves[1]
1525         except IndexError:
1526             second_leaf = None
1527         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1528             first_leaf.type == token.ASYNC
1529             and second_leaf is not None
1530             and second_leaf.type == token.NAME
1531             and second_leaf.value == "def"
1532         )
1533
1534     @property
1535     def is_class_paren_empty(self) -> bool:
1536         """Is this a class with no base classes but using parentheses?
1537
1538         Those are unnecessary and should be removed.
1539         """
1540         return (
1541             bool(self)
1542             and len(self.leaves) == 4
1543             and self.is_class
1544             and self.leaves[2].type == token.LPAR
1545             and self.leaves[2].value == "("
1546             and self.leaves[3].type == token.RPAR
1547             and self.leaves[3].value == ")"
1548         )
1549
1550     @property
1551     def is_triple_quoted_string(self) -> bool:
1552         """Is the line a triple quoted string?"""
1553         return (
1554             bool(self)
1555             and self.leaves[0].type == token.STRING
1556             and self.leaves[0].value.startswith(('"""', "'''"))
1557         )
1558
1559     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1560         """If so, needs to be split before emitting."""
1561         for leaf in self.leaves:
1562             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1563                 return True
1564
1565         return False
1566
1567     def contains_uncollapsable_type_comments(self) -> bool:
1568         ignored_ids = set()
1569         try:
1570             last_leaf = self.leaves[-1]
1571             ignored_ids.add(id(last_leaf))
1572             if last_leaf.type == token.COMMA or (
1573                 last_leaf.type == token.RPAR and not last_leaf.value
1574             ):
1575                 # When trailing commas or optional parens are inserted by Black for
1576                 # consistency, comments after the previous last element are not moved
1577                 # (they don't have to, rendering will still be correct).  So we ignore
1578                 # trailing commas and invisible.
1579                 last_leaf = self.leaves[-2]
1580                 ignored_ids.add(id(last_leaf))
1581         except IndexError:
1582             return False
1583
1584         # A type comment is uncollapsable if it is attached to a leaf
1585         # that isn't at the end of the line (since that could cause it
1586         # to get associated to a different argument) or if there are
1587         # comments before it (since that could cause it to get hidden
1588         # behind a comment.
1589         comment_seen = False
1590         for leaf_id, comments in self.comments.items():
1591             for comment in comments:
1592                 if is_type_comment(comment):
1593                     if comment_seen or (
1594                         not is_type_comment(comment, " ignore")
1595                         and leaf_id not in ignored_ids
1596                     ):
1597                         return True
1598
1599                 comment_seen = True
1600
1601         return False
1602
1603     def contains_unsplittable_type_ignore(self) -> bool:
1604         if not self.leaves:
1605             return False
1606
1607         # If a 'type: ignore' is attached to the end of a line, we
1608         # can't split the line, because we can't know which of the
1609         # subexpressions the ignore was meant to apply to.
1610         #
1611         # We only want this to apply to actual physical lines from the
1612         # original source, though: we don't want the presence of a
1613         # 'type: ignore' at the end of a multiline expression to
1614         # justify pushing it all onto one line. Thus we
1615         # (unfortunately) need to check the actual source lines and
1616         # only report an unsplittable 'type: ignore' if this line was
1617         # one line in the original code.
1618
1619         # Grab the first and last line numbers, skipping generated leaves
1620         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1621         last_line = next(
1622             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1623         )
1624
1625         if first_line == last_line:
1626             # We look at the last two leaves since a comma or an
1627             # invisible paren could have been added at the end of the
1628             # line.
1629             for node in self.leaves[-2:]:
1630                 for comment in self.comments.get(id(node), []):
1631                     if is_type_comment(comment, " ignore"):
1632                         return True
1633
1634         return False
1635
1636     def contains_multiline_strings(self) -> bool:
1637         return any(is_multiline_string(leaf) for leaf in self.leaves)
1638
1639     def maybe_should_explode(self, closing: Leaf) -> bool:
1640         """Return True if this line should explode (always be split), that is when:
1641         - there's a trailing comma here; and
1642         - it's not a one-tuple.
1643         """
1644         if not (
1645             closing.type in CLOSING_BRACKETS
1646             and self.leaves
1647             and self.leaves[-1].type == token.COMMA
1648         ):
1649             return False
1650
1651         if closing.type in {token.RBRACE, token.RSQB}:
1652             return True
1653
1654         if self.is_import:
1655             return True
1656
1657         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1658             return True
1659
1660         return False
1661
1662     def append_comment(self, comment: Leaf) -> bool:
1663         """Add an inline or standalone comment to the line."""
1664         if (
1665             comment.type == STANDALONE_COMMENT
1666             and self.bracket_tracker.any_open_brackets()
1667         ):
1668             comment.prefix = ""
1669             return False
1670
1671         if comment.type != token.COMMENT:
1672             return False
1673
1674         if not self.leaves:
1675             comment.type = STANDALONE_COMMENT
1676             comment.prefix = ""
1677             return False
1678
1679         last_leaf = self.leaves[-1]
1680         if (
1681             last_leaf.type == token.RPAR
1682             and not last_leaf.value
1683             and last_leaf.parent
1684             and len(list(last_leaf.parent.leaves())) <= 3
1685             and not is_type_comment(comment)
1686         ):
1687             # Comments on an optional parens wrapping a single leaf should belong to
1688             # the wrapped node except if it's a type comment. Pinning the comment like
1689             # this avoids unstable formatting caused by comment migration.
1690             if len(self.leaves) < 2:
1691                 comment.type = STANDALONE_COMMENT
1692                 comment.prefix = ""
1693                 return False
1694
1695             last_leaf = self.leaves[-2]
1696         self.comments.setdefault(id(last_leaf), []).append(comment)
1697         return True
1698
1699     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1700         """Generate comments that should appear directly after `leaf`."""
1701         return self.comments.get(id(leaf), [])
1702
1703     def remove_trailing_comma(self) -> None:
1704         """Remove the trailing comma and moves the comments attached to it."""
1705         trailing_comma = self.leaves.pop()
1706         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1707         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1708             trailing_comma_comments
1709         )
1710
1711     def is_complex_subscript(self, leaf: Leaf) -> bool:
1712         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1713         open_lsqb = self.bracket_tracker.get_open_lsqb()
1714         if open_lsqb is None:
1715             return False
1716
1717         subscript_start = open_lsqb.next_sibling
1718
1719         if isinstance(subscript_start, Node):
1720             if subscript_start.type == syms.listmaker:
1721                 return False
1722
1723             if subscript_start.type == syms.subscriptlist:
1724                 subscript_start = child_towards(subscript_start, leaf)
1725         return subscript_start is not None and any(
1726             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1727         )
1728
1729     def clone(self) -> "Line":
1730         return Line(
1731             depth=self.depth,
1732             inside_brackets=self.inside_brackets,
1733             should_explode=self.should_explode,
1734         )
1735
1736     def __str__(self) -> str:
1737         """Render the line."""
1738         if not self:
1739             return "\n"
1740
1741         indent = "    " * self.depth
1742         leaves = iter(self.leaves)
1743         first = next(leaves)
1744         res = f"{first.prefix}{indent}{first.value}"
1745         for leaf in leaves:
1746             res += str(leaf)
1747         for comment in itertools.chain.from_iterable(self.comments.values()):
1748             res += str(comment)
1749
1750         return res + "\n"
1751
1752     def __bool__(self) -> bool:
1753         """Return True if the line has leaves or comments."""
1754         return bool(self.leaves or self.comments)
1755
1756
1757 @dataclass
1758 class EmptyLineTracker:
1759     """Provides a stateful method that returns the number of potential extra
1760     empty lines needed before and after the currently processed line.
1761
1762     Note: this tracker works on lines that haven't been split yet.  It assumes
1763     the prefix of the first leaf consists of optional newlines.  Those newlines
1764     are consumed by `maybe_empty_lines()` and included in the computation.
1765     """
1766
1767     is_pyi: bool = False
1768     previous_line: Optional[Line] = None
1769     previous_after: int = 0
1770     previous_defs: List[int] = field(default_factory=list)
1771
1772     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1773         """Return the number of extra empty lines before and after the `current_line`.
1774
1775         This is for separating `def`, `async def` and `class` with extra empty
1776         lines (two on module-level).
1777         """
1778         before, after = self._maybe_empty_lines(current_line)
1779         before = (
1780             # Black should not insert empty lines at the beginning
1781             # of the file
1782             0
1783             if self.previous_line is None
1784             else before - self.previous_after
1785         )
1786         self.previous_after = after
1787         self.previous_line = current_line
1788         return before, after
1789
1790     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1791         max_allowed = 1
1792         if current_line.depth == 0:
1793             max_allowed = 1 if self.is_pyi else 2
1794         if current_line.leaves:
1795             # Consume the first leaf's extra newlines.
1796             first_leaf = current_line.leaves[0]
1797             before = first_leaf.prefix.count("\n")
1798             before = min(before, max_allowed)
1799             first_leaf.prefix = ""
1800         else:
1801             before = 0
1802         depth = current_line.depth
1803         while self.previous_defs and self.previous_defs[-1] >= depth:
1804             self.previous_defs.pop()
1805             if self.is_pyi:
1806                 before = 0 if depth else 1
1807             else:
1808                 before = 1 if depth else 2
1809         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1810             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1811
1812         if (
1813             self.previous_line
1814             and self.previous_line.is_import
1815             and not current_line.is_import
1816             and depth == self.previous_line.depth
1817         ):
1818             return (before or 1), 0
1819
1820         if (
1821             self.previous_line
1822             and self.previous_line.is_class
1823             and current_line.is_triple_quoted_string
1824         ):
1825             return before, 1
1826
1827         return before, 0
1828
1829     def _maybe_empty_lines_for_class_or_def(
1830         self, current_line: Line, before: int
1831     ) -> Tuple[int, int]:
1832         if not current_line.is_decorator:
1833             self.previous_defs.append(current_line.depth)
1834         if self.previous_line is None:
1835             # Don't insert empty lines before the first line in the file.
1836             return 0, 0
1837
1838         if self.previous_line.is_decorator:
1839             if self.is_pyi and current_line.is_stub_class:
1840                 # Insert an empty line after a decorated stub class
1841                 return 0, 1
1842
1843             return 0, 0
1844
1845         if self.previous_line.depth < current_line.depth and (
1846             self.previous_line.is_class or self.previous_line.is_def
1847         ):
1848             return 0, 0
1849
1850         if (
1851             self.previous_line.is_comment
1852             and self.previous_line.depth == current_line.depth
1853             and before == 0
1854         ):
1855             return 0, 0
1856
1857         if self.is_pyi:
1858             if self.previous_line.depth > current_line.depth:
1859                 newlines = 1
1860             elif current_line.is_class or self.previous_line.is_class:
1861                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1862                     # No blank line between classes with an empty body
1863                     newlines = 0
1864                 else:
1865                     newlines = 1
1866             elif (
1867                 current_line.is_def or current_line.is_decorator
1868             ) and not self.previous_line.is_def:
1869                 # Blank line between a block of functions (maybe with preceding
1870                 # decorators) and a block of non-functions
1871                 newlines = 1
1872             else:
1873                 newlines = 0
1874         else:
1875             newlines = 2
1876         if current_line.depth and newlines:
1877             newlines -= 1
1878         return newlines, 0
1879
1880
1881 @dataclass
1882 class LineGenerator(Visitor[Line]):
1883     """Generates reformatted Line objects.  Empty lines are not emitted.
1884
1885     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1886     in ways that will no longer stringify to valid Python code on the tree.
1887     """
1888
1889     is_pyi: bool = False
1890     normalize_strings: bool = True
1891     current_line: Line = field(default_factory=Line)
1892     remove_u_prefix: bool = False
1893
1894     def line(self, indent: int = 0) -> Iterator[Line]:
1895         """Generate a line.
1896
1897         If the line is empty, only emit if it makes sense.
1898         If the line is too long, split it first and then generate.
1899
1900         If any lines were generated, set up a new current_line.
1901         """
1902         if not self.current_line:
1903             self.current_line.depth += indent
1904             return  # Line is empty, don't emit. Creating a new one unnecessary.
1905
1906         complete_line = self.current_line
1907         self.current_line = Line(depth=complete_line.depth + indent)
1908         yield complete_line
1909
1910     def visit_default(self, node: LN) -> Iterator[Line]:
1911         """Default `visit_*()` implementation. Recurses to children of `node`."""
1912         if isinstance(node, Leaf):
1913             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1914             for comment in generate_comments(node):
1915                 if any_open_brackets:
1916                     # any comment within brackets is subject to splitting
1917                     self.current_line.append(comment)
1918                 elif comment.type == token.COMMENT:
1919                     # regular trailing comment
1920                     self.current_line.append(comment)
1921                     yield from self.line()
1922
1923                 else:
1924                     # regular standalone comment
1925                     yield from self.line()
1926
1927                     self.current_line.append(comment)
1928                     yield from self.line()
1929
1930             normalize_prefix(node, inside_brackets=any_open_brackets)
1931             if self.normalize_strings and node.type == token.STRING:
1932                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1933                 normalize_string_quotes(node)
1934             if node.type == token.NUMBER:
1935                 normalize_numeric_literal(node)
1936             if node.type not in WHITESPACE:
1937                 self.current_line.append(node)
1938         yield from super().visit_default(node)
1939
1940     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1941         """Increase indentation level, maybe yield a line."""
1942         # In blib2to3 INDENT never holds comments.
1943         yield from self.line(+1)
1944         yield from self.visit_default(node)
1945
1946     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1947         """Decrease indentation level, maybe yield a line."""
1948         # The current line might still wait for trailing comments.  At DEDENT time
1949         # there won't be any (they would be prefixes on the preceding NEWLINE).
1950         # Emit the line then.
1951         yield from self.line()
1952
1953         # While DEDENT has no value, its prefix may contain standalone comments
1954         # that belong to the current indentation level.  Get 'em.
1955         yield from self.visit_default(node)
1956
1957         # Finally, emit the dedent.
1958         yield from self.line(-1)
1959
1960     def visit_stmt(
1961         self, node: Node, keywords: Set[str], parens: Set[str]
1962     ) -> Iterator[Line]:
1963         """Visit a statement.
1964
1965         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1966         `def`, `with`, `class`, `assert` and assignments.
1967
1968         The relevant Python language `keywords` for a given statement will be
1969         NAME leaves within it. This methods puts those on a separate line.
1970
1971         `parens` holds a set of string leaf values immediately after which
1972         invisible parens should be put.
1973         """
1974         normalize_invisible_parens(node, parens_after=parens)
1975         for child in node.children:
1976             if child.type == token.NAME and child.value in keywords:  # type: ignore
1977                 yield from self.line()
1978
1979             yield from self.visit(child)
1980
1981     def visit_suite(self, node: Node) -> Iterator[Line]:
1982         """Visit a suite."""
1983         if self.is_pyi and is_stub_suite(node):
1984             yield from self.visit(node.children[2])
1985         else:
1986             yield from self.visit_default(node)
1987
1988     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1989         """Visit a statement without nested statements."""
1990         is_suite_like = node.parent and node.parent.type in STATEMENT
1991         if is_suite_like:
1992             if self.is_pyi and is_stub_body(node):
1993                 yield from self.visit_default(node)
1994             else:
1995                 yield from self.line(+1)
1996                 yield from self.visit_default(node)
1997                 yield from self.line(-1)
1998
1999         else:
2000             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
2001                 yield from self.line()
2002             yield from self.visit_default(node)
2003
2004     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2005         """Visit `async def`, `async for`, `async with`."""
2006         yield from self.line()
2007
2008         children = iter(node.children)
2009         for child in children:
2010             yield from self.visit(child)
2011
2012             if child.type == token.ASYNC:
2013                 break
2014
2015         internal_stmt = next(children)
2016         for child in internal_stmt.children:
2017             yield from self.visit(child)
2018
2019     def visit_decorators(self, node: Node) -> Iterator[Line]:
2020         """Visit decorators."""
2021         for child in node.children:
2022             yield from self.line()
2023             yield from self.visit(child)
2024
2025     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2026         """Remove a semicolon and put the other statement on a separate line."""
2027         yield from self.line()
2028
2029     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2030         """End of file. Process outstanding comments and end with a newline."""
2031         yield from self.visit_default(leaf)
2032         yield from self.line()
2033
2034     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2035         if not self.current_line.bracket_tracker.any_open_brackets():
2036             yield from self.line()
2037         yield from self.visit_default(leaf)
2038
2039     def visit_factor(self, node: Node) -> Iterator[Line]:
2040         """Force parentheses between a unary op and a binary power:
2041
2042         -2 ** 8 -> -(2 ** 8)
2043         """
2044         _operator, operand = node.children
2045         if (
2046             operand.type == syms.power
2047             and len(operand.children) == 3
2048             and operand.children[1].type == token.DOUBLESTAR
2049         ):
2050             lpar = Leaf(token.LPAR, "(")
2051             rpar = Leaf(token.RPAR, ")")
2052             index = operand.remove() or 0
2053             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2054         yield from self.visit_default(node)
2055
2056     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2057         if is_docstring(leaf) and "\\\n" not in leaf.value:
2058             # We're ignoring docstrings with backslash newline escapes because changing
2059             # indentation of those changes the AST representation of the code.
2060             prefix = get_string_prefix(leaf.value)
2061             lead_len = len(prefix) + 3
2062             tail_len = -3
2063             indent = " " * 4 * self.current_line.depth
2064             docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
2065             if docstring:
2066                 if leaf.value[lead_len - 1] == docstring[0]:
2067                     docstring = " " + docstring
2068                 if leaf.value[tail_len + 1] == docstring[-1]:
2069                     docstring = docstring + " "
2070             leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
2071
2072         yield from self.visit_default(leaf)
2073
2074     def __post_init__(self) -> None:
2075         """You are in a twisty little maze of passages."""
2076         v = self.visit_stmt
2077         Ø: Set[str] = set()
2078         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2079         self.visit_if_stmt = partial(
2080             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2081         )
2082         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2083         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2084         self.visit_try_stmt = partial(
2085             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2086         )
2087         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2088         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2089         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2090         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2091         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2092         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2093         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2094         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2095         self.visit_async_funcdef = self.visit_async_stmt
2096         self.visit_decorated = self.visit_decorators
2097
2098
2099 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2100 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2101 OPENING_BRACKETS = set(BRACKET.keys())
2102 CLOSING_BRACKETS = set(BRACKET.values())
2103 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2104 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2105
2106
2107 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2108     """Return whitespace prefix if needed for the given `leaf`.
2109
2110     `complex_subscript` signals whether the given leaf is part of a subscription
2111     which has non-trivial arguments, like arithmetic expressions or function calls.
2112     """
2113     NO = ""
2114     SPACE = " "
2115     DOUBLESPACE = "  "
2116     t = leaf.type
2117     p = leaf.parent
2118     v = leaf.value
2119     if t in ALWAYS_NO_SPACE:
2120         return NO
2121
2122     if t == token.COMMENT:
2123         return DOUBLESPACE
2124
2125     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2126     if t == token.COLON and p.type not in {
2127         syms.subscript,
2128         syms.subscriptlist,
2129         syms.sliceop,
2130     }:
2131         return NO
2132
2133     prev = leaf.prev_sibling
2134     if not prev:
2135         prevp = preceding_leaf(p)
2136         if not prevp or prevp.type in OPENING_BRACKETS:
2137             return NO
2138
2139         if t == token.COLON:
2140             if prevp.type == token.COLON:
2141                 return NO
2142
2143             elif prevp.type != token.COMMA and not complex_subscript:
2144                 return NO
2145
2146             return SPACE
2147
2148         if prevp.type == token.EQUAL:
2149             if prevp.parent:
2150                 if prevp.parent.type in {
2151                     syms.arglist,
2152                     syms.argument,
2153                     syms.parameters,
2154                     syms.varargslist,
2155                 }:
2156                     return NO
2157
2158                 elif prevp.parent.type == syms.typedargslist:
2159                     # A bit hacky: if the equal sign has whitespace, it means we
2160                     # previously found it's a typed argument.  So, we're using
2161                     # that, too.
2162                     return prevp.prefix
2163
2164         elif prevp.type in VARARGS_SPECIALS:
2165             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2166                 return NO
2167
2168         elif prevp.type == token.COLON:
2169             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2170                 return SPACE if complex_subscript else NO
2171
2172         elif (
2173             prevp.parent
2174             and prevp.parent.type == syms.factor
2175             and prevp.type in MATH_OPERATORS
2176         ):
2177             return NO
2178
2179         elif (
2180             prevp.type == token.RIGHTSHIFT
2181             and prevp.parent
2182             and prevp.parent.type == syms.shift_expr
2183             and prevp.prev_sibling
2184             and prevp.prev_sibling.type == token.NAME
2185             and prevp.prev_sibling.value == "print"  # type: ignore
2186         ):
2187             # Python 2 print chevron
2188             return NO
2189         elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
2190             # no space in decorators
2191             return NO
2192
2193     elif prev.type in OPENING_BRACKETS:
2194         return NO
2195
2196     if p.type in {syms.parameters, syms.arglist}:
2197         # untyped function signatures or calls
2198         if not prev or prev.type != token.COMMA:
2199             return NO
2200
2201     elif p.type == syms.varargslist:
2202         # lambdas
2203         if prev and prev.type != token.COMMA:
2204             return NO
2205
2206     elif p.type == syms.typedargslist:
2207         # typed function signatures
2208         if not prev:
2209             return NO
2210
2211         if t == token.EQUAL:
2212             if prev.type != syms.tname:
2213                 return NO
2214
2215         elif prev.type == token.EQUAL:
2216             # A bit hacky: if the equal sign has whitespace, it means we
2217             # previously found it's a typed argument.  So, we're using that, too.
2218             return prev.prefix
2219
2220         elif prev.type != token.COMMA:
2221             return NO
2222
2223     elif p.type == syms.tname:
2224         # type names
2225         if not prev:
2226             prevp = preceding_leaf(p)
2227             if not prevp or prevp.type != token.COMMA:
2228                 return NO
2229
2230     elif p.type == syms.trailer:
2231         # attributes and calls
2232         if t == token.LPAR or t == token.RPAR:
2233             return NO
2234
2235         if not prev:
2236             if t == token.DOT:
2237                 prevp = preceding_leaf(p)
2238                 if not prevp or prevp.type != token.NUMBER:
2239                     return NO
2240
2241             elif t == token.LSQB:
2242                 return NO
2243
2244         elif prev.type != token.COMMA:
2245             return NO
2246
2247     elif p.type == syms.argument:
2248         # single argument
2249         if t == token.EQUAL:
2250             return NO
2251
2252         if not prev:
2253             prevp = preceding_leaf(p)
2254             if not prevp or prevp.type == token.LPAR:
2255                 return NO
2256
2257         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2258             return NO
2259
2260     elif p.type == syms.decorator:
2261         # decorators
2262         return NO
2263
2264     elif p.type == syms.dotted_name:
2265         if prev:
2266             return NO
2267
2268         prevp = preceding_leaf(p)
2269         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2270             return NO
2271
2272     elif p.type == syms.classdef:
2273         if t == token.LPAR:
2274             return NO
2275
2276         if prev and prev.type == token.LPAR:
2277             return NO
2278
2279     elif p.type in {syms.subscript, syms.sliceop}:
2280         # indexing
2281         if not prev:
2282             assert p.parent is not None, "subscripts are always parented"
2283             if p.parent.type == syms.subscriptlist:
2284                 return SPACE
2285
2286             return NO
2287
2288         elif not complex_subscript:
2289             return NO
2290
2291     elif p.type == syms.atom:
2292         if prev and t == token.DOT:
2293             # dots, but not the first one.
2294             return NO
2295
2296     elif p.type == syms.dictsetmaker:
2297         # dict unpacking
2298         if prev and prev.type == token.DOUBLESTAR:
2299             return NO
2300
2301     elif p.type in {syms.factor, syms.star_expr}:
2302         # unary ops
2303         if not prev:
2304             prevp = preceding_leaf(p)
2305             if not prevp or prevp.type in OPENING_BRACKETS:
2306                 return NO
2307
2308             prevp_parent = prevp.parent
2309             assert prevp_parent is not None
2310             if prevp.type == token.COLON and prevp_parent.type in {
2311                 syms.subscript,
2312                 syms.sliceop,
2313             }:
2314                 return NO
2315
2316             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2317                 return NO
2318
2319         elif t in {token.NAME, token.NUMBER, token.STRING}:
2320             return NO
2321
2322     elif p.type == syms.import_from:
2323         if t == token.DOT:
2324             if prev and prev.type == token.DOT:
2325                 return NO
2326
2327         elif t == token.NAME:
2328             if v == "import":
2329                 return SPACE
2330
2331             if prev and prev.type == token.DOT:
2332                 return NO
2333
2334     elif p.type == syms.sliceop:
2335         return NO
2336
2337     return SPACE
2338
2339
2340 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2341     """Return the first leaf that precedes `node`, if any."""
2342     while node:
2343         res = node.prev_sibling
2344         if res:
2345             if isinstance(res, Leaf):
2346                 return res
2347
2348             try:
2349                 return list(res.leaves())[-1]
2350
2351             except IndexError:
2352                 return None
2353
2354         node = node.parent
2355     return None
2356
2357
2358 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2359     """Return if the `node` and its previous siblings match types against the provided
2360     list of tokens; the provided `node`has its type matched against the last element in
2361     the list.  `None` can be used as the first element to declare that the start of the
2362     list is anchored at the start of its parent's children."""
2363     if not tokens:
2364         return True
2365     if tokens[-1] is None:
2366         return node is None
2367     if not node:
2368         return False
2369     if node.type != tokens[-1]:
2370         return False
2371     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2372
2373
2374 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2375     """Return the child of `ancestor` that contains `descendant`."""
2376     node: Optional[LN] = descendant
2377     while node and node.parent != ancestor:
2378         node = node.parent
2379     return node
2380
2381
2382 def container_of(leaf: Leaf) -> LN:
2383     """Return `leaf` or one of its ancestors that is the topmost container of it.
2384
2385     By "container" we mean a node where `leaf` is the very first child.
2386     """
2387     same_prefix = leaf.prefix
2388     container: LN = leaf
2389     while container:
2390         parent = container.parent
2391         if parent is None:
2392             break
2393
2394         if parent.children[0].prefix != same_prefix:
2395             break
2396
2397         if parent.type == syms.file_input:
2398             break
2399
2400         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2401             break
2402
2403         container = parent
2404     return container
2405
2406
2407 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2408     """Return the priority of the `leaf` delimiter, given a line break after it.
2409
2410     The delimiter priorities returned here are from those delimiters that would
2411     cause a line break after themselves.
2412
2413     Higher numbers are higher priority.
2414     """
2415     if leaf.type == token.COMMA:
2416         return COMMA_PRIORITY
2417
2418     return 0
2419
2420
2421 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2422     """Return the priority of the `leaf` delimiter, given a line break before it.
2423
2424     The delimiter priorities returned here are from those delimiters that would
2425     cause a line break before themselves.
2426
2427     Higher numbers are higher priority.
2428     """
2429     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2430         # * and ** might also be MATH_OPERATORS but in this case they are not.
2431         # Don't treat them as a delimiter.
2432         return 0
2433
2434     if (
2435         leaf.type == token.DOT
2436         and leaf.parent
2437         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2438         and (previous is None or previous.type in CLOSING_BRACKETS)
2439     ):
2440         return DOT_PRIORITY
2441
2442     if (
2443         leaf.type in MATH_OPERATORS
2444         and leaf.parent
2445         and leaf.parent.type not in {syms.factor, syms.star_expr}
2446     ):
2447         return MATH_PRIORITIES[leaf.type]
2448
2449     if leaf.type in COMPARATORS:
2450         return COMPARATOR_PRIORITY
2451
2452     if (
2453         leaf.type == token.STRING
2454         and previous is not None
2455         and previous.type == token.STRING
2456     ):
2457         return STRING_PRIORITY
2458
2459     if leaf.type not in {token.NAME, token.ASYNC}:
2460         return 0
2461
2462     if (
2463         leaf.value == "for"
2464         and leaf.parent
2465         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2466         or leaf.type == token.ASYNC
2467     ):
2468         if (
2469             not isinstance(leaf.prev_sibling, Leaf)
2470             or leaf.prev_sibling.value != "async"
2471         ):
2472             return COMPREHENSION_PRIORITY
2473
2474     if (
2475         leaf.value == "if"
2476         and leaf.parent
2477         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2478     ):
2479         return COMPREHENSION_PRIORITY
2480
2481     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2482         return TERNARY_PRIORITY
2483
2484     if leaf.value == "is":
2485         return COMPARATOR_PRIORITY
2486
2487     if (
2488         leaf.value == "in"
2489         and leaf.parent
2490         and leaf.parent.type in {syms.comp_op, syms.comparison}
2491         and not (
2492             previous is not None
2493             and previous.type == token.NAME
2494             and previous.value == "not"
2495         )
2496     ):
2497         return COMPARATOR_PRIORITY
2498
2499     if (
2500         leaf.value == "not"
2501         and leaf.parent
2502         and leaf.parent.type == syms.comp_op
2503         and not (
2504             previous is not None
2505             and previous.type == token.NAME
2506             and previous.value == "is"
2507         )
2508     ):
2509         return COMPARATOR_PRIORITY
2510
2511     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2512         return LOGIC_PRIORITY
2513
2514     return 0
2515
2516
2517 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2518 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2519
2520
2521 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2522     """Clean the prefix of the `leaf` and generate comments from it, if any.
2523
2524     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2525     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2526     move because it does away with modifying the grammar to include all the
2527     possible places in which comments can be placed.
2528
2529     The sad consequence for us though is that comments don't "belong" anywhere.
2530     This is why this function generates simple parentless Leaf objects for
2531     comments.  We simply don't know what the correct parent should be.
2532
2533     No matter though, we can live without this.  We really only need to
2534     differentiate between inline and standalone comments.  The latter don't
2535     share the line with any code.
2536
2537     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2538     are emitted with a fake STANDALONE_COMMENT token identifier.
2539     """
2540     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2541         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2542
2543
2544 @dataclass
2545 class ProtoComment:
2546     """Describes a piece of syntax that is a comment.
2547
2548     It's not a :class:`blib2to3.pytree.Leaf` so that:
2549
2550     * it can be cached (`Leaf` objects should not be reused more than once as
2551       they store their lineno, column, prefix, and parent information);
2552     * `newlines` and `consumed` fields are kept separate from the `value`. This
2553       simplifies handling of special marker comments like ``# fmt: off/on``.
2554     """
2555
2556     type: int  # token.COMMENT or STANDALONE_COMMENT
2557     value: str  # content of the comment
2558     newlines: int  # how many newlines before the comment
2559     consumed: int  # how many characters of the original leaf's prefix did we consume
2560
2561
2562 @lru_cache(maxsize=4096)
2563 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2564     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2565     result: List[ProtoComment] = []
2566     if not prefix or "#" not in prefix:
2567         return result
2568
2569     consumed = 0
2570     nlines = 0
2571     ignored_lines = 0
2572     for index, line in enumerate(prefix.split("\n")):
2573         consumed += len(line) + 1  # adding the length of the split '\n'
2574         line = line.lstrip()
2575         if not line:
2576             nlines += 1
2577         if not line.startswith("#"):
2578             # Escaped newlines outside of a comment are not really newlines at
2579             # all. We treat a single-line comment following an escaped newline
2580             # as a simple trailing comment.
2581             if line.endswith("\\"):
2582                 ignored_lines += 1
2583             continue
2584
2585         if index == ignored_lines and not is_endmarker:
2586             comment_type = token.COMMENT  # simple trailing comment
2587         else:
2588             comment_type = STANDALONE_COMMENT
2589         comment = make_comment(line)
2590         result.append(
2591             ProtoComment(
2592                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2593             )
2594         )
2595         nlines = 0
2596     return result
2597
2598
2599 def make_comment(content: str) -> str:
2600     """Return a consistently formatted comment from the given `content` string.
2601
2602     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2603     space between the hash sign and the content.
2604
2605     If `content` didn't start with a hash sign, one is provided.
2606     """
2607     content = content.rstrip()
2608     if not content:
2609         return "#"
2610
2611     if content[0] == "#":
2612         content = content[1:]
2613     if content and content[0] not in " !:#'%":
2614         content = " " + content
2615     return "#" + content
2616
2617
2618 def transform_line(
2619     line: Line, mode: Mode, features: Collection[Feature] = ()
2620 ) -> Iterator[Line]:
2621     """Transform a `line`, potentially splitting it into many lines.
2622
2623     They should fit in the allotted `line_length` but might not be able to.
2624
2625     `features` are syntactical features that may be used in the output.
2626     """
2627     if line.is_comment:
2628         yield line
2629         return
2630
2631     line_str = line_to_string(line)
2632
2633     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2634         """Initialize StringTransformer"""
2635         return ST(mode.line_length, mode.string_normalization)
2636
2637     string_merge = init_st(StringMerger)
2638     string_paren_strip = init_st(StringParenStripper)
2639     string_split = init_st(StringSplitter)
2640     string_paren_wrap = init_st(StringParenWrapper)
2641
2642     transformers: List[Transformer]
2643     if (
2644         not line.contains_uncollapsable_type_comments()
2645         and not line.should_explode
2646         and (
2647             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2648             or line.contains_unsplittable_type_ignore()
2649         )
2650         and not (line.inside_brackets and line.contains_standalone_comments())
2651     ):
2652         # Only apply basic string preprocessing, since lines shouldn't be split here.
2653         if mode.experimental_string_processing:
2654             transformers = [string_merge, string_paren_strip]
2655         else:
2656             transformers = []
2657     elif line.is_def:
2658         transformers = [left_hand_split]
2659     else:
2660
2661         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2662             """Wraps calls to `right_hand_split`.
2663
2664             The calls increasingly `omit` right-hand trailers (bracket pairs with
2665             content), meaning the trailers get glued together to split on another
2666             bracket pair instead.
2667             """
2668             for omit in generate_trailers_to_omit(line, mode.line_length):
2669                 lines = list(
2670                     right_hand_split(line, mode.line_length, features, omit=omit)
2671                 )
2672                 # Note: this check is only able to figure out if the first line of the
2673                 # *current* transformation fits in the line length.  This is true only
2674                 # for simple cases.  All others require running more transforms via
2675                 # `transform_line()`.  This check doesn't know if those would succeed.
2676                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2677                     yield from lines
2678                     return
2679
2680             # All splits failed, best effort split with no omits.
2681             # This mostly happens to multiline strings that are by definition
2682             # reported as not fitting a single line, as well as lines that contain
2683             # trailing commas (those have to be exploded).
2684             yield from right_hand_split(
2685                 line, line_length=mode.line_length, features=features
2686             )
2687
2688         if mode.experimental_string_processing:
2689             if line.inside_brackets:
2690                 transformers = [
2691                     string_merge,
2692                     string_paren_strip,
2693                     string_split,
2694                     delimiter_split,
2695                     standalone_comment_split,
2696                     string_paren_wrap,
2697                     rhs,
2698                 ]
2699             else:
2700                 transformers = [
2701                     string_merge,
2702                     string_paren_strip,
2703                     string_split,
2704                     string_paren_wrap,
2705                     rhs,
2706                 ]
2707         else:
2708             if line.inside_brackets:
2709                 transformers = [delimiter_split, standalone_comment_split, rhs]
2710             else:
2711                 transformers = [rhs]
2712
2713     for transform in transformers:
2714         # We are accumulating lines in `result` because we might want to abort
2715         # mission and return the original line in the end, or attempt a different
2716         # split altogether.
2717         try:
2718             result = run_transformer(line, transform, mode, features, line_str=line_str)
2719         except CannotTransform:
2720             continue
2721         else:
2722             yield from result
2723             break
2724
2725     else:
2726         yield line
2727
2728
2729 @dataclass  # type: ignore
2730 class StringTransformer(ABC):
2731     """
2732     An implementation of the Transformer protocol that relies on its
2733     subclasses overriding the template methods `do_match(...)` and
2734     `do_transform(...)`.
2735
2736     This Transformer works exclusively on strings (for example, by merging
2737     or splitting them).
2738
2739     The following sections can be found among the docstrings of each concrete
2740     StringTransformer subclass.
2741
2742     Requirements:
2743         Which requirements must be met of the given Line for this
2744         StringTransformer to be applied?
2745
2746     Transformations:
2747         If the given Line meets all of the above requirements, which string
2748         transformations can you expect to be applied to it by this
2749         StringTransformer?
2750
2751     Collaborations:
2752         What contractual agreements does this StringTransformer have with other
2753         StringTransfomers? Such collaborations should be eliminated/minimized
2754         as much as possible.
2755     """
2756
2757     line_length: int
2758     normalize_strings: bool
2759     __name__ = "StringTransformer"
2760
2761     @abstractmethod
2762     def do_match(self, line: Line) -> TMatchResult:
2763         """
2764         Returns:
2765             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2766             string, if a match was able to be made.
2767                 OR
2768             * Err(CannotTransform), if a match was not able to be made.
2769         """
2770
2771     @abstractmethod
2772     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2773         """
2774         Yields:
2775             * Ok(new_line) where new_line is the new transformed line.
2776                 OR
2777             * Err(CannotTransform) if the transformation failed for some reason. The
2778             `do_match(...)` template method should usually be used to reject
2779             the form of the given Line, but in some cases it is difficult to
2780             know whether or not a Line meets the StringTransformer's
2781             requirements until the transformation is already midway.
2782
2783         Side Effects:
2784             This method should NOT mutate @line directly, but it MAY mutate the
2785             Line's underlying Node structure. (WARNING: If the underlying Node
2786             structure IS altered, then this method should NOT be allowed to
2787             yield an CannotTransform after that point.)
2788         """
2789
2790     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2791         """
2792         StringTransformer instances have a call signature that mirrors that of
2793         the Transformer type.
2794
2795         Raises:
2796             CannotTransform(...) if the concrete StringTransformer class is unable
2797             to transform @line.
2798         """
2799         # Optimization to avoid calling `self.do_match(...)` when the line does
2800         # not contain any string.
2801         if not any(leaf.type == token.STRING for leaf in line.leaves):
2802             raise CannotTransform("There are no strings in this line.")
2803
2804         match_result = self.do_match(line)
2805
2806         if isinstance(match_result, Err):
2807             cant_transform = match_result.err()
2808             raise CannotTransform(
2809                 f"The string transformer {self.__class__.__name__} does not recognize"
2810                 " this line as one that it can transform."
2811             ) from cant_transform
2812
2813         string_idx = match_result.ok()
2814
2815         for line_result in self.do_transform(line, string_idx):
2816             if isinstance(line_result, Err):
2817                 cant_transform = line_result.err()
2818                 raise CannotTransform(
2819                     "StringTransformer failed while attempting to transform string."
2820                 ) from cant_transform
2821             line = line_result.ok()
2822             yield line
2823
2824
2825 @dataclass
2826 class CustomSplit:
2827     """A custom (i.e. manual) string split.
2828
2829     A single CustomSplit instance represents a single substring.
2830
2831     Examples:
2832         Consider the following string:
2833         ```
2834         "Hi there friend."
2835         " This is a custom"
2836         f" string {split}."
2837         ```
2838
2839         This string will correspond to the following three CustomSplit instances:
2840         ```
2841         CustomSplit(False, 16)
2842         CustomSplit(False, 17)
2843         CustomSplit(True, 16)
2844         ```
2845     """
2846
2847     has_prefix: bool
2848     break_idx: int
2849
2850
2851 class CustomSplitMapMixin:
2852     """
2853     This mixin class is used to map merged strings to a sequence of
2854     CustomSplits, which will then be used to re-split the strings iff none of
2855     the resultant substrings go over the configured max line length.
2856     """
2857
2858     _Key = Tuple[StringID, str]
2859     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2860
2861     @staticmethod
2862     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2863         """
2864         Returns:
2865             A unique identifier that is used internally to map @string to a
2866             group of custom splits.
2867         """
2868         return (id(string), string)
2869
2870     def add_custom_splits(
2871         self, string: str, custom_splits: Iterable[CustomSplit]
2872     ) -> None:
2873         """Custom Split Map Setter Method
2874
2875         Side Effects:
2876             Adds a mapping from @string to the custom splits @custom_splits.
2877         """
2878         key = self._get_key(string)
2879         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2880
2881     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2882         """Custom Split Map Getter Method
2883
2884         Returns:
2885             * A list of the custom splits that are mapped to @string, if any
2886             exist.
2887                 OR
2888             * [], otherwise.
2889
2890         Side Effects:
2891             Deletes the mapping between @string and its associated custom
2892             splits (which are returned to the caller).
2893         """
2894         key = self._get_key(string)
2895
2896         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2897         del self._CUSTOM_SPLIT_MAP[key]
2898
2899         return list(custom_splits)
2900
2901     def has_custom_splits(self, string: str) -> bool:
2902         """
2903         Returns:
2904             True iff @string is associated with a set of custom splits.
2905         """
2906         key = self._get_key(string)
2907         return key in self._CUSTOM_SPLIT_MAP
2908
2909
2910 class StringMerger(CustomSplitMapMixin, StringTransformer):
2911     """StringTransformer that merges strings together.
2912
2913     Requirements:
2914         (A) The line contains adjacent strings such that ALL of the validation checks
2915         listed in StringMerger.__validate_msg(...)'s docstring pass.
2916             OR
2917         (B) The line contains a string which uses line continuation backslashes.
2918
2919     Transformations:
2920         Depending on which of the two requirements above where met, either:
2921
2922         (A) The string group associated with the target string is merged.
2923             OR
2924         (B) All line-continuation backslashes are removed from the target string.
2925
2926     Collaborations:
2927         StringMerger provides custom split information to StringSplitter.
2928     """
2929
2930     def do_match(self, line: Line) -> TMatchResult:
2931         LL = line.leaves
2932
2933         is_valid_index = is_valid_index_factory(LL)
2934
2935         for (i, leaf) in enumerate(LL):
2936             if (
2937                 leaf.type == token.STRING
2938                 and is_valid_index(i + 1)
2939                 and LL[i + 1].type == token.STRING
2940             ):
2941                 return Ok(i)
2942
2943             if leaf.type == token.STRING and "\\\n" in leaf.value:
2944                 return Ok(i)
2945
2946         return TErr("This line has no strings that need merging.")
2947
2948     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2949         new_line = line
2950         rblc_result = self.__remove_backslash_line_continuation_chars(
2951             new_line, string_idx
2952         )
2953         if isinstance(rblc_result, Ok):
2954             new_line = rblc_result.ok()
2955
2956         msg_result = self.__merge_string_group(new_line, string_idx)
2957         if isinstance(msg_result, Ok):
2958             new_line = msg_result.ok()
2959
2960         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
2961             msg_cant_transform = msg_result.err()
2962             rblc_cant_transform = rblc_result.err()
2963             cant_transform = CannotTransform(
2964                 "StringMerger failed to merge any strings in this line."
2965             )
2966
2967             # Chain the errors together using `__cause__`.
2968             msg_cant_transform.__cause__ = rblc_cant_transform
2969             cant_transform.__cause__ = msg_cant_transform
2970
2971             yield Err(cant_transform)
2972         else:
2973             yield Ok(new_line)
2974
2975     @staticmethod
2976     def __remove_backslash_line_continuation_chars(
2977         line: Line, string_idx: int
2978     ) -> TResult[Line]:
2979         """
2980         Merge strings that were split across multiple lines using
2981         line-continuation backslashes.
2982
2983         Returns:
2984             Ok(new_line), if @line contains backslash line-continuation
2985             characters.
2986                 OR
2987             Err(CannotTransform), otherwise.
2988         """
2989         LL = line.leaves
2990
2991         string_leaf = LL[string_idx]
2992         if not (
2993             string_leaf.type == token.STRING
2994             and "\\\n" in string_leaf.value
2995             and not has_triple_quotes(string_leaf.value)
2996         ):
2997             return TErr(
2998                 f"String leaf {string_leaf} does not contain any backslash line"
2999                 " continuation characters."
3000             )
3001
3002         new_line = line.clone()
3003         new_line.comments = line.comments.copy()
3004         append_leaves(new_line, line, LL)
3005
3006         new_string_leaf = new_line.leaves[string_idx]
3007         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3008
3009         return Ok(new_line)
3010
3011     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3012         """
3013         Merges string group (i.e. set of adjacent strings) where the first
3014         string in the group is `line.leaves[string_idx]`.
3015
3016         Returns:
3017             Ok(new_line), if ALL of the validation checks found in
3018             __validate_msg(...) pass.
3019                 OR
3020             Err(CannotTransform), otherwise.
3021         """
3022         LL = line.leaves
3023
3024         is_valid_index = is_valid_index_factory(LL)
3025
3026         vresult = self.__validate_msg(line, string_idx)
3027         if isinstance(vresult, Err):
3028             return vresult
3029
3030         # If the string group is wrapped inside an Atom node, we must make sure
3031         # to later replace that Atom with our new (merged) string leaf.
3032         atom_node = LL[string_idx].parent
3033
3034         # We will place BREAK_MARK in between every two substrings that we
3035         # merge. We will then later go through our final result and use the
3036         # various instances of BREAK_MARK we find to add the right values to
3037         # the custom split map.
3038         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3039
3040         QUOTE = LL[string_idx].value[-1]
3041
3042         def make_naked(string: str, string_prefix: str) -> str:
3043             """Strip @string (i.e. make it a "naked" string)
3044
3045             Pre-conditions:
3046                 * assert_is_leaf_string(@string)
3047
3048             Returns:
3049                 A string that is identical to @string except that
3050                 @string_prefix has been stripped, the surrounding QUOTE
3051                 characters have been removed, and any remaining QUOTE
3052                 characters have been escaped.
3053             """
3054             assert_is_leaf_string(string)
3055
3056             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3057             naked_string = string[len(string_prefix) + 1 : -1]
3058             naked_string = re.sub(
3059                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3060             )
3061             return naked_string
3062
3063         # Holds the CustomSplit objects that will later be added to the custom
3064         # split map.
3065         custom_splits = []
3066
3067         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3068         prefix_tracker = []
3069
3070         # Sets the 'prefix' variable. This is the prefix that the final merged
3071         # string will have.
3072         next_str_idx = string_idx
3073         prefix = ""
3074         while (
3075             not prefix
3076             and is_valid_index(next_str_idx)
3077             and LL[next_str_idx].type == token.STRING
3078         ):
3079             prefix = get_string_prefix(LL[next_str_idx].value)
3080             next_str_idx += 1
3081
3082         # The next loop merges the string group. The final string will be
3083         # contained in 'S'.
3084         #
3085         # The following convenience variables are used:
3086         #
3087         #   S: string
3088         #   NS: naked string
3089         #   SS: next string
3090         #   NSS: naked next string
3091         S = ""
3092         NS = ""
3093         num_of_strings = 0
3094         next_str_idx = string_idx
3095         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3096             num_of_strings += 1
3097
3098             SS = LL[next_str_idx].value
3099             next_prefix = get_string_prefix(SS)
3100
3101             # If this is an f-string group but this substring is not prefixed
3102             # with 'f'...
3103             if "f" in prefix and "f" not in next_prefix:
3104                 # Then we must escape any braces contained in this substring.
3105                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3106
3107             NSS = make_naked(SS, next_prefix)
3108
3109             has_prefix = bool(next_prefix)
3110             prefix_tracker.append(has_prefix)
3111
3112             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3113             NS = make_naked(S, prefix)
3114
3115             next_str_idx += 1
3116
3117         S_leaf = Leaf(token.STRING, S)
3118         if self.normalize_strings:
3119             normalize_string_quotes(S_leaf)
3120
3121         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3122         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3123         for has_prefix in prefix_tracker:
3124             mark_idx = temp_string.find(BREAK_MARK)
3125             assert (
3126                 mark_idx >= 0
3127             ), "Logic error while filling the custom string breakpoint cache."
3128
3129             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3130             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3131             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3132
3133         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3134
3135         if atom_node is not None:
3136             replace_child(atom_node, string_leaf)
3137
3138         # Build the final line ('new_line') that this method will later return.
3139         new_line = line.clone()
3140         for (i, leaf) in enumerate(LL):
3141             if i == string_idx:
3142                 new_line.append(string_leaf)
3143
3144             if string_idx <= i < string_idx + num_of_strings:
3145                 for comment_leaf in line.comments_after(LL[i]):
3146                     new_line.append(comment_leaf, preformatted=True)
3147                 continue
3148
3149             append_leaves(new_line, line, [leaf])
3150
3151         self.add_custom_splits(string_leaf.value, custom_splits)
3152         return Ok(new_line)
3153
3154     @staticmethod
3155     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3156         """Validate (M)erge (S)tring (G)roup
3157
3158         Transform-time string validation logic for __merge_string_group(...).
3159
3160         Returns:
3161             * Ok(None), if ALL validation checks (listed below) pass.
3162                 OR
3163             * Err(CannotTransform), if any of the following are true:
3164                 - The target string group does not contain ANY stand-alone comments.
3165                 - The target string is not in a string group (i.e. it has no
3166                   adjacent strings).
3167                 - The string group has more than one inline comment.
3168                 - The string group has an inline comment that appears to be a pragma.
3169                 - The set of all string prefixes in the string group is of
3170                   length greater than one and is not equal to {"", "f"}.
3171                 - The string group consists of raw strings.
3172         """
3173         # We first check for "inner" stand-alone comments (i.e. stand-alone
3174         # comments that have a string leaf before them AND after them).
3175         for inc in [1, -1]:
3176             i = string_idx
3177             found_sa_comment = False
3178             is_valid_index = is_valid_index_factory(line.leaves)
3179             while is_valid_index(i) and line.leaves[i].type in [
3180                 token.STRING,
3181                 STANDALONE_COMMENT,
3182             ]:
3183                 if line.leaves[i].type == STANDALONE_COMMENT:
3184                     found_sa_comment = True
3185                 elif found_sa_comment:
3186                     return TErr(
3187                         "StringMerger does NOT merge string groups which contain "
3188                         "stand-alone comments."
3189                     )
3190
3191                 i += inc
3192
3193         num_of_inline_string_comments = 0
3194         set_of_prefixes = set()
3195         num_of_strings = 0
3196         for leaf in line.leaves[string_idx:]:
3197             if leaf.type != token.STRING:
3198                 # If the string group is trailed by a comma, we count the
3199                 # comments trailing the comma to be one of the string group's
3200                 # comments.
3201                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3202                     num_of_inline_string_comments += 1
3203                 break
3204
3205             if has_triple_quotes(leaf.value):
3206                 return TErr("StringMerger does NOT merge multiline strings.")
3207
3208             num_of_strings += 1
3209             prefix = get_string_prefix(leaf.value)
3210             if "r" in prefix:
3211                 return TErr("StringMerger does NOT merge raw strings.")
3212
3213             set_of_prefixes.add(prefix)
3214
3215             if id(leaf) in line.comments:
3216                 num_of_inline_string_comments += 1
3217                 if contains_pragma_comment(line.comments[id(leaf)]):
3218                     return TErr("Cannot merge strings which have pragma comments.")
3219
3220         if num_of_strings < 2:
3221             return TErr(
3222                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3223             )
3224
3225         if num_of_inline_string_comments > 1:
3226             return TErr(
3227                 f"Too many inline string comments ({num_of_inline_string_comments})."
3228             )
3229
3230         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3231             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3232
3233         return Ok(None)
3234
3235
3236 class StringParenStripper(StringTransformer):
3237     """StringTransformer that strips surrounding parentheses from strings.
3238
3239     Requirements:
3240         The line contains a string which is surrounded by parentheses and:
3241             - The target string is NOT the only argument to a function call).
3242             - If the target string contains a PERCENT, the brackets are not
3243               preceeded or followed by an operator with higher precedence than
3244               PERCENT.
3245
3246     Transformations:
3247         The parentheses mentioned in the 'Requirements' section are stripped.
3248
3249     Collaborations:
3250         StringParenStripper has its own inherent usefulness, but it is also
3251         relied on to clean up the parentheses created by StringParenWrapper (in
3252         the event that they are no longer needed).
3253     """
3254
3255     def do_match(self, line: Line) -> TMatchResult:
3256         LL = line.leaves
3257
3258         is_valid_index = is_valid_index_factory(LL)
3259
3260         for (idx, leaf) in enumerate(LL):
3261             # Should be a string...
3262             if leaf.type != token.STRING:
3263                 continue
3264
3265             # Should be preceded by a non-empty LPAR...
3266             if (
3267                 not is_valid_index(idx - 1)
3268                 or LL[idx - 1].type != token.LPAR
3269                 or is_empty_lpar(LL[idx - 1])
3270             ):
3271                 continue
3272
3273             # That LPAR should NOT be preceded by a function name or a closing
3274             # bracket (which could be a function which returns a function or a
3275             # list/dictionary that contains a function)...
3276             if is_valid_index(idx - 2) and (
3277                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3278             ):
3279                 continue
3280
3281             string_idx = idx
3282
3283             # Skip the string trailer, if one exists.
3284             string_parser = StringParser()
3285             next_idx = string_parser.parse(LL, string_idx)
3286
3287             # if the leaves in the parsed string include a PERCENT, we need to
3288             # make sure the initial LPAR is NOT preceded by an operator with
3289             # higher or equal precedence to PERCENT
3290             if is_valid_index(idx - 2):
3291                 # mypy can't quite follow unless we name this
3292                 before_lpar = LL[idx - 2]
3293                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3294                     (
3295                         before_lpar.type
3296                         in {
3297                             token.STAR,
3298                             token.AT,
3299                             token.SLASH,
3300                             token.DOUBLESLASH,
3301                             token.PERCENT,
3302                             token.TILDE,
3303                             token.DOUBLESTAR,
3304                             token.AWAIT,
3305                             token.LSQB,
3306                             token.LPAR,
3307                         }
3308                     )
3309                     or (
3310                         # only unary PLUS/MINUS
3311                         before_lpar.parent
3312                         and before_lpar.parent.type == syms.factor
3313                         and (before_lpar.type in {token.PLUS, token.MINUS})
3314                     )
3315                 ):
3316                     continue
3317
3318             # Should be followed by a non-empty RPAR...
3319             if (
3320                 is_valid_index(next_idx)
3321                 and LL[next_idx].type == token.RPAR
3322                 and not is_empty_rpar(LL[next_idx])
3323             ):
3324                 # That RPAR should NOT be followed by anything with higher
3325                 # precedence than PERCENT
3326                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3327                     token.DOUBLESTAR,
3328                     token.LSQB,
3329                     token.LPAR,
3330                     token.DOT,
3331                 }:
3332                     continue
3333
3334                 return Ok(string_idx)
3335
3336         return TErr("This line has no strings wrapped in parens.")
3337
3338     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3339         LL = line.leaves
3340
3341         string_parser = StringParser()
3342         rpar_idx = string_parser.parse(LL, string_idx)
3343
3344         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3345             if line.comments_after(leaf):
3346                 yield TErr(
3347                     "Will not strip parentheses which have comments attached to them."
3348                 )
3349                 return
3350
3351         new_line = line.clone()
3352         new_line.comments = line.comments.copy()
3353         try:
3354             append_leaves(new_line, line, LL[: string_idx - 1])
3355         except BracketMatchError:
3356             # HACK: I believe there is currently a bug somewhere in
3357             # right_hand_split() that is causing brackets to not be tracked
3358             # properly by a shared BracketTracker.
3359             append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
3360
3361         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3362         LL[string_idx - 1].remove()
3363         replace_child(LL[string_idx], string_leaf)
3364         new_line.append(string_leaf)
3365
3366         append_leaves(
3367             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3368         )
3369
3370         LL[rpar_idx].remove()
3371
3372         yield Ok(new_line)
3373
3374
3375 class BaseStringSplitter(StringTransformer):
3376     """
3377     Abstract class for StringTransformers which transform a Line's strings by splitting
3378     them or placing them on their own lines where necessary to avoid going over
3379     the configured line length.
3380
3381     Requirements:
3382         * The target string value is responsible for the line going over the
3383         line length limit. It follows that after all of black's other line
3384         split methods have been exhausted, this line (or one of the resulting
3385         lines after all line splits are performed) would still be over the
3386         line_length limit unless we split this string.
3387             AND
3388         * The target string is NOT a "pointless" string (i.e. a string that has
3389         no parent or siblings).
3390             AND
3391         * The target string is not followed by an inline comment that appears
3392         to be a pragma.
3393             AND
3394         * The target string is not a multiline (i.e. triple-quote) string.
3395     """
3396
3397     @abstractmethod
3398     def do_splitter_match(self, line: Line) -> TMatchResult:
3399         """
3400         BaseStringSplitter asks its clients to override this method instead of
3401         `StringTransformer.do_match(...)`.
3402
3403         Follows the same protocol as `StringTransformer.do_match(...)`.
3404
3405         Refer to `help(StringTransformer.do_match)` for more information.
3406         """
3407
3408     def do_match(self, line: Line) -> TMatchResult:
3409         match_result = self.do_splitter_match(line)
3410         if isinstance(match_result, Err):
3411             return match_result
3412
3413         string_idx = match_result.ok()
3414         vresult = self.__validate(line, string_idx)
3415         if isinstance(vresult, Err):
3416             return vresult
3417
3418         return match_result
3419
3420     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3421         """
3422         Checks that @line meets all of the requirements listed in this classes'
3423         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3424         description of those requirements.
3425
3426         Returns:
3427             * Ok(None), if ALL of the requirements are met.
3428                 OR
3429             * Err(CannotTransform), if ANY of the requirements are NOT met.
3430         """
3431         LL = line.leaves
3432
3433         string_leaf = LL[string_idx]
3434
3435         max_string_length = self.__get_max_string_length(line, string_idx)
3436         if len(string_leaf.value) <= max_string_length:
3437             return TErr(
3438                 "The string itself is not what is causing this line to be too long."
3439             )
3440
3441         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3442             token.STRING,
3443             token.NEWLINE,
3444         ]:
3445             return TErr(
3446                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3447                 " no parent)."
3448             )
3449
3450         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3451             line.comments[id(line.leaves[string_idx])]
3452         ):
3453             return TErr(
3454                 "Line appears to end with an inline pragma comment. Splitting the line"
3455                 " could modify the pragma's behavior."
3456             )
3457
3458         if has_triple_quotes(string_leaf.value):
3459             return TErr("We cannot split multiline strings.")
3460
3461         return Ok(None)
3462
3463     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3464         """
3465         Calculates the max string length used when attempting to determine
3466         whether or not the target string is responsible for causing the line to
3467         go over the line length limit.
3468
3469         WARNING: This method is tightly coupled to both StringSplitter and
3470         (especially) StringParenWrapper. There is probably a better way to
3471         accomplish what is being done here.
3472
3473         Returns:
3474             max_string_length: such that `line.leaves[string_idx].value >
3475             max_string_length` implies that the target string IS responsible
3476             for causing this line to exceed the line length limit.
3477         """
3478         LL = line.leaves
3479
3480         is_valid_index = is_valid_index_factory(LL)
3481
3482         # We use the shorthand "WMA4" in comments to abbreviate "We must
3483         # account for". When giving examples, we use STRING to mean some/any
3484         # valid string.
3485         #
3486         # Finally, we use the following convenience variables:
3487         #
3488         #   P:  The leaf that is before the target string leaf.
3489         #   N:  The leaf that is after the target string leaf.
3490         #   NN: The leaf that is after N.
3491
3492         # WMA4 the whitespace at the beginning of the line.
3493         offset = line.depth * 4
3494
3495         if is_valid_index(string_idx - 1):
3496             p_idx = string_idx - 1
3497             if (
3498                 LL[string_idx - 1].type == token.LPAR
3499                 and LL[string_idx - 1].value == ""
3500                 and string_idx >= 2
3501             ):
3502                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3503                 p_idx -= 1
3504
3505             P = LL[p_idx]
3506             if P.type == token.PLUS:
3507                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3508                 offset += 2
3509
3510             if P.type == token.COMMA:
3511                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3512                 offset += 3
3513
3514             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3515                 # This conditional branch is meant to handle dictionary keys,
3516                 # variable assignments, 'return STRING' statement lines, and
3517                 # 'else STRING' ternary expression lines.
3518
3519                 # WMA4 a single space.
3520                 offset += 1
3521
3522                 # WMA4 the lengths of any leaves that came before that space,
3523                 # but after any closing bracket before that space.
3524                 for leaf in reversed(LL[: p_idx + 1]):
3525                     offset += len(str(leaf))
3526                     if leaf.type in CLOSING_BRACKETS:
3527                         break
3528
3529         if is_valid_index(string_idx + 1):
3530             N = LL[string_idx + 1]
3531             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3532                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3533                 N = LL[string_idx + 2]
3534
3535             if N.type == token.COMMA:
3536                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3537                 offset += 1
3538
3539             if is_valid_index(string_idx + 2):
3540                 NN = LL[string_idx + 2]
3541
3542                 if N.type == token.DOT and NN.type == token.NAME:
3543                     # This conditional branch is meant to handle method calls invoked
3544                     # off of a string literal up to and including the LPAR character.
3545
3546                     # WMA4 the '.' character.
3547                     offset += 1
3548
3549                     if (
3550                         is_valid_index(string_idx + 3)
3551                         and LL[string_idx + 3].type == token.LPAR
3552                     ):
3553                         # WMA4 the left parenthesis character.
3554                         offset += 1
3555
3556                     # WMA4 the length of the method's name.
3557                     offset += len(NN.value)
3558
3559         has_comments = False
3560         for comment_leaf in line.comments_after(LL[string_idx]):
3561             if not has_comments:
3562                 has_comments = True
3563                 # WMA4 two spaces before the '#' character.
3564                 offset += 2
3565
3566             # WMA4 the length of the inline comment.
3567             offset += len(comment_leaf.value)
3568
3569         max_string_length = self.line_length - offset
3570         return max_string_length
3571
3572
3573 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3574     """
3575     StringTransformer that splits "atom" strings (i.e. strings which exist on
3576     lines by themselves).
3577
3578     Requirements:
3579         * The line consists ONLY of a single string (with the exception of a
3580         '+' symbol which MAY exist at the start of the line), MAYBE a string
3581         trailer, and MAYBE a trailing comma.
3582             AND
3583         * All of the requirements listed in BaseStringSplitter's docstring.
3584
3585     Transformations:
3586         The string mentioned in the 'Requirements' section is split into as
3587         many substrings as necessary to adhere to the configured line length.
3588
3589         In the final set of substrings, no substring should be smaller than
3590         MIN_SUBSTR_SIZE characters.
3591
3592         The string will ONLY be split on spaces (i.e. each new substring should
3593         start with a space).
3594
3595         If the string is an f-string, it will NOT be split in the middle of an
3596         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3597         else bar()} is an f-expression).
3598
3599         If the string that is being split has an associated set of custom split
3600         records and those custom splits will NOT result in any line going over
3601         the configured line length, those custom splits are used. Otherwise the
3602         string is split as late as possible (from left-to-right) while still
3603         adhering to the transformation rules listed above.
3604
3605     Collaborations:
3606         StringSplitter relies on StringMerger to construct the appropriate
3607         CustomSplit objects and add them to the custom split map.
3608     """
3609
3610     MIN_SUBSTR_SIZE = 6
3611     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3612     RE_FEXPR = r"""
3613     (?<!\{)\{
3614         (?:
3615             [^\{\}]
3616             | \{\{
3617             | \}\}
3618         )+?
3619     (?<!\})(?:\}\})*\}(?!\})
3620     """
3621
3622     def do_splitter_match(self, line: Line) -> TMatchResult:
3623         LL = line.leaves
3624
3625         is_valid_index = is_valid_index_factory(LL)
3626
3627         idx = 0
3628
3629         # The first leaf MAY be a '+' symbol...
3630         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3631             idx += 1
3632
3633         # The next/first leaf MAY be an empty LPAR...
3634         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3635             idx += 1
3636
3637         # The next/first leaf MUST be a string...
3638         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3639             return TErr("Line does not start with a string.")
3640
3641         string_idx = idx
3642
3643         # Skip the string trailer, if one exists.
3644         string_parser = StringParser()
3645         idx = string_parser.parse(LL, string_idx)
3646
3647         # That string MAY be followed by an empty RPAR...
3648         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3649             idx += 1
3650
3651         # That string / empty RPAR leaf MAY be followed by a comma...
3652         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3653             idx += 1
3654
3655         # But no more leaves are allowed...
3656         if is_valid_index(idx):
3657             return TErr("This line does not end with a string.")
3658
3659         return Ok(string_idx)
3660
3661     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3662         LL = line.leaves
3663
3664         QUOTE = LL[string_idx].value[-1]
3665
3666         is_valid_index = is_valid_index_factory(LL)
3667         insert_str_child = insert_str_child_factory(LL[string_idx])
3668
3669         prefix = get_string_prefix(LL[string_idx].value)
3670
3671         # We MAY choose to drop the 'f' prefix from substrings that don't
3672         # contain any f-expressions, but ONLY if the original f-string
3673         # contains at least one f-expression. Otherwise, we will alter the AST
3674         # of the program.
3675         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3676             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3677         )
3678
3679         first_string_line = True
3680         starts_with_plus = LL[0].type == token.PLUS
3681
3682         def line_needs_plus() -> bool:
3683             return first_string_line and starts_with_plus
3684
3685         def maybe_append_plus(new_line: Line) -> None:
3686             """
3687             Side Effects:
3688                 If @line starts with a plus and this is the first line we are
3689                 constructing, this function appends a PLUS leaf to @new_line
3690                 and replaces the old PLUS leaf in the node structure. Otherwise
3691                 this function does nothing.
3692             """
3693             if line_needs_plus():
3694                 plus_leaf = Leaf(token.PLUS, "+")
3695                 replace_child(LL[0], plus_leaf)
3696                 new_line.append(plus_leaf)
3697
3698         ends_with_comma = (
3699             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3700         )
3701
3702         def max_last_string() -> int:
3703             """
3704             Returns:
3705                 The max allowed length of the string value used for the last
3706                 line we will construct.
3707             """
3708             result = self.line_length
3709             result -= line.depth * 4
3710             result -= 1 if ends_with_comma else 0
3711             result -= 2 if line_needs_plus() else 0
3712             return result
3713
3714         # --- Calculate Max Break Index (for string value)
3715         # We start with the line length limit
3716         max_break_idx = self.line_length
3717         # The last index of a string of length N is N-1.
3718         max_break_idx -= 1
3719         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3720         max_break_idx -= line.depth * 4
3721         if max_break_idx < 0:
3722             yield TErr(
3723                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3724                 f" {line.depth}"
3725             )
3726             return
3727
3728         # Check if StringMerger registered any custom splits.
3729         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3730         # We use them ONLY if none of them would produce lines that exceed the
3731         # line limit.
3732         use_custom_breakpoints = bool(
3733             custom_splits
3734             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3735         )
3736
3737         # Temporary storage for the remaining chunk of the string line that
3738         # can't fit onto the line currently being constructed.
3739         rest_value = LL[string_idx].value
3740
3741         def more_splits_should_be_made() -> bool:
3742             """
3743             Returns:
3744                 True iff `rest_value` (the remaining string value from the last
3745                 split), should be split again.
3746             """
3747             if use_custom_breakpoints:
3748                 return len(custom_splits) > 1
3749             else:
3750                 return len(rest_value) > max_last_string()
3751
3752         string_line_results: List[Ok[Line]] = []
3753         while more_splits_should_be_made():
3754             if use_custom_breakpoints:
3755                 # Custom User Split (manual)
3756                 csplit = custom_splits.pop(0)
3757                 break_idx = csplit.break_idx
3758             else:
3759                 # Algorithmic Split (automatic)
3760                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3761                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3762                 if maybe_break_idx is None:
3763                     # If we are unable to algorithmically determine a good split
3764                     # and this string has custom splits registered to it, we
3765                     # fall back to using them--which means we have to start
3766                     # over from the beginning.
3767                     if custom_splits:
3768                         rest_value = LL[string_idx].value
3769                         string_line_results = []
3770                         first_string_line = True
3771                         use_custom_breakpoints = True
3772                         continue
3773
3774                     # Otherwise, we stop splitting here.
3775                     break
3776
3777                 break_idx = maybe_break_idx
3778
3779             # --- Construct `next_value`
3780             next_value = rest_value[:break_idx] + QUOTE
3781             if (
3782                 # Are we allowed to try to drop a pointless 'f' prefix?
3783                 drop_pointless_f_prefix
3784                 # If we are, will we be successful?
3785                 and next_value != self.__normalize_f_string(next_value, prefix)
3786             ):
3787                 # If the current custom split did NOT originally use a prefix,
3788                 # then `csplit.break_idx` will be off by one after removing
3789                 # the 'f' prefix.
3790                 break_idx = (
3791                     break_idx + 1
3792                     if use_custom_breakpoints and not csplit.has_prefix
3793                     else break_idx
3794                 )
3795                 next_value = rest_value[:break_idx] + QUOTE
3796                 next_value = self.__normalize_f_string(next_value, prefix)
3797
3798             # --- Construct `next_leaf`
3799             next_leaf = Leaf(token.STRING, next_value)
3800             insert_str_child(next_leaf)
3801             self.__maybe_normalize_string_quotes(next_leaf)
3802
3803             # --- Construct `next_line`
3804             next_line = line.clone()
3805             maybe_append_plus(next_line)
3806             next_line.append(next_leaf)
3807             string_line_results.append(Ok(next_line))
3808
3809             rest_value = prefix + QUOTE + rest_value[break_idx:]
3810             first_string_line = False
3811
3812         yield from string_line_results
3813
3814         if drop_pointless_f_prefix:
3815             rest_value = self.__normalize_f_string(rest_value, prefix)
3816
3817         rest_leaf = Leaf(token.STRING, rest_value)
3818         insert_str_child(rest_leaf)
3819
3820         # NOTE: I could not find a test case that verifies that the following
3821         # line is actually necessary, but it seems to be. Otherwise we risk
3822         # not normalizing the last substring, right?
3823         self.__maybe_normalize_string_quotes(rest_leaf)
3824
3825         last_line = line.clone()
3826         maybe_append_plus(last_line)
3827
3828         # If there are any leaves to the right of the target string...
3829         if is_valid_index(string_idx + 1):
3830             # We use `temp_value` here to determine how long the last line
3831             # would be if we were to append all the leaves to the right of the
3832             # target string to the last string line.
3833             temp_value = rest_value
3834             for leaf in LL[string_idx + 1 :]:
3835                 temp_value += str(leaf)
3836                 if leaf.type == token.LPAR:
3837                     break
3838
3839             # Try to fit them all on the same line with the last substring...
3840             if (
3841                 len(temp_value) <= max_last_string()
3842                 or LL[string_idx + 1].type == token.COMMA
3843             ):
3844                 last_line.append(rest_leaf)
3845                 append_leaves(last_line, line, LL[string_idx + 1 :])
3846                 yield Ok(last_line)
3847             # Otherwise, place the last substring on one line and everything
3848             # else on a line below that...
3849             else:
3850                 last_line.append(rest_leaf)
3851                 yield Ok(last_line)
3852
3853                 non_string_line = line.clone()
3854                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3855                 yield Ok(non_string_line)
3856         # Else the target string was the last leaf...
3857         else:
3858             last_line.append(rest_leaf)
3859             last_line.comments = line.comments.copy()
3860             yield Ok(last_line)
3861
3862     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3863         """
3864         This method contains the algorithm that StringSplitter uses to
3865         determine which character to split each string at.
3866
3867         Args:
3868             @string: The substring that we are attempting to split.
3869             @max_break_idx: The ideal break index. We will return this value if it
3870             meets all the necessary conditions. In the likely event that it
3871             doesn't we will try to find the closest index BELOW @max_break_idx
3872             that does. If that fails, we will expand our search by also
3873             considering all valid indices ABOVE @max_break_idx.
3874
3875         Pre-Conditions:
3876             * assert_is_leaf_string(@string)
3877             * 0 <= @max_break_idx < len(@string)
3878
3879         Returns:
3880             break_idx, if an index is able to be found that meets all of the
3881             conditions listed in the 'Transformations' section of this classes'
3882             docstring.
3883                 OR
3884             None, otherwise.
3885         """
3886         is_valid_index = is_valid_index_factory(string)
3887
3888         assert is_valid_index(max_break_idx)
3889         assert_is_leaf_string(string)
3890
3891         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3892
3893         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3894             """
3895             Yields:
3896                 All ranges of @string which, if @string were to be split there,
3897                 would result in the splitting of an f-expression (which is NOT
3898                 allowed).
3899             """
3900             nonlocal _fexpr_slices
3901
3902             if _fexpr_slices is None:
3903                 _fexpr_slices = []
3904                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3905                     _fexpr_slices.append(match.span())
3906
3907             yield from _fexpr_slices
3908
3909         is_fstring = "f" in get_string_prefix(string)
3910
3911         def breaks_fstring_expression(i: Index) -> bool:
3912             """
3913             Returns:
3914                 True iff returning @i would result in the splitting of an
3915                 f-expression (which is NOT allowed).
3916             """
3917             if not is_fstring:
3918                 return False
3919
3920             for (start, end) in fexpr_slices():
3921                 if start <= i < end:
3922                     return True
3923
3924             return False
3925
3926         def passes_all_checks(i: Index) -> bool:
3927             """
3928             Returns:
3929                 True iff ALL of the conditions listed in the 'Transformations'
3930                 section of this classes' docstring would be be met by returning @i.
3931             """
3932             is_space = string[i] == " "
3933             is_big_enough = (
3934                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3935                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3936             )
3937             return is_space and is_big_enough and not breaks_fstring_expression(i)
3938
3939         # First, we check all indices BELOW @max_break_idx.
3940         break_idx = max_break_idx
3941         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3942             break_idx -= 1
3943
3944         if not passes_all_checks(break_idx):
3945             # If that fails, we check all indices ABOVE @max_break_idx.
3946             #
3947             # If we are able to find a valid index here, the next line is going
3948             # to be longer than the specified line length, but it's probably
3949             # better than doing nothing at all.
3950             break_idx = max_break_idx + 1
3951             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3952                 break_idx += 1
3953
3954             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3955                 return None
3956
3957         return break_idx
3958
3959     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3960         if self.normalize_strings:
3961             normalize_string_quotes(leaf)
3962
3963     def __normalize_f_string(self, string: str, prefix: str) -> str:
3964         """
3965         Pre-Conditions:
3966             * assert_is_leaf_string(@string)
3967
3968         Returns:
3969             * If @string is an f-string that contains no f-expressions, we
3970             return a string identical to @string except that the 'f' prefix
3971             has been stripped and all double braces (i.e. '{{' or '}}') have
3972             been normalized (i.e. turned into '{' or '}').
3973                 OR
3974             * Otherwise, we return @string.
3975         """
3976         assert_is_leaf_string(string)
3977
3978         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3979             new_prefix = prefix.replace("f", "")
3980
3981             temp = string[len(prefix) :]
3982             temp = re.sub(r"\{\{", "{", temp)
3983             temp = re.sub(r"\}\}", "}", temp)
3984             new_string = temp
3985
3986             return f"{new_prefix}{new_string}"
3987         else:
3988             return string
3989
3990
3991 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
3992     """
3993     StringTransformer that splits non-"atom" strings (i.e. strings that do not
3994     exist on lines by themselves).
3995
3996     Requirements:
3997         All of the requirements listed in BaseStringSplitter's docstring in
3998         addition to the requirements listed below:
3999
4000         * The line is a return/yield statement, which returns/yields a string.
4001             OR
4002         * The line is part of a ternary expression (e.g. `x = y if cond else
4003         z`) such that the line starts with `else <string>`, where <string> is
4004         some string.
4005             OR
4006         * The line is an assert statement, which ends with a string.
4007             OR
4008         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4009         <string>`) such that the variable is being assigned the value of some
4010         string.
4011             OR
4012         * The line is a dictionary key assignment where some valid key is being
4013         assigned the value of some string.
4014
4015     Transformations:
4016         The chosen string is wrapped in parentheses and then split at the LPAR.
4017
4018         We then have one line which ends with an LPAR and another line that
4019         starts with the chosen string. The latter line is then split again at
4020         the RPAR. This results in the RPAR (and possibly a trailing comma)
4021         being placed on its own line.
4022
4023         NOTE: If any leaves exist to the right of the chosen string (except
4024         for a trailing comma, which would be placed after the RPAR), those
4025         leaves are placed inside the parentheses.  In effect, the chosen
4026         string is not necessarily being "wrapped" by parentheses. We can,
4027         however, count on the LPAR being placed directly before the chosen
4028         string.
4029
4030         In other words, StringParenWrapper creates "atom" strings. These
4031         can then be split again by StringSplitter, if necessary.
4032
4033     Collaborations:
4034         In the event that a string line split by StringParenWrapper is
4035         changed such that it no longer needs to be given its own line,
4036         StringParenWrapper relies on StringParenStripper to clean up the
4037         parentheses it created.
4038     """
4039
4040     def do_splitter_match(self, line: Line) -> TMatchResult:
4041         LL = line.leaves
4042
4043         string_idx = (
4044             self._return_match(LL)
4045             or self._else_match(LL)
4046             or self._assert_match(LL)
4047             or self._assign_match(LL)
4048             or self._dict_match(LL)
4049         )
4050
4051         if string_idx is not None:
4052             string_value = line.leaves[string_idx].value
4053             # If the string has no spaces...
4054             if " " not in string_value:
4055                 # And will still violate the line length limit when split...
4056                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4057                 if len(string_value) > max_string_length:
4058                     # And has no associated custom splits...
4059                     if not self.has_custom_splits(string_value):
4060                         # Then we should NOT put this string on its own line.
4061                         return TErr(
4062                             "We do not wrap long strings in parentheses when the"
4063                             " resultant line would still be over the specified line"
4064                             " length and can't be split further by StringSplitter."
4065                         )
4066             return Ok(string_idx)
4067
4068         return TErr("This line does not contain any non-atomic strings.")
4069
4070     @staticmethod
4071     def _return_match(LL: List[Leaf]) -> Optional[int]:
4072         """
4073         Returns:
4074             string_idx such that @LL[string_idx] is equal to our target (i.e.
4075             matched) string, if this line matches the return/yield statement
4076             requirements listed in the 'Requirements' section of this classes'
4077             docstring.
4078                 OR
4079             None, otherwise.
4080         """
4081         # If this line is apart of a return/yield statement and the first leaf
4082         # contains either the "return" or "yield" keywords...
4083         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4084             0
4085         ].value in ["return", "yield"]:
4086             is_valid_index = is_valid_index_factory(LL)
4087
4088             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4089             # The next visible leaf MUST contain a string...
4090             if is_valid_index(idx) and LL[idx].type == token.STRING:
4091                 return idx
4092
4093         return None
4094
4095     @staticmethod
4096     def _else_match(LL: List[Leaf]) -> Optional[int]:
4097         """
4098         Returns:
4099             string_idx such that @LL[string_idx] is equal to our target (i.e.
4100             matched) string, if this line matches the ternary expression
4101             requirements listed in the 'Requirements' section of this classes'
4102             docstring.
4103                 OR
4104             None, otherwise.
4105         """
4106         # If this line is apart of a ternary expression and the first leaf
4107         # contains the "else" keyword...
4108         if (
4109             parent_type(LL[0]) == syms.test
4110             and LL[0].type == token.NAME
4111             and LL[0].value == "else"
4112         ):
4113             is_valid_index = is_valid_index_factory(LL)
4114
4115             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4116             # The next visible leaf MUST contain a string...
4117             if is_valid_index(idx) and LL[idx].type == token.STRING:
4118                 return idx
4119
4120         return None
4121
4122     @staticmethod
4123     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4124         """
4125         Returns:
4126             string_idx such that @LL[string_idx] is equal to our target (i.e.
4127             matched) string, if this line matches the assert statement
4128             requirements listed in the 'Requirements' section of this classes'
4129             docstring.
4130                 OR
4131             None, otherwise.
4132         """
4133         # If this line is apart of an assert statement and the first leaf
4134         # contains the "assert" keyword...
4135         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4136             is_valid_index = is_valid_index_factory(LL)
4137
4138             for (i, leaf) in enumerate(LL):
4139                 # We MUST find a comma...
4140                 if leaf.type == token.COMMA:
4141                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4142
4143                     # That comma MUST be followed by a string...
4144                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4145                         string_idx = idx
4146
4147                         # Skip the string trailer, if one exists.
4148                         string_parser = StringParser()
4149                         idx = string_parser.parse(LL, string_idx)
4150
4151                         # But no more leaves are allowed...
4152                         if not is_valid_index(idx):
4153                             return string_idx
4154
4155         return None
4156
4157     @staticmethod
4158     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4159         """
4160         Returns:
4161             string_idx such that @LL[string_idx] is equal to our target (i.e.
4162             matched) string, if this line matches the assignment statement
4163             requirements listed in the 'Requirements' section of this classes'
4164             docstring.
4165                 OR
4166             None, otherwise.
4167         """
4168         # If this line is apart of an expression statement or is a function
4169         # argument AND the first leaf contains a variable name...
4170         if (
4171             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4172             and LL[0].type == token.NAME
4173         ):
4174             is_valid_index = is_valid_index_factory(LL)
4175
4176             for (i, leaf) in enumerate(LL):
4177                 # We MUST find either an '=' or '+=' symbol...
4178                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4179                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4180
4181                     # That symbol MUST be followed by a string...
4182                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4183                         string_idx = idx
4184
4185                         # Skip the string trailer, if one exists.
4186                         string_parser = StringParser()
4187                         idx = string_parser.parse(LL, string_idx)
4188
4189                         # The next leaf MAY be a comma iff this line is apart
4190                         # of a function argument...
4191                         if (
4192                             parent_type(LL[0]) == syms.argument
4193                             and is_valid_index(idx)
4194                             and LL[idx].type == token.COMMA
4195                         ):
4196                             idx += 1
4197
4198                         # But no more leaves are allowed...
4199                         if not is_valid_index(idx):
4200                             return string_idx
4201
4202         return None
4203
4204     @staticmethod
4205     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4206         """
4207         Returns:
4208             string_idx such that @LL[string_idx] is equal to our target (i.e.
4209             matched) string, if this line matches the dictionary key assignment
4210             statement requirements listed in the 'Requirements' section of this
4211             classes' docstring.
4212                 OR
4213             None, otherwise.
4214         """
4215         # If this line is apart of a dictionary key assignment...
4216         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4217             is_valid_index = is_valid_index_factory(LL)
4218
4219             for (i, leaf) in enumerate(LL):
4220                 # We MUST find a colon...
4221                 if leaf.type == token.COLON:
4222                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4223
4224                     # That colon MUST be followed by a string...
4225                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4226                         string_idx = idx
4227
4228                         # Skip the string trailer, if one exists.
4229                         string_parser = StringParser()
4230                         idx = string_parser.parse(LL, string_idx)
4231
4232                         # That string MAY be followed by a comma...
4233                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4234                             idx += 1
4235
4236                         # But no more leaves are allowed...
4237                         if not is_valid_index(idx):
4238                             return string_idx
4239
4240         return None
4241
4242     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4243         LL = line.leaves
4244
4245         is_valid_index = is_valid_index_factory(LL)
4246         insert_str_child = insert_str_child_factory(LL[string_idx])
4247
4248         comma_idx = -1
4249         ends_with_comma = False
4250         if LL[comma_idx].type == token.COMMA:
4251             ends_with_comma = True
4252
4253         leaves_to_steal_comments_from = [LL[string_idx]]
4254         if ends_with_comma:
4255             leaves_to_steal_comments_from.append(LL[comma_idx])
4256
4257         # --- First Line
4258         first_line = line.clone()
4259         left_leaves = LL[:string_idx]
4260
4261         # We have to remember to account for (possibly invisible) LPAR and RPAR
4262         # leaves that already wrapped the target string. If these leaves do
4263         # exist, we will replace them with our own LPAR and RPAR leaves.
4264         old_parens_exist = False
4265         if left_leaves and left_leaves[-1].type == token.LPAR:
4266             old_parens_exist = True
4267             leaves_to_steal_comments_from.append(left_leaves[-1])
4268             left_leaves.pop()
4269
4270         append_leaves(first_line, line, left_leaves)
4271
4272         lpar_leaf = Leaf(token.LPAR, "(")
4273         if old_parens_exist:
4274             replace_child(LL[string_idx - 1], lpar_leaf)
4275         else:
4276             insert_str_child(lpar_leaf)
4277         first_line.append(lpar_leaf)
4278
4279         # We throw inline comments that were originally to the right of the
4280         # target string to the top line. They will now be shown to the right of
4281         # the LPAR.
4282         for leaf in leaves_to_steal_comments_from:
4283             for comment_leaf in line.comments_after(leaf):
4284                 first_line.append(comment_leaf, preformatted=True)
4285
4286         yield Ok(first_line)
4287
4288         # --- Middle (String) Line
4289         # We only need to yield one (possibly too long) string line, since the
4290         # `StringSplitter` will break it down further if necessary.
4291         string_value = LL[string_idx].value
4292         string_line = Line(
4293             depth=line.depth + 1,
4294             inside_brackets=True,
4295             should_explode=line.should_explode,
4296         )
4297         string_leaf = Leaf(token.STRING, string_value)
4298         insert_str_child(string_leaf)
4299         string_line.append(string_leaf)
4300
4301         old_rpar_leaf = None
4302         if is_valid_index(string_idx + 1):
4303             right_leaves = LL[string_idx + 1 :]
4304             if ends_with_comma:
4305                 right_leaves.pop()
4306
4307             if old_parens_exist:
4308                 assert (
4309                     right_leaves and right_leaves[-1].type == token.RPAR
4310                 ), "Apparently, old parentheses do NOT exist?!"
4311                 old_rpar_leaf = right_leaves.pop()
4312
4313             append_leaves(string_line, line, right_leaves)
4314
4315         yield Ok(string_line)
4316
4317         # --- Last Line
4318         last_line = line.clone()
4319         last_line.bracket_tracker = first_line.bracket_tracker
4320
4321         new_rpar_leaf = Leaf(token.RPAR, ")")
4322         if old_rpar_leaf is not None:
4323             replace_child(old_rpar_leaf, new_rpar_leaf)
4324         else:
4325             insert_str_child(new_rpar_leaf)
4326         last_line.append(new_rpar_leaf)
4327
4328         # If the target string ended with a comma, we place this comma to the
4329         # right of the RPAR on the last line.
4330         if ends_with_comma:
4331             comma_leaf = Leaf(token.COMMA, ",")
4332             replace_child(LL[comma_idx], comma_leaf)
4333             last_line.append(comma_leaf)
4334
4335         yield Ok(last_line)
4336
4337
4338 class StringParser:
4339     """
4340     A state machine that aids in parsing a string's "trailer", which can be
4341     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4342     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4343     varY)`).
4344
4345     NOTE: A new StringParser object MUST be instantiated for each string
4346     trailer we need to parse.
4347
4348     Examples:
4349         We shall assume that `line` equals the `Line` object that corresponds
4350         to the following line of python code:
4351         ```
4352         x = "Some {}.".format("String") + some_other_string
4353         ```
4354
4355         Furthermore, we will assume that `string_idx` is some index such that:
4356         ```
4357         assert line.leaves[string_idx].value == "Some {}."
4358         ```
4359
4360         The following code snippet then holds:
4361         ```
4362         string_parser = StringParser()
4363         idx = string_parser.parse(line.leaves, string_idx)
4364         assert line.leaves[idx].type == token.PLUS
4365         ```
4366     """
4367
4368     DEFAULT_TOKEN = -1
4369
4370     # String Parser States
4371     START = 1
4372     DOT = 2
4373     NAME = 3
4374     PERCENT = 4
4375     SINGLE_FMT_ARG = 5
4376     LPAR = 6
4377     RPAR = 7
4378     DONE = 8
4379
4380     # Lookup Table for Next State
4381     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4382         # A string trailer may start with '.' OR '%'.
4383         (START, token.DOT): DOT,
4384         (START, token.PERCENT): PERCENT,
4385         (START, DEFAULT_TOKEN): DONE,
4386         # A '.' MUST be followed by an attribute or method name.
4387         (DOT, token.NAME): NAME,
4388         # A method name MUST be followed by an '(', whereas an attribute name
4389         # is the last symbol in the string trailer.
4390         (NAME, token.LPAR): LPAR,
4391         (NAME, DEFAULT_TOKEN): DONE,
4392         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4393         # string or variable name).
4394         (PERCENT, token.LPAR): LPAR,
4395         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4396         # If a '%' symbol is followed by a single argument, that argument is
4397         # the last leaf in the string trailer.
4398         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4399         # If present, a ')' symbol is the last symbol in a string trailer.
4400         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4401         # since they are treated as a special case by the parsing logic in this
4402         # classes' implementation.)
4403         (RPAR, DEFAULT_TOKEN): DONE,
4404     }
4405
4406     def __init__(self) -> None:
4407         self._state = self.START
4408         self._unmatched_lpars = 0
4409
4410     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4411         """
4412         Pre-conditions:
4413             * @leaves[@string_idx].type == token.STRING
4414
4415         Returns:
4416             The index directly after the last leaf which is apart of the string
4417             trailer, if a "trailer" exists.
4418                 OR
4419             @string_idx + 1, if no string "trailer" exists.
4420         """
4421         assert leaves[string_idx].type == token.STRING
4422
4423         idx = string_idx + 1
4424         while idx < len(leaves) and self._next_state(leaves[idx]):
4425             idx += 1
4426         return idx
4427
4428     def _next_state(self, leaf: Leaf) -> bool:
4429         """
4430         Pre-conditions:
4431             * On the first call to this function, @leaf MUST be the leaf that
4432             was directly after the string leaf in question (e.g. if our target
4433             string is `line.leaves[i]` then the first call to this method must
4434             be `line.leaves[i + 1]`).
4435             * On the next call to this function, the leaf parameter passed in
4436             MUST be the leaf directly following @leaf.
4437
4438         Returns:
4439             True iff @leaf is apart of the string's trailer.
4440         """
4441         # We ignore empty LPAR or RPAR leaves.
4442         if is_empty_par(leaf):
4443             return True
4444
4445         next_token = leaf.type
4446         if next_token == token.LPAR:
4447             self._unmatched_lpars += 1
4448
4449         current_state = self._state
4450
4451         # The LPAR parser state is a special case. We will return True until we
4452         # find the matching RPAR token.
4453         if current_state == self.LPAR:
4454             if next_token == token.RPAR:
4455                 self._unmatched_lpars -= 1
4456                 if self._unmatched_lpars == 0:
4457                     self._state = self.RPAR
4458         # Otherwise, we use a lookup table to determine the next state.
4459         else:
4460             # If the lookup table matches the current state to the next
4461             # token, we use the lookup table.
4462             if (current_state, next_token) in self._goto:
4463                 self._state = self._goto[current_state, next_token]
4464             else:
4465                 # Otherwise, we check if a the current state was assigned a
4466                 # default.
4467                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4468                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4469                 # If no default has been assigned, then this parser has a logic
4470                 # error.
4471                 else:
4472                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4473
4474             if self._state == self.DONE:
4475                 return False
4476
4477         return True
4478
4479
4480 def TErr(err_msg: str) -> Err[CannotTransform]:
4481     """(T)ransform Err
4482
4483     Convenience function used when working with the TResult type.
4484     """
4485     cant_transform = CannotTransform(err_msg)
4486     return Err(cant_transform)
4487
4488
4489 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4490     """
4491     Returns:
4492         True iff one of the comments in @comment_list is a pragma used by one
4493         of the more common static analysis tools for python (e.g. mypy, flake8,
4494         pylint).
4495     """
4496     for comment in comment_list:
4497         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4498             return True
4499
4500     return False
4501
4502
4503 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4504     """
4505     Factory for a convenience function that is used to orphan @string_leaf
4506     and then insert multiple new leaves into the same part of the node
4507     structure that @string_leaf had originally occupied.
4508
4509     Examples:
4510         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4511         string_leaf.parent`. Assume the node `N` has the following
4512         original structure:
4513
4514         Node(
4515             expr_stmt, [
4516                 Leaf(NAME, 'x'),
4517                 Leaf(EQUAL, '='),
4518                 Leaf(STRING, '"foo"'),
4519             ]
4520         )
4521
4522         We then run the code snippet shown below.
4523         ```
4524         insert_str_child = insert_str_child_factory(string_leaf)
4525
4526         lpar = Leaf(token.LPAR, '(')
4527         insert_str_child(lpar)
4528
4529         bar = Leaf(token.STRING, '"bar"')
4530         insert_str_child(bar)
4531
4532         rpar = Leaf(token.RPAR, ')')
4533         insert_str_child(rpar)
4534         ```
4535
4536         After which point, it follows that `string_leaf.parent is None` and
4537         the node `N` now has the following structure:
4538
4539         Node(
4540             expr_stmt, [
4541                 Leaf(NAME, 'x'),
4542                 Leaf(EQUAL, '='),
4543                 Leaf(LPAR, '('),
4544                 Leaf(STRING, '"bar"'),
4545                 Leaf(RPAR, ')'),
4546             ]
4547         )
4548     """
4549     string_parent = string_leaf.parent
4550     string_child_idx = string_leaf.remove()
4551
4552     def insert_str_child(child: LN) -> None:
4553         nonlocal string_child_idx
4554
4555         assert string_parent is not None
4556         assert string_child_idx is not None
4557
4558         string_parent.insert_child(string_child_idx, child)
4559         string_child_idx += 1
4560
4561     return insert_str_child
4562
4563
4564 def has_triple_quotes(string: str) -> bool:
4565     """
4566     Returns:
4567         True iff @string starts with three quotation characters.
4568     """
4569     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4570     return raw_string[:3] in {'"""', "'''"}
4571
4572
4573 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4574     """
4575     Returns:
4576         @node.parent.type, if @node is not None and has a parent.
4577             OR
4578         None, otherwise.
4579     """
4580     if node is None or node.parent is None:
4581         return None
4582
4583     return node.parent.type
4584
4585
4586 def is_empty_par(leaf: Leaf) -> bool:
4587     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4588
4589
4590 def is_empty_lpar(leaf: Leaf) -> bool:
4591     return leaf.type == token.LPAR and leaf.value == ""
4592
4593
4594 def is_empty_rpar(leaf: Leaf) -> bool:
4595     return leaf.type == token.RPAR and leaf.value == ""
4596
4597
4598 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4599     """
4600     Examples:
4601         ```
4602         my_list = [1, 2, 3]
4603
4604         is_valid_index = is_valid_index_factory(my_list)
4605
4606         assert is_valid_index(0)
4607         assert is_valid_index(2)
4608
4609         assert not is_valid_index(3)
4610         assert not is_valid_index(-1)
4611         ```
4612     """
4613
4614     def is_valid_index(idx: int) -> bool:
4615         """
4616         Returns:
4617             True iff @idx is positive AND seq[@idx] does NOT raise an
4618             IndexError.
4619         """
4620         return 0 <= idx < len(seq)
4621
4622     return is_valid_index
4623
4624
4625 def line_to_string(line: Line) -> str:
4626     """Returns the string representation of @line.
4627
4628     WARNING: This is known to be computationally expensive.
4629     """
4630     return str(line).strip("\n")
4631
4632
4633 def append_leaves(
4634     new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
4635 ) -> None:
4636     """
4637     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4638     underlying Node structure where appropriate.
4639
4640     All of the leaves in @leaves are duplicated. The duplicates are then
4641     appended to @new_line and used to replace their originals in the underlying
4642     Node structure. Any comments attached to the old leaves are reattached to
4643     the new leaves.
4644
4645     Pre-conditions:
4646         set(@leaves) is a subset of set(@old_line.leaves).
4647     """
4648     for old_leaf in leaves:
4649         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4650         replace_child(old_leaf, new_leaf)
4651         new_line.append(new_leaf, preformatted=preformatted)
4652
4653         for comment_leaf in old_line.comments_after(old_leaf):
4654             new_line.append(comment_leaf, preformatted=True)
4655
4656
4657 def replace_child(old_child: LN, new_child: LN) -> None:
4658     """
4659     Side Effects:
4660         * If @old_child.parent is set, replace @old_child with @new_child in
4661         @old_child's underlying Node structure.
4662             OR
4663         * Otherwise, this function does nothing.
4664     """
4665     parent = old_child.parent
4666     if not parent:
4667         return
4668
4669     child_idx = old_child.remove()
4670     if child_idx is not None:
4671         parent.insert_child(child_idx, new_child)
4672
4673
4674 def get_string_prefix(string: str) -> str:
4675     """
4676     Pre-conditions:
4677         * assert_is_leaf_string(@string)
4678
4679     Returns:
4680         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4681     """
4682     assert_is_leaf_string(string)
4683
4684     prefix = ""
4685     prefix_idx = 0
4686     while string[prefix_idx] in STRING_PREFIX_CHARS:
4687         prefix += string[prefix_idx].lower()
4688         prefix_idx += 1
4689
4690     return prefix
4691
4692
4693 def assert_is_leaf_string(string: str) -> None:
4694     """
4695     Checks the pre-condition that @string has the format that you would expect
4696     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4697     token.STRING`. A more precise description of the pre-conditions that are
4698     checked are listed below.
4699
4700     Pre-conditions:
4701         * @string starts with either ', ", <prefix>', or <prefix>" where
4702         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4703         * @string ends with a quote character (' or ").
4704
4705     Raises:
4706         AssertionError(...) if the pre-conditions listed above are not
4707         satisfied.
4708     """
4709     dquote_idx = string.find('"')
4710     squote_idx = string.find("'")
4711     if -1 in [dquote_idx, squote_idx]:
4712         quote_idx = max(dquote_idx, squote_idx)
4713     else:
4714         quote_idx = min(squote_idx, dquote_idx)
4715
4716     assert (
4717         0 <= quote_idx < len(string) - 1
4718     ), f"{string!r} is missing a starting quote character (' or \")."
4719     assert string[-1] in (
4720         "'",
4721         '"',
4722     ), f"{string!r} is missing an ending quote character (' or \")."
4723     assert set(string[:quote_idx]).issubset(
4724         set(STRING_PREFIX_CHARS)
4725     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4726
4727
4728 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4729     """Split line into many lines, starting with the first matching bracket pair.
4730
4731     Note: this usually looks weird, only use this for function definitions.
4732     Prefer RHS otherwise.  This is why this function is not symmetrical with
4733     :func:`right_hand_split` which also handles optional parentheses.
4734     """
4735     tail_leaves: List[Leaf] = []
4736     body_leaves: List[Leaf] = []
4737     head_leaves: List[Leaf] = []
4738     current_leaves = head_leaves
4739     matching_bracket: Optional[Leaf] = None
4740     for leaf in line.leaves:
4741         if (
4742             current_leaves is body_leaves
4743             and leaf.type in CLOSING_BRACKETS
4744             and leaf.opening_bracket is matching_bracket
4745         ):
4746             current_leaves = tail_leaves if body_leaves else head_leaves
4747         current_leaves.append(leaf)
4748         if current_leaves is head_leaves:
4749             if leaf.type in OPENING_BRACKETS:
4750                 matching_bracket = leaf
4751                 current_leaves = body_leaves
4752     if not matching_bracket:
4753         raise CannotSplit("No brackets found")
4754
4755     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4756     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4757     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4758     bracket_split_succeeded_or_raise(head, body, tail)
4759     for result in (head, body, tail):
4760         if result:
4761             yield result
4762
4763
4764 def right_hand_split(
4765     line: Line,
4766     line_length: int,
4767     features: Collection[Feature] = (),
4768     omit: Collection[LeafID] = (),
4769 ) -> Iterator[Line]:
4770     """Split line into many lines, starting with the last matching bracket pair.
4771
4772     If the split was by optional parentheses, attempt splitting without them, too.
4773     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4774     this split.
4775
4776     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4777     """
4778     tail_leaves: List[Leaf] = []
4779     body_leaves: List[Leaf] = []
4780     head_leaves: List[Leaf] = []
4781     current_leaves = tail_leaves
4782     opening_bracket: Optional[Leaf] = None
4783     closing_bracket: Optional[Leaf] = None
4784     for leaf in reversed(line.leaves):
4785         if current_leaves is body_leaves:
4786             if leaf is opening_bracket:
4787                 current_leaves = head_leaves if body_leaves else tail_leaves
4788         current_leaves.append(leaf)
4789         if current_leaves is tail_leaves:
4790             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4791                 opening_bracket = leaf.opening_bracket
4792                 closing_bracket = leaf
4793                 current_leaves = body_leaves
4794     if not (opening_bracket and closing_bracket and head_leaves):
4795         # If there is no opening or closing_bracket that means the split failed and
4796         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4797         # the matching `opening_bracket` wasn't available on `line` anymore.
4798         raise CannotSplit("No brackets found")
4799
4800     tail_leaves.reverse()
4801     body_leaves.reverse()
4802     head_leaves.reverse()
4803     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4804     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4805     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4806     bracket_split_succeeded_or_raise(head, body, tail)
4807     if (
4808         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4809         # the opening bracket is an optional paren
4810         and opening_bracket.type == token.LPAR
4811         and not opening_bracket.value
4812         # the closing bracket is an optional paren
4813         and closing_bracket.type == token.RPAR
4814         and not closing_bracket.value
4815         # it's not an import (optional parens are the only thing we can split on
4816         # in this case; attempting a split without them is a waste of time)
4817         and not line.is_import
4818         # there are no standalone comments in the body
4819         and not body.contains_standalone_comments(0)
4820         # and we can actually remove the parens
4821         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4822     ):
4823         omit = {id(closing_bracket), *omit}
4824         try:
4825             yield from right_hand_split(line, line_length, features=features, omit=omit)
4826             return
4827
4828         except CannotSplit:
4829             if not (
4830                 can_be_split(body)
4831                 or is_line_short_enough(body, line_length=line_length)
4832             ):
4833                 raise CannotSplit(
4834                     "Splitting failed, body is still too long and can't be split."
4835                 )
4836
4837             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4838                 raise CannotSplit(
4839                     "The current optional pair of parentheses is bound to fail to"
4840                     " satisfy the splitting algorithm because the head or the tail"
4841                     " contains multiline strings which by definition never fit one"
4842                     " line."
4843                 )
4844
4845     ensure_visible(opening_bracket)
4846     ensure_visible(closing_bracket)
4847     for result in (head, body, tail):
4848         if result:
4849             yield result
4850
4851
4852 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4853     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4854
4855     Do nothing otherwise.
4856
4857     A left- or right-hand split is based on a pair of brackets. Content before
4858     (and including) the opening bracket is left on one line, content inside the
4859     brackets is put on a separate line, and finally content starting with and
4860     following the closing bracket is put on a separate line.
4861
4862     Those are called `head`, `body`, and `tail`, respectively. If the split
4863     produced the same line (all content in `head`) or ended up with an empty `body`
4864     and the `tail` is just the closing bracket, then it's considered failed.
4865     """
4866     tail_len = len(str(tail).strip())
4867     if not body:
4868         if tail_len == 0:
4869             raise CannotSplit("Splitting brackets produced the same line")
4870
4871         elif tail_len < 3:
4872             raise CannotSplit(
4873                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4874                 " not worth it"
4875             )
4876
4877
4878 def bracket_split_build_line(
4879     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4880 ) -> Line:
4881     """Return a new line with given `leaves` and respective comments from `original`.
4882
4883     If `is_body` is True, the result line is one-indented inside brackets and as such
4884     has its first leaf's prefix normalized and a trailing comma added when expected.
4885     """
4886     result = Line(depth=original.depth)
4887     if is_body:
4888         result.inside_brackets = True
4889         result.depth += 1
4890         if leaves:
4891             # Since body is a new indent level, remove spurious leading whitespace.
4892             normalize_prefix(leaves[0], inside_brackets=True)
4893             # Ensure a trailing comma for imports and standalone function arguments, but
4894             # be careful not to add one after any comments or within type annotations.
4895             no_commas = (
4896                 original.is_def
4897                 and opening_bracket.value == "("
4898                 and not any(leaf.type == token.COMMA for leaf in leaves)
4899             )
4900
4901             if original.is_import or no_commas:
4902                 for i in range(len(leaves) - 1, -1, -1):
4903                     if leaves[i].type == STANDALONE_COMMENT:
4904                         continue
4905
4906                     if leaves[i].type != token.COMMA:
4907                         new_comma = Leaf(token.COMMA, ",")
4908                         leaves.insert(i + 1, new_comma)
4909                     break
4910
4911     # Populate the line
4912     for leaf in leaves:
4913         result.append(leaf, preformatted=True)
4914         for comment_after in original.comments_after(leaf):
4915             result.append(comment_after, preformatted=True)
4916     if is_body and should_split_body_explode(result, opening_bracket):
4917         result.should_explode = True
4918     return result
4919
4920
4921 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4922     """Normalize prefix of the first leaf in every line returned by `split_func`.
4923
4924     This is a decorator over relevant split functions.
4925     """
4926
4927     @wraps(split_func)
4928     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4929         for line in split_func(line, features):
4930             normalize_prefix(line.leaves[0], inside_brackets=True)
4931             yield line
4932
4933     return split_wrapper
4934
4935
4936 @dont_increase_indentation
4937 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4938     """Split according to delimiters of the highest priority.
4939
4940     If the appropriate Features are given, the split will add trailing commas
4941     also in function signatures and calls that contain `*` and `**`.
4942     """
4943     try:
4944         last_leaf = line.leaves[-1]
4945     except IndexError:
4946         raise CannotSplit("Line empty")
4947
4948     bt = line.bracket_tracker
4949     try:
4950         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4951     except ValueError:
4952         raise CannotSplit("No delimiters found")
4953
4954     if delimiter_priority == DOT_PRIORITY:
4955         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4956             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4957
4958     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4959     lowest_depth = sys.maxsize
4960     trailing_comma_safe = True
4961
4962     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4963         """Append `leaf` to current line or to new line if appending impossible."""
4964         nonlocal current_line
4965         try:
4966             current_line.append_safe(leaf, preformatted=True)
4967         except ValueError:
4968             yield current_line
4969
4970             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4971             current_line.append(leaf)
4972
4973     for leaf in line.leaves:
4974         yield from append_to_line(leaf)
4975
4976         for comment_after in line.comments_after(leaf):
4977             yield from append_to_line(comment_after)
4978
4979         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4980         if leaf.bracket_depth == lowest_depth:
4981             if is_vararg(leaf, within={syms.typedargslist}):
4982                 trailing_comma_safe = (
4983                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4984                 )
4985             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4986                 trailing_comma_safe = (
4987                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4988                 )
4989
4990         leaf_priority = bt.delimiters.get(id(leaf))
4991         if leaf_priority == delimiter_priority:
4992             yield current_line
4993
4994             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4995     if current_line:
4996         if (
4997             trailing_comma_safe
4998             and delimiter_priority == COMMA_PRIORITY
4999             and current_line.leaves[-1].type != token.COMMA
5000             and current_line.leaves[-1].type != STANDALONE_COMMENT
5001         ):
5002             new_comma = Leaf(token.COMMA, ",")
5003             current_line.append(new_comma)
5004         yield current_line
5005
5006
5007 @dont_increase_indentation
5008 def standalone_comment_split(
5009     line: Line, features: Collection[Feature] = ()
5010 ) -> Iterator[Line]:
5011     """Split standalone comments from the rest of the line."""
5012     if not line.contains_standalone_comments(0):
5013         raise CannotSplit("Line does not have any standalone comments")
5014
5015     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5016
5017     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5018         """Append `leaf` to current line or to new line if appending impossible."""
5019         nonlocal current_line
5020         try:
5021             current_line.append_safe(leaf, preformatted=True)
5022         except ValueError:
5023             yield current_line
5024
5025             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5026             current_line.append(leaf)
5027
5028     for leaf in line.leaves:
5029         yield from append_to_line(leaf)
5030
5031         for comment_after in line.comments_after(leaf):
5032             yield from append_to_line(comment_after)
5033
5034     if current_line:
5035         yield current_line
5036
5037
5038 def is_import(leaf: Leaf) -> bool:
5039     """Return True if the given leaf starts an import statement."""
5040     p = leaf.parent
5041     t = leaf.type
5042     v = leaf.value
5043     return bool(
5044         t == token.NAME
5045         and (
5046             (v == "import" and p and p.type == syms.import_name)
5047             or (v == "from" and p and p.type == syms.import_from)
5048         )
5049     )
5050
5051
5052 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5053     """Return True if the given leaf is a special comment.
5054     Only returns true for type comments for now."""
5055     t = leaf.type
5056     v = leaf.value
5057     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5058
5059
5060 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5061     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5062     else.
5063
5064     Note: don't use backslashes for formatting or you'll lose your voting rights.
5065     """
5066     if not inside_brackets:
5067         spl = leaf.prefix.split("#")
5068         if "\\" not in spl[0]:
5069             nl_count = spl[-1].count("\n")
5070             if len(spl) > 1:
5071                 nl_count -= 1
5072             leaf.prefix = "\n" * nl_count
5073             return
5074
5075     leaf.prefix = ""
5076
5077
5078 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5079     """Make all string prefixes lowercase.
5080
5081     If remove_u_prefix is given, also removes any u prefix from the string.
5082
5083     Note: Mutates its argument.
5084     """
5085     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5086     assert match is not None, f"failed to match string {leaf.value!r}"
5087     orig_prefix = match.group(1)
5088     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5089     if remove_u_prefix:
5090         new_prefix = new_prefix.replace("u", "")
5091     leaf.value = f"{new_prefix}{match.group(2)}"
5092
5093
5094 def normalize_string_quotes(leaf: Leaf) -> None:
5095     """Prefer double quotes but only if it doesn't cause more escaping.
5096
5097     Adds or removes backslashes as appropriate. Doesn't parse and fix
5098     strings nested in f-strings (yet).
5099
5100     Note: Mutates its argument.
5101     """
5102     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5103     if value[:3] == '"""':
5104         return
5105
5106     elif value[:3] == "'''":
5107         orig_quote = "'''"
5108         new_quote = '"""'
5109     elif value[0] == '"':
5110         orig_quote = '"'
5111         new_quote = "'"
5112     else:
5113         orig_quote = "'"
5114         new_quote = '"'
5115     first_quote_pos = leaf.value.find(orig_quote)
5116     if first_quote_pos == -1:
5117         return  # There's an internal error
5118
5119     prefix = leaf.value[:first_quote_pos]
5120     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5121     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5122     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5123     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5124     if "r" in prefix.casefold():
5125         if unescaped_new_quote.search(body):
5126             # There's at least one unescaped new_quote in this raw string
5127             # so converting is impossible
5128             return
5129
5130         # Do not introduce or remove backslashes in raw strings
5131         new_body = body
5132     else:
5133         # remove unnecessary escapes
5134         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5135         if body != new_body:
5136             # Consider the string without unnecessary escapes as the original
5137             body = new_body
5138             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5139         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5140         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5141     if "f" in prefix.casefold():
5142         matches = re.findall(
5143             r"""
5144             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5145                 ([^{].*?)  # contents of the brackets except if begins with {{
5146             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5147             """,
5148             new_body,
5149             re.VERBOSE,
5150         )
5151         for m in matches:
5152             if "\\" in str(m):
5153                 # Do not introduce backslashes in interpolated expressions
5154                 return
5155
5156     if new_quote == '"""' and new_body[-1:] == '"':
5157         # edge case:
5158         new_body = new_body[:-1] + '\\"'
5159     orig_escape_count = body.count("\\")
5160     new_escape_count = new_body.count("\\")
5161     if new_escape_count > orig_escape_count:
5162         return  # Do not introduce more escaping
5163
5164     if new_escape_count == orig_escape_count and orig_quote == '"':
5165         return  # Prefer double quotes
5166
5167     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5168
5169
5170 def normalize_numeric_literal(leaf: Leaf) -> None:
5171     """Normalizes numeric (float, int, and complex) literals.
5172
5173     All letters used in the representation are normalized to lowercase (except
5174     in Python 2 long literals).
5175     """
5176     text = leaf.value.lower()
5177     if text.startswith(("0o", "0b")):
5178         # Leave octal and binary literals alone.
5179         pass
5180     elif text.startswith("0x"):
5181         # Change hex literals to upper case.
5182         before, after = text[:2], text[2:]
5183         text = f"{before}{after.upper()}"
5184     elif "e" in text:
5185         before, after = text.split("e")
5186         sign = ""
5187         if after.startswith("-"):
5188             after = after[1:]
5189             sign = "-"
5190         elif after.startswith("+"):
5191             after = after[1:]
5192         before = format_float_or_int_string(before)
5193         text = f"{before}e{sign}{after}"
5194     elif text.endswith(("j", "l")):
5195         number = text[:-1]
5196         suffix = text[-1]
5197         # Capitalize in "2L" because "l" looks too similar to "1".
5198         if suffix == "l":
5199             suffix = "L"
5200         text = f"{format_float_or_int_string(number)}{suffix}"
5201     else:
5202         text = format_float_or_int_string(text)
5203     leaf.value = text
5204
5205
5206 def format_float_or_int_string(text: str) -> str:
5207     """Formats a float string like "1.0"."""
5208     if "." not in text:
5209         return text
5210
5211     before, after = text.split(".")
5212     return f"{before or 0}.{after or 0}"
5213
5214
5215 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5216     """Make existing optional parentheses invisible or create new ones.
5217
5218     `parens_after` is a set of string leaf values immediately after which parens
5219     should be put.
5220
5221     Standardizes on visible parentheses for single-element tuples, and keeps
5222     existing visible parentheses for other tuples and generator expressions.
5223     """
5224     for pc in list_comments(node.prefix, is_endmarker=False):
5225         if pc.value in FMT_OFF:
5226             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5227             return
5228     check_lpar = False
5229     for index, child in enumerate(list(node.children)):
5230         # Fixes a bug where invisible parens are not properly stripped from
5231         # assignment statements that contain type annotations.
5232         if isinstance(child, Node) and child.type == syms.annassign:
5233             normalize_invisible_parens(child, parens_after=parens_after)
5234
5235         # Add parentheses around long tuple unpacking in assignments.
5236         if (
5237             index == 0
5238             and isinstance(child, Node)
5239             and child.type == syms.testlist_star_expr
5240         ):
5241             check_lpar = True
5242
5243         if check_lpar:
5244             if is_walrus_assignment(child):
5245                 pass
5246
5247             elif child.type == syms.atom:
5248                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5249                     wrap_in_parentheses(node, child, visible=False)
5250             elif is_one_tuple(child):
5251                 wrap_in_parentheses(node, child, visible=True)
5252             elif node.type == syms.import_from:
5253                 # "import from" nodes store parentheses directly as part of
5254                 # the statement
5255                 if child.type == token.LPAR:
5256                     # make parentheses invisible
5257                     child.value = ""  # type: ignore
5258                     node.children[-1].value = ""  # type: ignore
5259                 elif child.type != token.STAR:
5260                     # insert invisible parentheses
5261                     node.insert_child(index, Leaf(token.LPAR, ""))
5262                     node.append_child(Leaf(token.RPAR, ""))
5263                 break
5264
5265             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5266                 wrap_in_parentheses(node, child, visible=False)
5267
5268         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5269
5270
5271 def normalize_fmt_off(node: Node) -> None:
5272     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5273     try_again = True
5274     while try_again:
5275         try_again = convert_one_fmt_off_pair(node)
5276
5277
5278 def convert_one_fmt_off_pair(node: Node) -> bool:
5279     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5280
5281     Returns True if a pair was converted.
5282     """
5283     for leaf in node.leaves():
5284         previous_consumed = 0
5285         for comment in list_comments(leaf.prefix, is_endmarker=False):
5286             if comment.value in FMT_OFF:
5287                 # We only want standalone comments. If there's no previous leaf or
5288                 # the previous leaf is indentation, it's a standalone comment in
5289                 # disguise.
5290                 if comment.type != STANDALONE_COMMENT:
5291                     prev = preceding_leaf(leaf)
5292                     if prev and prev.type not in WHITESPACE:
5293                         continue
5294
5295                 ignored_nodes = list(generate_ignored_nodes(leaf))
5296                 if not ignored_nodes:
5297                     continue
5298
5299                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5300                 parent = first.parent
5301                 prefix = first.prefix
5302                 first.prefix = prefix[comment.consumed :]
5303                 hidden_value = (
5304                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5305                 )
5306                 if hidden_value.endswith("\n"):
5307                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5308                     # leaf (possibly followed by a DEDENT).
5309                     hidden_value = hidden_value[:-1]
5310                 first_idx: Optional[int] = None
5311                 for ignored in ignored_nodes:
5312                     index = ignored.remove()
5313                     if first_idx is None:
5314                         first_idx = index
5315                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5316                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5317                 parent.insert_child(
5318                     first_idx,
5319                     Leaf(
5320                         STANDALONE_COMMENT,
5321                         hidden_value,
5322                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5323                     ),
5324                 )
5325                 return True
5326
5327             previous_consumed = comment.consumed
5328
5329     return False
5330
5331
5332 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5333     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5334
5335     Stops at the end of the block.
5336     """
5337     container: Optional[LN] = container_of(leaf)
5338     while container is not None and container.type != token.ENDMARKER:
5339         if is_fmt_on(container):
5340             return
5341
5342         # fix for fmt: on in children
5343         if contains_fmt_on_at_column(container, leaf.column):
5344             for child in container.children:
5345                 if contains_fmt_on_at_column(child, leaf.column):
5346                     return
5347                 yield child
5348         else:
5349             yield container
5350             container = container.next_sibling
5351
5352
5353 def is_fmt_on(container: LN) -> bool:
5354     """Determine whether formatting is switched on within a container.
5355     Determined by whether the last `# fmt:` comment is `on` or `off`.
5356     """
5357     fmt_on = False
5358     for comment in list_comments(container.prefix, is_endmarker=False):
5359         if comment.value in FMT_ON:
5360             fmt_on = True
5361         elif comment.value in FMT_OFF:
5362             fmt_on = False
5363     return fmt_on
5364
5365
5366 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5367     """Determine if children at a given column have formatting switched on."""
5368     for child in container.children:
5369         if (
5370             isinstance(child, Node)
5371             and first_leaf_column(child) == column
5372             or isinstance(child, Leaf)
5373             and child.column == column
5374         ):
5375             if is_fmt_on(child):
5376                 return True
5377
5378     return False
5379
5380
5381 def first_leaf_column(node: Node) -> Optional[int]:
5382     """Returns the column of the first leaf child of a node."""
5383     for child in node.children:
5384         if isinstance(child, Leaf):
5385             return child.column
5386     return None
5387
5388
5389 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5390     """If it's safe, make the parens in the atom `node` invisible, recursively.
5391     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5392     as they are redundant.
5393
5394     Returns whether the node should itself be wrapped in invisible parentheses.
5395
5396     """
5397     if (
5398         node.type != syms.atom
5399         or is_empty_tuple(node)
5400         or is_one_tuple(node)
5401         or (is_yield(node) and parent.type != syms.expr_stmt)
5402         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5403     ):
5404         return False
5405
5406     first = node.children[0]
5407     last = node.children[-1]
5408     if first.type == token.LPAR and last.type == token.RPAR:
5409         middle = node.children[1]
5410         # make parentheses invisible
5411         first.value = ""  # type: ignore
5412         last.value = ""  # type: ignore
5413         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5414
5415         if is_atom_with_invisible_parens(middle):
5416             # Strip the invisible parens from `middle` by replacing
5417             # it with the child in-between the invisible parens
5418             middle.replace(middle.children[1])
5419
5420         return False
5421
5422     return True
5423
5424
5425 def is_atom_with_invisible_parens(node: LN) -> bool:
5426     """Given a `LN`, determines whether it's an atom `node` with invisible
5427     parens. Useful in dedupe-ing and normalizing parens.
5428     """
5429     if isinstance(node, Leaf) or node.type != syms.atom:
5430         return False
5431
5432     first, last = node.children[0], node.children[-1]
5433     return (
5434         isinstance(first, Leaf)
5435         and first.type == token.LPAR
5436         and first.value == ""
5437         and isinstance(last, Leaf)
5438         and last.type == token.RPAR
5439         and last.value == ""
5440     )
5441
5442
5443 def is_empty_tuple(node: LN) -> bool:
5444     """Return True if `node` holds an empty tuple."""
5445     return (
5446         node.type == syms.atom
5447         and len(node.children) == 2
5448         and node.children[0].type == token.LPAR
5449         and node.children[1].type == token.RPAR
5450     )
5451
5452
5453 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5454     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5455
5456     Parenthesis can be optional. Returns None otherwise"""
5457     if len(node.children) != 3:
5458         return None
5459
5460     lpar, wrapped, rpar = node.children
5461     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5462         return None
5463
5464     return wrapped
5465
5466
5467 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5468     """Wrap `child` in parentheses.
5469
5470     This replaces `child` with an atom holding the parentheses and the old
5471     child.  That requires moving the prefix.
5472
5473     If `visible` is False, the leaves will be valueless (and thus invisible).
5474     """
5475     lpar = Leaf(token.LPAR, "(" if visible else "")
5476     rpar = Leaf(token.RPAR, ")" if visible else "")
5477     prefix = child.prefix
5478     child.prefix = ""
5479     index = child.remove() or 0
5480     new_child = Node(syms.atom, [lpar, child, rpar])
5481     new_child.prefix = prefix
5482     parent.insert_child(index, new_child)
5483
5484
5485 def is_one_tuple(node: LN) -> bool:
5486     """Return True if `node` holds a tuple with one element, with or without parens."""
5487     if node.type == syms.atom:
5488         gexp = unwrap_singleton_parenthesis(node)
5489         if gexp is None or gexp.type != syms.testlist_gexp:
5490             return False
5491
5492         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5493
5494     return (
5495         node.type in IMPLICIT_TUPLE
5496         and len(node.children) == 2
5497         and node.children[1].type == token.COMMA
5498     )
5499
5500
5501 def is_walrus_assignment(node: LN) -> bool:
5502     """Return True iff `node` is of the shape ( test := test )"""
5503     inner = unwrap_singleton_parenthesis(node)
5504     return inner is not None and inner.type == syms.namedexpr_test
5505
5506
5507 def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
5508     """Return True iff `node` is a trailer valid in a simple decorator"""
5509     return node.type == syms.trailer and (
5510         (
5511             len(node.children) == 2
5512             and node.children[0].type == token.DOT
5513             and node.children[1].type == token.NAME
5514         )
5515         # last trailer can be arguments
5516         or (
5517             last
5518             and len(node.children) == 3
5519             and node.children[0].type == token.LPAR
5520             # and node.children[1].type == syms.argument
5521             and node.children[2].type == token.RPAR
5522         )
5523     )
5524
5525
5526 def is_simple_decorator_expression(node: LN) -> bool:
5527     """Return True iff `node` could be a 'dotted name' decorator
5528
5529     This function takes the node of the 'namedexpr_test' of the new decorator
5530     grammar and test if it would be valid under the old decorator grammar.
5531
5532     The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
5533     The new grammar is : decorator: @ namedexpr_test NEWLINE
5534     """
5535     if node.type == token.NAME:
5536         return True
5537     if node.type == syms.power:
5538         if node.children:
5539             return (
5540                 node.children[0].type == token.NAME
5541                 and all(map(is_simple_decorator_trailer, node.children[1:-1]))
5542                 and (
5543                     len(node.children) < 2
5544                     or is_simple_decorator_trailer(node.children[-1], last=True)
5545                 )
5546             )
5547     return False
5548
5549
5550 def is_yield(node: LN) -> bool:
5551     """Return True if `node` holds a `yield` or `yield from` expression."""
5552     if node.type == syms.yield_expr:
5553         return True
5554
5555     if node.type == token.NAME and node.value == "yield":  # type: ignore
5556         return True
5557
5558     if node.type != syms.atom:
5559         return False
5560
5561     if len(node.children) != 3:
5562         return False
5563
5564     lpar, expr, rpar = node.children
5565     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5566         return is_yield(expr)
5567
5568     return False
5569
5570
5571 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5572     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5573
5574     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5575     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5576     extended iterable unpacking (PEP 3132) and additional unpacking
5577     generalizations (PEP 448).
5578     """
5579     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5580         return False
5581
5582     p = leaf.parent
5583     if p.type == syms.star_expr:
5584         # Star expressions are also used as assignment targets in extended
5585         # iterable unpacking (PEP 3132).  See what its parent is instead.
5586         if not p.parent:
5587             return False
5588
5589         p = p.parent
5590
5591     return p.type in within
5592
5593
5594 def is_multiline_string(leaf: Leaf) -> bool:
5595     """Return True if `leaf` is a multiline string that actually spans many lines."""
5596     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5597
5598
5599 def is_stub_suite(node: Node) -> bool:
5600     """Return True if `node` is a suite with a stub body."""
5601     if (
5602         len(node.children) != 4
5603         or node.children[0].type != token.NEWLINE
5604         or node.children[1].type != token.INDENT
5605         or node.children[3].type != token.DEDENT
5606     ):
5607         return False
5608
5609     return is_stub_body(node.children[2])
5610
5611
5612 def is_stub_body(node: LN) -> bool:
5613     """Return True if `node` is a simple statement containing an ellipsis."""
5614     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5615         return False
5616
5617     if len(node.children) != 2:
5618         return False
5619
5620     child = node.children[0]
5621     return (
5622         child.type == syms.atom
5623         and len(child.children) == 3
5624         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5625     )
5626
5627
5628 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5629     """Return maximum delimiter priority inside `node`.
5630
5631     This is specific to atoms with contents contained in a pair of parentheses.
5632     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5633     """
5634     if node.type != syms.atom:
5635         return 0
5636
5637     first = node.children[0]
5638     last = node.children[-1]
5639     if not (first.type == token.LPAR and last.type == token.RPAR):
5640         return 0
5641
5642     bt = BracketTracker()
5643     for c in node.children[1:-1]:
5644         if isinstance(c, Leaf):
5645             bt.mark(c)
5646         else:
5647             for leaf in c.leaves():
5648                 bt.mark(leaf)
5649     try:
5650         return bt.max_delimiter_priority()
5651
5652     except ValueError:
5653         return 0
5654
5655
5656 def ensure_visible(leaf: Leaf) -> None:
5657     """Make sure parentheses are visible.
5658
5659     They could be invisible as part of some statements (see
5660     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5661     """
5662     if leaf.type == token.LPAR:
5663         leaf.value = "("
5664     elif leaf.type == token.RPAR:
5665         leaf.value = ")"
5666
5667
5668 def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
5669     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5670
5671     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5672         return False
5673
5674     # We're essentially checking if the body is delimited by commas and there's more
5675     # than one of them (we're excluding the trailing comma and if the delimiter priority
5676     # is still commas, that means there's more).
5677     exclude = set()
5678     trailing_comma = False
5679     try:
5680         last_leaf = line.leaves[-1]
5681         if last_leaf.type == token.COMMA:
5682             trailing_comma = True
5683             exclude.add(id(last_leaf))
5684         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5685     except (IndexError, ValueError):
5686         return False
5687
5688     return max_priority == COMMA_PRIORITY and (
5689         trailing_comma
5690         # always explode imports
5691         or opening_bracket.parent.type in {syms.atom, syms.import_from}
5692     )
5693
5694
5695 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5696     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5697     if opening.type != token.LPAR and closing.type != token.RPAR:
5698         return False
5699
5700     depth = closing.bracket_depth + 1
5701     for _opening_index, leaf in enumerate(leaves):
5702         if leaf is opening:
5703             break
5704
5705     else:
5706         raise LookupError("Opening paren not found in `leaves`")
5707
5708     commas = 0
5709     _opening_index += 1
5710     for leaf in leaves[_opening_index:]:
5711         if leaf is closing:
5712             break
5713
5714         bracket_depth = leaf.bracket_depth
5715         if bracket_depth == depth and leaf.type == token.COMMA:
5716             commas += 1
5717             if leaf.parent and leaf.parent.type in {
5718                 syms.arglist,
5719                 syms.typedargslist,
5720             }:
5721                 commas += 1
5722                 break
5723
5724     return commas < 2
5725
5726
5727 def get_features_used(node: Node) -> Set[Feature]:
5728     """Return a set of (relatively) new Python features used in this file.
5729
5730     Currently looking for:
5731     - f-strings;
5732     - underscores in numeric literals;
5733     - trailing commas after * or ** in function signatures and calls;
5734     - positional only arguments in function signatures and lambdas;
5735     - assignment expression;
5736     - relaxed decorator syntax;
5737     """
5738     features: Set[Feature] = set()
5739     for n in node.pre_order():
5740         if n.type == token.STRING:
5741             value_head = n.value[:2]  # type: ignore
5742             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5743                 features.add(Feature.F_STRINGS)
5744
5745         elif n.type == token.NUMBER:
5746             if "_" in n.value:  # type: ignore
5747                 features.add(Feature.NUMERIC_UNDERSCORES)
5748
5749         elif n.type == token.SLASH:
5750             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5751                 features.add(Feature.POS_ONLY_ARGUMENTS)
5752
5753         elif n.type == token.COLONEQUAL:
5754             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5755
5756         elif n.type == syms.decorator:
5757             if len(n.children) > 1 and not is_simple_decorator_expression(
5758                 n.children[1]
5759             ):
5760                 features.add(Feature.RELAXED_DECORATORS)
5761
5762         elif (
5763             n.type in {syms.typedargslist, syms.arglist}
5764             and n.children
5765             and n.children[-1].type == token.COMMA
5766         ):
5767             if n.type == syms.typedargslist:
5768                 feature = Feature.TRAILING_COMMA_IN_DEF
5769             else:
5770                 feature = Feature.TRAILING_COMMA_IN_CALL
5771
5772             for ch in n.children:
5773                 if ch.type in STARS:
5774                     features.add(feature)
5775
5776                 if ch.type == syms.argument:
5777                     for argch in ch.children:
5778                         if argch.type in STARS:
5779                             features.add(feature)
5780
5781     return features
5782
5783
5784 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5785     """Detect the version to target based on the nodes used."""
5786     features = get_features_used(node)
5787     return {
5788         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5789     }
5790
5791
5792 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5793     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5794
5795     Brackets can be omitted if the entire trailer up to and including
5796     a preceding closing bracket fits in one line.
5797
5798     Yielded sets are cumulative (contain results of previous yields, too).  First
5799     set is empty, unless the line should explode, in which case bracket pairs until
5800     the one that needs to explode are omitted.
5801     """
5802
5803     omit: Set[LeafID] = set()
5804     if not line.should_explode:
5805         yield omit
5806
5807     length = 4 * line.depth
5808     opening_bracket: Optional[Leaf] = None
5809     closing_bracket: Optional[Leaf] = None
5810     inner_brackets: Set[LeafID] = set()
5811     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5812         length += leaf_length
5813         if length > line_length:
5814             break
5815
5816         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5817         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5818             break
5819
5820         if opening_bracket:
5821             if leaf is opening_bracket:
5822                 opening_bracket = None
5823             elif leaf.type in CLOSING_BRACKETS:
5824                 prev = line.leaves[index - 1] if index > 0 else None
5825                 if (
5826                     line.should_explode
5827                     and prev
5828                     and prev.type == token.COMMA
5829                     and not is_one_tuple_between(
5830                         leaf.opening_bracket, leaf, line.leaves
5831                     )
5832                 ):
5833                     # Never omit bracket pairs with trailing commas.
5834                     # We need to explode on those.
5835                     break
5836
5837                 inner_brackets.add(id(leaf))
5838         elif leaf.type in CLOSING_BRACKETS:
5839             prev = line.leaves[index - 1] if index > 0 else None
5840             if prev and prev.type in OPENING_BRACKETS:
5841                 # Empty brackets would fail a split so treat them as "inner"
5842                 # brackets (e.g. only add them to the `omit` set if another
5843                 # pair of brackets was good enough.
5844                 inner_brackets.add(id(leaf))
5845                 continue
5846
5847             if closing_bracket:
5848                 omit.add(id(closing_bracket))
5849                 omit.update(inner_brackets)
5850                 inner_brackets.clear()
5851                 yield omit
5852
5853             if (
5854                 line.should_explode
5855                 and prev
5856                 and prev.type == token.COMMA
5857                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
5858             ):
5859                 # Never omit bracket pairs with trailing commas.
5860                 # We need to explode on those.
5861                 break
5862
5863             if leaf.value:
5864                 opening_bracket = leaf.opening_bracket
5865                 closing_bracket = leaf
5866
5867
5868 def get_future_imports(node: Node) -> Set[str]:
5869     """Return a set of __future__ imports in the file."""
5870     imports: Set[str] = set()
5871
5872     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5873         for child in children:
5874             if isinstance(child, Leaf):
5875                 if child.type == token.NAME:
5876                     yield child.value
5877
5878             elif child.type == syms.import_as_name:
5879                 orig_name = child.children[0]
5880                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5881                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5882                 yield orig_name.value
5883
5884             elif child.type == syms.import_as_names:
5885                 yield from get_imports_from_children(child.children)
5886
5887             else:
5888                 raise AssertionError("Invalid syntax parsing imports")
5889
5890     for child in node.children:
5891         if child.type != syms.simple_stmt:
5892             break
5893
5894         first_child = child.children[0]
5895         if isinstance(first_child, Leaf):
5896             # Continue looking if we see a docstring; otherwise stop.
5897             if (
5898                 len(child.children) == 2
5899                 and first_child.type == token.STRING
5900                 and child.children[1].type == token.NEWLINE
5901             ):
5902                 continue
5903
5904             break
5905
5906         elif first_child.type == syms.import_from:
5907             module_name = first_child.children[1]
5908             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5909                 break
5910
5911             imports |= set(get_imports_from_children(first_child.children[3:]))
5912         else:
5913             break
5914
5915     return imports
5916
5917
5918 @lru_cache()
5919 def get_gitignore(root: Path) -> PathSpec:
5920     """ Return a PathSpec matching gitignore content if present."""
5921     gitignore = root / ".gitignore"
5922     lines: List[str] = []
5923     if gitignore.is_file():
5924         with gitignore.open() as gf:
5925             lines = gf.readlines()
5926     return PathSpec.from_lines("gitwildmatch", lines)
5927
5928
5929 def normalize_path_maybe_ignore(
5930     path: Path, root: Path, report: "Report"
5931 ) -> Optional[str]:
5932     """Normalize `path`. May return `None` if `path` was ignored.
5933
5934     `report` is where "path ignored" output goes.
5935     """
5936     try:
5937         abspath = path if path.is_absolute() else Path.cwd() / path
5938         normalized_path = abspath.resolve().relative_to(root).as_posix()
5939     except OSError as e:
5940         report.path_ignored(path, f"cannot be read because {e}")
5941         return None
5942
5943     except ValueError:
5944         if path.is_symlink():
5945             report.path_ignored(path, f"is a symbolic link that points outside {root}")
5946             return None
5947
5948         raise
5949
5950     return normalized_path
5951
5952
5953 def gen_python_files(
5954     paths: Iterable[Path],
5955     root: Path,
5956     include: Optional[Pattern[str]],
5957     exclude: Pattern[str],
5958     force_exclude: Optional[Pattern[str]],
5959     report: "Report",
5960     gitignore: PathSpec,
5961 ) -> Iterator[Path]:
5962     """Generate all files under `path` whose paths are not excluded by the
5963     `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
5964
5965     Symbolic links pointing outside of the `root` directory are ignored.
5966
5967     `report` is where output about exclusions goes.
5968     """
5969     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5970     for child in paths:
5971         normalized_path = normalize_path_maybe_ignore(child, root, report)
5972         if normalized_path is None:
5973             continue
5974
5975         # First ignore files matching .gitignore
5976         if gitignore.match_file(normalized_path):
5977             report.path_ignored(child, "matches the .gitignore file content")
5978             continue
5979
5980         # Then ignore with `--exclude` and `--force-exclude` options.
5981         normalized_path = "/" + normalized_path
5982         if child.is_dir():
5983             normalized_path += "/"
5984
5985         exclude_match = exclude.search(normalized_path) if exclude else None
5986         if exclude_match and exclude_match.group(0):
5987             report.path_ignored(child, "matches the --exclude regular expression")
5988             continue
5989
5990         force_exclude_match = (
5991             force_exclude.search(normalized_path) if force_exclude else None
5992         )
5993         if force_exclude_match and force_exclude_match.group(0):
5994             report.path_ignored(child, "matches the --force-exclude regular expression")
5995             continue
5996
5997         if child.is_dir():
5998             yield from gen_python_files(
5999                 child.iterdir(),
6000                 root,
6001                 include,
6002                 exclude,
6003                 force_exclude,
6004                 report,
6005                 gitignore,
6006             )
6007
6008         elif child.is_file():
6009             include_match = include.search(normalized_path) if include else True
6010             if include_match:
6011                 yield child
6012
6013
6014 @lru_cache()
6015 def find_project_root(srcs: Iterable[str]) -> Path:
6016     """Return a directory containing .git, .hg, or pyproject.toml.
6017
6018     That directory will be a common parent of all files and directories
6019     passed in `srcs`.
6020
6021     If no directory in the tree contains a marker that would specify it's the
6022     project root, the root of the file system is returned.
6023     """
6024     if not srcs:
6025         return Path("/").resolve()
6026
6027     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
6028
6029     # A list of lists of parents for each 'src'. 'src' is included as a
6030     # "parent" of itself if it is a directory
6031     src_parents = [
6032         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
6033     ]
6034
6035     common_base = max(
6036         set.intersection(*(set(parents) for parents in src_parents)),
6037         key=lambda path: path.parts,
6038     )
6039
6040     for directory in (common_base, *common_base.parents):
6041         if (directory / ".git").exists():
6042             return directory
6043
6044         if (directory / ".hg").is_dir():
6045             return directory
6046
6047         if (directory / "pyproject.toml").is_file():
6048             return directory
6049
6050     return directory
6051
6052
6053 @dataclass
6054 class Report:
6055     """Provides a reformatting counter. Can be rendered with `str(report)`."""
6056
6057     check: bool = False
6058     diff: bool = False
6059     quiet: bool = False
6060     verbose: bool = False
6061     change_count: int = 0
6062     same_count: int = 0
6063     failure_count: int = 0
6064
6065     def done(self, src: Path, changed: Changed) -> None:
6066         """Increment the counter for successful reformatting. Write out a message."""
6067         if changed is Changed.YES:
6068             reformatted = "would reformat" if self.check or self.diff else "reformatted"
6069             if self.verbose or not self.quiet:
6070                 out(f"{reformatted} {src}")
6071             self.change_count += 1
6072         else:
6073             if self.verbose:
6074                 if changed is Changed.NO:
6075                     msg = f"{src} already well formatted, good job."
6076                 else:
6077                     msg = f"{src} wasn't modified on disk since last run."
6078                 out(msg, bold=False)
6079             self.same_count += 1
6080
6081     def failed(self, src: Path, message: str) -> None:
6082         """Increment the counter for failed reformatting. Write out a message."""
6083         err(f"error: cannot format {src}: {message}")
6084         self.failure_count += 1
6085
6086     def path_ignored(self, path: Path, message: str) -> None:
6087         if self.verbose:
6088             out(f"{path} ignored: {message}", bold=False)
6089
6090     @property
6091     def return_code(self) -> int:
6092         """Return the exit code that the app should use.
6093
6094         This considers the current state of changed files and failures:
6095         - if there were any failures, return 123;
6096         - if any files were changed and --check is being used, return 1;
6097         - otherwise return 0.
6098         """
6099         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6100         # 126 we have special return codes reserved by the shell.
6101         if self.failure_count:
6102             return 123
6103
6104         elif self.change_count and self.check:
6105             return 1
6106
6107         return 0
6108
6109     def __str__(self) -> str:
6110         """Render a color report of the current state.
6111
6112         Use `click.unstyle` to remove colors.
6113         """
6114         if self.check or self.diff:
6115             reformatted = "would be reformatted"
6116             unchanged = "would be left unchanged"
6117             failed = "would fail to reformat"
6118         else:
6119             reformatted = "reformatted"
6120             unchanged = "left unchanged"
6121             failed = "failed to reformat"
6122         report = []
6123         if self.change_count:
6124             s = "s" if self.change_count > 1 else ""
6125             report.append(
6126                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6127             )
6128         if self.same_count:
6129             s = "s" if self.same_count > 1 else ""
6130             report.append(f"{self.same_count} file{s} {unchanged}")
6131         if self.failure_count:
6132             s = "s" if self.failure_count > 1 else ""
6133             report.append(
6134                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6135             )
6136         return ", ".join(report) + "."
6137
6138
6139 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6140     filename = "<unknown>"
6141     if sys.version_info >= (3, 8):
6142         # TODO: support Python 4+ ;)
6143         for minor_version in range(sys.version_info[1], 4, -1):
6144             try:
6145                 return ast.parse(src, filename, feature_version=(3, minor_version))
6146             except SyntaxError:
6147                 continue
6148     else:
6149         for feature_version in (7, 6):
6150             try:
6151                 return ast3.parse(src, filename, feature_version=feature_version)
6152             except SyntaxError:
6153                 continue
6154
6155     return ast27.parse(src)
6156
6157
6158 def _fixup_ast_constants(
6159     node: Union[ast.AST, ast3.AST, ast27.AST]
6160 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6161     """Map ast nodes deprecated in 3.8 to Constant."""
6162     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6163         return ast.Constant(value=node.s)
6164
6165     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6166         return ast.Constant(value=node.n)
6167
6168     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6169         return ast.Constant(value=node.value)
6170
6171     return node
6172
6173
6174 def _stringify_ast(
6175     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6176 ) -> Iterator[str]:
6177     """Simple visitor generating strings to compare ASTs by content."""
6178
6179     node = _fixup_ast_constants(node)
6180
6181     yield f"{'  ' * depth}{node.__class__.__name__}("
6182
6183     for field in sorted(node._fields):  # noqa: F402
6184         # TypeIgnore has only one field 'lineno' which breaks this comparison
6185         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6186         if sys.version_info >= (3, 8):
6187             type_ignore_classes += (ast.TypeIgnore,)
6188         if isinstance(node, type_ignore_classes):
6189             break
6190
6191         try:
6192             value = getattr(node, field)
6193         except AttributeError:
6194             continue
6195
6196         yield f"{'  ' * (depth+1)}{field}="
6197
6198         if isinstance(value, list):
6199             for item in value:
6200                 # Ignore nested tuples within del statements, because we may insert
6201                 # parentheses and they change the AST.
6202                 if (
6203                     field == "targets"
6204                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6205                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6206                 ):
6207                     for item in item.elts:
6208                         yield from _stringify_ast(item, depth + 2)
6209
6210                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6211                     yield from _stringify_ast(item, depth + 2)
6212
6213         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6214             yield from _stringify_ast(value, depth + 2)
6215
6216         else:
6217             # Constant strings may be indented across newlines, if they are
6218             # docstrings; fold spaces after newlines when comparing. Similarly,
6219             # trailing and leading space may be removed.
6220             if (
6221                 isinstance(node, ast.Constant)
6222                 and field == "value"
6223                 and isinstance(value, str)
6224             ):
6225                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6226             else:
6227                 normalized = value
6228             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6229
6230     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6231
6232
6233 def assert_equivalent(src: str, dst: str) -> None:
6234     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6235     try:
6236         src_ast = parse_ast(src)
6237     except Exception as exc:
6238         raise AssertionError(
6239             "cannot use --safe with this file; failed to parse source file.  AST"
6240             f" error message: {exc}"
6241         )
6242
6243     try:
6244         dst_ast = parse_ast(dst)
6245     except Exception as exc:
6246         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6247         raise AssertionError(
6248             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6249             " on https://github.com/psf/black/issues.  This invalid output might be"
6250             f" helpful: {log}"
6251         ) from None
6252
6253     src_ast_str = "\n".join(_stringify_ast(src_ast))
6254     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6255     if src_ast_str != dst_ast_str:
6256         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6257         raise AssertionError(
6258             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6259             " source.  Please report a bug on https://github.com/psf/black/issues. "
6260             f" This diff might be helpful: {log}"
6261         ) from None
6262
6263
6264 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6265     """Raise AssertionError if `dst` reformats differently the second time."""
6266     newdst = format_str(dst, mode=mode)
6267     if dst != newdst:
6268         log = dump_to_file(
6269             str(mode),
6270             diff(src, dst, "source", "first pass"),
6271             diff(dst, newdst, "first pass", "second pass"),
6272         )
6273         raise AssertionError(
6274             "INTERNAL ERROR: Black produced different code on the second pass of the"
6275             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6276             f"  This diff might be helpful: {log}"
6277         ) from None
6278
6279
6280 @mypyc_attr(patchable=True)
6281 def dump_to_file(*output: str) -> str:
6282     """Dump `output` to a temporary file. Return path to the file."""
6283     with tempfile.NamedTemporaryFile(
6284         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6285     ) as f:
6286         for lines in output:
6287             f.write(lines)
6288             if lines and lines[-1] != "\n":
6289                 f.write("\n")
6290     return f.name
6291
6292
6293 @contextmanager
6294 def nullcontext() -> Iterator[None]:
6295     """Return an empty context manager.
6296
6297     To be used like `nullcontext` in Python 3.7.
6298     """
6299     yield
6300
6301
6302 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6303     """Return a unified diff string between strings `a` and `b`."""
6304     import difflib
6305
6306     a_lines = [line + "\n" for line in a.splitlines()]
6307     b_lines = [line + "\n" for line in b.splitlines()]
6308     return "".join(
6309         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6310     )
6311
6312
6313 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6314     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6315     err("Aborted!")
6316     for task in tasks:
6317         task.cancel()
6318
6319
6320 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6321     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6322     try:
6323         if sys.version_info[:2] >= (3, 7):
6324             all_tasks = asyncio.all_tasks
6325         else:
6326             all_tasks = asyncio.Task.all_tasks
6327         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6328         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6329         if not to_cancel:
6330             return
6331
6332         for task in to_cancel:
6333             task.cancel()
6334         loop.run_until_complete(
6335             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6336         )
6337     finally:
6338         # `concurrent.futures.Future` objects cannot be cancelled once they
6339         # are already running. There might be some when the `shutdown()` happened.
6340         # Silence their logger's spew about the event loop being closed.
6341         cf_logger = logging.getLogger("concurrent.futures")
6342         cf_logger.setLevel(logging.CRITICAL)
6343         loop.close()
6344
6345
6346 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6347     """Replace `regex` with `replacement` twice on `original`.
6348
6349     This is used by string normalization to perform replaces on
6350     overlapping matches.
6351     """
6352     return regex.sub(replacement, regex.sub(replacement, original))
6353
6354
6355 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6356     """Compile a regular expression string in `regex`.
6357
6358     If it contains newlines, use verbose mode.
6359     """
6360     if "\n" in regex:
6361         regex = "(?x)" + regex
6362     compiled: Pattern[str] = re.compile(regex)
6363     return compiled
6364
6365
6366 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6367     """Like `reversed(enumerate(sequence))` if that were possible."""
6368     index = len(sequence) - 1
6369     for element in reversed(sequence):
6370         yield (index, element)
6371         index -= 1
6372
6373
6374 def enumerate_with_length(
6375     line: Line, reversed: bool = False
6376 ) -> Iterator[Tuple[Index, Leaf, int]]:
6377     """Return an enumeration of leaves with their length.
6378
6379     Stops prematurely on multiline strings and standalone comments.
6380     """
6381     op = cast(
6382         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6383         enumerate_reversed if reversed else enumerate,
6384     )
6385     for index, leaf in op(line.leaves):
6386         length = len(leaf.prefix) + len(leaf.value)
6387         if "\n" in leaf.value:
6388             return  # Multiline strings, we can't continue.
6389
6390         for comment in line.comments_after(leaf):
6391             length += len(comment.value)
6392
6393         yield index, leaf, length
6394
6395
6396 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6397     """Return True if `line` is no longer than `line_length`.
6398
6399     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6400     """
6401     if not line_str:
6402         line_str = line_to_string(line)
6403     return (
6404         len(line_str) <= line_length
6405         and "\n" not in line_str  # multiline strings
6406         and not line.contains_standalone_comments()
6407     )
6408
6409
6410 def can_be_split(line: Line) -> bool:
6411     """Return False if the line cannot be split *for sure*.
6412
6413     This is not an exhaustive search but a cheap heuristic that we can use to
6414     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6415     in unnecessary parentheses).
6416     """
6417     leaves = line.leaves
6418     if len(leaves) < 2:
6419         return False
6420
6421     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6422         call_count = 0
6423         dot_count = 0
6424         next = leaves[-1]
6425         for leaf in leaves[-2::-1]:
6426             if leaf.type in OPENING_BRACKETS:
6427                 if next.type not in CLOSING_BRACKETS:
6428                     return False
6429
6430                 call_count += 1
6431             elif leaf.type == token.DOT:
6432                 dot_count += 1
6433             elif leaf.type == token.NAME:
6434                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6435                     return False
6436
6437             elif leaf.type not in CLOSING_BRACKETS:
6438                 return False
6439
6440             if dot_count > 1 and call_count > 1:
6441                 return False
6442
6443     return True
6444
6445
6446 def can_omit_invisible_parens(
6447     line: Line,
6448     line_length: int,
6449     omit_on_explode: Collection[LeafID] = (),
6450 ) -> bool:
6451     """Does `line` have a shape safe to reformat without optional parens around it?
6452
6453     Returns True for only a subset of potentially nice looking formattings but
6454     the point is to not return false positives that end up producing lines that
6455     are too long.
6456     """
6457     bt = line.bracket_tracker
6458     if not bt.delimiters:
6459         # Without delimiters the optional parentheses are useless.
6460         return True
6461
6462     max_priority = bt.max_delimiter_priority()
6463     if bt.delimiter_count_with_priority(max_priority) > 1:
6464         # With more than one delimiter of a kind the optional parentheses read better.
6465         return False
6466
6467     if max_priority == DOT_PRIORITY:
6468         # A single stranded method call doesn't require optional parentheses.
6469         return True
6470
6471     assert len(line.leaves) >= 2, "Stranded delimiter"
6472
6473     # With a single delimiter, omit if the expression starts or ends with
6474     # a bracket.
6475     first = line.leaves[0]
6476     second = line.leaves[1]
6477     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6478         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6479             return True
6480
6481         # Note: we are not returning False here because a line might have *both*
6482         # a leading opening bracket and a trailing closing bracket.  If the
6483         # opening bracket doesn't match our rule, maybe the closing will.
6484
6485     penultimate = line.leaves[-2]
6486     last = line.leaves[-1]
6487     if line.should_explode:
6488         try:
6489             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6490         except LookupError:
6491             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6492             return False
6493
6494     if (
6495         last.type == token.RPAR
6496         or last.type == token.RBRACE
6497         or (
6498             # don't use indexing for omitting optional parentheses;
6499             # it looks weird
6500             last.type == token.RSQB
6501             and last.parent
6502             and last.parent.type != syms.trailer
6503         )
6504     ):
6505         if penultimate.type in OPENING_BRACKETS:
6506             # Empty brackets don't help.
6507             return False
6508
6509         if is_multiline_string(first):
6510             # Additional wrapping of a multiline string in this situation is
6511             # unnecessary.
6512             return True
6513
6514         if line.should_explode and penultimate.type == token.COMMA:
6515             # The rightmost non-omitted bracket pair is the one we want to explode on.
6516             return True
6517
6518         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6519             return True
6520
6521     return False
6522
6523
6524 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6525     """See `can_omit_invisible_parens`."""
6526     remainder = False
6527     length = 4 * line.depth
6528     _index = -1
6529     for _index, leaf, leaf_length in enumerate_with_length(line):
6530         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6531             remainder = True
6532         if remainder:
6533             length += leaf_length
6534             if length > line_length:
6535                 break
6536
6537             if leaf.type in OPENING_BRACKETS:
6538                 # There are brackets we can further split on.
6539                 remainder = False
6540
6541     else:
6542         # checked the entire string and line length wasn't exceeded
6543         if len(line.leaves) == _index + 1:
6544             return True
6545
6546     return False
6547
6548
6549 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6550     """See `can_omit_invisible_parens`."""
6551     length = 4 * line.depth
6552     seen_other_brackets = False
6553     for _index, leaf, leaf_length in enumerate_with_length(line):
6554         length += leaf_length
6555         if leaf is last.opening_bracket:
6556             if seen_other_brackets or length <= line_length:
6557                 return True
6558
6559         elif leaf.type in OPENING_BRACKETS:
6560             # There are brackets we can further split on.
6561             seen_other_brackets = True
6562
6563     return False
6564
6565
6566 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6567     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6568     stop_after = None
6569     last = None
6570     for leaf in reversed(leaves):
6571         if stop_after:
6572             if leaf is stop_after:
6573                 stop_after = None
6574             continue
6575
6576         if last:
6577             return leaf, last
6578
6579         if id(leaf) in omit:
6580             stop_after = leaf.opening_bracket
6581         else:
6582             last = leaf
6583     else:
6584         raise LookupError("Last two leaves were also skipped")
6585
6586
6587 def run_transformer(
6588     line: Line,
6589     transform: Transformer,
6590     mode: Mode,
6591     features: Collection[Feature],
6592     *,
6593     line_str: str = "",
6594 ) -> List[Line]:
6595     if not line_str:
6596         line_str = line_to_string(line)
6597     result: List[Line] = []
6598     for transformed_line in transform(line, features):
6599         if str(transformed_line).strip("\n") == line_str:
6600             raise CannotTransform("Line transformer returned an unchanged result")
6601
6602         result.extend(transform_line(transformed_line, mode=mode, features=features))
6603
6604     if not (
6605         transform.__name__ == "rhs"
6606         and line.bracket_tracker.invisible
6607         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6608         and not line.contains_multiline_strings()
6609         and not result[0].contains_uncollapsable_type_comments()
6610         and not result[0].contains_unsplittable_type_ignore()
6611         and not is_line_short_enough(result[0], line_length=mode.line_length)
6612     ):
6613         return result
6614
6615     line_copy = line.clone()
6616     append_leaves(line_copy, line, line.leaves)
6617     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6618     second_opinion = run_transformer(
6619         line_copy, transform, mode, features_fop, line_str=line_str
6620     )
6621     if all(
6622         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6623     ):
6624         result = second_opinion
6625     return result
6626
6627
6628 def get_cache_file(mode: Mode) -> Path:
6629     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6630
6631
6632 def read_cache(mode: Mode) -> Cache:
6633     """Read the cache if it exists and is well formed.
6634
6635     If it is not well formed, the call to write_cache later should resolve the issue.
6636     """
6637     cache_file = get_cache_file(mode)
6638     if not cache_file.exists():
6639         return {}
6640
6641     with cache_file.open("rb") as fobj:
6642         try:
6643             cache: Cache = pickle.load(fobj)
6644         except (pickle.UnpicklingError, ValueError):
6645             return {}
6646
6647     return cache
6648
6649
6650 def get_cache_info(path: Path) -> CacheInfo:
6651     """Return the information used to check if a file is already formatted or not."""
6652     stat = path.stat()
6653     return stat.st_mtime, stat.st_size
6654
6655
6656 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6657     """Split an iterable of paths in `sources` into two sets.
6658
6659     The first contains paths of files that modified on disk or are not in the
6660     cache. The other contains paths to non-modified files.
6661     """
6662     todo, done = set(), set()
6663     for src in sources:
6664         src = src.resolve()
6665         if cache.get(src) != get_cache_info(src):
6666             todo.add(src)
6667         else:
6668             done.add(src)
6669     return todo, done
6670
6671
6672 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6673     """Update the cache file."""
6674     cache_file = get_cache_file(mode)
6675     try:
6676         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6677         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6678         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6679             pickle.dump(new_cache, f, protocol=4)
6680         os.replace(f.name, cache_file)
6681     except OSError:
6682         pass
6683
6684
6685 def patch_click() -> None:
6686     """Make Click not crash.
6687
6688     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6689     default which restricts paths that it can access during the lifetime of the
6690     application.  Click refuses to work in this scenario by raising a RuntimeError.
6691
6692     In case of Black the likelihood that non-ASCII characters are going to be used in
6693     file paths is minimal since it's Python source code.  Moreover, this crash was
6694     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6695     """
6696     try:
6697         from click import core
6698         from click import _unicodefun  # type: ignore
6699     except ModuleNotFoundError:
6700         return
6701
6702     for module in (core, _unicodefun):
6703         if hasattr(module, "_verify_python3_env"):
6704             module._verify_python3_env = lambda: None
6705
6706
6707 def patched_main() -> None:
6708     freeze_support()
6709     patch_click()
6710     main()
6711
6712
6713 def is_docstring(leaf: Leaf) -> bool:
6714     if not is_multiline_string(leaf):
6715         # For the purposes of docstring re-indentation, we don't need to do anything
6716         # with single-line docstrings.
6717         return False
6718
6719     if prev_siblings_are(
6720         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
6721     ):
6722         return True
6723
6724     # Multiline docstring on the same line as the `def`.
6725     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
6726         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
6727         # grammar. We're safe to return True without further checks.
6728         return True
6729
6730     return False
6731
6732
6733 def fix_docstring(docstring: str, prefix: str) -> str:
6734     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6735     if not docstring:
6736         return ""
6737     # Convert tabs to spaces (following the normal Python rules)
6738     # and split into a list of lines:
6739     lines = docstring.expandtabs().splitlines()
6740     # Determine minimum indentation (first line doesn't count):
6741     indent = sys.maxsize
6742     for line in lines[1:]:
6743         stripped = line.lstrip()
6744         if stripped:
6745             indent = min(indent, len(line) - len(stripped))
6746     # Remove indentation (first line is special):
6747     trimmed = [lines[0].strip()]
6748     if indent < sys.maxsize:
6749         last_line_idx = len(lines) - 2
6750         for i, line in enumerate(lines[1:]):
6751             stripped_line = line[indent:].rstrip()
6752             if stripped_line or i == last_line_idx:
6753                 trimmed.append(prefix + stripped_line)
6754             else:
6755                 trimmed.append("")
6756     return "\n".join(trimmed)
6757
6758
6759 if __name__ == "__main__":
6760     patched_main()