]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add --skip-magic-trailing-comma (#1824)
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from mypy_extensions import mypyc_attr
46
47 from appdirs import user_cache_dir
48 from dataclasses import dataclass, field, replace
49 import click
50 import toml
51 from typed_ast import ast3, ast27
52 from pathspec import PathSpec
53
54 # lib2to3 fork
55 from blib2to3.pytree import Node, Leaf, type_repr
56 from blib2to3 import pygram, pytree
57 from blib2to3.pgen2 import driver, token
58 from blib2to3.pgen2.grammar import Grammar
59 from blib2to3.pgen2.parse import ParseError
60
61 from _black_version import version as __version__
62
63 if sys.version_info < (3, 8):
64     from typing_extensions import Final
65 else:
66     from typing import Final
67
68 if TYPE_CHECKING:
69     import colorama  # noqa: F401
70
71 DEFAULT_LINE_LENGTH = 88
72 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
73 DEFAULT_INCLUDES = r"\.pyi?$"
74 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
75 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
76
77 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
78
79
80 # types
81 FileContent = str
82 Encoding = str
83 NewLine = str
84 Depth = int
85 NodeType = int
86 ParserState = int
87 LeafID = int
88 StringID = int
89 Priority = int
90 Index = int
91 LN = Union[Leaf, Node]
92 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
93 Timestamp = float
94 FileSize = int
95 CacheInfo = Tuple[Timestamp, FileSize]
96 Cache = Dict[Path, CacheInfo]
97 out = partial(click.secho, bold=True, err=True)
98 err = partial(click.secho, fg="red", err=True)
99
100 pygram.initialize(CACHE_DIR)
101 syms = pygram.python_symbols
102
103
104 class NothingChanged(UserWarning):
105     """Raised when reformatted code is the same as source."""
106
107
108 class CannotTransform(Exception):
109     """Base class for errors raised by Transformers."""
110
111
112 class CannotSplit(CannotTransform):
113     """A readable split that fits the allotted line length is impossible."""
114
115
116 class InvalidInput(ValueError):
117     """Raised when input source code fails all parse attempts."""
118
119
120 class BracketMatchError(KeyError):
121     """Raised when an opening bracket is unable to be matched to a closing bracket."""
122
123
124 T = TypeVar("T")
125 E = TypeVar("E", bound=Exception)
126
127
128 class Ok(Generic[T]):
129     def __init__(self, value: T) -> None:
130         self._value = value
131
132     def ok(self) -> T:
133         return self._value
134
135
136 class Err(Generic[E]):
137     def __init__(self, e: E) -> None:
138         self._e = e
139
140     def err(self) -> E:
141         return self._e
142
143
144 # The 'Result' return type is used to implement an error-handling model heavily
145 # influenced by that used by the Rust programming language
146 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
147 Result = Union[Ok[T], Err[E]]
148 TResult = Result[T, CannotTransform]  # (T)ransform Result
149 TMatchResult = TResult[Index]
150
151
152 class WriteBack(Enum):
153     NO = 0
154     YES = 1
155     DIFF = 2
156     CHECK = 3
157     COLOR_DIFF = 4
158
159     @classmethod
160     def from_configuration(
161         cls, *, check: bool, diff: bool, color: bool = False
162     ) -> "WriteBack":
163         if check and not diff:
164             return cls.CHECK
165
166         if diff and color:
167             return cls.COLOR_DIFF
168
169         return cls.DIFF if diff else cls.YES
170
171
172 class Changed(Enum):
173     NO = 0
174     CACHED = 1
175     YES = 2
176
177
178 class TargetVersion(Enum):
179     PY27 = 2
180     PY33 = 3
181     PY34 = 4
182     PY35 = 5
183     PY36 = 6
184     PY37 = 7
185     PY38 = 8
186     PY39 = 9
187
188     def is_python2(self) -> bool:
189         return self is TargetVersion.PY27
190
191
192 class Feature(Enum):
193     # All string literals are unicode
194     UNICODE_LITERALS = 1
195     F_STRINGS = 2
196     NUMERIC_UNDERSCORES = 3
197     TRAILING_COMMA_IN_CALL = 4
198     TRAILING_COMMA_IN_DEF = 5
199     # The following two feature-flags are mutually exclusive, and exactly one should be
200     # set for every version of python.
201     ASYNC_IDENTIFIERS = 6
202     ASYNC_KEYWORDS = 7
203     ASSIGNMENT_EXPRESSIONS = 8
204     POS_ONLY_ARGUMENTS = 9
205     RELAXED_DECORATORS = 10
206     FORCE_OPTIONAL_PARENTHESES = 50
207
208
209 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
210     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
211     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
212     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
213     TargetVersion.PY35: {
214         Feature.UNICODE_LITERALS,
215         Feature.TRAILING_COMMA_IN_CALL,
216         Feature.ASYNC_IDENTIFIERS,
217     },
218     TargetVersion.PY36: {
219         Feature.UNICODE_LITERALS,
220         Feature.F_STRINGS,
221         Feature.NUMERIC_UNDERSCORES,
222         Feature.TRAILING_COMMA_IN_CALL,
223         Feature.TRAILING_COMMA_IN_DEF,
224         Feature.ASYNC_IDENTIFIERS,
225     },
226     TargetVersion.PY37: {
227         Feature.UNICODE_LITERALS,
228         Feature.F_STRINGS,
229         Feature.NUMERIC_UNDERSCORES,
230         Feature.TRAILING_COMMA_IN_CALL,
231         Feature.TRAILING_COMMA_IN_DEF,
232         Feature.ASYNC_KEYWORDS,
233     },
234     TargetVersion.PY38: {
235         Feature.UNICODE_LITERALS,
236         Feature.F_STRINGS,
237         Feature.NUMERIC_UNDERSCORES,
238         Feature.TRAILING_COMMA_IN_CALL,
239         Feature.TRAILING_COMMA_IN_DEF,
240         Feature.ASYNC_KEYWORDS,
241         Feature.ASSIGNMENT_EXPRESSIONS,
242         Feature.POS_ONLY_ARGUMENTS,
243     },
244     TargetVersion.PY39: {
245         Feature.UNICODE_LITERALS,
246         Feature.F_STRINGS,
247         Feature.NUMERIC_UNDERSCORES,
248         Feature.TRAILING_COMMA_IN_CALL,
249         Feature.TRAILING_COMMA_IN_DEF,
250         Feature.ASYNC_KEYWORDS,
251         Feature.ASSIGNMENT_EXPRESSIONS,
252         Feature.RELAXED_DECORATORS,
253         Feature.POS_ONLY_ARGUMENTS,
254     },
255 }
256
257
258 @dataclass
259 class Mode:
260     target_versions: Set[TargetVersion] = field(default_factory=set)
261     line_length: int = DEFAULT_LINE_LENGTH
262     string_normalization: bool = True
263     magic_trailing_comma: bool = True
264     experimental_string_processing: bool = False
265     is_pyi: bool = False
266
267     def get_cache_key(self) -> str:
268         if self.target_versions:
269             version_str = ",".join(
270                 str(version.value)
271                 for version in sorted(self.target_versions, key=lambda v: v.value)
272             )
273         else:
274             version_str = "-"
275         parts = [
276             version_str,
277             str(self.line_length),
278             str(int(self.string_normalization)),
279             str(int(self.is_pyi)),
280         ]
281         return ".".join(parts)
282
283
284 # Legacy name, left for integrations.
285 FileMode = Mode
286
287
288 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
289     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
290
291
292 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
293     """Find the absolute filepath to a pyproject.toml if it exists"""
294     path_project_root = find_project_root(path_search_start)
295     path_pyproject_toml = path_project_root / "pyproject.toml"
296     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
297
298
299 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
300     """Parse a pyproject toml file, pulling out relevant parts for Black
301
302     If parsing fails, will raise a toml.TomlDecodeError
303     """
304     pyproject_toml = toml.load(path_config)
305     config = pyproject_toml.get("tool", {}).get("black", {})
306     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
307
308
309 def read_pyproject_toml(
310     ctx: click.Context, param: click.Parameter, value: Optional[str]
311 ) -> Optional[str]:
312     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
313
314     Returns the path to a successfully found and read configuration file, None
315     otherwise.
316     """
317     if not value:
318         value = find_pyproject_toml(ctx.params.get("src", ()))
319         if value is None:
320             return None
321
322     try:
323         config = parse_pyproject_toml(value)
324     except (toml.TomlDecodeError, OSError) as e:
325         raise click.FileError(
326             filename=value, hint=f"Error reading configuration file: {e}"
327         )
328
329     if not config:
330         return None
331     else:
332         # Sanitize the values to be Click friendly. For more information please see:
333         # https://github.com/psf/black/issues/1458
334         # https://github.com/pallets/click/issues/1567
335         config = {
336             k: str(v) if not isinstance(v, (list, dict)) else v
337             for k, v in config.items()
338         }
339
340     target_version = config.get("target_version")
341     if target_version is not None and not isinstance(target_version, list):
342         raise click.BadOptionUsage(
343             "target-version", "Config key target-version must be a list"
344         )
345
346     default_map: Dict[str, Any] = {}
347     if ctx.default_map:
348         default_map.update(ctx.default_map)
349     default_map.update(config)
350
351     ctx.default_map = default_map
352     return value
353
354
355 def target_version_option_callback(
356     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
357 ) -> List[TargetVersion]:
358     """Compute the target versions from a --target-version flag.
359
360     This is its own function because mypy couldn't infer the type correctly
361     when it was a lambda, causing mypyc trouble.
362     """
363     return [TargetVersion[val.upper()] for val in v]
364
365
366 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
367 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
368 @click.option(
369     "-l",
370     "--line-length",
371     type=int,
372     default=DEFAULT_LINE_LENGTH,
373     help="How many characters per line to allow.",
374     show_default=True,
375 )
376 @click.option(
377     "-t",
378     "--target-version",
379     type=click.Choice([v.name.lower() for v in TargetVersion]),
380     callback=target_version_option_callback,
381     multiple=True,
382     help=(
383         "Python versions that should be supported by Black's output. [default: per-file"
384         " auto-detection]"
385     ),
386 )
387 @click.option(
388     "--pyi",
389     is_flag=True,
390     help=(
391         "Format all input files like typing stubs regardless of file extension (useful"
392         " when piping source on standard input)."
393     ),
394 )
395 @click.option(
396     "-S",
397     "--skip-string-normalization",
398     is_flag=True,
399     help="Don't normalize string quotes or prefixes.",
400 )
401 @click.option(
402     "-C",
403     "--skip-magic-trailing-comma",
404     is_flag=True,
405     help="Don't use trailing commas as a reason to split lines.",
406 )
407 @click.option(
408     "--experimental-string-processing",
409     is_flag=True,
410     hidden=True,
411     help=(
412         "Experimental option that performs more normalization on string literals."
413         " Currently disabled because it leads to some crashes."
414     ),
415 )
416 @click.option(
417     "--check",
418     is_flag=True,
419     help=(
420         "Don't write the files back, just return the status.  Return code 0 means"
421         " nothing would change.  Return code 1 means some files would be reformatted."
422         " Return code 123 means there was an internal error."
423     ),
424 )
425 @click.option(
426     "--diff",
427     is_flag=True,
428     help="Don't write the files back, just output a diff for each file on stdout.",
429 )
430 @click.option(
431     "--color/--no-color",
432     is_flag=True,
433     help="Show colored diff. Only applies when `--diff` is given.",
434 )
435 @click.option(
436     "--fast/--safe",
437     is_flag=True,
438     help="If --fast given, skip temporary sanity checks. [default: --safe]",
439 )
440 @click.option(
441     "--include",
442     type=str,
443     default=DEFAULT_INCLUDES,
444     help=(
445         "A regular expression that matches files and directories that should be"
446         " included on recursive searches.  An empty value means all files are included"
447         " regardless of the name.  Use forward slashes for directories on all platforms"
448         " (Windows, too).  Exclusions are calculated first, inclusions later."
449     ),
450     show_default=True,
451 )
452 @click.option(
453     "--exclude",
454     type=str,
455     default=DEFAULT_EXCLUDES,
456     help=(
457         "A regular expression that matches files and directories that should be"
458         " excluded on recursive searches.  An empty value means no paths are excluded."
459         " Use forward slashes for directories on all platforms (Windows, too). "
460         " Exclusions are calculated first, inclusions later."
461     ),
462     show_default=True,
463 )
464 @click.option(
465     "--force-exclude",
466     type=str,
467     help=(
468         "Like --exclude, but files and directories matching this regex will be "
469         "excluded even when they are passed explicitly as arguments."
470     ),
471 )
472 @click.option(
473     "--stdin-filename",
474     type=str,
475     help=(
476         "The name of the file when passing it through stdin. Useful to make "
477         "sure Black will respect --force-exclude option on some "
478         "editors that rely on using stdin."
479     ),
480 )
481 @click.option(
482     "-q",
483     "--quiet",
484     is_flag=True,
485     help=(
486         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
487         " those with 2>/dev/null."
488     ),
489 )
490 @click.option(
491     "-v",
492     "--verbose",
493     is_flag=True,
494     help=(
495         "Also emit messages to stderr about files that were not changed or were ignored"
496         " due to --exclude=."
497     ),
498 )
499 @click.version_option(version=__version__)
500 @click.argument(
501     "src",
502     nargs=-1,
503     type=click.Path(
504         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
505     ),
506     is_eager=True,
507 )
508 @click.option(
509     "--config",
510     type=click.Path(
511         exists=True,
512         file_okay=True,
513         dir_okay=False,
514         readable=True,
515         allow_dash=False,
516         path_type=str,
517     ),
518     is_eager=True,
519     callback=read_pyproject_toml,
520     help="Read configuration from FILE path.",
521 )
522 @click.pass_context
523 def main(
524     ctx: click.Context,
525     code: Optional[str],
526     line_length: int,
527     target_version: List[TargetVersion],
528     check: bool,
529     diff: bool,
530     color: bool,
531     fast: bool,
532     pyi: bool,
533     skip_string_normalization: bool,
534     skip_magic_trailing_comma: bool,
535     experimental_string_processing: bool,
536     quiet: bool,
537     verbose: bool,
538     include: str,
539     exclude: str,
540     force_exclude: Optional[str],
541     stdin_filename: Optional[str],
542     src: Tuple[str, ...],
543     config: Optional[str],
544 ) -> None:
545     """The uncompromising code formatter."""
546     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
547     if target_version:
548         versions = set(target_version)
549     else:
550         # We'll autodetect later.
551         versions = set()
552     mode = Mode(
553         target_versions=versions,
554         line_length=line_length,
555         is_pyi=pyi,
556         string_normalization=not skip_string_normalization,
557         magic_trailing_comma=not skip_magic_trailing_comma,
558         experimental_string_processing=experimental_string_processing,
559     )
560     if config and verbose:
561         out(f"Using configuration from {config}.", bold=False, fg="blue")
562     if code is not None:
563         print(format_str(code, mode=mode))
564         ctx.exit(0)
565     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
566     sources = get_sources(
567         ctx=ctx,
568         src=src,
569         quiet=quiet,
570         verbose=verbose,
571         include=include,
572         exclude=exclude,
573         force_exclude=force_exclude,
574         report=report,
575         stdin_filename=stdin_filename,
576     )
577
578     path_empty(
579         sources,
580         "No Python files are present to be formatted. Nothing to do 😴",
581         quiet,
582         verbose,
583         ctx,
584     )
585
586     if len(sources) == 1:
587         reformat_one(
588             src=sources.pop(),
589             fast=fast,
590             write_back=write_back,
591             mode=mode,
592             report=report,
593         )
594     else:
595         reformat_many(
596             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
597         )
598
599     if verbose or not quiet:
600         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
601         click.secho(str(report), err=True)
602     ctx.exit(report.return_code)
603
604
605 def get_sources(
606     *,
607     ctx: click.Context,
608     src: Tuple[str, ...],
609     quiet: bool,
610     verbose: bool,
611     include: str,
612     exclude: str,
613     force_exclude: Optional[str],
614     report: "Report",
615     stdin_filename: Optional[str],
616 ) -> Set[Path]:
617     """Compute the set of files to be formatted."""
618     try:
619         include_regex = re_compile_maybe_verbose(include)
620     except re.error:
621         err(f"Invalid regular expression for include given: {include!r}")
622         ctx.exit(2)
623     try:
624         exclude_regex = re_compile_maybe_verbose(exclude)
625     except re.error:
626         err(f"Invalid regular expression for exclude given: {exclude!r}")
627         ctx.exit(2)
628     try:
629         force_exclude_regex = (
630             re_compile_maybe_verbose(force_exclude) if force_exclude else None
631         )
632     except re.error:
633         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
634         ctx.exit(2)
635
636     root = find_project_root(src)
637     sources: Set[Path] = set()
638     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
639     gitignore = get_gitignore(root)
640
641     for s in src:
642         if s == "-" and stdin_filename:
643             p = Path(stdin_filename)
644             is_stdin = True
645         else:
646             p = Path(s)
647             is_stdin = False
648
649         if is_stdin or p.is_file():
650             normalized_path = normalize_path_maybe_ignore(p, root, report)
651             if normalized_path is None:
652                 continue
653
654             normalized_path = "/" + normalized_path
655             # Hard-exclude any files that matches the `--force-exclude` regex.
656             if force_exclude_regex:
657                 force_exclude_match = force_exclude_regex.search(normalized_path)
658             else:
659                 force_exclude_match = None
660             if force_exclude_match and force_exclude_match.group(0):
661                 report.path_ignored(p, "matches the --force-exclude regular expression")
662                 continue
663
664             if is_stdin:
665                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
666
667             sources.add(p)
668         elif p.is_dir():
669             sources.update(
670                 gen_python_files(
671                     p.iterdir(),
672                     root,
673                     include_regex,
674                     exclude_regex,
675                     force_exclude_regex,
676                     report,
677                     gitignore,
678                 )
679             )
680         elif s == "-":
681             sources.add(p)
682         else:
683             err(f"invalid path: {s}")
684     return sources
685
686
687 def path_empty(
688     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
689 ) -> None:
690     """
691     Exit if there is no `src` provided for formatting
692     """
693     if not src and (verbose or not quiet):
694         out(msg)
695         ctx.exit(0)
696
697
698 def reformat_one(
699     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
700 ) -> None:
701     """Reformat a single file under `src` without spawning child processes.
702
703     `fast`, `write_back`, and `mode` options are passed to
704     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
705     """
706     try:
707         changed = Changed.NO
708
709         if str(src) == "-":
710             is_stdin = True
711         elif str(src).startswith(STDIN_PLACEHOLDER):
712             is_stdin = True
713             # Use the original name again in case we want to print something
714             # to the user
715             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
716         else:
717             is_stdin = False
718
719         if is_stdin:
720             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
721                 changed = Changed.YES
722         else:
723             cache: Cache = {}
724             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
725                 cache = read_cache(mode)
726                 res_src = src.resolve()
727                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
728                     changed = Changed.CACHED
729             if changed is not Changed.CACHED and format_file_in_place(
730                 src, fast=fast, write_back=write_back, mode=mode
731             ):
732                 changed = Changed.YES
733             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
734                 write_back is WriteBack.CHECK and changed is Changed.NO
735             ):
736                 write_cache(cache, [src], mode)
737         report.done(src, changed)
738     except Exception as exc:
739         if report.verbose:
740             traceback.print_exc()
741         report.failed(src, str(exc))
742
743
744 def reformat_many(
745     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
746 ) -> None:
747     """Reformat multiple files using a ProcessPoolExecutor."""
748     executor: Executor
749     loop = asyncio.get_event_loop()
750     worker_count = os.cpu_count()
751     if sys.platform == "win32":
752         # Work around https://bugs.python.org/issue26903
753         worker_count = min(worker_count, 60)
754     try:
755         executor = ProcessPoolExecutor(max_workers=worker_count)
756     except (ImportError, OSError):
757         # we arrive here if the underlying system does not support multi-processing
758         # like in AWS Lambda or Termux, in which case we gracefully fallback to
759         # a ThreadPollExecutor with just a single worker (more workers would not do us
760         # any good due to the Global Interpreter Lock)
761         executor = ThreadPoolExecutor(max_workers=1)
762
763     try:
764         loop.run_until_complete(
765             schedule_formatting(
766                 sources=sources,
767                 fast=fast,
768                 write_back=write_back,
769                 mode=mode,
770                 report=report,
771                 loop=loop,
772                 executor=executor,
773             )
774         )
775     finally:
776         shutdown(loop)
777         if executor is not None:
778             executor.shutdown()
779
780
781 async def schedule_formatting(
782     sources: Set[Path],
783     fast: bool,
784     write_back: WriteBack,
785     mode: Mode,
786     report: "Report",
787     loop: asyncio.AbstractEventLoop,
788     executor: Executor,
789 ) -> None:
790     """Run formatting of `sources` in parallel using the provided `executor`.
791
792     (Use ProcessPoolExecutors for actual parallelism.)
793
794     `write_back`, `fast`, and `mode` options are passed to
795     :func:`format_file_in_place`.
796     """
797     cache: Cache = {}
798     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
799         cache = read_cache(mode)
800         sources, cached = filter_cached(cache, sources)
801         for src in sorted(cached):
802             report.done(src, Changed.CACHED)
803     if not sources:
804         return
805
806     cancelled = []
807     sources_to_cache = []
808     lock = None
809     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
810         # For diff output, we need locks to ensure we don't interleave output
811         # from different processes.
812         manager = Manager()
813         lock = manager.Lock()
814     tasks = {
815         asyncio.ensure_future(
816             loop.run_in_executor(
817                 executor, format_file_in_place, src, fast, mode, write_back, lock
818             )
819         ): src
820         for src in sorted(sources)
821     }
822     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
823     try:
824         loop.add_signal_handler(signal.SIGINT, cancel, pending)
825         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
826     except NotImplementedError:
827         # There are no good alternatives for these on Windows.
828         pass
829     while pending:
830         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
831         for task in done:
832             src = tasks.pop(task)
833             if task.cancelled():
834                 cancelled.append(task)
835             elif task.exception():
836                 report.failed(src, str(task.exception()))
837             else:
838                 changed = Changed.YES if task.result() else Changed.NO
839                 # If the file was written back or was successfully checked as
840                 # well-formatted, store this information in the cache.
841                 if write_back is WriteBack.YES or (
842                     write_back is WriteBack.CHECK and changed is Changed.NO
843                 ):
844                     sources_to_cache.append(src)
845                 report.done(src, changed)
846     if cancelled:
847         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
848     if sources_to_cache:
849         write_cache(cache, sources_to_cache, mode)
850
851
852 def format_file_in_place(
853     src: Path,
854     fast: bool,
855     mode: Mode,
856     write_back: WriteBack = WriteBack.NO,
857     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
858 ) -> bool:
859     """Format file under `src` path. Return True if changed.
860
861     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
862     code to the file.
863     `mode` and `fast` options are passed to :func:`format_file_contents`.
864     """
865     if src.suffix == ".pyi":
866         mode = replace(mode, is_pyi=True)
867
868     then = datetime.utcfromtimestamp(src.stat().st_mtime)
869     with open(src, "rb") as buf:
870         src_contents, encoding, newline = decode_bytes(buf.read())
871     try:
872         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
873     except NothingChanged:
874         return False
875
876     if write_back == WriteBack.YES:
877         with open(src, "w", encoding=encoding, newline=newline) as f:
878             f.write(dst_contents)
879     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
880         now = datetime.utcnow()
881         src_name = f"{src}\t{then} +0000"
882         dst_name = f"{src}\t{now} +0000"
883         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
884
885         if write_back == write_back.COLOR_DIFF:
886             diff_contents = color_diff(diff_contents)
887
888         with lock or nullcontext():
889             f = io.TextIOWrapper(
890                 sys.stdout.buffer,
891                 encoding=encoding,
892                 newline=newline,
893                 write_through=True,
894             )
895             f = wrap_stream_for_windows(f)
896             f.write(diff_contents)
897             f.detach()
898
899     return True
900
901
902 def color_diff(contents: str) -> str:
903     """Inject the ANSI color codes to the diff."""
904     lines = contents.split("\n")
905     for i, line in enumerate(lines):
906         if line.startswith("+++") or line.startswith("---"):
907             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
908         elif line.startswith("@@"):
909             line = "\033[36m" + line + "\033[0m"  # cyan, reset
910         elif line.startswith("+"):
911             line = "\033[32m" + line + "\033[0m"  # green, reset
912         elif line.startswith("-"):
913             line = "\033[31m" + line + "\033[0m"  # red, reset
914         lines[i] = line
915     return "\n".join(lines)
916
917
918 def wrap_stream_for_windows(
919     f: io.TextIOWrapper,
920 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
921     """
922     Wrap stream with colorama's wrap_stream so colors are shown on Windows.
923
924     If `colorama` is unavailable, the original stream is returned unmodified.
925     Otherwise, the `wrap_stream()` function determines whether the stream needs
926     to be wrapped for a Windows environment and will accordingly either return
927     an `AnsiToWin32` wrapper or the original stream.
928     """
929     try:
930         from colorama.initialise import wrap_stream
931     except ImportError:
932         return f
933     else:
934         # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
935         return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
936
937
938 def format_stdin_to_stdout(
939     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
940 ) -> bool:
941     """Format file on stdin. Return True if changed.
942
943     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
944     write a diff to stdout. The `mode` argument is passed to
945     :func:`format_file_contents`.
946     """
947     then = datetime.utcnow()
948     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
949     dst = src
950     try:
951         dst = format_file_contents(src, fast=fast, mode=mode)
952         return True
953
954     except NothingChanged:
955         return False
956
957     finally:
958         f = io.TextIOWrapper(
959             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
960         )
961         if write_back == WriteBack.YES:
962             f.write(dst)
963         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
964             now = datetime.utcnow()
965             src_name = f"STDIN\t{then} +0000"
966             dst_name = f"STDOUT\t{now} +0000"
967             d = diff(src, dst, src_name, dst_name)
968             if write_back == WriteBack.COLOR_DIFF:
969                 d = color_diff(d)
970                 f = wrap_stream_for_windows(f)
971             f.write(d)
972         f.detach()
973
974
975 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
976     """Reformat contents of a file and return new contents.
977
978     If `fast` is False, additionally confirm that the reformatted code is
979     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
980     `mode` is passed to :func:`format_str`.
981     """
982     if not src_contents.strip():
983         raise NothingChanged
984
985     dst_contents = format_str(src_contents, mode=mode)
986     if src_contents == dst_contents:
987         raise NothingChanged
988
989     if not fast:
990         assert_equivalent(src_contents, dst_contents)
991         assert_stable(src_contents, dst_contents, mode=mode)
992     return dst_contents
993
994
995 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
996     """Reformat a string and return new contents.
997
998     `mode` determines formatting options, such as how many characters per line are
999     allowed.  Example:
1000
1001     >>> import black
1002     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1003     def f(arg: str = "") -> None:
1004         ...
1005
1006     A more complex example:
1007
1008     >>> print(
1009     ...   black.format_str(
1010     ...     "def f(arg:str='')->None: hey",
1011     ...     mode=black.Mode(
1012     ...       target_versions={black.TargetVersion.PY36},
1013     ...       line_length=10,
1014     ...       string_normalization=False,
1015     ...       is_pyi=False,
1016     ...     ),
1017     ...   ),
1018     ... )
1019     def f(
1020         arg: str = '',
1021     ) -> None:
1022         hey
1023
1024     """
1025     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1026     dst_contents = []
1027     future_imports = get_future_imports(src_node)
1028     if mode.target_versions:
1029         versions = mode.target_versions
1030     else:
1031         versions = detect_target_versions(src_node)
1032     normalize_fmt_off(src_node)
1033     lines = LineGenerator(
1034         mode=mode,
1035         remove_u_prefix="unicode_literals" in future_imports
1036         or supports_feature(versions, Feature.UNICODE_LITERALS),
1037     )
1038     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1039     empty_line = Line(mode=mode)
1040     after = 0
1041     split_line_features = {
1042         feature
1043         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1044         if supports_feature(versions, feature)
1045     }
1046     for current_line in lines.visit(src_node):
1047         dst_contents.append(str(empty_line) * after)
1048         before, after = elt.maybe_empty_lines(current_line)
1049         dst_contents.append(str(empty_line) * before)
1050         for line in transform_line(
1051             current_line, mode=mode, features=split_line_features
1052         ):
1053             dst_contents.append(str(line))
1054     return "".join(dst_contents)
1055
1056
1057 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1058     """Return a tuple of (decoded_contents, encoding, newline).
1059
1060     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1061     universal newlines (i.e. only contains LF).
1062     """
1063     srcbuf = io.BytesIO(src)
1064     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1065     if not lines:
1066         return "", encoding, "\n"
1067
1068     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1069     srcbuf.seek(0)
1070     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1071         return tiow.read(), encoding, newline
1072
1073
1074 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1075     if not target_versions:
1076         # No target_version specified, so try all grammars.
1077         return [
1078             # Python 3.7+
1079             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1080             # Python 3.0-3.6
1081             pygram.python_grammar_no_print_statement_no_exec_statement,
1082             # Python 2.7 with future print_function import
1083             pygram.python_grammar_no_print_statement,
1084             # Python 2.7
1085             pygram.python_grammar,
1086         ]
1087
1088     if all(version.is_python2() for version in target_versions):
1089         # Python 2-only code, so try Python 2 grammars.
1090         return [
1091             # Python 2.7 with future print_function import
1092             pygram.python_grammar_no_print_statement,
1093             # Python 2.7
1094             pygram.python_grammar,
1095         ]
1096
1097     # Python 3-compatible code, so only try Python 3 grammar.
1098     grammars = []
1099     # If we have to parse both, try to parse async as a keyword first
1100     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1101         # Python 3.7+
1102         grammars.append(
1103             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1104         )
1105     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1106         # Python 3.0-3.6
1107         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1108     # At least one of the above branches must have been taken, because every Python
1109     # version has exactly one of the two 'ASYNC_*' flags
1110     return grammars
1111
1112
1113 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1114     """Given a string with source, return the lib2to3 Node."""
1115     if not src_txt.endswith("\n"):
1116         src_txt += "\n"
1117
1118     for grammar in get_grammars(set(target_versions)):
1119         drv = driver.Driver(grammar, pytree.convert)
1120         try:
1121             result = drv.parse_string(src_txt, True)
1122             break
1123
1124         except ParseError as pe:
1125             lineno, column = pe.context[1]
1126             lines = src_txt.splitlines()
1127             try:
1128                 faulty_line = lines[lineno - 1]
1129             except IndexError:
1130                 faulty_line = "<line number missing in source>"
1131             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1132     else:
1133         raise exc from None
1134
1135     if isinstance(result, Leaf):
1136         result = Node(syms.file_input, [result])
1137     return result
1138
1139
1140 def lib2to3_unparse(node: Node) -> str:
1141     """Given a lib2to3 node, return its string representation."""
1142     code = str(node)
1143     return code
1144
1145
1146 class Visitor(Generic[T]):
1147     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1148
1149     def visit(self, node: LN) -> Iterator[T]:
1150         """Main method to visit `node` and its children.
1151
1152         It tries to find a `visit_*()` method for the given `node.type`, like
1153         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1154         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1155         instead.
1156
1157         Then yields objects of type `T` from the selected visitor.
1158         """
1159         if node.type < 256:
1160             name = token.tok_name[node.type]
1161         else:
1162             name = str(type_repr(node.type))
1163         # We explicitly branch on whether a visitor exists (instead of
1164         # using self.visit_default as the default arg to getattr) in order
1165         # to save needing to create a bound method object and so mypyc can
1166         # generate a native call to visit_default.
1167         visitf = getattr(self, f"visit_{name}", None)
1168         if visitf:
1169             yield from visitf(node)
1170         else:
1171             yield from self.visit_default(node)
1172
1173     def visit_default(self, node: LN) -> Iterator[T]:
1174         """Default `visit_*()` implementation. Recurses to children of `node`."""
1175         if isinstance(node, Node):
1176             for child in node.children:
1177                 yield from self.visit(child)
1178
1179
1180 @dataclass
1181 class DebugVisitor(Visitor[T]):
1182     tree_depth: int = 0
1183
1184     def visit_default(self, node: LN) -> Iterator[T]:
1185         indent = " " * (2 * self.tree_depth)
1186         if isinstance(node, Node):
1187             _type = type_repr(node.type)
1188             out(f"{indent}{_type}", fg="yellow")
1189             self.tree_depth += 1
1190             for child in node.children:
1191                 yield from self.visit(child)
1192
1193             self.tree_depth -= 1
1194             out(f"{indent}/{_type}", fg="yellow", bold=False)
1195         else:
1196             _type = token.tok_name.get(node.type, str(node.type))
1197             out(f"{indent}{_type}", fg="blue", nl=False)
1198             if node.prefix:
1199                 # We don't have to handle prefixes for `Node` objects since
1200                 # that delegates to the first child anyway.
1201                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1202             out(f" {node.value!r}", fg="blue", bold=False)
1203
1204     @classmethod
1205     def show(cls, code: Union[str, Leaf, Node]) -> None:
1206         """Pretty-print the lib2to3 AST of a given string of `code`.
1207
1208         Convenience method for debugging.
1209         """
1210         v: DebugVisitor[None] = DebugVisitor()
1211         if isinstance(code, str):
1212             code = lib2to3_parse(code)
1213         list(v.visit(code))
1214
1215
1216 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1217 STATEMENT: Final = {
1218     syms.if_stmt,
1219     syms.while_stmt,
1220     syms.for_stmt,
1221     syms.try_stmt,
1222     syms.except_clause,
1223     syms.with_stmt,
1224     syms.funcdef,
1225     syms.classdef,
1226 }
1227 STANDALONE_COMMENT: Final = 153
1228 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1229 LOGIC_OPERATORS: Final = {"and", "or"}
1230 COMPARATORS: Final = {
1231     token.LESS,
1232     token.GREATER,
1233     token.EQEQUAL,
1234     token.NOTEQUAL,
1235     token.LESSEQUAL,
1236     token.GREATEREQUAL,
1237 }
1238 MATH_OPERATORS: Final = {
1239     token.VBAR,
1240     token.CIRCUMFLEX,
1241     token.AMPER,
1242     token.LEFTSHIFT,
1243     token.RIGHTSHIFT,
1244     token.PLUS,
1245     token.MINUS,
1246     token.STAR,
1247     token.SLASH,
1248     token.DOUBLESLASH,
1249     token.PERCENT,
1250     token.AT,
1251     token.TILDE,
1252     token.DOUBLESTAR,
1253 }
1254 STARS: Final = {token.STAR, token.DOUBLESTAR}
1255 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1256 VARARGS_PARENTS: Final = {
1257     syms.arglist,
1258     syms.argument,  # double star in arglist
1259     syms.trailer,  # single argument to call
1260     syms.typedargslist,
1261     syms.varargslist,  # lambdas
1262 }
1263 UNPACKING_PARENTS: Final = {
1264     syms.atom,  # single element of a list or set literal
1265     syms.dictsetmaker,
1266     syms.listmaker,
1267     syms.testlist_gexp,
1268     syms.testlist_star_expr,
1269 }
1270 TEST_DESCENDANTS: Final = {
1271     syms.test,
1272     syms.lambdef,
1273     syms.or_test,
1274     syms.and_test,
1275     syms.not_test,
1276     syms.comparison,
1277     syms.star_expr,
1278     syms.expr,
1279     syms.xor_expr,
1280     syms.and_expr,
1281     syms.shift_expr,
1282     syms.arith_expr,
1283     syms.trailer,
1284     syms.term,
1285     syms.power,
1286 }
1287 ASSIGNMENTS: Final = {
1288     "=",
1289     "+=",
1290     "-=",
1291     "*=",
1292     "@=",
1293     "/=",
1294     "%=",
1295     "&=",
1296     "|=",
1297     "^=",
1298     "<<=",
1299     ">>=",
1300     "**=",
1301     "//=",
1302 }
1303 COMPREHENSION_PRIORITY: Final = 20
1304 COMMA_PRIORITY: Final = 18
1305 TERNARY_PRIORITY: Final = 16
1306 LOGIC_PRIORITY: Final = 14
1307 STRING_PRIORITY: Final = 12
1308 COMPARATOR_PRIORITY: Final = 10
1309 MATH_PRIORITIES: Final = {
1310     token.VBAR: 9,
1311     token.CIRCUMFLEX: 8,
1312     token.AMPER: 7,
1313     token.LEFTSHIFT: 6,
1314     token.RIGHTSHIFT: 6,
1315     token.PLUS: 5,
1316     token.MINUS: 5,
1317     token.STAR: 4,
1318     token.SLASH: 4,
1319     token.DOUBLESLASH: 4,
1320     token.PERCENT: 4,
1321     token.AT: 4,
1322     token.TILDE: 3,
1323     token.DOUBLESTAR: 2,
1324 }
1325 DOT_PRIORITY: Final = 1
1326
1327
1328 @dataclass
1329 class BracketTracker:
1330     """Keeps track of brackets on a line."""
1331
1332     depth: int = 0
1333     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1334     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1335     previous: Optional[Leaf] = None
1336     _for_loop_depths: List[int] = field(default_factory=list)
1337     _lambda_argument_depths: List[int] = field(default_factory=list)
1338     invisible: List[Leaf] = field(default_factory=list)
1339
1340     def mark(self, leaf: Leaf) -> None:
1341         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1342
1343         All leaves receive an int `bracket_depth` field that stores how deep
1344         within brackets a given leaf is. 0 means there are no enclosing brackets
1345         that started on this line.
1346
1347         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1348         field that it forms a pair with. This is a one-directional link to
1349         avoid reference cycles.
1350
1351         If a leaf is a delimiter (a token on which Black can split the line if
1352         needed) and it's on depth 0, its `id()` is stored in the tracker's
1353         `delimiters` field.
1354         """
1355         if leaf.type == token.COMMENT:
1356             return
1357
1358         self.maybe_decrement_after_for_loop_variable(leaf)
1359         self.maybe_decrement_after_lambda_arguments(leaf)
1360         if leaf.type in CLOSING_BRACKETS:
1361             self.depth -= 1
1362             try:
1363                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1364             except KeyError as e:
1365                 raise BracketMatchError(
1366                     "Unable to match a closing bracket to the following opening"
1367                     f" bracket: {leaf}"
1368                 ) from e
1369             leaf.opening_bracket = opening_bracket
1370             if not leaf.value:
1371                 self.invisible.append(leaf)
1372         leaf.bracket_depth = self.depth
1373         if self.depth == 0:
1374             delim = is_split_before_delimiter(leaf, self.previous)
1375             if delim and self.previous is not None:
1376                 self.delimiters[id(self.previous)] = delim
1377             else:
1378                 delim = is_split_after_delimiter(leaf, self.previous)
1379                 if delim:
1380                     self.delimiters[id(leaf)] = delim
1381         if leaf.type in OPENING_BRACKETS:
1382             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1383             self.depth += 1
1384             if not leaf.value:
1385                 self.invisible.append(leaf)
1386         self.previous = leaf
1387         self.maybe_increment_lambda_arguments(leaf)
1388         self.maybe_increment_for_loop_variable(leaf)
1389
1390     def any_open_brackets(self) -> bool:
1391         """Return True if there is an yet unmatched open bracket on the line."""
1392         return bool(self.bracket_match)
1393
1394     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1395         """Return the highest priority of a delimiter found on the line.
1396
1397         Values are consistent with what `is_split_*_delimiter()` return.
1398         Raises ValueError on no delimiters.
1399         """
1400         return max(v for k, v in self.delimiters.items() if k not in exclude)
1401
1402     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1403         """Return the number of delimiters with the given `priority`.
1404
1405         If no `priority` is passed, defaults to max priority on the line.
1406         """
1407         if not self.delimiters:
1408             return 0
1409
1410         priority = priority or self.max_delimiter_priority()
1411         return sum(1 for p in self.delimiters.values() if p == priority)
1412
1413     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1414         """In a for loop, or comprehension, the variables are often unpacks.
1415
1416         To avoid splitting on the comma in this situation, increase the depth of
1417         tokens between `for` and `in`.
1418         """
1419         if leaf.type == token.NAME and leaf.value == "for":
1420             self.depth += 1
1421             self._for_loop_depths.append(self.depth)
1422             return True
1423
1424         return False
1425
1426     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1427         """See `maybe_increment_for_loop_variable` above for explanation."""
1428         if (
1429             self._for_loop_depths
1430             and self._for_loop_depths[-1] == self.depth
1431             and leaf.type == token.NAME
1432             and leaf.value == "in"
1433         ):
1434             self.depth -= 1
1435             self._for_loop_depths.pop()
1436             return True
1437
1438         return False
1439
1440     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1441         """In a lambda expression, there might be more than one argument.
1442
1443         To avoid splitting on the comma in this situation, increase the depth of
1444         tokens between `lambda` and `:`.
1445         """
1446         if leaf.type == token.NAME and leaf.value == "lambda":
1447             self.depth += 1
1448             self._lambda_argument_depths.append(self.depth)
1449             return True
1450
1451         return False
1452
1453     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1454         """See `maybe_increment_lambda_arguments` above for explanation."""
1455         if (
1456             self._lambda_argument_depths
1457             and self._lambda_argument_depths[-1] == self.depth
1458             and leaf.type == token.COLON
1459         ):
1460             self.depth -= 1
1461             self._lambda_argument_depths.pop()
1462             return True
1463
1464         return False
1465
1466     def get_open_lsqb(self) -> Optional[Leaf]:
1467         """Return the most recent opening square bracket (if any)."""
1468         return self.bracket_match.get((self.depth - 1, token.RSQB))
1469
1470
1471 @dataclass
1472 class Line:
1473     """Holds leaves and comments. Can be printed with `str(line)`."""
1474
1475     mode: Mode
1476     depth: int = 0
1477     leaves: List[Leaf] = field(default_factory=list)
1478     # keys ordered like `leaves`
1479     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1480     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1481     inside_brackets: bool = False
1482     should_explode: bool = False
1483
1484     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1485         """Add a new `leaf` to the end of the line.
1486
1487         Unless `preformatted` is True, the `leaf` will receive a new consistent
1488         whitespace prefix and metadata applied by :class:`BracketTracker`.
1489         Trailing commas are maybe removed, unpacked for loop variables are
1490         demoted from being delimiters.
1491
1492         Inline comments are put aside.
1493         """
1494         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1495         if not has_value:
1496             return
1497
1498         if token.COLON == leaf.type and self.is_class_paren_empty:
1499             del self.leaves[-2:]
1500         if self.leaves and not preformatted:
1501             # Note: at this point leaf.prefix should be empty except for
1502             # imports, for which we only preserve newlines.
1503             leaf.prefix += whitespace(
1504                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1505             )
1506         if self.inside_brackets or not preformatted:
1507             self.bracket_tracker.mark(leaf)
1508             if self.mode.magic_trailing_comma:
1509                 if self.has_magic_trailing_comma(leaf):
1510                     self.should_explode = True
1511             elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
1512                 self.remove_trailing_comma()
1513         if not self.append_comment(leaf):
1514             self.leaves.append(leaf)
1515
1516     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1517         """Like :func:`append()` but disallow invalid standalone comment structure.
1518
1519         Raises ValueError when any `leaf` is appended after a standalone comment
1520         or when a standalone comment is not the first leaf on the line.
1521         """
1522         if self.bracket_tracker.depth == 0:
1523             if self.is_comment:
1524                 raise ValueError("cannot append to standalone comments")
1525
1526             if self.leaves and leaf.type == STANDALONE_COMMENT:
1527                 raise ValueError(
1528                     "cannot append standalone comments to a populated line"
1529                 )
1530
1531         self.append(leaf, preformatted=preformatted)
1532
1533     @property
1534     def is_comment(self) -> bool:
1535         """Is this line a standalone comment?"""
1536         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1537
1538     @property
1539     def is_decorator(self) -> bool:
1540         """Is this line a decorator?"""
1541         return bool(self) and self.leaves[0].type == token.AT
1542
1543     @property
1544     def is_import(self) -> bool:
1545         """Is this an import line?"""
1546         return bool(self) and is_import(self.leaves[0])
1547
1548     @property
1549     def is_class(self) -> bool:
1550         """Is this line a class definition?"""
1551         return (
1552             bool(self)
1553             and self.leaves[0].type == token.NAME
1554             and self.leaves[0].value == "class"
1555         )
1556
1557     @property
1558     def is_stub_class(self) -> bool:
1559         """Is this line a class definition with a body consisting only of "..."?"""
1560         return self.is_class and self.leaves[-3:] == [
1561             Leaf(token.DOT, ".") for _ in range(3)
1562         ]
1563
1564     @property
1565     def is_def(self) -> bool:
1566         """Is this a function definition? (Also returns True for async defs.)"""
1567         try:
1568             first_leaf = self.leaves[0]
1569         except IndexError:
1570             return False
1571
1572         try:
1573             second_leaf: Optional[Leaf] = self.leaves[1]
1574         except IndexError:
1575             second_leaf = None
1576         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1577             first_leaf.type == token.ASYNC
1578             and second_leaf is not None
1579             and second_leaf.type == token.NAME
1580             and second_leaf.value == "def"
1581         )
1582
1583     @property
1584     def is_class_paren_empty(self) -> bool:
1585         """Is this a class with no base classes but using parentheses?
1586
1587         Those are unnecessary and should be removed.
1588         """
1589         return (
1590             bool(self)
1591             and len(self.leaves) == 4
1592             and self.is_class
1593             and self.leaves[2].type == token.LPAR
1594             and self.leaves[2].value == "("
1595             and self.leaves[3].type == token.RPAR
1596             and self.leaves[3].value == ")"
1597         )
1598
1599     @property
1600     def is_triple_quoted_string(self) -> bool:
1601         """Is the line a triple quoted string?"""
1602         return (
1603             bool(self)
1604             and self.leaves[0].type == token.STRING
1605             and self.leaves[0].value.startswith(('"""', "'''"))
1606         )
1607
1608     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1609         """If so, needs to be split before emitting."""
1610         for leaf in self.leaves:
1611             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1612                 return True
1613
1614         return False
1615
1616     def contains_uncollapsable_type_comments(self) -> bool:
1617         ignored_ids = set()
1618         try:
1619             last_leaf = self.leaves[-1]
1620             ignored_ids.add(id(last_leaf))
1621             if last_leaf.type == token.COMMA or (
1622                 last_leaf.type == token.RPAR and not last_leaf.value
1623             ):
1624                 # When trailing commas or optional parens are inserted by Black for
1625                 # consistency, comments after the previous last element are not moved
1626                 # (they don't have to, rendering will still be correct).  So we ignore
1627                 # trailing commas and invisible.
1628                 last_leaf = self.leaves[-2]
1629                 ignored_ids.add(id(last_leaf))
1630         except IndexError:
1631             return False
1632
1633         # A type comment is uncollapsable if it is attached to a leaf
1634         # that isn't at the end of the line (since that could cause it
1635         # to get associated to a different argument) or if there are
1636         # comments before it (since that could cause it to get hidden
1637         # behind a comment.
1638         comment_seen = False
1639         for leaf_id, comments in self.comments.items():
1640             for comment in comments:
1641                 if is_type_comment(comment):
1642                     if comment_seen or (
1643                         not is_type_comment(comment, " ignore")
1644                         and leaf_id not in ignored_ids
1645                     ):
1646                         return True
1647
1648                 comment_seen = True
1649
1650         return False
1651
1652     def contains_unsplittable_type_ignore(self) -> bool:
1653         if not self.leaves:
1654             return False
1655
1656         # If a 'type: ignore' is attached to the end of a line, we
1657         # can't split the line, because we can't know which of the
1658         # subexpressions the ignore was meant to apply to.
1659         #
1660         # We only want this to apply to actual physical lines from the
1661         # original source, though: we don't want the presence of a
1662         # 'type: ignore' at the end of a multiline expression to
1663         # justify pushing it all onto one line. Thus we
1664         # (unfortunately) need to check the actual source lines and
1665         # only report an unsplittable 'type: ignore' if this line was
1666         # one line in the original code.
1667
1668         # Grab the first and last line numbers, skipping generated leaves
1669         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1670         last_line = next(
1671             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1672         )
1673
1674         if first_line == last_line:
1675             # We look at the last two leaves since a comma or an
1676             # invisible paren could have been added at the end of the
1677             # line.
1678             for node in self.leaves[-2:]:
1679                 for comment in self.comments.get(id(node), []):
1680                     if is_type_comment(comment, " ignore"):
1681                         return True
1682
1683         return False
1684
1685     def contains_multiline_strings(self) -> bool:
1686         return any(is_multiline_string(leaf) for leaf in self.leaves)
1687
1688     def has_magic_trailing_comma(
1689         self, closing: Leaf, ensure_removable: bool = False
1690     ) -> bool:
1691         """Return True if we have a magic trailing comma, that is when:
1692         - there's a trailing comma here
1693         - it's not a one-tuple
1694         Additionally, if ensure_removable:
1695         - it's not from square bracket indexing
1696         """
1697         if not (
1698             closing.type in CLOSING_BRACKETS
1699             and self.leaves
1700             and self.leaves[-1].type == token.COMMA
1701         ):
1702             return False
1703
1704         if closing.type == token.RBRACE:
1705             return True
1706
1707         if closing.type == token.RSQB:
1708             if not ensure_removable:
1709                 return True
1710             comma = self.leaves[-1]
1711             return bool(comma.parent and comma.parent.type == syms.listmaker)
1712
1713         if self.is_import:
1714             return True
1715
1716         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1717             return True
1718
1719         return False
1720
1721     def append_comment(self, comment: Leaf) -> bool:
1722         """Add an inline or standalone comment to the line."""
1723         if (
1724             comment.type == STANDALONE_COMMENT
1725             and self.bracket_tracker.any_open_brackets()
1726         ):
1727             comment.prefix = ""
1728             return False
1729
1730         if comment.type != token.COMMENT:
1731             return False
1732
1733         if not self.leaves:
1734             comment.type = STANDALONE_COMMENT
1735             comment.prefix = ""
1736             return False
1737
1738         last_leaf = self.leaves[-1]
1739         if (
1740             last_leaf.type == token.RPAR
1741             and not last_leaf.value
1742             and last_leaf.parent
1743             and len(list(last_leaf.parent.leaves())) <= 3
1744             and not is_type_comment(comment)
1745         ):
1746             # Comments on an optional parens wrapping a single leaf should belong to
1747             # the wrapped node except if it's a type comment. Pinning the comment like
1748             # this avoids unstable formatting caused by comment migration.
1749             if len(self.leaves) < 2:
1750                 comment.type = STANDALONE_COMMENT
1751                 comment.prefix = ""
1752                 return False
1753
1754             last_leaf = self.leaves[-2]
1755         self.comments.setdefault(id(last_leaf), []).append(comment)
1756         return True
1757
1758     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1759         """Generate comments that should appear directly after `leaf`."""
1760         return self.comments.get(id(leaf), [])
1761
1762     def remove_trailing_comma(self) -> None:
1763         """Remove the trailing comma and moves the comments attached to it."""
1764         trailing_comma = self.leaves.pop()
1765         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1766         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1767             trailing_comma_comments
1768         )
1769
1770     def is_complex_subscript(self, leaf: Leaf) -> bool:
1771         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1772         open_lsqb = self.bracket_tracker.get_open_lsqb()
1773         if open_lsqb is None:
1774             return False
1775
1776         subscript_start = open_lsqb.next_sibling
1777
1778         if isinstance(subscript_start, Node):
1779             if subscript_start.type == syms.listmaker:
1780                 return False
1781
1782             if subscript_start.type == syms.subscriptlist:
1783                 subscript_start = child_towards(subscript_start, leaf)
1784         return subscript_start is not None and any(
1785             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1786         )
1787
1788     def clone(self) -> "Line":
1789         return Line(
1790             mode=self.mode,
1791             depth=self.depth,
1792             inside_brackets=self.inside_brackets,
1793             should_explode=self.should_explode,
1794         )
1795
1796     def __str__(self) -> str:
1797         """Render the line."""
1798         if not self:
1799             return "\n"
1800
1801         indent = "    " * self.depth
1802         leaves = iter(self.leaves)
1803         first = next(leaves)
1804         res = f"{first.prefix}{indent}{first.value}"
1805         for leaf in leaves:
1806             res += str(leaf)
1807         for comment in itertools.chain.from_iterable(self.comments.values()):
1808             res += str(comment)
1809
1810         return res + "\n"
1811
1812     def __bool__(self) -> bool:
1813         """Return True if the line has leaves or comments."""
1814         return bool(self.leaves or self.comments)
1815
1816
1817 @dataclass
1818 class EmptyLineTracker:
1819     """Provides a stateful method that returns the number of potential extra
1820     empty lines needed before and after the currently processed line.
1821
1822     Note: this tracker works on lines that haven't been split yet.  It assumes
1823     the prefix of the first leaf consists of optional newlines.  Those newlines
1824     are consumed by `maybe_empty_lines()` and included in the computation.
1825     """
1826
1827     is_pyi: bool = False
1828     previous_line: Optional[Line] = None
1829     previous_after: int = 0
1830     previous_defs: List[int] = field(default_factory=list)
1831
1832     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1833         """Return the number of extra empty lines before and after the `current_line`.
1834
1835         This is for separating `def`, `async def` and `class` with extra empty
1836         lines (two on module-level).
1837         """
1838         before, after = self._maybe_empty_lines(current_line)
1839         before = (
1840             # Black should not insert empty lines at the beginning
1841             # of the file
1842             0
1843             if self.previous_line is None
1844             else before - self.previous_after
1845         )
1846         self.previous_after = after
1847         self.previous_line = current_line
1848         return before, after
1849
1850     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1851         max_allowed = 1
1852         if current_line.depth == 0:
1853             max_allowed = 1 if self.is_pyi else 2
1854         if current_line.leaves:
1855             # Consume the first leaf's extra newlines.
1856             first_leaf = current_line.leaves[0]
1857             before = first_leaf.prefix.count("\n")
1858             before = min(before, max_allowed)
1859             first_leaf.prefix = ""
1860         else:
1861             before = 0
1862         depth = current_line.depth
1863         while self.previous_defs and self.previous_defs[-1] >= depth:
1864             self.previous_defs.pop()
1865             if self.is_pyi:
1866                 before = 0 if depth else 1
1867             else:
1868                 before = 1 if depth else 2
1869         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1870             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1871
1872         if (
1873             self.previous_line
1874             and self.previous_line.is_import
1875             and not current_line.is_import
1876             and depth == self.previous_line.depth
1877         ):
1878             return (before or 1), 0
1879
1880         if (
1881             self.previous_line
1882             and self.previous_line.is_class
1883             and current_line.is_triple_quoted_string
1884         ):
1885             return before, 1
1886
1887         return before, 0
1888
1889     def _maybe_empty_lines_for_class_or_def(
1890         self, current_line: Line, before: int
1891     ) -> Tuple[int, int]:
1892         if not current_line.is_decorator:
1893             self.previous_defs.append(current_line.depth)
1894         if self.previous_line is None:
1895             # Don't insert empty lines before the first line in the file.
1896             return 0, 0
1897
1898         if self.previous_line.is_decorator:
1899             if self.is_pyi and current_line.is_stub_class:
1900                 # Insert an empty line after a decorated stub class
1901                 return 0, 1
1902
1903             return 0, 0
1904
1905         if self.previous_line.depth < current_line.depth and (
1906             self.previous_line.is_class or self.previous_line.is_def
1907         ):
1908             return 0, 0
1909
1910         if (
1911             self.previous_line.is_comment
1912             and self.previous_line.depth == current_line.depth
1913             and before == 0
1914         ):
1915             return 0, 0
1916
1917         if self.is_pyi:
1918             if self.previous_line.depth > current_line.depth:
1919                 newlines = 1
1920             elif current_line.is_class or self.previous_line.is_class:
1921                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1922                     # No blank line between classes with an empty body
1923                     newlines = 0
1924                 else:
1925                     newlines = 1
1926             elif (
1927                 current_line.is_def or current_line.is_decorator
1928             ) and not self.previous_line.is_def:
1929                 # Blank line between a block of functions (maybe with preceding
1930                 # decorators) and a block of non-functions
1931                 newlines = 1
1932             else:
1933                 newlines = 0
1934         else:
1935             newlines = 2
1936         if current_line.depth and newlines:
1937             newlines -= 1
1938         return newlines, 0
1939
1940
1941 @dataclass
1942 class LineGenerator(Visitor[Line]):
1943     """Generates reformatted Line objects.  Empty lines are not emitted.
1944
1945     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1946     in ways that will no longer stringify to valid Python code on the tree.
1947     """
1948
1949     mode: Mode
1950     remove_u_prefix: bool = False
1951     current_line: Line = field(init=False)
1952
1953     def line(self, indent: int = 0) -> Iterator[Line]:
1954         """Generate a line.
1955
1956         If the line is empty, only emit if it makes sense.
1957         If the line is too long, split it first and then generate.
1958
1959         If any lines were generated, set up a new current_line.
1960         """
1961         if not self.current_line:
1962             self.current_line.depth += indent
1963             return  # Line is empty, don't emit. Creating a new one unnecessary.
1964
1965         complete_line = self.current_line
1966         self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
1967         yield complete_line
1968
1969     def visit_default(self, node: LN) -> Iterator[Line]:
1970         """Default `visit_*()` implementation. Recurses to children of `node`."""
1971         if isinstance(node, Leaf):
1972             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1973             for comment in generate_comments(node):
1974                 if any_open_brackets:
1975                     # any comment within brackets is subject to splitting
1976                     self.current_line.append(comment)
1977                 elif comment.type == token.COMMENT:
1978                     # regular trailing comment
1979                     self.current_line.append(comment)
1980                     yield from self.line()
1981
1982                 else:
1983                     # regular standalone comment
1984                     yield from self.line()
1985
1986                     self.current_line.append(comment)
1987                     yield from self.line()
1988
1989             normalize_prefix(node, inside_brackets=any_open_brackets)
1990             if self.mode.string_normalization and node.type == token.STRING:
1991                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1992                 normalize_string_quotes(node)
1993             if node.type == token.NUMBER:
1994                 normalize_numeric_literal(node)
1995             if node.type not in WHITESPACE:
1996                 self.current_line.append(node)
1997         yield from super().visit_default(node)
1998
1999     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
2000         """Increase indentation level, maybe yield a line."""
2001         # In blib2to3 INDENT never holds comments.
2002         yield from self.line(+1)
2003         yield from self.visit_default(node)
2004
2005     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
2006         """Decrease indentation level, maybe yield a line."""
2007         # The current line might still wait for trailing comments.  At DEDENT time
2008         # there won't be any (they would be prefixes on the preceding NEWLINE).
2009         # Emit the line then.
2010         yield from self.line()
2011
2012         # While DEDENT has no value, its prefix may contain standalone comments
2013         # that belong to the current indentation level.  Get 'em.
2014         yield from self.visit_default(node)
2015
2016         # Finally, emit the dedent.
2017         yield from self.line(-1)
2018
2019     def visit_stmt(
2020         self, node: Node, keywords: Set[str], parens: Set[str]
2021     ) -> Iterator[Line]:
2022         """Visit a statement.
2023
2024         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2025         `def`, `with`, `class`, `assert` and assignments.
2026
2027         The relevant Python language `keywords` for a given statement will be
2028         NAME leaves within it. This methods puts those on a separate line.
2029
2030         `parens` holds a set of string leaf values immediately after which
2031         invisible parens should be put.
2032         """
2033         normalize_invisible_parens(node, parens_after=parens)
2034         for child in node.children:
2035             if child.type == token.NAME and child.value in keywords:  # type: ignore
2036                 yield from self.line()
2037
2038             yield from self.visit(child)
2039
2040     def visit_suite(self, node: Node) -> Iterator[Line]:
2041         """Visit a suite."""
2042         if self.mode.is_pyi and is_stub_suite(node):
2043             yield from self.visit(node.children[2])
2044         else:
2045             yield from self.visit_default(node)
2046
2047     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2048         """Visit a statement without nested statements."""
2049         is_suite_like = node.parent and node.parent.type in STATEMENT
2050         if is_suite_like:
2051             if self.mode.is_pyi and is_stub_body(node):
2052                 yield from self.visit_default(node)
2053             else:
2054                 yield from self.line(+1)
2055                 yield from self.visit_default(node)
2056                 yield from self.line(-1)
2057
2058         else:
2059             if (
2060                 not self.mode.is_pyi
2061                 or not node.parent
2062                 or not is_stub_suite(node.parent)
2063             ):
2064                 yield from self.line()
2065             yield from self.visit_default(node)
2066
2067     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2068         """Visit `async def`, `async for`, `async with`."""
2069         yield from self.line()
2070
2071         children = iter(node.children)
2072         for child in children:
2073             yield from self.visit(child)
2074
2075             if child.type == token.ASYNC:
2076                 break
2077
2078         internal_stmt = next(children)
2079         for child in internal_stmt.children:
2080             yield from self.visit(child)
2081
2082     def visit_decorators(self, node: Node) -> Iterator[Line]:
2083         """Visit decorators."""
2084         for child in node.children:
2085             yield from self.line()
2086             yield from self.visit(child)
2087
2088     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2089         """Remove a semicolon and put the other statement on a separate line."""
2090         yield from self.line()
2091
2092     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2093         """End of file. Process outstanding comments and end with a newline."""
2094         yield from self.visit_default(leaf)
2095         yield from self.line()
2096
2097     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2098         if not self.current_line.bracket_tracker.any_open_brackets():
2099             yield from self.line()
2100         yield from self.visit_default(leaf)
2101
2102     def visit_factor(self, node: Node) -> Iterator[Line]:
2103         """Force parentheses between a unary op and a binary power:
2104
2105         -2 ** 8 -> -(2 ** 8)
2106         """
2107         _operator, operand = node.children
2108         if (
2109             operand.type == syms.power
2110             and len(operand.children) == 3
2111             and operand.children[1].type == token.DOUBLESTAR
2112         ):
2113             lpar = Leaf(token.LPAR, "(")
2114             rpar = Leaf(token.RPAR, ")")
2115             index = operand.remove() or 0
2116             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2117         yield from self.visit_default(node)
2118
2119     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2120         if is_docstring(leaf) and "\\\n" not in leaf.value:
2121             # We're ignoring docstrings with backslash newline escapes because changing
2122             # indentation of those changes the AST representation of the code.
2123             prefix = get_string_prefix(leaf.value)
2124             lead_len = len(prefix) + 3
2125             tail_len = -3
2126             indent = " " * 4 * self.current_line.depth
2127             docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
2128             if docstring:
2129                 if leaf.value[lead_len - 1] == docstring[0]:
2130                     docstring = " " + docstring
2131                 if leaf.value[tail_len + 1] == docstring[-1]:
2132                     docstring = docstring + " "
2133             leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
2134
2135         yield from self.visit_default(leaf)
2136
2137     def __post_init__(self) -> None:
2138         """You are in a twisty little maze of passages."""
2139         self.current_line = Line(mode=self.mode)
2140
2141         v = self.visit_stmt
2142         Ø: Set[str] = set()
2143         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2144         self.visit_if_stmt = partial(
2145             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2146         )
2147         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2148         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2149         self.visit_try_stmt = partial(
2150             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2151         )
2152         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2153         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2154         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2155         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2156         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2157         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2158         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2159         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2160         self.visit_async_funcdef = self.visit_async_stmt
2161         self.visit_decorated = self.visit_decorators
2162
2163
2164 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2165 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2166 OPENING_BRACKETS = set(BRACKET.keys())
2167 CLOSING_BRACKETS = set(BRACKET.values())
2168 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2169 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2170
2171
2172 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2173     """Return whitespace prefix if needed for the given `leaf`.
2174
2175     `complex_subscript` signals whether the given leaf is part of a subscription
2176     which has non-trivial arguments, like arithmetic expressions or function calls.
2177     """
2178     NO = ""
2179     SPACE = " "
2180     DOUBLESPACE = "  "
2181     t = leaf.type
2182     p = leaf.parent
2183     v = leaf.value
2184     if t in ALWAYS_NO_SPACE:
2185         return NO
2186
2187     if t == token.COMMENT:
2188         return DOUBLESPACE
2189
2190     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2191     if t == token.COLON and p.type not in {
2192         syms.subscript,
2193         syms.subscriptlist,
2194         syms.sliceop,
2195     }:
2196         return NO
2197
2198     prev = leaf.prev_sibling
2199     if not prev:
2200         prevp = preceding_leaf(p)
2201         if not prevp or prevp.type in OPENING_BRACKETS:
2202             return NO
2203
2204         if t == token.COLON:
2205             if prevp.type == token.COLON:
2206                 return NO
2207
2208             elif prevp.type != token.COMMA and not complex_subscript:
2209                 return NO
2210
2211             return SPACE
2212
2213         if prevp.type == token.EQUAL:
2214             if prevp.parent:
2215                 if prevp.parent.type in {
2216                     syms.arglist,
2217                     syms.argument,
2218                     syms.parameters,
2219                     syms.varargslist,
2220                 }:
2221                     return NO
2222
2223                 elif prevp.parent.type == syms.typedargslist:
2224                     # A bit hacky: if the equal sign has whitespace, it means we
2225                     # previously found it's a typed argument.  So, we're using
2226                     # that, too.
2227                     return prevp.prefix
2228
2229         elif prevp.type in VARARGS_SPECIALS:
2230             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2231                 return NO
2232
2233         elif prevp.type == token.COLON:
2234             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2235                 return SPACE if complex_subscript else NO
2236
2237         elif (
2238             prevp.parent
2239             and prevp.parent.type == syms.factor
2240             and prevp.type in MATH_OPERATORS
2241         ):
2242             return NO
2243
2244         elif (
2245             prevp.type == token.RIGHTSHIFT
2246             and prevp.parent
2247             and prevp.parent.type == syms.shift_expr
2248             and prevp.prev_sibling
2249             and prevp.prev_sibling.type == token.NAME
2250             and prevp.prev_sibling.value == "print"  # type: ignore
2251         ):
2252             # Python 2 print chevron
2253             return NO
2254         elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
2255             # no space in decorators
2256             return NO
2257
2258     elif prev.type in OPENING_BRACKETS:
2259         return NO
2260
2261     if p.type in {syms.parameters, syms.arglist}:
2262         # untyped function signatures or calls
2263         if not prev or prev.type != token.COMMA:
2264             return NO
2265
2266     elif p.type == syms.varargslist:
2267         # lambdas
2268         if prev and prev.type != token.COMMA:
2269             return NO
2270
2271     elif p.type == syms.typedargslist:
2272         # typed function signatures
2273         if not prev:
2274             return NO
2275
2276         if t == token.EQUAL:
2277             if prev.type != syms.tname:
2278                 return NO
2279
2280         elif prev.type == token.EQUAL:
2281             # A bit hacky: if the equal sign has whitespace, it means we
2282             # previously found it's a typed argument.  So, we're using that, too.
2283             return prev.prefix
2284
2285         elif prev.type != token.COMMA:
2286             return NO
2287
2288     elif p.type == syms.tname:
2289         # type names
2290         if not prev:
2291             prevp = preceding_leaf(p)
2292             if not prevp or prevp.type != token.COMMA:
2293                 return NO
2294
2295     elif p.type == syms.trailer:
2296         # attributes and calls
2297         if t == token.LPAR or t == token.RPAR:
2298             return NO
2299
2300         if not prev:
2301             if t == token.DOT:
2302                 prevp = preceding_leaf(p)
2303                 if not prevp or prevp.type != token.NUMBER:
2304                     return NO
2305
2306             elif t == token.LSQB:
2307                 return NO
2308
2309         elif prev.type != token.COMMA:
2310             return NO
2311
2312     elif p.type == syms.argument:
2313         # single argument
2314         if t == token.EQUAL:
2315             return NO
2316
2317         if not prev:
2318             prevp = preceding_leaf(p)
2319             if not prevp or prevp.type == token.LPAR:
2320                 return NO
2321
2322         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2323             return NO
2324
2325     elif p.type == syms.decorator:
2326         # decorators
2327         return NO
2328
2329     elif p.type == syms.dotted_name:
2330         if prev:
2331             return NO
2332
2333         prevp = preceding_leaf(p)
2334         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2335             return NO
2336
2337     elif p.type == syms.classdef:
2338         if t == token.LPAR:
2339             return NO
2340
2341         if prev and prev.type == token.LPAR:
2342             return NO
2343
2344     elif p.type in {syms.subscript, syms.sliceop}:
2345         # indexing
2346         if not prev:
2347             assert p.parent is not None, "subscripts are always parented"
2348             if p.parent.type == syms.subscriptlist:
2349                 return SPACE
2350
2351             return NO
2352
2353         elif not complex_subscript:
2354             return NO
2355
2356     elif p.type == syms.atom:
2357         if prev and t == token.DOT:
2358             # dots, but not the first one.
2359             return NO
2360
2361     elif p.type == syms.dictsetmaker:
2362         # dict unpacking
2363         if prev and prev.type == token.DOUBLESTAR:
2364             return NO
2365
2366     elif p.type in {syms.factor, syms.star_expr}:
2367         # unary ops
2368         if not prev:
2369             prevp = preceding_leaf(p)
2370             if not prevp or prevp.type in OPENING_BRACKETS:
2371                 return NO
2372
2373             prevp_parent = prevp.parent
2374             assert prevp_parent is not None
2375             if prevp.type == token.COLON and prevp_parent.type in {
2376                 syms.subscript,
2377                 syms.sliceop,
2378             }:
2379                 return NO
2380
2381             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2382                 return NO
2383
2384         elif t in {token.NAME, token.NUMBER, token.STRING}:
2385             return NO
2386
2387     elif p.type == syms.import_from:
2388         if t == token.DOT:
2389             if prev and prev.type == token.DOT:
2390                 return NO
2391
2392         elif t == token.NAME:
2393             if v == "import":
2394                 return SPACE
2395
2396             if prev and prev.type == token.DOT:
2397                 return NO
2398
2399     elif p.type == syms.sliceop:
2400         return NO
2401
2402     return SPACE
2403
2404
2405 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2406     """Return the first leaf that precedes `node`, if any."""
2407     while node:
2408         res = node.prev_sibling
2409         if res:
2410             if isinstance(res, Leaf):
2411                 return res
2412
2413             try:
2414                 return list(res.leaves())[-1]
2415
2416             except IndexError:
2417                 return None
2418
2419         node = node.parent
2420     return None
2421
2422
2423 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2424     """Return if the `node` and its previous siblings match types against the provided
2425     list of tokens; the provided `node`has its type matched against the last element in
2426     the list.  `None` can be used as the first element to declare that the start of the
2427     list is anchored at the start of its parent's children."""
2428     if not tokens:
2429         return True
2430     if tokens[-1] is None:
2431         return node is None
2432     if not node:
2433         return False
2434     if node.type != tokens[-1]:
2435         return False
2436     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2437
2438
2439 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2440     """Return the child of `ancestor` that contains `descendant`."""
2441     node: Optional[LN] = descendant
2442     while node and node.parent != ancestor:
2443         node = node.parent
2444     return node
2445
2446
2447 def container_of(leaf: Leaf) -> LN:
2448     """Return `leaf` or one of its ancestors that is the topmost container of it.
2449
2450     By "container" we mean a node where `leaf` is the very first child.
2451     """
2452     same_prefix = leaf.prefix
2453     container: LN = leaf
2454     while container:
2455         parent = container.parent
2456         if parent is None:
2457             break
2458
2459         if parent.children[0].prefix != same_prefix:
2460             break
2461
2462         if parent.type == syms.file_input:
2463             break
2464
2465         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2466             break
2467
2468         container = parent
2469     return container
2470
2471
2472 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2473     """Return the priority of the `leaf` delimiter, given a line break after it.
2474
2475     The delimiter priorities returned here are from those delimiters that would
2476     cause a line break after themselves.
2477
2478     Higher numbers are higher priority.
2479     """
2480     if leaf.type == token.COMMA:
2481         return COMMA_PRIORITY
2482
2483     return 0
2484
2485
2486 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2487     """Return the priority of the `leaf` delimiter, given a line break before it.
2488
2489     The delimiter priorities returned here are from those delimiters that would
2490     cause a line break before themselves.
2491
2492     Higher numbers are higher priority.
2493     """
2494     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2495         # * and ** might also be MATH_OPERATORS but in this case they are not.
2496         # Don't treat them as a delimiter.
2497         return 0
2498
2499     if (
2500         leaf.type == token.DOT
2501         and leaf.parent
2502         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2503         and (previous is None or previous.type in CLOSING_BRACKETS)
2504     ):
2505         return DOT_PRIORITY
2506
2507     if (
2508         leaf.type in MATH_OPERATORS
2509         and leaf.parent
2510         and leaf.parent.type not in {syms.factor, syms.star_expr}
2511     ):
2512         return MATH_PRIORITIES[leaf.type]
2513
2514     if leaf.type in COMPARATORS:
2515         return COMPARATOR_PRIORITY
2516
2517     if (
2518         leaf.type == token.STRING
2519         and previous is not None
2520         and previous.type == token.STRING
2521     ):
2522         return STRING_PRIORITY
2523
2524     if leaf.type not in {token.NAME, token.ASYNC}:
2525         return 0
2526
2527     if (
2528         leaf.value == "for"
2529         and leaf.parent
2530         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2531         or leaf.type == token.ASYNC
2532     ):
2533         if (
2534             not isinstance(leaf.prev_sibling, Leaf)
2535             or leaf.prev_sibling.value != "async"
2536         ):
2537             return COMPREHENSION_PRIORITY
2538
2539     if (
2540         leaf.value == "if"
2541         and leaf.parent
2542         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2543     ):
2544         return COMPREHENSION_PRIORITY
2545
2546     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2547         return TERNARY_PRIORITY
2548
2549     if leaf.value == "is":
2550         return COMPARATOR_PRIORITY
2551
2552     if (
2553         leaf.value == "in"
2554         and leaf.parent
2555         and leaf.parent.type in {syms.comp_op, syms.comparison}
2556         and not (
2557             previous is not None
2558             and previous.type == token.NAME
2559             and previous.value == "not"
2560         )
2561     ):
2562         return COMPARATOR_PRIORITY
2563
2564     if (
2565         leaf.value == "not"
2566         and leaf.parent
2567         and leaf.parent.type == syms.comp_op
2568         and not (
2569             previous is not None
2570             and previous.type == token.NAME
2571             and previous.value == "is"
2572         )
2573     ):
2574         return COMPARATOR_PRIORITY
2575
2576     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2577         return LOGIC_PRIORITY
2578
2579     return 0
2580
2581
2582 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2583 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2584
2585
2586 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2587     """Clean the prefix of the `leaf` and generate comments from it, if any.
2588
2589     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2590     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2591     move because it does away with modifying the grammar to include all the
2592     possible places in which comments can be placed.
2593
2594     The sad consequence for us though is that comments don't "belong" anywhere.
2595     This is why this function generates simple parentless Leaf objects for
2596     comments.  We simply don't know what the correct parent should be.
2597
2598     No matter though, we can live without this.  We really only need to
2599     differentiate between inline and standalone comments.  The latter don't
2600     share the line with any code.
2601
2602     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2603     are emitted with a fake STANDALONE_COMMENT token identifier.
2604     """
2605     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2606         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2607
2608
2609 @dataclass
2610 class ProtoComment:
2611     """Describes a piece of syntax that is a comment.
2612
2613     It's not a :class:`blib2to3.pytree.Leaf` so that:
2614
2615     * it can be cached (`Leaf` objects should not be reused more than once as
2616       they store their lineno, column, prefix, and parent information);
2617     * `newlines` and `consumed` fields are kept separate from the `value`. This
2618       simplifies handling of special marker comments like ``# fmt: off/on``.
2619     """
2620
2621     type: int  # token.COMMENT or STANDALONE_COMMENT
2622     value: str  # content of the comment
2623     newlines: int  # how many newlines before the comment
2624     consumed: int  # how many characters of the original leaf's prefix did we consume
2625
2626
2627 @lru_cache(maxsize=4096)
2628 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2629     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2630     result: List[ProtoComment] = []
2631     if not prefix or "#" not in prefix:
2632         return result
2633
2634     consumed = 0
2635     nlines = 0
2636     ignored_lines = 0
2637     for index, line in enumerate(prefix.split("\n")):
2638         consumed += len(line) + 1  # adding the length of the split '\n'
2639         line = line.lstrip()
2640         if not line:
2641             nlines += 1
2642         if not line.startswith("#"):
2643             # Escaped newlines outside of a comment are not really newlines at
2644             # all. We treat a single-line comment following an escaped newline
2645             # as a simple trailing comment.
2646             if line.endswith("\\"):
2647                 ignored_lines += 1
2648             continue
2649
2650         if index == ignored_lines and not is_endmarker:
2651             comment_type = token.COMMENT  # simple trailing comment
2652         else:
2653             comment_type = STANDALONE_COMMENT
2654         comment = make_comment(line)
2655         result.append(
2656             ProtoComment(
2657                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2658             )
2659         )
2660         nlines = 0
2661     return result
2662
2663
2664 def make_comment(content: str) -> str:
2665     """Return a consistently formatted comment from the given `content` string.
2666
2667     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2668     space between the hash sign and the content.
2669
2670     If `content` didn't start with a hash sign, one is provided.
2671     """
2672     content = content.rstrip()
2673     if not content:
2674         return "#"
2675
2676     if content[0] == "#":
2677         content = content[1:]
2678     if content and content[0] not in " !:#'%":
2679         content = " " + content
2680     return "#" + content
2681
2682
2683 def transform_line(
2684     line: Line, mode: Mode, features: Collection[Feature] = ()
2685 ) -> Iterator[Line]:
2686     """Transform a `line`, potentially splitting it into many lines.
2687
2688     They should fit in the allotted `line_length` but might not be able to.
2689
2690     `features` are syntactical features that may be used in the output.
2691     """
2692     if line.is_comment:
2693         yield line
2694         return
2695
2696     line_str = line_to_string(line)
2697
2698     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2699         """Initialize StringTransformer"""
2700         return ST(mode.line_length, mode.string_normalization)
2701
2702     string_merge = init_st(StringMerger)
2703     string_paren_strip = init_st(StringParenStripper)
2704     string_split = init_st(StringSplitter)
2705     string_paren_wrap = init_st(StringParenWrapper)
2706
2707     transformers: List[Transformer]
2708     if (
2709         not line.contains_uncollapsable_type_comments()
2710         and not line.should_explode
2711         and (
2712             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2713             or line.contains_unsplittable_type_ignore()
2714         )
2715         and not (line.inside_brackets and line.contains_standalone_comments())
2716     ):
2717         # Only apply basic string preprocessing, since lines shouldn't be split here.
2718         if mode.experimental_string_processing:
2719             transformers = [string_merge, string_paren_strip]
2720         else:
2721             transformers = []
2722     elif line.is_def:
2723         transformers = [left_hand_split]
2724     else:
2725
2726         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2727             """Wraps calls to `right_hand_split`.
2728
2729             The calls increasingly `omit` right-hand trailers (bracket pairs with
2730             content), meaning the trailers get glued together to split on another
2731             bracket pair instead.
2732             """
2733             for omit in generate_trailers_to_omit(line, mode.line_length):
2734                 lines = list(
2735                     right_hand_split(line, mode.line_length, features, omit=omit)
2736                 )
2737                 # Note: this check is only able to figure out if the first line of the
2738                 # *current* transformation fits in the line length.  This is true only
2739                 # for simple cases.  All others require running more transforms via
2740                 # `transform_line()`.  This check doesn't know if those would succeed.
2741                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2742                     yield from lines
2743                     return
2744
2745             # All splits failed, best effort split with no omits.
2746             # This mostly happens to multiline strings that are by definition
2747             # reported as not fitting a single line, as well as lines that contain
2748             # trailing commas (those have to be exploded).
2749             yield from right_hand_split(
2750                 line, line_length=mode.line_length, features=features
2751             )
2752
2753         if mode.experimental_string_processing:
2754             if line.inside_brackets:
2755                 transformers = [
2756                     string_merge,
2757                     string_paren_strip,
2758                     string_split,
2759                     delimiter_split,
2760                     standalone_comment_split,
2761                     string_paren_wrap,
2762                     rhs,
2763                 ]
2764             else:
2765                 transformers = [
2766                     string_merge,
2767                     string_paren_strip,
2768                     string_split,
2769                     string_paren_wrap,
2770                     rhs,
2771                 ]
2772         else:
2773             if line.inside_brackets:
2774                 transformers = [delimiter_split, standalone_comment_split, rhs]
2775             else:
2776                 transformers = [rhs]
2777
2778     for transform in transformers:
2779         # We are accumulating lines in `result` because we might want to abort
2780         # mission and return the original line in the end, or attempt a different
2781         # split altogether.
2782         try:
2783             result = run_transformer(line, transform, mode, features, line_str=line_str)
2784         except CannotTransform:
2785             continue
2786         else:
2787             yield from result
2788             break
2789
2790     else:
2791         yield line
2792
2793
2794 @dataclass  # type: ignore
2795 class StringTransformer(ABC):
2796     """
2797     An implementation of the Transformer protocol that relies on its
2798     subclasses overriding the template methods `do_match(...)` and
2799     `do_transform(...)`.
2800
2801     This Transformer works exclusively on strings (for example, by merging
2802     or splitting them).
2803
2804     The following sections can be found among the docstrings of each concrete
2805     StringTransformer subclass.
2806
2807     Requirements:
2808         Which requirements must be met of the given Line for this
2809         StringTransformer to be applied?
2810
2811     Transformations:
2812         If the given Line meets all of the above requirements, which string
2813         transformations can you expect to be applied to it by this
2814         StringTransformer?
2815
2816     Collaborations:
2817         What contractual agreements does this StringTransformer have with other
2818         StringTransfomers? Such collaborations should be eliminated/minimized
2819         as much as possible.
2820     """
2821
2822     line_length: int
2823     normalize_strings: bool
2824     __name__ = "StringTransformer"
2825
2826     @abstractmethod
2827     def do_match(self, line: Line) -> TMatchResult:
2828         """
2829         Returns:
2830             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2831             string, if a match was able to be made.
2832                 OR
2833             * Err(CannotTransform), if a match was not able to be made.
2834         """
2835
2836     @abstractmethod
2837     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2838         """
2839         Yields:
2840             * Ok(new_line) where new_line is the new transformed line.
2841                 OR
2842             * Err(CannotTransform) if the transformation failed for some reason. The
2843             `do_match(...)` template method should usually be used to reject
2844             the form of the given Line, but in some cases it is difficult to
2845             know whether or not a Line meets the StringTransformer's
2846             requirements until the transformation is already midway.
2847
2848         Side Effects:
2849             This method should NOT mutate @line directly, but it MAY mutate the
2850             Line's underlying Node structure. (WARNING: If the underlying Node
2851             structure IS altered, then this method should NOT be allowed to
2852             yield an CannotTransform after that point.)
2853         """
2854
2855     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2856         """
2857         StringTransformer instances have a call signature that mirrors that of
2858         the Transformer type.
2859
2860         Raises:
2861             CannotTransform(...) if the concrete StringTransformer class is unable
2862             to transform @line.
2863         """
2864         # Optimization to avoid calling `self.do_match(...)` when the line does
2865         # not contain any string.
2866         if not any(leaf.type == token.STRING for leaf in line.leaves):
2867             raise CannotTransform("There are no strings in this line.")
2868
2869         match_result = self.do_match(line)
2870
2871         if isinstance(match_result, Err):
2872             cant_transform = match_result.err()
2873             raise CannotTransform(
2874                 f"The string transformer {self.__class__.__name__} does not recognize"
2875                 " this line as one that it can transform."
2876             ) from cant_transform
2877
2878         string_idx = match_result.ok()
2879
2880         for line_result in self.do_transform(line, string_idx):
2881             if isinstance(line_result, Err):
2882                 cant_transform = line_result.err()
2883                 raise CannotTransform(
2884                     "StringTransformer failed while attempting to transform string."
2885                 ) from cant_transform
2886             line = line_result.ok()
2887             yield line
2888
2889
2890 @dataclass
2891 class CustomSplit:
2892     """A custom (i.e. manual) string split.
2893
2894     A single CustomSplit instance represents a single substring.
2895
2896     Examples:
2897         Consider the following string:
2898         ```
2899         "Hi there friend."
2900         " This is a custom"
2901         f" string {split}."
2902         ```
2903
2904         This string will correspond to the following three CustomSplit instances:
2905         ```
2906         CustomSplit(False, 16)
2907         CustomSplit(False, 17)
2908         CustomSplit(True, 16)
2909         ```
2910     """
2911
2912     has_prefix: bool
2913     break_idx: int
2914
2915
2916 class CustomSplitMapMixin:
2917     """
2918     This mixin class is used to map merged strings to a sequence of
2919     CustomSplits, which will then be used to re-split the strings iff none of
2920     the resultant substrings go over the configured max line length.
2921     """
2922
2923     _Key = Tuple[StringID, str]
2924     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2925
2926     @staticmethod
2927     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2928         """
2929         Returns:
2930             A unique identifier that is used internally to map @string to a
2931             group of custom splits.
2932         """
2933         return (id(string), string)
2934
2935     def add_custom_splits(
2936         self, string: str, custom_splits: Iterable[CustomSplit]
2937     ) -> None:
2938         """Custom Split Map Setter Method
2939
2940         Side Effects:
2941             Adds a mapping from @string to the custom splits @custom_splits.
2942         """
2943         key = self._get_key(string)
2944         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2945
2946     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2947         """Custom Split Map Getter Method
2948
2949         Returns:
2950             * A list of the custom splits that are mapped to @string, if any
2951             exist.
2952                 OR
2953             * [], otherwise.
2954
2955         Side Effects:
2956             Deletes the mapping between @string and its associated custom
2957             splits (which are returned to the caller).
2958         """
2959         key = self._get_key(string)
2960
2961         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2962         del self._CUSTOM_SPLIT_MAP[key]
2963
2964         return list(custom_splits)
2965
2966     def has_custom_splits(self, string: str) -> bool:
2967         """
2968         Returns:
2969             True iff @string is associated with a set of custom splits.
2970         """
2971         key = self._get_key(string)
2972         return key in self._CUSTOM_SPLIT_MAP
2973
2974
2975 class StringMerger(CustomSplitMapMixin, StringTransformer):
2976     """StringTransformer that merges strings together.
2977
2978     Requirements:
2979         (A) The line contains adjacent strings such that ALL of the validation checks
2980         listed in StringMerger.__validate_msg(...)'s docstring pass.
2981             OR
2982         (B) The line contains a string which uses line continuation backslashes.
2983
2984     Transformations:
2985         Depending on which of the two requirements above where met, either:
2986
2987         (A) The string group associated with the target string is merged.
2988             OR
2989         (B) All line-continuation backslashes are removed from the target string.
2990
2991     Collaborations:
2992         StringMerger provides custom split information to StringSplitter.
2993     """
2994
2995     def do_match(self, line: Line) -> TMatchResult:
2996         LL = line.leaves
2997
2998         is_valid_index = is_valid_index_factory(LL)
2999
3000         for (i, leaf) in enumerate(LL):
3001             if (
3002                 leaf.type == token.STRING
3003                 and is_valid_index(i + 1)
3004                 and LL[i + 1].type == token.STRING
3005             ):
3006                 return Ok(i)
3007
3008             if leaf.type == token.STRING and "\\\n" in leaf.value:
3009                 return Ok(i)
3010
3011         return TErr("This line has no strings that need merging.")
3012
3013     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3014         new_line = line
3015         rblc_result = self.__remove_backslash_line_continuation_chars(
3016             new_line, string_idx
3017         )
3018         if isinstance(rblc_result, Ok):
3019             new_line = rblc_result.ok()
3020
3021         msg_result = self.__merge_string_group(new_line, string_idx)
3022         if isinstance(msg_result, Ok):
3023             new_line = msg_result.ok()
3024
3025         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
3026             msg_cant_transform = msg_result.err()
3027             rblc_cant_transform = rblc_result.err()
3028             cant_transform = CannotTransform(
3029                 "StringMerger failed to merge any strings in this line."
3030             )
3031
3032             # Chain the errors together using `__cause__`.
3033             msg_cant_transform.__cause__ = rblc_cant_transform
3034             cant_transform.__cause__ = msg_cant_transform
3035
3036             yield Err(cant_transform)
3037         else:
3038             yield Ok(new_line)
3039
3040     @staticmethod
3041     def __remove_backslash_line_continuation_chars(
3042         line: Line, string_idx: int
3043     ) -> TResult[Line]:
3044         """
3045         Merge strings that were split across multiple lines using
3046         line-continuation backslashes.
3047
3048         Returns:
3049             Ok(new_line), if @line contains backslash line-continuation
3050             characters.
3051                 OR
3052             Err(CannotTransform), otherwise.
3053         """
3054         LL = line.leaves
3055
3056         string_leaf = LL[string_idx]
3057         if not (
3058             string_leaf.type == token.STRING
3059             and "\\\n" in string_leaf.value
3060             and not has_triple_quotes(string_leaf.value)
3061         ):
3062             return TErr(
3063                 f"String leaf {string_leaf} does not contain any backslash line"
3064                 " continuation characters."
3065             )
3066
3067         new_line = line.clone()
3068         new_line.comments = line.comments.copy()
3069         append_leaves(new_line, line, LL)
3070
3071         new_string_leaf = new_line.leaves[string_idx]
3072         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3073
3074         return Ok(new_line)
3075
3076     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3077         """
3078         Merges string group (i.e. set of adjacent strings) where the first
3079         string in the group is `line.leaves[string_idx]`.
3080
3081         Returns:
3082             Ok(new_line), if ALL of the validation checks found in
3083             __validate_msg(...) pass.
3084                 OR
3085             Err(CannotTransform), otherwise.
3086         """
3087         LL = line.leaves
3088
3089         is_valid_index = is_valid_index_factory(LL)
3090
3091         vresult = self.__validate_msg(line, string_idx)
3092         if isinstance(vresult, Err):
3093             return vresult
3094
3095         # If the string group is wrapped inside an Atom node, we must make sure
3096         # to later replace that Atom with our new (merged) string leaf.
3097         atom_node = LL[string_idx].parent
3098
3099         # We will place BREAK_MARK in between every two substrings that we
3100         # merge. We will then later go through our final result and use the
3101         # various instances of BREAK_MARK we find to add the right values to
3102         # the custom split map.
3103         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3104
3105         QUOTE = LL[string_idx].value[-1]
3106
3107         def make_naked(string: str, string_prefix: str) -> str:
3108             """Strip @string (i.e. make it a "naked" string)
3109
3110             Pre-conditions:
3111                 * assert_is_leaf_string(@string)
3112
3113             Returns:
3114                 A string that is identical to @string except that
3115                 @string_prefix has been stripped, the surrounding QUOTE
3116                 characters have been removed, and any remaining QUOTE
3117                 characters have been escaped.
3118             """
3119             assert_is_leaf_string(string)
3120
3121             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3122             naked_string = string[len(string_prefix) + 1 : -1]
3123             naked_string = re.sub(
3124                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3125             )
3126             return naked_string
3127
3128         # Holds the CustomSplit objects that will later be added to the custom
3129         # split map.
3130         custom_splits = []
3131
3132         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3133         prefix_tracker = []
3134
3135         # Sets the 'prefix' variable. This is the prefix that the final merged
3136         # string will have.
3137         next_str_idx = string_idx
3138         prefix = ""
3139         while (
3140             not prefix
3141             and is_valid_index(next_str_idx)
3142             and LL[next_str_idx].type == token.STRING
3143         ):
3144             prefix = get_string_prefix(LL[next_str_idx].value)
3145             next_str_idx += 1
3146
3147         # The next loop merges the string group. The final string will be
3148         # contained in 'S'.
3149         #
3150         # The following convenience variables are used:
3151         #
3152         #   S: string
3153         #   NS: naked string
3154         #   SS: next string
3155         #   NSS: naked next string
3156         S = ""
3157         NS = ""
3158         num_of_strings = 0
3159         next_str_idx = string_idx
3160         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3161             num_of_strings += 1
3162
3163             SS = LL[next_str_idx].value
3164             next_prefix = get_string_prefix(SS)
3165
3166             # If this is an f-string group but this substring is not prefixed
3167             # with 'f'...
3168             if "f" in prefix and "f" not in next_prefix:
3169                 # Then we must escape any braces contained in this substring.
3170                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3171
3172             NSS = make_naked(SS, next_prefix)
3173
3174             has_prefix = bool(next_prefix)
3175             prefix_tracker.append(has_prefix)
3176
3177             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3178             NS = make_naked(S, prefix)
3179
3180             next_str_idx += 1
3181
3182         S_leaf = Leaf(token.STRING, S)
3183         if self.normalize_strings:
3184             normalize_string_quotes(S_leaf)
3185
3186         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3187         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3188         for has_prefix in prefix_tracker:
3189             mark_idx = temp_string.find(BREAK_MARK)
3190             assert (
3191                 mark_idx >= 0
3192             ), "Logic error while filling the custom string breakpoint cache."
3193
3194             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3195             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3196             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3197
3198         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3199
3200         if atom_node is not None:
3201             replace_child(atom_node, string_leaf)
3202
3203         # Build the final line ('new_line') that this method will later return.
3204         new_line = line.clone()
3205         for (i, leaf) in enumerate(LL):
3206             if i == string_idx:
3207                 new_line.append(string_leaf)
3208
3209             if string_idx <= i < string_idx + num_of_strings:
3210                 for comment_leaf in line.comments_after(LL[i]):
3211                     new_line.append(comment_leaf, preformatted=True)
3212                 continue
3213
3214             append_leaves(new_line, line, [leaf])
3215
3216         self.add_custom_splits(string_leaf.value, custom_splits)
3217         return Ok(new_line)
3218
3219     @staticmethod
3220     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3221         """Validate (M)erge (S)tring (G)roup
3222
3223         Transform-time string validation logic for __merge_string_group(...).
3224
3225         Returns:
3226             * Ok(None), if ALL validation checks (listed below) pass.
3227                 OR
3228             * Err(CannotTransform), if any of the following are true:
3229                 - The target string group does not contain ANY stand-alone comments.
3230                 - The target string is not in a string group (i.e. it has no
3231                   adjacent strings).
3232                 - The string group has more than one inline comment.
3233                 - The string group has an inline comment that appears to be a pragma.
3234                 - The set of all string prefixes in the string group is of
3235                   length greater than one and is not equal to {"", "f"}.
3236                 - The string group consists of raw strings.
3237         """
3238         # We first check for "inner" stand-alone comments (i.e. stand-alone
3239         # comments that have a string leaf before them AND after them).
3240         for inc in [1, -1]:
3241             i = string_idx
3242             found_sa_comment = False
3243             is_valid_index = is_valid_index_factory(line.leaves)
3244             while is_valid_index(i) and line.leaves[i].type in [
3245                 token.STRING,
3246                 STANDALONE_COMMENT,
3247             ]:
3248                 if line.leaves[i].type == STANDALONE_COMMENT:
3249                     found_sa_comment = True
3250                 elif found_sa_comment:
3251                     return TErr(
3252                         "StringMerger does NOT merge string groups which contain "
3253                         "stand-alone comments."
3254                     )
3255
3256                 i += inc
3257
3258         num_of_inline_string_comments = 0
3259         set_of_prefixes = set()
3260         num_of_strings = 0
3261         for leaf in line.leaves[string_idx:]:
3262             if leaf.type != token.STRING:
3263                 # If the string group is trailed by a comma, we count the
3264                 # comments trailing the comma to be one of the string group's
3265                 # comments.
3266                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3267                     num_of_inline_string_comments += 1
3268                 break
3269
3270             if has_triple_quotes(leaf.value):
3271                 return TErr("StringMerger does NOT merge multiline strings.")
3272
3273             num_of_strings += 1
3274             prefix = get_string_prefix(leaf.value)
3275             if "r" in prefix:
3276                 return TErr("StringMerger does NOT merge raw strings.")
3277
3278             set_of_prefixes.add(prefix)
3279
3280             if id(leaf) in line.comments:
3281                 num_of_inline_string_comments += 1
3282                 if contains_pragma_comment(line.comments[id(leaf)]):
3283                     return TErr("Cannot merge strings which have pragma comments.")
3284
3285         if num_of_strings < 2:
3286             return TErr(
3287                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3288             )
3289
3290         if num_of_inline_string_comments > 1:
3291             return TErr(
3292                 f"Too many inline string comments ({num_of_inline_string_comments})."
3293             )
3294
3295         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3296             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3297
3298         return Ok(None)
3299
3300
3301 class StringParenStripper(StringTransformer):
3302     """StringTransformer that strips surrounding parentheses from strings.
3303
3304     Requirements:
3305         The line contains a string which is surrounded by parentheses and:
3306             - The target string is NOT the only argument to a function call.
3307             - The target string is NOT a "pointless" string.
3308             - If the target string contains a PERCENT, the brackets are not
3309               preceeded or followed by an operator with higher precedence than
3310               PERCENT.
3311
3312     Transformations:
3313         The parentheses mentioned in the 'Requirements' section are stripped.
3314
3315     Collaborations:
3316         StringParenStripper has its own inherent usefulness, but it is also
3317         relied on to clean up the parentheses created by StringParenWrapper (in
3318         the event that they are no longer needed).
3319     """
3320
3321     def do_match(self, line: Line) -> TMatchResult:
3322         LL = line.leaves
3323
3324         is_valid_index = is_valid_index_factory(LL)
3325
3326         for (idx, leaf) in enumerate(LL):
3327             # Should be a string...
3328             if leaf.type != token.STRING:
3329                 continue
3330
3331             # If this is a "pointless" string...
3332             if (
3333                 leaf.parent
3334                 and leaf.parent.parent
3335                 and leaf.parent.parent.type == syms.simple_stmt
3336             ):
3337                 continue
3338
3339             # Should be preceded by a non-empty LPAR...
3340             if (
3341                 not is_valid_index(idx - 1)
3342                 or LL[idx - 1].type != token.LPAR
3343                 or is_empty_lpar(LL[idx - 1])
3344             ):
3345                 continue
3346
3347             # That LPAR should NOT be preceded by a function name or a closing
3348             # bracket (which could be a function which returns a function or a
3349             # list/dictionary that contains a function)...
3350             if is_valid_index(idx - 2) and (
3351                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3352             ):
3353                 continue
3354
3355             string_idx = idx
3356
3357             # Skip the string trailer, if one exists.
3358             string_parser = StringParser()
3359             next_idx = string_parser.parse(LL, string_idx)
3360
3361             # if the leaves in the parsed string include a PERCENT, we need to
3362             # make sure the initial LPAR is NOT preceded by an operator with
3363             # higher or equal precedence to PERCENT
3364             if is_valid_index(idx - 2):
3365                 # mypy can't quite follow unless we name this
3366                 before_lpar = LL[idx - 2]
3367                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3368                     (
3369                         before_lpar.type
3370                         in {
3371                             token.STAR,
3372                             token.AT,
3373                             token.SLASH,
3374                             token.DOUBLESLASH,
3375                             token.PERCENT,
3376                             token.TILDE,
3377                             token.DOUBLESTAR,
3378                             token.AWAIT,
3379                             token.LSQB,
3380                             token.LPAR,
3381                         }
3382                     )
3383                     or (
3384                         # only unary PLUS/MINUS
3385                         before_lpar.parent
3386                         and before_lpar.parent.type == syms.factor
3387                         and (before_lpar.type in {token.PLUS, token.MINUS})
3388                     )
3389                 ):
3390                     continue
3391
3392             # Should be followed by a non-empty RPAR...
3393             if (
3394                 is_valid_index(next_idx)
3395                 and LL[next_idx].type == token.RPAR
3396                 and not is_empty_rpar(LL[next_idx])
3397             ):
3398                 # That RPAR should NOT be followed by anything with higher
3399                 # precedence than PERCENT
3400                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3401                     token.DOUBLESTAR,
3402                     token.LSQB,
3403                     token.LPAR,
3404                     token.DOT,
3405                 }:
3406                     continue
3407
3408                 return Ok(string_idx)
3409
3410         return TErr("This line has no strings wrapped in parens.")
3411
3412     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3413         LL = line.leaves
3414
3415         string_parser = StringParser()
3416         rpar_idx = string_parser.parse(LL, string_idx)
3417
3418         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3419             if line.comments_after(leaf):
3420                 yield TErr(
3421                     "Will not strip parentheses which have comments attached to them."
3422                 )
3423                 return
3424
3425         new_line = line.clone()
3426         new_line.comments = line.comments.copy()
3427         try:
3428             append_leaves(new_line, line, LL[: string_idx - 1])
3429         except BracketMatchError:
3430             # HACK: I believe there is currently a bug somewhere in
3431             # right_hand_split() that is causing brackets to not be tracked
3432             # properly by a shared BracketTracker.
3433             append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
3434
3435         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3436         LL[string_idx - 1].remove()
3437         replace_child(LL[string_idx], string_leaf)
3438         new_line.append(string_leaf)
3439
3440         append_leaves(
3441             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3442         )
3443
3444         LL[rpar_idx].remove()
3445
3446         yield Ok(new_line)
3447
3448
3449 class BaseStringSplitter(StringTransformer):
3450     """
3451     Abstract class for StringTransformers which transform a Line's strings by splitting
3452     them or placing them on their own lines where necessary to avoid going over
3453     the configured line length.
3454
3455     Requirements:
3456         * The target string value is responsible for the line going over the
3457         line length limit. It follows that after all of black's other line
3458         split methods have been exhausted, this line (or one of the resulting
3459         lines after all line splits are performed) would still be over the
3460         line_length limit unless we split this string.
3461             AND
3462         * The target string is NOT a "pointless" string (i.e. a string that has
3463         no parent or siblings).
3464             AND
3465         * The target string is not followed by an inline comment that appears
3466         to be a pragma.
3467             AND
3468         * The target string is not a multiline (i.e. triple-quote) string.
3469     """
3470
3471     @abstractmethod
3472     def do_splitter_match(self, line: Line) -> TMatchResult:
3473         """
3474         BaseStringSplitter asks its clients to override this method instead of
3475         `StringTransformer.do_match(...)`.
3476
3477         Follows the same protocol as `StringTransformer.do_match(...)`.
3478
3479         Refer to `help(StringTransformer.do_match)` for more information.
3480         """
3481
3482     def do_match(self, line: Line) -> TMatchResult:
3483         match_result = self.do_splitter_match(line)
3484         if isinstance(match_result, Err):
3485             return match_result
3486
3487         string_idx = match_result.ok()
3488         vresult = self.__validate(line, string_idx)
3489         if isinstance(vresult, Err):
3490             return vresult
3491
3492         return match_result
3493
3494     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3495         """
3496         Checks that @line meets all of the requirements listed in this classes'
3497         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3498         description of those requirements.
3499
3500         Returns:
3501             * Ok(None), if ALL of the requirements are met.
3502                 OR
3503             * Err(CannotTransform), if ANY of the requirements are NOT met.
3504         """
3505         LL = line.leaves
3506
3507         string_leaf = LL[string_idx]
3508
3509         max_string_length = self.__get_max_string_length(line, string_idx)
3510         if len(string_leaf.value) <= max_string_length:
3511             return TErr(
3512                 "The string itself is not what is causing this line to be too long."
3513             )
3514
3515         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3516             token.STRING,
3517             token.NEWLINE,
3518         ]:
3519             return TErr(
3520                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3521                 " no parent)."
3522             )
3523
3524         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3525             line.comments[id(line.leaves[string_idx])]
3526         ):
3527             return TErr(
3528                 "Line appears to end with an inline pragma comment. Splitting the line"
3529                 " could modify the pragma's behavior."
3530             )
3531
3532         if has_triple_quotes(string_leaf.value):
3533             return TErr("We cannot split multiline strings.")
3534
3535         return Ok(None)
3536
3537     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3538         """
3539         Calculates the max string length used when attempting to determine
3540         whether or not the target string is responsible for causing the line to
3541         go over the line length limit.
3542
3543         WARNING: This method is tightly coupled to both StringSplitter and
3544         (especially) StringParenWrapper. There is probably a better way to
3545         accomplish what is being done here.
3546
3547         Returns:
3548             max_string_length: such that `line.leaves[string_idx].value >
3549             max_string_length` implies that the target string IS responsible
3550             for causing this line to exceed the line length limit.
3551         """
3552         LL = line.leaves
3553
3554         is_valid_index = is_valid_index_factory(LL)
3555
3556         # We use the shorthand "WMA4" in comments to abbreviate "We must
3557         # account for". When giving examples, we use STRING to mean some/any
3558         # valid string.
3559         #
3560         # Finally, we use the following convenience variables:
3561         #
3562         #   P:  The leaf that is before the target string leaf.
3563         #   N:  The leaf that is after the target string leaf.
3564         #   NN: The leaf that is after N.
3565
3566         # WMA4 the whitespace at the beginning of the line.
3567         offset = line.depth * 4
3568
3569         if is_valid_index(string_idx - 1):
3570             p_idx = string_idx - 1
3571             if (
3572                 LL[string_idx - 1].type == token.LPAR
3573                 and LL[string_idx - 1].value == ""
3574                 and string_idx >= 2
3575             ):
3576                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3577                 p_idx -= 1
3578
3579             P = LL[p_idx]
3580             if P.type == token.PLUS:
3581                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3582                 offset += 2
3583
3584             if P.type == token.COMMA:
3585                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3586                 offset += 3
3587
3588             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3589                 # This conditional branch is meant to handle dictionary keys,
3590                 # variable assignments, 'return STRING' statement lines, and
3591                 # 'else STRING' ternary expression lines.
3592
3593                 # WMA4 a single space.
3594                 offset += 1
3595
3596                 # WMA4 the lengths of any leaves that came before that space,
3597                 # but after any closing bracket before that space.
3598                 for leaf in reversed(LL[: p_idx + 1]):
3599                     offset += len(str(leaf))
3600                     if leaf.type in CLOSING_BRACKETS:
3601                         break
3602
3603         if is_valid_index(string_idx + 1):
3604             N = LL[string_idx + 1]
3605             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3606                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3607                 N = LL[string_idx + 2]
3608
3609             if N.type == token.COMMA:
3610                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3611                 offset += 1
3612
3613             if is_valid_index(string_idx + 2):
3614                 NN = LL[string_idx + 2]
3615
3616                 if N.type == token.DOT and NN.type == token.NAME:
3617                     # This conditional branch is meant to handle method calls invoked
3618                     # off of a string literal up to and including the LPAR character.
3619
3620                     # WMA4 the '.' character.
3621                     offset += 1
3622
3623                     if (
3624                         is_valid_index(string_idx + 3)
3625                         and LL[string_idx + 3].type == token.LPAR
3626                     ):
3627                         # WMA4 the left parenthesis character.
3628                         offset += 1
3629
3630                     # WMA4 the length of the method's name.
3631                     offset += len(NN.value)
3632
3633         has_comments = False
3634         for comment_leaf in line.comments_after(LL[string_idx]):
3635             if not has_comments:
3636                 has_comments = True
3637                 # WMA4 two spaces before the '#' character.
3638                 offset += 2
3639
3640             # WMA4 the length of the inline comment.
3641             offset += len(comment_leaf.value)
3642
3643         max_string_length = self.line_length - offset
3644         return max_string_length
3645
3646
3647 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3648     """
3649     StringTransformer that splits "atom" strings (i.e. strings which exist on
3650     lines by themselves).
3651
3652     Requirements:
3653         * The line consists ONLY of a single string (with the exception of a
3654         '+' symbol which MAY exist at the start of the line), MAYBE a string
3655         trailer, and MAYBE a trailing comma.
3656             AND
3657         * All of the requirements listed in BaseStringSplitter's docstring.
3658
3659     Transformations:
3660         The string mentioned in the 'Requirements' section is split into as
3661         many substrings as necessary to adhere to the configured line length.
3662
3663         In the final set of substrings, no substring should be smaller than
3664         MIN_SUBSTR_SIZE characters.
3665
3666         The string will ONLY be split on spaces (i.e. each new substring should
3667         start with a space). Note that the string will NOT be split on a space
3668         which is escaped with a backslash.
3669
3670         If the string is an f-string, it will NOT be split in the middle of an
3671         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3672         else bar()} is an f-expression).
3673
3674         If the string that is being split has an associated set of custom split
3675         records and those custom splits will NOT result in any line going over
3676         the configured line length, those custom splits are used. Otherwise the
3677         string is split as late as possible (from left-to-right) while still
3678         adhering to the transformation rules listed above.
3679
3680     Collaborations:
3681         StringSplitter relies on StringMerger to construct the appropriate
3682         CustomSplit objects and add them to the custom split map.
3683     """
3684
3685     MIN_SUBSTR_SIZE = 6
3686     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3687     RE_FEXPR = r"""
3688     (?<!\{) (?:\{\{)* \{ (?!\{)
3689         (?:
3690             [^\{\}]
3691             | \{\{
3692             | \}\}
3693             | (?R)
3694         )+?
3695     (?<!\}) \} (?:\}\})* (?!\})
3696     """
3697
3698     def do_splitter_match(self, line: Line) -> TMatchResult:
3699         LL = line.leaves
3700
3701         is_valid_index = is_valid_index_factory(LL)
3702
3703         idx = 0
3704
3705         # The first leaf MAY be a '+' symbol...
3706         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3707             idx += 1
3708
3709         # The next/first leaf MAY be an empty LPAR...
3710         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3711             idx += 1
3712
3713         # The next/first leaf MUST be a string...
3714         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3715             return TErr("Line does not start with a string.")
3716
3717         string_idx = idx
3718
3719         # Skip the string trailer, if one exists.
3720         string_parser = StringParser()
3721         idx = string_parser.parse(LL, string_idx)
3722
3723         # That string MAY be followed by an empty RPAR...
3724         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3725             idx += 1
3726
3727         # That string / empty RPAR leaf MAY be followed by a comma...
3728         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3729             idx += 1
3730
3731         # But no more leaves are allowed...
3732         if is_valid_index(idx):
3733             return TErr("This line does not end with a string.")
3734
3735         return Ok(string_idx)
3736
3737     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3738         LL = line.leaves
3739
3740         QUOTE = LL[string_idx].value[-1]
3741
3742         is_valid_index = is_valid_index_factory(LL)
3743         insert_str_child = insert_str_child_factory(LL[string_idx])
3744
3745         prefix = get_string_prefix(LL[string_idx].value)
3746
3747         # We MAY choose to drop the 'f' prefix from substrings that don't
3748         # contain any f-expressions, but ONLY if the original f-string
3749         # contains at least one f-expression. Otherwise, we will alter the AST
3750         # of the program.
3751         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3752             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3753         )
3754
3755         first_string_line = True
3756         starts_with_plus = LL[0].type == token.PLUS
3757
3758         def line_needs_plus() -> bool:
3759             return first_string_line and starts_with_plus
3760
3761         def maybe_append_plus(new_line: Line) -> None:
3762             """
3763             Side Effects:
3764                 If @line starts with a plus and this is the first line we are
3765                 constructing, this function appends a PLUS leaf to @new_line
3766                 and replaces the old PLUS leaf in the node structure. Otherwise
3767                 this function does nothing.
3768             """
3769             if line_needs_plus():
3770                 plus_leaf = Leaf(token.PLUS, "+")
3771                 replace_child(LL[0], plus_leaf)
3772                 new_line.append(plus_leaf)
3773
3774         ends_with_comma = (
3775             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3776         )
3777
3778         def max_last_string() -> int:
3779             """
3780             Returns:
3781                 The max allowed length of the string value used for the last
3782                 line we will construct.
3783             """
3784             result = self.line_length
3785             result -= line.depth * 4
3786             result -= 1 if ends_with_comma else 0
3787             result -= 2 if line_needs_plus() else 0
3788             return result
3789
3790         # --- Calculate Max Break Index (for string value)
3791         # We start with the line length limit
3792         max_break_idx = self.line_length
3793         # The last index of a string of length N is N-1.
3794         max_break_idx -= 1
3795         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3796         max_break_idx -= line.depth * 4
3797         if max_break_idx < 0:
3798             yield TErr(
3799                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3800                 f" {line.depth}"
3801             )
3802             return
3803
3804         # Check if StringMerger registered any custom splits.
3805         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3806         # We use them ONLY if none of them would produce lines that exceed the
3807         # line limit.
3808         use_custom_breakpoints = bool(
3809             custom_splits
3810             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3811         )
3812
3813         # Temporary storage for the remaining chunk of the string line that
3814         # can't fit onto the line currently being constructed.
3815         rest_value = LL[string_idx].value
3816
3817         def more_splits_should_be_made() -> bool:
3818             """
3819             Returns:
3820                 True iff `rest_value` (the remaining string value from the last
3821                 split), should be split again.
3822             """
3823             if use_custom_breakpoints:
3824                 return len(custom_splits) > 1
3825             else:
3826                 return len(rest_value) > max_last_string()
3827
3828         string_line_results: List[Ok[Line]] = []
3829         while more_splits_should_be_made():
3830             if use_custom_breakpoints:
3831                 # Custom User Split (manual)
3832                 csplit = custom_splits.pop(0)
3833                 break_idx = csplit.break_idx
3834             else:
3835                 # Algorithmic Split (automatic)
3836                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3837                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3838                 if maybe_break_idx is None:
3839                     # If we are unable to algorithmically determine a good split
3840                     # and this string has custom splits registered to it, we
3841                     # fall back to using them--which means we have to start
3842                     # over from the beginning.
3843                     if custom_splits:
3844                         rest_value = LL[string_idx].value
3845                         string_line_results = []
3846                         first_string_line = True
3847                         use_custom_breakpoints = True
3848                         continue
3849
3850                     # Otherwise, we stop splitting here.
3851                     break
3852
3853                 break_idx = maybe_break_idx
3854
3855             # --- Construct `next_value`
3856             next_value = rest_value[:break_idx] + QUOTE
3857             if (
3858                 # Are we allowed to try to drop a pointless 'f' prefix?
3859                 drop_pointless_f_prefix
3860                 # If we are, will we be successful?
3861                 and next_value != self.__normalize_f_string(next_value, prefix)
3862             ):
3863                 # If the current custom split did NOT originally use a prefix,
3864                 # then `csplit.break_idx` will be off by one after removing
3865                 # the 'f' prefix.
3866                 break_idx = (
3867                     break_idx + 1
3868                     if use_custom_breakpoints and not csplit.has_prefix
3869                     else break_idx
3870                 )
3871                 next_value = rest_value[:break_idx] + QUOTE
3872                 next_value = self.__normalize_f_string(next_value, prefix)
3873
3874             # --- Construct `next_leaf`
3875             next_leaf = Leaf(token.STRING, next_value)
3876             insert_str_child(next_leaf)
3877             self.__maybe_normalize_string_quotes(next_leaf)
3878
3879             # --- Construct `next_line`
3880             next_line = line.clone()
3881             maybe_append_plus(next_line)
3882             next_line.append(next_leaf)
3883             string_line_results.append(Ok(next_line))
3884
3885             rest_value = prefix + QUOTE + rest_value[break_idx:]
3886             first_string_line = False
3887
3888         yield from string_line_results
3889
3890         if drop_pointless_f_prefix:
3891             rest_value = self.__normalize_f_string(rest_value, prefix)
3892
3893         rest_leaf = Leaf(token.STRING, rest_value)
3894         insert_str_child(rest_leaf)
3895
3896         # NOTE: I could not find a test case that verifies that the following
3897         # line is actually necessary, but it seems to be. Otherwise we risk
3898         # not normalizing the last substring, right?
3899         self.__maybe_normalize_string_quotes(rest_leaf)
3900
3901         last_line = line.clone()
3902         maybe_append_plus(last_line)
3903
3904         # If there are any leaves to the right of the target string...
3905         if is_valid_index(string_idx + 1):
3906             # We use `temp_value` here to determine how long the last line
3907             # would be if we were to append all the leaves to the right of the
3908             # target string to the last string line.
3909             temp_value = rest_value
3910             for leaf in LL[string_idx + 1 :]:
3911                 temp_value += str(leaf)
3912                 if leaf.type == token.LPAR:
3913                     break
3914
3915             # Try to fit them all on the same line with the last substring...
3916             if (
3917                 len(temp_value) <= max_last_string()
3918                 or LL[string_idx + 1].type == token.COMMA
3919             ):
3920                 last_line.append(rest_leaf)
3921                 append_leaves(last_line, line, LL[string_idx + 1 :])
3922                 yield Ok(last_line)
3923             # Otherwise, place the last substring on one line and everything
3924             # else on a line below that...
3925             else:
3926                 last_line.append(rest_leaf)
3927                 yield Ok(last_line)
3928
3929                 non_string_line = line.clone()
3930                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3931                 yield Ok(non_string_line)
3932         # Else the target string was the last leaf...
3933         else:
3934             last_line.append(rest_leaf)
3935             last_line.comments = line.comments.copy()
3936             yield Ok(last_line)
3937
3938     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3939         """
3940         This method contains the algorithm that StringSplitter uses to
3941         determine which character to split each string at.
3942
3943         Args:
3944             @string: The substring that we are attempting to split.
3945             @max_break_idx: The ideal break index. We will return this value if it
3946             meets all the necessary conditions. In the likely event that it
3947             doesn't we will try to find the closest index BELOW @max_break_idx
3948             that does. If that fails, we will expand our search by also
3949             considering all valid indices ABOVE @max_break_idx.
3950
3951         Pre-Conditions:
3952             * assert_is_leaf_string(@string)
3953             * 0 <= @max_break_idx < len(@string)
3954
3955         Returns:
3956             break_idx, if an index is able to be found that meets all of the
3957             conditions listed in the 'Transformations' section of this classes'
3958             docstring.
3959                 OR
3960             None, otherwise.
3961         """
3962         is_valid_index = is_valid_index_factory(string)
3963
3964         assert is_valid_index(max_break_idx)
3965         assert_is_leaf_string(string)
3966
3967         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3968
3969         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3970             """
3971             Yields:
3972                 All ranges of @string which, if @string were to be split there,
3973                 would result in the splitting of an f-expression (which is NOT
3974                 allowed).
3975             """
3976             nonlocal _fexpr_slices
3977
3978             if _fexpr_slices is None:
3979                 _fexpr_slices = []
3980                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3981                     _fexpr_slices.append(match.span())
3982
3983             yield from _fexpr_slices
3984
3985         is_fstring = "f" in get_string_prefix(string)
3986
3987         def breaks_fstring_expression(i: Index) -> bool:
3988             """
3989             Returns:
3990                 True iff returning @i would result in the splitting of an
3991                 f-expression (which is NOT allowed).
3992             """
3993             if not is_fstring:
3994                 return False
3995
3996             for (start, end) in fexpr_slices():
3997                 if start <= i < end:
3998                     return True
3999
4000             return False
4001
4002         def passes_all_checks(i: Index) -> bool:
4003             """
4004             Returns:
4005                 True iff ALL of the conditions listed in the 'Transformations'
4006                 section of this classes' docstring would be be met by returning @i.
4007             """
4008             is_space = string[i] == " "
4009
4010             is_not_escaped = True
4011             j = i - 1
4012             while is_valid_index(j) and string[j] == "\\":
4013                 is_not_escaped = not is_not_escaped
4014                 j -= 1
4015
4016             is_big_enough = (
4017                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
4018                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
4019             )
4020             return (
4021                 is_space
4022                 and is_not_escaped
4023                 and is_big_enough
4024                 and not breaks_fstring_expression(i)
4025             )
4026
4027         # First, we check all indices BELOW @max_break_idx.
4028         break_idx = max_break_idx
4029         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
4030             break_idx -= 1
4031
4032         if not passes_all_checks(break_idx):
4033             # If that fails, we check all indices ABOVE @max_break_idx.
4034             #
4035             # If we are able to find a valid index here, the next line is going
4036             # to be longer than the specified line length, but it's probably
4037             # better than doing nothing at all.
4038             break_idx = max_break_idx + 1
4039             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
4040                 break_idx += 1
4041
4042             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
4043                 return None
4044
4045         return break_idx
4046
4047     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
4048         if self.normalize_strings:
4049             normalize_string_quotes(leaf)
4050
4051     def __normalize_f_string(self, string: str, prefix: str) -> str:
4052         """
4053         Pre-Conditions:
4054             * assert_is_leaf_string(@string)
4055
4056         Returns:
4057             * If @string is an f-string that contains no f-expressions, we
4058             return a string identical to @string except that the 'f' prefix
4059             has been stripped and all double braces (i.e. '{{' or '}}') have
4060             been normalized (i.e. turned into '{' or '}').
4061                 OR
4062             * Otherwise, we return @string.
4063         """
4064         assert_is_leaf_string(string)
4065
4066         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
4067             new_prefix = prefix.replace("f", "")
4068
4069             temp = string[len(prefix) :]
4070             temp = re.sub(r"\{\{", "{", temp)
4071             temp = re.sub(r"\}\}", "}", temp)
4072             new_string = temp
4073
4074             return f"{new_prefix}{new_string}"
4075         else:
4076             return string
4077
4078
4079 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4080     """
4081     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4082     exist on lines by themselves).
4083
4084     Requirements:
4085         All of the requirements listed in BaseStringSplitter's docstring in
4086         addition to the requirements listed below:
4087
4088         * The line is a return/yield statement, which returns/yields a string.
4089             OR
4090         * The line is part of a ternary expression (e.g. `x = y if cond else
4091         z`) such that the line starts with `else <string>`, where <string> is
4092         some string.
4093             OR
4094         * The line is an assert statement, which ends with a string.
4095             OR
4096         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4097         <string>`) such that the variable is being assigned the value of some
4098         string.
4099             OR
4100         * The line is a dictionary key assignment where some valid key is being
4101         assigned the value of some string.
4102
4103     Transformations:
4104         The chosen string is wrapped in parentheses and then split at the LPAR.
4105
4106         We then have one line which ends with an LPAR and another line that
4107         starts with the chosen string. The latter line is then split again at
4108         the RPAR. This results in the RPAR (and possibly a trailing comma)
4109         being placed on its own line.
4110
4111         NOTE: If any leaves exist to the right of the chosen string (except
4112         for a trailing comma, which would be placed after the RPAR), those
4113         leaves are placed inside the parentheses.  In effect, the chosen
4114         string is not necessarily being "wrapped" by parentheses. We can,
4115         however, count on the LPAR being placed directly before the chosen
4116         string.
4117
4118         In other words, StringParenWrapper creates "atom" strings. These
4119         can then be split again by StringSplitter, if necessary.
4120
4121     Collaborations:
4122         In the event that a string line split by StringParenWrapper is
4123         changed such that it no longer needs to be given its own line,
4124         StringParenWrapper relies on StringParenStripper to clean up the
4125         parentheses it created.
4126     """
4127
4128     def do_splitter_match(self, line: Line) -> TMatchResult:
4129         LL = line.leaves
4130
4131         string_idx = (
4132             self._return_match(LL)
4133             or self._else_match(LL)
4134             or self._assert_match(LL)
4135             or self._assign_match(LL)
4136             or self._dict_match(LL)
4137         )
4138
4139         if string_idx is not None:
4140             string_value = line.leaves[string_idx].value
4141             # If the string has no spaces...
4142             if " " not in string_value:
4143                 # And will still violate the line length limit when split...
4144                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4145                 if len(string_value) > max_string_length:
4146                     # And has no associated custom splits...
4147                     if not self.has_custom_splits(string_value):
4148                         # Then we should NOT put this string on its own line.
4149                         return TErr(
4150                             "We do not wrap long strings in parentheses when the"
4151                             " resultant line would still be over the specified line"
4152                             " length and can't be split further by StringSplitter."
4153                         )
4154             return Ok(string_idx)
4155
4156         return TErr("This line does not contain any non-atomic strings.")
4157
4158     @staticmethod
4159     def _return_match(LL: List[Leaf]) -> Optional[int]:
4160         """
4161         Returns:
4162             string_idx such that @LL[string_idx] is equal to our target (i.e.
4163             matched) string, if this line matches the return/yield statement
4164             requirements listed in the 'Requirements' section of this classes'
4165             docstring.
4166                 OR
4167             None, otherwise.
4168         """
4169         # If this line is apart of a return/yield statement and the first leaf
4170         # contains either the "return" or "yield" keywords...
4171         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4172             0
4173         ].value in ["return", "yield"]:
4174             is_valid_index = is_valid_index_factory(LL)
4175
4176             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4177             # The next visible leaf MUST contain a string...
4178             if is_valid_index(idx) and LL[idx].type == token.STRING:
4179                 return idx
4180
4181         return None
4182
4183     @staticmethod
4184     def _else_match(LL: List[Leaf]) -> Optional[int]:
4185         """
4186         Returns:
4187             string_idx such that @LL[string_idx] is equal to our target (i.e.
4188             matched) string, if this line matches the ternary expression
4189             requirements listed in the 'Requirements' section of this classes'
4190             docstring.
4191                 OR
4192             None, otherwise.
4193         """
4194         # If this line is apart of a ternary expression and the first leaf
4195         # contains the "else" keyword...
4196         if (
4197             parent_type(LL[0]) == syms.test
4198             and LL[0].type == token.NAME
4199             and LL[0].value == "else"
4200         ):
4201             is_valid_index = is_valid_index_factory(LL)
4202
4203             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4204             # The next visible leaf MUST contain a string...
4205             if is_valid_index(idx) and LL[idx].type == token.STRING:
4206                 return idx
4207
4208         return None
4209
4210     @staticmethod
4211     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4212         """
4213         Returns:
4214             string_idx such that @LL[string_idx] is equal to our target (i.e.
4215             matched) string, if this line matches the assert statement
4216             requirements listed in the 'Requirements' section of this classes'
4217             docstring.
4218                 OR
4219             None, otherwise.
4220         """
4221         # If this line is apart of an assert statement and the first leaf
4222         # contains the "assert" keyword...
4223         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4224             is_valid_index = is_valid_index_factory(LL)
4225
4226             for (i, leaf) in enumerate(LL):
4227                 # We MUST find a comma...
4228                 if leaf.type == token.COMMA:
4229                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4230
4231                     # That comma MUST be followed by a string...
4232                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4233                         string_idx = idx
4234
4235                         # Skip the string trailer, if one exists.
4236                         string_parser = StringParser()
4237                         idx = string_parser.parse(LL, string_idx)
4238
4239                         # But no more leaves are allowed...
4240                         if not is_valid_index(idx):
4241                             return string_idx
4242
4243         return None
4244
4245     @staticmethod
4246     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4247         """
4248         Returns:
4249             string_idx such that @LL[string_idx] is equal to our target (i.e.
4250             matched) string, if this line matches the assignment statement
4251             requirements listed in the 'Requirements' section of this classes'
4252             docstring.
4253                 OR
4254             None, otherwise.
4255         """
4256         # If this line is apart of an expression statement or is a function
4257         # argument AND the first leaf contains a variable name...
4258         if (
4259             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4260             and LL[0].type == token.NAME
4261         ):
4262             is_valid_index = is_valid_index_factory(LL)
4263
4264             for (i, leaf) in enumerate(LL):
4265                 # We MUST find either an '=' or '+=' symbol...
4266                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4267                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4268
4269                     # That symbol MUST be followed by a string...
4270                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4271                         string_idx = idx
4272
4273                         # Skip the string trailer, if one exists.
4274                         string_parser = StringParser()
4275                         idx = string_parser.parse(LL, string_idx)
4276
4277                         # The next leaf MAY be a comma iff this line is apart
4278                         # of a function argument...
4279                         if (
4280                             parent_type(LL[0]) == syms.argument
4281                             and is_valid_index(idx)
4282                             and LL[idx].type == token.COMMA
4283                         ):
4284                             idx += 1
4285
4286                         # But no more leaves are allowed...
4287                         if not is_valid_index(idx):
4288                             return string_idx
4289
4290         return None
4291
4292     @staticmethod
4293     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4294         """
4295         Returns:
4296             string_idx such that @LL[string_idx] is equal to our target (i.e.
4297             matched) string, if this line matches the dictionary key assignment
4298             statement requirements listed in the 'Requirements' section of this
4299             classes' docstring.
4300                 OR
4301             None, otherwise.
4302         """
4303         # If this line is apart of a dictionary key assignment...
4304         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4305             is_valid_index = is_valid_index_factory(LL)
4306
4307             for (i, leaf) in enumerate(LL):
4308                 # We MUST find a colon...
4309                 if leaf.type == token.COLON:
4310                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4311
4312                     # That colon MUST be followed by a string...
4313                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4314                         string_idx = idx
4315
4316                         # Skip the string trailer, if one exists.
4317                         string_parser = StringParser()
4318                         idx = string_parser.parse(LL, string_idx)
4319
4320                         # That string MAY be followed by a comma...
4321                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4322                             idx += 1
4323
4324                         # But no more leaves are allowed...
4325                         if not is_valid_index(idx):
4326                             return string_idx
4327
4328         return None
4329
4330     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4331         LL = line.leaves
4332
4333         is_valid_index = is_valid_index_factory(LL)
4334         insert_str_child = insert_str_child_factory(LL[string_idx])
4335
4336         comma_idx = -1
4337         ends_with_comma = False
4338         if LL[comma_idx].type == token.COMMA:
4339             ends_with_comma = True
4340
4341         leaves_to_steal_comments_from = [LL[string_idx]]
4342         if ends_with_comma:
4343             leaves_to_steal_comments_from.append(LL[comma_idx])
4344
4345         # --- First Line
4346         first_line = line.clone()
4347         left_leaves = LL[:string_idx]
4348
4349         # We have to remember to account for (possibly invisible) LPAR and RPAR
4350         # leaves that already wrapped the target string. If these leaves do
4351         # exist, we will replace them with our own LPAR and RPAR leaves.
4352         old_parens_exist = False
4353         if left_leaves and left_leaves[-1].type == token.LPAR:
4354             old_parens_exist = True
4355             leaves_to_steal_comments_from.append(left_leaves[-1])
4356             left_leaves.pop()
4357
4358         append_leaves(first_line, line, left_leaves)
4359
4360         lpar_leaf = Leaf(token.LPAR, "(")
4361         if old_parens_exist:
4362             replace_child(LL[string_idx - 1], lpar_leaf)
4363         else:
4364             insert_str_child(lpar_leaf)
4365         first_line.append(lpar_leaf)
4366
4367         # We throw inline comments that were originally to the right of the
4368         # target string to the top line. They will now be shown to the right of
4369         # the LPAR.
4370         for leaf in leaves_to_steal_comments_from:
4371             for comment_leaf in line.comments_after(leaf):
4372                 first_line.append(comment_leaf, preformatted=True)
4373
4374         yield Ok(first_line)
4375
4376         # --- Middle (String) Line
4377         # We only need to yield one (possibly too long) string line, since the
4378         # `StringSplitter` will break it down further if necessary.
4379         string_value = LL[string_idx].value
4380         string_line = Line(
4381             mode=line.mode,
4382             depth=line.depth + 1,
4383             inside_brackets=True,
4384             should_explode=line.should_explode,
4385         )
4386         string_leaf = Leaf(token.STRING, string_value)
4387         insert_str_child(string_leaf)
4388         string_line.append(string_leaf)
4389
4390         old_rpar_leaf = None
4391         if is_valid_index(string_idx + 1):
4392             right_leaves = LL[string_idx + 1 :]
4393             if ends_with_comma:
4394                 right_leaves.pop()
4395
4396             if old_parens_exist:
4397                 assert (
4398                     right_leaves and right_leaves[-1].type == token.RPAR
4399                 ), "Apparently, old parentheses do NOT exist?!"
4400                 old_rpar_leaf = right_leaves.pop()
4401
4402             append_leaves(string_line, line, right_leaves)
4403
4404         yield Ok(string_line)
4405
4406         # --- Last Line
4407         last_line = line.clone()
4408         last_line.bracket_tracker = first_line.bracket_tracker
4409
4410         new_rpar_leaf = Leaf(token.RPAR, ")")
4411         if old_rpar_leaf is not None:
4412             replace_child(old_rpar_leaf, new_rpar_leaf)
4413         else:
4414             insert_str_child(new_rpar_leaf)
4415         last_line.append(new_rpar_leaf)
4416
4417         # If the target string ended with a comma, we place this comma to the
4418         # right of the RPAR on the last line.
4419         if ends_with_comma:
4420             comma_leaf = Leaf(token.COMMA, ",")
4421             replace_child(LL[comma_idx], comma_leaf)
4422             last_line.append(comma_leaf)
4423
4424         yield Ok(last_line)
4425
4426
4427 class StringParser:
4428     """
4429     A state machine that aids in parsing a string's "trailer", which can be
4430     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4431     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4432     varY)`).
4433
4434     NOTE: A new StringParser object MUST be instantiated for each string
4435     trailer we need to parse.
4436
4437     Examples:
4438         We shall assume that `line` equals the `Line` object that corresponds
4439         to the following line of python code:
4440         ```
4441         x = "Some {}.".format("String") + some_other_string
4442         ```
4443
4444         Furthermore, we will assume that `string_idx` is some index such that:
4445         ```
4446         assert line.leaves[string_idx].value == "Some {}."
4447         ```
4448
4449         The following code snippet then holds:
4450         ```
4451         string_parser = StringParser()
4452         idx = string_parser.parse(line.leaves, string_idx)
4453         assert line.leaves[idx].type == token.PLUS
4454         ```
4455     """
4456
4457     DEFAULT_TOKEN = -1
4458
4459     # String Parser States
4460     START = 1
4461     DOT = 2
4462     NAME = 3
4463     PERCENT = 4
4464     SINGLE_FMT_ARG = 5
4465     LPAR = 6
4466     RPAR = 7
4467     DONE = 8
4468
4469     # Lookup Table for Next State
4470     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4471         # A string trailer may start with '.' OR '%'.
4472         (START, token.DOT): DOT,
4473         (START, token.PERCENT): PERCENT,
4474         (START, DEFAULT_TOKEN): DONE,
4475         # A '.' MUST be followed by an attribute or method name.
4476         (DOT, token.NAME): NAME,
4477         # A method name MUST be followed by an '(', whereas an attribute name
4478         # is the last symbol in the string trailer.
4479         (NAME, token.LPAR): LPAR,
4480         (NAME, DEFAULT_TOKEN): DONE,
4481         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4482         # string or variable name).
4483         (PERCENT, token.LPAR): LPAR,
4484         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4485         # If a '%' symbol is followed by a single argument, that argument is
4486         # the last leaf in the string trailer.
4487         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4488         # If present, a ')' symbol is the last symbol in a string trailer.
4489         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4490         # since they are treated as a special case by the parsing logic in this
4491         # classes' implementation.)
4492         (RPAR, DEFAULT_TOKEN): DONE,
4493     }
4494
4495     def __init__(self) -> None:
4496         self._state = self.START
4497         self._unmatched_lpars = 0
4498
4499     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4500         """
4501         Pre-conditions:
4502             * @leaves[@string_idx].type == token.STRING
4503
4504         Returns:
4505             The index directly after the last leaf which is apart of the string
4506             trailer, if a "trailer" exists.
4507                 OR
4508             @string_idx + 1, if no string "trailer" exists.
4509         """
4510         assert leaves[string_idx].type == token.STRING
4511
4512         idx = string_idx + 1
4513         while idx < len(leaves) and self._next_state(leaves[idx]):
4514             idx += 1
4515         return idx
4516
4517     def _next_state(self, leaf: Leaf) -> bool:
4518         """
4519         Pre-conditions:
4520             * On the first call to this function, @leaf MUST be the leaf that
4521             was directly after the string leaf in question (e.g. if our target
4522             string is `line.leaves[i]` then the first call to this method must
4523             be `line.leaves[i + 1]`).
4524             * On the next call to this function, the leaf parameter passed in
4525             MUST be the leaf directly following @leaf.
4526
4527         Returns:
4528             True iff @leaf is apart of the string's trailer.
4529         """
4530         # We ignore empty LPAR or RPAR leaves.
4531         if is_empty_par(leaf):
4532             return True
4533
4534         next_token = leaf.type
4535         if next_token == token.LPAR:
4536             self._unmatched_lpars += 1
4537
4538         current_state = self._state
4539
4540         # The LPAR parser state is a special case. We will return True until we
4541         # find the matching RPAR token.
4542         if current_state == self.LPAR:
4543             if next_token == token.RPAR:
4544                 self._unmatched_lpars -= 1
4545                 if self._unmatched_lpars == 0:
4546                     self._state = self.RPAR
4547         # Otherwise, we use a lookup table to determine the next state.
4548         else:
4549             # If the lookup table matches the current state to the next
4550             # token, we use the lookup table.
4551             if (current_state, next_token) in self._goto:
4552                 self._state = self._goto[current_state, next_token]
4553             else:
4554                 # Otherwise, we check if a the current state was assigned a
4555                 # default.
4556                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4557                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4558                 # If no default has been assigned, then this parser has a logic
4559                 # error.
4560                 else:
4561                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4562
4563             if self._state == self.DONE:
4564                 return False
4565
4566         return True
4567
4568
4569 def TErr(err_msg: str) -> Err[CannotTransform]:
4570     """(T)ransform Err
4571
4572     Convenience function used when working with the TResult type.
4573     """
4574     cant_transform = CannotTransform(err_msg)
4575     return Err(cant_transform)
4576
4577
4578 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4579     """
4580     Returns:
4581         True iff one of the comments in @comment_list is a pragma used by one
4582         of the more common static analysis tools for python (e.g. mypy, flake8,
4583         pylint).
4584     """
4585     for comment in comment_list:
4586         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4587             return True
4588
4589     return False
4590
4591
4592 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4593     """
4594     Factory for a convenience function that is used to orphan @string_leaf
4595     and then insert multiple new leaves into the same part of the node
4596     structure that @string_leaf had originally occupied.
4597
4598     Examples:
4599         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4600         string_leaf.parent`. Assume the node `N` has the following
4601         original structure:
4602
4603         Node(
4604             expr_stmt, [
4605                 Leaf(NAME, 'x'),
4606                 Leaf(EQUAL, '='),
4607                 Leaf(STRING, '"foo"'),
4608             ]
4609         )
4610
4611         We then run the code snippet shown below.
4612         ```
4613         insert_str_child = insert_str_child_factory(string_leaf)
4614
4615         lpar = Leaf(token.LPAR, '(')
4616         insert_str_child(lpar)
4617
4618         bar = Leaf(token.STRING, '"bar"')
4619         insert_str_child(bar)
4620
4621         rpar = Leaf(token.RPAR, ')')
4622         insert_str_child(rpar)
4623         ```
4624
4625         After which point, it follows that `string_leaf.parent is None` and
4626         the node `N` now has the following structure:
4627
4628         Node(
4629             expr_stmt, [
4630                 Leaf(NAME, 'x'),
4631                 Leaf(EQUAL, '='),
4632                 Leaf(LPAR, '('),
4633                 Leaf(STRING, '"bar"'),
4634                 Leaf(RPAR, ')'),
4635             ]
4636         )
4637     """
4638     string_parent = string_leaf.parent
4639     string_child_idx = string_leaf.remove()
4640
4641     def insert_str_child(child: LN) -> None:
4642         nonlocal string_child_idx
4643
4644         assert string_parent is not None
4645         assert string_child_idx is not None
4646
4647         string_parent.insert_child(string_child_idx, child)
4648         string_child_idx += 1
4649
4650     return insert_str_child
4651
4652
4653 def has_triple_quotes(string: str) -> bool:
4654     """
4655     Returns:
4656         True iff @string starts with three quotation characters.
4657     """
4658     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4659     return raw_string[:3] in {'"""', "'''"}
4660
4661
4662 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4663     """
4664     Returns:
4665         @node.parent.type, if @node is not None and has a parent.
4666             OR
4667         None, otherwise.
4668     """
4669     if node is None or node.parent is None:
4670         return None
4671
4672     return node.parent.type
4673
4674
4675 def is_empty_par(leaf: Leaf) -> bool:
4676     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4677
4678
4679 def is_empty_lpar(leaf: Leaf) -> bool:
4680     return leaf.type == token.LPAR and leaf.value == ""
4681
4682
4683 def is_empty_rpar(leaf: Leaf) -> bool:
4684     return leaf.type == token.RPAR and leaf.value == ""
4685
4686
4687 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4688     """
4689     Examples:
4690         ```
4691         my_list = [1, 2, 3]
4692
4693         is_valid_index = is_valid_index_factory(my_list)
4694
4695         assert is_valid_index(0)
4696         assert is_valid_index(2)
4697
4698         assert not is_valid_index(3)
4699         assert not is_valid_index(-1)
4700         ```
4701     """
4702
4703     def is_valid_index(idx: int) -> bool:
4704         """
4705         Returns:
4706             True iff @idx is positive AND seq[@idx] does NOT raise an
4707             IndexError.
4708         """
4709         return 0 <= idx < len(seq)
4710
4711     return is_valid_index
4712
4713
4714 def line_to_string(line: Line) -> str:
4715     """Returns the string representation of @line.
4716
4717     WARNING: This is known to be computationally expensive.
4718     """
4719     return str(line).strip("\n")
4720
4721
4722 def append_leaves(
4723     new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
4724 ) -> None:
4725     """
4726     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4727     underlying Node structure where appropriate.
4728
4729     All of the leaves in @leaves are duplicated. The duplicates are then
4730     appended to @new_line and used to replace their originals in the underlying
4731     Node structure. Any comments attached to the old leaves are reattached to
4732     the new leaves.
4733
4734     Pre-conditions:
4735         set(@leaves) is a subset of set(@old_line.leaves).
4736     """
4737     for old_leaf in leaves:
4738         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4739         replace_child(old_leaf, new_leaf)
4740         new_line.append(new_leaf, preformatted=preformatted)
4741
4742         for comment_leaf in old_line.comments_after(old_leaf):
4743             new_line.append(comment_leaf, preformatted=True)
4744
4745
4746 def replace_child(old_child: LN, new_child: LN) -> None:
4747     """
4748     Side Effects:
4749         * If @old_child.parent is set, replace @old_child with @new_child in
4750         @old_child's underlying Node structure.
4751             OR
4752         * Otherwise, this function does nothing.
4753     """
4754     parent = old_child.parent
4755     if not parent:
4756         return
4757
4758     child_idx = old_child.remove()
4759     if child_idx is not None:
4760         parent.insert_child(child_idx, new_child)
4761
4762
4763 def get_string_prefix(string: str) -> str:
4764     """
4765     Pre-conditions:
4766         * assert_is_leaf_string(@string)
4767
4768     Returns:
4769         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4770     """
4771     assert_is_leaf_string(string)
4772
4773     prefix = ""
4774     prefix_idx = 0
4775     while string[prefix_idx] in STRING_PREFIX_CHARS:
4776         prefix += string[prefix_idx].lower()
4777         prefix_idx += 1
4778
4779     return prefix
4780
4781
4782 def assert_is_leaf_string(string: str) -> None:
4783     """
4784     Checks the pre-condition that @string has the format that you would expect
4785     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4786     token.STRING`. A more precise description of the pre-conditions that are
4787     checked are listed below.
4788
4789     Pre-conditions:
4790         * @string starts with either ', ", <prefix>', or <prefix>" where
4791         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4792         * @string ends with a quote character (' or ").
4793
4794     Raises:
4795         AssertionError(...) if the pre-conditions listed above are not
4796         satisfied.
4797     """
4798     dquote_idx = string.find('"')
4799     squote_idx = string.find("'")
4800     if -1 in [dquote_idx, squote_idx]:
4801         quote_idx = max(dquote_idx, squote_idx)
4802     else:
4803         quote_idx = min(squote_idx, dquote_idx)
4804
4805     assert (
4806         0 <= quote_idx < len(string) - 1
4807     ), f"{string!r} is missing a starting quote character (' or \")."
4808     assert string[-1] in (
4809         "'",
4810         '"',
4811     ), f"{string!r} is missing an ending quote character (' or \")."
4812     assert set(string[:quote_idx]).issubset(
4813         set(STRING_PREFIX_CHARS)
4814     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4815
4816
4817 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4818     """Split line into many lines, starting with the first matching bracket pair.
4819
4820     Note: this usually looks weird, only use this for function definitions.
4821     Prefer RHS otherwise.  This is why this function is not symmetrical with
4822     :func:`right_hand_split` which also handles optional parentheses.
4823     """
4824     tail_leaves: List[Leaf] = []
4825     body_leaves: List[Leaf] = []
4826     head_leaves: List[Leaf] = []
4827     current_leaves = head_leaves
4828     matching_bracket: Optional[Leaf] = None
4829     for leaf in line.leaves:
4830         if (
4831             current_leaves is body_leaves
4832             and leaf.type in CLOSING_BRACKETS
4833             and leaf.opening_bracket is matching_bracket
4834         ):
4835             current_leaves = tail_leaves if body_leaves else head_leaves
4836         current_leaves.append(leaf)
4837         if current_leaves is head_leaves:
4838             if leaf.type in OPENING_BRACKETS:
4839                 matching_bracket = leaf
4840                 current_leaves = body_leaves
4841     if not matching_bracket:
4842         raise CannotSplit("No brackets found")
4843
4844     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4845     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4846     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4847     bracket_split_succeeded_or_raise(head, body, tail)
4848     for result in (head, body, tail):
4849         if result:
4850             yield result
4851
4852
4853 def right_hand_split(
4854     line: Line,
4855     line_length: int,
4856     features: Collection[Feature] = (),
4857     omit: Collection[LeafID] = (),
4858 ) -> Iterator[Line]:
4859     """Split line into many lines, starting with the last matching bracket pair.
4860
4861     If the split was by optional parentheses, attempt splitting without them, too.
4862     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4863     this split.
4864
4865     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4866     """
4867     tail_leaves: List[Leaf] = []
4868     body_leaves: List[Leaf] = []
4869     head_leaves: List[Leaf] = []
4870     current_leaves = tail_leaves
4871     opening_bracket: Optional[Leaf] = None
4872     closing_bracket: Optional[Leaf] = None
4873     for leaf in reversed(line.leaves):
4874         if current_leaves is body_leaves:
4875             if leaf is opening_bracket:
4876                 current_leaves = head_leaves if body_leaves else tail_leaves
4877         current_leaves.append(leaf)
4878         if current_leaves is tail_leaves:
4879             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4880                 opening_bracket = leaf.opening_bracket
4881                 closing_bracket = leaf
4882                 current_leaves = body_leaves
4883     if not (opening_bracket and closing_bracket and head_leaves):
4884         # If there is no opening or closing_bracket that means the split failed and
4885         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4886         # the matching `opening_bracket` wasn't available on `line` anymore.
4887         raise CannotSplit("No brackets found")
4888
4889     tail_leaves.reverse()
4890     body_leaves.reverse()
4891     head_leaves.reverse()
4892     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4893     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4894     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4895     bracket_split_succeeded_or_raise(head, body, tail)
4896     if (
4897         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4898         # the opening bracket is an optional paren
4899         and opening_bracket.type == token.LPAR
4900         and not opening_bracket.value
4901         # the closing bracket is an optional paren
4902         and closing_bracket.type == token.RPAR
4903         and not closing_bracket.value
4904         # it's not an import (optional parens are the only thing we can split on
4905         # in this case; attempting a split without them is a waste of time)
4906         and not line.is_import
4907         # there are no standalone comments in the body
4908         and not body.contains_standalone_comments(0)
4909         # and we can actually remove the parens
4910         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4911     ):
4912         omit = {id(closing_bracket), *omit}
4913         try:
4914             yield from right_hand_split(line, line_length, features=features, omit=omit)
4915             return
4916
4917         except CannotSplit:
4918             if not (
4919                 can_be_split(body)
4920                 or is_line_short_enough(body, line_length=line_length)
4921             ):
4922                 raise CannotSplit(
4923                     "Splitting failed, body is still too long and can't be split."
4924                 )
4925
4926             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4927                 raise CannotSplit(
4928                     "The current optional pair of parentheses is bound to fail to"
4929                     " satisfy the splitting algorithm because the head or the tail"
4930                     " contains multiline strings which by definition never fit one"
4931                     " line."
4932                 )
4933
4934     ensure_visible(opening_bracket)
4935     ensure_visible(closing_bracket)
4936     for result in (head, body, tail):
4937         if result:
4938             yield result
4939
4940
4941 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4942     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4943
4944     Do nothing otherwise.
4945
4946     A left- or right-hand split is based on a pair of brackets. Content before
4947     (and including) the opening bracket is left on one line, content inside the
4948     brackets is put on a separate line, and finally content starting with and
4949     following the closing bracket is put on a separate line.
4950
4951     Those are called `head`, `body`, and `tail`, respectively. If the split
4952     produced the same line (all content in `head`) or ended up with an empty `body`
4953     and the `tail` is just the closing bracket, then it's considered failed.
4954     """
4955     tail_len = len(str(tail).strip())
4956     if not body:
4957         if tail_len == 0:
4958             raise CannotSplit("Splitting brackets produced the same line")
4959
4960         elif tail_len < 3:
4961             raise CannotSplit(
4962                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4963                 " not worth it"
4964             )
4965
4966
4967 def bracket_split_build_line(
4968     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4969 ) -> Line:
4970     """Return a new line with given `leaves` and respective comments from `original`.
4971
4972     If `is_body` is True, the result line is one-indented inside brackets and as such
4973     has its first leaf's prefix normalized and a trailing comma added when expected.
4974     """
4975     result = Line(mode=original.mode, depth=original.depth)
4976     if is_body:
4977         result.inside_brackets = True
4978         result.depth += 1
4979         if leaves:
4980             # Since body is a new indent level, remove spurious leading whitespace.
4981             normalize_prefix(leaves[0], inside_brackets=True)
4982             # Ensure a trailing comma for imports and standalone function arguments, but
4983             # be careful not to add one after any comments or within type annotations.
4984             no_commas = (
4985                 original.is_def
4986                 and opening_bracket.value == "("
4987                 and not any(leaf.type == token.COMMA for leaf in leaves)
4988             )
4989
4990             if original.is_import or no_commas:
4991                 for i in range(len(leaves) - 1, -1, -1):
4992                     if leaves[i].type == STANDALONE_COMMENT:
4993                         continue
4994
4995                     if leaves[i].type != token.COMMA:
4996                         new_comma = Leaf(token.COMMA, ",")
4997                         leaves.insert(i + 1, new_comma)
4998                     break
4999
5000     # Populate the line
5001     for leaf in leaves:
5002         result.append(leaf, preformatted=True)
5003         for comment_after in original.comments_after(leaf):
5004             result.append(comment_after, preformatted=True)
5005     if is_body and should_split_body_explode(result, opening_bracket):
5006         result.should_explode = True
5007     return result
5008
5009
5010 def dont_increase_indentation(split_func: Transformer) -> Transformer:
5011     """Normalize prefix of the first leaf in every line returned by `split_func`.
5012
5013     This is a decorator over relevant split functions.
5014     """
5015
5016     @wraps(split_func)
5017     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5018         for line in split_func(line, features):
5019             normalize_prefix(line.leaves[0], inside_brackets=True)
5020             yield line
5021
5022     return split_wrapper
5023
5024
5025 @dont_increase_indentation
5026 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5027     """Split according to delimiters of the highest priority.
5028
5029     If the appropriate Features are given, the split will add trailing commas
5030     also in function signatures and calls that contain `*` and `**`.
5031     """
5032     try:
5033         last_leaf = line.leaves[-1]
5034     except IndexError:
5035         raise CannotSplit("Line empty")
5036
5037     bt = line.bracket_tracker
5038     try:
5039         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
5040     except ValueError:
5041         raise CannotSplit("No delimiters found")
5042
5043     if delimiter_priority == DOT_PRIORITY:
5044         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
5045             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
5046
5047     current_line = Line(
5048         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5049     )
5050     lowest_depth = sys.maxsize
5051     trailing_comma_safe = True
5052
5053     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5054         """Append `leaf` to current line or to new line if appending impossible."""
5055         nonlocal current_line
5056         try:
5057             current_line.append_safe(leaf, preformatted=True)
5058         except ValueError:
5059             yield current_line
5060
5061             current_line = Line(
5062                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5063             )
5064             current_line.append(leaf)
5065
5066     for leaf in line.leaves:
5067         yield from append_to_line(leaf)
5068
5069         for comment_after in line.comments_after(leaf):
5070             yield from append_to_line(comment_after)
5071
5072         lowest_depth = min(lowest_depth, leaf.bracket_depth)
5073         if leaf.bracket_depth == lowest_depth:
5074             if is_vararg(leaf, within={syms.typedargslist}):
5075                 trailing_comma_safe = (
5076                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
5077                 )
5078             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
5079                 trailing_comma_safe = (
5080                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
5081                 )
5082
5083         leaf_priority = bt.delimiters.get(id(leaf))
5084         if leaf_priority == delimiter_priority:
5085             yield current_line
5086
5087             current_line = Line(
5088                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5089             )
5090     if current_line:
5091         if (
5092             trailing_comma_safe
5093             and delimiter_priority == COMMA_PRIORITY
5094             and current_line.leaves[-1].type != token.COMMA
5095             and current_line.leaves[-1].type != STANDALONE_COMMENT
5096         ):
5097             new_comma = Leaf(token.COMMA, ",")
5098             current_line.append(new_comma)
5099         yield current_line
5100
5101
5102 @dont_increase_indentation
5103 def standalone_comment_split(
5104     line: Line, features: Collection[Feature] = ()
5105 ) -> Iterator[Line]:
5106     """Split standalone comments from the rest of the line."""
5107     if not line.contains_standalone_comments(0):
5108         raise CannotSplit("Line does not have any standalone comments")
5109
5110     current_line = Line(
5111         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5112     )
5113
5114     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5115         """Append `leaf` to current line or to new line if appending impossible."""
5116         nonlocal current_line
5117         try:
5118             current_line.append_safe(leaf, preformatted=True)
5119         except ValueError:
5120             yield current_line
5121
5122             current_line = Line(
5123                 line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5124             )
5125             current_line.append(leaf)
5126
5127     for leaf in line.leaves:
5128         yield from append_to_line(leaf)
5129
5130         for comment_after in line.comments_after(leaf):
5131             yield from append_to_line(comment_after)
5132
5133     if current_line:
5134         yield current_line
5135
5136
5137 def is_import(leaf: Leaf) -> bool:
5138     """Return True if the given leaf starts an import statement."""
5139     p = leaf.parent
5140     t = leaf.type
5141     v = leaf.value
5142     return bool(
5143         t == token.NAME
5144         and (
5145             (v == "import" and p and p.type == syms.import_name)
5146             or (v == "from" and p and p.type == syms.import_from)
5147         )
5148     )
5149
5150
5151 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5152     """Return True if the given leaf is a special comment.
5153     Only returns true for type comments for now."""
5154     t = leaf.type
5155     v = leaf.value
5156     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5157
5158
5159 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5160     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5161     else.
5162
5163     Note: don't use backslashes for formatting or you'll lose your voting rights.
5164     """
5165     if not inside_brackets:
5166         spl = leaf.prefix.split("#")
5167         if "\\" not in spl[0]:
5168             nl_count = spl[-1].count("\n")
5169             if len(spl) > 1:
5170                 nl_count -= 1
5171             leaf.prefix = "\n" * nl_count
5172             return
5173
5174     leaf.prefix = ""
5175
5176
5177 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5178     """Make all string prefixes lowercase.
5179
5180     If remove_u_prefix is given, also removes any u prefix from the string.
5181
5182     Note: Mutates its argument.
5183     """
5184     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5185     assert match is not None, f"failed to match string {leaf.value!r}"
5186     orig_prefix = match.group(1)
5187     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5188     if remove_u_prefix:
5189         new_prefix = new_prefix.replace("u", "")
5190     leaf.value = f"{new_prefix}{match.group(2)}"
5191
5192
5193 def normalize_string_quotes(leaf: Leaf) -> None:
5194     """Prefer double quotes but only if it doesn't cause more escaping.
5195
5196     Adds or removes backslashes as appropriate. Doesn't parse and fix
5197     strings nested in f-strings (yet).
5198
5199     Note: Mutates its argument.
5200     """
5201     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5202     if value[:3] == '"""':
5203         return
5204
5205     elif value[:3] == "'''":
5206         orig_quote = "'''"
5207         new_quote = '"""'
5208     elif value[0] == '"':
5209         orig_quote = '"'
5210         new_quote = "'"
5211     else:
5212         orig_quote = "'"
5213         new_quote = '"'
5214     first_quote_pos = leaf.value.find(orig_quote)
5215     if first_quote_pos == -1:
5216         return  # There's an internal error
5217
5218     prefix = leaf.value[:first_quote_pos]
5219     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5220     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5221     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5222     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5223     if "r" in prefix.casefold():
5224         if unescaped_new_quote.search(body):
5225             # There's at least one unescaped new_quote in this raw string
5226             # so converting is impossible
5227             return
5228
5229         # Do not introduce or remove backslashes in raw strings
5230         new_body = body
5231     else:
5232         # remove unnecessary escapes
5233         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5234         if body != new_body:
5235             # Consider the string without unnecessary escapes as the original
5236             body = new_body
5237             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5238         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5239         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5240     if "f" in prefix.casefold():
5241         matches = re.findall(
5242             r"""
5243             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5244                 ([^{].*?)  # contents of the brackets except if begins with {{
5245             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5246             """,
5247             new_body,
5248             re.VERBOSE,
5249         )
5250         for m in matches:
5251             if "\\" in str(m):
5252                 # Do not introduce backslashes in interpolated expressions
5253                 return
5254
5255     if new_quote == '"""' and new_body[-1:] == '"':
5256         # edge case:
5257         new_body = new_body[:-1] + '\\"'
5258     orig_escape_count = body.count("\\")
5259     new_escape_count = new_body.count("\\")
5260     if new_escape_count > orig_escape_count:
5261         return  # Do not introduce more escaping
5262
5263     if new_escape_count == orig_escape_count and orig_quote == '"':
5264         return  # Prefer double quotes
5265
5266     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5267
5268
5269 def normalize_numeric_literal(leaf: Leaf) -> None:
5270     """Normalizes numeric (float, int, and complex) literals.
5271
5272     All letters used in the representation are normalized to lowercase (except
5273     in Python 2 long literals).
5274     """
5275     text = leaf.value.lower()
5276     if text.startswith(("0o", "0b")):
5277         # Leave octal and binary literals alone.
5278         pass
5279     elif text.startswith("0x"):
5280         text = format_hex(text)
5281     elif "e" in text:
5282         text = format_scientific_notation(text)
5283     elif text.endswith(("j", "l")):
5284         text = format_long_or_complex_number(text)
5285     else:
5286         text = format_float_or_int_string(text)
5287     leaf.value = text
5288
5289
5290 def format_hex(text: str) -> str:
5291     """
5292     Formats a hexadecimal string like "0x12b3"
5293
5294     Uses lowercase because of similarity between "B" and "8", which
5295     can cause security issues.
5296     see: https://github.com/psf/black/issues/1692
5297     """
5298
5299     before, after = text[:2], text[2:]
5300     return f"{before}{after.lower()}"
5301
5302
5303 def format_scientific_notation(text: str) -> str:
5304     """Formats a numeric string utilizing scentific notation"""
5305     before, after = text.split("e")
5306     sign = ""
5307     if after.startswith("-"):
5308         after = after[1:]
5309         sign = "-"
5310     elif after.startswith("+"):
5311         after = after[1:]
5312     before = format_float_or_int_string(before)
5313     return f"{before}e{sign}{after}"
5314
5315
5316 def format_long_or_complex_number(text: str) -> str:
5317     """Formats a long or complex string like `10L` or `10j`"""
5318     number = text[:-1]
5319     suffix = text[-1]
5320     # Capitalize in "2L" because "l" looks too similar to "1".
5321     if suffix == "l":
5322         suffix = "L"
5323     return f"{format_float_or_int_string(number)}{suffix}"
5324
5325
5326 def format_float_or_int_string(text: str) -> str:
5327     """Formats a float string like "1.0"."""
5328     if "." not in text:
5329         return text
5330
5331     before, after = text.split(".")
5332     return f"{before or 0}.{after or 0}"
5333
5334
5335 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5336     """Make existing optional parentheses invisible or create new ones.
5337
5338     `parens_after` is a set of string leaf values immediately after which parens
5339     should be put.
5340
5341     Standardizes on visible parentheses for single-element tuples, and keeps
5342     existing visible parentheses for other tuples and generator expressions.
5343     """
5344     for pc in list_comments(node.prefix, is_endmarker=False):
5345         if pc.value in FMT_OFF:
5346             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5347             return
5348     check_lpar = False
5349     for index, child in enumerate(list(node.children)):
5350         # Fixes a bug where invisible parens are not properly stripped from
5351         # assignment statements that contain type annotations.
5352         if isinstance(child, Node) and child.type == syms.annassign:
5353             normalize_invisible_parens(child, parens_after=parens_after)
5354
5355         # Add parentheses around long tuple unpacking in assignments.
5356         if (
5357             index == 0
5358             and isinstance(child, Node)
5359             and child.type == syms.testlist_star_expr
5360         ):
5361             check_lpar = True
5362
5363         if check_lpar:
5364             if is_walrus_assignment(child):
5365                 pass
5366
5367             elif child.type == syms.atom:
5368                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5369                     wrap_in_parentheses(node, child, visible=False)
5370             elif is_one_tuple(child):
5371                 wrap_in_parentheses(node, child, visible=True)
5372             elif node.type == syms.import_from:
5373                 # "import from" nodes store parentheses directly as part of
5374                 # the statement
5375                 if child.type == token.LPAR:
5376                     # make parentheses invisible
5377                     child.value = ""  # type: ignore
5378                     node.children[-1].value = ""  # type: ignore
5379                 elif child.type != token.STAR:
5380                     # insert invisible parentheses
5381                     node.insert_child(index, Leaf(token.LPAR, ""))
5382                     node.append_child(Leaf(token.RPAR, ""))
5383                 break
5384
5385             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5386                 wrap_in_parentheses(node, child, visible=False)
5387
5388         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5389
5390
5391 def normalize_fmt_off(node: Node) -> None:
5392     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5393     try_again = True
5394     while try_again:
5395         try_again = convert_one_fmt_off_pair(node)
5396
5397
5398 def convert_one_fmt_off_pair(node: Node) -> bool:
5399     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5400
5401     Returns True if a pair was converted.
5402     """
5403     for leaf in node.leaves():
5404         previous_consumed = 0
5405         for comment in list_comments(leaf.prefix, is_endmarker=False):
5406             if comment.value in FMT_OFF:
5407                 # We only want standalone comments. If there's no previous leaf or
5408                 # the previous leaf is indentation, it's a standalone comment in
5409                 # disguise.
5410                 if comment.type != STANDALONE_COMMENT:
5411                     prev = preceding_leaf(leaf)
5412                     if prev and prev.type not in WHITESPACE:
5413                         continue
5414
5415                 ignored_nodes = list(generate_ignored_nodes(leaf))
5416                 if not ignored_nodes:
5417                     continue
5418
5419                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5420                 parent = first.parent
5421                 prefix = first.prefix
5422                 first.prefix = prefix[comment.consumed :]
5423                 hidden_value = (
5424                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5425                 )
5426                 if hidden_value.endswith("\n"):
5427                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5428                     # leaf (possibly followed by a DEDENT).
5429                     hidden_value = hidden_value[:-1]
5430                 first_idx: Optional[int] = None
5431                 for ignored in ignored_nodes:
5432                     index = ignored.remove()
5433                     if first_idx is None:
5434                         first_idx = index
5435                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5436                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5437                 parent.insert_child(
5438                     first_idx,
5439                     Leaf(
5440                         STANDALONE_COMMENT,
5441                         hidden_value,
5442                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5443                     ),
5444                 )
5445                 return True
5446
5447             previous_consumed = comment.consumed
5448
5449     return False
5450
5451
5452 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5453     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5454
5455     Stops at the end of the block.
5456     """
5457     container: Optional[LN] = container_of(leaf)
5458     while container is not None and container.type != token.ENDMARKER:
5459         if is_fmt_on(container):
5460             return
5461
5462         # fix for fmt: on in children
5463         if contains_fmt_on_at_column(container, leaf.column):
5464             for child in container.children:
5465                 if contains_fmt_on_at_column(child, leaf.column):
5466                     return
5467                 yield child
5468         else:
5469             yield container
5470             container = container.next_sibling
5471
5472
5473 def is_fmt_on(container: LN) -> bool:
5474     """Determine whether formatting is switched on within a container.
5475     Determined by whether the last `# fmt:` comment is `on` or `off`.
5476     """
5477     fmt_on = False
5478     for comment in list_comments(container.prefix, is_endmarker=False):
5479         if comment.value in FMT_ON:
5480             fmt_on = True
5481         elif comment.value in FMT_OFF:
5482             fmt_on = False
5483     return fmt_on
5484
5485
5486 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5487     """Determine if children at a given column have formatting switched on."""
5488     for child in container.children:
5489         if (
5490             isinstance(child, Node)
5491             and first_leaf_column(child) == column
5492             or isinstance(child, Leaf)
5493             and child.column == column
5494         ):
5495             if is_fmt_on(child):
5496                 return True
5497
5498     return False
5499
5500
5501 def first_leaf_column(node: Node) -> Optional[int]:
5502     """Returns the column of the first leaf child of a node."""
5503     for child in node.children:
5504         if isinstance(child, Leaf):
5505             return child.column
5506     return None
5507
5508
5509 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5510     """If it's safe, make the parens in the atom `node` invisible, recursively.
5511     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5512     as they are redundant.
5513
5514     Returns whether the node should itself be wrapped in invisible parentheses.
5515
5516     """
5517     if (
5518         node.type != syms.atom
5519         or is_empty_tuple(node)
5520         or is_one_tuple(node)
5521         or (is_yield(node) and parent.type != syms.expr_stmt)
5522         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5523     ):
5524         return False
5525
5526     first = node.children[0]
5527     last = node.children[-1]
5528     if first.type == token.LPAR and last.type == token.RPAR:
5529         middle = node.children[1]
5530         # make parentheses invisible
5531         first.value = ""  # type: ignore
5532         last.value = ""  # type: ignore
5533         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5534
5535         if is_atom_with_invisible_parens(middle):
5536             # Strip the invisible parens from `middle` by replacing
5537             # it with the child in-between the invisible parens
5538             middle.replace(middle.children[1])
5539
5540         return False
5541
5542     return True
5543
5544
5545 def is_atom_with_invisible_parens(node: LN) -> bool:
5546     """Given a `LN`, determines whether it's an atom `node` with invisible
5547     parens. Useful in dedupe-ing and normalizing parens.
5548     """
5549     if isinstance(node, Leaf) or node.type != syms.atom:
5550         return False
5551
5552     first, last = node.children[0], node.children[-1]
5553     return (
5554         isinstance(first, Leaf)
5555         and first.type == token.LPAR
5556         and first.value == ""
5557         and isinstance(last, Leaf)
5558         and last.type == token.RPAR
5559         and last.value == ""
5560     )
5561
5562
5563 def is_empty_tuple(node: LN) -> bool:
5564     """Return True if `node` holds an empty tuple."""
5565     return (
5566         node.type == syms.atom
5567         and len(node.children) == 2
5568         and node.children[0].type == token.LPAR
5569         and node.children[1].type == token.RPAR
5570     )
5571
5572
5573 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5574     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5575
5576     Parenthesis can be optional. Returns None otherwise"""
5577     if len(node.children) != 3:
5578         return None
5579
5580     lpar, wrapped, rpar = node.children
5581     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5582         return None
5583
5584     return wrapped
5585
5586
5587 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5588     """Wrap `child` in parentheses.
5589
5590     This replaces `child` with an atom holding the parentheses and the old
5591     child.  That requires moving the prefix.
5592
5593     If `visible` is False, the leaves will be valueless (and thus invisible).
5594     """
5595     lpar = Leaf(token.LPAR, "(" if visible else "")
5596     rpar = Leaf(token.RPAR, ")" if visible else "")
5597     prefix = child.prefix
5598     child.prefix = ""
5599     index = child.remove() or 0
5600     new_child = Node(syms.atom, [lpar, child, rpar])
5601     new_child.prefix = prefix
5602     parent.insert_child(index, new_child)
5603
5604
5605 def is_one_tuple(node: LN) -> bool:
5606     """Return True if `node` holds a tuple with one element, with or without parens."""
5607     if node.type == syms.atom:
5608         gexp = unwrap_singleton_parenthesis(node)
5609         if gexp is None or gexp.type != syms.testlist_gexp:
5610             return False
5611
5612         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5613
5614     return (
5615         node.type in IMPLICIT_TUPLE
5616         and len(node.children) == 2
5617         and node.children[1].type == token.COMMA
5618     )
5619
5620
5621 def is_walrus_assignment(node: LN) -> bool:
5622     """Return True iff `node` is of the shape ( test := test )"""
5623     inner = unwrap_singleton_parenthesis(node)
5624     return inner is not None and inner.type == syms.namedexpr_test
5625
5626
5627 def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
5628     """Return True iff `node` is a trailer valid in a simple decorator"""
5629     return node.type == syms.trailer and (
5630         (
5631             len(node.children) == 2
5632             and node.children[0].type == token.DOT
5633             and node.children[1].type == token.NAME
5634         )
5635         # last trailer can be arguments
5636         or (
5637             last
5638             and len(node.children) == 3
5639             and node.children[0].type == token.LPAR
5640             # and node.children[1].type == syms.argument
5641             and node.children[2].type == token.RPAR
5642         )
5643     )
5644
5645
5646 def is_simple_decorator_expression(node: LN) -> bool:
5647     """Return True iff `node` could be a 'dotted name' decorator
5648
5649     This function takes the node of the 'namedexpr_test' of the new decorator
5650     grammar and test if it would be valid under the old decorator grammar.
5651
5652     The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
5653     The new grammar is : decorator: @ namedexpr_test NEWLINE
5654     """
5655     if node.type == token.NAME:
5656         return True
5657     if node.type == syms.power:
5658         if node.children:
5659             return (
5660                 node.children[0].type == token.NAME
5661                 and all(map(is_simple_decorator_trailer, node.children[1:-1]))
5662                 and (
5663                     len(node.children) < 2
5664                     or is_simple_decorator_trailer(node.children[-1], last=True)
5665                 )
5666             )
5667     return False
5668
5669
5670 def is_yield(node: LN) -> bool:
5671     """Return True if `node` holds a `yield` or `yield from` expression."""
5672     if node.type == syms.yield_expr:
5673         return True
5674
5675     if node.type == token.NAME and node.value == "yield":  # type: ignore
5676         return True
5677
5678     if node.type != syms.atom:
5679         return False
5680
5681     if len(node.children) != 3:
5682         return False
5683
5684     lpar, expr, rpar = node.children
5685     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5686         return is_yield(expr)
5687
5688     return False
5689
5690
5691 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5692     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5693
5694     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5695     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5696     extended iterable unpacking (PEP 3132) and additional unpacking
5697     generalizations (PEP 448).
5698     """
5699     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5700         return False
5701
5702     p = leaf.parent
5703     if p.type == syms.star_expr:
5704         # Star expressions are also used as assignment targets in extended
5705         # iterable unpacking (PEP 3132).  See what its parent is instead.
5706         if not p.parent:
5707             return False
5708
5709         p = p.parent
5710
5711     return p.type in within
5712
5713
5714 def is_multiline_string(leaf: Leaf) -> bool:
5715     """Return True if `leaf` is a multiline string that actually spans many lines."""
5716     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5717
5718
5719 def is_stub_suite(node: Node) -> bool:
5720     """Return True if `node` is a suite with a stub body."""
5721     if (
5722         len(node.children) != 4
5723         or node.children[0].type != token.NEWLINE
5724         or node.children[1].type != token.INDENT
5725         or node.children[3].type != token.DEDENT
5726     ):
5727         return False
5728
5729     return is_stub_body(node.children[2])
5730
5731
5732 def is_stub_body(node: LN) -> bool:
5733     """Return True if `node` is a simple statement containing an ellipsis."""
5734     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5735         return False
5736
5737     if len(node.children) != 2:
5738         return False
5739
5740     child = node.children[0]
5741     return (
5742         child.type == syms.atom
5743         and len(child.children) == 3
5744         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5745     )
5746
5747
5748 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5749     """Return maximum delimiter priority inside `node`.
5750
5751     This is specific to atoms with contents contained in a pair of parentheses.
5752     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5753     """
5754     if node.type != syms.atom:
5755         return 0
5756
5757     first = node.children[0]
5758     last = node.children[-1]
5759     if not (first.type == token.LPAR and last.type == token.RPAR):
5760         return 0
5761
5762     bt = BracketTracker()
5763     for c in node.children[1:-1]:
5764         if isinstance(c, Leaf):
5765             bt.mark(c)
5766         else:
5767             for leaf in c.leaves():
5768                 bt.mark(leaf)
5769     try:
5770         return bt.max_delimiter_priority()
5771
5772     except ValueError:
5773         return 0
5774
5775
5776 def ensure_visible(leaf: Leaf) -> None:
5777     """Make sure parentheses are visible.
5778
5779     They could be invisible as part of some statements (see
5780     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5781     """
5782     if leaf.type == token.LPAR:
5783         leaf.value = "("
5784     elif leaf.type == token.RPAR:
5785         leaf.value = ")"
5786
5787
5788 def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
5789     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5790
5791     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5792         return False
5793
5794     # We're essentially checking if the body is delimited by commas and there's more
5795     # than one of them (we're excluding the trailing comma and if the delimiter priority
5796     # is still commas, that means there's more).
5797     exclude = set()
5798     trailing_comma = False
5799     try:
5800         last_leaf = line.leaves[-1]
5801         if last_leaf.type == token.COMMA:
5802             trailing_comma = True
5803             exclude.add(id(last_leaf))
5804         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5805     except (IndexError, ValueError):
5806         return False
5807
5808     return max_priority == COMMA_PRIORITY and (
5809         (line.mode.magic_trailing_comma and trailing_comma)
5810         # always explode imports
5811         or opening_bracket.parent.type in {syms.atom, syms.import_from}
5812     )
5813
5814
5815 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5816     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5817     if opening.type != token.LPAR and closing.type != token.RPAR:
5818         return False
5819
5820     depth = closing.bracket_depth + 1
5821     for _opening_index, leaf in enumerate(leaves):
5822         if leaf is opening:
5823             break
5824
5825     else:
5826         raise LookupError("Opening paren not found in `leaves`")
5827
5828     commas = 0
5829     _opening_index += 1
5830     for leaf in leaves[_opening_index:]:
5831         if leaf is closing:
5832             break
5833
5834         bracket_depth = leaf.bracket_depth
5835         if bracket_depth == depth and leaf.type == token.COMMA:
5836             commas += 1
5837             if leaf.parent and leaf.parent.type in {
5838                 syms.arglist,
5839                 syms.typedargslist,
5840             }:
5841                 commas += 1
5842                 break
5843
5844     return commas < 2
5845
5846
5847 def get_features_used(node: Node) -> Set[Feature]:
5848     """Return a set of (relatively) new Python features used in this file.
5849
5850     Currently looking for:
5851     - f-strings;
5852     - underscores in numeric literals;
5853     - trailing commas after * or ** in function signatures and calls;
5854     - positional only arguments in function signatures and lambdas;
5855     - assignment expression;
5856     - relaxed decorator syntax;
5857     """
5858     features: Set[Feature] = set()
5859     for n in node.pre_order():
5860         if n.type == token.STRING:
5861             value_head = n.value[:2]  # type: ignore
5862             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5863                 features.add(Feature.F_STRINGS)
5864
5865         elif n.type == token.NUMBER:
5866             if "_" in n.value:  # type: ignore
5867                 features.add(Feature.NUMERIC_UNDERSCORES)
5868
5869         elif n.type == token.SLASH:
5870             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5871                 features.add(Feature.POS_ONLY_ARGUMENTS)
5872
5873         elif n.type == token.COLONEQUAL:
5874             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5875
5876         elif n.type == syms.decorator:
5877             if len(n.children) > 1 and not is_simple_decorator_expression(
5878                 n.children[1]
5879             ):
5880                 features.add(Feature.RELAXED_DECORATORS)
5881
5882         elif (
5883             n.type in {syms.typedargslist, syms.arglist}
5884             and n.children
5885             and n.children[-1].type == token.COMMA
5886         ):
5887             if n.type == syms.typedargslist:
5888                 feature = Feature.TRAILING_COMMA_IN_DEF
5889             else:
5890                 feature = Feature.TRAILING_COMMA_IN_CALL
5891
5892             for ch in n.children:
5893                 if ch.type in STARS:
5894                     features.add(feature)
5895
5896                 if ch.type == syms.argument:
5897                     for argch in ch.children:
5898                         if argch.type in STARS:
5899                             features.add(feature)
5900
5901     return features
5902
5903
5904 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5905     """Detect the version to target based on the nodes used."""
5906     features = get_features_used(node)
5907     return {
5908         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5909     }
5910
5911
5912 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5913     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5914
5915     Brackets can be omitted if the entire trailer up to and including
5916     a preceding closing bracket fits in one line.
5917
5918     Yielded sets are cumulative (contain results of previous yields, too).  First
5919     set is empty, unless the line should explode, in which case bracket pairs until
5920     the one that needs to explode are omitted.
5921     """
5922
5923     omit: Set[LeafID] = set()
5924     if not line.should_explode:
5925         yield omit
5926
5927     length = 4 * line.depth
5928     opening_bracket: Optional[Leaf] = None
5929     closing_bracket: Optional[Leaf] = None
5930     inner_brackets: Set[LeafID] = set()
5931     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5932         length += leaf_length
5933         if length > line_length:
5934             break
5935
5936         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5937         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5938             break
5939
5940         if opening_bracket:
5941             if leaf is opening_bracket:
5942                 opening_bracket = None
5943             elif leaf.type in CLOSING_BRACKETS:
5944                 prev = line.leaves[index - 1] if index > 0 else None
5945                 if (
5946                     line.should_explode
5947                     and prev
5948                     and prev.type == token.COMMA
5949                     and not is_one_tuple_between(
5950                         leaf.opening_bracket, leaf, line.leaves
5951                     )
5952                 ):
5953                     # Never omit bracket pairs with trailing commas.
5954                     # We need to explode on those.
5955                     break
5956
5957                 inner_brackets.add(id(leaf))
5958         elif leaf.type in CLOSING_BRACKETS:
5959             prev = line.leaves[index - 1] if index > 0 else None
5960             if prev and prev.type in OPENING_BRACKETS:
5961                 # Empty brackets would fail a split so treat them as "inner"
5962                 # brackets (e.g. only add them to the `omit` set if another
5963                 # pair of brackets was good enough.
5964                 inner_brackets.add(id(leaf))
5965                 continue
5966
5967             if closing_bracket:
5968                 omit.add(id(closing_bracket))
5969                 omit.update(inner_brackets)
5970                 inner_brackets.clear()
5971                 yield omit
5972
5973             if (
5974                 line.should_explode
5975                 and prev
5976                 and prev.type == token.COMMA
5977                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
5978             ):
5979                 # Never omit bracket pairs with trailing commas.
5980                 # We need to explode on those.
5981                 break
5982
5983             if leaf.value:
5984                 opening_bracket = leaf.opening_bracket
5985                 closing_bracket = leaf
5986
5987
5988 def get_future_imports(node: Node) -> Set[str]:
5989     """Return a set of __future__ imports in the file."""
5990     imports: Set[str] = set()
5991
5992     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5993         for child in children:
5994             if isinstance(child, Leaf):
5995                 if child.type == token.NAME:
5996                     yield child.value
5997
5998             elif child.type == syms.import_as_name:
5999                 orig_name = child.children[0]
6000                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
6001                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
6002                 yield orig_name.value
6003
6004             elif child.type == syms.import_as_names:
6005                 yield from get_imports_from_children(child.children)
6006
6007             else:
6008                 raise AssertionError("Invalid syntax parsing imports")
6009
6010     for child in node.children:
6011         if child.type != syms.simple_stmt:
6012             break
6013
6014         first_child = child.children[0]
6015         if isinstance(first_child, Leaf):
6016             # Continue looking if we see a docstring; otherwise stop.
6017             if (
6018                 len(child.children) == 2
6019                 and first_child.type == token.STRING
6020                 and child.children[1].type == token.NEWLINE
6021             ):
6022                 continue
6023
6024             break
6025
6026         elif first_child.type == syms.import_from:
6027             module_name = first_child.children[1]
6028             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
6029                 break
6030
6031             imports |= set(get_imports_from_children(first_child.children[3:]))
6032         else:
6033             break
6034
6035     return imports
6036
6037
6038 @lru_cache()
6039 def get_gitignore(root: Path) -> PathSpec:
6040     """ Return a PathSpec matching gitignore content if present."""
6041     gitignore = root / ".gitignore"
6042     lines: List[str] = []
6043     if gitignore.is_file():
6044         with gitignore.open() as gf:
6045             lines = gf.readlines()
6046     return PathSpec.from_lines("gitwildmatch", lines)
6047
6048
6049 def normalize_path_maybe_ignore(
6050     path: Path, root: Path, report: "Report"
6051 ) -> Optional[str]:
6052     """Normalize `path`. May return `None` if `path` was ignored.
6053
6054     `report` is where "path ignored" output goes.
6055     """
6056     try:
6057         abspath = path if path.is_absolute() else Path.cwd() / path
6058         normalized_path = abspath.resolve().relative_to(root).as_posix()
6059     except OSError as e:
6060         report.path_ignored(path, f"cannot be read because {e}")
6061         return None
6062
6063     except ValueError:
6064         if path.is_symlink():
6065             report.path_ignored(path, f"is a symbolic link that points outside {root}")
6066             return None
6067
6068         raise
6069
6070     return normalized_path
6071
6072
6073 def gen_python_files(
6074     paths: Iterable[Path],
6075     root: Path,
6076     include: Optional[Pattern[str]],
6077     exclude: Pattern[str],
6078     force_exclude: Optional[Pattern[str]],
6079     report: "Report",
6080     gitignore: PathSpec,
6081 ) -> Iterator[Path]:
6082     """Generate all files under `path` whose paths are not excluded by the
6083     `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
6084
6085     Symbolic links pointing outside of the `root` directory are ignored.
6086
6087     `report` is where output about exclusions goes.
6088     """
6089     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
6090     for child in paths:
6091         normalized_path = normalize_path_maybe_ignore(child, root, report)
6092         if normalized_path is None:
6093             continue
6094
6095         # First ignore files matching .gitignore
6096         if gitignore.match_file(normalized_path):
6097             report.path_ignored(child, "matches the .gitignore file content")
6098             continue
6099
6100         # Then ignore with `--exclude` and `--force-exclude` options.
6101         normalized_path = "/" + normalized_path
6102         if child.is_dir():
6103             normalized_path += "/"
6104
6105         exclude_match = exclude.search(normalized_path) if exclude else None
6106         if exclude_match and exclude_match.group(0):
6107             report.path_ignored(child, "matches the --exclude regular expression")
6108             continue
6109
6110         force_exclude_match = (
6111             force_exclude.search(normalized_path) if force_exclude else None
6112         )
6113         if force_exclude_match and force_exclude_match.group(0):
6114             report.path_ignored(child, "matches the --force-exclude regular expression")
6115             continue
6116
6117         if child.is_dir():
6118             yield from gen_python_files(
6119                 child.iterdir(),
6120                 root,
6121                 include,
6122                 exclude,
6123                 force_exclude,
6124                 report,
6125                 gitignore,
6126             )
6127
6128         elif child.is_file():
6129             include_match = include.search(normalized_path) if include else True
6130             if include_match:
6131                 yield child
6132
6133
6134 @lru_cache()
6135 def find_project_root(srcs: Iterable[str]) -> Path:
6136     """Return a directory containing .git, .hg, or pyproject.toml.
6137
6138     That directory will be a common parent of all files and directories
6139     passed in `srcs`.
6140
6141     If no directory in the tree contains a marker that would specify it's the
6142     project root, the root of the file system is returned.
6143     """
6144     if not srcs:
6145         return Path("/").resolve()
6146
6147     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
6148
6149     # A list of lists of parents for each 'src'. 'src' is included as a
6150     # "parent" of itself if it is a directory
6151     src_parents = [
6152         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
6153     ]
6154
6155     common_base = max(
6156         set.intersection(*(set(parents) for parents in src_parents)),
6157         key=lambda path: path.parts,
6158     )
6159
6160     for directory in (common_base, *common_base.parents):
6161         if (directory / ".git").exists():
6162             return directory
6163
6164         if (directory / ".hg").is_dir():
6165             return directory
6166
6167         if (directory / "pyproject.toml").is_file():
6168             return directory
6169
6170     return directory
6171
6172
6173 @dataclass
6174 class Report:
6175     """Provides a reformatting counter. Can be rendered with `str(report)`."""
6176
6177     check: bool = False
6178     diff: bool = False
6179     quiet: bool = False
6180     verbose: bool = False
6181     change_count: int = 0
6182     same_count: int = 0
6183     failure_count: int = 0
6184
6185     def done(self, src: Path, changed: Changed) -> None:
6186         """Increment the counter for successful reformatting. Write out a message."""
6187         if changed is Changed.YES:
6188             reformatted = "would reformat" if self.check or self.diff else "reformatted"
6189             if self.verbose or not self.quiet:
6190                 out(f"{reformatted} {src}")
6191             self.change_count += 1
6192         else:
6193             if self.verbose:
6194                 if changed is Changed.NO:
6195                     msg = f"{src} already well formatted, good job."
6196                 else:
6197                     msg = f"{src} wasn't modified on disk since last run."
6198                 out(msg, bold=False)
6199             self.same_count += 1
6200
6201     def failed(self, src: Path, message: str) -> None:
6202         """Increment the counter for failed reformatting. Write out a message."""
6203         err(f"error: cannot format {src}: {message}")
6204         self.failure_count += 1
6205
6206     def path_ignored(self, path: Path, message: str) -> None:
6207         if self.verbose:
6208             out(f"{path} ignored: {message}", bold=False)
6209
6210     @property
6211     def return_code(self) -> int:
6212         """Return the exit code that the app should use.
6213
6214         This considers the current state of changed files and failures:
6215         - if there were any failures, return 123;
6216         - if any files were changed and --check is being used, return 1;
6217         - otherwise return 0.
6218         """
6219         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6220         # 126 we have special return codes reserved by the shell.
6221         if self.failure_count:
6222             return 123
6223
6224         elif self.change_count and self.check:
6225             return 1
6226
6227         return 0
6228
6229     def __str__(self) -> str:
6230         """Render a color report of the current state.
6231
6232         Use `click.unstyle` to remove colors.
6233         """
6234         if self.check or self.diff:
6235             reformatted = "would be reformatted"
6236             unchanged = "would be left unchanged"
6237             failed = "would fail to reformat"
6238         else:
6239             reformatted = "reformatted"
6240             unchanged = "left unchanged"
6241             failed = "failed to reformat"
6242         report = []
6243         if self.change_count:
6244             s = "s" if self.change_count > 1 else ""
6245             report.append(
6246                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6247             )
6248         if self.same_count:
6249             s = "s" if self.same_count > 1 else ""
6250             report.append(f"{self.same_count} file{s} {unchanged}")
6251         if self.failure_count:
6252             s = "s" if self.failure_count > 1 else ""
6253             report.append(
6254                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6255             )
6256         return ", ".join(report) + "."
6257
6258
6259 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6260     filename = "<unknown>"
6261     if sys.version_info >= (3, 8):
6262         # TODO: support Python 4+ ;)
6263         for minor_version in range(sys.version_info[1], 4, -1):
6264             try:
6265                 return ast.parse(src, filename, feature_version=(3, minor_version))
6266             except SyntaxError:
6267                 continue
6268     else:
6269         for feature_version in (7, 6):
6270             try:
6271                 return ast3.parse(src, filename, feature_version=feature_version)
6272             except SyntaxError:
6273                 continue
6274
6275     return ast27.parse(src)
6276
6277
6278 def _fixup_ast_constants(
6279     node: Union[ast.AST, ast3.AST, ast27.AST]
6280 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6281     """Map ast nodes deprecated in 3.8 to Constant."""
6282     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6283         return ast.Constant(value=node.s)
6284
6285     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6286         return ast.Constant(value=node.n)
6287
6288     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6289         return ast.Constant(value=node.value)
6290
6291     return node
6292
6293
6294 def _stringify_ast(
6295     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6296 ) -> Iterator[str]:
6297     """Simple visitor generating strings to compare ASTs by content."""
6298
6299     node = _fixup_ast_constants(node)
6300
6301     yield f"{'  ' * depth}{node.__class__.__name__}("
6302
6303     for field in sorted(node._fields):  # noqa: F402
6304         # TypeIgnore has only one field 'lineno' which breaks this comparison
6305         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6306         if sys.version_info >= (3, 8):
6307             type_ignore_classes += (ast.TypeIgnore,)
6308         if isinstance(node, type_ignore_classes):
6309             break
6310
6311         try:
6312             value = getattr(node, field)
6313         except AttributeError:
6314             continue
6315
6316         yield f"{'  ' * (depth+1)}{field}="
6317
6318         if isinstance(value, list):
6319             for item in value:
6320                 # Ignore nested tuples within del statements, because we may insert
6321                 # parentheses and they change the AST.
6322                 if (
6323                     field == "targets"
6324                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6325                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6326                 ):
6327                     for item in item.elts:
6328                         yield from _stringify_ast(item, depth + 2)
6329
6330                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6331                     yield from _stringify_ast(item, depth + 2)
6332
6333         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6334             yield from _stringify_ast(value, depth + 2)
6335
6336         else:
6337             # Constant strings may be indented across newlines, if they are
6338             # docstrings; fold spaces after newlines when comparing. Similarly,
6339             # trailing and leading space may be removed.
6340             if (
6341                 isinstance(node, ast.Constant)
6342                 and field == "value"
6343                 and isinstance(value, str)
6344             ):
6345                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6346             else:
6347                 normalized = value
6348             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6349
6350     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6351
6352
6353 def assert_equivalent(src: str, dst: str) -> None:
6354     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6355     try:
6356         src_ast = parse_ast(src)
6357     except Exception as exc:
6358         raise AssertionError(
6359             "cannot use --safe with this file; failed to parse source file.  AST"
6360             f" error message: {exc}"
6361         )
6362
6363     try:
6364         dst_ast = parse_ast(dst)
6365     except Exception as exc:
6366         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6367         raise AssertionError(
6368             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6369             " on https://github.com/psf/black/issues.  This invalid output might be"
6370             f" helpful: {log}"
6371         ) from None
6372
6373     src_ast_str = "\n".join(_stringify_ast(src_ast))
6374     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6375     if src_ast_str != dst_ast_str:
6376         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6377         raise AssertionError(
6378             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6379             " source.  Please report a bug on https://github.com/psf/black/issues. "
6380             f" This diff might be helpful: {log}"
6381         ) from None
6382
6383
6384 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6385     """Raise AssertionError if `dst` reformats differently the second time."""
6386     newdst = format_str(dst, mode=mode)
6387     if dst != newdst:
6388         log = dump_to_file(
6389             str(mode),
6390             diff(src, dst, "source", "first pass"),
6391             diff(dst, newdst, "first pass", "second pass"),
6392         )
6393         raise AssertionError(
6394             "INTERNAL ERROR: Black produced different code on the second pass of the"
6395             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6396             f"  This diff might be helpful: {log}"
6397         ) from None
6398
6399
6400 @mypyc_attr(patchable=True)
6401 def dump_to_file(*output: str) -> str:
6402     """Dump `output` to a temporary file. Return path to the file."""
6403     with tempfile.NamedTemporaryFile(
6404         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6405     ) as f:
6406         for lines in output:
6407             f.write(lines)
6408             if lines and lines[-1] != "\n":
6409                 f.write("\n")
6410     return f.name
6411
6412
6413 @contextmanager
6414 def nullcontext() -> Iterator[None]:
6415     """Return an empty context manager.
6416
6417     To be used like `nullcontext` in Python 3.7.
6418     """
6419     yield
6420
6421
6422 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6423     """Return a unified diff string between strings `a` and `b`."""
6424     import difflib
6425
6426     a_lines = [line + "\n" for line in a.splitlines()]
6427     b_lines = [line + "\n" for line in b.splitlines()]
6428     return "".join(
6429         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6430     )
6431
6432
6433 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6434     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6435     err("Aborted!")
6436     for task in tasks:
6437         task.cancel()
6438
6439
6440 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6441     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6442     try:
6443         if sys.version_info[:2] >= (3, 7):
6444             all_tasks = asyncio.all_tasks
6445         else:
6446             all_tasks = asyncio.Task.all_tasks
6447         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6448         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6449         if not to_cancel:
6450             return
6451
6452         for task in to_cancel:
6453             task.cancel()
6454         loop.run_until_complete(
6455             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6456         )
6457     finally:
6458         # `concurrent.futures.Future` objects cannot be cancelled once they
6459         # are already running. There might be some when the `shutdown()` happened.
6460         # Silence their logger's spew about the event loop being closed.
6461         cf_logger = logging.getLogger("concurrent.futures")
6462         cf_logger.setLevel(logging.CRITICAL)
6463         loop.close()
6464
6465
6466 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6467     """Replace `regex` with `replacement` twice on `original`.
6468
6469     This is used by string normalization to perform replaces on
6470     overlapping matches.
6471     """
6472     return regex.sub(replacement, regex.sub(replacement, original))
6473
6474
6475 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6476     """Compile a regular expression string in `regex`.
6477
6478     If it contains newlines, use verbose mode.
6479     """
6480     if "\n" in regex:
6481         regex = "(?x)" + regex
6482     compiled: Pattern[str] = re.compile(regex)
6483     return compiled
6484
6485
6486 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6487     """Like `reversed(enumerate(sequence))` if that were possible."""
6488     index = len(sequence) - 1
6489     for element in reversed(sequence):
6490         yield (index, element)
6491         index -= 1
6492
6493
6494 def enumerate_with_length(
6495     line: Line, reversed: bool = False
6496 ) -> Iterator[Tuple[Index, Leaf, int]]:
6497     """Return an enumeration of leaves with their length.
6498
6499     Stops prematurely on multiline strings and standalone comments.
6500     """
6501     op = cast(
6502         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6503         enumerate_reversed if reversed else enumerate,
6504     )
6505     for index, leaf in op(line.leaves):
6506         length = len(leaf.prefix) + len(leaf.value)
6507         if "\n" in leaf.value:
6508             return  # Multiline strings, we can't continue.
6509
6510         for comment in line.comments_after(leaf):
6511             length += len(comment.value)
6512
6513         yield index, leaf, length
6514
6515
6516 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6517     """Return True if `line` is no longer than `line_length`.
6518
6519     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6520     """
6521     if not line_str:
6522         line_str = line_to_string(line)
6523     return (
6524         len(line_str) <= line_length
6525         and "\n" not in line_str  # multiline strings
6526         and not line.contains_standalone_comments()
6527     )
6528
6529
6530 def can_be_split(line: Line) -> bool:
6531     """Return False if the line cannot be split *for sure*.
6532
6533     This is not an exhaustive search but a cheap heuristic that we can use to
6534     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6535     in unnecessary parentheses).
6536     """
6537     leaves = line.leaves
6538     if len(leaves) < 2:
6539         return False
6540
6541     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6542         call_count = 0
6543         dot_count = 0
6544         next = leaves[-1]
6545         for leaf in leaves[-2::-1]:
6546             if leaf.type in OPENING_BRACKETS:
6547                 if next.type not in CLOSING_BRACKETS:
6548                     return False
6549
6550                 call_count += 1
6551             elif leaf.type == token.DOT:
6552                 dot_count += 1
6553             elif leaf.type == token.NAME:
6554                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6555                     return False
6556
6557             elif leaf.type not in CLOSING_BRACKETS:
6558                 return False
6559
6560             if dot_count > 1 and call_count > 1:
6561                 return False
6562
6563     return True
6564
6565
6566 def can_omit_invisible_parens(
6567     line: Line,
6568     line_length: int,
6569     omit_on_explode: Collection[LeafID] = (),
6570 ) -> bool:
6571     """Does `line` have a shape safe to reformat without optional parens around it?
6572
6573     Returns True for only a subset of potentially nice looking formattings but
6574     the point is to not return false positives that end up producing lines that
6575     are too long.
6576     """
6577     bt = line.bracket_tracker
6578     if not bt.delimiters:
6579         # Without delimiters the optional parentheses are useless.
6580         return True
6581
6582     max_priority = bt.max_delimiter_priority()
6583     if bt.delimiter_count_with_priority(max_priority) > 1:
6584         # With more than one delimiter of a kind the optional parentheses read better.
6585         return False
6586
6587     if max_priority == DOT_PRIORITY:
6588         # A single stranded method call doesn't require optional parentheses.
6589         return True
6590
6591     assert len(line.leaves) >= 2, "Stranded delimiter"
6592
6593     # With a single delimiter, omit if the expression starts or ends with
6594     # a bracket.
6595     first = line.leaves[0]
6596     second = line.leaves[1]
6597     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6598         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6599             return True
6600
6601         # Note: we are not returning False here because a line might have *both*
6602         # a leading opening bracket and a trailing closing bracket.  If the
6603         # opening bracket doesn't match our rule, maybe the closing will.
6604
6605     penultimate = line.leaves[-2]
6606     last = line.leaves[-1]
6607     if line.should_explode:
6608         try:
6609             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6610         except LookupError:
6611             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6612             return False
6613
6614     if (
6615         last.type == token.RPAR
6616         or last.type == token.RBRACE
6617         or (
6618             # don't use indexing for omitting optional parentheses;
6619             # it looks weird
6620             last.type == token.RSQB
6621             and last.parent
6622             and last.parent.type != syms.trailer
6623         )
6624     ):
6625         if penultimate.type in OPENING_BRACKETS:
6626             # Empty brackets don't help.
6627             return False
6628
6629         if is_multiline_string(first):
6630             # Additional wrapping of a multiline string in this situation is
6631             # unnecessary.
6632             return True
6633
6634         if line.should_explode and penultimate.type == token.COMMA:
6635             # The rightmost non-omitted bracket pair is the one we want to explode on.
6636             return True
6637
6638         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6639             return True
6640
6641     return False
6642
6643
6644 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6645     """See `can_omit_invisible_parens`."""
6646     remainder = False
6647     length = 4 * line.depth
6648     _index = -1
6649     for _index, leaf, leaf_length in enumerate_with_length(line):
6650         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6651             remainder = True
6652         if remainder:
6653             length += leaf_length
6654             if length > line_length:
6655                 break
6656
6657             if leaf.type in OPENING_BRACKETS:
6658                 # There are brackets we can further split on.
6659                 remainder = False
6660
6661     else:
6662         # checked the entire string and line length wasn't exceeded
6663         if len(line.leaves) == _index + 1:
6664             return True
6665
6666     return False
6667
6668
6669 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6670     """See `can_omit_invisible_parens`."""
6671     length = 4 * line.depth
6672     seen_other_brackets = False
6673     for _index, leaf, leaf_length in enumerate_with_length(line):
6674         length += leaf_length
6675         if leaf is last.opening_bracket:
6676             if seen_other_brackets or length <= line_length:
6677                 return True
6678
6679         elif leaf.type in OPENING_BRACKETS:
6680             # There are brackets we can further split on.
6681             seen_other_brackets = True
6682
6683     return False
6684
6685
6686 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6687     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6688     stop_after = None
6689     last = None
6690     for leaf in reversed(leaves):
6691         if stop_after:
6692             if leaf is stop_after:
6693                 stop_after = None
6694             continue
6695
6696         if last:
6697             return leaf, last
6698
6699         if id(leaf) in omit:
6700             stop_after = leaf.opening_bracket
6701         else:
6702             last = leaf
6703     else:
6704         raise LookupError("Last two leaves were also skipped")
6705
6706
6707 def run_transformer(
6708     line: Line,
6709     transform: Transformer,
6710     mode: Mode,
6711     features: Collection[Feature],
6712     *,
6713     line_str: str = "",
6714 ) -> List[Line]:
6715     if not line_str:
6716         line_str = line_to_string(line)
6717     result: List[Line] = []
6718     for transformed_line in transform(line, features):
6719         if str(transformed_line).strip("\n") == line_str:
6720             raise CannotTransform("Line transformer returned an unchanged result")
6721
6722         result.extend(transform_line(transformed_line, mode=mode, features=features))
6723
6724     if not (
6725         transform.__name__ == "rhs"
6726         and line.bracket_tracker.invisible
6727         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6728         and not line.contains_multiline_strings()
6729         and not result[0].contains_uncollapsable_type_comments()
6730         and not result[0].contains_unsplittable_type_ignore()
6731         and not is_line_short_enough(result[0], line_length=mode.line_length)
6732     ):
6733         return result
6734
6735     line_copy = line.clone()
6736     append_leaves(line_copy, line, line.leaves)
6737     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6738     second_opinion = run_transformer(
6739         line_copy, transform, mode, features_fop, line_str=line_str
6740     )
6741     if all(
6742         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6743     ):
6744         result = second_opinion
6745     return result
6746
6747
6748 def get_cache_file(mode: Mode) -> Path:
6749     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6750
6751
6752 def read_cache(mode: Mode) -> Cache:
6753     """Read the cache if it exists and is well formed.
6754
6755     If it is not well formed, the call to write_cache later should resolve the issue.
6756     """
6757     cache_file = get_cache_file(mode)
6758     if not cache_file.exists():
6759         return {}
6760
6761     with cache_file.open("rb") as fobj:
6762         try:
6763             cache: Cache = pickle.load(fobj)
6764         except (pickle.UnpicklingError, ValueError):
6765             return {}
6766
6767     return cache
6768
6769
6770 def get_cache_info(path: Path) -> CacheInfo:
6771     """Return the information used to check if a file is already formatted or not."""
6772     stat = path.stat()
6773     return stat.st_mtime, stat.st_size
6774
6775
6776 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6777     """Split an iterable of paths in `sources` into two sets.
6778
6779     The first contains paths of files that modified on disk or are not in the
6780     cache. The other contains paths to non-modified files.
6781     """
6782     todo, done = set(), set()
6783     for src in sources:
6784         src = src.resolve()
6785         if cache.get(src) != get_cache_info(src):
6786             todo.add(src)
6787         else:
6788             done.add(src)
6789     return todo, done
6790
6791
6792 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6793     """Update the cache file."""
6794     cache_file = get_cache_file(mode)
6795     try:
6796         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6797         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6798         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6799             pickle.dump(new_cache, f, protocol=4)
6800         os.replace(f.name, cache_file)
6801     except OSError:
6802         pass
6803
6804
6805 def patch_click() -> None:
6806     """Make Click not crash.
6807
6808     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6809     default which restricts paths that it can access during the lifetime of the
6810     application.  Click refuses to work in this scenario by raising a RuntimeError.
6811
6812     In case of Black the likelihood that non-ASCII characters are going to be used in
6813     file paths is minimal since it's Python source code.  Moreover, this crash was
6814     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6815     """
6816     try:
6817         from click import core
6818         from click import _unicodefun  # type: ignore
6819     except ModuleNotFoundError:
6820         return
6821
6822     for module in (core, _unicodefun):
6823         if hasattr(module, "_verify_python3_env"):
6824             module._verify_python3_env = lambda: None
6825
6826
6827 def patched_main() -> None:
6828     freeze_support()
6829     patch_click()
6830     main()
6831
6832
6833 def is_docstring(leaf: Leaf) -> bool:
6834     if not is_multiline_string(leaf):
6835         # For the purposes of docstring re-indentation, we don't need to do anything
6836         # with single-line docstrings.
6837         return False
6838
6839     if prev_siblings_are(
6840         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
6841     ):
6842         return True
6843
6844     # Multiline docstring on the same line as the `def`.
6845     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
6846         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
6847         # grammar. We're safe to return True without further checks.
6848         return True
6849
6850     return False
6851
6852
6853 def lines_with_leading_tabs_expanded(s: str) -> List[str]:
6854     """
6855     Splits string into lines and expands only leading tabs (following the normal
6856     Python rules)
6857     """
6858     lines = []
6859     for line in s.splitlines():
6860         # Find the index of the first non-whitespace character after a string of
6861         # whitespace that includes at least one tab
6862         match = re.match(r"\s*\t+\s*(\S)", line)
6863         if match:
6864             first_non_whitespace_idx = match.start(1)
6865
6866             lines.append(
6867                 line[:first_non_whitespace_idx].expandtabs()
6868                 + line[first_non_whitespace_idx:]
6869             )
6870         else:
6871             lines.append(line)
6872     return lines
6873
6874
6875 def fix_docstring(docstring: str, prefix: str) -> str:
6876     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6877     if not docstring:
6878         return ""
6879     lines = lines_with_leading_tabs_expanded(docstring)
6880     # Determine minimum indentation (first line doesn't count):
6881     indent = sys.maxsize
6882     for line in lines[1:]:
6883         stripped = line.lstrip()
6884         if stripped:
6885             indent = min(indent, len(line) - len(stripped))
6886     # Remove indentation (first line is special):
6887     trimmed = [lines[0].strip()]
6888     if indent < sys.maxsize:
6889         last_line_idx = len(lines) - 2
6890         for i, line in enumerate(lines[1:]):
6891             stripped_line = line[indent:].rstrip()
6892             if stripped_line or i == last_line_idx:
6893                 trimmed.append(prefix + stripped_line)
6894             else:
6895                 trimmed.append("")
6896     return "\n".join(trimmed)
6897
6898
6899 if __name__ == "__main__":
6900     patched_main()