]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove Lowercase Hex (PR #1692) from CHANGES.md (#2133)
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from mypy_extensions import mypyc_attr
46
47 from appdirs import user_cache_dir
48 from dataclasses import dataclass, field, replace
49 import click
50 import toml
51
52 try:
53     from typed_ast import ast3, ast27
54 except ImportError:
55     if sys.version_info < (3, 8):
56         print(
57             "The typed_ast package is not installed.\n"
58             "You can install it with `python3 -m pip install typed-ast`.",
59             file=sys.stderr,
60         )
61         sys.exit(1)
62     else:
63         ast3 = ast27 = ast
64
65 from pathspec import PathSpec
66
67 # lib2to3 fork
68 from blib2to3.pytree import Node, Leaf, type_repr
69 from blib2to3 import pygram, pytree
70 from blib2to3.pgen2 import driver, token
71 from blib2to3.pgen2.grammar import Grammar
72 from blib2to3.pgen2.parse import ParseError
73
74 from _black_version import version as __version__
75
76 if sys.version_info < (3, 8):
77     from typing_extensions import Final
78 else:
79     from typing import Final
80
81 if TYPE_CHECKING:
82     import colorama  # noqa: F401
83
84 DEFAULT_LINE_LENGTH = 88
85 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
86 DEFAULT_INCLUDES = r"\.pyi?$"
87 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
88 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
89
90 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
91
92
93 # types
94 FileContent = str
95 Encoding = str
96 NewLine = str
97 Depth = int
98 NodeType = int
99 ParserState = int
100 LeafID = int
101 StringID = int
102 Priority = int
103 Index = int
104 LN = Union[Leaf, Node]
105 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
106 Timestamp = float
107 FileSize = int
108 CacheInfo = Tuple[Timestamp, FileSize]
109 Cache = Dict[str, CacheInfo]
110 out = partial(click.secho, bold=True, err=True)
111 err = partial(click.secho, fg="red", err=True)
112
113 pygram.initialize(CACHE_DIR)
114 syms = pygram.python_symbols
115
116
117 class NothingChanged(UserWarning):
118     """Raised when reformatted code is the same as source."""
119
120
121 class CannotTransform(Exception):
122     """Base class for errors raised by Transformers."""
123
124
125 class CannotSplit(CannotTransform):
126     """A readable split that fits the allotted line length is impossible."""
127
128
129 class InvalidInput(ValueError):
130     """Raised when input source code fails all parse attempts."""
131
132
133 class BracketMatchError(KeyError):
134     """Raised when an opening bracket is unable to be matched to a closing bracket."""
135
136
137 T = TypeVar("T")
138 E = TypeVar("E", bound=Exception)
139
140
141 class Ok(Generic[T]):
142     def __init__(self, value: T) -> None:
143         self._value = value
144
145     def ok(self) -> T:
146         return self._value
147
148
149 class Err(Generic[E]):
150     def __init__(self, e: E) -> None:
151         self._e = e
152
153     def err(self) -> E:
154         return self._e
155
156
157 # The 'Result' return type is used to implement an error-handling model heavily
158 # influenced by that used by the Rust programming language
159 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
160 Result = Union[Ok[T], Err[E]]
161 TResult = Result[T, CannotTransform]  # (T)ransform Result
162 TMatchResult = TResult[Index]
163
164
165 class WriteBack(Enum):
166     NO = 0
167     YES = 1
168     DIFF = 2
169     CHECK = 3
170     COLOR_DIFF = 4
171
172     @classmethod
173     def from_configuration(
174         cls, *, check: bool, diff: bool, color: bool = False
175     ) -> "WriteBack":
176         if check and not diff:
177             return cls.CHECK
178
179         if diff and color:
180             return cls.COLOR_DIFF
181
182         return cls.DIFF if diff else cls.YES
183
184
185 class Changed(Enum):
186     NO = 0
187     CACHED = 1
188     YES = 2
189
190
191 class TargetVersion(Enum):
192     PY27 = 2
193     PY33 = 3
194     PY34 = 4
195     PY35 = 5
196     PY36 = 6
197     PY37 = 7
198     PY38 = 8
199     PY39 = 9
200
201     def is_python2(self) -> bool:
202         return self is TargetVersion.PY27
203
204
205 class Feature(Enum):
206     # All string literals are unicode
207     UNICODE_LITERALS = 1
208     F_STRINGS = 2
209     NUMERIC_UNDERSCORES = 3
210     TRAILING_COMMA_IN_CALL = 4
211     TRAILING_COMMA_IN_DEF = 5
212     # The following two feature-flags are mutually exclusive, and exactly one should be
213     # set for every version of python.
214     ASYNC_IDENTIFIERS = 6
215     ASYNC_KEYWORDS = 7
216     ASSIGNMENT_EXPRESSIONS = 8
217     POS_ONLY_ARGUMENTS = 9
218     RELAXED_DECORATORS = 10
219     FORCE_OPTIONAL_PARENTHESES = 50
220
221
222 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
223     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
224     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
225     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
226     TargetVersion.PY35: {
227         Feature.UNICODE_LITERALS,
228         Feature.TRAILING_COMMA_IN_CALL,
229         Feature.ASYNC_IDENTIFIERS,
230     },
231     TargetVersion.PY36: {
232         Feature.UNICODE_LITERALS,
233         Feature.F_STRINGS,
234         Feature.NUMERIC_UNDERSCORES,
235         Feature.TRAILING_COMMA_IN_CALL,
236         Feature.TRAILING_COMMA_IN_DEF,
237         Feature.ASYNC_IDENTIFIERS,
238     },
239     TargetVersion.PY37: {
240         Feature.UNICODE_LITERALS,
241         Feature.F_STRINGS,
242         Feature.NUMERIC_UNDERSCORES,
243         Feature.TRAILING_COMMA_IN_CALL,
244         Feature.TRAILING_COMMA_IN_DEF,
245         Feature.ASYNC_KEYWORDS,
246     },
247     TargetVersion.PY38: {
248         Feature.UNICODE_LITERALS,
249         Feature.F_STRINGS,
250         Feature.NUMERIC_UNDERSCORES,
251         Feature.TRAILING_COMMA_IN_CALL,
252         Feature.TRAILING_COMMA_IN_DEF,
253         Feature.ASYNC_KEYWORDS,
254         Feature.ASSIGNMENT_EXPRESSIONS,
255         Feature.POS_ONLY_ARGUMENTS,
256     },
257     TargetVersion.PY39: {
258         Feature.UNICODE_LITERALS,
259         Feature.F_STRINGS,
260         Feature.NUMERIC_UNDERSCORES,
261         Feature.TRAILING_COMMA_IN_CALL,
262         Feature.TRAILING_COMMA_IN_DEF,
263         Feature.ASYNC_KEYWORDS,
264         Feature.ASSIGNMENT_EXPRESSIONS,
265         Feature.RELAXED_DECORATORS,
266         Feature.POS_ONLY_ARGUMENTS,
267     },
268 }
269
270
271 @dataclass
272 class Mode:
273     target_versions: Set[TargetVersion] = field(default_factory=set)
274     line_length: int = DEFAULT_LINE_LENGTH
275     string_normalization: bool = True
276     magic_trailing_comma: bool = True
277     experimental_string_processing: bool = False
278     is_pyi: bool = False
279
280     def get_cache_key(self) -> str:
281         if self.target_versions:
282             version_str = ",".join(
283                 str(version.value)
284                 for version in sorted(self.target_versions, key=lambda v: v.value)
285             )
286         else:
287             version_str = "-"
288         parts = [
289             version_str,
290             str(self.line_length),
291             str(int(self.string_normalization)),
292             str(int(self.is_pyi)),
293         ]
294         return ".".join(parts)
295
296
297 # Legacy name, left for integrations.
298 FileMode = Mode
299
300
301 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
302     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
303
304
305 def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
306     """Find the absolute filepath to a pyproject.toml if it exists"""
307     path_project_root = find_project_root(path_search_start)
308     path_pyproject_toml = path_project_root / "pyproject.toml"
309     if path_pyproject_toml.is_file():
310         return str(path_pyproject_toml)
311
312     path_user_pyproject_toml = find_user_pyproject_toml()
313     return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
314
315
316 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
317     """Parse a pyproject toml file, pulling out relevant parts for Black
318
319     If parsing fails, will raise a toml.TomlDecodeError
320     """
321     pyproject_toml = toml.load(path_config)
322     config = pyproject_toml.get("tool", {}).get("black", {})
323     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
324
325
326 def read_pyproject_toml(
327     ctx: click.Context, param: click.Parameter, value: Optional[str]
328 ) -> Optional[str]:
329     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
330
331     Returns the path to a successfully found and read configuration file, None
332     otherwise.
333     """
334     if not value:
335         value = find_pyproject_toml(ctx.params.get("src", ()))
336         if value is None:
337             return None
338
339     try:
340         config = parse_pyproject_toml(value)
341     except (toml.TomlDecodeError, OSError) as e:
342         raise click.FileError(
343             filename=value, hint=f"Error reading configuration file: {e}"
344         )
345
346     if not config:
347         return None
348     else:
349         # Sanitize the values to be Click friendly. For more information please see:
350         # https://github.com/psf/black/issues/1458
351         # https://github.com/pallets/click/issues/1567
352         config = {
353             k: str(v) if not isinstance(v, (list, dict)) else v
354             for k, v in config.items()
355         }
356
357     target_version = config.get("target_version")
358     if target_version is not None and not isinstance(target_version, list):
359         raise click.BadOptionUsage(
360             "target-version", "Config key target-version must be a list"
361         )
362
363     default_map: Dict[str, Any] = {}
364     if ctx.default_map:
365         default_map.update(ctx.default_map)
366     default_map.update(config)
367
368     ctx.default_map = default_map
369     return value
370
371
372 def target_version_option_callback(
373     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
374 ) -> List[TargetVersion]:
375     """Compute the target versions from a --target-version flag.
376
377     This is its own function because mypy couldn't infer the type correctly
378     when it was a lambda, causing mypyc trouble.
379     """
380     return [TargetVersion[val.upper()] for val in v]
381
382
383 def validate_regex(
384     ctx: click.Context,
385     param: click.Parameter,
386     value: Optional[str],
387 ) -> Optional[Pattern]:
388     try:
389         return re_compile_maybe_verbose(value) if value is not None else None
390     except re.error:
391         raise click.BadParameter("Not a valid regular expression")
392
393
394 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
395 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
396 @click.option(
397     "-l",
398     "--line-length",
399     type=int,
400     default=DEFAULT_LINE_LENGTH,
401     help="How many characters per line to allow.",
402     show_default=True,
403 )
404 @click.option(
405     "-t",
406     "--target-version",
407     type=click.Choice([v.name.lower() for v in TargetVersion]),
408     callback=target_version_option_callback,
409     multiple=True,
410     help=(
411         "Python versions that should be supported by Black's output. [default: per-file"
412         " auto-detection]"
413     ),
414 )
415 @click.option(
416     "--pyi",
417     is_flag=True,
418     help=(
419         "Format all input files like typing stubs regardless of file extension (useful"
420         " when piping source on standard input)."
421     ),
422 )
423 @click.option(
424     "-S",
425     "--skip-string-normalization",
426     is_flag=True,
427     help="Don't normalize string quotes or prefixes.",
428 )
429 @click.option(
430     "-C",
431     "--skip-magic-trailing-comma",
432     is_flag=True,
433     help="Don't use trailing commas as a reason to split lines.",
434 )
435 @click.option(
436     "--experimental-string-processing",
437     is_flag=True,
438     hidden=True,
439     help=(
440         "Experimental option that performs more normalization on string literals."
441         " Currently disabled because it leads to some crashes."
442     ),
443 )
444 @click.option(
445     "--check",
446     is_flag=True,
447     help=(
448         "Don't write the files back, just return the status. Return code 0 means"
449         " nothing would change. Return code 1 means some files would be reformatted."
450         " Return code 123 means there was an internal error."
451     ),
452 )
453 @click.option(
454     "--diff",
455     is_flag=True,
456     help="Don't write the files back, just output a diff for each file on stdout.",
457 )
458 @click.option(
459     "--color/--no-color",
460     is_flag=True,
461     help="Show colored diff. Only applies when `--diff` is given.",
462 )
463 @click.option(
464     "--fast/--safe",
465     is_flag=True,
466     help="If --fast given, skip temporary sanity checks. [default: --safe]",
467 )
468 @click.option(
469     "--include",
470     type=str,
471     default=DEFAULT_INCLUDES,
472     callback=validate_regex,
473     help=(
474         "A regular expression that matches files and directories that should be"
475         " included on recursive searches. An empty value means all files are included"
476         " regardless of the name. Use forward slashes for directories on all platforms"
477         " (Windows, too). Exclusions are calculated first, inclusions later."
478     ),
479     show_default=True,
480 )
481 @click.option(
482     "--exclude",
483     type=str,
484     default=DEFAULT_EXCLUDES,
485     callback=validate_regex,
486     help=(
487         "A regular expression that matches files and directories that should be"
488         " excluded on recursive searches. An empty value means no paths are excluded."
489         " Use forward slashes for directories on all platforms (Windows, too)."
490         " Exclusions are calculated first, inclusions later."
491     ),
492     show_default=True,
493 )
494 @click.option(
495     "--extend-exclude",
496     type=str,
497     callback=validate_regex,
498     help=(
499         "Like --exclude, but adds additional files and directories on top of the"
500         " excluded ones. (Useful if you simply want to add to the default)"
501     ),
502 )
503 @click.option(
504     "--force-exclude",
505     type=str,
506     callback=validate_regex,
507     help=(
508         "Like --exclude, but files and directories matching this regex will be "
509         "excluded even when they are passed explicitly as arguments."
510     ),
511 )
512 @click.option(
513     "--stdin-filename",
514     type=str,
515     help=(
516         "The name of the file when passing it through stdin. Useful to make "
517         "sure Black will respect --force-exclude option on some "
518         "editors that rely on using stdin."
519     ),
520 )
521 @click.option(
522     "-q",
523     "--quiet",
524     is_flag=True,
525     help=(
526         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
527         " those with 2>/dev/null."
528     ),
529 )
530 @click.option(
531     "-v",
532     "--verbose",
533     is_flag=True,
534     help=(
535         "Also emit messages to stderr about files that were not changed or were ignored"
536         " due to exclusion patterns."
537     ),
538 )
539 @click.version_option(version=__version__)
540 @click.argument(
541     "src",
542     nargs=-1,
543     type=click.Path(
544         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
545     ),
546     is_eager=True,
547 )
548 @click.option(
549     "--config",
550     type=click.Path(
551         exists=True,
552         file_okay=True,
553         dir_okay=False,
554         readable=True,
555         allow_dash=False,
556         path_type=str,
557     ),
558     is_eager=True,
559     callback=read_pyproject_toml,
560     help="Read configuration from FILE path.",
561 )
562 @click.pass_context
563 def main(
564     ctx: click.Context,
565     code: Optional[str],
566     line_length: int,
567     target_version: List[TargetVersion],
568     check: bool,
569     diff: bool,
570     color: bool,
571     fast: bool,
572     pyi: bool,
573     skip_string_normalization: bool,
574     skip_magic_trailing_comma: bool,
575     experimental_string_processing: bool,
576     quiet: bool,
577     verbose: bool,
578     include: Pattern,
579     exclude: Pattern,
580     extend_exclude: Optional[Pattern],
581     force_exclude: Optional[Pattern],
582     stdin_filename: Optional[str],
583     src: Tuple[str, ...],
584     config: Optional[str],
585 ) -> None:
586     """The uncompromising code formatter."""
587     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
588     if target_version:
589         versions = set(target_version)
590     else:
591         # We'll autodetect later.
592         versions = set()
593     mode = Mode(
594         target_versions=versions,
595         line_length=line_length,
596         is_pyi=pyi,
597         string_normalization=not skip_string_normalization,
598         magic_trailing_comma=not skip_magic_trailing_comma,
599         experimental_string_processing=experimental_string_processing,
600     )
601     if config and verbose:
602         out(f"Using configuration from {config}.", bold=False, fg="blue")
603     if code is not None:
604         print(format_str(code, mode=mode))
605         ctx.exit(0)
606     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
607     sources = get_sources(
608         ctx=ctx,
609         src=src,
610         quiet=quiet,
611         verbose=verbose,
612         include=include,
613         exclude=exclude,
614         extend_exclude=extend_exclude,
615         force_exclude=force_exclude,
616         report=report,
617         stdin_filename=stdin_filename,
618     )
619
620     path_empty(
621         sources,
622         "No Python files are present to be formatted. Nothing to do 😴",
623         quiet,
624         verbose,
625         ctx,
626     )
627
628     if len(sources) == 1:
629         reformat_one(
630             src=sources.pop(),
631             fast=fast,
632             write_back=write_back,
633             mode=mode,
634             report=report,
635         )
636     else:
637         reformat_many(
638             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
639         )
640
641     if verbose or not quiet:
642         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
643         click.secho(str(report), err=True)
644     ctx.exit(report.return_code)
645
646
647 def get_sources(
648     *,
649     ctx: click.Context,
650     src: Tuple[str, ...],
651     quiet: bool,
652     verbose: bool,
653     include: Pattern[str],
654     exclude: Pattern[str],
655     extend_exclude: Optional[Pattern[str]],
656     force_exclude: Optional[Pattern[str]],
657     report: "Report",
658     stdin_filename: Optional[str],
659 ) -> Set[Path]:
660     """Compute the set of files to be formatted."""
661
662     root = find_project_root(src)
663     sources: Set[Path] = set()
664     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
665     gitignore = get_gitignore(root)
666
667     for s in src:
668         if s == "-" and stdin_filename:
669             p = Path(stdin_filename)
670             is_stdin = True
671         else:
672             p = Path(s)
673             is_stdin = False
674
675         if is_stdin or p.is_file():
676             normalized_path = normalize_path_maybe_ignore(p, root, report)
677             if normalized_path is None:
678                 continue
679
680             normalized_path = "/" + normalized_path
681             # Hard-exclude any files that matches the `--force-exclude` regex.
682             if force_exclude:
683                 force_exclude_match = force_exclude.search(normalized_path)
684             else:
685                 force_exclude_match = None
686             if force_exclude_match and force_exclude_match.group(0):
687                 report.path_ignored(p, "matches the --force-exclude regular expression")
688                 continue
689
690             if is_stdin:
691                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
692
693             sources.add(p)
694         elif p.is_dir():
695             sources.update(
696                 gen_python_files(
697                     p.iterdir(),
698                     root,
699                     include,
700                     exclude,
701                     extend_exclude,
702                     force_exclude,
703                     report,
704                     gitignore,
705                 )
706             )
707         elif s == "-":
708             sources.add(p)
709         else:
710             err(f"invalid path: {s}")
711     return sources
712
713
714 def path_empty(
715     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
716 ) -> None:
717     """
718     Exit if there is no `src` provided for formatting
719     """
720     if not src and (verbose or not quiet):
721         out(msg)
722         ctx.exit(0)
723
724
725 def reformat_one(
726     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
727 ) -> None:
728     """Reformat a single file under `src` without spawning child processes.
729
730     `fast`, `write_back`, and `mode` options are passed to
731     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
732     """
733     try:
734         changed = Changed.NO
735
736         if str(src) == "-":
737             is_stdin = True
738         elif str(src).startswith(STDIN_PLACEHOLDER):
739             is_stdin = True
740             # Use the original name again in case we want to print something
741             # to the user
742             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
743         else:
744             is_stdin = False
745
746         if is_stdin:
747             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
748                 changed = Changed.YES
749         else:
750             cache: Cache = {}
751             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
752                 cache = read_cache(mode)
753                 res_src = src.resolve()
754                 res_src_s = str(res_src)
755                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
756                     changed = Changed.CACHED
757             if changed is not Changed.CACHED and format_file_in_place(
758                 src, fast=fast, write_back=write_back, mode=mode
759             ):
760                 changed = Changed.YES
761             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
762                 write_back is WriteBack.CHECK and changed is Changed.NO
763             ):
764                 write_cache(cache, [src], mode)
765         report.done(src, changed)
766     except Exception as exc:
767         if report.verbose:
768             traceback.print_exc()
769         report.failed(src, str(exc))
770
771
772 def reformat_many(
773     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
774 ) -> None:
775     """Reformat multiple files using a ProcessPoolExecutor."""
776     executor: Executor
777     loop = asyncio.get_event_loop()
778     worker_count = os.cpu_count()
779     if sys.platform == "win32":
780         # Work around https://bugs.python.org/issue26903
781         worker_count = min(worker_count, 60)
782     try:
783         executor = ProcessPoolExecutor(max_workers=worker_count)
784     except (ImportError, OSError):
785         # we arrive here if the underlying system does not support multi-processing
786         # like in AWS Lambda or Termux, in which case we gracefully fallback to
787         # a ThreadPoolExecutor with just a single worker (more workers would not do us
788         # any good due to the Global Interpreter Lock)
789         executor = ThreadPoolExecutor(max_workers=1)
790
791     try:
792         loop.run_until_complete(
793             schedule_formatting(
794                 sources=sources,
795                 fast=fast,
796                 write_back=write_back,
797                 mode=mode,
798                 report=report,
799                 loop=loop,
800                 executor=executor,
801             )
802         )
803     finally:
804         shutdown(loop)
805         if executor is not None:
806             executor.shutdown()
807
808
809 async def schedule_formatting(
810     sources: Set[Path],
811     fast: bool,
812     write_back: WriteBack,
813     mode: Mode,
814     report: "Report",
815     loop: asyncio.AbstractEventLoop,
816     executor: Executor,
817 ) -> None:
818     """Run formatting of `sources` in parallel using the provided `executor`.
819
820     (Use ProcessPoolExecutors for actual parallelism.)
821
822     `write_back`, `fast`, and `mode` options are passed to
823     :func:`format_file_in_place`.
824     """
825     cache: Cache = {}
826     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
827         cache = read_cache(mode)
828         sources, cached = filter_cached(cache, sources)
829         for src in sorted(cached):
830             report.done(src, Changed.CACHED)
831     if not sources:
832         return
833
834     cancelled = []
835     sources_to_cache = []
836     lock = None
837     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
838         # For diff output, we need locks to ensure we don't interleave output
839         # from different processes.
840         manager = Manager()
841         lock = manager.Lock()
842     tasks = {
843         asyncio.ensure_future(
844             loop.run_in_executor(
845                 executor, format_file_in_place, src, fast, mode, write_back, lock
846             )
847         ): src
848         for src in sorted(sources)
849     }
850     pending = tasks.keys()
851     try:
852         loop.add_signal_handler(signal.SIGINT, cancel, pending)
853         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
854     except NotImplementedError:
855         # There are no good alternatives for these on Windows.
856         pass
857     while pending:
858         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
859         for task in done:
860             src = tasks.pop(task)
861             if task.cancelled():
862                 cancelled.append(task)
863             elif task.exception():
864                 report.failed(src, str(task.exception()))
865             else:
866                 changed = Changed.YES if task.result() else Changed.NO
867                 # If the file was written back or was successfully checked as
868                 # well-formatted, store this information in the cache.
869                 if write_back is WriteBack.YES or (
870                     write_back is WriteBack.CHECK and changed is Changed.NO
871                 ):
872                     sources_to_cache.append(src)
873                 report.done(src, changed)
874     if cancelled:
875         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
876     if sources_to_cache:
877         write_cache(cache, sources_to_cache, mode)
878
879
880 def format_file_in_place(
881     src: Path,
882     fast: bool,
883     mode: Mode,
884     write_back: WriteBack = WriteBack.NO,
885     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
886 ) -> bool:
887     """Format file under `src` path. Return True if changed.
888
889     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
890     code to the file.
891     `mode` and `fast` options are passed to :func:`format_file_contents`.
892     """
893     if src.suffix == ".pyi":
894         mode = replace(mode, is_pyi=True)
895
896     then = datetime.utcfromtimestamp(src.stat().st_mtime)
897     with open(src, "rb") as buf:
898         src_contents, encoding, newline = decode_bytes(buf.read())
899     try:
900         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
901     except NothingChanged:
902         return False
903
904     if write_back == WriteBack.YES:
905         with open(src, "w", encoding=encoding, newline=newline) as f:
906             f.write(dst_contents)
907     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
908         now = datetime.utcnow()
909         src_name = f"{src}\t{then} +0000"
910         dst_name = f"{src}\t{now} +0000"
911         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
912
913         if write_back == WriteBack.COLOR_DIFF:
914             diff_contents = color_diff(diff_contents)
915
916         with lock or nullcontext():
917             f = io.TextIOWrapper(
918                 sys.stdout.buffer,
919                 encoding=encoding,
920                 newline=newline,
921                 write_through=True,
922             )
923             f = wrap_stream_for_windows(f)
924             f.write(diff_contents)
925             f.detach()
926
927     return True
928
929
930 def color_diff(contents: str) -> str:
931     """Inject the ANSI color codes to the diff."""
932     lines = contents.split("\n")
933     for i, line in enumerate(lines):
934         if line.startswith("+++") or line.startswith("---"):
935             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
936         elif line.startswith("@@"):
937             line = "\033[36m" + line + "\033[0m"  # cyan, reset
938         elif line.startswith("+"):
939             line = "\033[32m" + line + "\033[0m"  # green, reset
940         elif line.startswith("-"):
941             line = "\033[31m" + line + "\033[0m"  # red, reset
942         lines[i] = line
943     return "\n".join(lines)
944
945
946 def wrap_stream_for_windows(
947     f: io.TextIOWrapper,
948 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
949     """
950     Wrap stream with colorama's wrap_stream so colors are shown on Windows.
951
952     If `colorama` is unavailable, the original stream is returned unmodified.
953     Otherwise, the `wrap_stream()` function determines whether the stream needs
954     to be wrapped for a Windows environment and will accordingly either return
955     an `AnsiToWin32` wrapper or the original stream.
956     """
957     try:
958         from colorama.initialise import wrap_stream
959     except ImportError:
960         return f
961     else:
962         # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
963         return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
964
965
966 def format_stdin_to_stdout(
967     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
968 ) -> bool:
969     """Format file on stdin. Return True if changed.
970
971     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
972     write a diff to stdout. The `mode` argument is passed to
973     :func:`format_file_contents`.
974     """
975     then = datetime.utcnow()
976     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
977     dst = src
978     try:
979         dst = format_file_contents(src, fast=fast, mode=mode)
980         return True
981
982     except NothingChanged:
983         return False
984
985     finally:
986         f = io.TextIOWrapper(
987             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
988         )
989         if write_back == WriteBack.YES:
990             f.write(dst)
991         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
992             now = datetime.utcnow()
993             src_name = f"STDIN\t{then} +0000"
994             dst_name = f"STDOUT\t{now} +0000"
995             d = diff(src, dst, src_name, dst_name)
996             if write_back == WriteBack.COLOR_DIFF:
997                 d = color_diff(d)
998                 f = wrap_stream_for_windows(f)
999             f.write(d)
1000         f.detach()
1001
1002
1003 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1004     """Reformat contents of a file and return new contents.
1005
1006     If `fast` is False, additionally confirm that the reformatted code is
1007     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
1008     `mode` is passed to :func:`format_str`.
1009     """
1010     if not src_contents.strip():
1011         raise NothingChanged
1012
1013     dst_contents = format_str(src_contents, mode=mode)
1014     if src_contents == dst_contents:
1015         raise NothingChanged
1016
1017     if not fast:
1018         assert_equivalent(src_contents, dst_contents)
1019
1020         # Forced second pass to work around optional trailing commas (becoming
1021         # forced trailing commas on pass 2) interacting differently with optional
1022         # parentheses.  Admittedly ugly.
1023         dst_contents_pass2 = format_str(dst_contents, mode=mode)
1024         if dst_contents != dst_contents_pass2:
1025             dst_contents = dst_contents_pass2
1026             assert_equivalent(src_contents, dst_contents, pass_num=2)
1027             assert_stable(src_contents, dst_contents, mode=mode)
1028         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
1029         # the same as `dst_contents_pass2`.
1030     return dst_contents
1031
1032
1033 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1034     """Reformat a string and return new contents.
1035
1036     `mode` determines formatting options, such as how many characters per line are
1037     allowed.  Example:
1038
1039     >>> import black
1040     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1041     def f(arg: str = "") -> None:
1042         ...
1043
1044     A more complex example:
1045
1046     >>> print(
1047     ...   black.format_str(
1048     ...     "def f(arg:str='')->None: hey",
1049     ...     mode=black.Mode(
1050     ...       target_versions={black.TargetVersion.PY36},
1051     ...       line_length=10,
1052     ...       string_normalization=False,
1053     ...       is_pyi=False,
1054     ...     ),
1055     ...   ),
1056     ... )
1057     def f(
1058         arg: str = '',
1059     ) -> None:
1060         hey
1061
1062     """
1063     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1064     dst_contents = []
1065     future_imports = get_future_imports(src_node)
1066     if mode.target_versions:
1067         versions = mode.target_versions
1068     else:
1069         versions = detect_target_versions(src_node)
1070     normalize_fmt_off(src_node)
1071     lines = LineGenerator(
1072         mode=mode,
1073         remove_u_prefix="unicode_literals" in future_imports
1074         or supports_feature(versions, Feature.UNICODE_LITERALS),
1075     )
1076     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1077     empty_line = Line(mode=mode)
1078     after = 0
1079     split_line_features = {
1080         feature
1081         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1082         if supports_feature(versions, feature)
1083     }
1084     for current_line in lines.visit(src_node):
1085         dst_contents.append(str(empty_line) * after)
1086         before, after = elt.maybe_empty_lines(current_line)
1087         dst_contents.append(str(empty_line) * before)
1088         for line in transform_line(
1089             current_line, mode=mode, features=split_line_features
1090         ):
1091             dst_contents.append(str(line))
1092     return "".join(dst_contents)
1093
1094
1095 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1096     """Return a tuple of (decoded_contents, encoding, newline).
1097
1098     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1099     universal newlines (i.e. only contains LF).
1100     """
1101     srcbuf = io.BytesIO(src)
1102     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1103     if not lines:
1104         return "", encoding, "\n"
1105
1106     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1107     srcbuf.seek(0)
1108     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1109         return tiow.read(), encoding, newline
1110
1111
1112 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1113     if not target_versions:
1114         # No target_version specified, so try all grammars.
1115         return [
1116             # Python 3.7+
1117             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1118             # Python 3.0-3.6
1119             pygram.python_grammar_no_print_statement_no_exec_statement,
1120             # Python 2.7 with future print_function import
1121             pygram.python_grammar_no_print_statement,
1122             # Python 2.7
1123             pygram.python_grammar,
1124         ]
1125
1126     if all(version.is_python2() for version in target_versions):
1127         # Python 2-only code, so try Python 2 grammars.
1128         return [
1129             # Python 2.7 with future print_function import
1130             pygram.python_grammar_no_print_statement,
1131             # Python 2.7
1132             pygram.python_grammar,
1133         ]
1134
1135     # Python 3-compatible code, so only try Python 3 grammar.
1136     grammars = []
1137     # If we have to parse both, try to parse async as a keyword first
1138     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1139         # Python 3.7+
1140         grammars.append(
1141             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1142         )
1143     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1144         # Python 3.0-3.6
1145         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1146     # At least one of the above branches must have been taken, because every Python
1147     # version has exactly one of the two 'ASYNC_*' flags
1148     return grammars
1149
1150
1151 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1152     """Given a string with source, return the lib2to3 Node."""
1153     if not src_txt.endswith("\n"):
1154         src_txt += "\n"
1155
1156     for grammar in get_grammars(set(target_versions)):
1157         drv = driver.Driver(grammar, pytree.convert)
1158         try:
1159             result = drv.parse_string(src_txt, True)
1160             break
1161
1162         except ParseError as pe:
1163             lineno, column = pe.context[1]
1164             lines = src_txt.splitlines()
1165             try:
1166                 faulty_line = lines[lineno - 1]
1167             except IndexError:
1168                 faulty_line = "<line number missing in source>"
1169             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1170     else:
1171         raise exc from None
1172
1173     if isinstance(result, Leaf):
1174         result = Node(syms.file_input, [result])
1175     return result
1176
1177
1178 def lib2to3_unparse(node: Node) -> str:
1179     """Given a lib2to3 node, return its string representation."""
1180     code = str(node)
1181     return code
1182
1183
1184 class Visitor(Generic[T]):
1185     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1186
1187     def visit(self, node: LN) -> Iterator[T]:
1188         """Main method to visit `node` and its children.
1189
1190         It tries to find a `visit_*()` method for the given `node.type`, like
1191         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1192         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1193         instead.
1194
1195         Then yields objects of type `T` from the selected visitor.
1196         """
1197         if node.type < 256:
1198             name = token.tok_name[node.type]
1199         else:
1200             name = str(type_repr(node.type))
1201         # We explicitly branch on whether a visitor exists (instead of
1202         # using self.visit_default as the default arg to getattr) in order
1203         # to save needing to create a bound method object and so mypyc can
1204         # generate a native call to visit_default.
1205         visitf = getattr(self, f"visit_{name}", None)
1206         if visitf:
1207             yield from visitf(node)
1208         else:
1209             yield from self.visit_default(node)
1210
1211     def visit_default(self, node: LN) -> Iterator[T]:
1212         """Default `visit_*()` implementation. Recurses to children of `node`."""
1213         if isinstance(node, Node):
1214             for child in node.children:
1215                 yield from self.visit(child)
1216
1217
1218 @dataclass
1219 class DebugVisitor(Visitor[T]):
1220     tree_depth: int = 0
1221
1222     def visit_default(self, node: LN) -> Iterator[T]:
1223         indent = " " * (2 * self.tree_depth)
1224         if isinstance(node, Node):
1225             _type = type_repr(node.type)
1226             out(f"{indent}{_type}", fg="yellow")
1227             self.tree_depth += 1
1228             for child in node.children:
1229                 yield from self.visit(child)
1230
1231             self.tree_depth -= 1
1232             out(f"{indent}/{_type}", fg="yellow", bold=False)
1233         else:
1234             _type = token.tok_name.get(node.type, str(node.type))
1235             out(f"{indent}{_type}", fg="blue", nl=False)
1236             if node.prefix:
1237                 # We don't have to handle prefixes for `Node` objects since
1238                 # that delegates to the first child anyway.
1239                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1240             out(f" {node.value!r}", fg="blue", bold=False)
1241
1242     @classmethod
1243     def show(cls, code: Union[str, Leaf, Node]) -> None:
1244         """Pretty-print the lib2to3 AST of a given string of `code`.
1245
1246         Convenience method for debugging.
1247         """
1248         v: DebugVisitor[None] = DebugVisitor()
1249         if isinstance(code, str):
1250             code = lib2to3_parse(code)
1251         list(v.visit(code))
1252
1253
1254 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1255 STATEMENT: Final = {
1256     syms.if_stmt,
1257     syms.while_stmt,
1258     syms.for_stmt,
1259     syms.try_stmt,
1260     syms.except_clause,
1261     syms.with_stmt,
1262     syms.funcdef,
1263     syms.classdef,
1264 }
1265 STANDALONE_COMMENT: Final = 153
1266 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1267 LOGIC_OPERATORS: Final = {"and", "or"}
1268 COMPARATORS: Final = {
1269     token.LESS,
1270     token.GREATER,
1271     token.EQEQUAL,
1272     token.NOTEQUAL,
1273     token.LESSEQUAL,
1274     token.GREATEREQUAL,
1275 }
1276 MATH_OPERATORS: Final = {
1277     token.VBAR,
1278     token.CIRCUMFLEX,
1279     token.AMPER,
1280     token.LEFTSHIFT,
1281     token.RIGHTSHIFT,
1282     token.PLUS,
1283     token.MINUS,
1284     token.STAR,
1285     token.SLASH,
1286     token.DOUBLESLASH,
1287     token.PERCENT,
1288     token.AT,
1289     token.TILDE,
1290     token.DOUBLESTAR,
1291 }
1292 STARS: Final = {token.STAR, token.DOUBLESTAR}
1293 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1294 VARARGS_PARENTS: Final = {
1295     syms.arglist,
1296     syms.argument,  # double star in arglist
1297     syms.trailer,  # single argument to call
1298     syms.typedargslist,
1299     syms.varargslist,  # lambdas
1300 }
1301 UNPACKING_PARENTS: Final = {
1302     syms.atom,  # single element of a list or set literal
1303     syms.dictsetmaker,
1304     syms.listmaker,
1305     syms.testlist_gexp,
1306     syms.testlist_star_expr,
1307 }
1308 TEST_DESCENDANTS: Final = {
1309     syms.test,
1310     syms.lambdef,
1311     syms.or_test,
1312     syms.and_test,
1313     syms.not_test,
1314     syms.comparison,
1315     syms.star_expr,
1316     syms.expr,
1317     syms.xor_expr,
1318     syms.and_expr,
1319     syms.shift_expr,
1320     syms.arith_expr,
1321     syms.trailer,
1322     syms.term,
1323     syms.power,
1324 }
1325 ASSIGNMENTS: Final = {
1326     "=",
1327     "+=",
1328     "-=",
1329     "*=",
1330     "@=",
1331     "/=",
1332     "%=",
1333     "&=",
1334     "|=",
1335     "^=",
1336     "<<=",
1337     ">>=",
1338     "**=",
1339     "//=",
1340 }
1341 COMPREHENSION_PRIORITY: Final = 20
1342 COMMA_PRIORITY: Final = 18
1343 TERNARY_PRIORITY: Final = 16
1344 LOGIC_PRIORITY: Final = 14
1345 STRING_PRIORITY: Final = 12
1346 COMPARATOR_PRIORITY: Final = 10
1347 MATH_PRIORITIES: Final = {
1348     token.VBAR: 9,
1349     token.CIRCUMFLEX: 8,
1350     token.AMPER: 7,
1351     token.LEFTSHIFT: 6,
1352     token.RIGHTSHIFT: 6,
1353     token.PLUS: 5,
1354     token.MINUS: 5,
1355     token.STAR: 4,
1356     token.SLASH: 4,
1357     token.DOUBLESLASH: 4,
1358     token.PERCENT: 4,
1359     token.AT: 4,
1360     token.TILDE: 3,
1361     token.DOUBLESTAR: 2,
1362 }
1363 DOT_PRIORITY: Final = 1
1364
1365
1366 @dataclass
1367 class BracketTracker:
1368     """Keeps track of brackets on a line."""
1369
1370     depth: int = 0
1371     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1372     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1373     previous: Optional[Leaf] = None
1374     _for_loop_depths: List[int] = field(default_factory=list)
1375     _lambda_argument_depths: List[int] = field(default_factory=list)
1376     invisible: List[Leaf] = field(default_factory=list)
1377
1378     def mark(self, leaf: Leaf) -> None:
1379         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1380
1381         All leaves receive an int `bracket_depth` field that stores how deep
1382         within brackets a given leaf is. 0 means there are no enclosing brackets
1383         that started on this line.
1384
1385         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1386         field that it forms a pair with. This is a one-directional link to
1387         avoid reference cycles.
1388
1389         If a leaf is a delimiter (a token on which Black can split the line if
1390         needed) and it's on depth 0, its `id()` is stored in the tracker's
1391         `delimiters` field.
1392         """
1393         if leaf.type == token.COMMENT:
1394             return
1395
1396         self.maybe_decrement_after_for_loop_variable(leaf)
1397         self.maybe_decrement_after_lambda_arguments(leaf)
1398         if leaf.type in CLOSING_BRACKETS:
1399             self.depth -= 1
1400             try:
1401                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1402             except KeyError as e:
1403                 raise BracketMatchError(
1404                     "Unable to match a closing bracket to the following opening"
1405                     f" bracket: {leaf}"
1406                 ) from e
1407             leaf.opening_bracket = opening_bracket
1408             if not leaf.value:
1409                 self.invisible.append(leaf)
1410         leaf.bracket_depth = self.depth
1411         if self.depth == 0:
1412             delim = is_split_before_delimiter(leaf, self.previous)
1413             if delim and self.previous is not None:
1414                 self.delimiters[id(self.previous)] = delim
1415             else:
1416                 delim = is_split_after_delimiter(leaf, self.previous)
1417                 if delim:
1418                     self.delimiters[id(leaf)] = delim
1419         if leaf.type in OPENING_BRACKETS:
1420             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1421             self.depth += 1
1422             if not leaf.value:
1423                 self.invisible.append(leaf)
1424         self.previous = leaf
1425         self.maybe_increment_lambda_arguments(leaf)
1426         self.maybe_increment_for_loop_variable(leaf)
1427
1428     def any_open_brackets(self) -> bool:
1429         """Return True if there is an yet unmatched open bracket on the line."""
1430         return bool(self.bracket_match)
1431
1432     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1433         """Return the highest priority of a delimiter found on the line.
1434
1435         Values are consistent with what `is_split_*_delimiter()` return.
1436         Raises ValueError on no delimiters.
1437         """
1438         return max(v for k, v in self.delimiters.items() if k not in exclude)
1439
1440     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1441         """Return the number of delimiters with the given `priority`.
1442
1443         If no `priority` is passed, defaults to max priority on the line.
1444         """
1445         if not self.delimiters:
1446             return 0
1447
1448         priority = priority or self.max_delimiter_priority()
1449         return sum(1 for p in self.delimiters.values() if p == priority)
1450
1451     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1452         """In a for loop, or comprehension, the variables are often unpacks.
1453
1454         To avoid splitting on the comma in this situation, increase the depth of
1455         tokens between `for` and `in`.
1456         """
1457         if leaf.type == token.NAME and leaf.value == "for":
1458             self.depth += 1
1459             self._for_loop_depths.append(self.depth)
1460             return True
1461
1462         return False
1463
1464     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1465         """See `maybe_increment_for_loop_variable` above for explanation."""
1466         if (
1467             self._for_loop_depths
1468             and self._for_loop_depths[-1] == self.depth
1469             and leaf.type == token.NAME
1470             and leaf.value == "in"
1471         ):
1472             self.depth -= 1
1473             self._for_loop_depths.pop()
1474             return True
1475
1476         return False
1477
1478     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1479         """In a lambda expression, there might be more than one argument.
1480
1481         To avoid splitting on the comma in this situation, increase the depth of
1482         tokens between `lambda` and `:`.
1483         """
1484         if leaf.type == token.NAME and leaf.value == "lambda":
1485             self.depth += 1
1486             self._lambda_argument_depths.append(self.depth)
1487             return True
1488
1489         return False
1490
1491     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1492         """See `maybe_increment_lambda_arguments` above for explanation."""
1493         if (
1494             self._lambda_argument_depths
1495             and self._lambda_argument_depths[-1] == self.depth
1496             and leaf.type == token.COLON
1497         ):
1498             self.depth -= 1
1499             self._lambda_argument_depths.pop()
1500             return True
1501
1502         return False
1503
1504     def get_open_lsqb(self) -> Optional[Leaf]:
1505         """Return the most recent opening square bracket (if any)."""
1506         return self.bracket_match.get((self.depth - 1, token.RSQB))
1507
1508
1509 @dataclass
1510 class Line:
1511     """Holds leaves and comments. Can be printed with `str(line)`."""
1512
1513     mode: Mode
1514     depth: int = 0
1515     leaves: List[Leaf] = field(default_factory=list)
1516     # keys ordered like `leaves`
1517     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1518     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1519     inside_brackets: bool = False
1520     should_split_rhs: bool = False
1521     magic_trailing_comma: Optional[Leaf] = None
1522
1523     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1524         """Add a new `leaf` to the end of the line.
1525
1526         Unless `preformatted` is True, the `leaf` will receive a new consistent
1527         whitespace prefix and metadata applied by :class:`BracketTracker`.
1528         Trailing commas are maybe removed, unpacked for loop variables are
1529         demoted from being delimiters.
1530
1531         Inline comments are put aside.
1532         """
1533         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1534         if not has_value:
1535             return
1536
1537         if token.COLON == leaf.type and self.is_class_paren_empty:
1538             del self.leaves[-2:]
1539         if self.leaves and not preformatted:
1540             # Note: at this point leaf.prefix should be empty except for
1541             # imports, for which we only preserve newlines.
1542             leaf.prefix += whitespace(
1543                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1544             )
1545         if self.inside_brackets or not preformatted:
1546             self.bracket_tracker.mark(leaf)
1547             if self.mode.magic_trailing_comma:
1548                 if self.has_magic_trailing_comma(leaf):
1549                     self.magic_trailing_comma = leaf
1550             elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
1551                 self.remove_trailing_comma()
1552         if not self.append_comment(leaf):
1553             self.leaves.append(leaf)
1554
1555     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1556         """Like :func:`append()` but disallow invalid standalone comment structure.
1557
1558         Raises ValueError when any `leaf` is appended after a standalone comment
1559         or when a standalone comment is not the first leaf on the line.
1560         """
1561         if self.bracket_tracker.depth == 0:
1562             if self.is_comment:
1563                 raise ValueError("cannot append to standalone comments")
1564
1565             if self.leaves and leaf.type == STANDALONE_COMMENT:
1566                 raise ValueError(
1567                     "cannot append standalone comments to a populated line"
1568                 )
1569
1570         self.append(leaf, preformatted=preformatted)
1571
1572     @property
1573     def is_comment(self) -> bool:
1574         """Is this line a standalone comment?"""
1575         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1576
1577     @property
1578     def is_decorator(self) -> bool:
1579         """Is this line a decorator?"""
1580         return bool(self) and self.leaves[0].type == token.AT
1581
1582     @property
1583     def is_import(self) -> bool:
1584         """Is this an import line?"""
1585         return bool(self) and is_import(self.leaves[0])
1586
1587     @property
1588     def is_class(self) -> bool:
1589         """Is this line a class definition?"""
1590         return (
1591             bool(self)
1592             and self.leaves[0].type == token.NAME
1593             and self.leaves[0].value == "class"
1594         )
1595
1596     @property
1597     def is_stub_class(self) -> bool:
1598         """Is this line a class definition with a body consisting only of "..."?"""
1599         return self.is_class and self.leaves[-3:] == [
1600             Leaf(token.DOT, ".") for _ in range(3)
1601         ]
1602
1603     @property
1604     def is_def(self) -> bool:
1605         """Is this a function definition? (Also returns True for async defs.)"""
1606         try:
1607             first_leaf = self.leaves[0]
1608         except IndexError:
1609             return False
1610
1611         try:
1612             second_leaf: Optional[Leaf] = self.leaves[1]
1613         except IndexError:
1614             second_leaf = None
1615         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1616             first_leaf.type == token.ASYNC
1617             and second_leaf is not None
1618             and second_leaf.type == token.NAME
1619             and second_leaf.value == "def"
1620         )
1621
1622     @property
1623     def is_class_paren_empty(self) -> bool:
1624         """Is this a class with no base classes but using parentheses?
1625
1626         Those are unnecessary and should be removed.
1627         """
1628         return (
1629             bool(self)
1630             and len(self.leaves) == 4
1631             and self.is_class
1632             and self.leaves[2].type == token.LPAR
1633             and self.leaves[2].value == "("
1634             and self.leaves[3].type == token.RPAR
1635             and self.leaves[3].value == ")"
1636         )
1637
1638     @property
1639     def is_triple_quoted_string(self) -> bool:
1640         """Is the line a triple quoted string?"""
1641         return (
1642             bool(self)
1643             and self.leaves[0].type == token.STRING
1644             and self.leaves[0].value.startswith(('"""', "'''"))
1645         )
1646
1647     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1648         """If so, needs to be split before emitting."""
1649         for leaf in self.leaves:
1650             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1651                 return True
1652
1653         return False
1654
1655     def contains_uncollapsable_type_comments(self) -> bool:
1656         ignored_ids = set()
1657         try:
1658             last_leaf = self.leaves[-1]
1659             ignored_ids.add(id(last_leaf))
1660             if last_leaf.type == token.COMMA or (
1661                 last_leaf.type == token.RPAR and not last_leaf.value
1662             ):
1663                 # When trailing commas or optional parens are inserted by Black for
1664                 # consistency, comments after the previous last element are not moved
1665                 # (they don't have to, rendering will still be correct).  So we ignore
1666                 # trailing commas and invisible.
1667                 last_leaf = self.leaves[-2]
1668                 ignored_ids.add(id(last_leaf))
1669         except IndexError:
1670             return False
1671
1672         # A type comment is uncollapsable if it is attached to a leaf
1673         # that isn't at the end of the line (since that could cause it
1674         # to get associated to a different argument) or if there are
1675         # comments before it (since that could cause it to get hidden
1676         # behind a comment.
1677         comment_seen = False
1678         for leaf_id, comments in self.comments.items():
1679             for comment in comments:
1680                 if is_type_comment(comment):
1681                     if comment_seen or (
1682                         not is_type_comment(comment, " ignore")
1683                         and leaf_id not in ignored_ids
1684                     ):
1685                         return True
1686
1687                 comment_seen = True
1688
1689         return False
1690
1691     def contains_unsplittable_type_ignore(self) -> bool:
1692         if not self.leaves:
1693             return False
1694
1695         # If a 'type: ignore' is attached to the end of a line, we
1696         # can't split the line, because we can't know which of the
1697         # subexpressions the ignore was meant to apply to.
1698         #
1699         # We only want this to apply to actual physical lines from the
1700         # original source, though: we don't want the presence of a
1701         # 'type: ignore' at the end of a multiline expression to
1702         # justify pushing it all onto one line. Thus we
1703         # (unfortunately) need to check the actual source lines and
1704         # only report an unsplittable 'type: ignore' if this line was
1705         # one line in the original code.
1706
1707         # Grab the first and last line numbers, skipping generated leaves
1708         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1709         last_line = next(
1710             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1711         )
1712
1713         if first_line == last_line:
1714             # We look at the last two leaves since a comma or an
1715             # invisible paren could have been added at the end of the
1716             # line.
1717             for node in self.leaves[-2:]:
1718                 for comment in self.comments.get(id(node), []):
1719                     if is_type_comment(comment, " ignore"):
1720                         return True
1721
1722         return False
1723
1724     def contains_multiline_strings(self) -> bool:
1725         return any(is_multiline_string(leaf) for leaf in self.leaves)
1726
1727     def has_magic_trailing_comma(
1728         self, closing: Leaf, ensure_removable: bool = False
1729     ) -> bool:
1730         """Return True if we have a magic trailing comma, that is when:
1731         - there's a trailing comma here
1732         - it's not a one-tuple
1733         Additionally, if ensure_removable:
1734         - it's not from square bracket indexing
1735         """
1736         if not (
1737             closing.type in CLOSING_BRACKETS
1738             and self.leaves
1739             and self.leaves[-1].type == token.COMMA
1740         ):
1741             return False
1742
1743         if closing.type == token.RBRACE:
1744             return True
1745
1746         if closing.type == token.RSQB:
1747             if not ensure_removable:
1748                 return True
1749             comma = self.leaves[-1]
1750             return bool(comma.parent and comma.parent.type == syms.listmaker)
1751
1752         if self.is_import:
1753             return True
1754
1755         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1756             return True
1757
1758         return False
1759
1760     def append_comment(self, comment: Leaf) -> bool:
1761         """Add an inline or standalone comment to the line."""
1762         if (
1763             comment.type == STANDALONE_COMMENT
1764             and self.bracket_tracker.any_open_brackets()
1765         ):
1766             comment.prefix = ""
1767             return False
1768
1769         if comment.type != token.COMMENT:
1770             return False
1771
1772         if not self.leaves:
1773             comment.type = STANDALONE_COMMENT
1774             comment.prefix = ""
1775             return False
1776
1777         last_leaf = self.leaves[-1]
1778         if (
1779             last_leaf.type == token.RPAR
1780             and not last_leaf.value
1781             and last_leaf.parent
1782             and len(list(last_leaf.parent.leaves())) <= 3
1783             and not is_type_comment(comment)
1784         ):
1785             # Comments on an optional parens wrapping a single leaf should belong to
1786             # the wrapped node except if it's a type comment. Pinning the comment like
1787             # this avoids unstable formatting caused by comment migration.
1788             if len(self.leaves) < 2:
1789                 comment.type = STANDALONE_COMMENT
1790                 comment.prefix = ""
1791                 return False
1792
1793             last_leaf = self.leaves[-2]
1794         self.comments.setdefault(id(last_leaf), []).append(comment)
1795         return True
1796
1797     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1798         """Generate comments that should appear directly after `leaf`."""
1799         return self.comments.get(id(leaf), [])
1800
1801     def remove_trailing_comma(self) -> None:
1802         """Remove the trailing comma and moves the comments attached to it."""
1803         trailing_comma = self.leaves.pop()
1804         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1805         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1806             trailing_comma_comments
1807         )
1808
1809     def is_complex_subscript(self, leaf: Leaf) -> bool:
1810         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1811         open_lsqb = self.bracket_tracker.get_open_lsqb()
1812         if open_lsqb is None:
1813             return False
1814
1815         subscript_start = open_lsqb.next_sibling
1816
1817         if isinstance(subscript_start, Node):
1818             if subscript_start.type == syms.listmaker:
1819                 return False
1820
1821             if subscript_start.type == syms.subscriptlist:
1822                 subscript_start = child_towards(subscript_start, leaf)
1823         return subscript_start is not None and any(
1824             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1825         )
1826
1827     def clone(self) -> "Line":
1828         return Line(
1829             mode=self.mode,
1830             depth=self.depth,
1831             inside_brackets=self.inside_brackets,
1832             should_split_rhs=self.should_split_rhs,
1833             magic_trailing_comma=self.magic_trailing_comma,
1834         )
1835
1836     def __str__(self) -> str:
1837         """Render the line."""
1838         if not self:
1839             return "\n"
1840
1841         indent = "    " * self.depth
1842         leaves = iter(self.leaves)
1843         first = next(leaves)
1844         res = f"{first.prefix}{indent}{first.value}"
1845         for leaf in leaves:
1846             res += str(leaf)
1847         for comment in itertools.chain.from_iterable(self.comments.values()):
1848             res += str(comment)
1849
1850         return res + "\n"
1851
1852     def __bool__(self) -> bool:
1853         """Return True if the line has leaves or comments."""
1854         return bool(self.leaves or self.comments)
1855
1856
1857 @dataclass
1858 class EmptyLineTracker:
1859     """Provides a stateful method that returns the number of potential extra
1860     empty lines needed before and after the currently processed line.
1861
1862     Note: this tracker works on lines that haven't been split yet.  It assumes
1863     the prefix of the first leaf consists of optional newlines.  Those newlines
1864     are consumed by `maybe_empty_lines()` and included in the computation.
1865     """
1866
1867     is_pyi: bool = False
1868     previous_line: Optional[Line] = None
1869     previous_after: int = 0
1870     previous_defs: List[int] = field(default_factory=list)
1871
1872     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1873         """Return the number of extra empty lines before and after the `current_line`.
1874
1875         This is for separating `def`, `async def` and `class` with extra empty
1876         lines (two on module-level).
1877         """
1878         before, after = self._maybe_empty_lines(current_line)
1879         before = (
1880             # Black should not insert empty lines at the beginning
1881             # of the file
1882             0
1883             if self.previous_line is None
1884             else before - self.previous_after
1885         )
1886         self.previous_after = after
1887         self.previous_line = current_line
1888         return before, after
1889
1890     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1891         max_allowed = 1
1892         if current_line.depth == 0:
1893             max_allowed = 1 if self.is_pyi else 2
1894         if current_line.leaves:
1895             # Consume the first leaf's extra newlines.
1896             first_leaf = current_line.leaves[0]
1897             before = first_leaf.prefix.count("\n")
1898             before = min(before, max_allowed)
1899             first_leaf.prefix = ""
1900         else:
1901             before = 0
1902         depth = current_line.depth
1903         while self.previous_defs and self.previous_defs[-1] >= depth:
1904             self.previous_defs.pop()
1905             if self.is_pyi:
1906                 before = 0 if depth else 1
1907             else:
1908                 before = 1 if depth else 2
1909         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1910             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1911
1912         if (
1913             self.previous_line
1914             and self.previous_line.is_import
1915             and not current_line.is_import
1916             and depth == self.previous_line.depth
1917         ):
1918             return (before or 1), 0
1919
1920         if (
1921             self.previous_line
1922             and self.previous_line.is_class
1923             and current_line.is_triple_quoted_string
1924         ):
1925             return before, 1
1926
1927         return before, 0
1928
1929     def _maybe_empty_lines_for_class_or_def(
1930         self, current_line: Line, before: int
1931     ) -> Tuple[int, int]:
1932         if not current_line.is_decorator:
1933             self.previous_defs.append(current_line.depth)
1934         if self.previous_line is None:
1935             # Don't insert empty lines before the first line in the file.
1936             return 0, 0
1937
1938         if self.previous_line.is_decorator:
1939             if self.is_pyi and current_line.is_stub_class:
1940                 # Insert an empty line after a decorated stub class
1941                 return 0, 1
1942
1943             return 0, 0
1944
1945         if self.previous_line.depth < current_line.depth and (
1946             self.previous_line.is_class or self.previous_line.is_def
1947         ):
1948             return 0, 0
1949
1950         if (
1951             self.previous_line.is_comment
1952             and self.previous_line.depth == current_line.depth
1953             and before == 0
1954         ):
1955             return 0, 0
1956
1957         if self.is_pyi:
1958             if self.previous_line.depth > current_line.depth:
1959                 newlines = 1
1960             elif current_line.is_class or self.previous_line.is_class:
1961                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1962                     # No blank line between classes with an empty body
1963                     newlines = 0
1964                 else:
1965                     newlines = 1
1966             elif (
1967                 current_line.is_def or current_line.is_decorator
1968             ) and not self.previous_line.is_def:
1969                 # Blank line between a block of functions (maybe with preceding
1970                 # decorators) and a block of non-functions
1971                 newlines = 1
1972             else:
1973                 newlines = 0
1974         else:
1975             newlines = 2
1976         if current_line.depth and newlines:
1977             newlines -= 1
1978         return newlines, 0
1979
1980
1981 @dataclass
1982 class LineGenerator(Visitor[Line]):
1983     """Generates reformatted Line objects.  Empty lines are not emitted.
1984
1985     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1986     in ways that will no longer stringify to valid Python code on the tree.
1987     """
1988
1989     mode: Mode
1990     remove_u_prefix: bool = False
1991     current_line: Line = field(init=False)
1992
1993     def line(self, indent: int = 0) -> Iterator[Line]:
1994         """Generate a line.
1995
1996         If the line is empty, only emit if it makes sense.
1997         If the line is too long, split it first and then generate.
1998
1999         If any lines were generated, set up a new current_line.
2000         """
2001         if not self.current_line:
2002             self.current_line.depth += indent
2003             return  # Line is empty, don't emit. Creating a new one unnecessary.
2004
2005         complete_line = self.current_line
2006         self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
2007         yield complete_line
2008
2009     def visit_default(self, node: LN) -> Iterator[Line]:
2010         """Default `visit_*()` implementation. Recurses to children of `node`."""
2011         if isinstance(node, Leaf):
2012             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
2013             for comment in generate_comments(node):
2014                 if any_open_brackets:
2015                     # any comment within brackets is subject to splitting
2016                     self.current_line.append(comment)
2017                 elif comment.type == token.COMMENT:
2018                     # regular trailing comment
2019                     self.current_line.append(comment)
2020                     yield from self.line()
2021
2022                 else:
2023                     # regular standalone comment
2024                     yield from self.line()
2025
2026                     self.current_line.append(comment)
2027                     yield from self.line()
2028
2029             normalize_prefix(node, inside_brackets=any_open_brackets)
2030             if self.mode.string_normalization and node.type == token.STRING:
2031                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
2032                 normalize_string_quotes(node)
2033             if node.type == token.NUMBER:
2034                 normalize_numeric_literal(node)
2035             if node.type not in WHITESPACE:
2036                 self.current_line.append(node)
2037         yield from super().visit_default(node)
2038
2039     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
2040         """Increase indentation level, maybe yield a line."""
2041         # In blib2to3 INDENT never holds comments.
2042         yield from self.line(+1)
2043         yield from self.visit_default(node)
2044
2045     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
2046         """Decrease indentation level, maybe yield a line."""
2047         # The current line might still wait for trailing comments.  At DEDENT time
2048         # there won't be any (they would be prefixes on the preceding NEWLINE).
2049         # Emit the line then.
2050         yield from self.line()
2051
2052         # While DEDENT has no value, its prefix may contain standalone comments
2053         # that belong to the current indentation level.  Get 'em.
2054         yield from self.visit_default(node)
2055
2056         # Finally, emit the dedent.
2057         yield from self.line(-1)
2058
2059     def visit_stmt(
2060         self, node: Node, keywords: Set[str], parens: Set[str]
2061     ) -> Iterator[Line]:
2062         """Visit a statement.
2063
2064         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2065         `def`, `with`, `class`, `assert` and assignments.
2066
2067         The relevant Python language `keywords` for a given statement will be
2068         NAME leaves within it. This methods puts those on a separate line.
2069
2070         `parens` holds a set of string leaf values immediately after which
2071         invisible parens should be put.
2072         """
2073         normalize_invisible_parens(node, parens_after=parens)
2074         for child in node.children:
2075             if child.type == token.NAME and child.value in keywords:  # type: ignore
2076                 yield from self.line()
2077
2078             yield from self.visit(child)
2079
2080     def visit_suite(self, node: Node) -> Iterator[Line]:
2081         """Visit a suite."""
2082         if self.mode.is_pyi and is_stub_suite(node):
2083             yield from self.visit(node.children[2])
2084         else:
2085             yield from self.visit_default(node)
2086
2087     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2088         """Visit a statement without nested statements."""
2089         if first_child_is_arith(node):
2090             wrap_in_parentheses(node, node.children[0], visible=False)
2091         is_suite_like = node.parent and node.parent.type in STATEMENT
2092         if is_suite_like:
2093             if self.mode.is_pyi and is_stub_body(node):
2094                 yield from self.visit_default(node)
2095             else:
2096                 yield from self.line(+1)
2097                 yield from self.visit_default(node)
2098                 yield from self.line(-1)
2099
2100         else:
2101             if (
2102                 not self.mode.is_pyi
2103                 or not node.parent
2104                 or not is_stub_suite(node.parent)
2105             ):
2106                 yield from self.line()
2107             yield from self.visit_default(node)
2108
2109     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2110         """Visit `async def`, `async for`, `async with`."""
2111         yield from self.line()
2112
2113         children = iter(node.children)
2114         for child in children:
2115             yield from self.visit(child)
2116
2117             if child.type == token.ASYNC:
2118                 break
2119
2120         internal_stmt = next(children)
2121         for child in internal_stmt.children:
2122             yield from self.visit(child)
2123
2124     def visit_decorators(self, node: Node) -> Iterator[Line]:
2125         """Visit decorators."""
2126         for child in node.children:
2127             yield from self.line()
2128             yield from self.visit(child)
2129
2130     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2131         """Remove a semicolon and put the other statement on a separate line."""
2132         yield from self.line()
2133
2134     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2135         """End of file. Process outstanding comments and end with a newline."""
2136         yield from self.visit_default(leaf)
2137         yield from self.line()
2138
2139     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2140         if not self.current_line.bracket_tracker.any_open_brackets():
2141             yield from self.line()
2142         yield from self.visit_default(leaf)
2143
2144     def visit_factor(self, node: Node) -> Iterator[Line]:
2145         """Force parentheses between a unary op and a binary power:
2146
2147         -2 ** 8 -> -(2 ** 8)
2148         """
2149         _operator, operand = node.children
2150         if (
2151             operand.type == syms.power
2152             and len(operand.children) == 3
2153             and operand.children[1].type == token.DOUBLESTAR
2154         ):
2155             lpar = Leaf(token.LPAR, "(")
2156             rpar = Leaf(token.RPAR, ")")
2157             index = operand.remove() or 0
2158             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2159         yield from self.visit_default(node)
2160
2161     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2162         if is_docstring(leaf) and "\\\n" not in leaf.value:
2163             # We're ignoring docstrings with backslash newline escapes because changing
2164             # indentation of those changes the AST representation of the code.
2165             prefix = get_string_prefix(leaf.value)
2166             docstring = leaf.value[len(prefix) :]  # Remove the prefix
2167             quote_char = docstring[0]
2168             # A natural way to remove the outer quotes is to do:
2169             #   docstring = docstring.strip(quote_char)
2170             # but that breaks on """""x""" (which is '""x').
2171             # So we actually need to remove the first character and the next two
2172             # characters but only if they are the same as the first.
2173             quote_len = 1 if docstring[1] != quote_char else 3
2174             docstring = docstring[quote_len:-quote_len]
2175
2176             if is_multiline_string(leaf):
2177                 indent = " " * 4 * self.current_line.depth
2178                 docstring = fix_docstring(docstring, indent)
2179             else:
2180                 docstring = docstring.strip()
2181
2182             if docstring:
2183                 # Add some padding if the docstring starts / ends with a quote mark.
2184                 if docstring[0] == quote_char:
2185                     docstring = " " + docstring
2186                 if docstring[-1] == quote_char:
2187                     docstring = docstring + " "
2188             else:
2189                 # Add some padding if the docstring is empty.
2190                 docstring = " "
2191
2192             # We could enforce triple quotes at this point.
2193             quote = quote_char * quote_len
2194             leaf.value = prefix + quote + docstring + quote
2195
2196         yield from self.visit_default(leaf)
2197
2198     def __post_init__(self) -> None:
2199         """You are in a twisty little maze of passages."""
2200         self.current_line = Line(mode=self.mode)
2201
2202         v = self.visit_stmt
2203         Ø: Set[str] = set()
2204         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2205         self.visit_if_stmt = partial(
2206             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2207         )
2208         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2209         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2210         self.visit_try_stmt = partial(
2211             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2212         )
2213         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2214         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2215         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2216         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2217         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2218         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2219         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2220         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2221         self.visit_async_funcdef = self.visit_async_stmt
2222         self.visit_decorated = self.visit_decorators
2223
2224
2225 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2226 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2227 OPENING_BRACKETS = set(BRACKET.keys())
2228 CLOSING_BRACKETS = set(BRACKET.values())
2229 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2230 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2231
2232
2233 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2234     """Return whitespace prefix if needed for the given `leaf`.
2235
2236     `complex_subscript` signals whether the given leaf is part of a subscription
2237     which has non-trivial arguments, like arithmetic expressions or function calls.
2238     """
2239     NO = ""
2240     SPACE = " "
2241     DOUBLESPACE = "  "
2242     t = leaf.type
2243     p = leaf.parent
2244     v = leaf.value
2245     if t in ALWAYS_NO_SPACE:
2246         return NO
2247
2248     if t == token.COMMENT:
2249         return DOUBLESPACE
2250
2251     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2252     if t == token.COLON and p.type not in {
2253         syms.subscript,
2254         syms.subscriptlist,
2255         syms.sliceop,
2256     }:
2257         return NO
2258
2259     prev = leaf.prev_sibling
2260     if not prev:
2261         prevp = preceding_leaf(p)
2262         if not prevp or prevp.type in OPENING_BRACKETS:
2263             return NO
2264
2265         if t == token.COLON:
2266             if prevp.type == token.COLON:
2267                 return NO
2268
2269             elif prevp.type != token.COMMA and not complex_subscript:
2270                 return NO
2271
2272             return SPACE
2273
2274         if prevp.type == token.EQUAL:
2275             if prevp.parent:
2276                 if prevp.parent.type in {
2277                     syms.arglist,
2278                     syms.argument,
2279                     syms.parameters,
2280                     syms.varargslist,
2281                 }:
2282                     return NO
2283
2284                 elif prevp.parent.type == syms.typedargslist:
2285                     # A bit hacky: if the equal sign has whitespace, it means we
2286                     # previously found it's a typed argument.  So, we're using
2287                     # that, too.
2288                     return prevp.prefix
2289
2290         elif prevp.type in VARARGS_SPECIALS:
2291             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2292                 return NO
2293
2294         elif prevp.type == token.COLON:
2295             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2296                 return SPACE if complex_subscript else NO
2297
2298         elif (
2299             prevp.parent
2300             and prevp.parent.type == syms.factor
2301             and prevp.type in MATH_OPERATORS
2302         ):
2303             return NO
2304
2305         elif (
2306             prevp.type == token.RIGHTSHIFT
2307             and prevp.parent
2308             and prevp.parent.type == syms.shift_expr
2309             and prevp.prev_sibling
2310             and prevp.prev_sibling.type == token.NAME
2311             and prevp.prev_sibling.value == "print"  # type: ignore
2312         ):
2313             # Python 2 print chevron
2314             return NO
2315         elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
2316             # no space in decorators
2317             return NO
2318
2319     elif prev.type in OPENING_BRACKETS:
2320         return NO
2321
2322     if p.type in {syms.parameters, syms.arglist}:
2323         # untyped function signatures or calls
2324         if not prev or prev.type != token.COMMA:
2325             return NO
2326
2327     elif p.type == syms.varargslist:
2328         # lambdas
2329         if prev and prev.type != token.COMMA:
2330             return NO
2331
2332     elif p.type == syms.typedargslist:
2333         # typed function signatures
2334         if not prev:
2335             return NO
2336
2337         if t == token.EQUAL:
2338             if prev.type != syms.tname:
2339                 return NO
2340
2341         elif prev.type == token.EQUAL:
2342             # A bit hacky: if the equal sign has whitespace, it means we
2343             # previously found it's a typed argument.  So, we're using that, too.
2344             return prev.prefix
2345
2346         elif prev.type != token.COMMA:
2347             return NO
2348
2349     elif p.type == syms.tname:
2350         # type names
2351         if not prev:
2352             prevp = preceding_leaf(p)
2353             if not prevp or prevp.type != token.COMMA:
2354                 return NO
2355
2356     elif p.type == syms.trailer:
2357         # attributes and calls
2358         if t == token.LPAR or t == token.RPAR:
2359             return NO
2360
2361         if not prev:
2362             if t == token.DOT:
2363                 prevp = preceding_leaf(p)
2364                 if not prevp or prevp.type != token.NUMBER:
2365                     return NO
2366
2367             elif t == token.LSQB:
2368                 return NO
2369
2370         elif prev.type != token.COMMA:
2371             return NO
2372
2373     elif p.type == syms.argument:
2374         # single argument
2375         if t == token.EQUAL:
2376             return NO
2377
2378         if not prev:
2379             prevp = preceding_leaf(p)
2380             if not prevp or prevp.type == token.LPAR:
2381                 return NO
2382
2383         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2384             return NO
2385
2386     elif p.type == syms.decorator:
2387         # decorators
2388         return NO
2389
2390     elif p.type == syms.dotted_name:
2391         if prev:
2392             return NO
2393
2394         prevp = preceding_leaf(p)
2395         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2396             return NO
2397
2398     elif p.type == syms.classdef:
2399         if t == token.LPAR:
2400             return NO
2401
2402         if prev and prev.type == token.LPAR:
2403             return NO
2404
2405     elif p.type in {syms.subscript, syms.sliceop}:
2406         # indexing
2407         if not prev:
2408             assert p.parent is not None, "subscripts are always parented"
2409             if p.parent.type == syms.subscriptlist:
2410                 return SPACE
2411
2412             return NO
2413
2414         elif not complex_subscript:
2415             return NO
2416
2417     elif p.type == syms.atom:
2418         if prev and t == token.DOT:
2419             # dots, but not the first one.
2420             return NO
2421
2422     elif p.type == syms.dictsetmaker:
2423         # dict unpacking
2424         if prev and prev.type == token.DOUBLESTAR:
2425             return NO
2426
2427     elif p.type in {syms.factor, syms.star_expr}:
2428         # unary ops
2429         if not prev:
2430             prevp = preceding_leaf(p)
2431             if not prevp or prevp.type in OPENING_BRACKETS:
2432                 return NO
2433
2434             prevp_parent = prevp.parent
2435             assert prevp_parent is not None
2436             if prevp.type == token.COLON and prevp_parent.type in {
2437                 syms.subscript,
2438                 syms.sliceop,
2439             }:
2440                 return NO
2441
2442             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2443                 return NO
2444
2445         elif t in {token.NAME, token.NUMBER, token.STRING}:
2446             return NO
2447
2448     elif p.type == syms.import_from:
2449         if t == token.DOT:
2450             if prev and prev.type == token.DOT:
2451                 return NO
2452
2453         elif t == token.NAME:
2454             if v == "import":
2455                 return SPACE
2456
2457             if prev and prev.type == token.DOT:
2458                 return NO
2459
2460     elif p.type == syms.sliceop:
2461         return NO
2462
2463     return SPACE
2464
2465
2466 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2467     """Return the first leaf that precedes `node`, if any."""
2468     while node:
2469         res = node.prev_sibling
2470         if res:
2471             if isinstance(res, Leaf):
2472                 return res
2473
2474             try:
2475                 return list(res.leaves())[-1]
2476
2477             except IndexError:
2478                 return None
2479
2480         node = node.parent
2481     return None
2482
2483
2484 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2485     """Return if the `node` and its previous siblings match types against the provided
2486     list of tokens; the provided `node`has its type matched against the last element in
2487     the list.  `None` can be used as the first element to declare that the start of the
2488     list is anchored at the start of its parent's children."""
2489     if not tokens:
2490         return True
2491     if tokens[-1] is None:
2492         return node is None
2493     if not node:
2494         return False
2495     if node.type != tokens[-1]:
2496         return False
2497     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2498
2499
2500 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2501     """Return the child of `ancestor` that contains `descendant`."""
2502     node: Optional[LN] = descendant
2503     while node and node.parent != ancestor:
2504         node = node.parent
2505     return node
2506
2507
2508 def container_of(leaf: Leaf) -> LN:
2509     """Return `leaf` or one of its ancestors that is the topmost container of it.
2510
2511     By "container" we mean a node where `leaf` is the very first child.
2512     """
2513     same_prefix = leaf.prefix
2514     container: LN = leaf
2515     while container:
2516         parent = container.parent
2517         if parent is None:
2518             break
2519
2520         if parent.children[0].prefix != same_prefix:
2521             break
2522
2523         if parent.type == syms.file_input:
2524             break
2525
2526         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2527             break
2528
2529         container = parent
2530     return container
2531
2532
2533 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2534     """Return the priority of the `leaf` delimiter, given a line break after it.
2535
2536     The delimiter priorities returned here are from those delimiters that would
2537     cause a line break after themselves.
2538
2539     Higher numbers are higher priority.
2540     """
2541     if leaf.type == token.COMMA:
2542         return COMMA_PRIORITY
2543
2544     return 0
2545
2546
2547 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2548     """Return the priority of the `leaf` delimiter, given a line break before it.
2549
2550     The delimiter priorities returned here are from those delimiters that would
2551     cause a line break before themselves.
2552
2553     Higher numbers are higher priority.
2554     """
2555     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2556         # * and ** might also be MATH_OPERATORS but in this case they are not.
2557         # Don't treat them as a delimiter.
2558         return 0
2559
2560     if (
2561         leaf.type == token.DOT
2562         and leaf.parent
2563         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2564         and (previous is None or previous.type in CLOSING_BRACKETS)
2565     ):
2566         return DOT_PRIORITY
2567
2568     if (
2569         leaf.type in MATH_OPERATORS
2570         and leaf.parent
2571         and leaf.parent.type not in {syms.factor, syms.star_expr}
2572     ):
2573         return MATH_PRIORITIES[leaf.type]
2574
2575     if leaf.type in COMPARATORS:
2576         return COMPARATOR_PRIORITY
2577
2578     if (
2579         leaf.type == token.STRING
2580         and previous is not None
2581         and previous.type == token.STRING
2582     ):
2583         return STRING_PRIORITY
2584
2585     if leaf.type not in {token.NAME, token.ASYNC}:
2586         return 0
2587
2588     if (
2589         leaf.value == "for"
2590         and leaf.parent
2591         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2592         or leaf.type == token.ASYNC
2593     ):
2594         if (
2595             not isinstance(leaf.prev_sibling, Leaf)
2596             or leaf.prev_sibling.value != "async"
2597         ):
2598             return COMPREHENSION_PRIORITY
2599
2600     if (
2601         leaf.value == "if"
2602         and leaf.parent
2603         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2604     ):
2605         return COMPREHENSION_PRIORITY
2606
2607     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2608         return TERNARY_PRIORITY
2609
2610     if leaf.value == "is":
2611         return COMPARATOR_PRIORITY
2612
2613     if (
2614         leaf.value == "in"
2615         and leaf.parent
2616         and leaf.parent.type in {syms.comp_op, syms.comparison}
2617         and not (
2618             previous is not None
2619             and previous.type == token.NAME
2620             and previous.value == "not"
2621         )
2622     ):
2623         return COMPARATOR_PRIORITY
2624
2625     if (
2626         leaf.value == "not"
2627         and leaf.parent
2628         and leaf.parent.type == syms.comp_op
2629         and not (
2630             previous is not None
2631             and previous.type == token.NAME
2632             and previous.value == "is"
2633         )
2634     ):
2635         return COMPARATOR_PRIORITY
2636
2637     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2638         return LOGIC_PRIORITY
2639
2640     return 0
2641
2642
2643 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2644 FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
2645 FMT_PASS = {*FMT_OFF, *FMT_SKIP}
2646 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2647
2648
2649 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2650     """Clean the prefix of the `leaf` and generate comments from it, if any.
2651
2652     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2653     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2654     move because it does away with modifying the grammar to include all the
2655     possible places in which comments can be placed.
2656
2657     The sad consequence for us though is that comments don't "belong" anywhere.
2658     This is why this function generates simple parentless Leaf objects for
2659     comments.  We simply don't know what the correct parent should be.
2660
2661     No matter though, we can live without this.  We really only need to
2662     differentiate between inline and standalone comments.  The latter don't
2663     share the line with any code.
2664
2665     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2666     are emitted with a fake STANDALONE_COMMENT token identifier.
2667     """
2668     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2669         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2670
2671
2672 @dataclass
2673 class ProtoComment:
2674     """Describes a piece of syntax that is a comment.
2675
2676     It's not a :class:`blib2to3.pytree.Leaf` so that:
2677
2678     * it can be cached (`Leaf` objects should not be reused more than once as
2679       they store their lineno, column, prefix, and parent information);
2680     * `newlines` and `consumed` fields are kept separate from the `value`. This
2681       simplifies handling of special marker comments like ``# fmt: off/on``.
2682     """
2683
2684     type: int  # token.COMMENT or STANDALONE_COMMENT
2685     value: str  # content of the comment
2686     newlines: int  # how many newlines before the comment
2687     consumed: int  # how many characters of the original leaf's prefix did we consume
2688
2689
2690 @lru_cache(maxsize=4096)
2691 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2692     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2693     result: List[ProtoComment] = []
2694     if not prefix or "#" not in prefix:
2695         return result
2696
2697     consumed = 0
2698     nlines = 0
2699     ignored_lines = 0
2700     for index, line in enumerate(re.split("\r?\n", prefix)):
2701         consumed += len(line) + 1  # adding the length of the split '\n'
2702         line = line.lstrip()
2703         if not line:
2704             nlines += 1
2705         if not line.startswith("#"):
2706             # Escaped newlines outside of a comment are not really newlines at
2707             # all. We treat a single-line comment following an escaped newline
2708             # as a simple trailing comment.
2709             if line.endswith("\\"):
2710                 ignored_lines += 1
2711             continue
2712
2713         if index == ignored_lines and not is_endmarker:
2714             comment_type = token.COMMENT  # simple trailing comment
2715         else:
2716             comment_type = STANDALONE_COMMENT
2717         comment = make_comment(line)
2718         result.append(
2719             ProtoComment(
2720                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2721             )
2722         )
2723         nlines = 0
2724     return result
2725
2726
2727 def make_comment(content: str) -> str:
2728     """Return a consistently formatted comment from the given `content` string.
2729
2730     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2731     space between the hash sign and the content.
2732
2733     If `content` didn't start with a hash sign, one is provided.
2734     """
2735     content = content.rstrip()
2736     if not content:
2737         return "#"
2738
2739     if content[0] == "#":
2740         content = content[1:]
2741     NON_BREAKING_SPACE = " "
2742     if (
2743         content
2744         and content[0] == NON_BREAKING_SPACE
2745         and not content.lstrip().startswith("type:")
2746     ):
2747         content = " " + content[1:]  # Replace NBSP by a simple space
2748     if content and content[0] not in " !:#'%":
2749         content = " " + content
2750     return "#" + content
2751
2752
2753 def transform_line(
2754     line: Line, mode: Mode, features: Collection[Feature] = ()
2755 ) -> Iterator[Line]:
2756     """Transform a `line`, potentially splitting it into many lines.
2757
2758     They should fit in the allotted `line_length` but might not be able to.
2759
2760     `features` are syntactical features that may be used in the output.
2761     """
2762     if line.is_comment:
2763         yield line
2764         return
2765
2766     line_str = line_to_string(line)
2767
2768     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2769         """Initialize StringTransformer"""
2770         return ST(mode.line_length, mode.string_normalization)
2771
2772     string_merge = init_st(StringMerger)
2773     string_paren_strip = init_st(StringParenStripper)
2774     string_split = init_st(StringSplitter)
2775     string_paren_wrap = init_st(StringParenWrapper)
2776
2777     transformers: List[Transformer]
2778     if (
2779         not line.contains_uncollapsable_type_comments()
2780         and not line.should_split_rhs
2781         and not line.magic_trailing_comma
2782         and (
2783             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2784             or line.contains_unsplittable_type_ignore()
2785         )
2786         and not (line.inside_brackets and line.contains_standalone_comments())
2787     ):
2788         # Only apply basic string preprocessing, since lines shouldn't be split here.
2789         if mode.experimental_string_processing:
2790             transformers = [string_merge, string_paren_strip]
2791         else:
2792             transformers = []
2793     elif line.is_def:
2794         transformers = [left_hand_split]
2795     else:
2796
2797         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2798             """Wraps calls to `right_hand_split`.
2799
2800             The calls increasingly `omit` right-hand trailers (bracket pairs with
2801             content), meaning the trailers get glued together to split on another
2802             bracket pair instead.
2803             """
2804             for omit in generate_trailers_to_omit(line, mode.line_length):
2805                 lines = list(
2806                     right_hand_split(line, mode.line_length, features, omit=omit)
2807                 )
2808                 # Note: this check is only able to figure out if the first line of the
2809                 # *current* transformation fits in the line length.  This is true only
2810                 # for simple cases.  All others require running more transforms via
2811                 # `transform_line()`.  This check doesn't know if those would succeed.
2812                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2813                     yield from lines
2814                     return
2815
2816             # All splits failed, best effort split with no omits.
2817             # This mostly happens to multiline strings that are by definition
2818             # reported as not fitting a single line, as well as lines that contain
2819             # trailing commas (those have to be exploded).
2820             yield from right_hand_split(
2821                 line, line_length=mode.line_length, features=features
2822             )
2823
2824         if mode.experimental_string_processing:
2825             if line.inside_brackets:
2826                 transformers = [
2827                     string_merge,
2828                     string_paren_strip,
2829                     string_split,
2830                     delimiter_split,
2831                     standalone_comment_split,
2832                     string_paren_wrap,
2833                     rhs,
2834                 ]
2835             else:
2836                 transformers = [
2837                     string_merge,
2838                     string_paren_strip,
2839                     string_split,
2840                     string_paren_wrap,
2841                     rhs,
2842                 ]
2843         else:
2844             if line.inside_brackets:
2845                 transformers = [delimiter_split, standalone_comment_split, rhs]
2846             else:
2847                 transformers = [rhs]
2848
2849     for transform in transformers:
2850         # We are accumulating lines in `result` because we might want to abort
2851         # mission and return the original line in the end, or attempt a different
2852         # split altogether.
2853         try:
2854             result = run_transformer(line, transform, mode, features, line_str=line_str)
2855         except CannotTransform:
2856             continue
2857         else:
2858             yield from result
2859             break
2860
2861     else:
2862         yield line
2863
2864
2865 @dataclass  # type: ignore
2866 class StringTransformer(ABC):
2867     """
2868     An implementation of the Transformer protocol that relies on its
2869     subclasses overriding the template methods `do_match(...)` and
2870     `do_transform(...)`.
2871
2872     This Transformer works exclusively on strings (for example, by merging
2873     or splitting them).
2874
2875     The following sections can be found among the docstrings of each concrete
2876     StringTransformer subclass.
2877
2878     Requirements:
2879         Which requirements must be met of the given Line for this
2880         StringTransformer to be applied?
2881
2882     Transformations:
2883         If the given Line meets all of the above requirements, which string
2884         transformations can you expect to be applied to it by this
2885         StringTransformer?
2886
2887     Collaborations:
2888         What contractual agreements does this StringTransformer have with other
2889         StringTransfomers? Such collaborations should be eliminated/minimized
2890         as much as possible.
2891     """
2892
2893     line_length: int
2894     normalize_strings: bool
2895     __name__ = "StringTransformer"
2896
2897     @abstractmethod
2898     def do_match(self, line: Line) -> TMatchResult:
2899         """
2900         Returns:
2901             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2902             string, if a match was able to be made.
2903                 OR
2904             * Err(CannotTransform), if a match was not able to be made.
2905         """
2906
2907     @abstractmethod
2908     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2909         """
2910         Yields:
2911             * Ok(new_line) where new_line is the new transformed line.
2912                 OR
2913             * Err(CannotTransform) if the transformation failed for some reason. The
2914             `do_match(...)` template method should usually be used to reject
2915             the form of the given Line, but in some cases it is difficult to
2916             know whether or not a Line meets the StringTransformer's
2917             requirements until the transformation is already midway.
2918
2919         Side Effects:
2920             This method should NOT mutate @line directly, but it MAY mutate the
2921             Line's underlying Node structure. (WARNING: If the underlying Node
2922             structure IS altered, then this method should NOT be allowed to
2923             yield an CannotTransform after that point.)
2924         """
2925
2926     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2927         """
2928         StringTransformer instances have a call signature that mirrors that of
2929         the Transformer type.
2930
2931         Raises:
2932             CannotTransform(...) if the concrete StringTransformer class is unable
2933             to transform @line.
2934         """
2935         # Optimization to avoid calling `self.do_match(...)` when the line does
2936         # not contain any string.
2937         if not any(leaf.type == token.STRING for leaf in line.leaves):
2938             raise CannotTransform("There are no strings in this line.")
2939
2940         match_result = self.do_match(line)
2941
2942         if isinstance(match_result, Err):
2943             cant_transform = match_result.err()
2944             raise CannotTransform(
2945                 f"The string transformer {self.__class__.__name__} does not recognize"
2946                 " this line as one that it can transform."
2947             ) from cant_transform
2948
2949         string_idx = match_result.ok()
2950
2951         for line_result in self.do_transform(line, string_idx):
2952             if isinstance(line_result, Err):
2953                 cant_transform = line_result.err()
2954                 raise CannotTransform(
2955                     "StringTransformer failed while attempting to transform string."
2956                 ) from cant_transform
2957             line = line_result.ok()
2958             yield line
2959
2960
2961 @dataclass
2962 class CustomSplit:
2963     """A custom (i.e. manual) string split.
2964
2965     A single CustomSplit instance represents a single substring.
2966
2967     Examples:
2968         Consider the following string:
2969         ```
2970         "Hi there friend."
2971         " This is a custom"
2972         f" string {split}."
2973         ```
2974
2975         This string will correspond to the following three CustomSplit instances:
2976         ```
2977         CustomSplit(False, 16)
2978         CustomSplit(False, 17)
2979         CustomSplit(True, 16)
2980         ```
2981     """
2982
2983     has_prefix: bool
2984     break_idx: int
2985
2986
2987 class CustomSplitMapMixin:
2988     """
2989     This mixin class is used to map merged strings to a sequence of
2990     CustomSplits, which will then be used to re-split the strings iff none of
2991     the resultant substrings go over the configured max line length.
2992     """
2993
2994     _Key = Tuple[StringID, str]
2995     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2996
2997     @staticmethod
2998     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2999         """
3000         Returns:
3001             A unique identifier that is used internally to map @string to a
3002             group of custom splits.
3003         """
3004         return (id(string), string)
3005
3006     def add_custom_splits(
3007         self, string: str, custom_splits: Iterable[CustomSplit]
3008     ) -> None:
3009         """Custom Split Map Setter Method
3010
3011         Side Effects:
3012             Adds a mapping from @string to the custom splits @custom_splits.
3013         """
3014         key = self._get_key(string)
3015         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
3016
3017     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
3018         """Custom Split Map Getter Method
3019
3020         Returns:
3021             * A list of the custom splits that are mapped to @string, if any
3022             exist.
3023                 OR
3024             * [], otherwise.
3025
3026         Side Effects:
3027             Deletes the mapping between @string and its associated custom
3028             splits (which are returned to the caller).
3029         """
3030         key = self._get_key(string)
3031
3032         custom_splits = self._CUSTOM_SPLIT_MAP[key]
3033         del self._CUSTOM_SPLIT_MAP[key]
3034
3035         return list(custom_splits)
3036
3037     def has_custom_splits(self, string: str) -> bool:
3038         """
3039         Returns:
3040             True iff @string is associated with a set of custom splits.
3041         """
3042         key = self._get_key(string)
3043         return key in self._CUSTOM_SPLIT_MAP
3044
3045
3046 class StringMerger(CustomSplitMapMixin, StringTransformer):
3047     """StringTransformer that merges strings together.
3048
3049     Requirements:
3050         (A) The line contains adjacent strings such that ALL of the validation checks
3051         listed in StringMerger.__validate_msg(...)'s docstring pass.
3052             OR
3053         (B) The line contains a string which uses line continuation backslashes.
3054
3055     Transformations:
3056         Depending on which of the two requirements above where met, either:
3057
3058         (A) The string group associated with the target string is merged.
3059             OR
3060         (B) All line-continuation backslashes are removed from the target string.
3061
3062     Collaborations:
3063         StringMerger provides custom split information to StringSplitter.
3064     """
3065
3066     def do_match(self, line: Line) -> TMatchResult:
3067         LL = line.leaves
3068
3069         is_valid_index = is_valid_index_factory(LL)
3070
3071         for (i, leaf) in enumerate(LL):
3072             if (
3073                 leaf.type == token.STRING
3074                 and is_valid_index(i + 1)
3075                 and LL[i + 1].type == token.STRING
3076             ):
3077                 return Ok(i)
3078
3079             if leaf.type == token.STRING and "\\\n" in leaf.value:
3080                 return Ok(i)
3081
3082         return TErr("This line has no strings that need merging.")
3083
3084     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3085         new_line = line
3086         rblc_result = self.__remove_backslash_line_continuation_chars(
3087             new_line, string_idx
3088         )
3089         if isinstance(rblc_result, Ok):
3090             new_line = rblc_result.ok()
3091
3092         msg_result = self.__merge_string_group(new_line, string_idx)
3093         if isinstance(msg_result, Ok):
3094             new_line = msg_result.ok()
3095
3096         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
3097             msg_cant_transform = msg_result.err()
3098             rblc_cant_transform = rblc_result.err()
3099             cant_transform = CannotTransform(
3100                 "StringMerger failed to merge any strings in this line."
3101             )
3102
3103             # Chain the errors together using `__cause__`.
3104             msg_cant_transform.__cause__ = rblc_cant_transform
3105             cant_transform.__cause__ = msg_cant_transform
3106
3107             yield Err(cant_transform)
3108         else:
3109             yield Ok(new_line)
3110
3111     @staticmethod
3112     def __remove_backslash_line_continuation_chars(
3113         line: Line, string_idx: int
3114     ) -> TResult[Line]:
3115         """
3116         Merge strings that were split across multiple lines using
3117         line-continuation backslashes.
3118
3119         Returns:
3120             Ok(new_line), if @line contains backslash line-continuation
3121             characters.
3122                 OR
3123             Err(CannotTransform), otherwise.
3124         """
3125         LL = line.leaves
3126
3127         string_leaf = LL[string_idx]
3128         if not (
3129             string_leaf.type == token.STRING
3130             and "\\\n" in string_leaf.value
3131             and not has_triple_quotes(string_leaf.value)
3132         ):
3133             return TErr(
3134                 f"String leaf {string_leaf} does not contain any backslash line"
3135                 " continuation characters."
3136             )
3137
3138         new_line = line.clone()
3139         new_line.comments = line.comments.copy()
3140         append_leaves(new_line, line, LL)
3141
3142         new_string_leaf = new_line.leaves[string_idx]
3143         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3144
3145         return Ok(new_line)
3146
3147     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3148         """
3149         Merges string group (i.e. set of adjacent strings) where the first
3150         string in the group is `line.leaves[string_idx]`.
3151
3152         Returns:
3153             Ok(new_line), if ALL of the validation checks found in
3154             __validate_msg(...) pass.
3155                 OR
3156             Err(CannotTransform), otherwise.
3157         """
3158         LL = line.leaves
3159
3160         is_valid_index = is_valid_index_factory(LL)
3161
3162         vresult = self.__validate_msg(line, string_idx)
3163         if isinstance(vresult, Err):
3164             return vresult
3165
3166         # If the string group is wrapped inside an Atom node, we must make sure
3167         # to later replace that Atom with our new (merged) string leaf.
3168         atom_node = LL[string_idx].parent
3169
3170         # We will place BREAK_MARK in between every two substrings that we
3171         # merge. We will then later go through our final result and use the
3172         # various instances of BREAK_MARK we find to add the right values to
3173         # the custom split map.
3174         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3175
3176         QUOTE = LL[string_idx].value[-1]
3177
3178         def make_naked(string: str, string_prefix: str) -> str:
3179             """Strip @string (i.e. make it a "naked" string)
3180
3181             Pre-conditions:
3182                 * assert_is_leaf_string(@string)
3183
3184             Returns:
3185                 A string that is identical to @string except that
3186                 @string_prefix has been stripped, the surrounding QUOTE
3187                 characters have been removed, and any remaining QUOTE
3188                 characters have been escaped.
3189             """
3190             assert_is_leaf_string(string)
3191
3192             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3193             naked_string = string[len(string_prefix) + 1 : -1]
3194             naked_string = re.sub(
3195                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3196             )
3197             return naked_string
3198
3199         # Holds the CustomSplit objects that will later be added to the custom
3200         # split map.
3201         custom_splits = []
3202
3203         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3204         prefix_tracker = []
3205
3206         # Sets the 'prefix' variable. This is the prefix that the final merged
3207         # string will have.
3208         next_str_idx = string_idx
3209         prefix = ""
3210         while (
3211             not prefix
3212             and is_valid_index(next_str_idx)
3213             and LL[next_str_idx].type == token.STRING
3214         ):
3215             prefix = get_string_prefix(LL[next_str_idx].value)
3216             next_str_idx += 1
3217
3218         # The next loop merges the string group. The final string will be
3219         # contained in 'S'.
3220         #
3221         # The following convenience variables are used:
3222         #
3223         #   S: string
3224         #   NS: naked string
3225         #   SS: next string
3226         #   NSS: naked next string
3227         S = ""
3228         NS = ""
3229         num_of_strings = 0
3230         next_str_idx = string_idx
3231         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3232             num_of_strings += 1
3233
3234             SS = LL[next_str_idx].value
3235             next_prefix = get_string_prefix(SS)
3236
3237             # If this is an f-string group but this substring is not prefixed
3238             # with 'f'...
3239             if "f" in prefix and "f" not in next_prefix:
3240                 # Then we must escape any braces contained in this substring.
3241                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3242
3243             NSS = make_naked(SS, next_prefix)
3244
3245             has_prefix = bool(next_prefix)
3246             prefix_tracker.append(has_prefix)
3247
3248             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3249             NS = make_naked(S, prefix)
3250
3251             next_str_idx += 1
3252
3253         S_leaf = Leaf(token.STRING, S)
3254         if self.normalize_strings:
3255             normalize_string_quotes(S_leaf)
3256
3257         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3258         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3259         for has_prefix in prefix_tracker:
3260             mark_idx = temp_string.find(BREAK_MARK)
3261             assert (
3262                 mark_idx >= 0
3263             ), "Logic error while filling the custom string breakpoint cache."
3264
3265             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3266             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3267             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3268
3269         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3270
3271         if atom_node is not None:
3272             replace_child(atom_node, string_leaf)
3273
3274         # Build the final line ('new_line') that this method will later return.
3275         new_line = line.clone()
3276         for (i, leaf) in enumerate(LL):
3277             if i == string_idx:
3278                 new_line.append(string_leaf)
3279
3280             if string_idx <= i < string_idx + num_of_strings:
3281                 for comment_leaf in line.comments_after(LL[i]):
3282                     new_line.append(comment_leaf, preformatted=True)
3283                 continue
3284
3285             append_leaves(new_line, line, [leaf])
3286
3287         self.add_custom_splits(string_leaf.value, custom_splits)
3288         return Ok(new_line)
3289
3290     @staticmethod
3291     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3292         """Validate (M)erge (S)tring (G)roup
3293
3294         Transform-time string validation logic for __merge_string_group(...).
3295
3296         Returns:
3297             * Ok(None), if ALL validation checks (listed below) pass.
3298                 OR
3299             * Err(CannotTransform), if any of the following are true:
3300                 - The target string group does not contain ANY stand-alone comments.
3301                 - The target string is not in a string group (i.e. it has no
3302                   adjacent strings).
3303                 - The string group has more than one inline comment.
3304                 - The string group has an inline comment that appears to be a pragma.
3305                 - The set of all string prefixes in the string group is of
3306                   length greater than one and is not equal to {"", "f"}.
3307                 - The string group consists of raw strings.
3308         """
3309         # We first check for "inner" stand-alone comments (i.e. stand-alone
3310         # comments that have a string leaf before them AND after them).
3311         for inc in [1, -1]:
3312             i = string_idx
3313             found_sa_comment = False
3314             is_valid_index = is_valid_index_factory(line.leaves)
3315             while is_valid_index(i) and line.leaves[i].type in [
3316                 token.STRING,
3317                 STANDALONE_COMMENT,
3318             ]:
3319                 if line.leaves[i].type == STANDALONE_COMMENT:
3320                     found_sa_comment = True
3321                 elif found_sa_comment:
3322                     return TErr(
3323                         "StringMerger does NOT merge string groups which contain "
3324                         "stand-alone comments."
3325                     )
3326
3327                 i += inc
3328
3329         num_of_inline_string_comments = 0
3330         set_of_prefixes = set()
3331         num_of_strings = 0
3332         for leaf in line.leaves[string_idx:]:
3333             if leaf.type != token.STRING:
3334                 # If the string group is trailed by a comma, we count the
3335                 # comments trailing the comma to be one of the string group's
3336                 # comments.
3337                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3338                     num_of_inline_string_comments += 1
3339                 break
3340
3341             if has_triple_quotes(leaf.value):
3342                 return TErr("StringMerger does NOT merge multiline strings.")
3343
3344             num_of_strings += 1
3345             prefix = get_string_prefix(leaf.value)
3346             if "r" in prefix:
3347                 return TErr("StringMerger does NOT merge raw strings.")
3348
3349             set_of_prefixes.add(prefix)
3350
3351             if id(leaf) in line.comments:
3352                 num_of_inline_string_comments += 1
3353                 if contains_pragma_comment(line.comments[id(leaf)]):
3354                     return TErr("Cannot merge strings which have pragma comments.")
3355
3356         if num_of_strings < 2:
3357             return TErr(
3358                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3359             )
3360
3361         if num_of_inline_string_comments > 1:
3362             return TErr(
3363                 f"Too many inline string comments ({num_of_inline_string_comments})."
3364             )
3365
3366         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3367             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3368
3369         return Ok(None)
3370
3371
3372 class StringParenStripper(StringTransformer):
3373     """StringTransformer that strips surrounding parentheses from strings.
3374
3375     Requirements:
3376         The line contains a string which is surrounded by parentheses and:
3377             - The target string is NOT the only argument to a function call.
3378             - The target string is NOT a "pointless" string.
3379             - If the target string contains a PERCENT, the brackets are not
3380               preceeded or followed by an operator with higher precedence than
3381               PERCENT.
3382
3383     Transformations:
3384         The parentheses mentioned in the 'Requirements' section are stripped.
3385
3386     Collaborations:
3387         StringParenStripper has its own inherent usefulness, but it is also
3388         relied on to clean up the parentheses created by StringParenWrapper (in
3389         the event that they are no longer needed).
3390     """
3391
3392     def do_match(self, line: Line) -> TMatchResult:
3393         LL = line.leaves
3394
3395         is_valid_index = is_valid_index_factory(LL)
3396
3397         for (idx, leaf) in enumerate(LL):
3398             # Should be a string...
3399             if leaf.type != token.STRING:
3400                 continue
3401
3402             # If this is a "pointless" string...
3403             if (
3404                 leaf.parent
3405                 and leaf.parent.parent
3406                 and leaf.parent.parent.type == syms.simple_stmt
3407             ):
3408                 continue
3409
3410             # Should be preceded by a non-empty LPAR...
3411             if (
3412                 not is_valid_index(idx - 1)
3413                 or LL[idx - 1].type != token.LPAR
3414                 or is_empty_lpar(LL[idx - 1])
3415             ):
3416                 continue
3417
3418             # That LPAR should NOT be preceded by a function name or a closing
3419             # bracket (which could be a function which returns a function or a
3420             # list/dictionary that contains a function)...
3421             if is_valid_index(idx - 2) and (
3422                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3423             ):
3424                 continue
3425
3426             string_idx = idx
3427
3428             # Skip the string trailer, if one exists.
3429             string_parser = StringParser()
3430             next_idx = string_parser.parse(LL, string_idx)
3431
3432             # if the leaves in the parsed string include a PERCENT, we need to
3433             # make sure the initial LPAR is NOT preceded by an operator with
3434             # higher or equal precedence to PERCENT
3435             if is_valid_index(idx - 2):
3436                 # mypy can't quite follow unless we name this
3437                 before_lpar = LL[idx - 2]
3438                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3439                     (
3440                         before_lpar.type
3441                         in {
3442                             token.STAR,
3443                             token.AT,
3444                             token.SLASH,
3445                             token.DOUBLESLASH,
3446                             token.PERCENT,
3447                             token.TILDE,
3448                             token.DOUBLESTAR,
3449                             token.AWAIT,
3450                             token.LSQB,
3451                             token.LPAR,
3452                         }
3453                     )
3454                     or (
3455                         # only unary PLUS/MINUS
3456                         before_lpar.parent
3457                         and before_lpar.parent.type == syms.factor
3458                         and (before_lpar.type in {token.PLUS, token.MINUS})
3459                     )
3460                 ):
3461                     continue
3462
3463             # Should be followed by a non-empty RPAR...
3464             if (
3465                 is_valid_index(next_idx)
3466                 and LL[next_idx].type == token.RPAR
3467                 and not is_empty_rpar(LL[next_idx])
3468             ):
3469                 # That RPAR should NOT be followed by anything with higher
3470                 # precedence than PERCENT
3471                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3472                     token.DOUBLESTAR,
3473                     token.LSQB,
3474                     token.LPAR,
3475                     token.DOT,
3476                 }:
3477                     continue
3478
3479                 return Ok(string_idx)
3480
3481         return TErr("This line has no strings wrapped in parens.")
3482
3483     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3484         LL = line.leaves
3485
3486         string_parser = StringParser()
3487         rpar_idx = string_parser.parse(LL, string_idx)
3488
3489         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3490             if line.comments_after(leaf):
3491                 yield TErr(
3492                     "Will not strip parentheses which have comments attached to them."
3493                 )
3494                 return
3495
3496         new_line = line.clone()
3497         new_line.comments = line.comments.copy()
3498         try:
3499             append_leaves(new_line, line, LL[: string_idx - 1])
3500         except BracketMatchError:
3501             # HACK: I believe there is currently a bug somewhere in
3502             # right_hand_split() that is causing brackets to not be tracked
3503             # properly by a shared BracketTracker.
3504             append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
3505
3506         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3507         LL[string_idx - 1].remove()
3508         replace_child(LL[string_idx], string_leaf)
3509         new_line.append(string_leaf)
3510
3511         append_leaves(
3512             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3513         )
3514
3515         LL[rpar_idx].remove()
3516
3517         yield Ok(new_line)
3518
3519
3520 class BaseStringSplitter(StringTransformer):
3521     """
3522     Abstract class for StringTransformers which transform a Line's strings by splitting
3523     them or placing them on their own lines where necessary to avoid going over
3524     the configured line length.
3525
3526     Requirements:
3527         * The target string value is responsible for the line going over the
3528         line length limit. It follows that after all of black's other line
3529         split methods have been exhausted, this line (or one of the resulting
3530         lines after all line splits are performed) would still be over the
3531         line_length limit unless we split this string.
3532             AND
3533         * The target string is NOT a "pointless" string (i.e. a string that has
3534         no parent or siblings).
3535             AND
3536         * The target string is not followed by an inline comment that appears
3537         to be a pragma.
3538             AND
3539         * The target string is not a multiline (i.e. triple-quote) string.
3540     """
3541
3542     @abstractmethod
3543     def do_splitter_match(self, line: Line) -> TMatchResult:
3544         """
3545         BaseStringSplitter asks its clients to override this method instead of
3546         `StringTransformer.do_match(...)`.
3547
3548         Follows the same protocol as `StringTransformer.do_match(...)`.
3549
3550         Refer to `help(StringTransformer.do_match)` for more information.
3551         """
3552
3553     def do_match(self, line: Line) -> TMatchResult:
3554         match_result = self.do_splitter_match(line)
3555         if isinstance(match_result, Err):
3556             return match_result
3557
3558         string_idx = match_result.ok()
3559         vresult = self.__validate(line, string_idx)
3560         if isinstance(vresult, Err):
3561             return vresult
3562
3563         return match_result
3564
3565     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3566         """
3567         Checks that @line meets all of the requirements listed in this classes'
3568         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3569         description of those requirements.
3570
3571         Returns:
3572             * Ok(None), if ALL of the requirements are met.
3573                 OR
3574             * Err(CannotTransform), if ANY of the requirements are NOT met.
3575         """
3576         LL = line.leaves
3577
3578         string_leaf = LL[string_idx]
3579
3580         max_string_length = self.__get_max_string_length(line, string_idx)
3581         if len(string_leaf.value) <= max_string_length:
3582             return TErr(
3583                 "The string itself is not what is causing this line to be too long."
3584             )
3585
3586         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3587             token.STRING,
3588             token.NEWLINE,
3589         ]:
3590             return TErr(
3591                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3592                 " no parent)."
3593             )
3594
3595         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3596             line.comments[id(line.leaves[string_idx])]
3597         ):
3598             return TErr(
3599                 "Line appears to end with an inline pragma comment. Splitting the line"
3600                 " could modify the pragma's behavior."
3601             )
3602
3603         if has_triple_quotes(string_leaf.value):
3604             return TErr("We cannot split multiline strings.")
3605
3606         return Ok(None)
3607
3608     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3609         """
3610         Calculates the max string length used when attempting to determine
3611         whether or not the target string is responsible for causing the line to
3612         go over the line length limit.
3613
3614         WARNING: This method is tightly coupled to both StringSplitter and
3615         (especially) StringParenWrapper. There is probably a better way to
3616         accomplish what is being done here.
3617
3618         Returns:
3619             max_string_length: such that `line.leaves[string_idx].value >
3620             max_string_length` implies that the target string IS responsible
3621             for causing this line to exceed the line length limit.
3622         """
3623         LL = line.leaves
3624
3625         is_valid_index = is_valid_index_factory(LL)
3626
3627         # We use the shorthand "WMA4" in comments to abbreviate "We must
3628         # account for". When giving examples, we use STRING to mean some/any
3629         # valid string.
3630         #
3631         # Finally, we use the following convenience variables:
3632         #
3633         #   P:  The leaf that is before the target string leaf.
3634         #   N:  The leaf that is after the target string leaf.
3635         #   NN: The leaf that is after N.
3636
3637         # WMA4 the whitespace at the beginning of the line.
3638         offset = line.depth * 4
3639
3640         if is_valid_index(string_idx - 1):
3641             p_idx = string_idx - 1
3642             if (
3643                 LL[string_idx - 1].type == token.LPAR
3644                 and LL[string_idx - 1].value == ""
3645                 and string_idx >= 2
3646             ):
3647                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3648                 p_idx -= 1
3649
3650             P = LL[p_idx]
3651             if P.type == token.PLUS:
3652                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3653                 offset += 2
3654
3655             if P.type == token.COMMA:
3656                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3657                 offset += 3
3658
3659             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3660                 # This conditional branch is meant to handle dictionary keys,
3661                 # variable assignments, 'return STRING' statement lines, and
3662                 # 'else STRING' ternary expression lines.
3663
3664                 # WMA4 a single space.
3665                 offset += 1
3666
3667                 # WMA4 the lengths of any leaves that came before that space,
3668                 # but after any closing bracket before that space.
3669                 for leaf in reversed(LL[: p_idx + 1]):
3670                     offset += len(str(leaf))
3671                     if leaf.type in CLOSING_BRACKETS:
3672                         break
3673
3674         if is_valid_index(string_idx + 1):
3675             N = LL[string_idx + 1]
3676             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3677                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3678                 N = LL[string_idx + 2]
3679
3680             if N.type == token.COMMA:
3681                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3682                 offset += 1
3683
3684             if is_valid_index(string_idx + 2):
3685                 NN = LL[string_idx + 2]
3686
3687                 if N.type == token.DOT and NN.type == token.NAME:
3688                     # This conditional branch is meant to handle method calls invoked
3689                     # off of a string literal up to and including the LPAR character.
3690
3691                     # WMA4 the '.' character.
3692                     offset += 1
3693
3694                     if (
3695                         is_valid_index(string_idx + 3)
3696                         and LL[string_idx + 3].type == token.LPAR
3697                     ):
3698                         # WMA4 the left parenthesis character.
3699                         offset += 1
3700
3701                     # WMA4 the length of the method's name.
3702                     offset += len(NN.value)
3703
3704         has_comments = False
3705         for comment_leaf in line.comments_after(LL[string_idx]):
3706             if not has_comments:
3707                 has_comments = True
3708                 # WMA4 two spaces before the '#' character.
3709                 offset += 2
3710
3711             # WMA4 the length of the inline comment.
3712             offset += len(comment_leaf.value)
3713
3714         max_string_length = self.line_length - offset
3715         return max_string_length
3716
3717
3718 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3719     """
3720     StringTransformer that splits "atom" strings (i.e. strings which exist on
3721     lines by themselves).
3722
3723     Requirements:
3724         * The line consists ONLY of a single string (with the exception of a
3725         '+' symbol which MAY exist at the start of the line), MAYBE a string
3726         trailer, and MAYBE a trailing comma.
3727             AND
3728         * All of the requirements listed in BaseStringSplitter's docstring.
3729
3730     Transformations:
3731         The string mentioned in the 'Requirements' section is split into as
3732         many substrings as necessary to adhere to the configured line length.
3733
3734         In the final set of substrings, no substring should be smaller than
3735         MIN_SUBSTR_SIZE characters.
3736
3737         The string will ONLY be split on spaces (i.e. each new substring should
3738         start with a space). Note that the string will NOT be split on a space
3739         which is escaped with a backslash.
3740
3741         If the string is an f-string, it will NOT be split in the middle of an
3742         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3743         else bar()} is an f-expression).
3744
3745         If the string that is being split has an associated set of custom split
3746         records and those custom splits will NOT result in any line going over
3747         the configured line length, those custom splits are used. Otherwise the
3748         string is split as late as possible (from left-to-right) while still
3749         adhering to the transformation rules listed above.
3750
3751     Collaborations:
3752         StringSplitter relies on StringMerger to construct the appropriate
3753         CustomSplit objects and add them to the custom split map.
3754     """
3755
3756     MIN_SUBSTR_SIZE = 6
3757     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3758     RE_FEXPR = r"""
3759     (?<!\{) (?:\{\{)* \{ (?!\{)
3760         (?:
3761             [^\{\}]
3762             | \{\{
3763             | \}\}
3764             | (?R)
3765         )+?
3766     (?<!\}) \} (?:\}\})* (?!\})
3767     """
3768
3769     def do_splitter_match(self, line: Line) -> TMatchResult:
3770         LL = line.leaves
3771
3772         is_valid_index = is_valid_index_factory(LL)
3773
3774         idx = 0
3775
3776         # The first leaf MAY be a '+' symbol...
3777         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3778             idx += 1
3779
3780         # The next/first leaf MAY be an empty LPAR...
3781         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3782             idx += 1
3783
3784         # The next/first leaf MUST be a string...
3785         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3786             return TErr("Line does not start with a string.")
3787
3788         string_idx = idx
3789
3790         # Skip the string trailer, if one exists.
3791         string_parser = StringParser()
3792         idx = string_parser.parse(LL, string_idx)
3793
3794         # That string MAY be followed by an empty RPAR...
3795         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3796             idx += 1
3797
3798         # That string / empty RPAR leaf MAY be followed by a comma...
3799         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3800             idx += 1
3801
3802         # But no more leaves are allowed...
3803         if is_valid_index(idx):
3804             return TErr("This line does not end with a string.")
3805
3806         return Ok(string_idx)
3807
3808     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3809         LL = line.leaves
3810
3811         QUOTE = LL[string_idx].value[-1]
3812
3813         is_valid_index = is_valid_index_factory(LL)
3814         insert_str_child = insert_str_child_factory(LL[string_idx])
3815
3816         prefix = get_string_prefix(LL[string_idx].value)
3817
3818         # We MAY choose to drop the 'f' prefix from substrings that don't
3819         # contain any f-expressions, but ONLY if the original f-string
3820         # contains at least one f-expression. Otherwise, we will alter the AST
3821         # of the program.
3822         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3823             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3824         )
3825
3826         first_string_line = True
3827         starts_with_plus = LL[0].type == token.PLUS
3828
3829         def line_needs_plus() -> bool:
3830             return first_string_line and starts_with_plus
3831
3832         def maybe_append_plus(new_line: Line) -> None:
3833             """
3834             Side Effects:
3835                 If @line starts with a plus and this is the first line we are
3836                 constructing, this function appends a PLUS leaf to @new_line
3837                 and replaces the old PLUS leaf in the node structure. Otherwise
3838                 this function does nothing.
3839             """
3840             if line_needs_plus():
3841                 plus_leaf = Leaf(token.PLUS, "+")
3842                 replace_child(LL[0], plus_leaf)
3843                 new_line.append(plus_leaf)
3844
3845         ends_with_comma = (
3846             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3847         )
3848
3849         def max_last_string() -> int:
3850             """
3851             Returns:
3852                 The max allowed length of the string value used for the last
3853                 line we will construct.
3854             """
3855             result = self.line_length
3856             result -= line.depth * 4
3857             result -= 1 if ends_with_comma else 0
3858             result -= 2 if line_needs_plus() else 0
3859             return result
3860
3861         # --- Calculate Max Break Index (for string value)
3862         # We start with the line length limit
3863         max_break_idx = self.line_length
3864         # The last index of a string of length N is N-1.
3865         max_break_idx -= 1
3866         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3867         max_break_idx -= line.depth * 4
3868         if max_break_idx < 0:
3869             yield TErr(
3870                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3871                 f" {line.depth}"
3872             )
3873             return
3874
3875         # Check if StringMerger registered any custom splits.
3876         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3877         # We use them ONLY if none of them would produce lines that exceed the
3878         # line limit.
3879         use_custom_breakpoints = bool(
3880             custom_splits
3881             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3882         )
3883
3884         # Temporary storage for the remaining chunk of the string line that
3885         # can't fit onto the line currently being constructed.
3886         rest_value = LL[string_idx].value
3887
3888         def more_splits_should_be_made() -> bool:
3889             """
3890             Returns:
3891                 True iff `rest_value` (the remaining string value from the last
3892                 split), should be split again.
3893             """
3894             if use_custom_breakpoints:
3895                 return len(custom_splits) > 1
3896             else:
3897                 return len(rest_value) > max_last_string()
3898
3899         string_line_results: List[Ok[Line]] = []
3900         while more_splits_should_be_made():
3901             if use_custom_breakpoints:
3902                 # Custom User Split (manual)
3903                 csplit = custom_splits.pop(0)
3904                 break_idx = csplit.break_idx
3905             else:
3906                 # Algorithmic Split (automatic)
3907                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3908                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3909                 if maybe_break_idx is None:
3910                     # If we are unable to algorithmically determine a good split
3911                     # and this string has custom splits registered to it, we
3912                     # fall back to using them--which means we have to start
3913                     # over from the beginning.
3914                     if custom_splits:
3915                         rest_value = LL[string_idx].value
3916                         string_line_results = []
3917                         first_string_line = True
3918                         use_custom_breakpoints = True
3919                         continue
3920
3921                     # Otherwise, we stop splitting here.
3922                     break
3923
3924                 break_idx = maybe_break_idx
3925
3926             # --- Construct `next_value`
3927             next_value = rest_value[:break_idx] + QUOTE
3928             if (
3929                 # Are we allowed to try to drop a pointless 'f' prefix?
3930                 drop_pointless_f_prefix
3931                 # If we are, will we be successful?
3932                 and next_value != self.__normalize_f_string(next_value, prefix)
3933             ):
3934                 # If the current custom split did NOT originally use a prefix,
3935                 # then `csplit.break_idx` will be off by one after removing
3936                 # the 'f' prefix.
3937                 break_idx = (
3938                     break_idx + 1
3939                     if use_custom_breakpoints and not csplit.has_prefix
3940                     else break_idx
3941                 )
3942                 next_value = rest_value[:break_idx] + QUOTE
3943                 next_value = self.__normalize_f_string(next_value, prefix)
3944
3945             # --- Construct `next_leaf`
3946             next_leaf = Leaf(token.STRING, next_value)
3947             insert_str_child(next_leaf)
3948             self.__maybe_normalize_string_quotes(next_leaf)
3949
3950             # --- Construct `next_line`
3951             next_line = line.clone()
3952             maybe_append_plus(next_line)
3953             next_line.append(next_leaf)
3954             string_line_results.append(Ok(next_line))
3955
3956             rest_value = prefix + QUOTE + rest_value[break_idx:]
3957             first_string_line = False
3958
3959         yield from string_line_results
3960
3961         if drop_pointless_f_prefix:
3962             rest_value = self.__normalize_f_string(rest_value, prefix)
3963
3964         rest_leaf = Leaf(token.STRING, rest_value)
3965         insert_str_child(rest_leaf)
3966
3967         # NOTE: I could not find a test case that verifies that the following
3968         # line is actually necessary, but it seems to be. Otherwise we risk
3969         # not normalizing the last substring, right?
3970         self.__maybe_normalize_string_quotes(rest_leaf)
3971
3972         last_line = line.clone()
3973         maybe_append_plus(last_line)
3974
3975         # If there are any leaves to the right of the target string...
3976         if is_valid_index(string_idx + 1):
3977             # We use `temp_value` here to determine how long the last line
3978             # would be if we were to append all the leaves to the right of the
3979             # target string to the last string line.
3980             temp_value = rest_value
3981             for leaf in LL[string_idx + 1 :]:
3982                 temp_value += str(leaf)
3983                 if leaf.type == token.LPAR:
3984                     break
3985
3986             # Try to fit them all on the same line with the last substring...
3987             if (
3988                 len(temp_value) <= max_last_string()
3989                 or LL[string_idx + 1].type == token.COMMA
3990             ):
3991                 last_line.append(rest_leaf)
3992                 append_leaves(last_line, line, LL[string_idx + 1 :])
3993                 yield Ok(last_line)
3994             # Otherwise, place the last substring on one line and everything
3995             # else on a line below that...
3996             else:
3997                 last_line.append(rest_leaf)
3998                 yield Ok(last_line)
3999
4000                 non_string_line = line.clone()
4001                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
4002                 yield Ok(non_string_line)
4003         # Else the target string was the last leaf...
4004         else:
4005             last_line.append(rest_leaf)
4006             last_line.comments = line.comments.copy()
4007             yield Ok(last_line)
4008
4009     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
4010         """
4011         This method contains the algorithm that StringSplitter uses to
4012         determine which character to split each string at.
4013
4014         Args:
4015             @string: The substring that we are attempting to split.
4016             @max_break_idx: The ideal break index. We will return this value if it
4017             meets all the necessary conditions. In the likely event that it
4018             doesn't we will try to find the closest index BELOW @max_break_idx
4019             that does. If that fails, we will expand our search by also
4020             considering all valid indices ABOVE @max_break_idx.
4021
4022         Pre-Conditions:
4023             * assert_is_leaf_string(@string)
4024             * 0 <= @max_break_idx < len(@string)
4025
4026         Returns:
4027             break_idx, if an index is able to be found that meets all of the
4028             conditions listed in the 'Transformations' section of this classes'
4029             docstring.
4030                 OR
4031             None, otherwise.
4032         """
4033         is_valid_index = is_valid_index_factory(string)
4034
4035         assert is_valid_index(max_break_idx)
4036         assert_is_leaf_string(string)
4037
4038         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
4039
4040         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
4041             """
4042             Yields:
4043                 All ranges of @string which, if @string were to be split there,
4044                 would result in the splitting of an f-expression (which is NOT
4045                 allowed).
4046             """
4047             nonlocal _fexpr_slices
4048
4049             if _fexpr_slices is None:
4050                 _fexpr_slices = []
4051                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
4052                     _fexpr_slices.append(match.span())
4053
4054             yield from _fexpr_slices
4055
4056         is_fstring = "f" in get_string_prefix(string)
4057
4058         def breaks_fstring_expression(i: Index) -> bool:
4059             """
4060             Returns:
4061                 True iff returning @i would result in the splitting of an
4062                 f-expression (which is NOT allowed).
4063             """
4064             if not is_fstring:
4065                 return False
4066
4067             for (start, end) in fexpr_slices():
4068                 if start <= i < end:
4069                     return True
4070
4071             return False
4072
4073         def passes_all_checks(i: Index) -> bool:
4074             """
4075             Returns:
4076                 True iff ALL of the conditions listed in the 'Transformations'
4077                 section of this classes' docstring would be be met by returning @i.
4078             """
4079             is_space = string[i] == " "
4080
4081             is_not_escaped = True
4082             j = i - 1
4083             while is_valid_index(j) and string[j] == "\\":
4084                 is_not_escaped = not is_not_escaped
4085                 j -= 1
4086
4087             is_big_enough = (
4088                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
4089                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
4090             )
4091             return (
4092                 is_space
4093                 and is_not_escaped
4094                 and is_big_enough
4095                 and not breaks_fstring_expression(i)
4096             )
4097
4098         # First, we check all indices BELOW @max_break_idx.
4099         break_idx = max_break_idx
4100         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
4101             break_idx -= 1
4102
4103         if not passes_all_checks(break_idx):
4104             # If that fails, we check all indices ABOVE @max_break_idx.
4105             #
4106             # If we are able to find a valid index here, the next line is going
4107             # to be longer than the specified line length, but it's probably
4108             # better than doing nothing at all.
4109             break_idx = max_break_idx + 1
4110             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
4111                 break_idx += 1
4112
4113             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
4114                 return None
4115
4116         return break_idx
4117
4118     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
4119         if self.normalize_strings:
4120             normalize_string_quotes(leaf)
4121
4122     def __normalize_f_string(self, string: str, prefix: str) -> str:
4123         """
4124         Pre-Conditions:
4125             * assert_is_leaf_string(@string)
4126
4127         Returns:
4128             * If @string is an f-string that contains no f-expressions, we
4129             return a string identical to @string except that the 'f' prefix
4130             has been stripped and all double braces (i.e. '{{' or '}}') have
4131             been normalized (i.e. turned into '{' or '}').
4132                 OR
4133             * Otherwise, we return @string.
4134         """
4135         assert_is_leaf_string(string)
4136
4137         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
4138             new_prefix = prefix.replace("f", "")
4139
4140             temp = string[len(prefix) :]
4141             temp = re.sub(r"\{\{", "{", temp)
4142             temp = re.sub(r"\}\}", "}", temp)
4143             new_string = temp
4144
4145             return f"{new_prefix}{new_string}"
4146         else:
4147             return string
4148
4149
4150 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4151     """
4152     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4153     exist on lines by themselves).
4154
4155     Requirements:
4156         All of the requirements listed in BaseStringSplitter's docstring in
4157         addition to the requirements listed below:
4158
4159         * The line is a return/yield statement, which returns/yields a string.
4160             OR
4161         * The line is part of a ternary expression (e.g. `x = y if cond else
4162         z`) such that the line starts with `else <string>`, where <string> is
4163         some string.
4164             OR
4165         * The line is an assert statement, which ends with a string.
4166             OR
4167         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4168         <string>`) such that the variable is being assigned the value of some
4169         string.
4170             OR
4171         * The line is a dictionary key assignment where some valid key is being
4172         assigned the value of some string.
4173
4174     Transformations:
4175         The chosen string is wrapped in parentheses and then split at the LPAR.
4176
4177         We then have one line which ends with an LPAR and another line that
4178         starts with the chosen string. The latter line is then split again at
4179         the RPAR. This results in the RPAR (and possibly a trailing comma)
4180         being placed on its own line.
4181
4182         NOTE: If any leaves exist to the right of the chosen string (except
4183         for a trailing comma, which would be placed after the RPAR), those
4184         leaves are placed inside the parentheses.  In effect, the chosen
4185         string is not necessarily being "wrapped" by parentheses. We can,
4186         however, count on the LPAR being placed directly before the chosen
4187         string.
4188
4189         In other words, StringParenWrapper creates "atom" strings. These
4190         can then be split again by StringSplitter, if necessary.
4191
4192     Collaborations:
4193         In the event that a string line split by StringParenWrapper is
4194         changed such that it no longer needs to be given its own line,
4195         StringParenWrapper relies on StringParenStripper to clean up the
4196         parentheses it created.
4197     """
4198
4199     def do_splitter_match(self, line: Line) -> TMatchResult:
4200         LL = line.leaves
4201
4202         string_idx = (
4203             self._return_match(LL)
4204             or self._else_match(LL)
4205             or self._assert_match(LL)
4206             or self._assign_match(LL)
4207             or self._dict_match(LL)
4208         )
4209
4210         if string_idx is not None:
4211             string_value = line.leaves[string_idx].value
4212             # If the string has no spaces...
4213             if " " not in string_value:
4214                 # And will still violate the line length limit when split...
4215                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4216                 if len(string_value) > max_string_length:
4217                     # And has no associated custom splits...
4218                     if not self.has_custom_splits(string_value):
4219                         # Then we should NOT put this string on its own line.
4220                         return TErr(
4221                             "We do not wrap long strings in parentheses when the"
4222                             " resultant line would still be over the specified line"
4223                             " length and can't be split further by StringSplitter."
4224                         )
4225             return Ok(string_idx)
4226
4227         return TErr("This line does not contain any non-atomic strings.")
4228
4229     @staticmethod
4230     def _return_match(LL: List[Leaf]) -> Optional[int]:
4231         """
4232         Returns:
4233             string_idx such that @LL[string_idx] is equal to our target (i.e.
4234             matched) string, if this line matches the return/yield statement
4235             requirements listed in the 'Requirements' section of this classes'
4236             docstring.
4237                 OR
4238             None, otherwise.
4239         """
4240         # If this line is apart of a return/yield statement and the first leaf
4241         # contains either the "return" or "yield" keywords...
4242         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4243             0
4244         ].value in ["return", "yield"]:
4245             is_valid_index = is_valid_index_factory(LL)
4246
4247             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4248             # The next visible leaf MUST contain a string...
4249             if is_valid_index(idx) and LL[idx].type == token.STRING:
4250                 return idx
4251
4252         return None
4253
4254     @staticmethod
4255     def _else_match(LL: List[Leaf]) -> Optional[int]:
4256         """
4257         Returns:
4258             string_idx such that @LL[string_idx] is equal to our target (i.e.
4259             matched) string, if this line matches the ternary expression
4260             requirements listed in the 'Requirements' section of this classes'
4261             docstring.
4262                 OR
4263             None, otherwise.
4264         """
4265         # If this line is apart of a ternary expression and the first leaf
4266         # contains the "else" keyword...
4267         if (
4268             parent_type(LL[0]) == syms.test
4269             and LL[0].type == token.NAME
4270             and LL[0].value == "else"
4271         ):
4272             is_valid_index = is_valid_index_factory(LL)
4273
4274             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4275             # The next visible leaf MUST contain a string...
4276             if is_valid_index(idx) and LL[idx].type == token.STRING:
4277                 return idx
4278
4279         return None
4280
4281     @staticmethod
4282     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4283         """
4284         Returns:
4285             string_idx such that @LL[string_idx] is equal to our target (i.e.
4286             matched) string, if this line matches the assert statement
4287             requirements listed in the 'Requirements' section of this classes'
4288             docstring.
4289                 OR
4290             None, otherwise.
4291         """
4292         # If this line is apart of an assert statement and the first leaf
4293         # contains the "assert" keyword...
4294         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4295             is_valid_index = is_valid_index_factory(LL)
4296
4297             for (i, leaf) in enumerate(LL):
4298                 # We MUST find a comma...
4299                 if leaf.type == token.COMMA:
4300                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4301
4302                     # That comma MUST be followed by a string...
4303                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4304                         string_idx = idx
4305
4306                         # Skip the string trailer, if one exists.
4307                         string_parser = StringParser()
4308                         idx = string_parser.parse(LL, string_idx)
4309
4310                         # But no more leaves are allowed...
4311                         if not is_valid_index(idx):
4312                             return string_idx
4313
4314         return None
4315
4316     @staticmethod
4317     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4318         """
4319         Returns:
4320             string_idx such that @LL[string_idx] is equal to our target (i.e.
4321             matched) string, if this line matches the assignment statement
4322             requirements listed in the 'Requirements' section of this classes'
4323             docstring.
4324                 OR
4325             None, otherwise.
4326         """
4327         # If this line is apart of an expression statement or is a function
4328         # argument AND the first leaf contains a variable name...
4329         if (
4330             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4331             and LL[0].type == token.NAME
4332         ):
4333             is_valid_index = is_valid_index_factory(LL)
4334
4335             for (i, leaf) in enumerate(LL):
4336                 # We MUST find either an '=' or '+=' symbol...
4337                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4338                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4339
4340                     # That symbol MUST be followed by a string...
4341                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4342                         string_idx = idx
4343
4344                         # Skip the string trailer, if one exists.
4345                         string_parser = StringParser()
4346                         idx = string_parser.parse(LL, string_idx)
4347
4348                         # The next leaf MAY be a comma iff this line is apart
4349                         # of a function argument...
4350                         if (
4351                             parent_type(LL[0]) == syms.argument
4352                             and is_valid_index(idx)
4353                             and LL[idx].type == token.COMMA
4354                         ):
4355                             idx += 1
4356
4357                         # But no more leaves are allowed...
4358                         if not is_valid_index(idx):
4359                             return string_idx
4360
4361         return None
4362
4363     @staticmethod
4364     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4365         """
4366         Returns:
4367             string_idx such that @LL[string_idx] is equal to our target (i.e.
4368             matched) string, if this line matches the dictionary key assignment
4369             statement requirements listed in the 'Requirements' section of this
4370             classes' docstring.
4371                 OR
4372             None, otherwise.
4373         """
4374         # If this line is apart of a dictionary key assignment...
4375         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4376             is_valid_index = is_valid_index_factory(LL)
4377
4378             for (i, leaf) in enumerate(LL):
4379                 # We MUST find a colon...
4380                 if leaf.type == token.COLON:
4381                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4382
4383                     # That colon MUST be followed by a string...
4384                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4385                         string_idx = idx
4386
4387                         # Skip the string trailer, if one exists.
4388                         string_parser = StringParser()
4389                         idx = string_parser.parse(LL, string_idx)
4390
4391                         # That string MAY be followed by a comma...
4392                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4393                             idx += 1
4394
4395                         # But no more leaves are allowed...
4396                         if not is_valid_index(idx):
4397                             return string_idx
4398
4399         return None
4400
4401     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4402         LL = line.leaves
4403
4404         is_valid_index = is_valid_index_factory(LL)
4405         insert_str_child = insert_str_child_factory(LL[string_idx])
4406
4407         comma_idx = -1
4408         ends_with_comma = False
4409         if LL[comma_idx].type == token.COMMA:
4410             ends_with_comma = True
4411
4412         leaves_to_steal_comments_from = [LL[string_idx]]
4413         if ends_with_comma:
4414             leaves_to_steal_comments_from.append(LL[comma_idx])
4415
4416         # --- First Line
4417         first_line = line.clone()
4418         left_leaves = LL[:string_idx]
4419
4420         # We have to remember to account for (possibly invisible) LPAR and RPAR
4421         # leaves that already wrapped the target string. If these leaves do
4422         # exist, we will replace them with our own LPAR and RPAR leaves.
4423         old_parens_exist = False
4424         if left_leaves and left_leaves[-1].type == token.LPAR:
4425             old_parens_exist = True
4426             leaves_to_steal_comments_from.append(left_leaves[-1])
4427             left_leaves.pop()
4428
4429         append_leaves(first_line, line, left_leaves)
4430
4431         lpar_leaf = Leaf(token.LPAR, "(")
4432         if old_parens_exist:
4433             replace_child(LL[string_idx - 1], lpar_leaf)
4434         else:
4435             insert_str_child(lpar_leaf)
4436         first_line.append(lpar_leaf)
4437
4438         # We throw inline comments that were originally to the right of the
4439         # target string to the top line. They will now be shown to the right of
4440         # the LPAR.
4441         for leaf in leaves_to_steal_comments_from:
4442             for comment_leaf in line.comments_after(leaf):
4443                 first_line.append(comment_leaf, preformatted=True)
4444
4445         yield Ok(first_line)
4446
4447         # --- Middle (String) Line
4448         # We only need to yield one (possibly too long) string line, since the
4449         # `StringSplitter` will break it down further if necessary.
4450         string_value = LL[string_idx].value
4451         string_line = Line(
4452             mode=line.mode,
4453             depth=line.depth + 1,
4454             inside_brackets=True,
4455             should_split_rhs=line.should_split_rhs,
4456             magic_trailing_comma=line.magic_trailing_comma,
4457         )
4458         string_leaf = Leaf(token.STRING, string_value)
4459         insert_str_child(string_leaf)
4460         string_line.append(string_leaf)
4461
4462         old_rpar_leaf = None
4463         if is_valid_index(string_idx + 1):
4464             right_leaves = LL[string_idx + 1 :]
4465             if ends_with_comma:
4466                 right_leaves.pop()
4467
4468             if old_parens_exist:
4469                 assert (
4470                     right_leaves and right_leaves[-1].type == token.RPAR
4471                 ), "Apparently, old parentheses do NOT exist?!"
4472                 old_rpar_leaf = right_leaves.pop()
4473
4474             append_leaves(string_line, line, right_leaves)
4475
4476         yield Ok(string_line)
4477
4478         # --- Last Line
4479         last_line = line.clone()
4480         last_line.bracket_tracker = first_line.bracket_tracker
4481
4482         new_rpar_leaf = Leaf(token.RPAR, ")")
4483         if old_rpar_leaf is not None:
4484             replace_child(old_rpar_leaf, new_rpar_leaf)
4485         else:
4486             insert_str_child(new_rpar_leaf)
4487         last_line.append(new_rpar_leaf)
4488
4489         # If the target string ended with a comma, we place this comma to the
4490         # right of the RPAR on the last line.
4491         if ends_with_comma:
4492             comma_leaf = Leaf(token.COMMA, ",")
4493             replace_child(LL[comma_idx], comma_leaf)
4494             last_line.append(comma_leaf)
4495
4496         yield Ok(last_line)
4497
4498
4499 class StringParser:
4500     """
4501     A state machine that aids in parsing a string's "trailer", which can be
4502     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4503     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4504     varY)`).
4505
4506     NOTE: A new StringParser object MUST be instantiated for each string
4507     trailer we need to parse.
4508
4509     Examples:
4510         We shall assume that `line` equals the `Line` object that corresponds
4511         to the following line of python code:
4512         ```
4513         x = "Some {}.".format("String") + some_other_string
4514         ```
4515
4516         Furthermore, we will assume that `string_idx` is some index such that:
4517         ```
4518         assert line.leaves[string_idx].value == "Some {}."
4519         ```
4520
4521         The following code snippet then holds:
4522         ```
4523         string_parser = StringParser()
4524         idx = string_parser.parse(line.leaves, string_idx)
4525         assert line.leaves[idx].type == token.PLUS
4526         ```
4527     """
4528
4529     DEFAULT_TOKEN = -1
4530
4531     # String Parser States
4532     START = 1
4533     DOT = 2
4534     NAME = 3
4535     PERCENT = 4
4536     SINGLE_FMT_ARG = 5
4537     LPAR = 6
4538     RPAR = 7
4539     DONE = 8
4540
4541     # Lookup Table for Next State
4542     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4543         # A string trailer may start with '.' OR '%'.
4544         (START, token.DOT): DOT,
4545         (START, token.PERCENT): PERCENT,
4546         (START, DEFAULT_TOKEN): DONE,
4547         # A '.' MUST be followed by an attribute or method name.
4548         (DOT, token.NAME): NAME,
4549         # A method name MUST be followed by an '(', whereas an attribute name
4550         # is the last symbol in the string trailer.
4551         (NAME, token.LPAR): LPAR,
4552         (NAME, DEFAULT_TOKEN): DONE,
4553         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4554         # string or variable name).
4555         (PERCENT, token.LPAR): LPAR,
4556         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4557         # If a '%' symbol is followed by a single argument, that argument is
4558         # the last leaf in the string trailer.
4559         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4560         # If present, a ')' symbol is the last symbol in a string trailer.
4561         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4562         # since they are treated as a special case by the parsing logic in this
4563         # classes' implementation.)
4564         (RPAR, DEFAULT_TOKEN): DONE,
4565     }
4566
4567     def __init__(self) -> None:
4568         self._state = self.START
4569         self._unmatched_lpars = 0
4570
4571     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4572         """
4573         Pre-conditions:
4574             * @leaves[@string_idx].type == token.STRING
4575
4576         Returns:
4577             The index directly after the last leaf which is apart of the string
4578             trailer, if a "trailer" exists.
4579                 OR
4580             @string_idx + 1, if no string "trailer" exists.
4581         """
4582         assert leaves[string_idx].type == token.STRING
4583
4584         idx = string_idx + 1
4585         while idx < len(leaves) and self._next_state(leaves[idx]):
4586             idx += 1
4587         return idx
4588
4589     def _next_state(self, leaf: Leaf) -> bool:
4590         """
4591         Pre-conditions:
4592             * On the first call to this function, @leaf MUST be the leaf that
4593             was directly after the string leaf in question (e.g. if our target
4594             string is `line.leaves[i]` then the first call to this method must
4595             be `line.leaves[i + 1]`).
4596             * On the next call to this function, the leaf parameter passed in
4597             MUST be the leaf directly following @leaf.
4598
4599         Returns:
4600             True iff @leaf is apart of the string's trailer.
4601         """
4602         # We ignore empty LPAR or RPAR leaves.
4603         if is_empty_par(leaf):
4604             return True
4605
4606         next_token = leaf.type
4607         if next_token == token.LPAR:
4608             self._unmatched_lpars += 1
4609
4610         current_state = self._state
4611
4612         # The LPAR parser state is a special case. We will return True until we
4613         # find the matching RPAR token.
4614         if current_state == self.LPAR:
4615             if next_token == token.RPAR:
4616                 self._unmatched_lpars -= 1
4617                 if self._unmatched_lpars == 0:
4618                     self._state = self.RPAR
4619         # Otherwise, we use a lookup table to determine the next state.
4620         else:
4621             # If the lookup table matches the current state to the next
4622             # token, we use the lookup table.
4623             if (current_state, next_token) in self._goto:
4624                 self._state = self._goto[current_state, next_token]
4625             else:
4626                 # Otherwise, we check if a the current state was assigned a
4627                 # default.
4628                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4629                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4630                 # If no default has been assigned, then this parser has a logic
4631                 # error.
4632                 else:
4633                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4634
4635             if self._state == self.DONE:
4636                 return False
4637
4638         return True
4639
4640
4641 def TErr(err_msg: str) -> Err[CannotTransform]:
4642     """(T)ransform Err
4643
4644     Convenience function used when working with the TResult type.
4645     """
4646     cant_transform = CannotTransform(err_msg)
4647     return Err(cant_transform)
4648
4649
4650 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4651     """
4652     Returns:
4653         True iff one of the comments in @comment_list is a pragma used by one
4654         of the more common static analysis tools for python (e.g. mypy, flake8,
4655         pylint).
4656     """
4657     for comment in comment_list:
4658         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4659             return True
4660
4661     return False
4662
4663
4664 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4665     """
4666     Factory for a convenience function that is used to orphan @string_leaf
4667     and then insert multiple new leaves into the same part of the node
4668     structure that @string_leaf had originally occupied.
4669
4670     Examples:
4671         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4672         string_leaf.parent`. Assume the node `N` has the following
4673         original structure:
4674
4675         Node(
4676             expr_stmt, [
4677                 Leaf(NAME, 'x'),
4678                 Leaf(EQUAL, '='),
4679                 Leaf(STRING, '"foo"'),
4680             ]
4681         )
4682
4683         We then run the code snippet shown below.
4684         ```
4685         insert_str_child = insert_str_child_factory(string_leaf)
4686
4687         lpar = Leaf(token.LPAR, '(')
4688         insert_str_child(lpar)
4689
4690         bar = Leaf(token.STRING, '"bar"')
4691         insert_str_child(bar)
4692
4693         rpar = Leaf(token.RPAR, ')')
4694         insert_str_child(rpar)
4695         ```
4696
4697         After which point, it follows that `string_leaf.parent is None` and
4698         the node `N` now has the following structure:
4699
4700         Node(
4701             expr_stmt, [
4702                 Leaf(NAME, 'x'),
4703                 Leaf(EQUAL, '='),
4704                 Leaf(LPAR, '('),
4705                 Leaf(STRING, '"bar"'),
4706                 Leaf(RPAR, ')'),
4707             ]
4708         )
4709     """
4710     string_parent = string_leaf.parent
4711     string_child_idx = string_leaf.remove()
4712
4713     def insert_str_child(child: LN) -> None:
4714         nonlocal string_child_idx
4715
4716         assert string_parent is not None
4717         assert string_child_idx is not None
4718
4719         string_parent.insert_child(string_child_idx, child)
4720         string_child_idx += 1
4721
4722     return insert_str_child
4723
4724
4725 def has_triple_quotes(string: str) -> bool:
4726     """
4727     Returns:
4728         True iff @string starts with three quotation characters.
4729     """
4730     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4731     return raw_string[:3] in {'"""', "'''"}
4732
4733
4734 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4735     """
4736     Returns:
4737         @node.parent.type, if @node is not None and has a parent.
4738             OR
4739         None, otherwise.
4740     """
4741     if node is None or node.parent is None:
4742         return None
4743
4744     return node.parent.type
4745
4746
4747 def is_empty_par(leaf: Leaf) -> bool:
4748     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4749
4750
4751 def is_empty_lpar(leaf: Leaf) -> bool:
4752     return leaf.type == token.LPAR and leaf.value == ""
4753
4754
4755 def is_empty_rpar(leaf: Leaf) -> bool:
4756     return leaf.type == token.RPAR and leaf.value == ""
4757
4758
4759 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4760     """
4761     Examples:
4762         ```
4763         my_list = [1, 2, 3]
4764
4765         is_valid_index = is_valid_index_factory(my_list)
4766
4767         assert is_valid_index(0)
4768         assert is_valid_index(2)
4769
4770         assert not is_valid_index(3)
4771         assert not is_valid_index(-1)
4772         ```
4773     """
4774
4775     def is_valid_index(idx: int) -> bool:
4776         """
4777         Returns:
4778             True iff @idx is positive AND seq[@idx] does NOT raise an
4779             IndexError.
4780         """
4781         return 0 <= idx < len(seq)
4782
4783     return is_valid_index
4784
4785
4786 def line_to_string(line: Line) -> str:
4787     """Returns the string representation of @line.
4788
4789     WARNING: This is known to be computationally expensive.
4790     """
4791     return str(line).strip("\n")
4792
4793
4794 def append_leaves(
4795     new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
4796 ) -> None:
4797     """
4798     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4799     underlying Node structure where appropriate.
4800
4801     All of the leaves in @leaves are duplicated. The duplicates are then
4802     appended to @new_line and used to replace their originals in the underlying
4803     Node structure. Any comments attached to the old leaves are reattached to
4804     the new leaves.
4805
4806     Pre-conditions:
4807         set(@leaves) is a subset of set(@old_line.leaves).
4808     """
4809     for old_leaf in leaves:
4810         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4811         replace_child(old_leaf, new_leaf)
4812         new_line.append(new_leaf, preformatted=preformatted)
4813
4814         for comment_leaf in old_line.comments_after(old_leaf):
4815             new_line.append(comment_leaf, preformatted=True)
4816
4817
4818 def replace_child(old_child: LN, new_child: LN) -> None:
4819     """
4820     Side Effects:
4821         * If @old_child.parent is set, replace @old_child with @new_child in
4822         @old_child's underlying Node structure.
4823             OR
4824         * Otherwise, this function does nothing.
4825     """
4826     parent = old_child.parent
4827     if not parent:
4828         return
4829
4830     child_idx = old_child.remove()
4831     if child_idx is not None:
4832         parent.insert_child(child_idx, new_child)
4833
4834
4835 def get_string_prefix(string: str) -> str:
4836     """
4837     Pre-conditions:
4838         * assert_is_leaf_string(@string)
4839
4840     Returns:
4841         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4842     """
4843     assert_is_leaf_string(string)
4844
4845     prefix = ""
4846     prefix_idx = 0
4847     while string[prefix_idx] in STRING_PREFIX_CHARS:
4848         prefix += string[prefix_idx].lower()
4849         prefix_idx += 1
4850
4851     return prefix
4852
4853
4854 def assert_is_leaf_string(string: str) -> None:
4855     """
4856     Checks the pre-condition that @string has the format that you would expect
4857     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4858     token.STRING`. A more precise description of the pre-conditions that are
4859     checked are listed below.
4860
4861     Pre-conditions:
4862         * @string starts with either ', ", <prefix>', or <prefix>" where
4863         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4864         * @string ends with a quote character (' or ").
4865
4866     Raises:
4867         AssertionError(...) if the pre-conditions listed above are not
4868         satisfied.
4869     """
4870     dquote_idx = string.find('"')
4871     squote_idx = string.find("'")
4872     if -1 in [dquote_idx, squote_idx]:
4873         quote_idx = max(dquote_idx, squote_idx)
4874     else:
4875         quote_idx = min(squote_idx, dquote_idx)
4876
4877     assert (
4878         0 <= quote_idx < len(string) - 1
4879     ), f"{string!r} is missing a starting quote character (' or \")."
4880     assert string[-1] in (
4881         "'",
4882         '"',
4883     ), f"{string!r} is missing an ending quote character (' or \")."
4884     assert set(string[:quote_idx]).issubset(
4885         set(STRING_PREFIX_CHARS)
4886     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4887
4888
4889 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4890     """Split line into many lines, starting with the first matching bracket pair.
4891
4892     Note: this usually looks weird, only use this for function definitions.
4893     Prefer RHS otherwise.  This is why this function is not symmetrical with
4894     :func:`right_hand_split` which also handles optional parentheses.
4895     """
4896     tail_leaves: List[Leaf] = []
4897     body_leaves: List[Leaf] = []
4898     head_leaves: List[Leaf] = []
4899     current_leaves = head_leaves
4900     matching_bracket: Optional[Leaf] = None
4901     for leaf in line.leaves:
4902         if (
4903             current_leaves is body_leaves
4904             and leaf.type in CLOSING_BRACKETS
4905             and leaf.opening_bracket is matching_bracket
4906         ):
4907             current_leaves = tail_leaves if body_leaves else head_leaves
4908         current_leaves.append(leaf)
4909         if current_leaves is head_leaves:
4910             if leaf.type in OPENING_BRACKETS:
4911                 matching_bracket = leaf
4912                 current_leaves = body_leaves
4913     if not matching_bracket:
4914         raise CannotSplit("No brackets found")
4915
4916     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4917     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4918     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4919     bracket_split_succeeded_or_raise(head, body, tail)
4920     for result in (head, body, tail):
4921         if result:
4922             yield result
4923
4924
4925 def right_hand_split(
4926     line: Line,
4927     line_length: int,
4928     features: Collection[Feature] = (),
4929     omit: Collection[LeafID] = (),
4930 ) -> Iterator[Line]:
4931     """Split line into many lines, starting with the last matching bracket pair.
4932
4933     If the split was by optional parentheses, attempt splitting without them, too.
4934     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4935     this split.
4936
4937     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4938     """
4939     tail_leaves: List[Leaf] = []
4940     body_leaves: List[Leaf] = []
4941     head_leaves: List[Leaf] = []
4942     current_leaves = tail_leaves
4943     opening_bracket: Optional[Leaf] = None
4944     closing_bracket: Optional[Leaf] = None
4945     for leaf in reversed(line.leaves):
4946         if current_leaves is body_leaves:
4947             if leaf is opening_bracket:
4948                 current_leaves = head_leaves if body_leaves else tail_leaves
4949         current_leaves.append(leaf)
4950         if current_leaves is tail_leaves:
4951             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4952                 opening_bracket = leaf.opening_bracket
4953                 closing_bracket = leaf
4954                 current_leaves = body_leaves
4955     if not (opening_bracket and closing_bracket and head_leaves):
4956         # If there is no opening or closing_bracket that means the split failed and
4957         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4958         # the matching `opening_bracket` wasn't available on `line` anymore.
4959         raise CannotSplit("No brackets found")
4960
4961     tail_leaves.reverse()
4962     body_leaves.reverse()
4963     head_leaves.reverse()
4964     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4965     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4966     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4967     bracket_split_succeeded_or_raise(head, body, tail)
4968     if (
4969         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4970         # the opening bracket is an optional paren
4971         and opening_bracket.type == token.LPAR
4972         and not opening_bracket.value
4973         # the closing bracket is an optional paren
4974         and closing_bracket.type == token.RPAR
4975         and not closing_bracket.value
4976         # it's not an import (optional parens are the only thing we can split on
4977         # in this case; attempting a split without them is a waste of time)
4978         and not line.is_import
4979         # there are no standalone comments in the body
4980         and not body.contains_standalone_comments(0)
4981         # and we can actually remove the parens
4982         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4983     ):
4984         omit = {id(closing_bracket), *omit}
4985         try:
4986             yield from right_hand_split(line, line_length, features=features, omit=omit)
4987             return
4988
4989         except CannotSplit:
4990             if not (
4991                 can_be_split(body)
4992                 or is_line_short_enough(body, line_length=line_length)
4993             ):
4994                 raise CannotSplit(
4995                     "Splitting failed, body is still too long and can't be split."
4996                 )
4997
4998             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4999                 raise CannotSplit(
5000                     "The current optional pair of parentheses is bound to fail to"
5001                     " satisfy the splitting algorithm because the head or the tail"
5002                     " contains multiline strings which by definition never fit one"
5003                     " line."
5004                 )
5005
5006     ensure_visible(opening_bracket)
5007     ensure_visible(closing_bracket)
5008     for result in (head, body, tail):
5009         if result:
5010             yield result
5011
5012
5013 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
5014     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
5015
5016     Do nothing otherwise.
5017
5018     A left- or right-hand split is based on a pair of brackets. Content before
5019     (and including) the opening bracket is left on one line, content inside the
5020     brackets is put on a separate line, and finally content starting with and
5021     following the closing bracket is put on a separate line.
5022
5023     Those are called `head`, `body`, and `tail`, respectively. If the split
5024     produced the same line (all content in `head`) or ended up with an empty `body`
5025     and the `tail` is just the closing bracket, then it's considered failed.
5026     """
5027     tail_len = len(str(tail).strip())
5028     if not body:
5029         if tail_len == 0:
5030             raise CannotSplit("Splitting brackets produced the same line")
5031
5032         elif tail_len < 3:
5033             raise CannotSplit(
5034                 f"Splitting brackets on an empty body to save {tail_len} characters is"
5035                 " not worth it"
5036             )
5037
5038
5039 def bracket_split_build_line(
5040     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
5041 ) -> Line:
5042     """Return a new line with given `leaves` and respective comments from `original`.
5043
5044     If `is_body` is True, the result line is one-indented inside brackets and as such
5045     has its first leaf's prefix normalized and a trailing comma added when expected.
5046     """
5047     result = Line(mode=original.mode, depth=original.depth)
5048     if is_body:
5049         result.inside_brackets = True
5050         result.depth += 1
5051         if leaves:
5052             # Since body is a new indent level, remove spurious leading whitespace.
5053             normalize_prefix(leaves[0], inside_brackets=True)
5054             # Ensure a trailing comma for imports and standalone function arguments, but
5055             # be careful not to add one after any comments or within type annotations.
5056             no_commas = (
5057                 original.is_def
5058                 and opening_bracket.value == "("
5059                 and not any(leaf.type == token.COMMA for leaf in leaves)
5060             )
5061
5062             if original.is_import or no_commas:
5063                 for i in range(len(leaves) - 1, -1, -1):
5064                     if leaves[i].type == STANDALONE_COMMENT:
5065                         continue
5066
5067                     if leaves[i].type != token.COMMA:
5068                         new_comma = Leaf(token.COMMA, ",")
5069                         leaves.insert(i + 1, new_comma)
5070                     break
5071
5072     # Populate the line
5073     for leaf in leaves:
5074         result.append(leaf, preformatted=True)
5075         for comment_after in original.comments_after(leaf):
5076             result.append(comment_after, preformatted=True)
5077     if is_body and should_split_line(result, opening_bracket):
5078         result.should_split_rhs = True
5079     return result
5080
5081
5082 def dont_increase_indentation(split_func: Transformer) -> Transformer:
5083     """Normalize prefix of the first leaf in every line returned by `split_func`.
5084
5085     This is a decorator over relevant split functions.
5086     """
5087
5088     @wraps(split_func)
5089     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5090         for line in split_func(line, features):
5091             normalize_prefix(line.leaves[0], inside_brackets=True)
5092             yield line
5093
5094     return split_wrapper
5095
5096
5097 @dont_increase_indentation
5098 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5099     """Split according to delimiters of the highest priority.
5100
5101     If the appropriate Features are given, the split will add trailing commas
5102     also in function signatures and calls that contain `*` and `**`.
5103     """
5104     try:
5105         last_leaf = line.leaves[-1]
5106     except IndexError:
5107         raise CannotSplit("Line empty")
5108
5109     bt = line.bracket_tracker
5110     try:
5111         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
5112     except ValueError:
5113         raise CannotSplit("No delimiters found")
5114
5115     if delimiter_priority == DOT_PRIORITY:
5116         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
5117             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
5118
5119     current_line = Line(
5120         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5121     )
5122     lowest_depth = sys.maxsize
5123     trailing_comma_safe = True
5124
5125     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5126         """Append `leaf` to current line or to new line if appending impossible."""
5127         nonlocal current_line
5128         try:
5129             current_line.append_safe(leaf, preformatted=True)
5130         except ValueError:
5131             yield current_line
5132
5133             current_line = Line(
5134                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5135             )
5136             current_line.append(leaf)
5137
5138     for leaf in line.leaves:
5139         yield from append_to_line(leaf)
5140
5141         for comment_after in line.comments_after(leaf):
5142             yield from append_to_line(comment_after)
5143
5144         lowest_depth = min(lowest_depth, leaf.bracket_depth)
5145         if leaf.bracket_depth == lowest_depth:
5146             if is_vararg(leaf, within={syms.typedargslist}):
5147                 trailing_comma_safe = (
5148                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
5149                 )
5150             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
5151                 trailing_comma_safe = (
5152                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
5153                 )
5154
5155         leaf_priority = bt.delimiters.get(id(leaf))
5156         if leaf_priority == delimiter_priority:
5157             yield current_line
5158
5159             current_line = Line(
5160                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5161             )
5162     if current_line:
5163         if (
5164             trailing_comma_safe
5165             and delimiter_priority == COMMA_PRIORITY
5166             and current_line.leaves[-1].type != token.COMMA
5167             and current_line.leaves[-1].type != STANDALONE_COMMENT
5168         ):
5169             new_comma = Leaf(token.COMMA, ",")
5170             current_line.append(new_comma)
5171         yield current_line
5172
5173
5174 @dont_increase_indentation
5175 def standalone_comment_split(
5176     line: Line, features: Collection[Feature] = ()
5177 ) -> Iterator[Line]:
5178     """Split standalone comments from the rest of the line."""
5179     if not line.contains_standalone_comments(0):
5180         raise CannotSplit("Line does not have any standalone comments")
5181
5182     current_line = Line(
5183         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5184     )
5185
5186     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5187         """Append `leaf` to current line or to new line if appending impossible."""
5188         nonlocal current_line
5189         try:
5190             current_line.append_safe(leaf, preformatted=True)
5191         except ValueError:
5192             yield current_line
5193
5194             current_line = Line(
5195                 line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5196             )
5197             current_line.append(leaf)
5198
5199     for leaf in line.leaves:
5200         yield from append_to_line(leaf)
5201
5202         for comment_after in line.comments_after(leaf):
5203             yield from append_to_line(comment_after)
5204
5205     if current_line:
5206         yield current_line
5207
5208
5209 def is_import(leaf: Leaf) -> bool:
5210     """Return True if the given leaf starts an import statement."""
5211     p = leaf.parent
5212     t = leaf.type
5213     v = leaf.value
5214     return bool(
5215         t == token.NAME
5216         and (
5217             (v == "import" and p and p.type == syms.import_name)
5218             or (v == "from" and p and p.type == syms.import_from)
5219         )
5220     )
5221
5222
5223 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5224     """Return True if the given leaf is a special comment.
5225     Only returns true for type comments for now."""
5226     t = leaf.type
5227     v = leaf.value
5228     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5229
5230
5231 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5232     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5233     else.
5234
5235     Note: don't use backslashes for formatting or you'll lose your voting rights.
5236     """
5237     if not inside_brackets:
5238         spl = leaf.prefix.split("#")
5239         if "\\" not in spl[0]:
5240             nl_count = spl[-1].count("\n")
5241             if len(spl) > 1:
5242                 nl_count -= 1
5243             leaf.prefix = "\n" * nl_count
5244             return
5245
5246     leaf.prefix = ""
5247
5248
5249 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5250     """Make all string prefixes lowercase.
5251
5252     If remove_u_prefix is given, also removes any u prefix from the string.
5253
5254     Note: Mutates its argument.
5255     """
5256     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5257     assert match is not None, f"failed to match string {leaf.value!r}"
5258     orig_prefix = match.group(1)
5259     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5260     if remove_u_prefix:
5261         new_prefix = new_prefix.replace("u", "")
5262     leaf.value = f"{new_prefix}{match.group(2)}"
5263
5264
5265 def normalize_string_quotes(leaf: Leaf) -> None:
5266     """Prefer double quotes but only if it doesn't cause more escaping.
5267
5268     Adds or removes backslashes as appropriate. Doesn't parse and fix
5269     strings nested in f-strings (yet).
5270
5271     Note: Mutates its argument.
5272     """
5273     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5274     if value[:3] == '"""':
5275         return
5276
5277     elif value[:3] == "'''":
5278         orig_quote = "'''"
5279         new_quote = '"""'
5280     elif value[0] == '"':
5281         orig_quote = '"'
5282         new_quote = "'"
5283     else:
5284         orig_quote = "'"
5285         new_quote = '"'
5286     first_quote_pos = leaf.value.find(orig_quote)
5287     if first_quote_pos == -1:
5288         return  # There's an internal error
5289
5290     prefix = leaf.value[:first_quote_pos]
5291     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5292     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5293     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5294     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5295     if "r" in prefix.casefold():
5296         if unescaped_new_quote.search(body):
5297             # There's at least one unescaped new_quote in this raw string
5298             # so converting is impossible
5299             return
5300
5301         # Do not introduce or remove backslashes in raw strings
5302         new_body = body
5303     else:
5304         # remove unnecessary escapes
5305         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5306         if body != new_body:
5307             # Consider the string without unnecessary escapes as the original
5308             body = new_body
5309             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5310         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5311         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5312     if "f" in prefix.casefold():
5313         matches = re.findall(
5314             r"""
5315             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5316                 ([^{].*?)  # contents of the brackets except if begins with {{
5317             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5318             """,
5319             new_body,
5320             re.VERBOSE,
5321         )
5322         for m in matches:
5323             if "\\" in str(m):
5324                 # Do not introduce backslashes in interpolated expressions
5325                 return
5326
5327     if new_quote == '"""' and new_body[-1:] == '"':
5328         # edge case:
5329         new_body = new_body[:-1] + '\\"'
5330     orig_escape_count = body.count("\\")
5331     new_escape_count = new_body.count("\\")
5332     if new_escape_count > orig_escape_count:
5333         return  # Do not introduce more escaping
5334
5335     if new_escape_count == orig_escape_count and orig_quote == '"':
5336         return  # Prefer double quotes
5337
5338     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5339
5340
5341 def normalize_numeric_literal(leaf: Leaf) -> None:
5342     """Normalizes numeric (float, int, and complex) literals.
5343
5344     All letters used in the representation are normalized to lowercase (except
5345     in Python 2 long literals).
5346     """
5347     text = leaf.value.lower()
5348     if text.startswith(("0o", "0b")):
5349         # Leave octal and binary literals alone.
5350         pass
5351     elif text.startswith("0x"):
5352         text = format_hex(text)
5353     elif "e" in text:
5354         text = format_scientific_notation(text)
5355     elif text.endswith(("j", "l")):
5356         text = format_long_or_complex_number(text)
5357     else:
5358         text = format_float_or_int_string(text)
5359     leaf.value = text
5360
5361
5362 def format_hex(text: str) -> str:
5363     """
5364     Formats a hexadecimal string like "0x12B3"
5365     """
5366     before, after = text[:2], text[2:]
5367     return f"{before}{after.upper()}"
5368
5369
5370 def format_scientific_notation(text: str) -> str:
5371     """Formats a numeric string utilizing scentific notation"""
5372     before, after = text.split("e")
5373     sign = ""
5374     if after.startswith("-"):
5375         after = after[1:]
5376         sign = "-"
5377     elif after.startswith("+"):
5378         after = after[1:]
5379     before = format_float_or_int_string(before)
5380     return f"{before}e{sign}{after}"
5381
5382
5383 def format_long_or_complex_number(text: str) -> str:
5384     """Formats a long or complex string like `10L` or `10j`"""
5385     number = text[:-1]
5386     suffix = text[-1]
5387     # Capitalize in "2L" because "l" looks too similar to "1".
5388     if suffix == "l":
5389         suffix = "L"
5390     return f"{format_float_or_int_string(number)}{suffix}"
5391
5392
5393 def format_float_or_int_string(text: str) -> str:
5394     """Formats a float string like "1.0"."""
5395     if "." not in text:
5396         return text
5397
5398     before, after = text.split(".")
5399     return f"{before or 0}.{after or 0}"
5400
5401
5402 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5403     """Make existing optional parentheses invisible or create new ones.
5404
5405     `parens_after` is a set of string leaf values immediately after which parens
5406     should be put.
5407
5408     Standardizes on visible parentheses for single-element tuples, and keeps
5409     existing visible parentheses for other tuples and generator expressions.
5410     """
5411     for pc in list_comments(node.prefix, is_endmarker=False):
5412         if pc.value in FMT_OFF:
5413             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5414             return
5415     check_lpar = False
5416     for index, child in enumerate(list(node.children)):
5417         # Fixes a bug where invisible parens are not properly stripped from
5418         # assignment statements that contain type annotations.
5419         if isinstance(child, Node) and child.type == syms.annassign:
5420             normalize_invisible_parens(child, parens_after=parens_after)
5421
5422         # Add parentheses around long tuple unpacking in assignments.
5423         if (
5424             index == 0
5425             and isinstance(child, Node)
5426             and child.type == syms.testlist_star_expr
5427         ):
5428             check_lpar = True
5429
5430         if check_lpar:
5431             if child.type == syms.atom:
5432                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5433                     wrap_in_parentheses(node, child, visible=False)
5434             elif is_one_tuple(child):
5435                 wrap_in_parentheses(node, child, visible=True)
5436             elif node.type == syms.import_from:
5437                 # "import from" nodes store parentheses directly as part of
5438                 # the statement
5439                 if child.type == token.LPAR:
5440                     # make parentheses invisible
5441                     child.value = ""  # type: ignore
5442                     node.children[-1].value = ""  # type: ignore
5443                 elif child.type != token.STAR:
5444                     # insert invisible parentheses
5445                     node.insert_child(index, Leaf(token.LPAR, ""))
5446                     node.append_child(Leaf(token.RPAR, ""))
5447                 break
5448
5449             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5450                 wrap_in_parentheses(node, child, visible=False)
5451
5452         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5453
5454
5455 def normalize_fmt_off(node: Node) -> None:
5456     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5457     try_again = True
5458     while try_again:
5459         try_again = convert_one_fmt_off_pair(node)
5460
5461
5462 def convert_one_fmt_off_pair(node: Node) -> bool:
5463     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5464
5465     Returns True if a pair was converted.
5466     """
5467     for leaf in node.leaves():
5468         previous_consumed = 0
5469         for comment in list_comments(leaf.prefix, is_endmarker=False):
5470             if comment.value not in FMT_PASS:
5471                 previous_consumed = comment.consumed
5472                 continue
5473             # We only want standalone comments. If there's no previous leaf or
5474             # the previous leaf is indentation, it's a standalone comment in
5475             # disguise.
5476             if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
5477                 prev = preceding_leaf(leaf)
5478                 if prev:
5479                     if comment.value in FMT_OFF and prev.type not in WHITESPACE:
5480                         continue
5481                     if comment.value in FMT_SKIP and prev.type in WHITESPACE:
5482                         continue
5483
5484             ignored_nodes = list(generate_ignored_nodes(leaf, comment))
5485             if not ignored_nodes:
5486                 continue
5487
5488             first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5489             parent = first.parent
5490             prefix = first.prefix
5491             first.prefix = prefix[comment.consumed :]
5492             hidden_value = "".join(str(n) for n in ignored_nodes)
5493             if comment.value in FMT_OFF:
5494                 hidden_value = comment.value + "\n" + hidden_value
5495             if comment.value in FMT_SKIP:
5496                 hidden_value += "  " + comment.value
5497             if hidden_value.endswith("\n"):
5498                 # That happens when one of the `ignored_nodes` ended with a NEWLINE
5499                 # leaf (possibly followed by a DEDENT).
5500                 hidden_value = hidden_value[:-1]
5501             first_idx: Optional[int] = None
5502             for ignored in ignored_nodes:
5503                 index = ignored.remove()
5504                 if first_idx is None:
5505                     first_idx = index
5506             assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5507             assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5508             parent.insert_child(
5509                 first_idx,
5510                 Leaf(
5511                     STANDALONE_COMMENT,
5512                     hidden_value,
5513                     prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5514                 ),
5515             )
5516             return True
5517
5518     return False
5519
5520
5521 def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
5522     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5523
5524     If comment is skip, returns leaf only.
5525     Stops at the end of the block.
5526     """
5527     container: Optional[LN] = container_of(leaf)
5528     if comment.value in FMT_SKIP:
5529         prev_sibling = leaf.prev_sibling
5530         if comment.value in leaf.prefix and prev_sibling is not None:
5531             leaf.prefix = leaf.prefix.replace(comment.value, "")
5532             siblings = [prev_sibling]
5533             while (
5534                 "\n" not in prev_sibling.prefix
5535                 and prev_sibling.prev_sibling is not None
5536             ):
5537                 prev_sibling = prev_sibling.prev_sibling
5538                 siblings.insert(0, prev_sibling)
5539             for sibling in siblings:
5540                 yield sibling
5541         elif leaf.parent is not None:
5542             yield leaf.parent
5543         return
5544     while container is not None and container.type != token.ENDMARKER:
5545         if is_fmt_on(container):
5546             return
5547
5548         # fix for fmt: on in children
5549         if contains_fmt_on_at_column(container, leaf.column):
5550             for child in container.children:
5551                 if contains_fmt_on_at_column(child, leaf.column):
5552                     return
5553                 yield child
5554         else:
5555             yield container
5556             container = container.next_sibling
5557
5558
5559 def is_fmt_on(container: LN) -> bool:
5560     """Determine whether formatting is switched on within a container.
5561     Determined by whether the last `# fmt:` comment is `on` or `off`.
5562     """
5563     fmt_on = False
5564     for comment in list_comments(container.prefix, is_endmarker=False):
5565         if comment.value in FMT_ON:
5566             fmt_on = True
5567         elif comment.value in FMT_OFF:
5568             fmt_on = False
5569     return fmt_on
5570
5571
5572 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5573     """Determine if children at a given column have formatting switched on."""
5574     for child in container.children:
5575         if (
5576             isinstance(child, Node)
5577             and first_leaf_column(child) == column
5578             or isinstance(child, Leaf)
5579             and child.column == column
5580         ):
5581             if is_fmt_on(child):
5582                 return True
5583
5584     return False
5585
5586
5587 def first_leaf_column(node: Node) -> Optional[int]:
5588     """Returns the column of the first leaf child of a node."""
5589     for child in node.children:
5590         if isinstance(child, Leaf):
5591             return child.column
5592     return None
5593
5594
5595 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5596     """If it's safe, make the parens in the atom `node` invisible, recursively.
5597     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5598     as they are redundant.
5599
5600     Returns whether the node should itself be wrapped in invisible parentheses.
5601
5602     """
5603
5604     if (
5605         node.type != syms.atom
5606         or is_empty_tuple(node)
5607         or is_one_tuple(node)
5608         or (is_yield(node) and parent.type != syms.expr_stmt)
5609         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5610     ):
5611         return False
5612
5613     if is_walrus_assignment(node):
5614         if parent.type in [syms.annassign, syms.expr_stmt]:
5615             return False
5616
5617     first = node.children[0]
5618     last = node.children[-1]
5619     if first.type == token.LPAR and last.type == token.RPAR:
5620         middle = node.children[1]
5621         # make parentheses invisible
5622         first.value = ""  # type: ignore
5623         last.value = ""  # type: ignore
5624         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5625
5626         if is_atom_with_invisible_parens(middle):
5627             # Strip the invisible parens from `middle` by replacing
5628             # it with the child in-between the invisible parens
5629             middle.replace(middle.children[1])
5630
5631         return False
5632
5633     return True
5634
5635
5636 def is_atom_with_invisible_parens(node: LN) -> bool:
5637     """Given a `LN`, determines whether it's an atom `node` with invisible
5638     parens. Useful in dedupe-ing and normalizing parens.
5639     """
5640     if isinstance(node, Leaf) or node.type != syms.atom:
5641         return False
5642
5643     first, last = node.children[0], node.children[-1]
5644     return (
5645         isinstance(first, Leaf)
5646         and first.type == token.LPAR
5647         and first.value == ""
5648         and isinstance(last, Leaf)
5649         and last.type == token.RPAR
5650         and last.value == ""
5651     )
5652
5653
5654 def is_empty_tuple(node: LN) -> bool:
5655     """Return True if `node` holds an empty tuple."""
5656     return (
5657         node.type == syms.atom
5658         and len(node.children) == 2
5659         and node.children[0].type == token.LPAR
5660         and node.children[1].type == token.RPAR
5661     )
5662
5663
5664 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5665     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5666
5667     Parenthesis can be optional. Returns None otherwise"""
5668     if len(node.children) != 3:
5669         return None
5670
5671     lpar, wrapped, rpar = node.children
5672     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5673         return None
5674
5675     return wrapped
5676
5677
5678 def first_child_is_arith(node: Node) -> bool:
5679     """Whether first child is an arithmetic or a binary arithmetic expression"""
5680     expr_types = {
5681         syms.arith_expr,
5682         syms.shift_expr,
5683         syms.xor_expr,
5684         syms.and_expr,
5685     }
5686     return bool(node.children and node.children[0].type in expr_types)
5687
5688
5689 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5690     """Wrap `child` in parentheses.
5691
5692     This replaces `child` with an atom holding the parentheses and the old
5693     child.  That requires moving the prefix.
5694
5695     If `visible` is False, the leaves will be valueless (and thus invisible).
5696     """
5697     lpar = Leaf(token.LPAR, "(" if visible else "")
5698     rpar = Leaf(token.RPAR, ")" if visible else "")
5699     prefix = child.prefix
5700     child.prefix = ""
5701     index = child.remove() or 0
5702     new_child = Node(syms.atom, [lpar, child, rpar])
5703     new_child.prefix = prefix
5704     parent.insert_child(index, new_child)
5705
5706
5707 def is_one_tuple(node: LN) -> bool:
5708     """Return True if `node` holds a tuple with one element, with or without parens."""
5709     if node.type == syms.atom:
5710         gexp = unwrap_singleton_parenthesis(node)
5711         if gexp is None or gexp.type != syms.testlist_gexp:
5712             return False
5713
5714         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5715
5716     return (
5717         node.type in IMPLICIT_TUPLE
5718         and len(node.children) == 2
5719         and node.children[1].type == token.COMMA
5720     )
5721
5722
5723 def is_walrus_assignment(node: LN) -> bool:
5724     """Return True iff `node` is of the shape ( test := test )"""
5725     inner = unwrap_singleton_parenthesis(node)
5726     return inner is not None and inner.type == syms.namedexpr_test
5727
5728
5729 def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
5730     """Return True iff `node` is a trailer valid in a simple decorator"""
5731     return node.type == syms.trailer and (
5732         (
5733             len(node.children) == 2
5734             and node.children[0].type == token.DOT
5735             and node.children[1].type == token.NAME
5736         )
5737         # last trailer can be arguments
5738         or (
5739             last
5740             and len(node.children) == 3
5741             and node.children[0].type == token.LPAR
5742             # and node.children[1].type == syms.argument
5743             and node.children[2].type == token.RPAR
5744         )
5745     )
5746
5747
5748 def is_simple_decorator_expression(node: LN) -> bool:
5749     """Return True iff `node` could be a 'dotted name' decorator
5750
5751     This function takes the node of the 'namedexpr_test' of the new decorator
5752     grammar and test if it would be valid under the old decorator grammar.
5753
5754     The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
5755     The new grammar is : decorator: @ namedexpr_test NEWLINE
5756     """
5757     if node.type == token.NAME:
5758         return True
5759     if node.type == syms.power:
5760         if node.children:
5761             return (
5762                 node.children[0].type == token.NAME
5763                 and all(map(is_simple_decorator_trailer, node.children[1:-1]))
5764                 and (
5765                     len(node.children) < 2
5766                     or is_simple_decorator_trailer(node.children[-1], last=True)
5767                 )
5768             )
5769     return False
5770
5771
5772 def is_yield(node: LN) -> bool:
5773     """Return True if `node` holds a `yield` or `yield from` expression."""
5774     if node.type == syms.yield_expr:
5775         return True
5776
5777     if node.type == token.NAME and node.value == "yield":  # type: ignore
5778         return True
5779
5780     if node.type != syms.atom:
5781         return False
5782
5783     if len(node.children) != 3:
5784         return False
5785
5786     lpar, expr, rpar = node.children
5787     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5788         return is_yield(expr)
5789
5790     return False
5791
5792
5793 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5794     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5795
5796     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5797     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5798     extended iterable unpacking (PEP 3132) and additional unpacking
5799     generalizations (PEP 448).
5800     """
5801     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5802         return False
5803
5804     p = leaf.parent
5805     if p.type == syms.star_expr:
5806         # Star expressions are also used as assignment targets in extended
5807         # iterable unpacking (PEP 3132).  See what its parent is instead.
5808         if not p.parent:
5809             return False
5810
5811         p = p.parent
5812
5813     return p.type in within
5814
5815
5816 def is_multiline_string(leaf: Leaf) -> bool:
5817     """Return True if `leaf` is a multiline string that actually spans many lines."""
5818     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5819
5820
5821 def is_stub_suite(node: Node) -> bool:
5822     """Return True if `node` is a suite with a stub body."""
5823     if (
5824         len(node.children) != 4
5825         or node.children[0].type != token.NEWLINE
5826         or node.children[1].type != token.INDENT
5827         or node.children[3].type != token.DEDENT
5828     ):
5829         return False
5830
5831     return is_stub_body(node.children[2])
5832
5833
5834 def is_stub_body(node: LN) -> bool:
5835     """Return True if `node` is a simple statement containing an ellipsis."""
5836     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5837         return False
5838
5839     if len(node.children) != 2:
5840         return False
5841
5842     child = node.children[0]
5843     return (
5844         child.type == syms.atom
5845         and len(child.children) == 3
5846         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5847     )
5848
5849
5850 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5851     """Return maximum delimiter priority inside `node`.
5852
5853     This is specific to atoms with contents contained in a pair of parentheses.
5854     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5855     """
5856     if node.type != syms.atom:
5857         return 0
5858
5859     first = node.children[0]
5860     last = node.children[-1]
5861     if not (first.type == token.LPAR and last.type == token.RPAR):
5862         return 0
5863
5864     bt = BracketTracker()
5865     for c in node.children[1:-1]:
5866         if isinstance(c, Leaf):
5867             bt.mark(c)
5868         else:
5869             for leaf in c.leaves():
5870                 bt.mark(leaf)
5871     try:
5872         return bt.max_delimiter_priority()
5873
5874     except ValueError:
5875         return 0
5876
5877
5878 def ensure_visible(leaf: Leaf) -> None:
5879     """Make sure parentheses are visible.
5880
5881     They could be invisible as part of some statements (see
5882     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5883     """
5884     if leaf.type == token.LPAR:
5885         leaf.value = "("
5886     elif leaf.type == token.RPAR:
5887         leaf.value = ")"
5888
5889
5890 def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
5891     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5892
5893     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5894         return False
5895
5896     # We're essentially checking if the body is delimited by commas and there's more
5897     # than one of them (we're excluding the trailing comma and if the delimiter priority
5898     # is still commas, that means there's more).
5899     exclude = set()
5900     trailing_comma = False
5901     try:
5902         last_leaf = line.leaves[-1]
5903         if last_leaf.type == token.COMMA:
5904             trailing_comma = True
5905             exclude.add(id(last_leaf))
5906         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5907     except (IndexError, ValueError):
5908         return False
5909
5910     return max_priority == COMMA_PRIORITY and (
5911         (line.mode.magic_trailing_comma and trailing_comma)
5912         # always explode imports
5913         or opening_bracket.parent.type in {syms.atom, syms.import_from}
5914     )
5915
5916
5917 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5918     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5919     if opening.type != token.LPAR and closing.type != token.RPAR:
5920         return False
5921
5922     depth = closing.bracket_depth + 1
5923     for _opening_index, leaf in enumerate(leaves):
5924         if leaf is opening:
5925             break
5926
5927     else:
5928         raise LookupError("Opening paren not found in `leaves`")
5929
5930     commas = 0
5931     _opening_index += 1
5932     for leaf in leaves[_opening_index:]:
5933         if leaf is closing:
5934             break
5935
5936         bracket_depth = leaf.bracket_depth
5937         if bracket_depth == depth and leaf.type == token.COMMA:
5938             commas += 1
5939             if leaf.parent and leaf.parent.type in {
5940                 syms.arglist,
5941                 syms.typedargslist,
5942             }:
5943                 commas += 1
5944                 break
5945
5946     return commas < 2
5947
5948
5949 def get_features_used(node: Node) -> Set[Feature]:
5950     """Return a set of (relatively) new Python features used in this file.
5951
5952     Currently looking for:
5953     - f-strings;
5954     - underscores in numeric literals;
5955     - trailing commas after * or ** in function signatures and calls;
5956     - positional only arguments in function signatures and lambdas;
5957     - assignment expression;
5958     - relaxed decorator syntax;
5959     """
5960     features: Set[Feature] = set()
5961     for n in node.pre_order():
5962         if n.type == token.STRING:
5963             value_head = n.value[:2]  # type: ignore
5964             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5965                 features.add(Feature.F_STRINGS)
5966
5967         elif n.type == token.NUMBER:
5968             if "_" in n.value:  # type: ignore
5969                 features.add(Feature.NUMERIC_UNDERSCORES)
5970
5971         elif n.type == token.SLASH:
5972             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5973                 features.add(Feature.POS_ONLY_ARGUMENTS)
5974
5975         elif n.type == token.COLONEQUAL:
5976             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5977
5978         elif n.type == syms.decorator:
5979             if len(n.children) > 1 and not is_simple_decorator_expression(
5980                 n.children[1]
5981             ):
5982                 features.add(Feature.RELAXED_DECORATORS)
5983
5984         elif (
5985             n.type in {syms.typedargslist, syms.arglist}
5986             and n.children
5987             and n.children[-1].type == token.COMMA
5988         ):
5989             if n.type == syms.typedargslist:
5990                 feature = Feature.TRAILING_COMMA_IN_DEF
5991             else:
5992                 feature = Feature.TRAILING_COMMA_IN_CALL
5993
5994             for ch in n.children:
5995                 if ch.type in STARS:
5996                     features.add(feature)
5997
5998                 if ch.type == syms.argument:
5999                     for argch in ch.children:
6000                         if argch.type in STARS:
6001                             features.add(feature)
6002
6003     return features
6004
6005
6006 def detect_target_versions(node: Node) -> Set[TargetVersion]:
6007     """Detect the version to target based on the nodes used."""
6008     features = get_features_used(node)
6009     return {
6010         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
6011     }
6012
6013
6014 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
6015     """Generate sets of closing bracket IDs that should be omitted in a RHS.
6016
6017     Brackets can be omitted if the entire trailer up to and including
6018     a preceding closing bracket fits in one line.
6019
6020     Yielded sets are cumulative (contain results of previous yields, too).  First
6021     set is empty, unless the line should explode, in which case bracket pairs until
6022     the one that needs to explode are omitted.
6023     """
6024
6025     omit: Set[LeafID] = set()
6026     if not line.magic_trailing_comma:
6027         yield omit
6028
6029     length = 4 * line.depth
6030     opening_bracket: Optional[Leaf] = None
6031     closing_bracket: Optional[Leaf] = None
6032     inner_brackets: Set[LeafID] = set()
6033     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
6034         length += leaf_length
6035         if length > line_length:
6036             break
6037
6038         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
6039         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
6040             break
6041
6042         if opening_bracket:
6043             if leaf is opening_bracket:
6044                 opening_bracket = None
6045             elif leaf.type in CLOSING_BRACKETS:
6046                 prev = line.leaves[index - 1] if index > 0 else None
6047                 if (
6048                     prev
6049                     and prev.type == token.COMMA
6050                     and not is_one_tuple_between(
6051                         leaf.opening_bracket, leaf, line.leaves
6052                     )
6053                 ):
6054                     # Never omit bracket pairs with trailing commas.
6055                     # We need to explode on those.
6056                     break
6057
6058                 inner_brackets.add(id(leaf))
6059         elif leaf.type in CLOSING_BRACKETS:
6060             prev = line.leaves[index - 1] if index > 0 else None
6061             if prev and prev.type in OPENING_BRACKETS:
6062                 # Empty brackets would fail a split so treat them as "inner"
6063                 # brackets (e.g. only add them to the `omit` set if another
6064                 # pair of brackets was good enough.
6065                 inner_brackets.add(id(leaf))
6066                 continue
6067
6068             if closing_bracket:
6069                 omit.add(id(closing_bracket))
6070                 omit.update(inner_brackets)
6071                 inner_brackets.clear()
6072                 yield omit
6073
6074             if (
6075                 prev
6076                 and prev.type == token.COMMA
6077                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
6078             ):
6079                 # Never omit bracket pairs with trailing commas.
6080                 # We need to explode on those.
6081                 break
6082
6083             if leaf.value:
6084                 opening_bracket = leaf.opening_bracket
6085                 closing_bracket = leaf
6086
6087
6088 def get_future_imports(node: Node) -> Set[str]:
6089     """Return a set of __future__ imports in the file."""
6090     imports: Set[str] = set()
6091
6092     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
6093         for child in children:
6094             if isinstance(child, Leaf):
6095                 if child.type == token.NAME:
6096                     yield child.value
6097
6098             elif child.type == syms.import_as_name:
6099                 orig_name = child.children[0]
6100                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
6101                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
6102                 yield orig_name.value
6103
6104             elif child.type == syms.import_as_names:
6105                 yield from get_imports_from_children(child.children)
6106
6107             else:
6108                 raise AssertionError("Invalid syntax parsing imports")
6109
6110     for child in node.children:
6111         if child.type != syms.simple_stmt:
6112             break
6113
6114         first_child = child.children[0]
6115         if isinstance(first_child, Leaf):
6116             # Continue looking if we see a docstring; otherwise stop.
6117             if (
6118                 len(child.children) == 2
6119                 and first_child.type == token.STRING
6120                 and child.children[1].type == token.NEWLINE
6121             ):
6122                 continue
6123
6124             break
6125
6126         elif first_child.type == syms.import_from:
6127             module_name = first_child.children[1]
6128             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
6129                 break
6130
6131             imports |= set(get_imports_from_children(first_child.children[3:]))
6132         else:
6133             break
6134
6135     return imports
6136
6137
6138 @lru_cache()
6139 def get_gitignore(root: Path) -> PathSpec:
6140     """Return a PathSpec matching gitignore content if present."""
6141     gitignore = root / ".gitignore"
6142     lines: List[str] = []
6143     if gitignore.is_file():
6144         with gitignore.open() as gf:
6145             lines = gf.readlines()
6146     return PathSpec.from_lines("gitwildmatch", lines)
6147
6148
6149 def normalize_path_maybe_ignore(
6150     path: Path, root: Path, report: "Report"
6151 ) -> Optional[str]:
6152     """Normalize `path`. May return `None` if `path` was ignored.
6153
6154     `report` is where "path ignored" output goes.
6155     """
6156     try:
6157         abspath = path if path.is_absolute() else Path.cwd() / path
6158         normalized_path = abspath.resolve().relative_to(root).as_posix()
6159     except OSError as e:
6160         report.path_ignored(path, f"cannot be read because {e}")
6161         return None
6162
6163     except ValueError:
6164         if path.is_symlink():
6165             report.path_ignored(path, f"is a symbolic link that points outside {root}")
6166             return None
6167
6168         raise
6169
6170     return normalized_path
6171
6172
6173 def path_is_excluded(
6174     normalized_path: str,
6175     pattern: Optional[Pattern[str]],
6176 ) -> bool:
6177     match = pattern.search(normalized_path) if pattern else None
6178     return bool(match and match.group(0))
6179
6180
6181 def gen_python_files(
6182     paths: Iterable[Path],
6183     root: Path,
6184     include: Optional[Pattern[str]],
6185     exclude: Pattern[str],
6186     extend_exclude: Optional[Pattern[str]],
6187     force_exclude: Optional[Pattern[str]],
6188     report: "Report",
6189     gitignore: PathSpec,
6190 ) -> Iterator[Path]:
6191     """Generate all files under `path` whose paths are not excluded by the
6192     `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
6193     but are included by the `include` regex.
6194
6195     Symbolic links pointing outside of the `root` directory are ignored.
6196
6197     `report` is where output about exclusions goes.
6198     """
6199     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
6200     for child in paths:
6201         normalized_path = normalize_path_maybe_ignore(child, root, report)
6202         if normalized_path is None:
6203             continue
6204
6205         # First ignore files matching .gitignore
6206         if gitignore.match_file(normalized_path):
6207             report.path_ignored(child, "matches the .gitignore file content")
6208             continue
6209
6210         # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
6211         normalized_path = "/" + normalized_path
6212         if child.is_dir():
6213             normalized_path += "/"
6214
6215         if path_is_excluded(normalized_path, exclude):
6216             report.path_ignored(child, "matches the --exclude regular expression")
6217             continue
6218
6219         if path_is_excluded(normalized_path, extend_exclude):
6220             report.path_ignored(
6221                 child, "matches the --extend-exclude regular expression"
6222             )
6223             continue
6224
6225         if path_is_excluded(normalized_path, force_exclude):
6226             report.path_ignored(child, "matches the --force-exclude regular expression")
6227             continue
6228
6229         if child.is_dir():
6230             yield from gen_python_files(
6231                 child.iterdir(),
6232                 root,
6233                 include,
6234                 exclude,
6235                 extend_exclude,
6236                 force_exclude,
6237                 report,
6238                 gitignore,
6239             )
6240
6241         elif child.is_file():
6242             include_match = include.search(normalized_path) if include else True
6243             if include_match:
6244                 yield child
6245
6246
6247 @lru_cache()
6248 def find_project_root(srcs: Tuple[str, ...]) -> Path:
6249     """Return a directory containing .git, .hg, or pyproject.toml.
6250
6251     That directory will be a common parent of all files and directories
6252     passed in `srcs`.
6253
6254     If no directory in the tree contains a marker that would specify it's the
6255     project root, the root of the file system is returned.
6256     """
6257     if not srcs:
6258         return Path("/").resolve()
6259
6260     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
6261
6262     # A list of lists of parents for each 'src'. 'src' is included as a
6263     # "parent" of itself if it is a directory
6264     src_parents = [
6265         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
6266     ]
6267
6268     common_base = max(
6269         set.intersection(*(set(parents) for parents in src_parents)),
6270         key=lambda path: path.parts,
6271     )
6272
6273     for directory in (common_base, *common_base.parents):
6274         if (directory / ".git").exists():
6275             return directory
6276
6277         if (directory / ".hg").is_dir():
6278             return directory
6279
6280         if (directory / "pyproject.toml").is_file():
6281             return directory
6282
6283     return directory
6284
6285
6286 @lru_cache()
6287 def find_user_pyproject_toml() -> Path:
6288     r"""Return the path to the top-level user configuration for black.
6289
6290     This looks for ~\.black on Windows and ~/.config/black on Linux and other
6291     Unix systems.
6292     """
6293     if sys.platform == "win32":
6294         # Windows
6295         user_config_path = Path.home() / ".black"
6296     else:
6297         config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
6298         user_config_path = Path(config_root).expanduser() / "black"
6299     return user_config_path.resolve()
6300
6301
6302 @dataclass
6303 class Report:
6304     """Provides a reformatting counter. Can be rendered with `str(report)`."""
6305
6306     check: bool = False
6307     diff: bool = False
6308     quiet: bool = False
6309     verbose: bool = False
6310     change_count: int = 0
6311     same_count: int = 0
6312     failure_count: int = 0
6313
6314     def done(self, src: Path, changed: Changed) -> None:
6315         """Increment the counter for successful reformatting. Write out a message."""
6316         if changed is Changed.YES:
6317             reformatted = "would reformat" if self.check or self.diff else "reformatted"
6318             if self.verbose or not self.quiet:
6319                 out(f"{reformatted} {src}")
6320             self.change_count += 1
6321         else:
6322             if self.verbose:
6323                 if changed is Changed.NO:
6324                     msg = f"{src} already well formatted, good job."
6325                 else:
6326                     msg = f"{src} wasn't modified on disk since last run."
6327                 out(msg, bold=False)
6328             self.same_count += 1
6329
6330     def failed(self, src: Path, message: str) -> None:
6331         """Increment the counter for failed reformatting. Write out a message."""
6332         err(f"error: cannot format {src}: {message}")
6333         self.failure_count += 1
6334
6335     def path_ignored(self, path: Path, message: str) -> None:
6336         if self.verbose:
6337             out(f"{path} ignored: {message}", bold=False)
6338
6339     @property
6340     def return_code(self) -> int:
6341         """Return the exit code that the app should use.
6342
6343         This considers the current state of changed files and failures:
6344         - if there were any failures, return 123;
6345         - if any files were changed and --check is being used, return 1;
6346         - otherwise return 0.
6347         """
6348         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6349         # 126 we have special return codes reserved by the shell.
6350         if self.failure_count:
6351             return 123
6352
6353         elif self.change_count and self.check:
6354             return 1
6355
6356         return 0
6357
6358     def __str__(self) -> str:
6359         """Render a color report of the current state.
6360
6361         Use `click.unstyle` to remove colors.
6362         """
6363         if self.check or self.diff:
6364             reformatted = "would be reformatted"
6365             unchanged = "would be left unchanged"
6366             failed = "would fail to reformat"
6367         else:
6368             reformatted = "reformatted"
6369             unchanged = "left unchanged"
6370             failed = "failed to reformat"
6371         report = []
6372         if self.change_count:
6373             s = "s" if self.change_count > 1 else ""
6374             report.append(
6375                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6376             )
6377         if self.same_count:
6378             s = "s" if self.same_count > 1 else ""
6379             report.append(f"{self.same_count} file{s} {unchanged}")
6380         if self.failure_count:
6381             s = "s" if self.failure_count > 1 else ""
6382             report.append(
6383                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6384             )
6385         return ", ".join(report) + "."
6386
6387
6388 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6389     filename = "<unknown>"
6390     if sys.version_info >= (3, 8):
6391         # TODO: support Python 4+ ;)
6392         for minor_version in range(sys.version_info[1], 4, -1):
6393             try:
6394                 return ast.parse(src, filename, feature_version=(3, minor_version))
6395             except SyntaxError:
6396                 continue
6397     else:
6398         for feature_version in (7, 6):
6399             try:
6400                 return ast3.parse(src, filename, feature_version=feature_version)
6401             except SyntaxError:
6402                 continue
6403     if ast27.__name__ == "ast":
6404         raise SyntaxError(
6405             "The requested source code has invalid Python 3 syntax.\n"
6406             "If you are trying to format Python 2 files please reinstall Black"
6407             " with the 'python2' extra: `python3 -m pip install black[python2]`."
6408         )
6409     return ast27.parse(src)
6410
6411
6412 def _fixup_ast_constants(
6413     node: Union[ast.AST, ast3.AST, ast27.AST]
6414 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6415     """Map ast nodes deprecated in 3.8 to Constant."""
6416     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6417         return ast.Constant(value=node.s)
6418
6419     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6420         return ast.Constant(value=node.n)
6421
6422     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6423         return ast.Constant(value=node.value)
6424
6425     return node
6426
6427
6428 def _stringify_ast(
6429     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6430 ) -> Iterator[str]:
6431     """Simple visitor generating strings to compare ASTs by content."""
6432
6433     node = _fixup_ast_constants(node)
6434
6435     yield f"{'  ' * depth}{node.__class__.__name__}("
6436
6437     for field in sorted(node._fields):  # noqa: F402
6438         # TypeIgnore has only one field 'lineno' which breaks this comparison
6439         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6440         if sys.version_info >= (3, 8):
6441             type_ignore_classes += (ast.TypeIgnore,)
6442         if isinstance(node, type_ignore_classes):
6443             break
6444
6445         try:
6446             value = getattr(node, field)
6447         except AttributeError:
6448             continue
6449
6450         yield f"{'  ' * (depth+1)}{field}="
6451
6452         if isinstance(value, list):
6453             for item in value:
6454                 # Ignore nested tuples within del statements, because we may insert
6455                 # parentheses and they change the AST.
6456                 if (
6457                     field == "targets"
6458                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6459                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6460                 ):
6461                     for item in item.elts:
6462                         yield from _stringify_ast(item, depth + 2)
6463
6464                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6465                     yield from _stringify_ast(item, depth + 2)
6466
6467         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6468             yield from _stringify_ast(value, depth + 2)
6469
6470         else:
6471             # Constant strings may be indented across newlines, if they are
6472             # docstrings; fold spaces after newlines when comparing. Similarly,
6473             # trailing and leading space may be removed.
6474             # Note that when formatting Python 2 code, at least with Windows
6475             # line-endings, docstrings can end up here as bytes instead of
6476             # str so make sure that we handle both cases.
6477             if (
6478                 isinstance(node, ast.Constant)
6479                 and field == "value"
6480                 and isinstance(value, (str, bytes))
6481             ):
6482                 lineend = "\n" if isinstance(value, str) else b"\n"
6483                 # To normalize, we strip any leading and trailing space from
6484                 # each line...
6485                 stripped = [line.strip() for line in value.splitlines()]
6486                 normalized = lineend.join(stripped)  # type: ignore[attr-defined]
6487                 # ...and remove any blank lines at the beginning and end of
6488                 # the whole string
6489                 normalized = normalized.strip()
6490             else:
6491                 normalized = value
6492             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6493
6494     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6495
6496
6497 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
6498     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6499     try:
6500         src_ast = parse_ast(src)
6501     except Exception as exc:
6502         raise AssertionError(
6503             "cannot use --safe with this file; failed to parse source file.  AST"
6504             f" error message: {exc}"
6505         )
6506
6507     try:
6508         dst_ast = parse_ast(dst)
6509     except Exception as exc:
6510         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6511         raise AssertionError(
6512             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
6513             "Please report a bug on https://github.com/psf/black/issues.  "
6514             f"This invalid output might be helpful: {log}"
6515         ) from None
6516
6517     src_ast_str = "\n".join(_stringify_ast(src_ast))
6518     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6519     if src_ast_str != dst_ast_str:
6520         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6521         raise AssertionError(
6522             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6523             f" source on pass {pass_num}.  Please report a bug on "
6524             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
6525         ) from None
6526
6527
6528 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6529     """Raise AssertionError if `dst` reformats differently the second time."""
6530     newdst = format_str(dst, mode=mode)
6531     if dst != newdst:
6532         log = dump_to_file(
6533             str(mode),
6534             diff(src, dst, "source", "first pass"),
6535             diff(dst, newdst, "first pass", "second pass"),
6536         )
6537         raise AssertionError(
6538             "INTERNAL ERROR: Black produced different code on the second pass of the"
6539             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6540             f"  This diff might be helpful: {log}"
6541         ) from None
6542
6543
6544 @mypyc_attr(patchable=True)
6545 def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
6546     """Dump `output` to a temporary file. Return path to the file."""
6547     with tempfile.NamedTemporaryFile(
6548         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6549     ) as f:
6550         for lines in output:
6551             f.write(lines)
6552             if ensure_final_newline and lines and lines[-1] != "\n":
6553                 f.write("\n")
6554     return f.name
6555
6556
6557 @contextmanager
6558 def nullcontext() -> Iterator[None]:
6559     """Return an empty context manager.
6560
6561     To be used like `nullcontext` in Python 3.7.
6562     """
6563     yield
6564
6565
6566 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6567     """Return a unified diff string between strings `a` and `b`."""
6568     import difflib
6569
6570     a_lines = [line for line in a.splitlines(keepends=True)]
6571     b_lines = [line for line in b.splitlines(keepends=True)]
6572     diff_lines = []
6573     for line in difflib.unified_diff(
6574         a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
6575     ):
6576         # Work around https://bugs.python.org/issue2142
6577         # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
6578         if line[-1] == "\n":
6579             diff_lines.append(line)
6580         else:
6581             diff_lines.append(line + "\n")
6582             diff_lines.append("\\ No newline at end of file\n")
6583     return "".join(diff_lines)
6584
6585
6586 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6587     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6588     err("Aborted!")
6589     for task in tasks:
6590         task.cancel()
6591
6592
6593 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6594     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6595     try:
6596         if sys.version_info[:2] >= (3, 7):
6597             all_tasks = asyncio.all_tasks
6598         else:
6599             all_tasks = asyncio.Task.all_tasks
6600         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6601         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6602         if not to_cancel:
6603             return
6604
6605         for task in to_cancel:
6606             task.cancel()
6607         loop.run_until_complete(
6608             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6609         )
6610     finally:
6611         # `concurrent.futures.Future` objects cannot be cancelled once they
6612         # are already running. There might be some when the `shutdown()` happened.
6613         # Silence their logger's spew about the event loop being closed.
6614         cf_logger = logging.getLogger("concurrent.futures")
6615         cf_logger.setLevel(logging.CRITICAL)
6616         loop.close()
6617
6618
6619 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6620     """Replace `regex` with `replacement` twice on `original`.
6621
6622     This is used by string normalization to perform replaces on
6623     overlapping matches.
6624     """
6625     return regex.sub(replacement, regex.sub(replacement, original))
6626
6627
6628 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6629     """Compile a regular expression string in `regex`.
6630
6631     If it contains newlines, use verbose mode.
6632     """
6633     if "\n" in regex:
6634         regex = "(?x)" + regex
6635     compiled: Pattern[str] = re.compile(regex)
6636     return compiled
6637
6638
6639 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6640     """Like `reversed(enumerate(sequence))` if that were possible."""
6641     index = len(sequence) - 1
6642     for element in reversed(sequence):
6643         yield (index, element)
6644         index -= 1
6645
6646
6647 def enumerate_with_length(
6648     line: Line, reversed: bool = False
6649 ) -> Iterator[Tuple[Index, Leaf, int]]:
6650     """Return an enumeration of leaves with their length.
6651
6652     Stops prematurely on multiline strings and standalone comments.
6653     """
6654     op = cast(
6655         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6656         enumerate_reversed if reversed else enumerate,
6657     )
6658     for index, leaf in op(line.leaves):
6659         length = len(leaf.prefix) + len(leaf.value)
6660         if "\n" in leaf.value:
6661             return  # Multiline strings, we can't continue.
6662
6663         for comment in line.comments_after(leaf):
6664             length += len(comment.value)
6665
6666         yield index, leaf, length
6667
6668
6669 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6670     """Return True if `line` is no longer than `line_length`.
6671
6672     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6673     """
6674     if not line_str:
6675         line_str = line_to_string(line)
6676     return (
6677         len(line_str) <= line_length
6678         and "\n" not in line_str  # multiline strings
6679         and not line.contains_standalone_comments()
6680     )
6681
6682
6683 def can_be_split(line: Line) -> bool:
6684     """Return False if the line cannot be split *for sure*.
6685
6686     This is not an exhaustive search but a cheap heuristic that we can use to
6687     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6688     in unnecessary parentheses).
6689     """
6690     leaves = line.leaves
6691     if len(leaves) < 2:
6692         return False
6693
6694     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6695         call_count = 0
6696         dot_count = 0
6697         next = leaves[-1]
6698         for leaf in leaves[-2::-1]:
6699             if leaf.type in OPENING_BRACKETS:
6700                 if next.type not in CLOSING_BRACKETS:
6701                     return False
6702
6703                 call_count += 1
6704             elif leaf.type == token.DOT:
6705                 dot_count += 1
6706             elif leaf.type == token.NAME:
6707                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6708                     return False
6709
6710             elif leaf.type not in CLOSING_BRACKETS:
6711                 return False
6712
6713             if dot_count > 1 and call_count > 1:
6714                 return False
6715
6716     return True
6717
6718
6719 def can_omit_invisible_parens(
6720     line: Line,
6721     line_length: int,
6722     omit_on_explode: Collection[LeafID] = (),
6723 ) -> bool:
6724     """Does `line` have a shape safe to reformat without optional parens around it?
6725
6726     Returns True for only a subset of potentially nice looking formattings but
6727     the point is to not return false positives that end up producing lines that
6728     are too long.
6729     """
6730     bt = line.bracket_tracker
6731     if not bt.delimiters:
6732         # Without delimiters the optional parentheses are useless.
6733         return True
6734
6735     max_priority = bt.max_delimiter_priority()
6736     if bt.delimiter_count_with_priority(max_priority) > 1:
6737         # With more than one delimiter of a kind the optional parentheses read better.
6738         return False
6739
6740     if max_priority == DOT_PRIORITY:
6741         # A single stranded method call doesn't require optional parentheses.
6742         return True
6743
6744     assert len(line.leaves) >= 2, "Stranded delimiter"
6745
6746     # With a single delimiter, omit if the expression starts or ends with
6747     # a bracket.
6748     first = line.leaves[0]
6749     second = line.leaves[1]
6750     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6751         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6752             return True
6753
6754         # Note: we are not returning False here because a line might have *both*
6755         # a leading opening bracket and a trailing closing bracket.  If the
6756         # opening bracket doesn't match our rule, maybe the closing will.
6757
6758     penultimate = line.leaves[-2]
6759     last = line.leaves[-1]
6760     if line.magic_trailing_comma:
6761         try:
6762             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6763         except LookupError:
6764             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6765             return False
6766
6767     if (
6768         last.type == token.RPAR
6769         or last.type == token.RBRACE
6770         or (
6771             # don't use indexing for omitting optional parentheses;
6772             # it looks weird
6773             last.type == token.RSQB
6774             and last.parent
6775             and last.parent.type != syms.trailer
6776         )
6777     ):
6778         if penultimate.type in OPENING_BRACKETS:
6779             # Empty brackets don't help.
6780             return False
6781
6782         if is_multiline_string(first):
6783             # Additional wrapping of a multiline string in this situation is
6784             # unnecessary.
6785             return True
6786
6787         if line.magic_trailing_comma and penultimate.type == token.COMMA:
6788             # The rightmost non-omitted bracket pair is the one we want to explode on.
6789             return True
6790
6791         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6792             return True
6793
6794     return False
6795
6796
6797 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6798     """See `can_omit_invisible_parens`."""
6799     remainder = False
6800     length = 4 * line.depth
6801     _index = -1
6802     for _index, leaf, leaf_length in enumerate_with_length(line):
6803         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6804             remainder = True
6805         if remainder:
6806             length += leaf_length
6807             if length > line_length:
6808                 break
6809
6810             if leaf.type in OPENING_BRACKETS:
6811                 # There are brackets we can further split on.
6812                 remainder = False
6813
6814     else:
6815         # checked the entire string and line length wasn't exceeded
6816         if len(line.leaves) == _index + 1:
6817             return True
6818
6819     return False
6820
6821
6822 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6823     """See `can_omit_invisible_parens`."""
6824     length = 4 * line.depth
6825     seen_other_brackets = False
6826     for _index, leaf, leaf_length in enumerate_with_length(line):
6827         length += leaf_length
6828         if leaf is last.opening_bracket:
6829             if seen_other_brackets or length <= line_length:
6830                 return True
6831
6832         elif leaf.type in OPENING_BRACKETS:
6833             # There are brackets we can further split on.
6834             seen_other_brackets = True
6835
6836     return False
6837
6838
6839 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6840     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6841     stop_after = None
6842     last = None
6843     for leaf in reversed(leaves):
6844         if stop_after:
6845             if leaf is stop_after:
6846                 stop_after = None
6847             continue
6848
6849         if last:
6850             return leaf, last
6851
6852         if id(leaf) in omit:
6853             stop_after = leaf.opening_bracket
6854         else:
6855             last = leaf
6856     else:
6857         raise LookupError("Last two leaves were also skipped")
6858
6859
6860 def run_transformer(
6861     line: Line,
6862     transform: Transformer,
6863     mode: Mode,
6864     features: Collection[Feature],
6865     *,
6866     line_str: str = "",
6867 ) -> List[Line]:
6868     if not line_str:
6869         line_str = line_to_string(line)
6870     result: List[Line] = []
6871     for transformed_line in transform(line, features):
6872         if str(transformed_line).strip("\n") == line_str:
6873             raise CannotTransform("Line transformer returned an unchanged result")
6874
6875         result.extend(transform_line(transformed_line, mode=mode, features=features))
6876
6877     if not (
6878         transform.__name__ == "rhs"
6879         and line.bracket_tracker.invisible
6880         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6881         and not line.contains_multiline_strings()
6882         and not result[0].contains_uncollapsable_type_comments()
6883         and not result[0].contains_unsplittable_type_ignore()
6884         and not is_line_short_enough(result[0], line_length=mode.line_length)
6885     ):
6886         return result
6887
6888     line_copy = line.clone()
6889     append_leaves(line_copy, line, line.leaves)
6890     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6891     second_opinion = run_transformer(
6892         line_copy, transform, mode, features_fop, line_str=line_str
6893     )
6894     if all(
6895         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6896     ):
6897         result = second_opinion
6898     return result
6899
6900
6901 def get_cache_file(mode: Mode) -> Path:
6902     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6903
6904
6905 def read_cache(mode: Mode) -> Cache:
6906     """Read the cache if it exists and is well formed.
6907
6908     If it is not well formed, the call to write_cache later should resolve the issue.
6909     """
6910     cache_file = get_cache_file(mode)
6911     if not cache_file.exists():
6912         return {}
6913
6914     with cache_file.open("rb") as fobj:
6915         try:
6916             cache: Cache = pickle.load(fobj)
6917         except (pickle.UnpicklingError, ValueError):
6918             return {}
6919
6920     return cache
6921
6922
6923 def get_cache_info(path: Path) -> CacheInfo:
6924     """Return the information used to check if a file is already formatted or not."""
6925     stat = path.stat()
6926     return stat.st_mtime, stat.st_size
6927
6928
6929 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6930     """Split an iterable of paths in `sources` into two sets.
6931
6932     The first contains paths of files that modified on disk or are not in the
6933     cache. The other contains paths to non-modified files.
6934     """
6935     todo, done = set(), set()
6936     for src in sources:
6937         res_src = src.resolve()
6938         if cache.get(str(res_src)) != get_cache_info(res_src):
6939             todo.add(src)
6940         else:
6941             done.add(src)
6942     return todo, done
6943
6944
6945 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6946     """Update the cache file."""
6947     cache_file = get_cache_file(mode)
6948     try:
6949         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6950         new_cache = {
6951             **cache,
6952             **{str(src.resolve()): get_cache_info(src) for src in sources},
6953         }
6954         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6955             pickle.dump(new_cache, f, protocol=4)
6956         os.replace(f.name, cache_file)
6957     except OSError:
6958         pass
6959
6960
6961 def patch_click() -> None:
6962     """Make Click not crash.
6963
6964     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6965     default which restricts paths that it can access during the lifetime of the
6966     application.  Click refuses to work in this scenario by raising a RuntimeError.
6967
6968     In case of Black the likelihood that non-ASCII characters are going to be used in
6969     file paths is minimal since it's Python source code.  Moreover, this crash was
6970     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6971     """
6972     try:
6973         from click import core
6974         from click import _unicodefun  # type: ignore
6975     except ModuleNotFoundError:
6976         return
6977
6978     for module in (core, _unicodefun):
6979         if hasattr(module, "_verify_python3_env"):
6980             module._verify_python3_env = lambda: None
6981
6982
6983 def patched_main() -> None:
6984     freeze_support()
6985     patch_click()
6986     main()
6987
6988
6989 def is_docstring(leaf: Leaf) -> bool:
6990     if prev_siblings_are(
6991         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
6992     ):
6993         return True
6994
6995     # Multiline docstring on the same line as the `def`.
6996     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
6997         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
6998         # grammar. We're safe to return True without further checks.
6999         return True
7000
7001     return False
7002
7003
7004 def lines_with_leading_tabs_expanded(s: str) -> List[str]:
7005     """
7006     Splits string into lines and expands only leading tabs (following the normal
7007     Python rules)
7008     """
7009     lines = []
7010     for line in s.splitlines():
7011         # Find the index of the first non-whitespace character after a string of
7012         # whitespace that includes at least one tab
7013         match = re.match(r"\s*\t+\s*(\S)", line)
7014         if match:
7015             first_non_whitespace_idx = match.start(1)
7016
7017             lines.append(
7018                 line[:first_non_whitespace_idx].expandtabs()
7019                 + line[first_non_whitespace_idx:]
7020             )
7021         else:
7022             lines.append(line)
7023     return lines
7024
7025
7026 def fix_docstring(docstring: str, prefix: str) -> str:
7027     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
7028     if not docstring:
7029         return ""
7030     lines = lines_with_leading_tabs_expanded(docstring)
7031     # Determine minimum indentation (first line doesn't count):
7032     indent = sys.maxsize
7033     for line in lines[1:]:
7034         stripped = line.lstrip()
7035         if stripped:
7036             indent = min(indent, len(line) - len(stripped))
7037     # Remove indentation (first line is special):
7038     trimmed = [lines[0].strip()]
7039     if indent < sys.maxsize:
7040         last_line_idx = len(lines) - 2
7041         for i, line in enumerate(lines[1:]):
7042             stripped_line = line[indent:].rstrip()
7043             if stripped_line or i == last_line_idx:
7044                 trimmed.append(prefix + stripped_line)
7045             else:
7046                 trimmed.append("")
7047     return "\n".join(trimmed)
7048
7049
7050 if __name__ == "__main__":
7051     patched_main()