]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

c1db8e6ef64812842be8efb4b48490a2c60927bb
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 class BracketMatchError(KeyError):
116     """Raised when an opening bracket is unable to be matched to a closing bracket."""
117
118
119 T = TypeVar("T")
120 E = TypeVar("E", bound=Exception)
121
122
123 class Ok(Generic[T]):
124     def __init__(self, value: T) -> None:
125         self._value = value
126
127     def ok(self) -> T:
128         return self._value
129
130
131 class Err(Generic[E]):
132     def __init__(self, e: E) -> None:
133         self._e = e
134
135     def err(self) -> E:
136         return self._e
137
138
139 # The 'Result' return type is used to implement an error-handling model heavily
140 # influenced by that used by the Rust programming language
141 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
142 Result = Union[Ok[T], Err[E]]
143 TResult = Result[T, CannotTransform]  # (T)ransform Result
144 TMatchResult = TResult[Index]
145
146
147 class WriteBack(Enum):
148     NO = 0
149     YES = 1
150     DIFF = 2
151     CHECK = 3
152     COLOR_DIFF = 4
153
154     @classmethod
155     def from_configuration(
156         cls, *, check: bool, diff: bool, color: bool = False
157     ) -> "WriteBack":
158         if check and not diff:
159             return cls.CHECK
160
161         if diff and color:
162             return cls.COLOR_DIFF
163
164         return cls.DIFF if diff else cls.YES
165
166
167 class Changed(Enum):
168     NO = 0
169     CACHED = 1
170     YES = 2
171
172
173 class TargetVersion(Enum):
174     PY27 = 2
175     PY33 = 3
176     PY34 = 4
177     PY35 = 5
178     PY36 = 6
179     PY37 = 7
180     PY38 = 8
181     PY39 = 9
182
183     def is_python2(self) -> bool:
184         return self is TargetVersion.PY27
185
186
187 class Feature(Enum):
188     # All string literals are unicode
189     UNICODE_LITERALS = 1
190     F_STRINGS = 2
191     NUMERIC_UNDERSCORES = 3
192     TRAILING_COMMA_IN_CALL = 4
193     TRAILING_COMMA_IN_DEF = 5
194     # The following two feature-flags are mutually exclusive, and exactly one should be
195     # set for every version of python.
196     ASYNC_IDENTIFIERS = 6
197     ASYNC_KEYWORDS = 7
198     ASSIGNMENT_EXPRESSIONS = 8
199     POS_ONLY_ARGUMENTS = 9
200     RELAXED_DECORATORS = 10
201     FORCE_OPTIONAL_PARENTHESES = 50
202
203
204 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
205     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
206     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
207     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
208     TargetVersion.PY35: {
209         Feature.UNICODE_LITERALS,
210         Feature.TRAILING_COMMA_IN_CALL,
211         Feature.ASYNC_IDENTIFIERS,
212     },
213     TargetVersion.PY36: {
214         Feature.UNICODE_LITERALS,
215         Feature.F_STRINGS,
216         Feature.NUMERIC_UNDERSCORES,
217         Feature.TRAILING_COMMA_IN_CALL,
218         Feature.TRAILING_COMMA_IN_DEF,
219         Feature.ASYNC_IDENTIFIERS,
220     },
221     TargetVersion.PY37: {
222         Feature.UNICODE_LITERALS,
223         Feature.F_STRINGS,
224         Feature.NUMERIC_UNDERSCORES,
225         Feature.TRAILING_COMMA_IN_CALL,
226         Feature.TRAILING_COMMA_IN_DEF,
227         Feature.ASYNC_KEYWORDS,
228     },
229     TargetVersion.PY38: {
230         Feature.UNICODE_LITERALS,
231         Feature.F_STRINGS,
232         Feature.NUMERIC_UNDERSCORES,
233         Feature.TRAILING_COMMA_IN_CALL,
234         Feature.TRAILING_COMMA_IN_DEF,
235         Feature.ASYNC_KEYWORDS,
236         Feature.ASSIGNMENT_EXPRESSIONS,
237         Feature.POS_ONLY_ARGUMENTS,
238     },
239     TargetVersion.PY39: {
240         Feature.UNICODE_LITERALS,
241         Feature.F_STRINGS,
242         Feature.NUMERIC_UNDERSCORES,
243         Feature.TRAILING_COMMA_IN_CALL,
244         Feature.TRAILING_COMMA_IN_DEF,
245         Feature.ASYNC_KEYWORDS,
246         Feature.ASSIGNMENT_EXPRESSIONS,
247         Feature.RELAXED_DECORATORS,
248         Feature.POS_ONLY_ARGUMENTS,
249     },
250 }
251
252
253 @dataclass
254 class Mode:
255     target_versions: Set[TargetVersion] = field(default_factory=set)
256     line_length: int = DEFAULT_LINE_LENGTH
257     string_normalization: bool = True
258     experimental_string_processing: bool = False
259     is_pyi: bool = False
260
261     def get_cache_key(self) -> str:
262         if self.target_versions:
263             version_str = ",".join(
264                 str(version.value)
265                 for version in sorted(self.target_versions, key=lambda v: v.value)
266             )
267         else:
268             version_str = "-"
269         parts = [
270             version_str,
271             str(self.line_length),
272             str(int(self.string_normalization)),
273             str(int(self.is_pyi)),
274         ]
275         return ".".join(parts)
276
277
278 # Legacy name, left for integrations.
279 FileMode = Mode
280
281
282 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
283     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
284
285
286 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
287     """Find the absolute filepath to a pyproject.toml if it exists"""
288     path_project_root = find_project_root(path_search_start)
289     path_pyproject_toml = path_project_root / "pyproject.toml"
290     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
291
292
293 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
294     """Parse a pyproject toml file, pulling out relevant parts for Black
295
296     If parsing fails, will raise a toml.TomlDecodeError
297     """
298     pyproject_toml = toml.load(path_config)
299     config = pyproject_toml.get("tool", {}).get("black", {})
300     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
301
302
303 def read_pyproject_toml(
304     ctx: click.Context, param: click.Parameter, value: Optional[str]
305 ) -> Optional[str]:
306     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
307
308     Returns the path to a successfully found and read configuration file, None
309     otherwise.
310     """
311     if not value:
312         value = find_pyproject_toml(ctx.params.get("src", ()))
313         if value is None:
314             return None
315
316     try:
317         config = parse_pyproject_toml(value)
318     except (toml.TomlDecodeError, OSError) as e:
319         raise click.FileError(
320             filename=value, hint=f"Error reading configuration file: {e}"
321         )
322
323     if not config:
324         return None
325     else:
326         # Sanitize the values to be Click friendly. For more information please see:
327         # https://github.com/psf/black/issues/1458
328         # https://github.com/pallets/click/issues/1567
329         config = {
330             k: str(v) if not isinstance(v, (list, dict)) else v
331             for k, v in config.items()
332         }
333
334     target_version = config.get("target_version")
335     if target_version is not None and not isinstance(target_version, list):
336         raise click.BadOptionUsage(
337             "target-version", "Config key target-version must be a list"
338         )
339
340     default_map: Dict[str, Any] = {}
341     if ctx.default_map:
342         default_map.update(ctx.default_map)
343     default_map.update(config)
344
345     ctx.default_map = default_map
346     return value
347
348
349 def target_version_option_callback(
350     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
351 ) -> List[TargetVersion]:
352     """Compute the target versions from a --target-version flag.
353
354     This is its own function because mypy couldn't infer the type correctly
355     when it was a lambda, causing mypyc trouble.
356     """
357     return [TargetVersion[val.upper()] for val in v]
358
359
360 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
361 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
362 @click.option(
363     "-l",
364     "--line-length",
365     type=int,
366     default=DEFAULT_LINE_LENGTH,
367     help="How many characters per line to allow.",
368     show_default=True,
369 )
370 @click.option(
371     "-t",
372     "--target-version",
373     type=click.Choice([v.name.lower() for v in TargetVersion]),
374     callback=target_version_option_callback,
375     multiple=True,
376     help=(
377         "Python versions that should be supported by Black's output. [default: per-file"
378         " auto-detection]"
379     ),
380 )
381 @click.option(
382     "--pyi",
383     is_flag=True,
384     help=(
385         "Format all input files like typing stubs regardless of file extension (useful"
386         " when piping source on standard input)."
387     ),
388 )
389 @click.option(
390     "-S",
391     "--skip-string-normalization",
392     is_flag=True,
393     help="Don't normalize string quotes or prefixes.",
394 )
395 @click.option(
396     "--experimental-string-processing",
397     is_flag=True,
398     hidden=True,
399     help=(
400         "Experimental option that performs more normalization on string literals."
401         " Currently disabled because it leads to some crashes."
402     ),
403 )
404 @click.option(
405     "--check",
406     is_flag=True,
407     help=(
408         "Don't write the files back, just return the status.  Return code 0 means"
409         " nothing would change.  Return code 1 means some files would be reformatted."
410         " Return code 123 means there was an internal error."
411     ),
412 )
413 @click.option(
414     "--diff",
415     is_flag=True,
416     help="Don't write the files back, just output a diff for each file on stdout.",
417 )
418 @click.option(
419     "--color/--no-color",
420     is_flag=True,
421     help="Show colored diff. Only applies when `--diff` is given.",
422 )
423 @click.option(
424     "--fast/--safe",
425     is_flag=True,
426     help="If --fast given, skip temporary sanity checks. [default: --safe]",
427 )
428 @click.option(
429     "--include",
430     type=str,
431     default=DEFAULT_INCLUDES,
432     help=(
433         "A regular expression that matches files and directories that should be"
434         " included on recursive searches.  An empty value means all files are included"
435         " regardless of the name.  Use forward slashes for directories on all platforms"
436         " (Windows, too).  Exclusions are calculated first, inclusions later."
437     ),
438     show_default=True,
439 )
440 @click.option(
441     "--exclude",
442     type=str,
443     default=DEFAULT_EXCLUDES,
444     help=(
445         "A regular expression that matches files and directories that should be"
446         " excluded on recursive searches.  An empty value means no paths are excluded."
447         " Use forward slashes for directories on all platforms (Windows, too). "
448         " Exclusions are calculated first, inclusions later."
449     ),
450     show_default=True,
451 )
452 @click.option(
453     "--force-exclude",
454     type=str,
455     help=(
456         "Like --exclude, but files and directories matching this regex will be "
457         "excluded even when they are passed explicitly as arguments"
458     ),
459 )
460 @click.option(
461     "-q",
462     "--quiet",
463     is_flag=True,
464     help=(
465         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
466         " those with 2>/dev/null."
467     ),
468 )
469 @click.option(
470     "-v",
471     "--verbose",
472     is_flag=True,
473     help=(
474         "Also emit messages to stderr about files that were not changed or were ignored"
475         " due to --exclude=."
476     ),
477 )
478 @click.version_option(version=__version__)
479 @click.argument(
480     "src",
481     nargs=-1,
482     type=click.Path(
483         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
484     ),
485     is_eager=True,
486 )
487 @click.option(
488     "--config",
489     type=click.Path(
490         exists=True,
491         file_okay=True,
492         dir_okay=False,
493         readable=True,
494         allow_dash=False,
495         path_type=str,
496     ),
497     is_eager=True,
498     callback=read_pyproject_toml,
499     help="Read configuration from FILE path.",
500 )
501 @click.pass_context
502 def main(
503     ctx: click.Context,
504     code: Optional[str],
505     line_length: int,
506     target_version: List[TargetVersion],
507     check: bool,
508     diff: bool,
509     color: bool,
510     fast: bool,
511     pyi: bool,
512     skip_string_normalization: bool,
513     experimental_string_processing: bool,
514     quiet: bool,
515     verbose: bool,
516     include: str,
517     exclude: str,
518     force_exclude: Optional[str],
519     src: Tuple[str, ...],
520     config: Optional[str],
521 ) -> None:
522     """The uncompromising code formatter."""
523     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
524     if target_version:
525         versions = set(target_version)
526     else:
527         # We'll autodetect later.
528         versions = set()
529     mode = Mode(
530         target_versions=versions,
531         line_length=line_length,
532         is_pyi=pyi,
533         string_normalization=not skip_string_normalization,
534         experimental_string_processing=experimental_string_processing,
535     )
536     if config and verbose:
537         out(f"Using configuration from {config}.", bold=False, fg="blue")
538     if code is not None:
539         print(format_str(code, mode=mode))
540         ctx.exit(0)
541     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
542     sources = get_sources(
543         ctx=ctx,
544         src=src,
545         quiet=quiet,
546         verbose=verbose,
547         include=include,
548         exclude=exclude,
549         force_exclude=force_exclude,
550         report=report,
551     )
552
553     path_empty(
554         sources,
555         "No Python files are present to be formatted. Nothing to do 😴",
556         quiet,
557         verbose,
558         ctx,
559     )
560
561     if len(sources) == 1:
562         reformat_one(
563             src=sources.pop(),
564             fast=fast,
565             write_back=write_back,
566             mode=mode,
567             report=report,
568         )
569     else:
570         reformat_many(
571             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
572         )
573
574     if verbose or not quiet:
575         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
576         click.secho(str(report), err=True)
577     ctx.exit(report.return_code)
578
579
580 def get_sources(
581     *,
582     ctx: click.Context,
583     src: Tuple[str, ...],
584     quiet: bool,
585     verbose: bool,
586     include: str,
587     exclude: str,
588     force_exclude: Optional[str],
589     report: "Report",
590 ) -> Set[Path]:
591     """Compute the set of files to be formatted."""
592     try:
593         include_regex = re_compile_maybe_verbose(include)
594     except re.error:
595         err(f"Invalid regular expression for include given: {include!r}")
596         ctx.exit(2)
597     try:
598         exclude_regex = re_compile_maybe_verbose(exclude)
599     except re.error:
600         err(f"Invalid regular expression for exclude given: {exclude!r}")
601         ctx.exit(2)
602     try:
603         force_exclude_regex = (
604             re_compile_maybe_verbose(force_exclude) if force_exclude else None
605         )
606     except re.error:
607         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
608         ctx.exit(2)
609
610     root = find_project_root(src)
611     sources: Set[Path] = set()
612     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
613     gitignore = get_gitignore(root)
614
615     for s in src:
616         p = Path(s)
617         if p.is_dir():
618             sources.update(
619                 gen_python_files(
620                     p.iterdir(),
621                     root,
622                     include_regex,
623                     exclude_regex,
624                     force_exclude_regex,
625                     report,
626                     gitignore,
627                 )
628             )
629         elif s == "-":
630             sources.add(p)
631         elif p.is_file():
632             normalized_path = normalize_path_maybe_ignore(p, root, report)
633             if normalized_path is None:
634                 continue
635
636             normalized_path = "/" + normalized_path
637             # Hard-exclude any files that matches the `--force-exclude` regex.
638             if force_exclude_regex:
639                 force_exclude_match = force_exclude_regex.search(normalized_path)
640             else:
641                 force_exclude_match = None
642             if force_exclude_match and force_exclude_match.group(0):
643                 report.path_ignored(p, "matches the --force-exclude regular expression")
644                 continue
645
646             sources.add(p)
647         else:
648             err(f"invalid path: {s}")
649     return sources
650
651
652 def path_empty(
653     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
654 ) -> None:
655     """
656     Exit if there is no `src` provided for formatting
657     """
658     if not src and (verbose or not quiet):
659         out(msg)
660         ctx.exit(0)
661
662
663 def reformat_one(
664     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
665 ) -> None:
666     """Reformat a single file under `src` without spawning child processes.
667
668     `fast`, `write_back`, and `mode` options are passed to
669     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
670     """
671     try:
672         changed = Changed.NO
673         if not src.is_file() and str(src) == "-":
674             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
675                 changed = Changed.YES
676         else:
677             cache: Cache = {}
678             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
679                 cache = read_cache(mode)
680                 res_src = src.resolve()
681                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
682                     changed = Changed.CACHED
683             if changed is not Changed.CACHED and format_file_in_place(
684                 src, fast=fast, write_back=write_back, mode=mode
685             ):
686                 changed = Changed.YES
687             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
688                 write_back is WriteBack.CHECK and changed is Changed.NO
689             ):
690                 write_cache(cache, [src], mode)
691         report.done(src, changed)
692     except Exception as exc:
693         if report.verbose:
694             traceback.print_exc()
695         report.failed(src, str(exc))
696
697
698 def reformat_many(
699     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
700 ) -> None:
701     """Reformat multiple files using a ProcessPoolExecutor."""
702     executor: Executor
703     loop = asyncio.get_event_loop()
704     worker_count = os.cpu_count()
705     if sys.platform == "win32":
706         # Work around https://bugs.python.org/issue26903
707         worker_count = min(worker_count, 61)
708     try:
709         executor = ProcessPoolExecutor(max_workers=worker_count)
710     except (ImportError, OSError):
711         # we arrive here if the underlying system does not support multi-processing
712         # like in AWS Lambda or Termux, in which case we gracefully fallback to
713         # a ThreadPollExecutor with just a single worker (more workers would not do us
714         # any good due to the Global Interpreter Lock)
715         executor = ThreadPoolExecutor(max_workers=1)
716
717     try:
718         loop.run_until_complete(
719             schedule_formatting(
720                 sources=sources,
721                 fast=fast,
722                 write_back=write_back,
723                 mode=mode,
724                 report=report,
725                 loop=loop,
726                 executor=executor,
727             )
728         )
729     finally:
730         shutdown(loop)
731         if executor is not None:
732             executor.shutdown()
733
734
735 async def schedule_formatting(
736     sources: Set[Path],
737     fast: bool,
738     write_back: WriteBack,
739     mode: Mode,
740     report: "Report",
741     loop: asyncio.AbstractEventLoop,
742     executor: Executor,
743 ) -> None:
744     """Run formatting of `sources` in parallel using the provided `executor`.
745
746     (Use ProcessPoolExecutors for actual parallelism.)
747
748     `write_back`, `fast`, and `mode` options are passed to
749     :func:`format_file_in_place`.
750     """
751     cache: Cache = {}
752     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
753         cache = read_cache(mode)
754         sources, cached = filter_cached(cache, sources)
755         for src in sorted(cached):
756             report.done(src, Changed.CACHED)
757     if not sources:
758         return
759
760     cancelled = []
761     sources_to_cache = []
762     lock = None
763     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
764         # For diff output, we need locks to ensure we don't interleave output
765         # from different processes.
766         manager = Manager()
767         lock = manager.Lock()
768     tasks = {
769         asyncio.ensure_future(
770             loop.run_in_executor(
771                 executor, format_file_in_place, src, fast, mode, write_back, lock
772             )
773         ): src
774         for src in sorted(sources)
775     }
776     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
777     try:
778         loop.add_signal_handler(signal.SIGINT, cancel, pending)
779         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
780     except NotImplementedError:
781         # There are no good alternatives for these on Windows.
782         pass
783     while pending:
784         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
785         for task in done:
786             src = tasks.pop(task)
787             if task.cancelled():
788                 cancelled.append(task)
789             elif task.exception():
790                 report.failed(src, str(task.exception()))
791             else:
792                 changed = Changed.YES if task.result() else Changed.NO
793                 # If the file was written back or was successfully checked as
794                 # well-formatted, store this information in the cache.
795                 if write_back is WriteBack.YES or (
796                     write_back is WriteBack.CHECK and changed is Changed.NO
797                 ):
798                     sources_to_cache.append(src)
799                 report.done(src, changed)
800     if cancelled:
801         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
802     if sources_to_cache:
803         write_cache(cache, sources_to_cache, mode)
804
805
806 def format_file_in_place(
807     src: Path,
808     fast: bool,
809     mode: Mode,
810     write_back: WriteBack = WriteBack.NO,
811     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
812 ) -> bool:
813     """Format file under `src` path. Return True if changed.
814
815     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
816     code to the file.
817     `mode` and `fast` options are passed to :func:`format_file_contents`.
818     """
819     if src.suffix == ".pyi":
820         mode = replace(mode, is_pyi=True)
821
822     then = datetime.utcfromtimestamp(src.stat().st_mtime)
823     with open(src, "rb") as buf:
824         src_contents, encoding, newline = decode_bytes(buf.read())
825     try:
826         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
827     except NothingChanged:
828         return False
829
830     if write_back == WriteBack.YES:
831         with open(src, "w", encoding=encoding, newline=newline) as f:
832             f.write(dst_contents)
833     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
834         now = datetime.utcnow()
835         src_name = f"{src}\t{then} +0000"
836         dst_name = f"{src}\t{now} +0000"
837         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
838
839         if write_back == write_back.COLOR_DIFF:
840             diff_contents = color_diff(diff_contents)
841
842         with lock or nullcontext():
843             f = io.TextIOWrapper(
844                 sys.stdout.buffer,
845                 encoding=encoding,
846                 newline=newline,
847                 write_through=True,
848             )
849             f = wrap_stream_for_windows(f)
850             f.write(diff_contents)
851             f.detach()
852
853     return True
854
855
856 def color_diff(contents: str) -> str:
857     """Inject the ANSI color codes to the diff."""
858     lines = contents.split("\n")
859     for i, line in enumerate(lines):
860         if line.startswith("+++") or line.startswith("---"):
861             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
862         if line.startswith("@@"):
863             line = "\033[36m" + line + "\033[0m"  # cyan, reset
864         if line.startswith("+"):
865             line = "\033[32m" + line + "\033[0m"  # green, reset
866         elif line.startswith("-"):
867             line = "\033[31m" + line + "\033[0m"  # red, reset
868         lines[i] = line
869     return "\n".join(lines)
870
871
872 def wrap_stream_for_windows(
873     f: io.TextIOWrapper,
874 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
875     """
876     Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
877
878     If `colorama` is not found, then no change is made. If `colorama` does
879     exist, then it handles the logic to determine whether or not to change
880     things.
881     """
882     try:
883         from colorama import initialise
884
885         # We set `strip=False` so that we can don't have to modify
886         # test_express_diff_with_color.
887         f = initialise.wrap_stream(
888             f, convert=None, strip=False, autoreset=False, wrap=True
889         )
890
891         # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
892         # which does not have a `detach()` method. So we fake one.
893         f.detach = lambda *args, **kwargs: None  # type: ignore
894     except ImportError:
895         pass
896
897     return f
898
899
900 def format_stdin_to_stdout(
901     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
902 ) -> bool:
903     """Format file on stdin. Return True if changed.
904
905     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
906     write a diff to stdout. The `mode` argument is passed to
907     :func:`format_file_contents`.
908     """
909     then = datetime.utcnow()
910     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
911     dst = src
912     try:
913         dst = format_file_contents(src, fast=fast, mode=mode)
914         return True
915
916     except NothingChanged:
917         return False
918
919     finally:
920         f = io.TextIOWrapper(
921             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
922         )
923         if write_back == WriteBack.YES:
924             f.write(dst)
925         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
926             now = datetime.utcnow()
927             src_name = f"STDIN\t{then} +0000"
928             dst_name = f"STDOUT\t{now} +0000"
929             d = diff(src, dst, src_name, dst_name)
930             if write_back == WriteBack.COLOR_DIFF:
931                 d = color_diff(d)
932                 f = wrap_stream_for_windows(f)
933             f.write(d)
934         f.detach()
935
936
937 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
938     """Reformat contents of a file and return new contents.
939
940     If `fast` is False, additionally confirm that the reformatted code is
941     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
942     `mode` is passed to :func:`format_str`.
943     """
944     if not src_contents.strip():
945         raise NothingChanged
946
947     dst_contents = format_str(src_contents, mode=mode)
948     if src_contents == dst_contents:
949         raise NothingChanged
950
951     if not fast:
952         assert_equivalent(src_contents, dst_contents)
953         assert_stable(src_contents, dst_contents, mode=mode)
954     return dst_contents
955
956
957 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
958     """Reformat a string and return new contents.
959
960     `mode` determines formatting options, such as how many characters per line are
961     allowed.  Example:
962
963     >>> import black
964     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
965     def f(arg: str = "") -> None:
966         ...
967
968     A more complex example:
969
970     >>> print(
971     ...   black.format_str(
972     ...     "def f(arg:str='')->None: hey",
973     ...     mode=black.Mode(
974     ...       target_versions={black.TargetVersion.PY36},
975     ...       line_length=10,
976     ...       string_normalization=False,
977     ...       is_pyi=False,
978     ...     ),
979     ...   ),
980     ... )
981     def f(
982         arg: str = '',
983     ) -> None:
984         hey
985
986     """
987     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
988     dst_contents = []
989     future_imports = get_future_imports(src_node)
990     if mode.target_versions:
991         versions = mode.target_versions
992     else:
993         versions = detect_target_versions(src_node)
994     normalize_fmt_off(src_node)
995     lines = LineGenerator(
996         remove_u_prefix="unicode_literals" in future_imports
997         or supports_feature(versions, Feature.UNICODE_LITERALS),
998         is_pyi=mode.is_pyi,
999         normalize_strings=mode.string_normalization,
1000     )
1001     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1002     empty_line = Line()
1003     after = 0
1004     split_line_features = {
1005         feature
1006         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1007         if supports_feature(versions, feature)
1008     }
1009     for current_line in lines.visit(src_node):
1010         dst_contents.append(str(empty_line) * after)
1011         before, after = elt.maybe_empty_lines(current_line)
1012         dst_contents.append(str(empty_line) * before)
1013         for line in transform_line(
1014             current_line, mode=mode, features=split_line_features
1015         ):
1016             dst_contents.append(str(line))
1017     return "".join(dst_contents)
1018
1019
1020 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1021     """Return a tuple of (decoded_contents, encoding, newline).
1022
1023     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1024     universal newlines (i.e. only contains LF).
1025     """
1026     srcbuf = io.BytesIO(src)
1027     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1028     if not lines:
1029         return "", encoding, "\n"
1030
1031     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1032     srcbuf.seek(0)
1033     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1034         return tiow.read(), encoding, newline
1035
1036
1037 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1038     if not target_versions:
1039         # No target_version specified, so try all grammars.
1040         return [
1041             # Python 3.7+
1042             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1043             # Python 3.0-3.6
1044             pygram.python_grammar_no_print_statement_no_exec_statement,
1045             # Python 2.7 with future print_function import
1046             pygram.python_grammar_no_print_statement,
1047             # Python 2.7
1048             pygram.python_grammar,
1049         ]
1050
1051     if all(version.is_python2() for version in target_versions):
1052         # Python 2-only code, so try Python 2 grammars.
1053         return [
1054             # Python 2.7 with future print_function import
1055             pygram.python_grammar_no_print_statement,
1056             # Python 2.7
1057             pygram.python_grammar,
1058         ]
1059
1060     # Python 3-compatible code, so only try Python 3 grammar.
1061     grammars = []
1062     # If we have to parse both, try to parse async as a keyword first
1063     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1064         # Python 3.7+
1065         grammars.append(
1066             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1067         )
1068     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1069         # Python 3.0-3.6
1070         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1071     # At least one of the above branches must have been taken, because every Python
1072     # version has exactly one of the two 'ASYNC_*' flags
1073     return grammars
1074
1075
1076 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1077     """Given a string with source, return the lib2to3 Node."""
1078     if not src_txt.endswith("\n"):
1079         src_txt += "\n"
1080
1081     for grammar in get_grammars(set(target_versions)):
1082         drv = driver.Driver(grammar, pytree.convert)
1083         try:
1084             result = drv.parse_string(src_txt, True)
1085             break
1086
1087         except ParseError as pe:
1088             lineno, column = pe.context[1]
1089             lines = src_txt.splitlines()
1090             try:
1091                 faulty_line = lines[lineno - 1]
1092             except IndexError:
1093                 faulty_line = "<line number missing in source>"
1094             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1095     else:
1096         raise exc from None
1097
1098     if isinstance(result, Leaf):
1099         result = Node(syms.file_input, [result])
1100     return result
1101
1102
1103 def lib2to3_unparse(node: Node) -> str:
1104     """Given a lib2to3 node, return its string representation."""
1105     code = str(node)
1106     return code
1107
1108
1109 class Visitor(Generic[T]):
1110     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1111
1112     def visit(self, node: LN) -> Iterator[T]:
1113         """Main method to visit `node` and its children.
1114
1115         It tries to find a `visit_*()` method for the given `node.type`, like
1116         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1117         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1118         instead.
1119
1120         Then yields objects of type `T` from the selected visitor.
1121         """
1122         if node.type < 256:
1123             name = token.tok_name[node.type]
1124         else:
1125             name = str(type_repr(node.type))
1126         # We explicitly branch on whether a visitor exists (instead of
1127         # using self.visit_default as the default arg to getattr) in order
1128         # to save needing to create a bound method object and so mypyc can
1129         # generate a native call to visit_default.
1130         visitf = getattr(self, f"visit_{name}", None)
1131         if visitf:
1132             yield from visitf(node)
1133         else:
1134             yield from self.visit_default(node)
1135
1136     def visit_default(self, node: LN) -> Iterator[T]:
1137         """Default `visit_*()` implementation. Recurses to children of `node`."""
1138         if isinstance(node, Node):
1139             for child in node.children:
1140                 yield from self.visit(child)
1141
1142
1143 @dataclass
1144 class DebugVisitor(Visitor[T]):
1145     tree_depth: int = 0
1146
1147     def visit_default(self, node: LN) -> Iterator[T]:
1148         indent = " " * (2 * self.tree_depth)
1149         if isinstance(node, Node):
1150             _type = type_repr(node.type)
1151             out(f"{indent}{_type}", fg="yellow")
1152             self.tree_depth += 1
1153             for child in node.children:
1154                 yield from self.visit(child)
1155
1156             self.tree_depth -= 1
1157             out(f"{indent}/{_type}", fg="yellow", bold=False)
1158         else:
1159             _type = token.tok_name.get(node.type, str(node.type))
1160             out(f"{indent}{_type}", fg="blue", nl=False)
1161             if node.prefix:
1162                 # We don't have to handle prefixes for `Node` objects since
1163                 # that delegates to the first child anyway.
1164                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1165             out(f" {node.value!r}", fg="blue", bold=False)
1166
1167     @classmethod
1168     def show(cls, code: Union[str, Leaf, Node]) -> None:
1169         """Pretty-print the lib2to3 AST of a given string of `code`.
1170
1171         Convenience method for debugging.
1172         """
1173         v: DebugVisitor[None] = DebugVisitor()
1174         if isinstance(code, str):
1175             code = lib2to3_parse(code)
1176         list(v.visit(code))
1177
1178
1179 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1180 STATEMENT: Final = {
1181     syms.if_stmt,
1182     syms.while_stmt,
1183     syms.for_stmt,
1184     syms.try_stmt,
1185     syms.except_clause,
1186     syms.with_stmt,
1187     syms.funcdef,
1188     syms.classdef,
1189 }
1190 STANDALONE_COMMENT: Final = 153
1191 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1192 LOGIC_OPERATORS: Final = {"and", "or"}
1193 COMPARATORS: Final = {
1194     token.LESS,
1195     token.GREATER,
1196     token.EQEQUAL,
1197     token.NOTEQUAL,
1198     token.LESSEQUAL,
1199     token.GREATEREQUAL,
1200 }
1201 MATH_OPERATORS: Final = {
1202     token.VBAR,
1203     token.CIRCUMFLEX,
1204     token.AMPER,
1205     token.LEFTSHIFT,
1206     token.RIGHTSHIFT,
1207     token.PLUS,
1208     token.MINUS,
1209     token.STAR,
1210     token.SLASH,
1211     token.DOUBLESLASH,
1212     token.PERCENT,
1213     token.AT,
1214     token.TILDE,
1215     token.DOUBLESTAR,
1216 }
1217 STARS: Final = {token.STAR, token.DOUBLESTAR}
1218 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1219 VARARGS_PARENTS: Final = {
1220     syms.arglist,
1221     syms.argument,  # double star in arglist
1222     syms.trailer,  # single argument to call
1223     syms.typedargslist,
1224     syms.varargslist,  # lambdas
1225 }
1226 UNPACKING_PARENTS: Final = {
1227     syms.atom,  # single element of a list or set literal
1228     syms.dictsetmaker,
1229     syms.listmaker,
1230     syms.testlist_gexp,
1231     syms.testlist_star_expr,
1232 }
1233 TEST_DESCENDANTS: Final = {
1234     syms.test,
1235     syms.lambdef,
1236     syms.or_test,
1237     syms.and_test,
1238     syms.not_test,
1239     syms.comparison,
1240     syms.star_expr,
1241     syms.expr,
1242     syms.xor_expr,
1243     syms.and_expr,
1244     syms.shift_expr,
1245     syms.arith_expr,
1246     syms.trailer,
1247     syms.term,
1248     syms.power,
1249 }
1250 ASSIGNMENTS: Final = {
1251     "=",
1252     "+=",
1253     "-=",
1254     "*=",
1255     "@=",
1256     "/=",
1257     "%=",
1258     "&=",
1259     "|=",
1260     "^=",
1261     "<<=",
1262     ">>=",
1263     "**=",
1264     "//=",
1265 }
1266 COMPREHENSION_PRIORITY: Final = 20
1267 COMMA_PRIORITY: Final = 18
1268 TERNARY_PRIORITY: Final = 16
1269 LOGIC_PRIORITY: Final = 14
1270 STRING_PRIORITY: Final = 12
1271 COMPARATOR_PRIORITY: Final = 10
1272 MATH_PRIORITIES: Final = {
1273     token.VBAR: 9,
1274     token.CIRCUMFLEX: 8,
1275     token.AMPER: 7,
1276     token.LEFTSHIFT: 6,
1277     token.RIGHTSHIFT: 6,
1278     token.PLUS: 5,
1279     token.MINUS: 5,
1280     token.STAR: 4,
1281     token.SLASH: 4,
1282     token.DOUBLESLASH: 4,
1283     token.PERCENT: 4,
1284     token.AT: 4,
1285     token.TILDE: 3,
1286     token.DOUBLESTAR: 2,
1287 }
1288 DOT_PRIORITY: Final = 1
1289
1290
1291 @dataclass
1292 class BracketTracker:
1293     """Keeps track of brackets on a line."""
1294
1295     depth: int = 0
1296     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1297     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1298     previous: Optional[Leaf] = None
1299     _for_loop_depths: List[int] = field(default_factory=list)
1300     _lambda_argument_depths: List[int] = field(default_factory=list)
1301     invisible: List[Leaf] = field(default_factory=list)
1302
1303     def mark(self, leaf: Leaf) -> None:
1304         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1305
1306         All leaves receive an int `bracket_depth` field that stores how deep
1307         within brackets a given leaf is. 0 means there are no enclosing brackets
1308         that started on this line.
1309
1310         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1311         field that it forms a pair with. This is a one-directional link to
1312         avoid reference cycles.
1313
1314         If a leaf is a delimiter (a token on which Black can split the line if
1315         needed) and it's on depth 0, its `id()` is stored in the tracker's
1316         `delimiters` field.
1317         """
1318         if leaf.type == token.COMMENT:
1319             return
1320
1321         self.maybe_decrement_after_for_loop_variable(leaf)
1322         self.maybe_decrement_after_lambda_arguments(leaf)
1323         if leaf.type in CLOSING_BRACKETS:
1324             self.depth -= 1
1325             try:
1326                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1327             except KeyError as e:
1328                 raise BracketMatchError(
1329                     "Unable to match a closing bracket to the following opening"
1330                     f" bracket: {leaf}"
1331                 ) from e
1332             leaf.opening_bracket = opening_bracket
1333             if not leaf.value:
1334                 self.invisible.append(leaf)
1335         leaf.bracket_depth = self.depth
1336         if self.depth == 0:
1337             delim = is_split_before_delimiter(leaf, self.previous)
1338             if delim and self.previous is not None:
1339                 self.delimiters[id(self.previous)] = delim
1340             else:
1341                 delim = is_split_after_delimiter(leaf, self.previous)
1342                 if delim:
1343                     self.delimiters[id(leaf)] = delim
1344         if leaf.type in OPENING_BRACKETS:
1345             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1346             self.depth += 1
1347             if not leaf.value:
1348                 self.invisible.append(leaf)
1349         self.previous = leaf
1350         self.maybe_increment_lambda_arguments(leaf)
1351         self.maybe_increment_for_loop_variable(leaf)
1352
1353     def any_open_brackets(self) -> bool:
1354         """Return True if there is an yet unmatched open bracket on the line."""
1355         return bool(self.bracket_match)
1356
1357     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1358         """Return the highest priority of a delimiter found on the line.
1359
1360         Values are consistent with what `is_split_*_delimiter()` return.
1361         Raises ValueError on no delimiters.
1362         """
1363         return max(v for k, v in self.delimiters.items() if k not in exclude)
1364
1365     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1366         """Return the number of delimiters with the given `priority`.
1367
1368         If no `priority` is passed, defaults to max priority on the line.
1369         """
1370         if not self.delimiters:
1371             return 0
1372
1373         priority = priority or self.max_delimiter_priority()
1374         return sum(1 for p in self.delimiters.values() if p == priority)
1375
1376     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1377         """In a for loop, or comprehension, the variables are often unpacks.
1378
1379         To avoid splitting on the comma in this situation, increase the depth of
1380         tokens between `for` and `in`.
1381         """
1382         if leaf.type == token.NAME and leaf.value == "for":
1383             self.depth += 1
1384             self._for_loop_depths.append(self.depth)
1385             return True
1386
1387         return False
1388
1389     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1390         """See `maybe_increment_for_loop_variable` above for explanation."""
1391         if (
1392             self._for_loop_depths
1393             and self._for_loop_depths[-1] == self.depth
1394             and leaf.type == token.NAME
1395             and leaf.value == "in"
1396         ):
1397             self.depth -= 1
1398             self._for_loop_depths.pop()
1399             return True
1400
1401         return False
1402
1403     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1404         """In a lambda expression, there might be more than one argument.
1405
1406         To avoid splitting on the comma in this situation, increase the depth of
1407         tokens between `lambda` and `:`.
1408         """
1409         if leaf.type == token.NAME and leaf.value == "lambda":
1410             self.depth += 1
1411             self._lambda_argument_depths.append(self.depth)
1412             return True
1413
1414         return False
1415
1416     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1417         """See `maybe_increment_lambda_arguments` above for explanation."""
1418         if (
1419             self._lambda_argument_depths
1420             and self._lambda_argument_depths[-1] == self.depth
1421             and leaf.type == token.COLON
1422         ):
1423             self.depth -= 1
1424             self._lambda_argument_depths.pop()
1425             return True
1426
1427         return False
1428
1429     def get_open_lsqb(self) -> Optional[Leaf]:
1430         """Return the most recent opening square bracket (if any)."""
1431         return self.bracket_match.get((self.depth - 1, token.RSQB))
1432
1433
1434 @dataclass
1435 class Line:
1436     """Holds leaves and comments. Can be printed with `str(line)`."""
1437
1438     depth: int = 0
1439     leaves: List[Leaf] = field(default_factory=list)
1440     # keys ordered like `leaves`
1441     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1442     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1443     inside_brackets: bool = False
1444     should_explode: bool = False
1445
1446     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1447         """Add a new `leaf` to the end of the line.
1448
1449         Unless `preformatted` is True, the `leaf` will receive a new consistent
1450         whitespace prefix and metadata applied by :class:`BracketTracker`.
1451         Trailing commas are maybe removed, unpacked for loop variables are
1452         demoted from being delimiters.
1453
1454         Inline comments are put aside.
1455         """
1456         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1457         if not has_value:
1458             return
1459
1460         if token.COLON == leaf.type and self.is_class_paren_empty:
1461             del self.leaves[-2:]
1462         if self.leaves and not preformatted:
1463             # Note: at this point leaf.prefix should be empty except for
1464             # imports, for which we only preserve newlines.
1465             leaf.prefix += whitespace(
1466                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1467             )
1468         if self.inside_brackets or not preformatted:
1469             self.bracket_tracker.mark(leaf)
1470             if self.maybe_should_explode(leaf):
1471                 self.should_explode = True
1472         if not self.append_comment(leaf):
1473             self.leaves.append(leaf)
1474
1475     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1476         """Like :func:`append()` but disallow invalid standalone comment structure.
1477
1478         Raises ValueError when any `leaf` is appended after a standalone comment
1479         or when a standalone comment is not the first leaf on the line.
1480         """
1481         if self.bracket_tracker.depth == 0:
1482             if self.is_comment:
1483                 raise ValueError("cannot append to standalone comments")
1484
1485             if self.leaves and leaf.type == STANDALONE_COMMENT:
1486                 raise ValueError(
1487                     "cannot append standalone comments to a populated line"
1488                 )
1489
1490         self.append(leaf, preformatted=preformatted)
1491
1492     @property
1493     def is_comment(self) -> bool:
1494         """Is this line a standalone comment?"""
1495         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1496
1497     @property
1498     def is_decorator(self) -> bool:
1499         """Is this line a decorator?"""
1500         return bool(self) and self.leaves[0].type == token.AT
1501
1502     @property
1503     def is_import(self) -> bool:
1504         """Is this an import line?"""
1505         return bool(self) and is_import(self.leaves[0])
1506
1507     @property
1508     def is_class(self) -> bool:
1509         """Is this line a class definition?"""
1510         return (
1511             bool(self)
1512             and self.leaves[0].type == token.NAME
1513             and self.leaves[0].value == "class"
1514         )
1515
1516     @property
1517     def is_stub_class(self) -> bool:
1518         """Is this line a class definition with a body consisting only of "..."?"""
1519         return self.is_class and self.leaves[-3:] == [
1520             Leaf(token.DOT, ".") for _ in range(3)
1521         ]
1522
1523     @property
1524     def is_def(self) -> bool:
1525         """Is this a function definition? (Also returns True for async defs.)"""
1526         try:
1527             first_leaf = self.leaves[0]
1528         except IndexError:
1529             return False
1530
1531         try:
1532             second_leaf: Optional[Leaf] = self.leaves[1]
1533         except IndexError:
1534             second_leaf = None
1535         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1536             first_leaf.type == token.ASYNC
1537             and second_leaf is not None
1538             and second_leaf.type == token.NAME
1539             and second_leaf.value == "def"
1540         )
1541
1542     @property
1543     def is_class_paren_empty(self) -> bool:
1544         """Is this a class with no base classes but using parentheses?
1545
1546         Those are unnecessary and should be removed.
1547         """
1548         return (
1549             bool(self)
1550             and len(self.leaves) == 4
1551             and self.is_class
1552             and self.leaves[2].type == token.LPAR
1553             and self.leaves[2].value == "("
1554             and self.leaves[3].type == token.RPAR
1555             and self.leaves[3].value == ")"
1556         )
1557
1558     @property
1559     def is_triple_quoted_string(self) -> bool:
1560         """Is the line a triple quoted string?"""
1561         return (
1562             bool(self)
1563             and self.leaves[0].type == token.STRING
1564             and self.leaves[0].value.startswith(('"""', "'''"))
1565         )
1566
1567     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1568         """If so, needs to be split before emitting."""
1569         for leaf in self.leaves:
1570             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1571                 return True
1572
1573         return False
1574
1575     def contains_uncollapsable_type_comments(self) -> bool:
1576         ignored_ids = set()
1577         try:
1578             last_leaf = self.leaves[-1]
1579             ignored_ids.add(id(last_leaf))
1580             if last_leaf.type == token.COMMA or (
1581                 last_leaf.type == token.RPAR and not last_leaf.value
1582             ):
1583                 # When trailing commas or optional parens are inserted by Black for
1584                 # consistency, comments after the previous last element are not moved
1585                 # (they don't have to, rendering will still be correct).  So we ignore
1586                 # trailing commas and invisible.
1587                 last_leaf = self.leaves[-2]
1588                 ignored_ids.add(id(last_leaf))
1589         except IndexError:
1590             return False
1591
1592         # A type comment is uncollapsable if it is attached to a leaf
1593         # that isn't at the end of the line (since that could cause it
1594         # to get associated to a different argument) or if there are
1595         # comments before it (since that could cause it to get hidden
1596         # behind a comment.
1597         comment_seen = False
1598         for leaf_id, comments in self.comments.items():
1599             for comment in comments:
1600                 if is_type_comment(comment):
1601                     if comment_seen or (
1602                         not is_type_comment(comment, " ignore")
1603                         and leaf_id not in ignored_ids
1604                     ):
1605                         return True
1606
1607                 comment_seen = True
1608
1609         return False
1610
1611     def contains_unsplittable_type_ignore(self) -> bool:
1612         if not self.leaves:
1613             return False
1614
1615         # If a 'type: ignore' is attached to the end of a line, we
1616         # can't split the line, because we can't know which of the
1617         # subexpressions the ignore was meant to apply to.
1618         #
1619         # We only want this to apply to actual physical lines from the
1620         # original source, though: we don't want the presence of a
1621         # 'type: ignore' at the end of a multiline expression to
1622         # justify pushing it all onto one line. Thus we
1623         # (unfortunately) need to check the actual source lines and
1624         # only report an unsplittable 'type: ignore' if this line was
1625         # one line in the original code.
1626
1627         # Grab the first and last line numbers, skipping generated leaves
1628         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1629         last_line = next(
1630             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1631         )
1632
1633         if first_line == last_line:
1634             # We look at the last two leaves since a comma or an
1635             # invisible paren could have been added at the end of the
1636             # line.
1637             for node in self.leaves[-2:]:
1638                 for comment in self.comments.get(id(node), []):
1639                     if is_type_comment(comment, " ignore"):
1640                         return True
1641
1642         return False
1643
1644     def contains_multiline_strings(self) -> bool:
1645         return any(is_multiline_string(leaf) for leaf in self.leaves)
1646
1647     def maybe_should_explode(self, closing: Leaf) -> bool:
1648         """Return True if this line should explode (always be split), that is when:
1649         - there's a trailing comma here; and
1650         - it's not a one-tuple.
1651         """
1652         if not (
1653             closing.type in CLOSING_BRACKETS
1654             and self.leaves
1655             and self.leaves[-1].type == token.COMMA
1656         ):
1657             return False
1658
1659         if closing.type in {token.RBRACE, token.RSQB}:
1660             return True
1661
1662         if self.is_import:
1663             return True
1664
1665         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1666             return True
1667
1668         return False
1669
1670     def append_comment(self, comment: Leaf) -> bool:
1671         """Add an inline or standalone comment to the line."""
1672         if (
1673             comment.type == STANDALONE_COMMENT
1674             and self.bracket_tracker.any_open_brackets()
1675         ):
1676             comment.prefix = ""
1677             return False
1678
1679         if comment.type != token.COMMENT:
1680             return False
1681
1682         if not self.leaves:
1683             comment.type = STANDALONE_COMMENT
1684             comment.prefix = ""
1685             return False
1686
1687         last_leaf = self.leaves[-1]
1688         if (
1689             last_leaf.type == token.RPAR
1690             and not last_leaf.value
1691             and last_leaf.parent
1692             and len(list(last_leaf.parent.leaves())) <= 3
1693             and not is_type_comment(comment)
1694         ):
1695             # Comments on an optional parens wrapping a single leaf should belong to
1696             # the wrapped node except if it's a type comment. Pinning the comment like
1697             # this avoids unstable formatting caused by comment migration.
1698             if len(self.leaves) < 2:
1699                 comment.type = STANDALONE_COMMENT
1700                 comment.prefix = ""
1701                 return False
1702
1703             last_leaf = self.leaves[-2]
1704         self.comments.setdefault(id(last_leaf), []).append(comment)
1705         return True
1706
1707     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1708         """Generate comments that should appear directly after `leaf`."""
1709         return self.comments.get(id(leaf), [])
1710
1711     def remove_trailing_comma(self) -> None:
1712         """Remove the trailing comma and moves the comments attached to it."""
1713         trailing_comma = self.leaves.pop()
1714         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1715         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1716             trailing_comma_comments
1717         )
1718
1719     def is_complex_subscript(self, leaf: Leaf) -> bool:
1720         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1721         open_lsqb = self.bracket_tracker.get_open_lsqb()
1722         if open_lsqb is None:
1723             return False
1724
1725         subscript_start = open_lsqb.next_sibling
1726
1727         if isinstance(subscript_start, Node):
1728             if subscript_start.type == syms.listmaker:
1729                 return False
1730
1731             if subscript_start.type == syms.subscriptlist:
1732                 subscript_start = child_towards(subscript_start, leaf)
1733         return subscript_start is not None and any(
1734             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1735         )
1736
1737     def clone(self) -> "Line":
1738         return Line(
1739             depth=self.depth,
1740             inside_brackets=self.inside_brackets,
1741             should_explode=self.should_explode,
1742         )
1743
1744     def __str__(self) -> str:
1745         """Render the line."""
1746         if not self:
1747             return "\n"
1748
1749         indent = "    " * self.depth
1750         leaves = iter(self.leaves)
1751         first = next(leaves)
1752         res = f"{first.prefix}{indent}{first.value}"
1753         for leaf in leaves:
1754             res += str(leaf)
1755         for comment in itertools.chain.from_iterable(self.comments.values()):
1756             res += str(comment)
1757
1758         return res + "\n"
1759
1760     def __bool__(self) -> bool:
1761         """Return True if the line has leaves or comments."""
1762         return bool(self.leaves or self.comments)
1763
1764
1765 @dataclass
1766 class EmptyLineTracker:
1767     """Provides a stateful method that returns the number of potential extra
1768     empty lines needed before and after the currently processed line.
1769
1770     Note: this tracker works on lines that haven't been split yet.  It assumes
1771     the prefix of the first leaf consists of optional newlines.  Those newlines
1772     are consumed by `maybe_empty_lines()` and included in the computation.
1773     """
1774
1775     is_pyi: bool = False
1776     previous_line: Optional[Line] = None
1777     previous_after: int = 0
1778     previous_defs: List[int] = field(default_factory=list)
1779
1780     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1781         """Return the number of extra empty lines before and after the `current_line`.
1782
1783         This is for separating `def`, `async def` and `class` with extra empty
1784         lines (two on module-level).
1785         """
1786         before, after = self._maybe_empty_lines(current_line)
1787         before = (
1788             # Black should not insert empty lines at the beginning
1789             # of the file
1790             0
1791             if self.previous_line is None
1792             else before - self.previous_after
1793         )
1794         self.previous_after = after
1795         self.previous_line = current_line
1796         return before, after
1797
1798     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1799         max_allowed = 1
1800         if current_line.depth == 0:
1801             max_allowed = 1 if self.is_pyi else 2
1802         if current_line.leaves:
1803             # Consume the first leaf's extra newlines.
1804             first_leaf = current_line.leaves[0]
1805             before = first_leaf.prefix.count("\n")
1806             before = min(before, max_allowed)
1807             first_leaf.prefix = ""
1808         else:
1809             before = 0
1810         depth = current_line.depth
1811         while self.previous_defs and self.previous_defs[-1] >= depth:
1812             self.previous_defs.pop()
1813             if self.is_pyi:
1814                 before = 0 if depth else 1
1815             else:
1816                 before = 1 if depth else 2
1817         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1818             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1819
1820         if (
1821             self.previous_line
1822             and self.previous_line.is_import
1823             and not current_line.is_import
1824             and depth == self.previous_line.depth
1825         ):
1826             return (before or 1), 0
1827
1828         if (
1829             self.previous_line
1830             and self.previous_line.is_class
1831             and current_line.is_triple_quoted_string
1832         ):
1833             return before, 1
1834
1835         return before, 0
1836
1837     def _maybe_empty_lines_for_class_or_def(
1838         self, current_line: Line, before: int
1839     ) -> Tuple[int, int]:
1840         if not current_line.is_decorator:
1841             self.previous_defs.append(current_line.depth)
1842         if self.previous_line is None:
1843             # Don't insert empty lines before the first line in the file.
1844             return 0, 0
1845
1846         if self.previous_line.is_decorator:
1847             if self.is_pyi and current_line.is_stub_class:
1848                 # Insert an empty line after a decorated stub class
1849                 return 0, 1
1850
1851             return 0, 0
1852
1853         if self.previous_line.depth < current_line.depth and (
1854             self.previous_line.is_class or self.previous_line.is_def
1855         ):
1856             return 0, 0
1857
1858         if (
1859             self.previous_line.is_comment
1860             and self.previous_line.depth == current_line.depth
1861             and before == 0
1862         ):
1863             return 0, 0
1864
1865         if self.is_pyi:
1866             if self.previous_line.depth > current_line.depth:
1867                 newlines = 1
1868             elif current_line.is_class or self.previous_line.is_class:
1869                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1870                     # No blank line between classes with an empty body
1871                     newlines = 0
1872                 else:
1873                     newlines = 1
1874             elif (
1875                 current_line.is_def or current_line.is_decorator
1876             ) and not self.previous_line.is_def:
1877                 # Blank line between a block of functions (maybe with preceding
1878                 # decorators) and a block of non-functions
1879                 newlines = 1
1880             else:
1881                 newlines = 0
1882         else:
1883             newlines = 2
1884         if current_line.depth and newlines:
1885             newlines -= 1
1886         return newlines, 0
1887
1888
1889 @dataclass
1890 class LineGenerator(Visitor[Line]):
1891     """Generates reformatted Line objects.  Empty lines are not emitted.
1892
1893     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1894     in ways that will no longer stringify to valid Python code on the tree.
1895     """
1896
1897     is_pyi: bool = False
1898     normalize_strings: bool = True
1899     current_line: Line = field(default_factory=Line)
1900     remove_u_prefix: bool = False
1901
1902     def line(self, indent: int = 0) -> Iterator[Line]:
1903         """Generate a line.
1904
1905         If the line is empty, only emit if it makes sense.
1906         If the line is too long, split it first and then generate.
1907
1908         If any lines were generated, set up a new current_line.
1909         """
1910         if not self.current_line:
1911             self.current_line.depth += indent
1912             return  # Line is empty, don't emit. Creating a new one unnecessary.
1913
1914         complete_line = self.current_line
1915         self.current_line = Line(depth=complete_line.depth + indent)
1916         yield complete_line
1917
1918     def visit_default(self, node: LN) -> Iterator[Line]:
1919         """Default `visit_*()` implementation. Recurses to children of `node`."""
1920         if isinstance(node, Leaf):
1921             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1922             for comment in generate_comments(node):
1923                 if any_open_brackets:
1924                     # any comment within brackets is subject to splitting
1925                     self.current_line.append(comment)
1926                 elif comment.type == token.COMMENT:
1927                     # regular trailing comment
1928                     self.current_line.append(comment)
1929                     yield from self.line()
1930
1931                 else:
1932                     # regular standalone comment
1933                     yield from self.line()
1934
1935                     self.current_line.append(comment)
1936                     yield from self.line()
1937
1938             normalize_prefix(node, inside_brackets=any_open_brackets)
1939             if self.normalize_strings and node.type == token.STRING:
1940                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1941                 normalize_string_quotes(node)
1942             if node.type == token.NUMBER:
1943                 normalize_numeric_literal(node)
1944             if node.type not in WHITESPACE:
1945                 self.current_line.append(node)
1946         yield from super().visit_default(node)
1947
1948     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1949         """Increase indentation level, maybe yield a line."""
1950         # In blib2to3 INDENT never holds comments.
1951         yield from self.line(+1)
1952         yield from self.visit_default(node)
1953
1954     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1955         """Decrease indentation level, maybe yield a line."""
1956         # The current line might still wait for trailing comments.  At DEDENT time
1957         # there won't be any (they would be prefixes on the preceding NEWLINE).
1958         # Emit the line then.
1959         yield from self.line()
1960
1961         # While DEDENT has no value, its prefix may contain standalone comments
1962         # that belong to the current indentation level.  Get 'em.
1963         yield from self.visit_default(node)
1964
1965         # Finally, emit the dedent.
1966         yield from self.line(-1)
1967
1968     def visit_stmt(
1969         self, node: Node, keywords: Set[str], parens: Set[str]
1970     ) -> Iterator[Line]:
1971         """Visit a statement.
1972
1973         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1974         `def`, `with`, `class`, `assert` and assignments.
1975
1976         The relevant Python language `keywords` for a given statement will be
1977         NAME leaves within it. This methods puts those on a separate line.
1978
1979         `parens` holds a set of string leaf values immediately after which
1980         invisible parens should be put.
1981         """
1982         normalize_invisible_parens(node, parens_after=parens)
1983         for child in node.children:
1984             if child.type == token.NAME and child.value in keywords:  # type: ignore
1985                 yield from self.line()
1986
1987             yield from self.visit(child)
1988
1989     def visit_suite(self, node: Node) -> Iterator[Line]:
1990         """Visit a suite."""
1991         if self.is_pyi and is_stub_suite(node):
1992             yield from self.visit(node.children[2])
1993         else:
1994             yield from self.visit_default(node)
1995
1996     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1997         """Visit a statement without nested statements."""
1998         is_suite_like = node.parent and node.parent.type in STATEMENT
1999         if is_suite_like:
2000             if self.is_pyi and is_stub_body(node):
2001                 yield from self.visit_default(node)
2002             else:
2003                 yield from self.line(+1)
2004                 yield from self.visit_default(node)
2005                 yield from self.line(-1)
2006
2007         else:
2008             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
2009                 yield from self.line()
2010             yield from self.visit_default(node)
2011
2012     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2013         """Visit `async def`, `async for`, `async with`."""
2014         yield from self.line()
2015
2016         children = iter(node.children)
2017         for child in children:
2018             yield from self.visit(child)
2019
2020             if child.type == token.ASYNC:
2021                 break
2022
2023         internal_stmt = next(children)
2024         for child in internal_stmt.children:
2025             yield from self.visit(child)
2026
2027     def visit_decorators(self, node: Node) -> Iterator[Line]:
2028         """Visit decorators."""
2029         for child in node.children:
2030             yield from self.line()
2031             yield from self.visit(child)
2032
2033     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2034         """Remove a semicolon and put the other statement on a separate line."""
2035         yield from self.line()
2036
2037     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2038         """End of file. Process outstanding comments and end with a newline."""
2039         yield from self.visit_default(leaf)
2040         yield from self.line()
2041
2042     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2043         if not self.current_line.bracket_tracker.any_open_brackets():
2044             yield from self.line()
2045         yield from self.visit_default(leaf)
2046
2047     def visit_factor(self, node: Node) -> Iterator[Line]:
2048         """Force parentheses between a unary op and a binary power:
2049
2050         -2 ** 8 -> -(2 ** 8)
2051         """
2052         _operator, operand = node.children
2053         if (
2054             operand.type == syms.power
2055             and len(operand.children) == 3
2056             and operand.children[1].type == token.DOUBLESTAR
2057         ):
2058             lpar = Leaf(token.LPAR, "(")
2059             rpar = Leaf(token.RPAR, ")")
2060             index = operand.remove() or 0
2061             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2062         yield from self.visit_default(node)
2063
2064     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2065         if is_docstring(leaf) and "\\\n" not in leaf.value:
2066             # We're ignoring docstrings with backslash newline escapes because changing
2067             # indentation of those changes the AST representation of the code.
2068             prefix = get_string_prefix(leaf.value)
2069             lead_len = len(prefix) + 3
2070             tail_len = -3
2071             indent = " " * 4 * self.current_line.depth
2072             docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
2073             if docstring:
2074                 if leaf.value[lead_len - 1] == docstring[0]:
2075                     docstring = " " + docstring
2076                 if leaf.value[tail_len + 1] == docstring[-1]:
2077                     docstring = docstring + " "
2078             leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
2079
2080         yield from self.visit_default(leaf)
2081
2082     def __post_init__(self) -> None:
2083         """You are in a twisty little maze of passages."""
2084         v = self.visit_stmt
2085         Ø: Set[str] = set()
2086         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2087         self.visit_if_stmt = partial(
2088             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2089         )
2090         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2091         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2092         self.visit_try_stmt = partial(
2093             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2094         )
2095         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2096         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2097         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2098         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2099         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2100         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2101         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2102         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2103         self.visit_async_funcdef = self.visit_async_stmt
2104         self.visit_decorated = self.visit_decorators
2105
2106
2107 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2108 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2109 OPENING_BRACKETS = set(BRACKET.keys())
2110 CLOSING_BRACKETS = set(BRACKET.values())
2111 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2112 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2113
2114
2115 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2116     """Return whitespace prefix if needed for the given `leaf`.
2117
2118     `complex_subscript` signals whether the given leaf is part of a subscription
2119     which has non-trivial arguments, like arithmetic expressions or function calls.
2120     """
2121     NO = ""
2122     SPACE = " "
2123     DOUBLESPACE = "  "
2124     t = leaf.type
2125     p = leaf.parent
2126     v = leaf.value
2127     if t in ALWAYS_NO_SPACE:
2128         return NO
2129
2130     if t == token.COMMENT:
2131         return DOUBLESPACE
2132
2133     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2134     if t == token.COLON and p.type not in {
2135         syms.subscript,
2136         syms.subscriptlist,
2137         syms.sliceop,
2138     }:
2139         return NO
2140
2141     prev = leaf.prev_sibling
2142     if not prev:
2143         prevp = preceding_leaf(p)
2144         if not prevp or prevp.type in OPENING_BRACKETS:
2145             return NO
2146
2147         if t == token.COLON:
2148             if prevp.type == token.COLON:
2149                 return NO
2150
2151             elif prevp.type != token.COMMA and not complex_subscript:
2152                 return NO
2153
2154             return SPACE
2155
2156         if prevp.type == token.EQUAL:
2157             if prevp.parent:
2158                 if prevp.parent.type in {
2159                     syms.arglist,
2160                     syms.argument,
2161                     syms.parameters,
2162                     syms.varargslist,
2163                 }:
2164                     return NO
2165
2166                 elif prevp.parent.type == syms.typedargslist:
2167                     # A bit hacky: if the equal sign has whitespace, it means we
2168                     # previously found it's a typed argument.  So, we're using
2169                     # that, too.
2170                     return prevp.prefix
2171
2172         elif prevp.type in VARARGS_SPECIALS:
2173             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2174                 return NO
2175
2176         elif prevp.type == token.COLON:
2177             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2178                 return SPACE if complex_subscript else NO
2179
2180         elif (
2181             prevp.parent
2182             and prevp.parent.type == syms.factor
2183             and prevp.type in MATH_OPERATORS
2184         ):
2185             return NO
2186
2187         elif (
2188             prevp.type == token.RIGHTSHIFT
2189             and prevp.parent
2190             and prevp.parent.type == syms.shift_expr
2191             and prevp.prev_sibling
2192             and prevp.prev_sibling.type == token.NAME
2193             and prevp.prev_sibling.value == "print"  # type: ignore
2194         ):
2195             # Python 2 print chevron
2196             return NO
2197         elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
2198             # no space in decorators
2199             return NO
2200
2201     elif prev.type in OPENING_BRACKETS:
2202         return NO
2203
2204     if p.type in {syms.parameters, syms.arglist}:
2205         # untyped function signatures or calls
2206         if not prev or prev.type != token.COMMA:
2207             return NO
2208
2209     elif p.type == syms.varargslist:
2210         # lambdas
2211         if prev and prev.type != token.COMMA:
2212             return NO
2213
2214     elif p.type == syms.typedargslist:
2215         # typed function signatures
2216         if not prev:
2217             return NO
2218
2219         if t == token.EQUAL:
2220             if prev.type != syms.tname:
2221                 return NO
2222
2223         elif prev.type == token.EQUAL:
2224             # A bit hacky: if the equal sign has whitespace, it means we
2225             # previously found it's a typed argument.  So, we're using that, too.
2226             return prev.prefix
2227
2228         elif prev.type != token.COMMA:
2229             return NO
2230
2231     elif p.type == syms.tname:
2232         # type names
2233         if not prev:
2234             prevp = preceding_leaf(p)
2235             if not prevp or prevp.type != token.COMMA:
2236                 return NO
2237
2238     elif p.type == syms.trailer:
2239         # attributes and calls
2240         if t == token.LPAR or t == token.RPAR:
2241             return NO
2242
2243         if not prev:
2244             if t == token.DOT:
2245                 prevp = preceding_leaf(p)
2246                 if not prevp or prevp.type != token.NUMBER:
2247                     return NO
2248
2249             elif t == token.LSQB:
2250                 return NO
2251
2252         elif prev.type != token.COMMA:
2253             return NO
2254
2255     elif p.type == syms.argument:
2256         # single argument
2257         if t == token.EQUAL:
2258             return NO
2259
2260         if not prev:
2261             prevp = preceding_leaf(p)
2262             if not prevp or prevp.type == token.LPAR:
2263                 return NO
2264
2265         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2266             return NO
2267
2268     elif p.type == syms.decorator:
2269         # decorators
2270         return NO
2271
2272     elif p.type == syms.dotted_name:
2273         if prev:
2274             return NO
2275
2276         prevp = preceding_leaf(p)
2277         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2278             return NO
2279
2280     elif p.type == syms.classdef:
2281         if t == token.LPAR:
2282             return NO
2283
2284         if prev and prev.type == token.LPAR:
2285             return NO
2286
2287     elif p.type in {syms.subscript, syms.sliceop}:
2288         # indexing
2289         if not prev:
2290             assert p.parent is not None, "subscripts are always parented"
2291             if p.parent.type == syms.subscriptlist:
2292                 return SPACE
2293
2294             return NO
2295
2296         elif not complex_subscript:
2297             return NO
2298
2299     elif p.type == syms.atom:
2300         if prev and t == token.DOT:
2301             # dots, but not the first one.
2302             return NO
2303
2304     elif p.type == syms.dictsetmaker:
2305         # dict unpacking
2306         if prev and prev.type == token.DOUBLESTAR:
2307             return NO
2308
2309     elif p.type in {syms.factor, syms.star_expr}:
2310         # unary ops
2311         if not prev:
2312             prevp = preceding_leaf(p)
2313             if not prevp or prevp.type in OPENING_BRACKETS:
2314                 return NO
2315
2316             prevp_parent = prevp.parent
2317             assert prevp_parent is not None
2318             if prevp.type == token.COLON and prevp_parent.type in {
2319                 syms.subscript,
2320                 syms.sliceop,
2321             }:
2322                 return NO
2323
2324             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2325                 return NO
2326
2327         elif t in {token.NAME, token.NUMBER, token.STRING}:
2328             return NO
2329
2330     elif p.type == syms.import_from:
2331         if t == token.DOT:
2332             if prev and prev.type == token.DOT:
2333                 return NO
2334
2335         elif t == token.NAME:
2336             if v == "import":
2337                 return SPACE
2338
2339             if prev and prev.type == token.DOT:
2340                 return NO
2341
2342     elif p.type == syms.sliceop:
2343         return NO
2344
2345     return SPACE
2346
2347
2348 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2349     """Return the first leaf that precedes `node`, if any."""
2350     while node:
2351         res = node.prev_sibling
2352         if res:
2353             if isinstance(res, Leaf):
2354                 return res
2355
2356             try:
2357                 return list(res.leaves())[-1]
2358
2359             except IndexError:
2360                 return None
2361
2362         node = node.parent
2363     return None
2364
2365
2366 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2367     """Return if the `node` and its previous siblings match types against the provided
2368     list of tokens; the provided `node`has its type matched against the last element in
2369     the list.  `None` can be used as the first element to declare that the start of the
2370     list is anchored at the start of its parent's children."""
2371     if not tokens:
2372         return True
2373     if tokens[-1] is None:
2374         return node is None
2375     if not node:
2376         return False
2377     if node.type != tokens[-1]:
2378         return False
2379     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2380
2381
2382 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2383     """Return the child of `ancestor` that contains `descendant`."""
2384     node: Optional[LN] = descendant
2385     while node and node.parent != ancestor:
2386         node = node.parent
2387     return node
2388
2389
2390 def container_of(leaf: Leaf) -> LN:
2391     """Return `leaf` or one of its ancestors that is the topmost container of it.
2392
2393     By "container" we mean a node where `leaf` is the very first child.
2394     """
2395     same_prefix = leaf.prefix
2396     container: LN = leaf
2397     while container:
2398         parent = container.parent
2399         if parent is None:
2400             break
2401
2402         if parent.children[0].prefix != same_prefix:
2403             break
2404
2405         if parent.type == syms.file_input:
2406             break
2407
2408         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2409             break
2410
2411         container = parent
2412     return container
2413
2414
2415 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2416     """Return the priority of the `leaf` delimiter, given a line break after it.
2417
2418     The delimiter priorities returned here are from those delimiters that would
2419     cause a line break after themselves.
2420
2421     Higher numbers are higher priority.
2422     """
2423     if leaf.type == token.COMMA:
2424         return COMMA_PRIORITY
2425
2426     return 0
2427
2428
2429 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2430     """Return the priority of the `leaf` delimiter, given a line break before it.
2431
2432     The delimiter priorities returned here are from those delimiters that would
2433     cause a line break before themselves.
2434
2435     Higher numbers are higher priority.
2436     """
2437     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2438         # * and ** might also be MATH_OPERATORS but in this case they are not.
2439         # Don't treat them as a delimiter.
2440         return 0
2441
2442     if (
2443         leaf.type == token.DOT
2444         and leaf.parent
2445         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2446         and (previous is None or previous.type in CLOSING_BRACKETS)
2447     ):
2448         return DOT_PRIORITY
2449
2450     if (
2451         leaf.type in MATH_OPERATORS
2452         and leaf.parent
2453         and leaf.parent.type not in {syms.factor, syms.star_expr}
2454     ):
2455         return MATH_PRIORITIES[leaf.type]
2456
2457     if leaf.type in COMPARATORS:
2458         return COMPARATOR_PRIORITY
2459
2460     if (
2461         leaf.type == token.STRING
2462         and previous is not None
2463         and previous.type == token.STRING
2464     ):
2465         return STRING_PRIORITY
2466
2467     if leaf.type not in {token.NAME, token.ASYNC}:
2468         return 0
2469
2470     if (
2471         leaf.value == "for"
2472         and leaf.parent
2473         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2474         or leaf.type == token.ASYNC
2475     ):
2476         if (
2477             not isinstance(leaf.prev_sibling, Leaf)
2478             or leaf.prev_sibling.value != "async"
2479         ):
2480             return COMPREHENSION_PRIORITY
2481
2482     if (
2483         leaf.value == "if"
2484         and leaf.parent
2485         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2486     ):
2487         return COMPREHENSION_PRIORITY
2488
2489     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2490         return TERNARY_PRIORITY
2491
2492     if leaf.value == "is":
2493         return COMPARATOR_PRIORITY
2494
2495     if (
2496         leaf.value == "in"
2497         and leaf.parent
2498         and leaf.parent.type in {syms.comp_op, syms.comparison}
2499         and not (
2500             previous is not None
2501             and previous.type == token.NAME
2502             and previous.value == "not"
2503         )
2504     ):
2505         return COMPARATOR_PRIORITY
2506
2507     if (
2508         leaf.value == "not"
2509         and leaf.parent
2510         and leaf.parent.type == syms.comp_op
2511         and not (
2512             previous is not None
2513             and previous.type == token.NAME
2514             and previous.value == "is"
2515         )
2516     ):
2517         return COMPARATOR_PRIORITY
2518
2519     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2520         return LOGIC_PRIORITY
2521
2522     return 0
2523
2524
2525 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2526 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2527
2528
2529 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2530     """Clean the prefix of the `leaf` and generate comments from it, if any.
2531
2532     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2533     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2534     move because it does away with modifying the grammar to include all the
2535     possible places in which comments can be placed.
2536
2537     The sad consequence for us though is that comments don't "belong" anywhere.
2538     This is why this function generates simple parentless Leaf objects for
2539     comments.  We simply don't know what the correct parent should be.
2540
2541     No matter though, we can live without this.  We really only need to
2542     differentiate between inline and standalone comments.  The latter don't
2543     share the line with any code.
2544
2545     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2546     are emitted with a fake STANDALONE_COMMENT token identifier.
2547     """
2548     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2549         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2550
2551
2552 @dataclass
2553 class ProtoComment:
2554     """Describes a piece of syntax that is a comment.
2555
2556     It's not a :class:`blib2to3.pytree.Leaf` so that:
2557
2558     * it can be cached (`Leaf` objects should not be reused more than once as
2559       they store their lineno, column, prefix, and parent information);
2560     * `newlines` and `consumed` fields are kept separate from the `value`. This
2561       simplifies handling of special marker comments like ``# fmt: off/on``.
2562     """
2563
2564     type: int  # token.COMMENT or STANDALONE_COMMENT
2565     value: str  # content of the comment
2566     newlines: int  # how many newlines before the comment
2567     consumed: int  # how many characters of the original leaf's prefix did we consume
2568
2569
2570 @lru_cache(maxsize=4096)
2571 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2572     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2573     result: List[ProtoComment] = []
2574     if not prefix or "#" not in prefix:
2575         return result
2576
2577     consumed = 0
2578     nlines = 0
2579     ignored_lines = 0
2580     for index, line in enumerate(prefix.split("\n")):
2581         consumed += len(line) + 1  # adding the length of the split '\n'
2582         line = line.lstrip()
2583         if not line:
2584             nlines += 1
2585         if not line.startswith("#"):
2586             # Escaped newlines outside of a comment are not really newlines at
2587             # all. We treat a single-line comment following an escaped newline
2588             # as a simple trailing comment.
2589             if line.endswith("\\"):
2590                 ignored_lines += 1
2591             continue
2592
2593         if index == ignored_lines and not is_endmarker:
2594             comment_type = token.COMMENT  # simple trailing comment
2595         else:
2596             comment_type = STANDALONE_COMMENT
2597         comment = make_comment(line)
2598         result.append(
2599             ProtoComment(
2600                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2601             )
2602         )
2603         nlines = 0
2604     return result
2605
2606
2607 def make_comment(content: str) -> str:
2608     """Return a consistently formatted comment from the given `content` string.
2609
2610     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2611     space between the hash sign and the content.
2612
2613     If `content` didn't start with a hash sign, one is provided.
2614     """
2615     content = content.rstrip()
2616     if not content:
2617         return "#"
2618
2619     if content[0] == "#":
2620         content = content[1:]
2621     if content and content[0] not in " !:#'%":
2622         content = " " + content
2623     return "#" + content
2624
2625
2626 def transform_line(
2627     line: Line, mode: Mode, features: Collection[Feature] = ()
2628 ) -> Iterator[Line]:
2629     """Transform a `line`, potentially splitting it into many lines.
2630
2631     They should fit in the allotted `line_length` but might not be able to.
2632
2633     `features` are syntactical features that may be used in the output.
2634     """
2635     if line.is_comment:
2636         yield line
2637         return
2638
2639     line_str = line_to_string(line)
2640
2641     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2642         """Initialize StringTransformer"""
2643         return ST(mode.line_length, mode.string_normalization)
2644
2645     string_merge = init_st(StringMerger)
2646     string_paren_strip = init_st(StringParenStripper)
2647     string_split = init_st(StringSplitter)
2648     string_paren_wrap = init_st(StringParenWrapper)
2649
2650     transformers: List[Transformer]
2651     if (
2652         not line.contains_uncollapsable_type_comments()
2653         and not line.should_explode
2654         and (
2655             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2656             or line.contains_unsplittable_type_ignore()
2657         )
2658         and not (line.inside_brackets and line.contains_standalone_comments())
2659     ):
2660         # Only apply basic string preprocessing, since lines shouldn't be split here.
2661         if mode.experimental_string_processing:
2662             transformers = [string_merge, string_paren_strip]
2663         else:
2664             transformers = []
2665     elif line.is_def:
2666         transformers = [left_hand_split]
2667     else:
2668
2669         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2670             """Wraps calls to `right_hand_split`.
2671
2672             The calls increasingly `omit` right-hand trailers (bracket pairs with
2673             content), meaning the trailers get glued together to split on another
2674             bracket pair instead.
2675             """
2676             for omit in generate_trailers_to_omit(line, mode.line_length):
2677                 lines = list(
2678                     right_hand_split(line, mode.line_length, features, omit=omit)
2679                 )
2680                 # Note: this check is only able to figure out if the first line of the
2681                 # *current* transformation fits in the line length.  This is true only
2682                 # for simple cases.  All others require running more transforms via
2683                 # `transform_line()`.  This check doesn't know if those would succeed.
2684                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2685                     yield from lines
2686                     return
2687
2688             # All splits failed, best effort split with no omits.
2689             # This mostly happens to multiline strings that are by definition
2690             # reported as not fitting a single line, as well as lines that contain
2691             # trailing commas (those have to be exploded).
2692             yield from right_hand_split(
2693                 line, line_length=mode.line_length, features=features
2694             )
2695
2696         if mode.experimental_string_processing:
2697             if line.inside_brackets:
2698                 transformers = [
2699                     string_merge,
2700                     string_paren_strip,
2701                     string_split,
2702                     delimiter_split,
2703                     standalone_comment_split,
2704                     string_paren_wrap,
2705                     rhs,
2706                 ]
2707             else:
2708                 transformers = [
2709                     string_merge,
2710                     string_paren_strip,
2711                     string_split,
2712                     string_paren_wrap,
2713                     rhs,
2714                 ]
2715         else:
2716             if line.inside_brackets:
2717                 transformers = [delimiter_split, standalone_comment_split, rhs]
2718             else:
2719                 transformers = [rhs]
2720
2721     for transform in transformers:
2722         # We are accumulating lines in `result` because we might want to abort
2723         # mission and return the original line in the end, or attempt a different
2724         # split altogether.
2725         try:
2726             result = run_transformer(line, transform, mode, features, line_str=line_str)
2727         except CannotTransform:
2728             continue
2729         else:
2730             yield from result
2731             break
2732
2733     else:
2734         yield line
2735
2736
2737 @dataclass  # type: ignore
2738 class StringTransformer(ABC):
2739     """
2740     An implementation of the Transformer protocol that relies on its
2741     subclasses overriding the template methods `do_match(...)` and
2742     `do_transform(...)`.
2743
2744     This Transformer works exclusively on strings (for example, by merging
2745     or splitting them).
2746
2747     The following sections can be found among the docstrings of each concrete
2748     StringTransformer subclass.
2749
2750     Requirements:
2751         Which requirements must be met of the given Line for this
2752         StringTransformer to be applied?
2753
2754     Transformations:
2755         If the given Line meets all of the above requirements, which string
2756         transformations can you expect to be applied to it by this
2757         StringTransformer?
2758
2759     Collaborations:
2760         What contractual agreements does this StringTransformer have with other
2761         StringTransfomers? Such collaborations should be eliminated/minimized
2762         as much as possible.
2763     """
2764
2765     line_length: int
2766     normalize_strings: bool
2767     __name__ = "StringTransformer"
2768
2769     @abstractmethod
2770     def do_match(self, line: Line) -> TMatchResult:
2771         """
2772         Returns:
2773             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2774             string, if a match was able to be made.
2775                 OR
2776             * Err(CannotTransform), if a match was not able to be made.
2777         """
2778
2779     @abstractmethod
2780     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2781         """
2782         Yields:
2783             * Ok(new_line) where new_line is the new transformed line.
2784                 OR
2785             * Err(CannotTransform) if the transformation failed for some reason. The
2786             `do_match(...)` template method should usually be used to reject
2787             the form of the given Line, but in some cases it is difficult to
2788             know whether or not a Line meets the StringTransformer's
2789             requirements until the transformation is already midway.
2790
2791         Side Effects:
2792             This method should NOT mutate @line directly, but it MAY mutate the
2793             Line's underlying Node structure. (WARNING: If the underlying Node
2794             structure IS altered, then this method should NOT be allowed to
2795             yield an CannotTransform after that point.)
2796         """
2797
2798     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2799         """
2800         StringTransformer instances have a call signature that mirrors that of
2801         the Transformer type.
2802
2803         Raises:
2804             CannotTransform(...) if the concrete StringTransformer class is unable
2805             to transform @line.
2806         """
2807         # Optimization to avoid calling `self.do_match(...)` when the line does
2808         # not contain any string.
2809         if not any(leaf.type == token.STRING for leaf in line.leaves):
2810             raise CannotTransform("There are no strings in this line.")
2811
2812         match_result = self.do_match(line)
2813
2814         if isinstance(match_result, Err):
2815             cant_transform = match_result.err()
2816             raise CannotTransform(
2817                 f"The string transformer {self.__class__.__name__} does not recognize"
2818                 " this line as one that it can transform."
2819             ) from cant_transform
2820
2821         string_idx = match_result.ok()
2822
2823         for line_result in self.do_transform(line, string_idx):
2824             if isinstance(line_result, Err):
2825                 cant_transform = line_result.err()
2826                 raise CannotTransform(
2827                     "StringTransformer failed while attempting to transform string."
2828                 ) from cant_transform
2829             line = line_result.ok()
2830             yield line
2831
2832
2833 @dataclass
2834 class CustomSplit:
2835     """A custom (i.e. manual) string split.
2836
2837     A single CustomSplit instance represents a single substring.
2838
2839     Examples:
2840         Consider the following string:
2841         ```
2842         "Hi there friend."
2843         " This is a custom"
2844         f" string {split}."
2845         ```
2846
2847         This string will correspond to the following three CustomSplit instances:
2848         ```
2849         CustomSplit(False, 16)
2850         CustomSplit(False, 17)
2851         CustomSplit(True, 16)
2852         ```
2853     """
2854
2855     has_prefix: bool
2856     break_idx: int
2857
2858
2859 class CustomSplitMapMixin:
2860     """
2861     This mixin class is used to map merged strings to a sequence of
2862     CustomSplits, which will then be used to re-split the strings iff none of
2863     the resultant substrings go over the configured max line length.
2864     """
2865
2866     _Key = Tuple[StringID, str]
2867     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2868
2869     @staticmethod
2870     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2871         """
2872         Returns:
2873             A unique identifier that is used internally to map @string to a
2874             group of custom splits.
2875         """
2876         return (id(string), string)
2877
2878     def add_custom_splits(
2879         self, string: str, custom_splits: Iterable[CustomSplit]
2880     ) -> None:
2881         """Custom Split Map Setter Method
2882
2883         Side Effects:
2884             Adds a mapping from @string to the custom splits @custom_splits.
2885         """
2886         key = self._get_key(string)
2887         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2888
2889     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2890         """Custom Split Map Getter Method
2891
2892         Returns:
2893             * A list of the custom splits that are mapped to @string, if any
2894             exist.
2895                 OR
2896             * [], otherwise.
2897
2898         Side Effects:
2899             Deletes the mapping between @string and its associated custom
2900             splits (which are returned to the caller).
2901         """
2902         key = self._get_key(string)
2903
2904         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2905         del self._CUSTOM_SPLIT_MAP[key]
2906
2907         return list(custom_splits)
2908
2909     def has_custom_splits(self, string: str) -> bool:
2910         """
2911         Returns:
2912             True iff @string is associated with a set of custom splits.
2913         """
2914         key = self._get_key(string)
2915         return key in self._CUSTOM_SPLIT_MAP
2916
2917
2918 class StringMerger(CustomSplitMapMixin, StringTransformer):
2919     """StringTransformer that merges strings together.
2920
2921     Requirements:
2922         (A) The line contains adjacent strings such that ALL of the validation checks
2923         listed in StringMerger.__validate_msg(...)'s docstring pass.
2924             OR
2925         (B) The line contains a string which uses line continuation backslashes.
2926
2927     Transformations:
2928         Depending on which of the two requirements above where met, either:
2929
2930         (A) The string group associated with the target string is merged.
2931             OR
2932         (B) All line-continuation backslashes are removed from the target string.
2933
2934     Collaborations:
2935         StringMerger provides custom split information to StringSplitter.
2936     """
2937
2938     def do_match(self, line: Line) -> TMatchResult:
2939         LL = line.leaves
2940
2941         is_valid_index = is_valid_index_factory(LL)
2942
2943         for (i, leaf) in enumerate(LL):
2944             if (
2945                 leaf.type == token.STRING
2946                 and is_valid_index(i + 1)
2947                 and LL[i + 1].type == token.STRING
2948             ):
2949                 return Ok(i)
2950
2951             if leaf.type == token.STRING and "\\\n" in leaf.value:
2952                 return Ok(i)
2953
2954         return TErr("This line has no strings that need merging.")
2955
2956     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2957         new_line = line
2958         rblc_result = self.__remove_backslash_line_continuation_chars(
2959             new_line, string_idx
2960         )
2961         if isinstance(rblc_result, Ok):
2962             new_line = rblc_result.ok()
2963
2964         msg_result = self.__merge_string_group(new_line, string_idx)
2965         if isinstance(msg_result, Ok):
2966             new_line = msg_result.ok()
2967
2968         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
2969             msg_cant_transform = msg_result.err()
2970             rblc_cant_transform = rblc_result.err()
2971             cant_transform = CannotTransform(
2972                 "StringMerger failed to merge any strings in this line."
2973             )
2974
2975             # Chain the errors together using `__cause__`.
2976             msg_cant_transform.__cause__ = rblc_cant_transform
2977             cant_transform.__cause__ = msg_cant_transform
2978
2979             yield Err(cant_transform)
2980         else:
2981             yield Ok(new_line)
2982
2983     @staticmethod
2984     def __remove_backslash_line_continuation_chars(
2985         line: Line, string_idx: int
2986     ) -> TResult[Line]:
2987         """
2988         Merge strings that were split across multiple lines using
2989         line-continuation backslashes.
2990
2991         Returns:
2992             Ok(new_line), if @line contains backslash line-continuation
2993             characters.
2994                 OR
2995             Err(CannotTransform), otherwise.
2996         """
2997         LL = line.leaves
2998
2999         string_leaf = LL[string_idx]
3000         if not (
3001             string_leaf.type == token.STRING
3002             and "\\\n" in string_leaf.value
3003             and not has_triple_quotes(string_leaf.value)
3004         ):
3005             return TErr(
3006                 f"String leaf {string_leaf} does not contain any backslash line"
3007                 " continuation characters."
3008             )
3009
3010         new_line = line.clone()
3011         new_line.comments = line.comments.copy()
3012         append_leaves(new_line, line, LL)
3013
3014         new_string_leaf = new_line.leaves[string_idx]
3015         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3016
3017         return Ok(new_line)
3018
3019     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3020         """
3021         Merges string group (i.e. set of adjacent strings) where the first
3022         string in the group is `line.leaves[string_idx]`.
3023
3024         Returns:
3025             Ok(new_line), if ALL of the validation checks found in
3026             __validate_msg(...) pass.
3027                 OR
3028             Err(CannotTransform), otherwise.
3029         """
3030         LL = line.leaves
3031
3032         is_valid_index = is_valid_index_factory(LL)
3033
3034         vresult = self.__validate_msg(line, string_idx)
3035         if isinstance(vresult, Err):
3036             return vresult
3037
3038         # If the string group is wrapped inside an Atom node, we must make sure
3039         # to later replace that Atom with our new (merged) string leaf.
3040         atom_node = LL[string_idx].parent
3041
3042         # We will place BREAK_MARK in between every two substrings that we
3043         # merge. We will then later go through our final result and use the
3044         # various instances of BREAK_MARK we find to add the right values to
3045         # the custom split map.
3046         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3047
3048         QUOTE = LL[string_idx].value[-1]
3049
3050         def make_naked(string: str, string_prefix: str) -> str:
3051             """Strip @string (i.e. make it a "naked" string)
3052
3053             Pre-conditions:
3054                 * assert_is_leaf_string(@string)
3055
3056             Returns:
3057                 A string that is identical to @string except that
3058                 @string_prefix has been stripped, the surrounding QUOTE
3059                 characters have been removed, and any remaining QUOTE
3060                 characters have been escaped.
3061             """
3062             assert_is_leaf_string(string)
3063
3064             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3065             naked_string = string[len(string_prefix) + 1 : -1]
3066             naked_string = re.sub(
3067                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3068             )
3069             return naked_string
3070
3071         # Holds the CustomSplit objects that will later be added to the custom
3072         # split map.
3073         custom_splits = []
3074
3075         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3076         prefix_tracker = []
3077
3078         # Sets the 'prefix' variable. This is the prefix that the final merged
3079         # string will have.
3080         next_str_idx = string_idx
3081         prefix = ""
3082         while (
3083             not prefix
3084             and is_valid_index(next_str_idx)
3085             and LL[next_str_idx].type == token.STRING
3086         ):
3087             prefix = get_string_prefix(LL[next_str_idx].value)
3088             next_str_idx += 1
3089
3090         # The next loop merges the string group. The final string will be
3091         # contained in 'S'.
3092         #
3093         # The following convenience variables are used:
3094         #
3095         #   S: string
3096         #   NS: naked string
3097         #   SS: next string
3098         #   NSS: naked next string
3099         S = ""
3100         NS = ""
3101         num_of_strings = 0
3102         next_str_idx = string_idx
3103         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3104             num_of_strings += 1
3105
3106             SS = LL[next_str_idx].value
3107             next_prefix = get_string_prefix(SS)
3108
3109             # If this is an f-string group but this substring is not prefixed
3110             # with 'f'...
3111             if "f" in prefix and "f" not in next_prefix:
3112                 # Then we must escape any braces contained in this substring.
3113                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3114
3115             NSS = make_naked(SS, next_prefix)
3116
3117             has_prefix = bool(next_prefix)
3118             prefix_tracker.append(has_prefix)
3119
3120             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3121             NS = make_naked(S, prefix)
3122
3123             next_str_idx += 1
3124
3125         S_leaf = Leaf(token.STRING, S)
3126         if self.normalize_strings:
3127             normalize_string_quotes(S_leaf)
3128
3129         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3130         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3131         for has_prefix in prefix_tracker:
3132             mark_idx = temp_string.find(BREAK_MARK)
3133             assert (
3134                 mark_idx >= 0
3135             ), "Logic error while filling the custom string breakpoint cache."
3136
3137             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3138             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3139             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3140
3141         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3142
3143         if atom_node is not None:
3144             replace_child(atom_node, string_leaf)
3145
3146         # Build the final line ('new_line') that this method will later return.
3147         new_line = line.clone()
3148         for (i, leaf) in enumerate(LL):
3149             if i == string_idx:
3150                 new_line.append(string_leaf)
3151
3152             if string_idx <= i < string_idx + num_of_strings:
3153                 for comment_leaf in line.comments_after(LL[i]):
3154                     new_line.append(comment_leaf, preformatted=True)
3155                 continue
3156
3157             append_leaves(new_line, line, [leaf])
3158
3159         self.add_custom_splits(string_leaf.value, custom_splits)
3160         return Ok(new_line)
3161
3162     @staticmethod
3163     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3164         """Validate (M)erge (S)tring (G)roup
3165
3166         Transform-time string validation logic for __merge_string_group(...).
3167
3168         Returns:
3169             * Ok(None), if ALL validation checks (listed below) pass.
3170                 OR
3171             * Err(CannotTransform), if any of the following are true:
3172                 - The target string group does not contain ANY stand-alone comments.
3173                 - The target string is not in a string group (i.e. it has no
3174                   adjacent strings).
3175                 - The string group has more than one inline comment.
3176                 - The string group has an inline comment that appears to be a pragma.
3177                 - The set of all string prefixes in the string group is of
3178                   length greater than one and is not equal to {"", "f"}.
3179                 - The string group consists of raw strings.
3180         """
3181         # We first check for "inner" stand-alone comments (i.e. stand-alone
3182         # comments that have a string leaf before them AND after them).
3183         for inc in [1, -1]:
3184             i = string_idx
3185             found_sa_comment = False
3186             is_valid_index = is_valid_index_factory(line.leaves)
3187             while is_valid_index(i) and line.leaves[i].type in [
3188                 token.STRING,
3189                 STANDALONE_COMMENT,
3190             ]:
3191                 if line.leaves[i].type == STANDALONE_COMMENT:
3192                     found_sa_comment = True
3193                 elif found_sa_comment:
3194                     return TErr(
3195                         "StringMerger does NOT merge string groups which contain "
3196                         "stand-alone comments."
3197                     )
3198
3199                 i += inc
3200
3201         num_of_inline_string_comments = 0
3202         set_of_prefixes = set()
3203         num_of_strings = 0
3204         for leaf in line.leaves[string_idx:]:
3205             if leaf.type != token.STRING:
3206                 # If the string group is trailed by a comma, we count the
3207                 # comments trailing the comma to be one of the string group's
3208                 # comments.
3209                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3210                     num_of_inline_string_comments += 1
3211                 break
3212
3213             if has_triple_quotes(leaf.value):
3214                 return TErr("StringMerger does NOT merge multiline strings.")
3215
3216             num_of_strings += 1
3217             prefix = get_string_prefix(leaf.value)
3218             if "r" in prefix:
3219                 return TErr("StringMerger does NOT merge raw strings.")
3220
3221             set_of_prefixes.add(prefix)
3222
3223             if id(leaf) in line.comments:
3224                 num_of_inline_string_comments += 1
3225                 if contains_pragma_comment(line.comments[id(leaf)]):
3226                     return TErr("Cannot merge strings which have pragma comments.")
3227
3228         if num_of_strings < 2:
3229             return TErr(
3230                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3231             )
3232
3233         if num_of_inline_string_comments > 1:
3234             return TErr(
3235                 f"Too many inline string comments ({num_of_inline_string_comments})."
3236             )
3237
3238         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3239             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3240
3241         return Ok(None)
3242
3243
3244 class StringParenStripper(StringTransformer):
3245     """StringTransformer that strips surrounding parentheses from strings.
3246
3247     Requirements:
3248         The line contains a string which is surrounded by parentheses and:
3249             - The target string is NOT the only argument to a function call).
3250             - If the target string contains a PERCENT, the brackets are not
3251               preceeded or followed by an operator with higher precedence than
3252               PERCENT.
3253
3254     Transformations:
3255         The parentheses mentioned in the 'Requirements' section are stripped.
3256
3257     Collaborations:
3258         StringParenStripper has its own inherent usefulness, but it is also
3259         relied on to clean up the parentheses created by StringParenWrapper (in
3260         the event that they are no longer needed).
3261     """
3262
3263     def do_match(self, line: Line) -> TMatchResult:
3264         LL = line.leaves
3265
3266         is_valid_index = is_valid_index_factory(LL)
3267
3268         for (idx, leaf) in enumerate(LL):
3269             # Should be a string...
3270             if leaf.type != token.STRING:
3271                 continue
3272
3273             # Should be preceded by a non-empty LPAR...
3274             if (
3275                 not is_valid_index(idx - 1)
3276                 or LL[idx - 1].type != token.LPAR
3277                 or is_empty_lpar(LL[idx - 1])
3278             ):
3279                 continue
3280
3281             # That LPAR should NOT be preceded by a function name or a closing
3282             # bracket (which could be a function which returns a function or a
3283             # list/dictionary that contains a function)...
3284             if is_valid_index(idx - 2) and (
3285                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3286             ):
3287                 continue
3288
3289             string_idx = idx
3290
3291             # Skip the string trailer, if one exists.
3292             string_parser = StringParser()
3293             next_idx = string_parser.parse(LL, string_idx)
3294
3295             # if the leaves in the parsed string include a PERCENT, we need to
3296             # make sure the initial LPAR is NOT preceded by an operator with
3297             # higher or equal precedence to PERCENT
3298             if is_valid_index(idx - 2):
3299                 # mypy can't quite follow unless we name this
3300                 before_lpar = LL[idx - 2]
3301                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3302                     (
3303                         before_lpar.type
3304                         in {
3305                             token.STAR,
3306                             token.AT,
3307                             token.SLASH,
3308                             token.DOUBLESLASH,
3309                             token.PERCENT,
3310                             token.TILDE,
3311                             token.DOUBLESTAR,
3312                             token.AWAIT,
3313                             token.LSQB,
3314                             token.LPAR,
3315                         }
3316                     )
3317                     or (
3318                         # only unary PLUS/MINUS
3319                         before_lpar.parent
3320                         and before_lpar.parent.type == syms.factor
3321                         and (before_lpar.type in {token.PLUS, token.MINUS})
3322                     )
3323                 ):
3324                     continue
3325
3326             # Should be followed by a non-empty RPAR...
3327             if (
3328                 is_valid_index(next_idx)
3329                 and LL[next_idx].type == token.RPAR
3330                 and not is_empty_rpar(LL[next_idx])
3331             ):
3332                 # That RPAR should NOT be followed by anything with higher
3333                 # precedence than PERCENT
3334                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3335                     token.DOUBLESTAR,
3336                     token.LSQB,
3337                     token.LPAR,
3338                     token.DOT,
3339                 }:
3340                     continue
3341
3342                 return Ok(string_idx)
3343
3344         return TErr("This line has no strings wrapped in parens.")
3345
3346     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3347         LL = line.leaves
3348
3349         string_parser = StringParser()
3350         rpar_idx = string_parser.parse(LL, string_idx)
3351
3352         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3353             if line.comments_after(leaf):
3354                 yield TErr(
3355                     "Will not strip parentheses which have comments attached to them."
3356                 )
3357                 return
3358
3359         new_line = line.clone()
3360         new_line.comments = line.comments.copy()
3361         try:
3362             append_leaves(new_line, line, LL[: string_idx - 1])
3363         except BracketMatchError:
3364             # HACK: I believe there is currently a bug somewhere in
3365             # right_hand_split() that is causing brackets to not be tracked
3366             # properly by a shared BracketTracker.
3367             append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
3368
3369         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3370         LL[string_idx - 1].remove()
3371         replace_child(LL[string_idx], string_leaf)
3372         new_line.append(string_leaf)
3373
3374         append_leaves(
3375             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3376         )
3377
3378         LL[rpar_idx].remove()
3379
3380         yield Ok(new_line)
3381
3382
3383 class BaseStringSplitter(StringTransformer):
3384     """
3385     Abstract class for StringTransformers which transform a Line's strings by splitting
3386     them or placing them on their own lines where necessary to avoid going over
3387     the configured line length.
3388
3389     Requirements:
3390         * The target string value is responsible for the line going over the
3391         line length limit. It follows that after all of black's other line
3392         split methods have been exhausted, this line (or one of the resulting
3393         lines after all line splits are performed) would still be over the
3394         line_length limit unless we split this string.
3395             AND
3396         * The target string is NOT a "pointless" string (i.e. a string that has
3397         no parent or siblings).
3398             AND
3399         * The target string is not followed by an inline comment that appears
3400         to be a pragma.
3401             AND
3402         * The target string is not a multiline (i.e. triple-quote) string.
3403     """
3404
3405     @abstractmethod
3406     def do_splitter_match(self, line: Line) -> TMatchResult:
3407         """
3408         BaseStringSplitter asks its clients to override this method instead of
3409         `StringTransformer.do_match(...)`.
3410
3411         Follows the same protocol as `StringTransformer.do_match(...)`.
3412
3413         Refer to `help(StringTransformer.do_match)` for more information.
3414         """
3415
3416     def do_match(self, line: Line) -> TMatchResult:
3417         match_result = self.do_splitter_match(line)
3418         if isinstance(match_result, Err):
3419             return match_result
3420
3421         string_idx = match_result.ok()
3422         vresult = self.__validate(line, string_idx)
3423         if isinstance(vresult, Err):
3424             return vresult
3425
3426         return match_result
3427
3428     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3429         """
3430         Checks that @line meets all of the requirements listed in this classes'
3431         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3432         description of those requirements.
3433
3434         Returns:
3435             * Ok(None), if ALL of the requirements are met.
3436                 OR
3437             * Err(CannotTransform), if ANY of the requirements are NOT met.
3438         """
3439         LL = line.leaves
3440
3441         string_leaf = LL[string_idx]
3442
3443         max_string_length = self.__get_max_string_length(line, string_idx)
3444         if len(string_leaf.value) <= max_string_length:
3445             return TErr(
3446                 "The string itself is not what is causing this line to be too long."
3447             )
3448
3449         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3450             token.STRING,
3451             token.NEWLINE,
3452         ]:
3453             return TErr(
3454                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3455                 " no parent)."
3456             )
3457
3458         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3459             line.comments[id(line.leaves[string_idx])]
3460         ):
3461             return TErr(
3462                 "Line appears to end with an inline pragma comment. Splitting the line"
3463                 " could modify the pragma's behavior."
3464             )
3465
3466         if has_triple_quotes(string_leaf.value):
3467             return TErr("We cannot split multiline strings.")
3468
3469         return Ok(None)
3470
3471     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3472         """
3473         Calculates the max string length used when attempting to determine
3474         whether or not the target string is responsible for causing the line to
3475         go over the line length limit.
3476
3477         WARNING: This method is tightly coupled to both StringSplitter and
3478         (especially) StringParenWrapper. There is probably a better way to
3479         accomplish what is being done here.
3480
3481         Returns:
3482             max_string_length: such that `line.leaves[string_idx].value >
3483             max_string_length` implies that the target string IS responsible
3484             for causing this line to exceed the line length limit.
3485         """
3486         LL = line.leaves
3487
3488         is_valid_index = is_valid_index_factory(LL)
3489
3490         # We use the shorthand "WMA4" in comments to abbreviate "We must
3491         # account for". When giving examples, we use STRING to mean some/any
3492         # valid string.
3493         #
3494         # Finally, we use the following convenience variables:
3495         #
3496         #   P:  The leaf that is before the target string leaf.
3497         #   N:  The leaf that is after the target string leaf.
3498         #   NN: The leaf that is after N.
3499
3500         # WMA4 the whitespace at the beginning of the line.
3501         offset = line.depth * 4
3502
3503         if is_valid_index(string_idx - 1):
3504             p_idx = string_idx - 1
3505             if (
3506                 LL[string_idx - 1].type == token.LPAR
3507                 and LL[string_idx - 1].value == ""
3508                 and string_idx >= 2
3509             ):
3510                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3511                 p_idx -= 1
3512
3513             P = LL[p_idx]
3514             if P.type == token.PLUS:
3515                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3516                 offset += 2
3517
3518             if P.type == token.COMMA:
3519                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3520                 offset += 3
3521
3522             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3523                 # This conditional branch is meant to handle dictionary keys,
3524                 # variable assignments, 'return STRING' statement lines, and
3525                 # 'else STRING' ternary expression lines.
3526
3527                 # WMA4 a single space.
3528                 offset += 1
3529
3530                 # WMA4 the lengths of any leaves that came before that space,
3531                 # but after any closing bracket before that space.
3532                 for leaf in reversed(LL[: p_idx + 1]):
3533                     offset += len(str(leaf))
3534                     if leaf.type in CLOSING_BRACKETS:
3535                         break
3536
3537         if is_valid_index(string_idx + 1):
3538             N = LL[string_idx + 1]
3539             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3540                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3541                 N = LL[string_idx + 2]
3542
3543             if N.type == token.COMMA:
3544                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3545                 offset += 1
3546
3547             if is_valid_index(string_idx + 2):
3548                 NN = LL[string_idx + 2]
3549
3550                 if N.type == token.DOT and NN.type == token.NAME:
3551                     # This conditional branch is meant to handle method calls invoked
3552                     # off of a string literal up to and including the LPAR character.
3553
3554                     # WMA4 the '.' character.
3555                     offset += 1
3556
3557                     if (
3558                         is_valid_index(string_idx + 3)
3559                         and LL[string_idx + 3].type == token.LPAR
3560                     ):
3561                         # WMA4 the left parenthesis character.
3562                         offset += 1
3563
3564                     # WMA4 the length of the method's name.
3565                     offset += len(NN.value)
3566
3567         has_comments = False
3568         for comment_leaf in line.comments_after(LL[string_idx]):
3569             if not has_comments:
3570                 has_comments = True
3571                 # WMA4 two spaces before the '#' character.
3572                 offset += 2
3573
3574             # WMA4 the length of the inline comment.
3575             offset += len(comment_leaf.value)
3576
3577         max_string_length = self.line_length - offset
3578         return max_string_length
3579
3580
3581 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3582     """
3583     StringTransformer that splits "atom" strings (i.e. strings which exist on
3584     lines by themselves).
3585
3586     Requirements:
3587         * The line consists ONLY of a single string (with the exception of a
3588         '+' symbol which MAY exist at the start of the line), MAYBE a string
3589         trailer, and MAYBE a trailing comma.
3590             AND
3591         * All of the requirements listed in BaseStringSplitter's docstring.
3592
3593     Transformations:
3594         The string mentioned in the 'Requirements' section is split into as
3595         many substrings as necessary to adhere to the configured line length.
3596
3597         In the final set of substrings, no substring should be smaller than
3598         MIN_SUBSTR_SIZE characters.
3599
3600         The string will ONLY be split on spaces (i.e. each new substring should
3601         start with a space).
3602
3603         If the string is an f-string, it will NOT be split in the middle of an
3604         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3605         else bar()} is an f-expression).
3606
3607         If the string that is being split has an associated set of custom split
3608         records and those custom splits will NOT result in any line going over
3609         the configured line length, those custom splits are used. Otherwise the
3610         string is split as late as possible (from left-to-right) while still
3611         adhering to the transformation rules listed above.
3612
3613     Collaborations:
3614         StringSplitter relies on StringMerger to construct the appropriate
3615         CustomSplit objects and add them to the custom split map.
3616     """
3617
3618     MIN_SUBSTR_SIZE = 6
3619     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3620     RE_FEXPR = r"""
3621     (?<!\{)\{
3622         (?:
3623             [^\{\}]
3624             | \{\{
3625             | \}\}
3626         )+?
3627     (?<!\})(?:\}\})*\}(?!\})
3628     """
3629
3630     def do_splitter_match(self, line: Line) -> TMatchResult:
3631         LL = line.leaves
3632
3633         is_valid_index = is_valid_index_factory(LL)
3634
3635         idx = 0
3636
3637         # The first leaf MAY be a '+' symbol...
3638         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3639             idx += 1
3640
3641         # The next/first leaf MAY be an empty LPAR...
3642         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3643             idx += 1
3644
3645         # The next/first leaf MUST be a string...
3646         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3647             return TErr("Line does not start with a string.")
3648
3649         string_idx = idx
3650
3651         # Skip the string trailer, if one exists.
3652         string_parser = StringParser()
3653         idx = string_parser.parse(LL, string_idx)
3654
3655         # That string MAY be followed by an empty RPAR...
3656         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3657             idx += 1
3658
3659         # That string / empty RPAR leaf MAY be followed by a comma...
3660         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3661             idx += 1
3662
3663         # But no more leaves are allowed...
3664         if is_valid_index(idx):
3665             return TErr("This line does not end with a string.")
3666
3667         return Ok(string_idx)
3668
3669     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3670         LL = line.leaves
3671
3672         QUOTE = LL[string_idx].value[-1]
3673
3674         is_valid_index = is_valid_index_factory(LL)
3675         insert_str_child = insert_str_child_factory(LL[string_idx])
3676
3677         prefix = get_string_prefix(LL[string_idx].value)
3678
3679         # We MAY choose to drop the 'f' prefix from substrings that don't
3680         # contain any f-expressions, but ONLY if the original f-string
3681         # contains at least one f-expression. Otherwise, we will alter the AST
3682         # of the program.
3683         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3684             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3685         )
3686
3687         first_string_line = True
3688         starts_with_plus = LL[0].type == token.PLUS
3689
3690         def line_needs_plus() -> bool:
3691             return first_string_line and starts_with_plus
3692
3693         def maybe_append_plus(new_line: Line) -> None:
3694             """
3695             Side Effects:
3696                 If @line starts with a plus and this is the first line we are
3697                 constructing, this function appends a PLUS leaf to @new_line
3698                 and replaces the old PLUS leaf in the node structure. Otherwise
3699                 this function does nothing.
3700             """
3701             if line_needs_plus():
3702                 plus_leaf = Leaf(token.PLUS, "+")
3703                 replace_child(LL[0], plus_leaf)
3704                 new_line.append(plus_leaf)
3705
3706         ends_with_comma = (
3707             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3708         )
3709
3710         def max_last_string() -> int:
3711             """
3712             Returns:
3713                 The max allowed length of the string value used for the last
3714                 line we will construct.
3715             """
3716             result = self.line_length
3717             result -= line.depth * 4
3718             result -= 1 if ends_with_comma else 0
3719             result -= 2 if line_needs_plus() else 0
3720             return result
3721
3722         # --- Calculate Max Break Index (for string value)
3723         # We start with the line length limit
3724         max_break_idx = self.line_length
3725         # The last index of a string of length N is N-1.
3726         max_break_idx -= 1
3727         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3728         max_break_idx -= line.depth * 4
3729         if max_break_idx < 0:
3730             yield TErr(
3731                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3732                 f" {line.depth}"
3733             )
3734             return
3735
3736         # Check if StringMerger registered any custom splits.
3737         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3738         # We use them ONLY if none of them would produce lines that exceed the
3739         # line limit.
3740         use_custom_breakpoints = bool(
3741             custom_splits
3742             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3743         )
3744
3745         # Temporary storage for the remaining chunk of the string line that
3746         # can't fit onto the line currently being constructed.
3747         rest_value = LL[string_idx].value
3748
3749         def more_splits_should_be_made() -> bool:
3750             """
3751             Returns:
3752                 True iff `rest_value` (the remaining string value from the last
3753                 split), should be split again.
3754             """
3755             if use_custom_breakpoints:
3756                 return len(custom_splits) > 1
3757             else:
3758                 return len(rest_value) > max_last_string()
3759
3760         string_line_results: List[Ok[Line]] = []
3761         while more_splits_should_be_made():
3762             if use_custom_breakpoints:
3763                 # Custom User Split (manual)
3764                 csplit = custom_splits.pop(0)
3765                 break_idx = csplit.break_idx
3766             else:
3767                 # Algorithmic Split (automatic)
3768                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3769                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3770                 if maybe_break_idx is None:
3771                     # If we are unable to algorithmically determine a good split
3772                     # and this string has custom splits registered to it, we
3773                     # fall back to using them--which means we have to start
3774                     # over from the beginning.
3775                     if custom_splits:
3776                         rest_value = LL[string_idx].value
3777                         string_line_results = []
3778                         first_string_line = True
3779                         use_custom_breakpoints = True
3780                         continue
3781
3782                     # Otherwise, we stop splitting here.
3783                     break
3784
3785                 break_idx = maybe_break_idx
3786
3787             # --- Construct `next_value`
3788             next_value = rest_value[:break_idx] + QUOTE
3789             if (
3790                 # Are we allowed to try to drop a pointless 'f' prefix?
3791                 drop_pointless_f_prefix
3792                 # If we are, will we be successful?
3793                 and next_value != self.__normalize_f_string(next_value, prefix)
3794             ):
3795                 # If the current custom split did NOT originally use a prefix,
3796                 # then `csplit.break_idx` will be off by one after removing
3797                 # the 'f' prefix.
3798                 break_idx = (
3799                     break_idx + 1
3800                     if use_custom_breakpoints and not csplit.has_prefix
3801                     else break_idx
3802                 )
3803                 next_value = rest_value[:break_idx] + QUOTE
3804                 next_value = self.__normalize_f_string(next_value, prefix)
3805
3806             # --- Construct `next_leaf`
3807             next_leaf = Leaf(token.STRING, next_value)
3808             insert_str_child(next_leaf)
3809             self.__maybe_normalize_string_quotes(next_leaf)
3810
3811             # --- Construct `next_line`
3812             next_line = line.clone()
3813             maybe_append_plus(next_line)
3814             next_line.append(next_leaf)
3815             string_line_results.append(Ok(next_line))
3816
3817             rest_value = prefix + QUOTE + rest_value[break_idx:]
3818             first_string_line = False
3819
3820         yield from string_line_results
3821
3822         if drop_pointless_f_prefix:
3823             rest_value = self.__normalize_f_string(rest_value, prefix)
3824
3825         rest_leaf = Leaf(token.STRING, rest_value)
3826         insert_str_child(rest_leaf)
3827
3828         # NOTE: I could not find a test case that verifies that the following
3829         # line is actually necessary, but it seems to be. Otherwise we risk
3830         # not normalizing the last substring, right?
3831         self.__maybe_normalize_string_quotes(rest_leaf)
3832
3833         last_line = line.clone()
3834         maybe_append_plus(last_line)
3835
3836         # If there are any leaves to the right of the target string...
3837         if is_valid_index(string_idx + 1):
3838             # We use `temp_value` here to determine how long the last line
3839             # would be if we were to append all the leaves to the right of the
3840             # target string to the last string line.
3841             temp_value = rest_value
3842             for leaf in LL[string_idx + 1 :]:
3843                 temp_value += str(leaf)
3844                 if leaf.type == token.LPAR:
3845                     break
3846
3847             # Try to fit them all on the same line with the last substring...
3848             if (
3849                 len(temp_value) <= max_last_string()
3850                 or LL[string_idx + 1].type == token.COMMA
3851             ):
3852                 last_line.append(rest_leaf)
3853                 append_leaves(last_line, line, LL[string_idx + 1 :])
3854                 yield Ok(last_line)
3855             # Otherwise, place the last substring on one line and everything
3856             # else on a line below that...
3857             else:
3858                 last_line.append(rest_leaf)
3859                 yield Ok(last_line)
3860
3861                 non_string_line = line.clone()
3862                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3863                 yield Ok(non_string_line)
3864         # Else the target string was the last leaf...
3865         else:
3866             last_line.append(rest_leaf)
3867             last_line.comments = line.comments.copy()
3868             yield Ok(last_line)
3869
3870     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3871         """
3872         This method contains the algorithm that StringSplitter uses to
3873         determine which character to split each string at.
3874
3875         Args:
3876             @string: The substring that we are attempting to split.
3877             @max_break_idx: The ideal break index. We will return this value if it
3878             meets all the necessary conditions. In the likely event that it
3879             doesn't we will try to find the closest index BELOW @max_break_idx
3880             that does. If that fails, we will expand our search by also
3881             considering all valid indices ABOVE @max_break_idx.
3882
3883         Pre-Conditions:
3884             * assert_is_leaf_string(@string)
3885             * 0 <= @max_break_idx < len(@string)
3886
3887         Returns:
3888             break_idx, if an index is able to be found that meets all of the
3889             conditions listed in the 'Transformations' section of this classes'
3890             docstring.
3891                 OR
3892             None, otherwise.
3893         """
3894         is_valid_index = is_valid_index_factory(string)
3895
3896         assert is_valid_index(max_break_idx)
3897         assert_is_leaf_string(string)
3898
3899         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3900
3901         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3902             """
3903             Yields:
3904                 All ranges of @string which, if @string were to be split there,
3905                 would result in the splitting of an f-expression (which is NOT
3906                 allowed).
3907             """
3908             nonlocal _fexpr_slices
3909
3910             if _fexpr_slices is None:
3911                 _fexpr_slices = []
3912                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3913                     _fexpr_slices.append(match.span())
3914
3915             yield from _fexpr_slices
3916
3917         is_fstring = "f" in get_string_prefix(string)
3918
3919         def breaks_fstring_expression(i: Index) -> bool:
3920             """
3921             Returns:
3922                 True iff returning @i would result in the splitting of an
3923                 f-expression (which is NOT allowed).
3924             """
3925             if not is_fstring:
3926                 return False
3927
3928             for (start, end) in fexpr_slices():
3929                 if start <= i < end:
3930                     return True
3931
3932             return False
3933
3934         def passes_all_checks(i: Index) -> bool:
3935             """
3936             Returns:
3937                 True iff ALL of the conditions listed in the 'Transformations'
3938                 section of this classes' docstring would be be met by returning @i.
3939             """
3940             is_space = string[i] == " "
3941             is_big_enough = (
3942                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3943                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3944             )
3945             return is_space and is_big_enough and not breaks_fstring_expression(i)
3946
3947         # First, we check all indices BELOW @max_break_idx.
3948         break_idx = max_break_idx
3949         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3950             break_idx -= 1
3951
3952         if not passes_all_checks(break_idx):
3953             # If that fails, we check all indices ABOVE @max_break_idx.
3954             #
3955             # If we are able to find a valid index here, the next line is going
3956             # to be longer than the specified line length, but it's probably
3957             # better than doing nothing at all.
3958             break_idx = max_break_idx + 1
3959             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3960                 break_idx += 1
3961
3962             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3963                 return None
3964
3965         return break_idx
3966
3967     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3968         if self.normalize_strings:
3969             normalize_string_quotes(leaf)
3970
3971     def __normalize_f_string(self, string: str, prefix: str) -> str:
3972         """
3973         Pre-Conditions:
3974             * assert_is_leaf_string(@string)
3975
3976         Returns:
3977             * If @string is an f-string that contains no f-expressions, we
3978             return a string identical to @string except that the 'f' prefix
3979             has been stripped and all double braces (i.e. '{{' or '}}') have
3980             been normalized (i.e. turned into '{' or '}').
3981                 OR
3982             * Otherwise, we return @string.
3983         """
3984         assert_is_leaf_string(string)
3985
3986         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3987             new_prefix = prefix.replace("f", "")
3988
3989             temp = string[len(prefix) :]
3990             temp = re.sub(r"\{\{", "{", temp)
3991             temp = re.sub(r"\}\}", "}", temp)
3992             new_string = temp
3993
3994             return f"{new_prefix}{new_string}"
3995         else:
3996             return string
3997
3998
3999 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4000     """
4001     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4002     exist on lines by themselves).
4003
4004     Requirements:
4005         All of the requirements listed in BaseStringSplitter's docstring in
4006         addition to the requirements listed below:
4007
4008         * The line is a return/yield statement, which returns/yields a string.
4009             OR
4010         * The line is part of a ternary expression (e.g. `x = y if cond else
4011         z`) such that the line starts with `else <string>`, where <string> is
4012         some string.
4013             OR
4014         * The line is an assert statement, which ends with a string.
4015             OR
4016         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4017         <string>`) such that the variable is being assigned the value of some
4018         string.
4019             OR
4020         * The line is a dictionary key assignment where some valid key is being
4021         assigned the value of some string.
4022
4023     Transformations:
4024         The chosen string is wrapped in parentheses and then split at the LPAR.
4025
4026         We then have one line which ends with an LPAR and another line that
4027         starts with the chosen string. The latter line is then split again at
4028         the RPAR. This results in the RPAR (and possibly a trailing comma)
4029         being placed on its own line.
4030
4031         NOTE: If any leaves exist to the right of the chosen string (except
4032         for a trailing comma, which would be placed after the RPAR), those
4033         leaves are placed inside the parentheses.  In effect, the chosen
4034         string is not necessarily being "wrapped" by parentheses. We can,
4035         however, count on the LPAR being placed directly before the chosen
4036         string.
4037
4038         In other words, StringParenWrapper creates "atom" strings. These
4039         can then be split again by StringSplitter, if necessary.
4040
4041     Collaborations:
4042         In the event that a string line split by StringParenWrapper is
4043         changed such that it no longer needs to be given its own line,
4044         StringParenWrapper relies on StringParenStripper to clean up the
4045         parentheses it created.
4046     """
4047
4048     def do_splitter_match(self, line: Line) -> TMatchResult:
4049         LL = line.leaves
4050
4051         string_idx = (
4052             self._return_match(LL)
4053             or self._else_match(LL)
4054             or self._assert_match(LL)
4055             or self._assign_match(LL)
4056             or self._dict_match(LL)
4057         )
4058
4059         if string_idx is not None:
4060             string_value = line.leaves[string_idx].value
4061             # If the string has no spaces...
4062             if " " not in string_value:
4063                 # And will still violate the line length limit when split...
4064                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4065                 if len(string_value) > max_string_length:
4066                     # And has no associated custom splits...
4067                     if not self.has_custom_splits(string_value):
4068                         # Then we should NOT put this string on its own line.
4069                         return TErr(
4070                             "We do not wrap long strings in parentheses when the"
4071                             " resultant line would still be over the specified line"
4072                             " length and can't be split further by StringSplitter."
4073                         )
4074             return Ok(string_idx)
4075
4076         return TErr("This line does not contain any non-atomic strings.")
4077
4078     @staticmethod
4079     def _return_match(LL: List[Leaf]) -> Optional[int]:
4080         """
4081         Returns:
4082             string_idx such that @LL[string_idx] is equal to our target (i.e.
4083             matched) string, if this line matches the return/yield statement
4084             requirements listed in the 'Requirements' section of this classes'
4085             docstring.
4086                 OR
4087             None, otherwise.
4088         """
4089         # If this line is apart of a return/yield statement and the first leaf
4090         # contains either the "return" or "yield" keywords...
4091         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4092             0
4093         ].value in ["return", "yield"]:
4094             is_valid_index = is_valid_index_factory(LL)
4095
4096             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4097             # The next visible leaf MUST contain a string...
4098             if is_valid_index(idx) and LL[idx].type == token.STRING:
4099                 return idx
4100
4101         return None
4102
4103     @staticmethod
4104     def _else_match(LL: List[Leaf]) -> Optional[int]:
4105         """
4106         Returns:
4107             string_idx such that @LL[string_idx] is equal to our target (i.e.
4108             matched) string, if this line matches the ternary expression
4109             requirements listed in the 'Requirements' section of this classes'
4110             docstring.
4111                 OR
4112             None, otherwise.
4113         """
4114         # If this line is apart of a ternary expression and the first leaf
4115         # contains the "else" keyword...
4116         if (
4117             parent_type(LL[0]) == syms.test
4118             and LL[0].type == token.NAME
4119             and LL[0].value == "else"
4120         ):
4121             is_valid_index = is_valid_index_factory(LL)
4122
4123             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4124             # The next visible leaf MUST contain a string...
4125             if is_valid_index(idx) and LL[idx].type == token.STRING:
4126                 return idx
4127
4128         return None
4129
4130     @staticmethod
4131     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4132         """
4133         Returns:
4134             string_idx such that @LL[string_idx] is equal to our target (i.e.
4135             matched) string, if this line matches the assert statement
4136             requirements listed in the 'Requirements' section of this classes'
4137             docstring.
4138                 OR
4139             None, otherwise.
4140         """
4141         # If this line is apart of an assert statement and the first leaf
4142         # contains the "assert" keyword...
4143         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4144             is_valid_index = is_valid_index_factory(LL)
4145
4146             for (i, leaf) in enumerate(LL):
4147                 # We MUST find a comma...
4148                 if leaf.type == token.COMMA:
4149                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4150
4151                     # That comma MUST be followed by a string...
4152                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4153                         string_idx = idx
4154
4155                         # Skip the string trailer, if one exists.
4156                         string_parser = StringParser()
4157                         idx = string_parser.parse(LL, string_idx)
4158
4159                         # But no more leaves are allowed...
4160                         if not is_valid_index(idx):
4161                             return string_idx
4162
4163         return None
4164
4165     @staticmethod
4166     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4167         """
4168         Returns:
4169             string_idx such that @LL[string_idx] is equal to our target (i.e.
4170             matched) string, if this line matches the assignment statement
4171             requirements listed in the 'Requirements' section of this classes'
4172             docstring.
4173                 OR
4174             None, otherwise.
4175         """
4176         # If this line is apart of an expression statement or is a function
4177         # argument AND the first leaf contains a variable name...
4178         if (
4179             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4180             and LL[0].type == token.NAME
4181         ):
4182             is_valid_index = is_valid_index_factory(LL)
4183
4184             for (i, leaf) in enumerate(LL):
4185                 # We MUST find either an '=' or '+=' symbol...
4186                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4187                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4188
4189                     # That symbol MUST be followed by a string...
4190                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4191                         string_idx = idx
4192
4193                         # Skip the string trailer, if one exists.
4194                         string_parser = StringParser()
4195                         idx = string_parser.parse(LL, string_idx)
4196
4197                         # The next leaf MAY be a comma iff this line is apart
4198                         # of a function argument...
4199                         if (
4200                             parent_type(LL[0]) == syms.argument
4201                             and is_valid_index(idx)
4202                             and LL[idx].type == token.COMMA
4203                         ):
4204                             idx += 1
4205
4206                         # But no more leaves are allowed...
4207                         if not is_valid_index(idx):
4208                             return string_idx
4209
4210         return None
4211
4212     @staticmethod
4213     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4214         """
4215         Returns:
4216             string_idx such that @LL[string_idx] is equal to our target (i.e.
4217             matched) string, if this line matches the dictionary key assignment
4218             statement requirements listed in the 'Requirements' section of this
4219             classes' docstring.
4220                 OR
4221             None, otherwise.
4222         """
4223         # If this line is apart of a dictionary key assignment...
4224         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4225             is_valid_index = is_valid_index_factory(LL)
4226
4227             for (i, leaf) in enumerate(LL):
4228                 # We MUST find a colon...
4229                 if leaf.type == token.COLON:
4230                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4231
4232                     # That colon MUST be followed by a string...
4233                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4234                         string_idx = idx
4235
4236                         # Skip the string trailer, if one exists.
4237                         string_parser = StringParser()
4238                         idx = string_parser.parse(LL, string_idx)
4239
4240                         # That string MAY be followed by a comma...
4241                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4242                             idx += 1
4243
4244                         # But no more leaves are allowed...
4245                         if not is_valid_index(idx):
4246                             return string_idx
4247
4248         return None
4249
4250     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4251         LL = line.leaves
4252
4253         is_valid_index = is_valid_index_factory(LL)
4254         insert_str_child = insert_str_child_factory(LL[string_idx])
4255
4256         comma_idx = -1
4257         ends_with_comma = False
4258         if LL[comma_idx].type == token.COMMA:
4259             ends_with_comma = True
4260
4261         leaves_to_steal_comments_from = [LL[string_idx]]
4262         if ends_with_comma:
4263             leaves_to_steal_comments_from.append(LL[comma_idx])
4264
4265         # --- First Line
4266         first_line = line.clone()
4267         left_leaves = LL[:string_idx]
4268
4269         # We have to remember to account for (possibly invisible) LPAR and RPAR
4270         # leaves that already wrapped the target string. If these leaves do
4271         # exist, we will replace them with our own LPAR and RPAR leaves.
4272         old_parens_exist = False
4273         if left_leaves and left_leaves[-1].type == token.LPAR:
4274             old_parens_exist = True
4275             leaves_to_steal_comments_from.append(left_leaves[-1])
4276             left_leaves.pop()
4277
4278         append_leaves(first_line, line, left_leaves)
4279
4280         lpar_leaf = Leaf(token.LPAR, "(")
4281         if old_parens_exist:
4282             replace_child(LL[string_idx - 1], lpar_leaf)
4283         else:
4284             insert_str_child(lpar_leaf)
4285         first_line.append(lpar_leaf)
4286
4287         # We throw inline comments that were originally to the right of the
4288         # target string to the top line. They will now be shown to the right of
4289         # the LPAR.
4290         for leaf in leaves_to_steal_comments_from:
4291             for comment_leaf in line.comments_after(leaf):
4292                 first_line.append(comment_leaf, preformatted=True)
4293
4294         yield Ok(first_line)
4295
4296         # --- Middle (String) Line
4297         # We only need to yield one (possibly too long) string line, since the
4298         # `StringSplitter` will break it down further if necessary.
4299         string_value = LL[string_idx].value
4300         string_line = Line(
4301             depth=line.depth + 1,
4302             inside_brackets=True,
4303             should_explode=line.should_explode,
4304         )
4305         string_leaf = Leaf(token.STRING, string_value)
4306         insert_str_child(string_leaf)
4307         string_line.append(string_leaf)
4308
4309         old_rpar_leaf = None
4310         if is_valid_index(string_idx + 1):
4311             right_leaves = LL[string_idx + 1 :]
4312             if ends_with_comma:
4313                 right_leaves.pop()
4314
4315             if old_parens_exist:
4316                 assert (
4317                     right_leaves and right_leaves[-1].type == token.RPAR
4318                 ), "Apparently, old parentheses do NOT exist?!"
4319                 old_rpar_leaf = right_leaves.pop()
4320
4321             append_leaves(string_line, line, right_leaves)
4322
4323         yield Ok(string_line)
4324
4325         # --- Last Line
4326         last_line = line.clone()
4327         last_line.bracket_tracker = first_line.bracket_tracker
4328
4329         new_rpar_leaf = Leaf(token.RPAR, ")")
4330         if old_rpar_leaf is not None:
4331             replace_child(old_rpar_leaf, new_rpar_leaf)
4332         else:
4333             insert_str_child(new_rpar_leaf)
4334         last_line.append(new_rpar_leaf)
4335
4336         # If the target string ended with a comma, we place this comma to the
4337         # right of the RPAR on the last line.
4338         if ends_with_comma:
4339             comma_leaf = Leaf(token.COMMA, ",")
4340             replace_child(LL[comma_idx], comma_leaf)
4341             last_line.append(comma_leaf)
4342
4343         yield Ok(last_line)
4344
4345
4346 class StringParser:
4347     """
4348     A state machine that aids in parsing a string's "trailer", which can be
4349     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4350     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4351     varY)`).
4352
4353     NOTE: A new StringParser object MUST be instantiated for each string
4354     trailer we need to parse.
4355
4356     Examples:
4357         We shall assume that `line` equals the `Line` object that corresponds
4358         to the following line of python code:
4359         ```
4360         x = "Some {}.".format("String") + some_other_string
4361         ```
4362
4363         Furthermore, we will assume that `string_idx` is some index such that:
4364         ```
4365         assert line.leaves[string_idx].value == "Some {}."
4366         ```
4367
4368         The following code snippet then holds:
4369         ```
4370         string_parser = StringParser()
4371         idx = string_parser.parse(line.leaves, string_idx)
4372         assert line.leaves[idx].type == token.PLUS
4373         ```
4374     """
4375
4376     DEFAULT_TOKEN = -1
4377
4378     # String Parser States
4379     START = 1
4380     DOT = 2
4381     NAME = 3
4382     PERCENT = 4
4383     SINGLE_FMT_ARG = 5
4384     LPAR = 6
4385     RPAR = 7
4386     DONE = 8
4387
4388     # Lookup Table for Next State
4389     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4390         # A string trailer may start with '.' OR '%'.
4391         (START, token.DOT): DOT,
4392         (START, token.PERCENT): PERCENT,
4393         (START, DEFAULT_TOKEN): DONE,
4394         # A '.' MUST be followed by an attribute or method name.
4395         (DOT, token.NAME): NAME,
4396         # A method name MUST be followed by an '(', whereas an attribute name
4397         # is the last symbol in the string trailer.
4398         (NAME, token.LPAR): LPAR,
4399         (NAME, DEFAULT_TOKEN): DONE,
4400         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4401         # string or variable name).
4402         (PERCENT, token.LPAR): LPAR,
4403         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4404         # If a '%' symbol is followed by a single argument, that argument is
4405         # the last leaf in the string trailer.
4406         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4407         # If present, a ')' symbol is the last symbol in a string trailer.
4408         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4409         # since they are treated as a special case by the parsing logic in this
4410         # classes' implementation.)
4411         (RPAR, DEFAULT_TOKEN): DONE,
4412     }
4413
4414     def __init__(self) -> None:
4415         self._state = self.START
4416         self._unmatched_lpars = 0
4417
4418     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4419         """
4420         Pre-conditions:
4421             * @leaves[@string_idx].type == token.STRING
4422
4423         Returns:
4424             The index directly after the last leaf which is apart of the string
4425             trailer, if a "trailer" exists.
4426                 OR
4427             @string_idx + 1, if no string "trailer" exists.
4428         """
4429         assert leaves[string_idx].type == token.STRING
4430
4431         idx = string_idx + 1
4432         while idx < len(leaves) and self._next_state(leaves[idx]):
4433             idx += 1
4434         return idx
4435
4436     def _next_state(self, leaf: Leaf) -> bool:
4437         """
4438         Pre-conditions:
4439             * On the first call to this function, @leaf MUST be the leaf that
4440             was directly after the string leaf in question (e.g. if our target
4441             string is `line.leaves[i]` then the first call to this method must
4442             be `line.leaves[i + 1]`).
4443             * On the next call to this function, the leaf parameter passed in
4444             MUST be the leaf directly following @leaf.
4445
4446         Returns:
4447             True iff @leaf is apart of the string's trailer.
4448         """
4449         # We ignore empty LPAR or RPAR leaves.
4450         if is_empty_par(leaf):
4451             return True
4452
4453         next_token = leaf.type
4454         if next_token == token.LPAR:
4455             self._unmatched_lpars += 1
4456
4457         current_state = self._state
4458
4459         # The LPAR parser state is a special case. We will return True until we
4460         # find the matching RPAR token.
4461         if current_state == self.LPAR:
4462             if next_token == token.RPAR:
4463                 self._unmatched_lpars -= 1
4464                 if self._unmatched_lpars == 0:
4465                     self._state = self.RPAR
4466         # Otherwise, we use a lookup table to determine the next state.
4467         else:
4468             # If the lookup table matches the current state to the next
4469             # token, we use the lookup table.
4470             if (current_state, next_token) in self._goto:
4471                 self._state = self._goto[current_state, next_token]
4472             else:
4473                 # Otherwise, we check if a the current state was assigned a
4474                 # default.
4475                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4476                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4477                 # If no default has been assigned, then this parser has a logic
4478                 # error.
4479                 else:
4480                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4481
4482             if self._state == self.DONE:
4483                 return False
4484
4485         return True
4486
4487
4488 def TErr(err_msg: str) -> Err[CannotTransform]:
4489     """(T)ransform Err
4490
4491     Convenience function used when working with the TResult type.
4492     """
4493     cant_transform = CannotTransform(err_msg)
4494     return Err(cant_transform)
4495
4496
4497 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4498     """
4499     Returns:
4500         True iff one of the comments in @comment_list is a pragma used by one
4501         of the more common static analysis tools for python (e.g. mypy, flake8,
4502         pylint).
4503     """
4504     for comment in comment_list:
4505         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4506             return True
4507
4508     return False
4509
4510
4511 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4512     """
4513     Factory for a convenience function that is used to orphan @string_leaf
4514     and then insert multiple new leaves into the same part of the node
4515     structure that @string_leaf had originally occupied.
4516
4517     Examples:
4518         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4519         string_leaf.parent`. Assume the node `N` has the following
4520         original structure:
4521
4522         Node(
4523             expr_stmt, [
4524                 Leaf(NAME, 'x'),
4525                 Leaf(EQUAL, '='),
4526                 Leaf(STRING, '"foo"'),
4527             ]
4528         )
4529
4530         We then run the code snippet shown below.
4531         ```
4532         insert_str_child = insert_str_child_factory(string_leaf)
4533
4534         lpar = Leaf(token.LPAR, '(')
4535         insert_str_child(lpar)
4536
4537         bar = Leaf(token.STRING, '"bar"')
4538         insert_str_child(bar)
4539
4540         rpar = Leaf(token.RPAR, ')')
4541         insert_str_child(rpar)
4542         ```
4543
4544         After which point, it follows that `string_leaf.parent is None` and
4545         the node `N` now has the following structure:
4546
4547         Node(
4548             expr_stmt, [
4549                 Leaf(NAME, 'x'),
4550                 Leaf(EQUAL, '='),
4551                 Leaf(LPAR, '('),
4552                 Leaf(STRING, '"bar"'),
4553                 Leaf(RPAR, ')'),
4554             ]
4555         )
4556     """
4557     string_parent = string_leaf.parent
4558     string_child_idx = string_leaf.remove()
4559
4560     def insert_str_child(child: LN) -> None:
4561         nonlocal string_child_idx
4562
4563         assert string_parent is not None
4564         assert string_child_idx is not None
4565
4566         string_parent.insert_child(string_child_idx, child)
4567         string_child_idx += 1
4568
4569     return insert_str_child
4570
4571
4572 def has_triple_quotes(string: str) -> bool:
4573     """
4574     Returns:
4575         True iff @string starts with three quotation characters.
4576     """
4577     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4578     return raw_string[:3] in {'"""', "'''"}
4579
4580
4581 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4582     """
4583     Returns:
4584         @node.parent.type, if @node is not None and has a parent.
4585             OR
4586         None, otherwise.
4587     """
4588     if node is None or node.parent is None:
4589         return None
4590
4591     return node.parent.type
4592
4593
4594 def is_empty_par(leaf: Leaf) -> bool:
4595     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4596
4597
4598 def is_empty_lpar(leaf: Leaf) -> bool:
4599     return leaf.type == token.LPAR and leaf.value == ""
4600
4601
4602 def is_empty_rpar(leaf: Leaf) -> bool:
4603     return leaf.type == token.RPAR and leaf.value == ""
4604
4605
4606 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4607     """
4608     Examples:
4609         ```
4610         my_list = [1, 2, 3]
4611
4612         is_valid_index = is_valid_index_factory(my_list)
4613
4614         assert is_valid_index(0)
4615         assert is_valid_index(2)
4616
4617         assert not is_valid_index(3)
4618         assert not is_valid_index(-1)
4619         ```
4620     """
4621
4622     def is_valid_index(idx: int) -> bool:
4623         """
4624         Returns:
4625             True iff @idx is positive AND seq[@idx] does NOT raise an
4626             IndexError.
4627         """
4628         return 0 <= idx < len(seq)
4629
4630     return is_valid_index
4631
4632
4633 def line_to_string(line: Line) -> str:
4634     """Returns the string representation of @line.
4635
4636     WARNING: This is known to be computationally expensive.
4637     """
4638     return str(line).strip("\n")
4639
4640
4641 def append_leaves(
4642     new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
4643 ) -> None:
4644     """
4645     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4646     underlying Node structure where appropriate.
4647
4648     All of the leaves in @leaves are duplicated. The duplicates are then
4649     appended to @new_line and used to replace their originals in the underlying
4650     Node structure. Any comments attached to the old leaves are reattached to
4651     the new leaves.
4652
4653     Pre-conditions:
4654         set(@leaves) is a subset of set(@old_line.leaves).
4655     """
4656     for old_leaf in leaves:
4657         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4658         replace_child(old_leaf, new_leaf)
4659         new_line.append(new_leaf, preformatted=preformatted)
4660
4661         for comment_leaf in old_line.comments_after(old_leaf):
4662             new_line.append(comment_leaf, preformatted=True)
4663
4664
4665 def replace_child(old_child: LN, new_child: LN) -> None:
4666     """
4667     Side Effects:
4668         * If @old_child.parent is set, replace @old_child with @new_child in
4669         @old_child's underlying Node structure.
4670             OR
4671         * Otherwise, this function does nothing.
4672     """
4673     parent = old_child.parent
4674     if not parent:
4675         return
4676
4677     child_idx = old_child.remove()
4678     if child_idx is not None:
4679         parent.insert_child(child_idx, new_child)
4680
4681
4682 def get_string_prefix(string: str) -> str:
4683     """
4684     Pre-conditions:
4685         * assert_is_leaf_string(@string)
4686
4687     Returns:
4688         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4689     """
4690     assert_is_leaf_string(string)
4691
4692     prefix = ""
4693     prefix_idx = 0
4694     while string[prefix_idx] in STRING_PREFIX_CHARS:
4695         prefix += string[prefix_idx].lower()
4696         prefix_idx += 1
4697
4698     return prefix
4699
4700
4701 def assert_is_leaf_string(string: str) -> None:
4702     """
4703     Checks the pre-condition that @string has the format that you would expect
4704     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4705     token.STRING`. A more precise description of the pre-conditions that are
4706     checked are listed below.
4707
4708     Pre-conditions:
4709         * @string starts with either ', ", <prefix>', or <prefix>" where
4710         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4711         * @string ends with a quote character (' or ").
4712
4713     Raises:
4714         AssertionError(...) if the pre-conditions listed above are not
4715         satisfied.
4716     """
4717     dquote_idx = string.find('"')
4718     squote_idx = string.find("'")
4719     if -1 in [dquote_idx, squote_idx]:
4720         quote_idx = max(dquote_idx, squote_idx)
4721     else:
4722         quote_idx = min(squote_idx, dquote_idx)
4723
4724     assert (
4725         0 <= quote_idx < len(string) - 1
4726     ), f"{string!r} is missing a starting quote character (' or \")."
4727     assert string[-1] in (
4728         "'",
4729         '"',
4730     ), f"{string!r} is missing an ending quote character (' or \")."
4731     assert set(string[:quote_idx]).issubset(
4732         set(STRING_PREFIX_CHARS)
4733     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4734
4735
4736 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4737     """Split line into many lines, starting with the first matching bracket pair.
4738
4739     Note: this usually looks weird, only use this for function definitions.
4740     Prefer RHS otherwise.  This is why this function is not symmetrical with
4741     :func:`right_hand_split` which also handles optional parentheses.
4742     """
4743     tail_leaves: List[Leaf] = []
4744     body_leaves: List[Leaf] = []
4745     head_leaves: List[Leaf] = []
4746     current_leaves = head_leaves
4747     matching_bracket: Optional[Leaf] = None
4748     for leaf in line.leaves:
4749         if (
4750             current_leaves is body_leaves
4751             and leaf.type in CLOSING_BRACKETS
4752             and leaf.opening_bracket is matching_bracket
4753         ):
4754             current_leaves = tail_leaves if body_leaves else head_leaves
4755         current_leaves.append(leaf)
4756         if current_leaves is head_leaves:
4757             if leaf.type in OPENING_BRACKETS:
4758                 matching_bracket = leaf
4759                 current_leaves = body_leaves
4760     if not matching_bracket:
4761         raise CannotSplit("No brackets found")
4762
4763     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4764     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4765     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4766     bracket_split_succeeded_or_raise(head, body, tail)
4767     for result in (head, body, tail):
4768         if result:
4769             yield result
4770
4771
4772 def right_hand_split(
4773     line: Line,
4774     line_length: int,
4775     features: Collection[Feature] = (),
4776     omit: Collection[LeafID] = (),
4777 ) -> Iterator[Line]:
4778     """Split line into many lines, starting with the last matching bracket pair.
4779
4780     If the split was by optional parentheses, attempt splitting without them, too.
4781     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4782     this split.
4783
4784     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4785     """
4786     tail_leaves: List[Leaf] = []
4787     body_leaves: List[Leaf] = []
4788     head_leaves: List[Leaf] = []
4789     current_leaves = tail_leaves
4790     opening_bracket: Optional[Leaf] = None
4791     closing_bracket: Optional[Leaf] = None
4792     for leaf in reversed(line.leaves):
4793         if current_leaves is body_leaves:
4794             if leaf is opening_bracket:
4795                 current_leaves = head_leaves if body_leaves else tail_leaves
4796         current_leaves.append(leaf)
4797         if current_leaves is tail_leaves:
4798             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4799                 opening_bracket = leaf.opening_bracket
4800                 closing_bracket = leaf
4801                 current_leaves = body_leaves
4802     if not (opening_bracket and closing_bracket and head_leaves):
4803         # If there is no opening or closing_bracket that means the split failed and
4804         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4805         # the matching `opening_bracket` wasn't available on `line` anymore.
4806         raise CannotSplit("No brackets found")
4807
4808     tail_leaves.reverse()
4809     body_leaves.reverse()
4810     head_leaves.reverse()
4811     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4812     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4813     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4814     bracket_split_succeeded_or_raise(head, body, tail)
4815     if (
4816         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4817         # the opening bracket is an optional paren
4818         and opening_bracket.type == token.LPAR
4819         and not opening_bracket.value
4820         # the closing bracket is an optional paren
4821         and closing_bracket.type == token.RPAR
4822         and not closing_bracket.value
4823         # it's not an import (optional parens are the only thing we can split on
4824         # in this case; attempting a split without them is a waste of time)
4825         and not line.is_import
4826         # there are no standalone comments in the body
4827         and not body.contains_standalone_comments(0)
4828         # and we can actually remove the parens
4829         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4830     ):
4831         omit = {id(closing_bracket), *omit}
4832         try:
4833             yield from right_hand_split(line, line_length, features=features, omit=omit)
4834             return
4835
4836         except CannotSplit:
4837             if not (
4838                 can_be_split(body)
4839                 or is_line_short_enough(body, line_length=line_length)
4840             ):
4841                 raise CannotSplit(
4842                     "Splitting failed, body is still too long and can't be split."
4843                 )
4844
4845             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4846                 raise CannotSplit(
4847                     "The current optional pair of parentheses is bound to fail to"
4848                     " satisfy the splitting algorithm because the head or the tail"
4849                     " contains multiline strings which by definition never fit one"
4850                     " line."
4851                 )
4852
4853     ensure_visible(opening_bracket)
4854     ensure_visible(closing_bracket)
4855     for result in (head, body, tail):
4856         if result:
4857             yield result
4858
4859
4860 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4861     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4862
4863     Do nothing otherwise.
4864
4865     A left- or right-hand split is based on a pair of brackets. Content before
4866     (and including) the opening bracket is left on one line, content inside the
4867     brackets is put on a separate line, and finally content starting with and
4868     following the closing bracket is put on a separate line.
4869
4870     Those are called `head`, `body`, and `tail`, respectively. If the split
4871     produced the same line (all content in `head`) or ended up with an empty `body`
4872     and the `tail` is just the closing bracket, then it's considered failed.
4873     """
4874     tail_len = len(str(tail).strip())
4875     if not body:
4876         if tail_len == 0:
4877             raise CannotSplit("Splitting brackets produced the same line")
4878
4879         elif tail_len < 3:
4880             raise CannotSplit(
4881                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4882                 " not worth it"
4883             )
4884
4885
4886 def bracket_split_build_line(
4887     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4888 ) -> Line:
4889     """Return a new line with given `leaves` and respective comments from `original`.
4890
4891     If `is_body` is True, the result line is one-indented inside brackets and as such
4892     has its first leaf's prefix normalized and a trailing comma added when expected.
4893     """
4894     result = Line(depth=original.depth)
4895     if is_body:
4896         result.inside_brackets = True
4897         result.depth += 1
4898         if leaves:
4899             # Since body is a new indent level, remove spurious leading whitespace.
4900             normalize_prefix(leaves[0], inside_brackets=True)
4901             # Ensure a trailing comma for imports and standalone function arguments, but
4902             # be careful not to add one after any comments or within type annotations.
4903             no_commas = (
4904                 original.is_def
4905                 and opening_bracket.value == "("
4906                 and not any(leaf.type == token.COMMA for leaf in leaves)
4907             )
4908
4909             if original.is_import or no_commas:
4910                 for i in range(len(leaves) - 1, -1, -1):
4911                     if leaves[i].type == STANDALONE_COMMENT:
4912                         continue
4913
4914                     if leaves[i].type != token.COMMA:
4915                         new_comma = Leaf(token.COMMA, ",")
4916                         leaves.insert(i + 1, new_comma)
4917                     break
4918
4919     # Populate the line
4920     for leaf in leaves:
4921         result.append(leaf, preformatted=True)
4922         for comment_after in original.comments_after(leaf):
4923             result.append(comment_after, preformatted=True)
4924     if is_body and should_split_body_explode(result, opening_bracket):
4925         result.should_explode = True
4926     return result
4927
4928
4929 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4930     """Normalize prefix of the first leaf in every line returned by `split_func`.
4931
4932     This is a decorator over relevant split functions.
4933     """
4934
4935     @wraps(split_func)
4936     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4937         for line in split_func(line, features):
4938             normalize_prefix(line.leaves[0], inside_brackets=True)
4939             yield line
4940
4941     return split_wrapper
4942
4943
4944 @dont_increase_indentation
4945 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4946     """Split according to delimiters of the highest priority.
4947
4948     If the appropriate Features are given, the split will add trailing commas
4949     also in function signatures and calls that contain `*` and `**`.
4950     """
4951     try:
4952         last_leaf = line.leaves[-1]
4953     except IndexError:
4954         raise CannotSplit("Line empty")
4955
4956     bt = line.bracket_tracker
4957     try:
4958         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4959     except ValueError:
4960         raise CannotSplit("No delimiters found")
4961
4962     if delimiter_priority == DOT_PRIORITY:
4963         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4964             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4965
4966     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4967     lowest_depth = sys.maxsize
4968     trailing_comma_safe = True
4969
4970     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4971         """Append `leaf` to current line or to new line if appending impossible."""
4972         nonlocal current_line
4973         try:
4974             current_line.append_safe(leaf, preformatted=True)
4975         except ValueError:
4976             yield current_line
4977
4978             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4979             current_line.append(leaf)
4980
4981     for leaf in line.leaves:
4982         yield from append_to_line(leaf)
4983
4984         for comment_after in line.comments_after(leaf):
4985             yield from append_to_line(comment_after)
4986
4987         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4988         if leaf.bracket_depth == lowest_depth:
4989             if is_vararg(leaf, within={syms.typedargslist}):
4990                 trailing_comma_safe = (
4991                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4992                 )
4993             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4994                 trailing_comma_safe = (
4995                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4996                 )
4997
4998         leaf_priority = bt.delimiters.get(id(leaf))
4999         if leaf_priority == delimiter_priority:
5000             yield current_line
5001
5002             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5003     if current_line:
5004         if (
5005             trailing_comma_safe
5006             and delimiter_priority == COMMA_PRIORITY
5007             and current_line.leaves[-1].type != token.COMMA
5008             and current_line.leaves[-1].type != STANDALONE_COMMENT
5009         ):
5010             new_comma = Leaf(token.COMMA, ",")
5011             current_line.append(new_comma)
5012         yield current_line
5013
5014
5015 @dont_increase_indentation
5016 def standalone_comment_split(
5017     line: Line, features: Collection[Feature] = ()
5018 ) -> Iterator[Line]:
5019     """Split standalone comments from the rest of the line."""
5020     if not line.contains_standalone_comments(0):
5021         raise CannotSplit("Line does not have any standalone comments")
5022
5023     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5024
5025     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5026         """Append `leaf` to current line or to new line if appending impossible."""
5027         nonlocal current_line
5028         try:
5029             current_line.append_safe(leaf, preformatted=True)
5030         except ValueError:
5031             yield current_line
5032
5033             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5034             current_line.append(leaf)
5035
5036     for leaf in line.leaves:
5037         yield from append_to_line(leaf)
5038
5039         for comment_after in line.comments_after(leaf):
5040             yield from append_to_line(comment_after)
5041
5042     if current_line:
5043         yield current_line
5044
5045
5046 def is_import(leaf: Leaf) -> bool:
5047     """Return True if the given leaf starts an import statement."""
5048     p = leaf.parent
5049     t = leaf.type
5050     v = leaf.value
5051     return bool(
5052         t == token.NAME
5053         and (
5054             (v == "import" and p and p.type == syms.import_name)
5055             or (v == "from" and p and p.type == syms.import_from)
5056         )
5057     )
5058
5059
5060 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5061     """Return True if the given leaf is a special comment.
5062     Only returns true for type comments for now."""
5063     t = leaf.type
5064     v = leaf.value
5065     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5066
5067
5068 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5069     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5070     else.
5071
5072     Note: don't use backslashes for formatting or you'll lose your voting rights.
5073     """
5074     if not inside_brackets:
5075         spl = leaf.prefix.split("#")
5076         if "\\" not in spl[0]:
5077             nl_count = spl[-1].count("\n")
5078             if len(spl) > 1:
5079                 nl_count -= 1
5080             leaf.prefix = "\n" * nl_count
5081             return
5082
5083     leaf.prefix = ""
5084
5085
5086 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5087     """Make all string prefixes lowercase.
5088
5089     If remove_u_prefix is given, also removes any u prefix from the string.
5090
5091     Note: Mutates its argument.
5092     """
5093     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5094     assert match is not None, f"failed to match string {leaf.value!r}"
5095     orig_prefix = match.group(1)
5096     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5097     if remove_u_prefix:
5098         new_prefix = new_prefix.replace("u", "")
5099     leaf.value = f"{new_prefix}{match.group(2)}"
5100
5101
5102 def normalize_string_quotes(leaf: Leaf) -> None:
5103     """Prefer double quotes but only if it doesn't cause more escaping.
5104
5105     Adds or removes backslashes as appropriate. Doesn't parse and fix
5106     strings nested in f-strings (yet).
5107
5108     Note: Mutates its argument.
5109     """
5110     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5111     if value[:3] == '"""':
5112         return
5113
5114     elif value[:3] == "'''":
5115         orig_quote = "'''"
5116         new_quote = '"""'
5117     elif value[0] == '"':
5118         orig_quote = '"'
5119         new_quote = "'"
5120     else:
5121         orig_quote = "'"
5122         new_quote = '"'
5123     first_quote_pos = leaf.value.find(orig_quote)
5124     if first_quote_pos == -1:
5125         return  # There's an internal error
5126
5127     prefix = leaf.value[:first_quote_pos]
5128     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5129     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5130     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5131     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5132     if "r" in prefix.casefold():
5133         if unescaped_new_quote.search(body):
5134             # There's at least one unescaped new_quote in this raw string
5135             # so converting is impossible
5136             return
5137
5138         # Do not introduce or remove backslashes in raw strings
5139         new_body = body
5140     else:
5141         # remove unnecessary escapes
5142         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5143         if body != new_body:
5144             # Consider the string without unnecessary escapes as the original
5145             body = new_body
5146             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5147         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5148         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5149     if "f" in prefix.casefold():
5150         matches = re.findall(
5151             r"""
5152             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5153                 ([^{].*?)  # contents of the brackets except if begins with {{
5154             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5155             """,
5156             new_body,
5157             re.VERBOSE,
5158         )
5159         for m in matches:
5160             if "\\" in str(m):
5161                 # Do not introduce backslashes in interpolated expressions
5162                 return
5163
5164     if new_quote == '"""' and new_body[-1:] == '"':
5165         # edge case:
5166         new_body = new_body[:-1] + '\\"'
5167     orig_escape_count = body.count("\\")
5168     new_escape_count = new_body.count("\\")
5169     if new_escape_count > orig_escape_count:
5170         return  # Do not introduce more escaping
5171
5172     if new_escape_count == orig_escape_count and orig_quote == '"':
5173         return  # Prefer double quotes
5174
5175     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5176
5177
5178 def normalize_numeric_literal(leaf: Leaf) -> None:
5179     """Normalizes numeric (float, int, and complex) literals.
5180
5181     All letters used in the representation are normalized to lowercase (except
5182     in Python 2 long literals).
5183     """
5184     text = leaf.value.lower()
5185     if text.startswith(("0o", "0b")):
5186         # Leave octal and binary literals alone.
5187         pass
5188     elif text.startswith("0x"):
5189         # Change hex literals to upper case.
5190         before, after = text[:2], text[2:]
5191         text = f"{before}{after.upper()}"
5192     elif "e" in text:
5193         before, after = text.split("e")
5194         sign = ""
5195         if after.startswith("-"):
5196             after = after[1:]
5197             sign = "-"
5198         elif after.startswith("+"):
5199             after = after[1:]
5200         before = format_float_or_int_string(before)
5201         text = f"{before}e{sign}{after}"
5202     elif text.endswith(("j", "l")):
5203         number = text[:-1]
5204         suffix = text[-1]
5205         # Capitalize in "2L" because "l" looks too similar to "1".
5206         if suffix == "l":
5207             suffix = "L"
5208         text = f"{format_float_or_int_string(number)}{suffix}"
5209     else:
5210         text = format_float_or_int_string(text)
5211     leaf.value = text
5212
5213
5214 def format_float_or_int_string(text: str) -> str:
5215     """Formats a float string like "1.0"."""
5216     if "." not in text:
5217         return text
5218
5219     before, after = text.split(".")
5220     return f"{before or 0}.{after or 0}"
5221
5222
5223 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5224     """Make existing optional parentheses invisible or create new ones.
5225
5226     `parens_after` is a set of string leaf values immediately after which parens
5227     should be put.
5228
5229     Standardizes on visible parentheses for single-element tuples, and keeps
5230     existing visible parentheses for other tuples and generator expressions.
5231     """
5232     for pc in list_comments(node.prefix, is_endmarker=False):
5233         if pc.value in FMT_OFF:
5234             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5235             return
5236     check_lpar = False
5237     for index, child in enumerate(list(node.children)):
5238         # Fixes a bug where invisible parens are not properly stripped from
5239         # assignment statements that contain type annotations.
5240         if isinstance(child, Node) and child.type == syms.annassign:
5241             normalize_invisible_parens(child, parens_after=parens_after)
5242
5243         # Add parentheses around long tuple unpacking in assignments.
5244         if (
5245             index == 0
5246             and isinstance(child, Node)
5247             and child.type == syms.testlist_star_expr
5248         ):
5249             check_lpar = True
5250
5251         if check_lpar:
5252             if is_walrus_assignment(child):
5253                 pass
5254
5255             elif child.type == syms.atom:
5256                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5257                     wrap_in_parentheses(node, child, visible=False)
5258             elif is_one_tuple(child):
5259                 wrap_in_parentheses(node, child, visible=True)
5260             elif node.type == syms.import_from:
5261                 # "import from" nodes store parentheses directly as part of
5262                 # the statement
5263                 if child.type == token.LPAR:
5264                     # make parentheses invisible
5265                     child.value = ""  # type: ignore
5266                     node.children[-1].value = ""  # type: ignore
5267                 elif child.type != token.STAR:
5268                     # insert invisible parentheses
5269                     node.insert_child(index, Leaf(token.LPAR, ""))
5270                     node.append_child(Leaf(token.RPAR, ""))
5271                 break
5272
5273             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5274                 wrap_in_parentheses(node, child, visible=False)
5275
5276         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5277
5278
5279 def normalize_fmt_off(node: Node) -> None:
5280     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5281     try_again = True
5282     while try_again:
5283         try_again = convert_one_fmt_off_pair(node)
5284
5285
5286 def convert_one_fmt_off_pair(node: Node) -> bool:
5287     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5288
5289     Returns True if a pair was converted.
5290     """
5291     for leaf in node.leaves():
5292         previous_consumed = 0
5293         for comment in list_comments(leaf.prefix, is_endmarker=False):
5294             if comment.value in FMT_OFF:
5295                 # We only want standalone comments. If there's no previous leaf or
5296                 # the previous leaf is indentation, it's a standalone comment in
5297                 # disguise.
5298                 if comment.type != STANDALONE_COMMENT:
5299                     prev = preceding_leaf(leaf)
5300                     if prev and prev.type not in WHITESPACE:
5301                         continue
5302
5303                 ignored_nodes = list(generate_ignored_nodes(leaf))
5304                 if not ignored_nodes:
5305                     continue
5306
5307                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5308                 parent = first.parent
5309                 prefix = first.prefix
5310                 first.prefix = prefix[comment.consumed :]
5311                 hidden_value = (
5312                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5313                 )
5314                 if hidden_value.endswith("\n"):
5315                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5316                     # leaf (possibly followed by a DEDENT).
5317                     hidden_value = hidden_value[:-1]
5318                 first_idx: Optional[int] = None
5319                 for ignored in ignored_nodes:
5320                     index = ignored.remove()
5321                     if first_idx is None:
5322                         first_idx = index
5323                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5324                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5325                 parent.insert_child(
5326                     first_idx,
5327                     Leaf(
5328                         STANDALONE_COMMENT,
5329                         hidden_value,
5330                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5331                     ),
5332                 )
5333                 return True
5334
5335             previous_consumed = comment.consumed
5336
5337     return False
5338
5339
5340 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5341     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5342
5343     Stops at the end of the block.
5344     """
5345     container: Optional[LN] = container_of(leaf)
5346     while container is not None and container.type != token.ENDMARKER:
5347         if is_fmt_on(container):
5348             return
5349
5350         # fix for fmt: on in children
5351         if contains_fmt_on_at_column(container, leaf.column):
5352             for child in container.children:
5353                 if contains_fmt_on_at_column(child, leaf.column):
5354                     return
5355                 yield child
5356         else:
5357             yield container
5358             container = container.next_sibling
5359
5360
5361 def is_fmt_on(container: LN) -> bool:
5362     """Determine whether formatting is switched on within a container.
5363     Determined by whether the last `# fmt:` comment is `on` or `off`.
5364     """
5365     fmt_on = False
5366     for comment in list_comments(container.prefix, is_endmarker=False):
5367         if comment.value in FMT_ON:
5368             fmt_on = True
5369         elif comment.value in FMT_OFF:
5370             fmt_on = False
5371     return fmt_on
5372
5373
5374 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5375     """Determine if children at a given column have formatting switched on."""
5376     for child in container.children:
5377         if (
5378             isinstance(child, Node)
5379             and first_leaf_column(child) == column
5380             or isinstance(child, Leaf)
5381             and child.column == column
5382         ):
5383             if is_fmt_on(child):
5384                 return True
5385
5386     return False
5387
5388
5389 def first_leaf_column(node: Node) -> Optional[int]:
5390     """Returns the column of the first leaf child of a node."""
5391     for child in node.children:
5392         if isinstance(child, Leaf):
5393             return child.column
5394     return None
5395
5396
5397 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5398     """If it's safe, make the parens in the atom `node` invisible, recursively.
5399     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5400     as they are redundant.
5401
5402     Returns whether the node should itself be wrapped in invisible parentheses.
5403
5404     """
5405     if (
5406         node.type != syms.atom
5407         or is_empty_tuple(node)
5408         or is_one_tuple(node)
5409         or (is_yield(node) and parent.type != syms.expr_stmt)
5410         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5411     ):
5412         return False
5413
5414     first = node.children[0]
5415     last = node.children[-1]
5416     if first.type == token.LPAR and last.type == token.RPAR:
5417         middle = node.children[1]
5418         # make parentheses invisible
5419         first.value = ""  # type: ignore
5420         last.value = ""  # type: ignore
5421         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5422
5423         if is_atom_with_invisible_parens(middle):
5424             # Strip the invisible parens from `middle` by replacing
5425             # it with the child in-between the invisible parens
5426             middle.replace(middle.children[1])
5427
5428         return False
5429
5430     return True
5431
5432
5433 def is_atom_with_invisible_parens(node: LN) -> bool:
5434     """Given a `LN`, determines whether it's an atom `node` with invisible
5435     parens. Useful in dedupe-ing and normalizing parens.
5436     """
5437     if isinstance(node, Leaf) or node.type != syms.atom:
5438         return False
5439
5440     first, last = node.children[0], node.children[-1]
5441     return (
5442         isinstance(first, Leaf)
5443         and first.type == token.LPAR
5444         and first.value == ""
5445         and isinstance(last, Leaf)
5446         and last.type == token.RPAR
5447         and last.value == ""
5448     )
5449
5450
5451 def is_empty_tuple(node: LN) -> bool:
5452     """Return True if `node` holds an empty tuple."""
5453     return (
5454         node.type == syms.atom
5455         and len(node.children) == 2
5456         and node.children[0].type == token.LPAR
5457         and node.children[1].type == token.RPAR
5458     )
5459
5460
5461 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5462     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5463
5464     Parenthesis can be optional. Returns None otherwise"""
5465     if len(node.children) != 3:
5466         return None
5467
5468     lpar, wrapped, rpar = node.children
5469     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5470         return None
5471
5472     return wrapped
5473
5474
5475 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5476     """Wrap `child` in parentheses.
5477
5478     This replaces `child` with an atom holding the parentheses and the old
5479     child.  That requires moving the prefix.
5480
5481     If `visible` is False, the leaves will be valueless (and thus invisible).
5482     """
5483     lpar = Leaf(token.LPAR, "(" if visible else "")
5484     rpar = Leaf(token.RPAR, ")" if visible else "")
5485     prefix = child.prefix
5486     child.prefix = ""
5487     index = child.remove() or 0
5488     new_child = Node(syms.atom, [lpar, child, rpar])
5489     new_child.prefix = prefix
5490     parent.insert_child(index, new_child)
5491
5492
5493 def is_one_tuple(node: LN) -> bool:
5494     """Return True if `node` holds a tuple with one element, with or without parens."""
5495     if node.type == syms.atom:
5496         gexp = unwrap_singleton_parenthesis(node)
5497         if gexp is None or gexp.type != syms.testlist_gexp:
5498             return False
5499
5500         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5501
5502     return (
5503         node.type in IMPLICIT_TUPLE
5504         and len(node.children) == 2
5505         and node.children[1].type == token.COMMA
5506     )
5507
5508
5509 def is_walrus_assignment(node: LN) -> bool:
5510     """Return True iff `node` is of the shape ( test := test )"""
5511     inner = unwrap_singleton_parenthesis(node)
5512     return inner is not None and inner.type == syms.namedexpr_test
5513
5514
5515 def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
5516     """Return True iff `node` is a trailer valid in a simple decorator"""
5517     return node.type == syms.trailer and (
5518         (
5519             len(node.children) == 2
5520             and node.children[0].type == token.DOT
5521             and node.children[1].type == token.NAME
5522         )
5523         # last trailer can be arguments
5524         or (
5525             last
5526             and len(node.children) == 3
5527             and node.children[0].type == token.LPAR
5528             # and node.children[1].type == syms.argument
5529             and node.children[2].type == token.RPAR
5530         )
5531     )
5532
5533
5534 def is_simple_decorator_expression(node: LN) -> bool:
5535     """Return True iff `node` could be a 'dotted name' decorator
5536
5537     This function takes the node of the 'namedexpr_test' of the new decorator
5538     grammar and test if it would be valid under the old decorator grammar.
5539
5540     The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
5541     The new grammar is : decorator: @ namedexpr_test NEWLINE
5542     """
5543     if node.type == token.NAME:
5544         return True
5545     if node.type == syms.power:
5546         if node.children:
5547             return (
5548                 node.children[0].type == token.NAME
5549                 and all(map(is_simple_decorator_trailer, node.children[1:-1]))
5550                 and (
5551                     len(node.children) < 2
5552                     or is_simple_decorator_trailer(node.children[-1], last=True)
5553                 )
5554             )
5555     return False
5556
5557
5558 def is_yield(node: LN) -> bool:
5559     """Return True if `node` holds a `yield` or `yield from` expression."""
5560     if node.type == syms.yield_expr:
5561         return True
5562
5563     if node.type == token.NAME and node.value == "yield":  # type: ignore
5564         return True
5565
5566     if node.type != syms.atom:
5567         return False
5568
5569     if len(node.children) != 3:
5570         return False
5571
5572     lpar, expr, rpar = node.children
5573     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5574         return is_yield(expr)
5575
5576     return False
5577
5578
5579 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5580     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5581
5582     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5583     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5584     extended iterable unpacking (PEP 3132) and additional unpacking
5585     generalizations (PEP 448).
5586     """
5587     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5588         return False
5589
5590     p = leaf.parent
5591     if p.type == syms.star_expr:
5592         # Star expressions are also used as assignment targets in extended
5593         # iterable unpacking (PEP 3132).  See what its parent is instead.
5594         if not p.parent:
5595             return False
5596
5597         p = p.parent
5598
5599     return p.type in within
5600
5601
5602 def is_multiline_string(leaf: Leaf) -> bool:
5603     """Return True if `leaf` is a multiline string that actually spans many lines."""
5604     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5605
5606
5607 def is_stub_suite(node: Node) -> bool:
5608     """Return True if `node` is a suite with a stub body."""
5609     if (
5610         len(node.children) != 4
5611         or node.children[0].type != token.NEWLINE
5612         or node.children[1].type != token.INDENT
5613         or node.children[3].type != token.DEDENT
5614     ):
5615         return False
5616
5617     return is_stub_body(node.children[2])
5618
5619
5620 def is_stub_body(node: LN) -> bool:
5621     """Return True if `node` is a simple statement containing an ellipsis."""
5622     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5623         return False
5624
5625     if len(node.children) != 2:
5626         return False
5627
5628     child = node.children[0]
5629     return (
5630         child.type == syms.atom
5631         and len(child.children) == 3
5632         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5633     )
5634
5635
5636 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5637     """Return maximum delimiter priority inside `node`.
5638
5639     This is specific to atoms with contents contained in a pair of parentheses.
5640     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5641     """
5642     if node.type != syms.atom:
5643         return 0
5644
5645     first = node.children[0]
5646     last = node.children[-1]
5647     if not (first.type == token.LPAR and last.type == token.RPAR):
5648         return 0
5649
5650     bt = BracketTracker()
5651     for c in node.children[1:-1]:
5652         if isinstance(c, Leaf):
5653             bt.mark(c)
5654         else:
5655             for leaf in c.leaves():
5656                 bt.mark(leaf)
5657     try:
5658         return bt.max_delimiter_priority()
5659
5660     except ValueError:
5661         return 0
5662
5663
5664 def ensure_visible(leaf: Leaf) -> None:
5665     """Make sure parentheses are visible.
5666
5667     They could be invisible as part of some statements (see
5668     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5669     """
5670     if leaf.type == token.LPAR:
5671         leaf.value = "("
5672     elif leaf.type == token.RPAR:
5673         leaf.value = ")"
5674
5675
5676 def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
5677     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5678
5679     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5680         return False
5681
5682     # We're essentially checking if the body is delimited by commas and there's more
5683     # than one of them (we're excluding the trailing comma and if the delimiter priority
5684     # is still commas, that means there's more).
5685     exclude = set()
5686     trailing_comma = False
5687     try:
5688         last_leaf = line.leaves[-1]
5689         if last_leaf.type == token.COMMA:
5690             trailing_comma = True
5691             exclude.add(id(last_leaf))
5692         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5693     except (IndexError, ValueError):
5694         return False
5695
5696     return max_priority == COMMA_PRIORITY and (
5697         trailing_comma
5698         # always explode imports
5699         or opening_bracket.parent.type in {syms.atom, syms.import_from}
5700     )
5701
5702
5703 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5704     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5705     if opening.type != token.LPAR and closing.type != token.RPAR:
5706         return False
5707
5708     depth = closing.bracket_depth + 1
5709     for _opening_index, leaf in enumerate(leaves):
5710         if leaf is opening:
5711             break
5712
5713     else:
5714         raise LookupError("Opening paren not found in `leaves`")
5715
5716     commas = 0
5717     _opening_index += 1
5718     for leaf in leaves[_opening_index:]:
5719         if leaf is closing:
5720             break
5721
5722         bracket_depth = leaf.bracket_depth
5723         if bracket_depth == depth and leaf.type == token.COMMA:
5724             commas += 1
5725             if leaf.parent and leaf.parent.type in {
5726                 syms.arglist,
5727                 syms.typedargslist,
5728             }:
5729                 commas += 1
5730                 break
5731
5732     return commas < 2
5733
5734
5735 def get_features_used(node: Node) -> Set[Feature]:
5736     """Return a set of (relatively) new Python features used in this file.
5737
5738     Currently looking for:
5739     - f-strings;
5740     - underscores in numeric literals;
5741     - trailing commas after * or ** in function signatures and calls;
5742     - positional only arguments in function signatures and lambdas;
5743     - assignment expression;
5744     - relaxed decorator syntax;
5745     """
5746     features: Set[Feature] = set()
5747     for n in node.pre_order():
5748         if n.type == token.STRING:
5749             value_head = n.value[:2]  # type: ignore
5750             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5751                 features.add(Feature.F_STRINGS)
5752
5753         elif n.type == token.NUMBER:
5754             if "_" in n.value:  # type: ignore
5755                 features.add(Feature.NUMERIC_UNDERSCORES)
5756
5757         elif n.type == token.SLASH:
5758             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5759                 features.add(Feature.POS_ONLY_ARGUMENTS)
5760
5761         elif n.type == token.COLONEQUAL:
5762             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5763
5764         elif n.type == syms.decorator:
5765             if len(n.children) > 1 and not is_simple_decorator_expression(
5766                 n.children[1]
5767             ):
5768                 features.add(Feature.RELAXED_DECORATORS)
5769
5770         elif (
5771             n.type in {syms.typedargslist, syms.arglist}
5772             and n.children
5773             and n.children[-1].type == token.COMMA
5774         ):
5775             if n.type == syms.typedargslist:
5776                 feature = Feature.TRAILING_COMMA_IN_DEF
5777             else:
5778                 feature = Feature.TRAILING_COMMA_IN_CALL
5779
5780             for ch in n.children:
5781                 if ch.type in STARS:
5782                     features.add(feature)
5783
5784                 if ch.type == syms.argument:
5785                     for argch in ch.children:
5786                         if argch.type in STARS:
5787                             features.add(feature)
5788
5789     return features
5790
5791
5792 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5793     """Detect the version to target based on the nodes used."""
5794     features = get_features_used(node)
5795     return {
5796         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5797     }
5798
5799
5800 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5801     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5802
5803     Brackets can be omitted if the entire trailer up to and including
5804     a preceding closing bracket fits in one line.
5805
5806     Yielded sets are cumulative (contain results of previous yields, too).  First
5807     set is empty, unless the line should explode, in which case bracket pairs until
5808     the one that needs to explode are omitted.
5809     """
5810
5811     omit: Set[LeafID] = set()
5812     if not line.should_explode:
5813         yield omit
5814
5815     length = 4 * line.depth
5816     opening_bracket: Optional[Leaf] = None
5817     closing_bracket: Optional[Leaf] = None
5818     inner_brackets: Set[LeafID] = set()
5819     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5820         length += leaf_length
5821         if length > line_length:
5822             break
5823
5824         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5825         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5826             break
5827
5828         if opening_bracket:
5829             if leaf is opening_bracket:
5830                 opening_bracket = None
5831             elif leaf.type in CLOSING_BRACKETS:
5832                 prev = line.leaves[index - 1] if index > 0 else None
5833                 if (
5834                     line.should_explode
5835                     and prev
5836                     and prev.type == token.COMMA
5837                     and not is_one_tuple_between(
5838                         leaf.opening_bracket, leaf, line.leaves
5839                     )
5840                 ):
5841                     # Never omit bracket pairs with trailing commas.
5842                     # We need to explode on those.
5843                     break
5844
5845                 inner_brackets.add(id(leaf))
5846         elif leaf.type in CLOSING_BRACKETS:
5847             prev = line.leaves[index - 1] if index > 0 else None
5848             if prev and prev.type in OPENING_BRACKETS:
5849                 # Empty brackets would fail a split so treat them as "inner"
5850                 # brackets (e.g. only add them to the `omit` set if another
5851                 # pair of brackets was good enough.
5852                 inner_brackets.add(id(leaf))
5853                 continue
5854
5855             if closing_bracket:
5856                 omit.add(id(closing_bracket))
5857                 omit.update(inner_brackets)
5858                 inner_brackets.clear()
5859                 yield omit
5860
5861             if (
5862                 line.should_explode
5863                 and prev
5864                 and prev.type == token.COMMA
5865                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
5866             ):
5867                 # Never omit bracket pairs with trailing commas.
5868                 # We need to explode on those.
5869                 break
5870
5871             if leaf.value:
5872                 opening_bracket = leaf.opening_bracket
5873                 closing_bracket = leaf
5874
5875
5876 def get_future_imports(node: Node) -> Set[str]:
5877     """Return a set of __future__ imports in the file."""
5878     imports: Set[str] = set()
5879
5880     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5881         for child in children:
5882             if isinstance(child, Leaf):
5883                 if child.type == token.NAME:
5884                     yield child.value
5885
5886             elif child.type == syms.import_as_name:
5887                 orig_name = child.children[0]
5888                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5889                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5890                 yield orig_name.value
5891
5892             elif child.type == syms.import_as_names:
5893                 yield from get_imports_from_children(child.children)
5894
5895             else:
5896                 raise AssertionError("Invalid syntax parsing imports")
5897
5898     for child in node.children:
5899         if child.type != syms.simple_stmt:
5900             break
5901
5902         first_child = child.children[0]
5903         if isinstance(first_child, Leaf):
5904             # Continue looking if we see a docstring; otherwise stop.
5905             if (
5906                 len(child.children) == 2
5907                 and first_child.type == token.STRING
5908                 and child.children[1].type == token.NEWLINE
5909             ):
5910                 continue
5911
5912             break
5913
5914         elif first_child.type == syms.import_from:
5915             module_name = first_child.children[1]
5916             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5917                 break
5918
5919             imports |= set(get_imports_from_children(first_child.children[3:]))
5920         else:
5921             break
5922
5923     return imports
5924
5925
5926 @lru_cache()
5927 def get_gitignore(root: Path) -> PathSpec:
5928     """ Return a PathSpec matching gitignore content if present."""
5929     gitignore = root / ".gitignore"
5930     lines: List[str] = []
5931     if gitignore.is_file():
5932         with gitignore.open() as gf:
5933             lines = gf.readlines()
5934     return PathSpec.from_lines("gitwildmatch", lines)
5935
5936
5937 def normalize_path_maybe_ignore(
5938     path: Path, root: Path, report: "Report"
5939 ) -> Optional[str]:
5940     """Normalize `path`. May return `None` if `path` was ignored.
5941
5942     `report` is where "path ignored" output goes.
5943     """
5944     try:
5945         abspath = path if path.is_absolute() else Path.cwd() / path
5946         normalized_path = abspath.resolve().relative_to(root).as_posix()
5947     except OSError as e:
5948         report.path_ignored(path, f"cannot be read because {e}")
5949         return None
5950
5951     except ValueError:
5952         if path.is_symlink():
5953             report.path_ignored(path, f"is a symbolic link that points outside {root}")
5954             return None
5955
5956         raise
5957
5958     return normalized_path
5959
5960
5961 def gen_python_files(
5962     paths: Iterable[Path],
5963     root: Path,
5964     include: Optional[Pattern[str]],
5965     exclude: Pattern[str],
5966     force_exclude: Optional[Pattern[str]],
5967     report: "Report",
5968     gitignore: PathSpec,
5969 ) -> Iterator[Path]:
5970     """Generate all files under `path` whose paths are not excluded by the
5971     `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
5972
5973     Symbolic links pointing outside of the `root` directory are ignored.
5974
5975     `report` is where output about exclusions goes.
5976     """
5977     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5978     for child in paths:
5979         normalized_path = normalize_path_maybe_ignore(child, root, report)
5980         if normalized_path is None:
5981             continue
5982
5983         # First ignore files matching .gitignore
5984         if gitignore.match_file(normalized_path):
5985             report.path_ignored(child, "matches the .gitignore file content")
5986             continue
5987
5988         # Then ignore with `--exclude` and `--force-exclude` options.
5989         normalized_path = "/" + normalized_path
5990         if child.is_dir():
5991             normalized_path += "/"
5992
5993         exclude_match = exclude.search(normalized_path) if exclude else None
5994         if exclude_match and exclude_match.group(0):
5995             report.path_ignored(child, "matches the --exclude regular expression")
5996             continue
5997
5998         force_exclude_match = (
5999             force_exclude.search(normalized_path) if force_exclude else None
6000         )
6001         if force_exclude_match and force_exclude_match.group(0):
6002             report.path_ignored(child, "matches the --force-exclude regular expression")
6003             continue
6004
6005         if child.is_dir():
6006             yield from gen_python_files(
6007                 child.iterdir(),
6008                 root,
6009                 include,
6010                 exclude,
6011                 force_exclude,
6012                 report,
6013                 gitignore,
6014             )
6015
6016         elif child.is_file():
6017             include_match = include.search(normalized_path) if include else True
6018             if include_match:
6019                 yield child
6020
6021
6022 @lru_cache()
6023 def find_project_root(srcs: Iterable[str]) -> Path:
6024     """Return a directory containing .git, .hg, or pyproject.toml.
6025
6026     That directory will be a common parent of all files and directories
6027     passed in `srcs`.
6028
6029     If no directory in the tree contains a marker that would specify it's the
6030     project root, the root of the file system is returned.
6031     """
6032     if not srcs:
6033         return Path("/").resolve()
6034
6035     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
6036
6037     # A list of lists of parents for each 'src'. 'src' is included as a
6038     # "parent" of itself if it is a directory
6039     src_parents = [
6040         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
6041     ]
6042
6043     common_base = max(
6044         set.intersection(*(set(parents) for parents in src_parents)),
6045         key=lambda path: path.parts,
6046     )
6047
6048     for directory in (common_base, *common_base.parents):
6049         if (directory / ".git").exists():
6050             return directory
6051
6052         if (directory / ".hg").is_dir():
6053             return directory
6054
6055         if (directory / "pyproject.toml").is_file():
6056             return directory
6057
6058     return directory
6059
6060
6061 @dataclass
6062 class Report:
6063     """Provides a reformatting counter. Can be rendered with `str(report)`."""
6064
6065     check: bool = False
6066     diff: bool = False
6067     quiet: bool = False
6068     verbose: bool = False
6069     change_count: int = 0
6070     same_count: int = 0
6071     failure_count: int = 0
6072
6073     def done(self, src: Path, changed: Changed) -> None:
6074         """Increment the counter for successful reformatting. Write out a message."""
6075         if changed is Changed.YES:
6076             reformatted = "would reformat" if self.check or self.diff else "reformatted"
6077             if self.verbose or not self.quiet:
6078                 out(f"{reformatted} {src}")
6079             self.change_count += 1
6080         else:
6081             if self.verbose:
6082                 if changed is Changed.NO:
6083                     msg = f"{src} already well formatted, good job."
6084                 else:
6085                     msg = f"{src} wasn't modified on disk since last run."
6086                 out(msg, bold=False)
6087             self.same_count += 1
6088
6089     def failed(self, src: Path, message: str) -> None:
6090         """Increment the counter for failed reformatting. Write out a message."""
6091         err(f"error: cannot format {src}: {message}")
6092         self.failure_count += 1
6093
6094     def path_ignored(self, path: Path, message: str) -> None:
6095         if self.verbose:
6096             out(f"{path} ignored: {message}", bold=False)
6097
6098     @property
6099     def return_code(self) -> int:
6100         """Return the exit code that the app should use.
6101
6102         This considers the current state of changed files and failures:
6103         - if there were any failures, return 123;
6104         - if any files were changed and --check is being used, return 1;
6105         - otherwise return 0.
6106         """
6107         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6108         # 126 we have special return codes reserved by the shell.
6109         if self.failure_count:
6110             return 123
6111
6112         elif self.change_count and self.check:
6113             return 1
6114
6115         return 0
6116
6117     def __str__(self) -> str:
6118         """Render a color report of the current state.
6119
6120         Use `click.unstyle` to remove colors.
6121         """
6122         if self.check or self.diff:
6123             reformatted = "would be reformatted"
6124             unchanged = "would be left unchanged"
6125             failed = "would fail to reformat"
6126         else:
6127             reformatted = "reformatted"
6128             unchanged = "left unchanged"
6129             failed = "failed to reformat"
6130         report = []
6131         if self.change_count:
6132             s = "s" if self.change_count > 1 else ""
6133             report.append(
6134                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6135             )
6136         if self.same_count:
6137             s = "s" if self.same_count > 1 else ""
6138             report.append(f"{self.same_count} file{s} {unchanged}")
6139         if self.failure_count:
6140             s = "s" if self.failure_count > 1 else ""
6141             report.append(
6142                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6143             )
6144         return ", ".join(report) + "."
6145
6146
6147 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6148     filename = "<unknown>"
6149     if sys.version_info >= (3, 8):
6150         # TODO: support Python 4+ ;)
6151         for minor_version in range(sys.version_info[1], 4, -1):
6152             try:
6153                 return ast.parse(src, filename, feature_version=(3, minor_version))
6154             except SyntaxError:
6155                 continue
6156     else:
6157         for feature_version in (7, 6):
6158             try:
6159                 return ast3.parse(src, filename, feature_version=feature_version)
6160             except SyntaxError:
6161                 continue
6162
6163     return ast27.parse(src)
6164
6165
6166 def _fixup_ast_constants(
6167     node: Union[ast.AST, ast3.AST, ast27.AST]
6168 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6169     """Map ast nodes deprecated in 3.8 to Constant."""
6170     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6171         return ast.Constant(value=node.s)
6172
6173     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6174         return ast.Constant(value=node.n)
6175
6176     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6177         return ast.Constant(value=node.value)
6178
6179     return node
6180
6181
6182 def _stringify_ast(
6183     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6184 ) -> Iterator[str]:
6185     """Simple visitor generating strings to compare ASTs by content."""
6186
6187     node = _fixup_ast_constants(node)
6188
6189     yield f"{'  ' * depth}{node.__class__.__name__}("
6190
6191     for field in sorted(node._fields):  # noqa: F402
6192         # TypeIgnore has only one field 'lineno' which breaks this comparison
6193         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6194         if sys.version_info >= (3, 8):
6195             type_ignore_classes += (ast.TypeIgnore,)
6196         if isinstance(node, type_ignore_classes):
6197             break
6198
6199         try:
6200             value = getattr(node, field)
6201         except AttributeError:
6202             continue
6203
6204         yield f"{'  ' * (depth+1)}{field}="
6205
6206         if isinstance(value, list):
6207             for item in value:
6208                 # Ignore nested tuples within del statements, because we may insert
6209                 # parentheses and they change the AST.
6210                 if (
6211                     field == "targets"
6212                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6213                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6214                 ):
6215                     for item in item.elts:
6216                         yield from _stringify_ast(item, depth + 2)
6217
6218                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6219                     yield from _stringify_ast(item, depth + 2)
6220
6221         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6222             yield from _stringify_ast(value, depth + 2)
6223
6224         else:
6225             # Constant strings may be indented across newlines, if they are
6226             # docstrings; fold spaces after newlines when comparing. Similarly,
6227             # trailing and leading space may be removed.
6228             if (
6229                 isinstance(node, ast.Constant)
6230                 and field == "value"
6231                 and isinstance(value, str)
6232             ):
6233                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6234             else:
6235                 normalized = value
6236             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6237
6238     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6239
6240
6241 def assert_equivalent(src: str, dst: str) -> None:
6242     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6243     try:
6244         src_ast = parse_ast(src)
6245     except Exception as exc:
6246         raise AssertionError(
6247             "cannot use --safe with this file; failed to parse source file.  AST"
6248             f" error message: {exc}"
6249         )
6250
6251     try:
6252         dst_ast = parse_ast(dst)
6253     except Exception as exc:
6254         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6255         raise AssertionError(
6256             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6257             " on https://github.com/psf/black/issues.  This invalid output might be"
6258             f" helpful: {log}"
6259         ) from None
6260
6261     src_ast_str = "\n".join(_stringify_ast(src_ast))
6262     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6263     if src_ast_str != dst_ast_str:
6264         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6265         raise AssertionError(
6266             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6267             " source.  Please report a bug on https://github.com/psf/black/issues. "
6268             f" This diff might be helpful: {log}"
6269         ) from None
6270
6271
6272 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6273     """Raise AssertionError if `dst` reformats differently the second time."""
6274     newdst = format_str(dst, mode=mode)
6275     if dst != newdst:
6276         log = dump_to_file(
6277             str(mode),
6278             diff(src, dst, "source", "first pass"),
6279             diff(dst, newdst, "first pass", "second pass"),
6280         )
6281         raise AssertionError(
6282             "INTERNAL ERROR: Black produced different code on the second pass of the"
6283             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6284             f"  This diff might be helpful: {log}"
6285         ) from None
6286
6287
6288 @mypyc_attr(patchable=True)
6289 def dump_to_file(*output: str) -> str:
6290     """Dump `output` to a temporary file. Return path to the file."""
6291     with tempfile.NamedTemporaryFile(
6292         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6293     ) as f:
6294         for lines in output:
6295             f.write(lines)
6296             if lines and lines[-1] != "\n":
6297                 f.write("\n")
6298     return f.name
6299
6300
6301 @contextmanager
6302 def nullcontext() -> Iterator[None]:
6303     """Return an empty context manager.
6304
6305     To be used like `nullcontext` in Python 3.7.
6306     """
6307     yield
6308
6309
6310 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6311     """Return a unified diff string between strings `a` and `b`."""
6312     import difflib
6313
6314     a_lines = [line + "\n" for line in a.splitlines()]
6315     b_lines = [line + "\n" for line in b.splitlines()]
6316     return "".join(
6317         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6318     )
6319
6320
6321 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6322     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6323     err("Aborted!")
6324     for task in tasks:
6325         task.cancel()
6326
6327
6328 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6329     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6330     try:
6331         if sys.version_info[:2] >= (3, 7):
6332             all_tasks = asyncio.all_tasks
6333         else:
6334             all_tasks = asyncio.Task.all_tasks
6335         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6336         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6337         if not to_cancel:
6338             return
6339
6340         for task in to_cancel:
6341             task.cancel()
6342         loop.run_until_complete(
6343             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6344         )
6345     finally:
6346         # `concurrent.futures.Future` objects cannot be cancelled once they
6347         # are already running. There might be some when the `shutdown()` happened.
6348         # Silence their logger's spew about the event loop being closed.
6349         cf_logger = logging.getLogger("concurrent.futures")
6350         cf_logger.setLevel(logging.CRITICAL)
6351         loop.close()
6352
6353
6354 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6355     """Replace `regex` with `replacement` twice on `original`.
6356
6357     This is used by string normalization to perform replaces on
6358     overlapping matches.
6359     """
6360     return regex.sub(replacement, regex.sub(replacement, original))
6361
6362
6363 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6364     """Compile a regular expression string in `regex`.
6365
6366     If it contains newlines, use verbose mode.
6367     """
6368     if "\n" in regex:
6369         regex = "(?x)" + regex
6370     compiled: Pattern[str] = re.compile(regex)
6371     return compiled
6372
6373
6374 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6375     """Like `reversed(enumerate(sequence))` if that were possible."""
6376     index = len(sequence) - 1
6377     for element in reversed(sequence):
6378         yield (index, element)
6379         index -= 1
6380
6381
6382 def enumerate_with_length(
6383     line: Line, reversed: bool = False
6384 ) -> Iterator[Tuple[Index, Leaf, int]]:
6385     """Return an enumeration of leaves with their length.
6386
6387     Stops prematurely on multiline strings and standalone comments.
6388     """
6389     op = cast(
6390         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6391         enumerate_reversed if reversed else enumerate,
6392     )
6393     for index, leaf in op(line.leaves):
6394         length = len(leaf.prefix) + len(leaf.value)
6395         if "\n" in leaf.value:
6396             return  # Multiline strings, we can't continue.
6397
6398         for comment in line.comments_after(leaf):
6399             length += len(comment.value)
6400
6401         yield index, leaf, length
6402
6403
6404 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6405     """Return True if `line` is no longer than `line_length`.
6406
6407     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6408     """
6409     if not line_str:
6410         line_str = line_to_string(line)
6411     return (
6412         len(line_str) <= line_length
6413         and "\n" not in line_str  # multiline strings
6414         and not line.contains_standalone_comments()
6415     )
6416
6417
6418 def can_be_split(line: Line) -> bool:
6419     """Return False if the line cannot be split *for sure*.
6420
6421     This is not an exhaustive search but a cheap heuristic that we can use to
6422     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6423     in unnecessary parentheses).
6424     """
6425     leaves = line.leaves
6426     if len(leaves) < 2:
6427         return False
6428
6429     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6430         call_count = 0
6431         dot_count = 0
6432         next = leaves[-1]
6433         for leaf in leaves[-2::-1]:
6434             if leaf.type in OPENING_BRACKETS:
6435                 if next.type not in CLOSING_BRACKETS:
6436                     return False
6437
6438                 call_count += 1
6439             elif leaf.type == token.DOT:
6440                 dot_count += 1
6441             elif leaf.type == token.NAME:
6442                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6443                     return False
6444
6445             elif leaf.type not in CLOSING_BRACKETS:
6446                 return False
6447
6448             if dot_count > 1 and call_count > 1:
6449                 return False
6450
6451     return True
6452
6453
6454 def can_omit_invisible_parens(
6455     line: Line,
6456     line_length: int,
6457     omit_on_explode: Collection[LeafID] = (),
6458 ) -> bool:
6459     """Does `line` have a shape safe to reformat without optional parens around it?
6460
6461     Returns True for only a subset of potentially nice looking formattings but
6462     the point is to not return false positives that end up producing lines that
6463     are too long.
6464     """
6465     bt = line.bracket_tracker
6466     if not bt.delimiters:
6467         # Without delimiters the optional parentheses are useless.
6468         return True
6469
6470     max_priority = bt.max_delimiter_priority()
6471     if bt.delimiter_count_with_priority(max_priority) > 1:
6472         # With more than one delimiter of a kind the optional parentheses read better.
6473         return False
6474
6475     if max_priority == DOT_PRIORITY:
6476         # A single stranded method call doesn't require optional parentheses.
6477         return True
6478
6479     assert len(line.leaves) >= 2, "Stranded delimiter"
6480
6481     # With a single delimiter, omit if the expression starts or ends with
6482     # a bracket.
6483     first = line.leaves[0]
6484     second = line.leaves[1]
6485     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6486         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6487             return True
6488
6489         # Note: we are not returning False here because a line might have *both*
6490         # a leading opening bracket and a trailing closing bracket.  If the
6491         # opening bracket doesn't match our rule, maybe the closing will.
6492
6493     penultimate = line.leaves[-2]
6494     last = line.leaves[-1]
6495     if line.should_explode:
6496         try:
6497             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6498         except LookupError:
6499             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6500             return False
6501
6502     if (
6503         last.type == token.RPAR
6504         or last.type == token.RBRACE
6505         or (
6506             # don't use indexing for omitting optional parentheses;
6507             # it looks weird
6508             last.type == token.RSQB
6509             and last.parent
6510             and last.parent.type != syms.trailer
6511         )
6512     ):
6513         if penultimate.type in OPENING_BRACKETS:
6514             # Empty brackets don't help.
6515             return False
6516
6517         if is_multiline_string(first):
6518             # Additional wrapping of a multiline string in this situation is
6519             # unnecessary.
6520             return True
6521
6522         if line.should_explode and penultimate.type == token.COMMA:
6523             # The rightmost non-omitted bracket pair is the one we want to explode on.
6524             return True
6525
6526         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6527             return True
6528
6529     return False
6530
6531
6532 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6533     """See `can_omit_invisible_parens`."""
6534     remainder = False
6535     length = 4 * line.depth
6536     _index = -1
6537     for _index, leaf, leaf_length in enumerate_with_length(line):
6538         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6539             remainder = True
6540         if remainder:
6541             length += leaf_length
6542             if length > line_length:
6543                 break
6544
6545             if leaf.type in OPENING_BRACKETS:
6546                 # There are brackets we can further split on.
6547                 remainder = False
6548
6549     else:
6550         # checked the entire string and line length wasn't exceeded
6551         if len(line.leaves) == _index + 1:
6552             return True
6553
6554     return False
6555
6556
6557 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6558     """See `can_omit_invisible_parens`."""
6559     length = 4 * line.depth
6560     seen_other_brackets = False
6561     for _index, leaf, leaf_length in enumerate_with_length(line):
6562         length += leaf_length
6563         if leaf is last.opening_bracket:
6564             if seen_other_brackets or length <= line_length:
6565                 return True
6566
6567         elif leaf.type in OPENING_BRACKETS:
6568             # There are brackets we can further split on.
6569             seen_other_brackets = True
6570
6571     return False
6572
6573
6574 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6575     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6576     stop_after = None
6577     last = None
6578     for leaf in reversed(leaves):
6579         if stop_after:
6580             if leaf is stop_after:
6581                 stop_after = None
6582             continue
6583
6584         if last:
6585             return leaf, last
6586
6587         if id(leaf) in omit:
6588             stop_after = leaf.opening_bracket
6589         else:
6590             last = leaf
6591     else:
6592         raise LookupError("Last two leaves were also skipped")
6593
6594
6595 def run_transformer(
6596     line: Line,
6597     transform: Transformer,
6598     mode: Mode,
6599     features: Collection[Feature],
6600     *,
6601     line_str: str = "",
6602 ) -> List[Line]:
6603     if not line_str:
6604         line_str = line_to_string(line)
6605     result: List[Line] = []
6606     for transformed_line in transform(line, features):
6607         if str(transformed_line).strip("\n") == line_str:
6608             raise CannotTransform("Line transformer returned an unchanged result")
6609
6610         result.extend(transform_line(transformed_line, mode=mode, features=features))
6611
6612     if not (
6613         transform.__name__ == "rhs"
6614         and line.bracket_tracker.invisible
6615         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6616         and not line.contains_multiline_strings()
6617         and not result[0].contains_uncollapsable_type_comments()
6618         and not result[0].contains_unsplittable_type_ignore()
6619         and not is_line_short_enough(result[0], line_length=mode.line_length)
6620     ):
6621         return result
6622
6623     line_copy = line.clone()
6624     append_leaves(line_copy, line, line.leaves)
6625     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6626     second_opinion = run_transformer(
6627         line_copy, transform, mode, features_fop, line_str=line_str
6628     )
6629     if all(
6630         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6631     ):
6632         result = second_opinion
6633     return result
6634
6635
6636 def get_cache_file(mode: Mode) -> Path:
6637     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6638
6639
6640 def read_cache(mode: Mode) -> Cache:
6641     """Read the cache if it exists and is well formed.
6642
6643     If it is not well formed, the call to write_cache later should resolve the issue.
6644     """
6645     cache_file = get_cache_file(mode)
6646     if not cache_file.exists():
6647         return {}
6648
6649     with cache_file.open("rb") as fobj:
6650         try:
6651             cache: Cache = pickle.load(fobj)
6652         except (pickle.UnpicklingError, ValueError):
6653             return {}
6654
6655     return cache
6656
6657
6658 def get_cache_info(path: Path) -> CacheInfo:
6659     """Return the information used to check if a file is already formatted or not."""
6660     stat = path.stat()
6661     return stat.st_mtime, stat.st_size
6662
6663
6664 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6665     """Split an iterable of paths in `sources` into two sets.
6666
6667     The first contains paths of files that modified on disk or are not in the
6668     cache. The other contains paths to non-modified files.
6669     """
6670     todo, done = set(), set()
6671     for src in sources:
6672         src = src.resolve()
6673         if cache.get(src) != get_cache_info(src):
6674             todo.add(src)
6675         else:
6676             done.add(src)
6677     return todo, done
6678
6679
6680 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6681     """Update the cache file."""
6682     cache_file = get_cache_file(mode)
6683     try:
6684         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6685         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6686         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6687             pickle.dump(new_cache, f, protocol=4)
6688         os.replace(f.name, cache_file)
6689     except OSError:
6690         pass
6691
6692
6693 def patch_click() -> None:
6694     """Make Click not crash.
6695
6696     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6697     default which restricts paths that it can access during the lifetime of the
6698     application.  Click refuses to work in this scenario by raising a RuntimeError.
6699
6700     In case of Black the likelihood that non-ASCII characters are going to be used in
6701     file paths is minimal since it's Python source code.  Moreover, this crash was
6702     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6703     """
6704     try:
6705         from click import core
6706         from click import _unicodefun  # type: ignore
6707     except ModuleNotFoundError:
6708         return
6709
6710     for module in (core, _unicodefun):
6711         if hasattr(module, "_verify_python3_env"):
6712             module._verify_python3_env = lambda: None
6713
6714
6715 def patched_main() -> None:
6716     freeze_support()
6717     patch_click()
6718     main()
6719
6720
6721 def is_docstring(leaf: Leaf) -> bool:
6722     if not is_multiline_string(leaf):
6723         # For the purposes of docstring re-indentation, we don't need to do anything
6724         # with single-line docstrings.
6725         return False
6726
6727     if prev_siblings_are(
6728         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
6729     ):
6730         return True
6731
6732     # Multiline docstring on the same line as the `def`.
6733     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
6734         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
6735         # grammar. We're safe to return True without further checks.
6736         return True
6737
6738     return False
6739
6740
6741 def fix_docstring(docstring: str, prefix: str) -> str:
6742     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6743     if not docstring:
6744         return ""
6745     # Convert tabs to spaces (following the normal Python rules)
6746     # and split into a list of lines:
6747     lines = docstring.expandtabs().splitlines()
6748     # Determine minimum indentation (first line doesn't count):
6749     indent = sys.maxsize
6750     for line in lines[1:]:
6751         stripped = line.lstrip()
6752         if stripped:
6753             indent = min(indent, len(line) - len(stripped))
6754     # Remove indentation (first line is special):
6755     trimmed = [lines[0].strip()]
6756     if indent < sys.maxsize:
6757         last_line_idx = len(lines) - 2
6758         for i, line in enumerate(lines[1:]):
6759             stripped_line = line[indent:].rstrip()
6760             if stripped_line or i == last_line_idx:
6761                 trimmed.append(prefix + stripped_line)
6762             else:
6763                 trimmed.append("")
6764     return "\n".join(trimmed)
6765
6766
6767 if __name__ == "__main__":
6768     patched_main()