]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

e21e2af5bd1394ca0441d1ccf133a0d68f6c0515
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from mypy_extensions import mypyc_attr
46
47 from appdirs import user_cache_dir
48 from dataclasses import dataclass, field, replace
49 import click
50 import toml
51 from typed_ast import ast3, ast27
52 from pathspec import PathSpec
53
54 # lib2to3 fork
55 from blib2to3.pytree import Node, Leaf, type_repr
56 from blib2to3 import pygram, pytree
57 from blib2to3.pgen2 import driver, token
58 from blib2to3.pgen2.grammar import Grammar
59 from blib2to3.pgen2.parse import ParseError
60
61 from _black_version import version as __version__
62
63 if sys.version_info < (3, 8):
64     from typing_extensions import Final
65 else:
66     from typing import Final
67
68 if TYPE_CHECKING:
69     import colorama  # noqa: F401
70
71 DEFAULT_LINE_LENGTH = 88
72 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
73 DEFAULT_INCLUDES = r"\.pyi?$"
74 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
75 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
76
77 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
78
79
80 # types
81 FileContent = str
82 Encoding = str
83 NewLine = str
84 Depth = int
85 NodeType = int
86 ParserState = int
87 LeafID = int
88 StringID = int
89 Priority = int
90 Index = int
91 LN = Union[Leaf, Node]
92 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
93 Timestamp = float
94 FileSize = int
95 CacheInfo = Tuple[Timestamp, FileSize]
96 Cache = Dict[str, CacheInfo]
97 out = partial(click.secho, bold=True, err=True)
98 err = partial(click.secho, fg="red", err=True)
99
100 pygram.initialize(CACHE_DIR)
101 syms = pygram.python_symbols
102
103
104 class NothingChanged(UserWarning):
105     """Raised when reformatted code is the same as source."""
106
107
108 class CannotTransform(Exception):
109     """Base class for errors raised by Transformers."""
110
111
112 class CannotSplit(CannotTransform):
113     """A readable split that fits the allotted line length is impossible."""
114
115
116 class InvalidInput(ValueError):
117     """Raised when input source code fails all parse attempts."""
118
119
120 class BracketMatchError(KeyError):
121     """Raised when an opening bracket is unable to be matched to a closing bracket."""
122
123
124 T = TypeVar("T")
125 E = TypeVar("E", bound=Exception)
126
127
128 class Ok(Generic[T]):
129     def __init__(self, value: T) -> None:
130         self._value = value
131
132     def ok(self) -> T:
133         return self._value
134
135
136 class Err(Generic[E]):
137     def __init__(self, e: E) -> None:
138         self._e = e
139
140     def err(self) -> E:
141         return self._e
142
143
144 # The 'Result' return type is used to implement an error-handling model heavily
145 # influenced by that used by the Rust programming language
146 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
147 Result = Union[Ok[T], Err[E]]
148 TResult = Result[T, CannotTransform]  # (T)ransform Result
149 TMatchResult = TResult[Index]
150
151
152 class WriteBack(Enum):
153     NO = 0
154     YES = 1
155     DIFF = 2
156     CHECK = 3
157     COLOR_DIFF = 4
158
159     @classmethod
160     def from_configuration(
161         cls, *, check: bool, diff: bool, color: bool = False
162     ) -> "WriteBack":
163         if check and not diff:
164             return cls.CHECK
165
166         if diff and color:
167             return cls.COLOR_DIFF
168
169         return cls.DIFF if diff else cls.YES
170
171
172 class Changed(Enum):
173     NO = 0
174     CACHED = 1
175     YES = 2
176
177
178 class TargetVersion(Enum):
179     PY27 = 2
180     PY33 = 3
181     PY34 = 4
182     PY35 = 5
183     PY36 = 6
184     PY37 = 7
185     PY38 = 8
186     PY39 = 9
187
188     def is_python2(self) -> bool:
189         return self is TargetVersion.PY27
190
191
192 class Feature(Enum):
193     # All string literals are unicode
194     UNICODE_LITERALS = 1
195     F_STRINGS = 2
196     NUMERIC_UNDERSCORES = 3
197     TRAILING_COMMA_IN_CALL = 4
198     TRAILING_COMMA_IN_DEF = 5
199     # The following two feature-flags are mutually exclusive, and exactly one should be
200     # set for every version of python.
201     ASYNC_IDENTIFIERS = 6
202     ASYNC_KEYWORDS = 7
203     ASSIGNMENT_EXPRESSIONS = 8
204     POS_ONLY_ARGUMENTS = 9
205     RELAXED_DECORATORS = 10
206     FORCE_OPTIONAL_PARENTHESES = 50
207
208
209 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
210     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
211     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
212     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
213     TargetVersion.PY35: {
214         Feature.UNICODE_LITERALS,
215         Feature.TRAILING_COMMA_IN_CALL,
216         Feature.ASYNC_IDENTIFIERS,
217     },
218     TargetVersion.PY36: {
219         Feature.UNICODE_LITERALS,
220         Feature.F_STRINGS,
221         Feature.NUMERIC_UNDERSCORES,
222         Feature.TRAILING_COMMA_IN_CALL,
223         Feature.TRAILING_COMMA_IN_DEF,
224         Feature.ASYNC_IDENTIFIERS,
225     },
226     TargetVersion.PY37: {
227         Feature.UNICODE_LITERALS,
228         Feature.F_STRINGS,
229         Feature.NUMERIC_UNDERSCORES,
230         Feature.TRAILING_COMMA_IN_CALL,
231         Feature.TRAILING_COMMA_IN_DEF,
232         Feature.ASYNC_KEYWORDS,
233     },
234     TargetVersion.PY38: {
235         Feature.UNICODE_LITERALS,
236         Feature.F_STRINGS,
237         Feature.NUMERIC_UNDERSCORES,
238         Feature.TRAILING_COMMA_IN_CALL,
239         Feature.TRAILING_COMMA_IN_DEF,
240         Feature.ASYNC_KEYWORDS,
241         Feature.ASSIGNMENT_EXPRESSIONS,
242         Feature.POS_ONLY_ARGUMENTS,
243     },
244     TargetVersion.PY39: {
245         Feature.UNICODE_LITERALS,
246         Feature.F_STRINGS,
247         Feature.NUMERIC_UNDERSCORES,
248         Feature.TRAILING_COMMA_IN_CALL,
249         Feature.TRAILING_COMMA_IN_DEF,
250         Feature.ASYNC_KEYWORDS,
251         Feature.ASSIGNMENT_EXPRESSIONS,
252         Feature.RELAXED_DECORATORS,
253         Feature.POS_ONLY_ARGUMENTS,
254     },
255 }
256
257
258 @dataclass
259 class Mode:
260     target_versions: Set[TargetVersion] = field(default_factory=set)
261     line_length: int = DEFAULT_LINE_LENGTH
262     string_normalization: bool = True
263     magic_trailing_comma: bool = True
264     experimental_string_processing: bool = False
265     is_pyi: bool = False
266
267     def get_cache_key(self) -> str:
268         if self.target_versions:
269             version_str = ",".join(
270                 str(version.value)
271                 for version in sorted(self.target_versions, key=lambda v: v.value)
272             )
273         else:
274             version_str = "-"
275         parts = [
276             version_str,
277             str(self.line_length),
278             str(int(self.string_normalization)),
279             str(int(self.is_pyi)),
280         ]
281         return ".".join(parts)
282
283
284 # Legacy name, left for integrations.
285 FileMode = Mode
286
287
288 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
289     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
290
291
292 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
293     """Find the absolute filepath to a pyproject.toml if it exists"""
294     path_project_root = find_project_root(path_search_start)
295     path_pyproject_toml = path_project_root / "pyproject.toml"
296     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
297
298
299 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
300     """Parse a pyproject toml file, pulling out relevant parts for Black
301
302     If parsing fails, will raise a toml.TomlDecodeError
303     """
304     pyproject_toml = toml.load(path_config)
305     config = pyproject_toml.get("tool", {}).get("black", {})
306     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
307
308
309 def read_pyproject_toml(
310     ctx: click.Context, param: click.Parameter, value: Optional[str]
311 ) -> Optional[str]:
312     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
313
314     Returns the path to a successfully found and read configuration file, None
315     otherwise.
316     """
317     if not value:
318         value = find_pyproject_toml(ctx.params.get("src", ()))
319         if value is None:
320             return None
321
322     try:
323         config = parse_pyproject_toml(value)
324     except (toml.TomlDecodeError, OSError) as e:
325         raise click.FileError(
326             filename=value, hint=f"Error reading configuration file: {e}"
327         )
328
329     if not config:
330         return None
331     else:
332         # Sanitize the values to be Click friendly. For more information please see:
333         # https://github.com/psf/black/issues/1458
334         # https://github.com/pallets/click/issues/1567
335         config = {
336             k: str(v) if not isinstance(v, (list, dict)) else v
337             for k, v in config.items()
338         }
339
340     target_version = config.get("target_version")
341     if target_version is not None and not isinstance(target_version, list):
342         raise click.BadOptionUsage(
343             "target-version", "Config key target-version must be a list"
344         )
345
346     default_map: Dict[str, Any] = {}
347     if ctx.default_map:
348         default_map.update(ctx.default_map)
349     default_map.update(config)
350
351     ctx.default_map = default_map
352     return value
353
354
355 def target_version_option_callback(
356     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
357 ) -> List[TargetVersion]:
358     """Compute the target versions from a --target-version flag.
359
360     This is its own function because mypy couldn't infer the type correctly
361     when it was a lambda, causing mypyc trouble.
362     """
363     return [TargetVersion[val.upper()] for val in v]
364
365
366 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
367 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
368 @click.option(
369     "-l",
370     "--line-length",
371     type=int,
372     default=DEFAULT_LINE_LENGTH,
373     help="How many characters per line to allow.",
374     show_default=True,
375 )
376 @click.option(
377     "-t",
378     "--target-version",
379     type=click.Choice([v.name.lower() for v in TargetVersion]),
380     callback=target_version_option_callback,
381     multiple=True,
382     help=(
383         "Python versions that should be supported by Black's output. [default: per-file"
384         " auto-detection]"
385     ),
386 )
387 @click.option(
388     "--pyi",
389     is_flag=True,
390     help=(
391         "Format all input files like typing stubs regardless of file extension (useful"
392         " when piping source on standard input)."
393     ),
394 )
395 @click.option(
396     "-S",
397     "--skip-string-normalization",
398     is_flag=True,
399     help="Don't normalize string quotes or prefixes.",
400 )
401 @click.option(
402     "-C",
403     "--skip-magic-trailing-comma",
404     is_flag=True,
405     help="Don't use trailing commas as a reason to split lines.",
406 )
407 @click.option(
408     "--experimental-string-processing",
409     is_flag=True,
410     hidden=True,
411     help=(
412         "Experimental option that performs more normalization on string literals."
413         " Currently disabled because it leads to some crashes."
414     ),
415 )
416 @click.option(
417     "--check",
418     is_flag=True,
419     help=(
420         "Don't write the files back, just return the status.  Return code 0 means"
421         " nothing would change.  Return code 1 means some files would be reformatted."
422         " Return code 123 means there was an internal error."
423     ),
424 )
425 @click.option(
426     "--diff",
427     is_flag=True,
428     help="Don't write the files back, just output a diff for each file on stdout.",
429 )
430 @click.option(
431     "--color/--no-color",
432     is_flag=True,
433     help="Show colored diff. Only applies when `--diff` is given.",
434 )
435 @click.option(
436     "--fast/--safe",
437     is_flag=True,
438     help="If --fast given, skip temporary sanity checks. [default: --safe]",
439 )
440 @click.option(
441     "--include",
442     type=str,
443     default=DEFAULT_INCLUDES,
444     help=(
445         "A regular expression that matches files and directories that should be"
446         " included on recursive searches.  An empty value means all files are included"
447         " regardless of the name.  Use forward slashes for directories on all platforms"
448         " (Windows, too).  Exclusions are calculated first, inclusions later."
449     ),
450     show_default=True,
451 )
452 @click.option(
453     "--exclude",
454     type=str,
455     default=DEFAULT_EXCLUDES,
456     help=(
457         "A regular expression that matches files and directories that should be"
458         " excluded on recursive searches.  An empty value means no paths are excluded."
459         " Use forward slashes for directories on all platforms (Windows, too). "
460         " Exclusions are calculated first, inclusions later."
461     ),
462     show_default=True,
463 )
464 @click.option(
465     "--extend-exclude",
466     type=str,
467     help=(
468         "Like --exclude, but adds additional files and directories on top of the"
469         " excluded ones. (Useful if you simply want to add to the default)"
470     ),
471 )
472 @click.option(
473     "--force-exclude",
474     type=str,
475     help=(
476         "Like --exclude, but files and directories matching this regex will be "
477         "excluded even when they are passed explicitly as arguments."
478     ),
479 )
480 @click.option(
481     "--stdin-filename",
482     type=str,
483     help=(
484         "The name of the file when passing it through stdin. Useful to make "
485         "sure Black will respect --force-exclude option on some "
486         "editors that rely on using stdin."
487     ),
488 )
489 @click.option(
490     "-q",
491     "--quiet",
492     is_flag=True,
493     help=(
494         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
495         " those with 2>/dev/null."
496     ),
497 )
498 @click.option(
499     "-v",
500     "--verbose",
501     is_flag=True,
502     help=(
503         "Also emit messages to stderr about files that were not changed or were ignored"
504         " due to exclusion patterns."
505     ),
506 )
507 @click.version_option(version=__version__)
508 @click.argument(
509     "src",
510     nargs=-1,
511     type=click.Path(
512         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
513     ),
514     is_eager=True,
515 )
516 @click.option(
517     "--config",
518     type=click.Path(
519         exists=True,
520         file_okay=True,
521         dir_okay=False,
522         readable=True,
523         allow_dash=False,
524         path_type=str,
525     ),
526     is_eager=True,
527     callback=read_pyproject_toml,
528     help="Read configuration from FILE path.",
529 )
530 @click.pass_context
531 def main(
532     ctx: click.Context,
533     code: Optional[str],
534     line_length: int,
535     target_version: List[TargetVersion],
536     check: bool,
537     diff: bool,
538     color: bool,
539     fast: bool,
540     pyi: bool,
541     skip_string_normalization: bool,
542     skip_magic_trailing_comma: bool,
543     experimental_string_processing: bool,
544     quiet: bool,
545     verbose: bool,
546     include: str,
547     exclude: str,
548     extend_exclude: Optional[str],
549     force_exclude: Optional[str],
550     stdin_filename: Optional[str],
551     src: Tuple[str, ...],
552     config: Optional[str],
553 ) -> None:
554     """The uncompromising code formatter."""
555     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
556     if target_version:
557         versions = set(target_version)
558     else:
559         # We'll autodetect later.
560         versions = set()
561     mode = Mode(
562         target_versions=versions,
563         line_length=line_length,
564         is_pyi=pyi,
565         string_normalization=not skip_string_normalization,
566         magic_trailing_comma=not skip_magic_trailing_comma,
567         experimental_string_processing=experimental_string_processing,
568     )
569     if config and verbose:
570         out(f"Using configuration from {config}.", bold=False, fg="blue")
571     if code is not None:
572         print(format_str(code, mode=mode))
573         ctx.exit(0)
574     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
575     sources = get_sources(
576         ctx=ctx,
577         src=src,
578         quiet=quiet,
579         verbose=verbose,
580         include=include,
581         exclude=exclude,
582         extend_exclude=extend_exclude,
583         force_exclude=force_exclude,
584         report=report,
585         stdin_filename=stdin_filename,
586     )
587
588     path_empty(
589         sources,
590         "No Python files are present to be formatted. Nothing to do 😴",
591         quiet,
592         verbose,
593         ctx,
594     )
595
596     if len(sources) == 1:
597         reformat_one(
598             src=sources.pop(),
599             fast=fast,
600             write_back=write_back,
601             mode=mode,
602             report=report,
603         )
604     else:
605         reformat_many(
606             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
607         )
608
609     if verbose or not quiet:
610         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
611         click.secho(str(report), err=True)
612     ctx.exit(report.return_code)
613
614
615 def test_regex(
616     ctx: click.Context,
617     regex_name: str,
618     regex: Optional[str],
619 ) -> Optional[Pattern]:
620     try:
621         return re_compile_maybe_verbose(regex) if regex is not None else None
622     except re.error:
623         err(f"Invalid regular expression for {regex_name} given: {regex!r}")
624         ctx.exit(2)
625
626
627 def get_sources(
628     *,
629     ctx: click.Context,
630     src: Tuple[str, ...],
631     quiet: bool,
632     verbose: bool,
633     include: str,
634     exclude: str,
635     extend_exclude: Optional[str],
636     force_exclude: Optional[str],
637     report: "Report",
638     stdin_filename: Optional[str],
639 ) -> Set[Path]:
640     """Compute the set of files to be formatted."""
641
642     include_regex = test_regex(ctx, "include", include)
643     exclude_regex = test_regex(ctx, "exclude", exclude)
644     assert exclude_regex is not None
645     extend_exclude_regex = test_regex(ctx, "extend_exclude", extend_exclude)
646     force_exclude_regex = test_regex(ctx, "force_exclude", force_exclude)
647
648     root = find_project_root(src)
649     sources: Set[Path] = set()
650     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
651     gitignore = get_gitignore(root)
652
653     for s in src:
654         if s == "-" and stdin_filename:
655             p = Path(stdin_filename)
656             is_stdin = True
657         else:
658             p = Path(s)
659             is_stdin = False
660
661         if is_stdin or p.is_file():
662             normalized_path = normalize_path_maybe_ignore(p, root, report)
663             if normalized_path is None:
664                 continue
665
666             normalized_path = "/" + normalized_path
667             # Hard-exclude any files that matches the `--force-exclude` regex.
668             if force_exclude_regex:
669                 force_exclude_match = force_exclude_regex.search(normalized_path)
670             else:
671                 force_exclude_match = None
672             if force_exclude_match and force_exclude_match.group(0):
673                 report.path_ignored(p, "matches the --force-exclude regular expression")
674                 continue
675
676             if is_stdin:
677                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
678
679             sources.add(p)
680         elif p.is_dir():
681             sources.update(
682                 gen_python_files(
683                     p.iterdir(),
684                     root,
685                     include_regex,
686                     exclude_regex,
687                     extend_exclude_regex,
688                     force_exclude_regex,
689                     report,
690                     gitignore,
691                 )
692             )
693         elif s == "-":
694             sources.add(p)
695         else:
696             err(f"invalid path: {s}")
697     return sources
698
699
700 def path_empty(
701     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
702 ) -> None:
703     """
704     Exit if there is no `src` provided for formatting
705     """
706     if not src and (verbose or not quiet):
707         out(msg)
708         ctx.exit(0)
709
710
711 def reformat_one(
712     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
713 ) -> None:
714     """Reformat a single file under `src` without spawning child processes.
715
716     `fast`, `write_back`, and `mode` options are passed to
717     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
718     """
719     try:
720         changed = Changed.NO
721
722         if str(src) == "-":
723             is_stdin = True
724         elif str(src).startswith(STDIN_PLACEHOLDER):
725             is_stdin = True
726             # Use the original name again in case we want to print something
727             # to the user
728             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
729         else:
730             is_stdin = False
731
732         if is_stdin:
733             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
734                 changed = Changed.YES
735         else:
736             cache: Cache = {}
737             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
738                 cache = read_cache(mode)
739                 res_src = src.resolve()
740                 res_src_s = str(res_src)
741                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
742                     changed = Changed.CACHED
743             if changed is not Changed.CACHED and format_file_in_place(
744                 src, fast=fast, write_back=write_back, mode=mode
745             ):
746                 changed = Changed.YES
747             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
748                 write_back is WriteBack.CHECK and changed is Changed.NO
749             ):
750                 write_cache(cache, [src], mode)
751         report.done(src, changed)
752     except Exception as exc:
753         if report.verbose:
754             traceback.print_exc()
755         report.failed(src, str(exc))
756
757
758 def reformat_many(
759     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
760 ) -> None:
761     """Reformat multiple files using a ProcessPoolExecutor."""
762     executor: Executor
763     loop = asyncio.get_event_loop()
764     worker_count = os.cpu_count()
765     if sys.platform == "win32":
766         # Work around https://bugs.python.org/issue26903
767         worker_count = min(worker_count, 60)
768     try:
769         executor = ProcessPoolExecutor(max_workers=worker_count)
770     except (ImportError, OSError):
771         # we arrive here if the underlying system does not support multi-processing
772         # like in AWS Lambda or Termux, in which case we gracefully fallback to
773         # a ThreadPollExecutor with just a single worker (more workers would not do us
774         # any good due to the Global Interpreter Lock)
775         executor = ThreadPoolExecutor(max_workers=1)
776
777     try:
778         loop.run_until_complete(
779             schedule_formatting(
780                 sources=sources,
781                 fast=fast,
782                 write_back=write_back,
783                 mode=mode,
784                 report=report,
785                 loop=loop,
786                 executor=executor,
787             )
788         )
789     finally:
790         shutdown(loop)
791         if executor is not None:
792             executor.shutdown()
793
794
795 async def schedule_formatting(
796     sources: Set[Path],
797     fast: bool,
798     write_back: WriteBack,
799     mode: Mode,
800     report: "Report",
801     loop: asyncio.AbstractEventLoop,
802     executor: Executor,
803 ) -> None:
804     """Run formatting of `sources` in parallel using the provided `executor`.
805
806     (Use ProcessPoolExecutors for actual parallelism.)
807
808     `write_back`, `fast`, and `mode` options are passed to
809     :func:`format_file_in_place`.
810     """
811     cache: Cache = {}
812     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
813         cache = read_cache(mode)
814         sources, cached = filter_cached(cache, sources)
815         for src in sorted(cached):
816             report.done(src, Changed.CACHED)
817     if not sources:
818         return
819
820     cancelled = []
821     sources_to_cache = []
822     lock = None
823     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
824         # For diff output, we need locks to ensure we don't interleave output
825         # from different processes.
826         manager = Manager()
827         lock = manager.Lock()
828     tasks = {
829         asyncio.ensure_future(
830             loop.run_in_executor(
831                 executor, format_file_in_place, src, fast, mode, write_back, lock
832             )
833         ): src
834         for src in sorted(sources)
835     }
836     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
837     try:
838         loop.add_signal_handler(signal.SIGINT, cancel, pending)
839         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
840     except NotImplementedError:
841         # There are no good alternatives for these on Windows.
842         pass
843     while pending:
844         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
845         for task in done:
846             src = tasks.pop(task)
847             if task.cancelled():
848                 cancelled.append(task)
849             elif task.exception():
850                 report.failed(src, str(task.exception()))
851             else:
852                 changed = Changed.YES if task.result() else Changed.NO
853                 # If the file was written back or was successfully checked as
854                 # well-formatted, store this information in the cache.
855                 if write_back is WriteBack.YES or (
856                     write_back is WriteBack.CHECK and changed is Changed.NO
857                 ):
858                     sources_to_cache.append(src)
859                 report.done(src, changed)
860     if cancelled:
861         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
862     if sources_to_cache:
863         write_cache(cache, sources_to_cache, mode)
864
865
866 def format_file_in_place(
867     src: Path,
868     fast: bool,
869     mode: Mode,
870     write_back: WriteBack = WriteBack.NO,
871     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
872 ) -> bool:
873     """Format file under `src` path. Return True if changed.
874
875     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
876     code to the file.
877     `mode` and `fast` options are passed to :func:`format_file_contents`.
878     """
879     if src.suffix == ".pyi":
880         mode = replace(mode, is_pyi=True)
881
882     then = datetime.utcfromtimestamp(src.stat().st_mtime)
883     with open(src, "rb") as buf:
884         src_contents, encoding, newline = decode_bytes(buf.read())
885     try:
886         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
887     except NothingChanged:
888         return False
889
890     if write_back == WriteBack.YES:
891         with open(src, "w", encoding=encoding, newline=newline) as f:
892             f.write(dst_contents)
893     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
894         now = datetime.utcnow()
895         src_name = f"{src}\t{then} +0000"
896         dst_name = f"{src}\t{now} +0000"
897         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
898
899         if write_back == WriteBack.COLOR_DIFF:
900             diff_contents = color_diff(diff_contents)
901
902         with lock or nullcontext():
903             f = io.TextIOWrapper(
904                 sys.stdout.buffer,
905                 encoding=encoding,
906                 newline=newline,
907                 write_through=True,
908             )
909             f = wrap_stream_for_windows(f)
910             f.write(diff_contents)
911             f.detach()
912
913     return True
914
915
916 def color_diff(contents: str) -> str:
917     """Inject the ANSI color codes to the diff."""
918     lines = contents.split("\n")
919     for i, line in enumerate(lines):
920         if line.startswith("+++") or line.startswith("---"):
921             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
922         elif line.startswith("@@"):
923             line = "\033[36m" + line + "\033[0m"  # cyan, reset
924         elif line.startswith("+"):
925             line = "\033[32m" + line + "\033[0m"  # green, reset
926         elif line.startswith("-"):
927             line = "\033[31m" + line + "\033[0m"  # red, reset
928         lines[i] = line
929     return "\n".join(lines)
930
931
932 def wrap_stream_for_windows(
933     f: io.TextIOWrapper,
934 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
935     """
936     Wrap stream with colorama's wrap_stream so colors are shown on Windows.
937
938     If `colorama` is unavailable, the original stream is returned unmodified.
939     Otherwise, the `wrap_stream()` function determines whether the stream needs
940     to be wrapped for a Windows environment and will accordingly either return
941     an `AnsiToWin32` wrapper or the original stream.
942     """
943     try:
944         from colorama.initialise import wrap_stream
945     except ImportError:
946         return f
947     else:
948         # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
949         return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
950
951
952 def format_stdin_to_stdout(
953     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
954 ) -> bool:
955     """Format file on stdin. Return True if changed.
956
957     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
958     write a diff to stdout. The `mode` argument is passed to
959     :func:`format_file_contents`.
960     """
961     then = datetime.utcnow()
962     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
963     dst = src
964     try:
965         dst = format_file_contents(src, fast=fast, mode=mode)
966         return True
967
968     except NothingChanged:
969         return False
970
971     finally:
972         f = io.TextIOWrapper(
973             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
974         )
975         if write_back == WriteBack.YES:
976             f.write(dst)
977         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
978             now = datetime.utcnow()
979             src_name = f"STDIN\t{then} +0000"
980             dst_name = f"STDOUT\t{now} +0000"
981             d = diff(src, dst, src_name, dst_name)
982             if write_back == WriteBack.COLOR_DIFF:
983                 d = color_diff(d)
984                 f = wrap_stream_for_windows(f)
985             f.write(d)
986         f.detach()
987
988
989 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
990     """Reformat contents of a file and return new contents.
991
992     If `fast` is False, additionally confirm that the reformatted code is
993     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
994     `mode` is passed to :func:`format_str`.
995     """
996     if not src_contents.strip():
997         raise NothingChanged
998
999     dst_contents = format_str(src_contents, mode=mode)
1000     if src_contents == dst_contents:
1001         raise NothingChanged
1002
1003     if not fast:
1004         assert_equivalent(src_contents, dst_contents)
1005         assert_stable(src_contents, dst_contents, mode=mode)
1006     return dst_contents
1007
1008
1009 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1010     """Reformat a string and return new contents.
1011
1012     `mode` determines formatting options, such as how many characters per line are
1013     allowed.  Example:
1014
1015     >>> import black
1016     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1017     def f(arg: str = "") -> None:
1018         ...
1019
1020     A more complex example:
1021
1022     >>> print(
1023     ...   black.format_str(
1024     ...     "def f(arg:str='')->None: hey",
1025     ...     mode=black.Mode(
1026     ...       target_versions={black.TargetVersion.PY36},
1027     ...       line_length=10,
1028     ...       string_normalization=False,
1029     ...       is_pyi=False,
1030     ...     ),
1031     ...   ),
1032     ... )
1033     def f(
1034         arg: str = '',
1035     ) -> None:
1036         hey
1037
1038     """
1039     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1040     dst_contents = []
1041     future_imports = get_future_imports(src_node)
1042     if mode.target_versions:
1043         versions = mode.target_versions
1044     else:
1045         versions = detect_target_versions(src_node)
1046     normalize_fmt_off(src_node)
1047     lines = LineGenerator(
1048         mode=mode,
1049         remove_u_prefix="unicode_literals" in future_imports
1050         or supports_feature(versions, Feature.UNICODE_LITERALS),
1051     )
1052     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1053     empty_line = Line(mode=mode)
1054     after = 0
1055     split_line_features = {
1056         feature
1057         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1058         if supports_feature(versions, feature)
1059     }
1060     for current_line in lines.visit(src_node):
1061         dst_contents.append(str(empty_line) * after)
1062         before, after = elt.maybe_empty_lines(current_line)
1063         dst_contents.append(str(empty_line) * before)
1064         for line in transform_line(
1065             current_line, mode=mode, features=split_line_features
1066         ):
1067             dst_contents.append(str(line))
1068     return "".join(dst_contents)
1069
1070
1071 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1072     """Return a tuple of (decoded_contents, encoding, newline).
1073
1074     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1075     universal newlines (i.e. only contains LF).
1076     """
1077     srcbuf = io.BytesIO(src)
1078     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1079     if not lines:
1080         return "", encoding, "\n"
1081
1082     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1083     srcbuf.seek(0)
1084     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1085         return tiow.read(), encoding, newline
1086
1087
1088 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1089     if not target_versions:
1090         # No target_version specified, so try all grammars.
1091         return [
1092             # Python 3.7+
1093             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1094             # Python 3.0-3.6
1095             pygram.python_grammar_no_print_statement_no_exec_statement,
1096             # Python 2.7 with future print_function import
1097             pygram.python_grammar_no_print_statement,
1098             # Python 2.7
1099             pygram.python_grammar,
1100         ]
1101
1102     if all(version.is_python2() for version in target_versions):
1103         # Python 2-only code, so try Python 2 grammars.
1104         return [
1105             # Python 2.7 with future print_function import
1106             pygram.python_grammar_no_print_statement,
1107             # Python 2.7
1108             pygram.python_grammar,
1109         ]
1110
1111     # Python 3-compatible code, so only try Python 3 grammar.
1112     grammars = []
1113     # If we have to parse both, try to parse async as a keyword first
1114     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1115         # Python 3.7+
1116         grammars.append(
1117             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1118         )
1119     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1120         # Python 3.0-3.6
1121         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1122     # At least one of the above branches must have been taken, because every Python
1123     # version has exactly one of the two 'ASYNC_*' flags
1124     return grammars
1125
1126
1127 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1128     """Given a string with source, return the lib2to3 Node."""
1129     if not src_txt.endswith("\n"):
1130         src_txt += "\n"
1131
1132     for grammar in get_grammars(set(target_versions)):
1133         drv = driver.Driver(grammar, pytree.convert)
1134         try:
1135             result = drv.parse_string(src_txt, True)
1136             break
1137
1138         except ParseError as pe:
1139             lineno, column = pe.context[1]
1140             lines = src_txt.splitlines()
1141             try:
1142                 faulty_line = lines[lineno - 1]
1143             except IndexError:
1144                 faulty_line = "<line number missing in source>"
1145             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1146     else:
1147         raise exc from None
1148
1149     if isinstance(result, Leaf):
1150         result = Node(syms.file_input, [result])
1151     return result
1152
1153
1154 def lib2to3_unparse(node: Node) -> str:
1155     """Given a lib2to3 node, return its string representation."""
1156     code = str(node)
1157     return code
1158
1159
1160 class Visitor(Generic[T]):
1161     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1162
1163     def visit(self, node: LN) -> Iterator[T]:
1164         """Main method to visit `node` and its children.
1165
1166         It tries to find a `visit_*()` method for the given `node.type`, like
1167         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1168         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1169         instead.
1170
1171         Then yields objects of type `T` from the selected visitor.
1172         """
1173         if node.type < 256:
1174             name = token.tok_name[node.type]
1175         else:
1176             name = str(type_repr(node.type))
1177         # We explicitly branch on whether a visitor exists (instead of
1178         # using self.visit_default as the default arg to getattr) in order
1179         # to save needing to create a bound method object and so mypyc can
1180         # generate a native call to visit_default.
1181         visitf = getattr(self, f"visit_{name}", None)
1182         if visitf:
1183             yield from visitf(node)
1184         else:
1185             yield from self.visit_default(node)
1186
1187     def visit_default(self, node: LN) -> Iterator[T]:
1188         """Default `visit_*()` implementation. Recurses to children of `node`."""
1189         if isinstance(node, Node):
1190             for child in node.children:
1191                 yield from self.visit(child)
1192
1193
1194 @dataclass
1195 class DebugVisitor(Visitor[T]):
1196     tree_depth: int = 0
1197
1198     def visit_default(self, node: LN) -> Iterator[T]:
1199         indent = " " * (2 * self.tree_depth)
1200         if isinstance(node, Node):
1201             _type = type_repr(node.type)
1202             out(f"{indent}{_type}", fg="yellow")
1203             self.tree_depth += 1
1204             for child in node.children:
1205                 yield from self.visit(child)
1206
1207             self.tree_depth -= 1
1208             out(f"{indent}/{_type}", fg="yellow", bold=False)
1209         else:
1210             _type = token.tok_name.get(node.type, str(node.type))
1211             out(f"{indent}{_type}", fg="blue", nl=False)
1212             if node.prefix:
1213                 # We don't have to handle prefixes for `Node` objects since
1214                 # that delegates to the first child anyway.
1215                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1216             out(f" {node.value!r}", fg="blue", bold=False)
1217
1218     @classmethod
1219     def show(cls, code: Union[str, Leaf, Node]) -> None:
1220         """Pretty-print the lib2to3 AST of a given string of `code`.
1221
1222         Convenience method for debugging.
1223         """
1224         v: DebugVisitor[None] = DebugVisitor()
1225         if isinstance(code, str):
1226             code = lib2to3_parse(code)
1227         list(v.visit(code))
1228
1229
1230 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1231 STATEMENT: Final = {
1232     syms.if_stmt,
1233     syms.while_stmt,
1234     syms.for_stmt,
1235     syms.try_stmt,
1236     syms.except_clause,
1237     syms.with_stmt,
1238     syms.funcdef,
1239     syms.classdef,
1240 }
1241 STANDALONE_COMMENT: Final = 153
1242 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1243 LOGIC_OPERATORS: Final = {"and", "or"}
1244 COMPARATORS: Final = {
1245     token.LESS,
1246     token.GREATER,
1247     token.EQEQUAL,
1248     token.NOTEQUAL,
1249     token.LESSEQUAL,
1250     token.GREATEREQUAL,
1251 }
1252 MATH_OPERATORS: Final = {
1253     token.VBAR,
1254     token.CIRCUMFLEX,
1255     token.AMPER,
1256     token.LEFTSHIFT,
1257     token.RIGHTSHIFT,
1258     token.PLUS,
1259     token.MINUS,
1260     token.STAR,
1261     token.SLASH,
1262     token.DOUBLESLASH,
1263     token.PERCENT,
1264     token.AT,
1265     token.TILDE,
1266     token.DOUBLESTAR,
1267 }
1268 STARS: Final = {token.STAR, token.DOUBLESTAR}
1269 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1270 VARARGS_PARENTS: Final = {
1271     syms.arglist,
1272     syms.argument,  # double star in arglist
1273     syms.trailer,  # single argument to call
1274     syms.typedargslist,
1275     syms.varargslist,  # lambdas
1276 }
1277 UNPACKING_PARENTS: Final = {
1278     syms.atom,  # single element of a list or set literal
1279     syms.dictsetmaker,
1280     syms.listmaker,
1281     syms.testlist_gexp,
1282     syms.testlist_star_expr,
1283 }
1284 TEST_DESCENDANTS: Final = {
1285     syms.test,
1286     syms.lambdef,
1287     syms.or_test,
1288     syms.and_test,
1289     syms.not_test,
1290     syms.comparison,
1291     syms.star_expr,
1292     syms.expr,
1293     syms.xor_expr,
1294     syms.and_expr,
1295     syms.shift_expr,
1296     syms.arith_expr,
1297     syms.trailer,
1298     syms.term,
1299     syms.power,
1300 }
1301 ASSIGNMENTS: Final = {
1302     "=",
1303     "+=",
1304     "-=",
1305     "*=",
1306     "@=",
1307     "/=",
1308     "%=",
1309     "&=",
1310     "|=",
1311     "^=",
1312     "<<=",
1313     ">>=",
1314     "**=",
1315     "//=",
1316 }
1317 COMPREHENSION_PRIORITY: Final = 20
1318 COMMA_PRIORITY: Final = 18
1319 TERNARY_PRIORITY: Final = 16
1320 LOGIC_PRIORITY: Final = 14
1321 STRING_PRIORITY: Final = 12
1322 COMPARATOR_PRIORITY: Final = 10
1323 MATH_PRIORITIES: Final = {
1324     token.VBAR: 9,
1325     token.CIRCUMFLEX: 8,
1326     token.AMPER: 7,
1327     token.LEFTSHIFT: 6,
1328     token.RIGHTSHIFT: 6,
1329     token.PLUS: 5,
1330     token.MINUS: 5,
1331     token.STAR: 4,
1332     token.SLASH: 4,
1333     token.DOUBLESLASH: 4,
1334     token.PERCENT: 4,
1335     token.AT: 4,
1336     token.TILDE: 3,
1337     token.DOUBLESTAR: 2,
1338 }
1339 DOT_PRIORITY: Final = 1
1340
1341
1342 @dataclass
1343 class BracketTracker:
1344     """Keeps track of brackets on a line."""
1345
1346     depth: int = 0
1347     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1348     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1349     previous: Optional[Leaf] = None
1350     _for_loop_depths: List[int] = field(default_factory=list)
1351     _lambda_argument_depths: List[int] = field(default_factory=list)
1352     invisible: List[Leaf] = field(default_factory=list)
1353
1354     def mark(self, leaf: Leaf) -> None:
1355         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1356
1357         All leaves receive an int `bracket_depth` field that stores how deep
1358         within brackets a given leaf is. 0 means there are no enclosing brackets
1359         that started on this line.
1360
1361         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1362         field that it forms a pair with. This is a one-directional link to
1363         avoid reference cycles.
1364
1365         If a leaf is a delimiter (a token on which Black can split the line if
1366         needed) and it's on depth 0, its `id()` is stored in the tracker's
1367         `delimiters` field.
1368         """
1369         if leaf.type == token.COMMENT:
1370             return
1371
1372         self.maybe_decrement_after_for_loop_variable(leaf)
1373         self.maybe_decrement_after_lambda_arguments(leaf)
1374         if leaf.type in CLOSING_BRACKETS:
1375             self.depth -= 1
1376             try:
1377                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1378             except KeyError as e:
1379                 raise BracketMatchError(
1380                     "Unable to match a closing bracket to the following opening"
1381                     f" bracket: {leaf}"
1382                 ) from e
1383             leaf.opening_bracket = opening_bracket
1384             if not leaf.value:
1385                 self.invisible.append(leaf)
1386         leaf.bracket_depth = self.depth
1387         if self.depth == 0:
1388             delim = is_split_before_delimiter(leaf, self.previous)
1389             if delim and self.previous is not None:
1390                 self.delimiters[id(self.previous)] = delim
1391             else:
1392                 delim = is_split_after_delimiter(leaf, self.previous)
1393                 if delim:
1394                     self.delimiters[id(leaf)] = delim
1395         if leaf.type in OPENING_BRACKETS:
1396             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1397             self.depth += 1
1398             if not leaf.value:
1399                 self.invisible.append(leaf)
1400         self.previous = leaf
1401         self.maybe_increment_lambda_arguments(leaf)
1402         self.maybe_increment_for_loop_variable(leaf)
1403
1404     def any_open_brackets(self) -> bool:
1405         """Return True if there is an yet unmatched open bracket on the line."""
1406         return bool(self.bracket_match)
1407
1408     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1409         """Return the highest priority of a delimiter found on the line.
1410
1411         Values are consistent with what `is_split_*_delimiter()` return.
1412         Raises ValueError on no delimiters.
1413         """
1414         return max(v for k, v in self.delimiters.items() if k not in exclude)
1415
1416     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1417         """Return the number of delimiters with the given `priority`.
1418
1419         If no `priority` is passed, defaults to max priority on the line.
1420         """
1421         if not self.delimiters:
1422             return 0
1423
1424         priority = priority or self.max_delimiter_priority()
1425         return sum(1 for p in self.delimiters.values() if p == priority)
1426
1427     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1428         """In a for loop, or comprehension, the variables are often unpacks.
1429
1430         To avoid splitting on the comma in this situation, increase the depth of
1431         tokens between `for` and `in`.
1432         """
1433         if leaf.type == token.NAME and leaf.value == "for":
1434             self.depth += 1
1435             self._for_loop_depths.append(self.depth)
1436             return True
1437
1438         return False
1439
1440     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1441         """See `maybe_increment_for_loop_variable` above for explanation."""
1442         if (
1443             self._for_loop_depths
1444             and self._for_loop_depths[-1] == self.depth
1445             and leaf.type == token.NAME
1446             and leaf.value == "in"
1447         ):
1448             self.depth -= 1
1449             self._for_loop_depths.pop()
1450             return True
1451
1452         return False
1453
1454     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1455         """In a lambda expression, there might be more than one argument.
1456
1457         To avoid splitting on the comma in this situation, increase the depth of
1458         tokens between `lambda` and `:`.
1459         """
1460         if leaf.type == token.NAME and leaf.value == "lambda":
1461             self.depth += 1
1462             self._lambda_argument_depths.append(self.depth)
1463             return True
1464
1465         return False
1466
1467     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1468         """See `maybe_increment_lambda_arguments` above for explanation."""
1469         if (
1470             self._lambda_argument_depths
1471             and self._lambda_argument_depths[-1] == self.depth
1472             and leaf.type == token.COLON
1473         ):
1474             self.depth -= 1
1475             self._lambda_argument_depths.pop()
1476             return True
1477
1478         return False
1479
1480     def get_open_lsqb(self) -> Optional[Leaf]:
1481         """Return the most recent opening square bracket (if any)."""
1482         return self.bracket_match.get((self.depth - 1, token.RSQB))
1483
1484
1485 @dataclass
1486 class Line:
1487     """Holds leaves and comments. Can be printed with `str(line)`."""
1488
1489     mode: Mode
1490     depth: int = 0
1491     leaves: List[Leaf] = field(default_factory=list)
1492     # keys ordered like `leaves`
1493     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1494     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1495     inside_brackets: bool = False
1496     should_split_rhs: bool = False
1497     magic_trailing_comma: Optional[Leaf] = None
1498
1499     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1500         """Add a new `leaf` to the end of the line.
1501
1502         Unless `preformatted` is True, the `leaf` will receive a new consistent
1503         whitespace prefix and metadata applied by :class:`BracketTracker`.
1504         Trailing commas are maybe removed, unpacked for loop variables are
1505         demoted from being delimiters.
1506
1507         Inline comments are put aside.
1508         """
1509         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1510         if not has_value:
1511             return
1512
1513         if token.COLON == leaf.type and self.is_class_paren_empty:
1514             del self.leaves[-2:]
1515         if self.leaves and not preformatted:
1516             # Note: at this point leaf.prefix should be empty except for
1517             # imports, for which we only preserve newlines.
1518             leaf.prefix += whitespace(
1519                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1520             )
1521         if self.inside_brackets or not preformatted:
1522             self.bracket_tracker.mark(leaf)
1523             if self.mode.magic_trailing_comma:
1524                 if self.has_magic_trailing_comma(leaf):
1525                     self.magic_trailing_comma = leaf
1526             elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
1527                 self.remove_trailing_comma()
1528         if not self.append_comment(leaf):
1529             self.leaves.append(leaf)
1530
1531     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1532         """Like :func:`append()` but disallow invalid standalone comment structure.
1533
1534         Raises ValueError when any `leaf` is appended after a standalone comment
1535         or when a standalone comment is not the first leaf on the line.
1536         """
1537         if self.bracket_tracker.depth == 0:
1538             if self.is_comment:
1539                 raise ValueError("cannot append to standalone comments")
1540
1541             if self.leaves and leaf.type == STANDALONE_COMMENT:
1542                 raise ValueError(
1543                     "cannot append standalone comments to a populated line"
1544                 )
1545
1546         self.append(leaf, preformatted=preformatted)
1547
1548     @property
1549     def is_comment(self) -> bool:
1550         """Is this line a standalone comment?"""
1551         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1552
1553     @property
1554     def is_decorator(self) -> bool:
1555         """Is this line a decorator?"""
1556         return bool(self) and self.leaves[0].type == token.AT
1557
1558     @property
1559     def is_import(self) -> bool:
1560         """Is this an import line?"""
1561         return bool(self) and is_import(self.leaves[0])
1562
1563     @property
1564     def is_class(self) -> bool:
1565         """Is this line a class definition?"""
1566         return (
1567             bool(self)
1568             and self.leaves[0].type == token.NAME
1569             and self.leaves[0].value == "class"
1570         )
1571
1572     @property
1573     def is_stub_class(self) -> bool:
1574         """Is this line a class definition with a body consisting only of "..."?"""
1575         return self.is_class and self.leaves[-3:] == [
1576             Leaf(token.DOT, ".") for _ in range(3)
1577         ]
1578
1579     @property
1580     def is_def(self) -> bool:
1581         """Is this a function definition? (Also returns True for async defs.)"""
1582         try:
1583             first_leaf = self.leaves[0]
1584         except IndexError:
1585             return False
1586
1587         try:
1588             second_leaf: Optional[Leaf] = self.leaves[1]
1589         except IndexError:
1590             second_leaf = None
1591         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1592             first_leaf.type == token.ASYNC
1593             and second_leaf is not None
1594             and second_leaf.type == token.NAME
1595             and second_leaf.value == "def"
1596         )
1597
1598     @property
1599     def is_class_paren_empty(self) -> bool:
1600         """Is this a class with no base classes but using parentheses?
1601
1602         Those are unnecessary and should be removed.
1603         """
1604         return (
1605             bool(self)
1606             and len(self.leaves) == 4
1607             and self.is_class
1608             and self.leaves[2].type == token.LPAR
1609             and self.leaves[2].value == "("
1610             and self.leaves[3].type == token.RPAR
1611             and self.leaves[3].value == ")"
1612         )
1613
1614     @property
1615     def is_triple_quoted_string(self) -> bool:
1616         """Is the line a triple quoted string?"""
1617         return (
1618             bool(self)
1619             and self.leaves[0].type == token.STRING
1620             and self.leaves[0].value.startswith(('"""', "'''"))
1621         )
1622
1623     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1624         """If so, needs to be split before emitting."""
1625         for leaf in self.leaves:
1626             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1627                 return True
1628
1629         return False
1630
1631     def contains_uncollapsable_type_comments(self) -> bool:
1632         ignored_ids = set()
1633         try:
1634             last_leaf = self.leaves[-1]
1635             ignored_ids.add(id(last_leaf))
1636             if last_leaf.type == token.COMMA or (
1637                 last_leaf.type == token.RPAR and not last_leaf.value
1638             ):
1639                 # When trailing commas or optional parens are inserted by Black for
1640                 # consistency, comments after the previous last element are not moved
1641                 # (they don't have to, rendering will still be correct).  So we ignore
1642                 # trailing commas and invisible.
1643                 last_leaf = self.leaves[-2]
1644                 ignored_ids.add(id(last_leaf))
1645         except IndexError:
1646             return False
1647
1648         # A type comment is uncollapsable if it is attached to a leaf
1649         # that isn't at the end of the line (since that could cause it
1650         # to get associated to a different argument) or if there are
1651         # comments before it (since that could cause it to get hidden
1652         # behind a comment.
1653         comment_seen = False
1654         for leaf_id, comments in self.comments.items():
1655             for comment in comments:
1656                 if is_type_comment(comment):
1657                     if comment_seen or (
1658                         not is_type_comment(comment, " ignore")
1659                         and leaf_id not in ignored_ids
1660                     ):
1661                         return True
1662
1663                 comment_seen = True
1664
1665         return False
1666
1667     def contains_unsplittable_type_ignore(self) -> bool:
1668         if not self.leaves:
1669             return False
1670
1671         # If a 'type: ignore' is attached to the end of a line, we
1672         # can't split the line, because we can't know which of the
1673         # subexpressions the ignore was meant to apply to.
1674         #
1675         # We only want this to apply to actual physical lines from the
1676         # original source, though: we don't want the presence of a
1677         # 'type: ignore' at the end of a multiline expression to
1678         # justify pushing it all onto one line. Thus we
1679         # (unfortunately) need to check the actual source lines and
1680         # only report an unsplittable 'type: ignore' if this line was
1681         # one line in the original code.
1682
1683         # Grab the first and last line numbers, skipping generated leaves
1684         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1685         last_line = next(
1686             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1687         )
1688
1689         if first_line == last_line:
1690             # We look at the last two leaves since a comma or an
1691             # invisible paren could have been added at the end of the
1692             # line.
1693             for node in self.leaves[-2:]:
1694                 for comment in self.comments.get(id(node), []):
1695                     if is_type_comment(comment, " ignore"):
1696                         return True
1697
1698         return False
1699
1700     def contains_multiline_strings(self) -> bool:
1701         return any(is_multiline_string(leaf) for leaf in self.leaves)
1702
1703     def has_magic_trailing_comma(
1704         self, closing: Leaf, ensure_removable: bool = False
1705     ) -> bool:
1706         """Return True if we have a magic trailing comma, that is when:
1707         - there's a trailing comma here
1708         - it's not a one-tuple
1709         Additionally, if ensure_removable:
1710         - it's not from square bracket indexing
1711         """
1712         if not (
1713             closing.type in CLOSING_BRACKETS
1714             and self.leaves
1715             and self.leaves[-1].type == token.COMMA
1716         ):
1717             return False
1718
1719         if closing.type == token.RBRACE:
1720             return True
1721
1722         if closing.type == token.RSQB:
1723             if not ensure_removable:
1724                 return True
1725             comma = self.leaves[-1]
1726             return bool(comma.parent and comma.parent.type == syms.listmaker)
1727
1728         if self.is_import:
1729             return True
1730
1731         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1732             return True
1733
1734         return False
1735
1736     def append_comment(self, comment: Leaf) -> bool:
1737         """Add an inline or standalone comment to the line."""
1738         if (
1739             comment.type == STANDALONE_COMMENT
1740             and self.bracket_tracker.any_open_brackets()
1741         ):
1742             comment.prefix = ""
1743             return False
1744
1745         if comment.type != token.COMMENT:
1746             return False
1747
1748         if not self.leaves:
1749             comment.type = STANDALONE_COMMENT
1750             comment.prefix = ""
1751             return False
1752
1753         last_leaf = self.leaves[-1]
1754         if (
1755             last_leaf.type == token.RPAR
1756             and not last_leaf.value
1757             and last_leaf.parent
1758             and len(list(last_leaf.parent.leaves())) <= 3
1759             and not is_type_comment(comment)
1760         ):
1761             # Comments on an optional parens wrapping a single leaf should belong to
1762             # the wrapped node except if it's a type comment. Pinning the comment like
1763             # this avoids unstable formatting caused by comment migration.
1764             if len(self.leaves) < 2:
1765                 comment.type = STANDALONE_COMMENT
1766                 comment.prefix = ""
1767                 return False
1768
1769             last_leaf = self.leaves[-2]
1770         self.comments.setdefault(id(last_leaf), []).append(comment)
1771         return True
1772
1773     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1774         """Generate comments that should appear directly after `leaf`."""
1775         return self.comments.get(id(leaf), [])
1776
1777     def remove_trailing_comma(self) -> None:
1778         """Remove the trailing comma and moves the comments attached to it."""
1779         trailing_comma = self.leaves.pop()
1780         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1781         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1782             trailing_comma_comments
1783         )
1784
1785     def is_complex_subscript(self, leaf: Leaf) -> bool:
1786         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1787         open_lsqb = self.bracket_tracker.get_open_lsqb()
1788         if open_lsqb is None:
1789             return False
1790
1791         subscript_start = open_lsqb.next_sibling
1792
1793         if isinstance(subscript_start, Node):
1794             if subscript_start.type == syms.listmaker:
1795                 return False
1796
1797             if subscript_start.type == syms.subscriptlist:
1798                 subscript_start = child_towards(subscript_start, leaf)
1799         return subscript_start is not None and any(
1800             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1801         )
1802
1803     def clone(self) -> "Line":
1804         return Line(
1805             mode=self.mode,
1806             depth=self.depth,
1807             inside_brackets=self.inside_brackets,
1808             should_split_rhs=self.should_split_rhs,
1809             magic_trailing_comma=self.magic_trailing_comma,
1810         )
1811
1812     def __str__(self) -> str:
1813         """Render the line."""
1814         if not self:
1815             return "\n"
1816
1817         indent = "    " * self.depth
1818         leaves = iter(self.leaves)
1819         first = next(leaves)
1820         res = f"{first.prefix}{indent}{first.value}"
1821         for leaf in leaves:
1822             res += str(leaf)
1823         for comment in itertools.chain.from_iterable(self.comments.values()):
1824             res += str(comment)
1825
1826         return res + "\n"
1827
1828     def __bool__(self) -> bool:
1829         """Return True if the line has leaves or comments."""
1830         return bool(self.leaves or self.comments)
1831
1832
1833 @dataclass
1834 class EmptyLineTracker:
1835     """Provides a stateful method that returns the number of potential extra
1836     empty lines needed before and after the currently processed line.
1837
1838     Note: this tracker works on lines that haven't been split yet.  It assumes
1839     the prefix of the first leaf consists of optional newlines.  Those newlines
1840     are consumed by `maybe_empty_lines()` and included in the computation.
1841     """
1842
1843     is_pyi: bool = False
1844     previous_line: Optional[Line] = None
1845     previous_after: int = 0
1846     previous_defs: List[int] = field(default_factory=list)
1847
1848     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1849         """Return the number of extra empty lines before and after the `current_line`.
1850
1851         This is for separating `def`, `async def` and `class` with extra empty
1852         lines (two on module-level).
1853         """
1854         before, after = self._maybe_empty_lines(current_line)
1855         before = (
1856             # Black should not insert empty lines at the beginning
1857             # of the file
1858             0
1859             if self.previous_line is None
1860             else before - self.previous_after
1861         )
1862         self.previous_after = after
1863         self.previous_line = current_line
1864         return before, after
1865
1866     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1867         max_allowed = 1
1868         if current_line.depth == 0:
1869             max_allowed = 1 if self.is_pyi else 2
1870         if current_line.leaves:
1871             # Consume the first leaf's extra newlines.
1872             first_leaf = current_line.leaves[0]
1873             before = first_leaf.prefix.count("\n")
1874             before = min(before, max_allowed)
1875             first_leaf.prefix = ""
1876         else:
1877             before = 0
1878         depth = current_line.depth
1879         while self.previous_defs and self.previous_defs[-1] >= depth:
1880             self.previous_defs.pop()
1881             if self.is_pyi:
1882                 before = 0 if depth else 1
1883             else:
1884                 before = 1 if depth else 2
1885         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1886             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1887
1888         if (
1889             self.previous_line
1890             and self.previous_line.is_import
1891             and not current_line.is_import
1892             and depth == self.previous_line.depth
1893         ):
1894             return (before or 1), 0
1895
1896         if (
1897             self.previous_line
1898             and self.previous_line.is_class
1899             and current_line.is_triple_quoted_string
1900         ):
1901             return before, 1
1902
1903         return before, 0
1904
1905     def _maybe_empty_lines_for_class_or_def(
1906         self, current_line: Line, before: int
1907     ) -> Tuple[int, int]:
1908         if not current_line.is_decorator:
1909             self.previous_defs.append(current_line.depth)
1910         if self.previous_line is None:
1911             # Don't insert empty lines before the first line in the file.
1912             return 0, 0
1913
1914         if self.previous_line.is_decorator:
1915             if self.is_pyi and current_line.is_stub_class:
1916                 # Insert an empty line after a decorated stub class
1917                 return 0, 1
1918
1919             return 0, 0
1920
1921         if self.previous_line.depth < current_line.depth and (
1922             self.previous_line.is_class or self.previous_line.is_def
1923         ):
1924             return 0, 0
1925
1926         if (
1927             self.previous_line.is_comment
1928             and self.previous_line.depth == current_line.depth
1929             and before == 0
1930         ):
1931             return 0, 0
1932
1933         if self.is_pyi:
1934             if self.previous_line.depth > current_line.depth:
1935                 newlines = 1
1936             elif current_line.is_class or self.previous_line.is_class:
1937                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1938                     # No blank line between classes with an empty body
1939                     newlines = 0
1940                 else:
1941                     newlines = 1
1942             elif (
1943                 current_line.is_def or current_line.is_decorator
1944             ) and not self.previous_line.is_def:
1945                 # Blank line between a block of functions (maybe with preceding
1946                 # decorators) and a block of non-functions
1947                 newlines = 1
1948             else:
1949                 newlines = 0
1950         else:
1951             newlines = 2
1952         if current_line.depth and newlines:
1953             newlines -= 1
1954         return newlines, 0
1955
1956
1957 @dataclass
1958 class LineGenerator(Visitor[Line]):
1959     """Generates reformatted Line objects.  Empty lines are not emitted.
1960
1961     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1962     in ways that will no longer stringify to valid Python code on the tree.
1963     """
1964
1965     mode: Mode
1966     remove_u_prefix: bool = False
1967     current_line: Line = field(init=False)
1968
1969     def line(self, indent: int = 0) -> Iterator[Line]:
1970         """Generate a line.
1971
1972         If the line is empty, only emit if it makes sense.
1973         If the line is too long, split it first and then generate.
1974
1975         If any lines were generated, set up a new current_line.
1976         """
1977         if not self.current_line:
1978             self.current_line.depth += indent
1979             return  # Line is empty, don't emit. Creating a new one unnecessary.
1980
1981         complete_line = self.current_line
1982         self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
1983         yield complete_line
1984
1985     def visit_default(self, node: LN) -> Iterator[Line]:
1986         """Default `visit_*()` implementation. Recurses to children of `node`."""
1987         if isinstance(node, Leaf):
1988             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1989             for comment in generate_comments(node):
1990                 if any_open_brackets:
1991                     # any comment within brackets is subject to splitting
1992                     self.current_line.append(comment)
1993                 elif comment.type == token.COMMENT:
1994                     # regular trailing comment
1995                     self.current_line.append(comment)
1996                     yield from self.line()
1997
1998                 else:
1999                     # regular standalone comment
2000                     yield from self.line()
2001
2002                     self.current_line.append(comment)
2003                     yield from self.line()
2004
2005             normalize_prefix(node, inside_brackets=any_open_brackets)
2006             if self.mode.string_normalization and node.type == token.STRING:
2007                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
2008                 normalize_string_quotes(node)
2009             if node.type == token.NUMBER:
2010                 normalize_numeric_literal(node)
2011             if node.type not in WHITESPACE:
2012                 self.current_line.append(node)
2013         yield from super().visit_default(node)
2014
2015     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
2016         """Increase indentation level, maybe yield a line."""
2017         # In blib2to3 INDENT never holds comments.
2018         yield from self.line(+1)
2019         yield from self.visit_default(node)
2020
2021     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
2022         """Decrease indentation level, maybe yield a line."""
2023         # The current line might still wait for trailing comments.  At DEDENT time
2024         # there won't be any (they would be prefixes on the preceding NEWLINE).
2025         # Emit the line then.
2026         yield from self.line()
2027
2028         # While DEDENT has no value, its prefix may contain standalone comments
2029         # that belong to the current indentation level.  Get 'em.
2030         yield from self.visit_default(node)
2031
2032         # Finally, emit the dedent.
2033         yield from self.line(-1)
2034
2035     def visit_stmt(
2036         self, node: Node, keywords: Set[str], parens: Set[str]
2037     ) -> Iterator[Line]:
2038         """Visit a statement.
2039
2040         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2041         `def`, `with`, `class`, `assert` and assignments.
2042
2043         The relevant Python language `keywords` for a given statement will be
2044         NAME leaves within it. This methods puts those on a separate line.
2045
2046         `parens` holds a set of string leaf values immediately after which
2047         invisible parens should be put.
2048         """
2049         normalize_invisible_parens(node, parens_after=parens)
2050         for child in node.children:
2051             if child.type == token.NAME and child.value in keywords:  # type: ignore
2052                 yield from self.line()
2053
2054             yield from self.visit(child)
2055
2056     def visit_suite(self, node: Node) -> Iterator[Line]:
2057         """Visit a suite."""
2058         if self.mode.is_pyi and is_stub_suite(node):
2059             yield from self.visit(node.children[2])
2060         else:
2061             yield from self.visit_default(node)
2062
2063     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2064         """Visit a statement without nested statements."""
2065         if first_child_is_arith(node):
2066             wrap_in_parentheses(node, node.children[0], visible=False)
2067         is_suite_like = node.parent and node.parent.type in STATEMENT
2068         if is_suite_like:
2069             if self.mode.is_pyi and is_stub_body(node):
2070                 yield from self.visit_default(node)
2071             else:
2072                 yield from self.line(+1)
2073                 yield from self.visit_default(node)
2074                 yield from self.line(-1)
2075
2076         else:
2077             if (
2078                 not self.mode.is_pyi
2079                 or not node.parent
2080                 or not is_stub_suite(node.parent)
2081             ):
2082                 yield from self.line()
2083             yield from self.visit_default(node)
2084
2085     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2086         """Visit `async def`, `async for`, `async with`."""
2087         yield from self.line()
2088
2089         children = iter(node.children)
2090         for child in children:
2091             yield from self.visit(child)
2092
2093             if child.type == token.ASYNC:
2094                 break
2095
2096         internal_stmt = next(children)
2097         for child in internal_stmt.children:
2098             yield from self.visit(child)
2099
2100     def visit_decorators(self, node: Node) -> Iterator[Line]:
2101         """Visit decorators."""
2102         for child in node.children:
2103             yield from self.line()
2104             yield from self.visit(child)
2105
2106     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2107         """Remove a semicolon and put the other statement on a separate line."""
2108         yield from self.line()
2109
2110     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2111         """End of file. Process outstanding comments and end with a newline."""
2112         yield from self.visit_default(leaf)
2113         yield from self.line()
2114
2115     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2116         if not self.current_line.bracket_tracker.any_open_brackets():
2117             yield from self.line()
2118         yield from self.visit_default(leaf)
2119
2120     def visit_factor(self, node: Node) -> Iterator[Line]:
2121         """Force parentheses between a unary op and a binary power:
2122
2123         -2 ** 8 -> -(2 ** 8)
2124         """
2125         _operator, operand = node.children
2126         if (
2127             operand.type == syms.power
2128             and len(operand.children) == 3
2129             and operand.children[1].type == token.DOUBLESTAR
2130         ):
2131             lpar = Leaf(token.LPAR, "(")
2132             rpar = Leaf(token.RPAR, ")")
2133             index = operand.remove() or 0
2134             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2135         yield from self.visit_default(node)
2136
2137     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2138         if is_docstring(leaf) and "\\\n" not in leaf.value:
2139             # We're ignoring docstrings with backslash newline escapes because changing
2140             # indentation of those changes the AST representation of the code.
2141             prefix = get_string_prefix(leaf.value)
2142             lead_len = len(prefix) + 3
2143             tail_len = -3
2144             indent = " " * 4 * self.current_line.depth
2145             docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
2146             if docstring:
2147                 if leaf.value[lead_len - 1] == docstring[0]:
2148                     docstring = " " + docstring
2149                 if leaf.value[tail_len + 1] == docstring[-1]:
2150                     docstring = docstring + " "
2151             leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
2152
2153         yield from self.visit_default(leaf)
2154
2155     def __post_init__(self) -> None:
2156         """You are in a twisty little maze of passages."""
2157         self.current_line = Line(mode=self.mode)
2158
2159         v = self.visit_stmt
2160         Ø: Set[str] = set()
2161         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2162         self.visit_if_stmt = partial(
2163             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2164         )
2165         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2166         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2167         self.visit_try_stmt = partial(
2168             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2169         )
2170         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2171         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2172         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2173         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2174         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2175         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2176         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2177         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2178         self.visit_async_funcdef = self.visit_async_stmt
2179         self.visit_decorated = self.visit_decorators
2180
2181
2182 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2183 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2184 OPENING_BRACKETS = set(BRACKET.keys())
2185 CLOSING_BRACKETS = set(BRACKET.values())
2186 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2187 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2188
2189
2190 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2191     """Return whitespace prefix if needed for the given `leaf`.
2192
2193     `complex_subscript` signals whether the given leaf is part of a subscription
2194     which has non-trivial arguments, like arithmetic expressions or function calls.
2195     """
2196     NO = ""
2197     SPACE = " "
2198     DOUBLESPACE = "  "
2199     t = leaf.type
2200     p = leaf.parent
2201     v = leaf.value
2202     if t in ALWAYS_NO_SPACE:
2203         return NO
2204
2205     if t == token.COMMENT:
2206         return DOUBLESPACE
2207
2208     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2209     if t == token.COLON and p.type not in {
2210         syms.subscript,
2211         syms.subscriptlist,
2212         syms.sliceop,
2213     }:
2214         return NO
2215
2216     prev = leaf.prev_sibling
2217     if not prev:
2218         prevp = preceding_leaf(p)
2219         if not prevp or prevp.type in OPENING_BRACKETS:
2220             return NO
2221
2222         if t == token.COLON:
2223             if prevp.type == token.COLON:
2224                 return NO
2225
2226             elif prevp.type != token.COMMA and not complex_subscript:
2227                 return NO
2228
2229             return SPACE
2230
2231         if prevp.type == token.EQUAL:
2232             if prevp.parent:
2233                 if prevp.parent.type in {
2234                     syms.arglist,
2235                     syms.argument,
2236                     syms.parameters,
2237                     syms.varargslist,
2238                 }:
2239                     return NO
2240
2241                 elif prevp.parent.type == syms.typedargslist:
2242                     # A bit hacky: if the equal sign has whitespace, it means we
2243                     # previously found it's a typed argument.  So, we're using
2244                     # that, too.
2245                     return prevp.prefix
2246
2247         elif prevp.type in VARARGS_SPECIALS:
2248             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2249                 return NO
2250
2251         elif prevp.type == token.COLON:
2252             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2253                 return SPACE if complex_subscript else NO
2254
2255         elif (
2256             prevp.parent
2257             and prevp.parent.type == syms.factor
2258             and prevp.type in MATH_OPERATORS
2259         ):
2260             return NO
2261
2262         elif (
2263             prevp.type == token.RIGHTSHIFT
2264             and prevp.parent
2265             and prevp.parent.type == syms.shift_expr
2266             and prevp.prev_sibling
2267             and prevp.prev_sibling.type == token.NAME
2268             and prevp.prev_sibling.value == "print"  # type: ignore
2269         ):
2270             # Python 2 print chevron
2271             return NO
2272         elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
2273             # no space in decorators
2274             return NO
2275
2276     elif prev.type in OPENING_BRACKETS:
2277         return NO
2278
2279     if p.type in {syms.parameters, syms.arglist}:
2280         # untyped function signatures or calls
2281         if not prev or prev.type != token.COMMA:
2282             return NO
2283
2284     elif p.type == syms.varargslist:
2285         # lambdas
2286         if prev and prev.type != token.COMMA:
2287             return NO
2288
2289     elif p.type == syms.typedargslist:
2290         # typed function signatures
2291         if not prev:
2292             return NO
2293
2294         if t == token.EQUAL:
2295             if prev.type != syms.tname:
2296                 return NO
2297
2298         elif prev.type == token.EQUAL:
2299             # A bit hacky: if the equal sign has whitespace, it means we
2300             # previously found it's a typed argument.  So, we're using that, too.
2301             return prev.prefix
2302
2303         elif prev.type != token.COMMA:
2304             return NO
2305
2306     elif p.type == syms.tname:
2307         # type names
2308         if not prev:
2309             prevp = preceding_leaf(p)
2310             if not prevp or prevp.type != token.COMMA:
2311                 return NO
2312
2313     elif p.type == syms.trailer:
2314         # attributes and calls
2315         if t == token.LPAR or t == token.RPAR:
2316             return NO
2317
2318         if not prev:
2319             if t == token.DOT:
2320                 prevp = preceding_leaf(p)
2321                 if not prevp or prevp.type != token.NUMBER:
2322                     return NO
2323
2324             elif t == token.LSQB:
2325                 return NO
2326
2327         elif prev.type != token.COMMA:
2328             return NO
2329
2330     elif p.type == syms.argument:
2331         # single argument
2332         if t == token.EQUAL:
2333             return NO
2334
2335         if not prev:
2336             prevp = preceding_leaf(p)
2337             if not prevp or prevp.type == token.LPAR:
2338                 return NO
2339
2340         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2341             return NO
2342
2343     elif p.type == syms.decorator:
2344         # decorators
2345         return NO
2346
2347     elif p.type == syms.dotted_name:
2348         if prev:
2349             return NO
2350
2351         prevp = preceding_leaf(p)
2352         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2353             return NO
2354
2355     elif p.type == syms.classdef:
2356         if t == token.LPAR:
2357             return NO
2358
2359         if prev and prev.type == token.LPAR:
2360             return NO
2361
2362     elif p.type in {syms.subscript, syms.sliceop}:
2363         # indexing
2364         if not prev:
2365             assert p.parent is not None, "subscripts are always parented"
2366             if p.parent.type == syms.subscriptlist:
2367                 return SPACE
2368
2369             return NO
2370
2371         elif not complex_subscript:
2372             return NO
2373
2374     elif p.type == syms.atom:
2375         if prev and t == token.DOT:
2376             # dots, but not the first one.
2377             return NO
2378
2379     elif p.type == syms.dictsetmaker:
2380         # dict unpacking
2381         if prev and prev.type == token.DOUBLESTAR:
2382             return NO
2383
2384     elif p.type in {syms.factor, syms.star_expr}:
2385         # unary ops
2386         if not prev:
2387             prevp = preceding_leaf(p)
2388             if not prevp or prevp.type in OPENING_BRACKETS:
2389                 return NO
2390
2391             prevp_parent = prevp.parent
2392             assert prevp_parent is not None
2393             if prevp.type == token.COLON and prevp_parent.type in {
2394                 syms.subscript,
2395                 syms.sliceop,
2396             }:
2397                 return NO
2398
2399             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2400                 return NO
2401
2402         elif t in {token.NAME, token.NUMBER, token.STRING}:
2403             return NO
2404
2405     elif p.type == syms.import_from:
2406         if t == token.DOT:
2407             if prev and prev.type == token.DOT:
2408                 return NO
2409
2410         elif t == token.NAME:
2411             if v == "import":
2412                 return SPACE
2413
2414             if prev and prev.type == token.DOT:
2415                 return NO
2416
2417     elif p.type == syms.sliceop:
2418         return NO
2419
2420     return SPACE
2421
2422
2423 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2424     """Return the first leaf that precedes `node`, if any."""
2425     while node:
2426         res = node.prev_sibling
2427         if res:
2428             if isinstance(res, Leaf):
2429                 return res
2430
2431             try:
2432                 return list(res.leaves())[-1]
2433
2434             except IndexError:
2435                 return None
2436
2437         node = node.parent
2438     return None
2439
2440
2441 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2442     """Return if the `node` and its previous siblings match types against the provided
2443     list of tokens; the provided `node`has its type matched against the last element in
2444     the list.  `None` can be used as the first element to declare that the start of the
2445     list is anchored at the start of its parent's children."""
2446     if not tokens:
2447         return True
2448     if tokens[-1] is None:
2449         return node is None
2450     if not node:
2451         return False
2452     if node.type != tokens[-1]:
2453         return False
2454     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2455
2456
2457 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2458     """Return the child of `ancestor` that contains `descendant`."""
2459     node: Optional[LN] = descendant
2460     while node and node.parent != ancestor:
2461         node = node.parent
2462     return node
2463
2464
2465 def container_of(leaf: Leaf) -> LN:
2466     """Return `leaf` or one of its ancestors that is the topmost container of it.
2467
2468     By "container" we mean a node where `leaf` is the very first child.
2469     """
2470     same_prefix = leaf.prefix
2471     container: LN = leaf
2472     while container:
2473         parent = container.parent
2474         if parent is None:
2475             break
2476
2477         if parent.children[0].prefix != same_prefix:
2478             break
2479
2480         if parent.type == syms.file_input:
2481             break
2482
2483         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2484             break
2485
2486         container = parent
2487     return container
2488
2489
2490 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2491     """Return the priority of the `leaf` delimiter, given a line break after it.
2492
2493     The delimiter priorities returned here are from those delimiters that would
2494     cause a line break after themselves.
2495
2496     Higher numbers are higher priority.
2497     """
2498     if leaf.type == token.COMMA:
2499         return COMMA_PRIORITY
2500
2501     return 0
2502
2503
2504 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2505     """Return the priority of the `leaf` delimiter, given a line break before it.
2506
2507     The delimiter priorities returned here are from those delimiters that would
2508     cause a line break before themselves.
2509
2510     Higher numbers are higher priority.
2511     """
2512     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2513         # * and ** might also be MATH_OPERATORS but in this case they are not.
2514         # Don't treat them as a delimiter.
2515         return 0
2516
2517     if (
2518         leaf.type == token.DOT
2519         and leaf.parent
2520         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2521         and (previous is None or previous.type in CLOSING_BRACKETS)
2522     ):
2523         return DOT_PRIORITY
2524
2525     if (
2526         leaf.type in MATH_OPERATORS
2527         and leaf.parent
2528         and leaf.parent.type not in {syms.factor, syms.star_expr}
2529     ):
2530         return MATH_PRIORITIES[leaf.type]
2531
2532     if leaf.type in COMPARATORS:
2533         return COMPARATOR_PRIORITY
2534
2535     if (
2536         leaf.type == token.STRING
2537         and previous is not None
2538         and previous.type == token.STRING
2539     ):
2540         return STRING_PRIORITY
2541
2542     if leaf.type not in {token.NAME, token.ASYNC}:
2543         return 0
2544
2545     if (
2546         leaf.value == "for"
2547         and leaf.parent
2548         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2549         or leaf.type == token.ASYNC
2550     ):
2551         if (
2552             not isinstance(leaf.prev_sibling, Leaf)
2553             or leaf.prev_sibling.value != "async"
2554         ):
2555             return COMPREHENSION_PRIORITY
2556
2557     if (
2558         leaf.value == "if"
2559         and leaf.parent
2560         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2561     ):
2562         return COMPREHENSION_PRIORITY
2563
2564     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2565         return TERNARY_PRIORITY
2566
2567     if leaf.value == "is":
2568         return COMPARATOR_PRIORITY
2569
2570     if (
2571         leaf.value == "in"
2572         and leaf.parent
2573         and leaf.parent.type in {syms.comp_op, syms.comparison}
2574         and not (
2575             previous is not None
2576             and previous.type == token.NAME
2577             and previous.value == "not"
2578         )
2579     ):
2580         return COMPARATOR_PRIORITY
2581
2582     if (
2583         leaf.value == "not"
2584         and leaf.parent
2585         and leaf.parent.type == syms.comp_op
2586         and not (
2587             previous is not None
2588             and previous.type == token.NAME
2589             and previous.value == "is"
2590         )
2591     ):
2592         return COMPARATOR_PRIORITY
2593
2594     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2595         return LOGIC_PRIORITY
2596
2597     return 0
2598
2599
2600 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2601 FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
2602 FMT_PASS = {*FMT_OFF, *FMT_SKIP}
2603 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2604
2605
2606 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2607     """Clean the prefix of the `leaf` and generate comments from it, if any.
2608
2609     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2610     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2611     move because it does away with modifying the grammar to include all the
2612     possible places in which comments can be placed.
2613
2614     The sad consequence for us though is that comments don't "belong" anywhere.
2615     This is why this function generates simple parentless Leaf objects for
2616     comments.  We simply don't know what the correct parent should be.
2617
2618     No matter though, we can live without this.  We really only need to
2619     differentiate between inline and standalone comments.  The latter don't
2620     share the line with any code.
2621
2622     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2623     are emitted with a fake STANDALONE_COMMENT token identifier.
2624     """
2625     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2626         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2627
2628
2629 @dataclass
2630 class ProtoComment:
2631     """Describes a piece of syntax that is a comment.
2632
2633     It's not a :class:`blib2to3.pytree.Leaf` so that:
2634
2635     * it can be cached (`Leaf` objects should not be reused more than once as
2636       they store their lineno, column, prefix, and parent information);
2637     * `newlines` and `consumed` fields are kept separate from the `value`. This
2638       simplifies handling of special marker comments like ``# fmt: off/on``.
2639     """
2640
2641     type: int  # token.COMMENT or STANDALONE_COMMENT
2642     value: str  # content of the comment
2643     newlines: int  # how many newlines before the comment
2644     consumed: int  # how many characters of the original leaf's prefix did we consume
2645
2646
2647 @lru_cache(maxsize=4096)
2648 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2649     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2650     result: List[ProtoComment] = []
2651     if not prefix or "#" not in prefix:
2652         return result
2653
2654     consumed = 0
2655     nlines = 0
2656     ignored_lines = 0
2657     for index, line in enumerate(re.split("\r?\n", prefix)):
2658         consumed += len(line) + 1  # adding the length of the split '\n'
2659         line = line.lstrip()
2660         if not line:
2661             nlines += 1
2662         if not line.startswith("#"):
2663             # Escaped newlines outside of a comment are not really newlines at
2664             # all. We treat a single-line comment following an escaped newline
2665             # as a simple trailing comment.
2666             if line.endswith("\\"):
2667                 ignored_lines += 1
2668             continue
2669
2670         if index == ignored_lines and not is_endmarker:
2671             comment_type = token.COMMENT  # simple trailing comment
2672         else:
2673             comment_type = STANDALONE_COMMENT
2674         comment = make_comment(line)
2675         result.append(
2676             ProtoComment(
2677                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2678             )
2679         )
2680         nlines = 0
2681     return result
2682
2683
2684 def make_comment(content: str) -> str:
2685     """Return a consistently formatted comment from the given `content` string.
2686
2687     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2688     space between the hash sign and the content.
2689
2690     If `content` didn't start with a hash sign, one is provided.
2691     """
2692     content = content.rstrip()
2693     if not content:
2694         return "#"
2695
2696     if content[0] == "#":
2697         content = content[1:]
2698     if content and content[0] not in " !:#'%":
2699         content = " " + content
2700     return "#" + content
2701
2702
2703 def transform_line(
2704     line: Line, mode: Mode, features: Collection[Feature] = ()
2705 ) -> Iterator[Line]:
2706     """Transform a `line`, potentially splitting it into many lines.
2707
2708     They should fit in the allotted `line_length` but might not be able to.
2709
2710     `features` are syntactical features that may be used in the output.
2711     """
2712     if line.is_comment:
2713         yield line
2714         return
2715
2716     line_str = line_to_string(line)
2717
2718     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2719         """Initialize StringTransformer"""
2720         return ST(mode.line_length, mode.string_normalization)
2721
2722     string_merge = init_st(StringMerger)
2723     string_paren_strip = init_st(StringParenStripper)
2724     string_split = init_st(StringSplitter)
2725     string_paren_wrap = init_st(StringParenWrapper)
2726
2727     transformers: List[Transformer]
2728     if (
2729         not line.contains_uncollapsable_type_comments()
2730         and not line.should_split_rhs
2731         and not line.magic_trailing_comma
2732         and (
2733             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2734             or line.contains_unsplittable_type_ignore()
2735         )
2736         and not (line.inside_brackets and line.contains_standalone_comments())
2737     ):
2738         # Only apply basic string preprocessing, since lines shouldn't be split here.
2739         if mode.experimental_string_processing:
2740             transformers = [string_merge, string_paren_strip]
2741         else:
2742             transformers = []
2743     elif line.is_def:
2744         transformers = [left_hand_split]
2745     else:
2746
2747         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2748             """Wraps calls to `right_hand_split`.
2749
2750             The calls increasingly `omit` right-hand trailers (bracket pairs with
2751             content), meaning the trailers get glued together to split on another
2752             bracket pair instead.
2753             """
2754             for omit in generate_trailers_to_omit(line, mode.line_length):
2755                 lines = list(
2756                     right_hand_split(line, mode.line_length, features, omit=omit)
2757                 )
2758                 # Note: this check is only able to figure out if the first line of the
2759                 # *current* transformation fits in the line length.  This is true only
2760                 # for simple cases.  All others require running more transforms via
2761                 # `transform_line()`.  This check doesn't know if those would succeed.
2762                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2763                     yield from lines
2764                     return
2765
2766             # All splits failed, best effort split with no omits.
2767             # This mostly happens to multiline strings that are by definition
2768             # reported as not fitting a single line, as well as lines that contain
2769             # trailing commas (those have to be exploded).
2770             yield from right_hand_split(
2771                 line, line_length=mode.line_length, features=features
2772             )
2773
2774         if mode.experimental_string_processing:
2775             if line.inside_brackets:
2776                 transformers = [
2777                     string_merge,
2778                     string_paren_strip,
2779                     string_split,
2780                     delimiter_split,
2781                     standalone_comment_split,
2782                     string_paren_wrap,
2783                     rhs,
2784                 ]
2785             else:
2786                 transformers = [
2787                     string_merge,
2788                     string_paren_strip,
2789                     string_split,
2790                     string_paren_wrap,
2791                     rhs,
2792                 ]
2793         else:
2794             if line.inside_brackets:
2795                 transformers = [delimiter_split, standalone_comment_split, rhs]
2796             else:
2797                 transformers = [rhs]
2798
2799     for transform in transformers:
2800         # We are accumulating lines in `result` because we might want to abort
2801         # mission and return the original line in the end, or attempt a different
2802         # split altogether.
2803         try:
2804             result = run_transformer(line, transform, mode, features, line_str=line_str)
2805         except CannotTransform:
2806             continue
2807         else:
2808             yield from result
2809             break
2810
2811     else:
2812         yield line
2813
2814
2815 @dataclass  # type: ignore
2816 class StringTransformer(ABC):
2817     """
2818     An implementation of the Transformer protocol that relies on its
2819     subclasses overriding the template methods `do_match(...)` and
2820     `do_transform(...)`.
2821
2822     This Transformer works exclusively on strings (for example, by merging
2823     or splitting them).
2824
2825     The following sections can be found among the docstrings of each concrete
2826     StringTransformer subclass.
2827
2828     Requirements:
2829         Which requirements must be met of the given Line for this
2830         StringTransformer to be applied?
2831
2832     Transformations:
2833         If the given Line meets all of the above requirements, which string
2834         transformations can you expect to be applied to it by this
2835         StringTransformer?
2836
2837     Collaborations:
2838         What contractual agreements does this StringTransformer have with other
2839         StringTransfomers? Such collaborations should be eliminated/minimized
2840         as much as possible.
2841     """
2842
2843     line_length: int
2844     normalize_strings: bool
2845     __name__ = "StringTransformer"
2846
2847     @abstractmethod
2848     def do_match(self, line: Line) -> TMatchResult:
2849         """
2850         Returns:
2851             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2852             string, if a match was able to be made.
2853                 OR
2854             * Err(CannotTransform), if a match was not able to be made.
2855         """
2856
2857     @abstractmethod
2858     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2859         """
2860         Yields:
2861             * Ok(new_line) where new_line is the new transformed line.
2862                 OR
2863             * Err(CannotTransform) if the transformation failed for some reason. The
2864             `do_match(...)` template method should usually be used to reject
2865             the form of the given Line, but in some cases it is difficult to
2866             know whether or not a Line meets the StringTransformer's
2867             requirements until the transformation is already midway.
2868
2869         Side Effects:
2870             This method should NOT mutate @line directly, but it MAY mutate the
2871             Line's underlying Node structure. (WARNING: If the underlying Node
2872             structure IS altered, then this method should NOT be allowed to
2873             yield an CannotTransform after that point.)
2874         """
2875
2876     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2877         """
2878         StringTransformer instances have a call signature that mirrors that of
2879         the Transformer type.
2880
2881         Raises:
2882             CannotTransform(...) if the concrete StringTransformer class is unable
2883             to transform @line.
2884         """
2885         # Optimization to avoid calling `self.do_match(...)` when the line does
2886         # not contain any string.
2887         if not any(leaf.type == token.STRING for leaf in line.leaves):
2888             raise CannotTransform("There are no strings in this line.")
2889
2890         match_result = self.do_match(line)
2891
2892         if isinstance(match_result, Err):
2893             cant_transform = match_result.err()
2894             raise CannotTransform(
2895                 f"The string transformer {self.__class__.__name__} does not recognize"
2896                 " this line as one that it can transform."
2897             ) from cant_transform
2898
2899         string_idx = match_result.ok()
2900
2901         for line_result in self.do_transform(line, string_idx):
2902             if isinstance(line_result, Err):
2903                 cant_transform = line_result.err()
2904                 raise CannotTransform(
2905                     "StringTransformer failed while attempting to transform string."
2906                 ) from cant_transform
2907             line = line_result.ok()
2908             yield line
2909
2910
2911 @dataclass
2912 class CustomSplit:
2913     """A custom (i.e. manual) string split.
2914
2915     A single CustomSplit instance represents a single substring.
2916
2917     Examples:
2918         Consider the following string:
2919         ```
2920         "Hi there friend."
2921         " This is a custom"
2922         f" string {split}."
2923         ```
2924
2925         This string will correspond to the following three CustomSplit instances:
2926         ```
2927         CustomSplit(False, 16)
2928         CustomSplit(False, 17)
2929         CustomSplit(True, 16)
2930         ```
2931     """
2932
2933     has_prefix: bool
2934     break_idx: int
2935
2936
2937 class CustomSplitMapMixin:
2938     """
2939     This mixin class is used to map merged strings to a sequence of
2940     CustomSplits, which will then be used to re-split the strings iff none of
2941     the resultant substrings go over the configured max line length.
2942     """
2943
2944     _Key = Tuple[StringID, str]
2945     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2946
2947     @staticmethod
2948     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2949         """
2950         Returns:
2951             A unique identifier that is used internally to map @string to a
2952             group of custom splits.
2953         """
2954         return (id(string), string)
2955
2956     def add_custom_splits(
2957         self, string: str, custom_splits: Iterable[CustomSplit]
2958     ) -> None:
2959         """Custom Split Map Setter Method
2960
2961         Side Effects:
2962             Adds a mapping from @string to the custom splits @custom_splits.
2963         """
2964         key = self._get_key(string)
2965         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2966
2967     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2968         """Custom Split Map Getter Method
2969
2970         Returns:
2971             * A list of the custom splits that are mapped to @string, if any
2972             exist.
2973                 OR
2974             * [], otherwise.
2975
2976         Side Effects:
2977             Deletes the mapping between @string and its associated custom
2978             splits (which are returned to the caller).
2979         """
2980         key = self._get_key(string)
2981
2982         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2983         del self._CUSTOM_SPLIT_MAP[key]
2984
2985         return list(custom_splits)
2986
2987     def has_custom_splits(self, string: str) -> bool:
2988         """
2989         Returns:
2990             True iff @string is associated with a set of custom splits.
2991         """
2992         key = self._get_key(string)
2993         return key in self._CUSTOM_SPLIT_MAP
2994
2995
2996 class StringMerger(CustomSplitMapMixin, StringTransformer):
2997     """StringTransformer that merges strings together.
2998
2999     Requirements:
3000         (A) The line contains adjacent strings such that ALL of the validation checks
3001         listed in StringMerger.__validate_msg(...)'s docstring pass.
3002             OR
3003         (B) The line contains a string which uses line continuation backslashes.
3004
3005     Transformations:
3006         Depending on which of the two requirements above where met, either:
3007
3008         (A) The string group associated with the target string is merged.
3009             OR
3010         (B) All line-continuation backslashes are removed from the target string.
3011
3012     Collaborations:
3013         StringMerger provides custom split information to StringSplitter.
3014     """
3015
3016     def do_match(self, line: Line) -> TMatchResult:
3017         LL = line.leaves
3018
3019         is_valid_index = is_valid_index_factory(LL)
3020
3021         for (i, leaf) in enumerate(LL):
3022             if (
3023                 leaf.type == token.STRING
3024                 and is_valid_index(i + 1)
3025                 and LL[i + 1].type == token.STRING
3026             ):
3027                 return Ok(i)
3028
3029             if leaf.type == token.STRING and "\\\n" in leaf.value:
3030                 return Ok(i)
3031
3032         return TErr("This line has no strings that need merging.")
3033
3034     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3035         new_line = line
3036         rblc_result = self.__remove_backslash_line_continuation_chars(
3037             new_line, string_idx
3038         )
3039         if isinstance(rblc_result, Ok):
3040             new_line = rblc_result.ok()
3041
3042         msg_result = self.__merge_string_group(new_line, string_idx)
3043         if isinstance(msg_result, Ok):
3044             new_line = msg_result.ok()
3045
3046         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
3047             msg_cant_transform = msg_result.err()
3048             rblc_cant_transform = rblc_result.err()
3049             cant_transform = CannotTransform(
3050                 "StringMerger failed to merge any strings in this line."
3051             )
3052
3053             # Chain the errors together using `__cause__`.
3054             msg_cant_transform.__cause__ = rblc_cant_transform
3055             cant_transform.__cause__ = msg_cant_transform
3056
3057             yield Err(cant_transform)
3058         else:
3059             yield Ok(new_line)
3060
3061     @staticmethod
3062     def __remove_backslash_line_continuation_chars(
3063         line: Line, string_idx: int
3064     ) -> TResult[Line]:
3065         """
3066         Merge strings that were split across multiple lines using
3067         line-continuation backslashes.
3068
3069         Returns:
3070             Ok(new_line), if @line contains backslash line-continuation
3071             characters.
3072                 OR
3073             Err(CannotTransform), otherwise.
3074         """
3075         LL = line.leaves
3076
3077         string_leaf = LL[string_idx]
3078         if not (
3079             string_leaf.type == token.STRING
3080             and "\\\n" in string_leaf.value
3081             and not has_triple_quotes(string_leaf.value)
3082         ):
3083             return TErr(
3084                 f"String leaf {string_leaf} does not contain any backslash line"
3085                 " continuation characters."
3086             )
3087
3088         new_line = line.clone()
3089         new_line.comments = line.comments.copy()
3090         append_leaves(new_line, line, LL)
3091
3092         new_string_leaf = new_line.leaves[string_idx]
3093         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3094
3095         return Ok(new_line)
3096
3097     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3098         """
3099         Merges string group (i.e. set of adjacent strings) where the first
3100         string in the group is `line.leaves[string_idx]`.
3101
3102         Returns:
3103             Ok(new_line), if ALL of the validation checks found in
3104             __validate_msg(...) pass.
3105                 OR
3106             Err(CannotTransform), otherwise.
3107         """
3108         LL = line.leaves
3109
3110         is_valid_index = is_valid_index_factory(LL)
3111
3112         vresult = self.__validate_msg(line, string_idx)
3113         if isinstance(vresult, Err):
3114             return vresult
3115
3116         # If the string group is wrapped inside an Atom node, we must make sure
3117         # to later replace that Atom with our new (merged) string leaf.
3118         atom_node = LL[string_idx].parent
3119
3120         # We will place BREAK_MARK in between every two substrings that we
3121         # merge. We will then later go through our final result and use the
3122         # various instances of BREAK_MARK we find to add the right values to
3123         # the custom split map.
3124         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3125
3126         QUOTE = LL[string_idx].value[-1]
3127
3128         def make_naked(string: str, string_prefix: str) -> str:
3129             """Strip @string (i.e. make it a "naked" string)
3130
3131             Pre-conditions:
3132                 * assert_is_leaf_string(@string)
3133
3134             Returns:
3135                 A string that is identical to @string except that
3136                 @string_prefix has been stripped, the surrounding QUOTE
3137                 characters have been removed, and any remaining QUOTE
3138                 characters have been escaped.
3139             """
3140             assert_is_leaf_string(string)
3141
3142             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3143             naked_string = string[len(string_prefix) + 1 : -1]
3144             naked_string = re.sub(
3145                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3146             )
3147             return naked_string
3148
3149         # Holds the CustomSplit objects that will later be added to the custom
3150         # split map.
3151         custom_splits = []
3152
3153         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3154         prefix_tracker = []
3155
3156         # Sets the 'prefix' variable. This is the prefix that the final merged
3157         # string will have.
3158         next_str_idx = string_idx
3159         prefix = ""
3160         while (
3161             not prefix
3162             and is_valid_index(next_str_idx)
3163             and LL[next_str_idx].type == token.STRING
3164         ):
3165             prefix = get_string_prefix(LL[next_str_idx].value)
3166             next_str_idx += 1
3167
3168         # The next loop merges the string group. The final string will be
3169         # contained in 'S'.
3170         #
3171         # The following convenience variables are used:
3172         #
3173         #   S: string
3174         #   NS: naked string
3175         #   SS: next string
3176         #   NSS: naked next string
3177         S = ""
3178         NS = ""
3179         num_of_strings = 0
3180         next_str_idx = string_idx
3181         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3182             num_of_strings += 1
3183
3184             SS = LL[next_str_idx].value
3185             next_prefix = get_string_prefix(SS)
3186
3187             # If this is an f-string group but this substring is not prefixed
3188             # with 'f'...
3189             if "f" in prefix and "f" not in next_prefix:
3190                 # Then we must escape any braces contained in this substring.
3191                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3192
3193             NSS = make_naked(SS, next_prefix)
3194
3195             has_prefix = bool(next_prefix)
3196             prefix_tracker.append(has_prefix)
3197
3198             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3199             NS = make_naked(S, prefix)
3200
3201             next_str_idx += 1
3202
3203         S_leaf = Leaf(token.STRING, S)
3204         if self.normalize_strings:
3205             normalize_string_quotes(S_leaf)
3206
3207         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3208         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3209         for has_prefix in prefix_tracker:
3210             mark_idx = temp_string.find(BREAK_MARK)
3211             assert (
3212                 mark_idx >= 0
3213             ), "Logic error while filling the custom string breakpoint cache."
3214
3215             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3216             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3217             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3218
3219         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3220
3221         if atom_node is not None:
3222             replace_child(atom_node, string_leaf)
3223
3224         # Build the final line ('new_line') that this method will later return.
3225         new_line = line.clone()
3226         for (i, leaf) in enumerate(LL):
3227             if i == string_idx:
3228                 new_line.append(string_leaf)
3229
3230             if string_idx <= i < string_idx + num_of_strings:
3231                 for comment_leaf in line.comments_after(LL[i]):
3232                     new_line.append(comment_leaf, preformatted=True)
3233                 continue
3234
3235             append_leaves(new_line, line, [leaf])
3236
3237         self.add_custom_splits(string_leaf.value, custom_splits)
3238         return Ok(new_line)
3239
3240     @staticmethod
3241     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3242         """Validate (M)erge (S)tring (G)roup
3243
3244         Transform-time string validation logic for __merge_string_group(...).
3245
3246         Returns:
3247             * Ok(None), if ALL validation checks (listed below) pass.
3248                 OR
3249             * Err(CannotTransform), if any of the following are true:
3250                 - The target string group does not contain ANY stand-alone comments.
3251                 - The target string is not in a string group (i.e. it has no
3252                   adjacent strings).
3253                 - The string group has more than one inline comment.
3254                 - The string group has an inline comment that appears to be a pragma.
3255                 - The set of all string prefixes in the string group is of
3256                   length greater than one and is not equal to {"", "f"}.
3257                 - The string group consists of raw strings.
3258         """
3259         # We first check for "inner" stand-alone comments (i.e. stand-alone
3260         # comments that have a string leaf before them AND after them).
3261         for inc in [1, -1]:
3262             i = string_idx
3263             found_sa_comment = False
3264             is_valid_index = is_valid_index_factory(line.leaves)
3265             while is_valid_index(i) and line.leaves[i].type in [
3266                 token.STRING,
3267                 STANDALONE_COMMENT,
3268             ]:
3269                 if line.leaves[i].type == STANDALONE_COMMENT:
3270                     found_sa_comment = True
3271                 elif found_sa_comment:
3272                     return TErr(
3273                         "StringMerger does NOT merge string groups which contain "
3274                         "stand-alone comments."
3275                     )
3276
3277                 i += inc
3278
3279         num_of_inline_string_comments = 0
3280         set_of_prefixes = set()
3281         num_of_strings = 0
3282         for leaf in line.leaves[string_idx:]:
3283             if leaf.type != token.STRING:
3284                 # If the string group is trailed by a comma, we count the
3285                 # comments trailing the comma to be one of the string group's
3286                 # comments.
3287                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3288                     num_of_inline_string_comments += 1
3289                 break
3290
3291             if has_triple_quotes(leaf.value):
3292                 return TErr("StringMerger does NOT merge multiline strings.")
3293
3294             num_of_strings += 1
3295             prefix = get_string_prefix(leaf.value)
3296             if "r" in prefix:
3297                 return TErr("StringMerger does NOT merge raw strings.")
3298
3299             set_of_prefixes.add(prefix)
3300
3301             if id(leaf) in line.comments:
3302                 num_of_inline_string_comments += 1
3303                 if contains_pragma_comment(line.comments[id(leaf)]):
3304                     return TErr("Cannot merge strings which have pragma comments.")
3305
3306         if num_of_strings < 2:
3307             return TErr(
3308                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3309             )
3310
3311         if num_of_inline_string_comments > 1:
3312             return TErr(
3313                 f"Too many inline string comments ({num_of_inline_string_comments})."
3314             )
3315
3316         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3317             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3318
3319         return Ok(None)
3320
3321
3322 class StringParenStripper(StringTransformer):
3323     """StringTransformer that strips surrounding parentheses from strings.
3324
3325     Requirements:
3326         The line contains a string which is surrounded by parentheses and:
3327             - The target string is NOT the only argument to a function call.
3328             - The target string is NOT a "pointless" string.
3329             - If the target string contains a PERCENT, the brackets are not
3330               preceeded or followed by an operator with higher precedence than
3331               PERCENT.
3332
3333     Transformations:
3334         The parentheses mentioned in the 'Requirements' section are stripped.
3335
3336     Collaborations:
3337         StringParenStripper has its own inherent usefulness, but it is also
3338         relied on to clean up the parentheses created by StringParenWrapper (in
3339         the event that they are no longer needed).
3340     """
3341
3342     def do_match(self, line: Line) -> TMatchResult:
3343         LL = line.leaves
3344
3345         is_valid_index = is_valid_index_factory(LL)
3346
3347         for (idx, leaf) in enumerate(LL):
3348             # Should be a string...
3349             if leaf.type != token.STRING:
3350                 continue
3351
3352             # If this is a "pointless" string...
3353             if (
3354                 leaf.parent
3355                 and leaf.parent.parent
3356                 and leaf.parent.parent.type == syms.simple_stmt
3357             ):
3358                 continue
3359
3360             # Should be preceded by a non-empty LPAR...
3361             if (
3362                 not is_valid_index(idx - 1)
3363                 or LL[idx - 1].type != token.LPAR
3364                 or is_empty_lpar(LL[idx - 1])
3365             ):
3366                 continue
3367
3368             # That LPAR should NOT be preceded by a function name or a closing
3369             # bracket (which could be a function which returns a function or a
3370             # list/dictionary that contains a function)...
3371             if is_valid_index(idx - 2) and (
3372                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3373             ):
3374                 continue
3375
3376             string_idx = idx
3377
3378             # Skip the string trailer, if one exists.
3379             string_parser = StringParser()
3380             next_idx = string_parser.parse(LL, string_idx)
3381
3382             # if the leaves in the parsed string include a PERCENT, we need to
3383             # make sure the initial LPAR is NOT preceded by an operator with
3384             # higher or equal precedence to PERCENT
3385             if is_valid_index(idx - 2):
3386                 # mypy can't quite follow unless we name this
3387                 before_lpar = LL[idx - 2]
3388                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3389                     (
3390                         before_lpar.type
3391                         in {
3392                             token.STAR,
3393                             token.AT,
3394                             token.SLASH,
3395                             token.DOUBLESLASH,
3396                             token.PERCENT,
3397                             token.TILDE,
3398                             token.DOUBLESTAR,
3399                             token.AWAIT,
3400                             token.LSQB,
3401                             token.LPAR,
3402                         }
3403                     )
3404                     or (
3405                         # only unary PLUS/MINUS
3406                         before_lpar.parent
3407                         and before_lpar.parent.type == syms.factor
3408                         and (before_lpar.type in {token.PLUS, token.MINUS})
3409                     )
3410                 ):
3411                     continue
3412
3413             # Should be followed by a non-empty RPAR...
3414             if (
3415                 is_valid_index(next_idx)
3416                 and LL[next_idx].type == token.RPAR
3417                 and not is_empty_rpar(LL[next_idx])
3418             ):
3419                 # That RPAR should NOT be followed by anything with higher
3420                 # precedence than PERCENT
3421                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3422                     token.DOUBLESTAR,
3423                     token.LSQB,
3424                     token.LPAR,
3425                     token.DOT,
3426                 }:
3427                     continue
3428
3429                 return Ok(string_idx)
3430
3431         return TErr("This line has no strings wrapped in parens.")
3432
3433     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3434         LL = line.leaves
3435
3436         string_parser = StringParser()
3437         rpar_idx = string_parser.parse(LL, string_idx)
3438
3439         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3440             if line.comments_after(leaf):
3441                 yield TErr(
3442                     "Will not strip parentheses which have comments attached to them."
3443                 )
3444                 return
3445
3446         new_line = line.clone()
3447         new_line.comments = line.comments.copy()
3448         try:
3449             append_leaves(new_line, line, LL[: string_idx - 1])
3450         except BracketMatchError:
3451             # HACK: I believe there is currently a bug somewhere in
3452             # right_hand_split() that is causing brackets to not be tracked
3453             # properly by a shared BracketTracker.
3454             append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
3455
3456         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3457         LL[string_idx - 1].remove()
3458         replace_child(LL[string_idx], string_leaf)
3459         new_line.append(string_leaf)
3460
3461         append_leaves(
3462             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3463         )
3464
3465         LL[rpar_idx].remove()
3466
3467         yield Ok(new_line)
3468
3469
3470 class BaseStringSplitter(StringTransformer):
3471     """
3472     Abstract class for StringTransformers which transform a Line's strings by splitting
3473     them or placing them on their own lines where necessary to avoid going over
3474     the configured line length.
3475
3476     Requirements:
3477         * The target string value is responsible for the line going over the
3478         line length limit. It follows that after all of black's other line
3479         split methods have been exhausted, this line (or one of the resulting
3480         lines after all line splits are performed) would still be over the
3481         line_length limit unless we split this string.
3482             AND
3483         * The target string is NOT a "pointless" string (i.e. a string that has
3484         no parent or siblings).
3485             AND
3486         * The target string is not followed by an inline comment that appears
3487         to be a pragma.
3488             AND
3489         * The target string is not a multiline (i.e. triple-quote) string.
3490     """
3491
3492     @abstractmethod
3493     def do_splitter_match(self, line: Line) -> TMatchResult:
3494         """
3495         BaseStringSplitter asks its clients to override this method instead of
3496         `StringTransformer.do_match(...)`.
3497
3498         Follows the same protocol as `StringTransformer.do_match(...)`.
3499
3500         Refer to `help(StringTransformer.do_match)` for more information.
3501         """
3502
3503     def do_match(self, line: Line) -> TMatchResult:
3504         match_result = self.do_splitter_match(line)
3505         if isinstance(match_result, Err):
3506             return match_result
3507
3508         string_idx = match_result.ok()
3509         vresult = self.__validate(line, string_idx)
3510         if isinstance(vresult, Err):
3511             return vresult
3512
3513         return match_result
3514
3515     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3516         """
3517         Checks that @line meets all of the requirements listed in this classes'
3518         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3519         description of those requirements.
3520
3521         Returns:
3522             * Ok(None), if ALL of the requirements are met.
3523                 OR
3524             * Err(CannotTransform), if ANY of the requirements are NOT met.
3525         """
3526         LL = line.leaves
3527
3528         string_leaf = LL[string_idx]
3529
3530         max_string_length = self.__get_max_string_length(line, string_idx)
3531         if len(string_leaf.value) <= max_string_length:
3532             return TErr(
3533                 "The string itself is not what is causing this line to be too long."
3534             )
3535
3536         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3537             token.STRING,
3538             token.NEWLINE,
3539         ]:
3540             return TErr(
3541                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3542                 " no parent)."
3543             )
3544
3545         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3546             line.comments[id(line.leaves[string_idx])]
3547         ):
3548             return TErr(
3549                 "Line appears to end with an inline pragma comment. Splitting the line"
3550                 " could modify the pragma's behavior."
3551             )
3552
3553         if has_triple_quotes(string_leaf.value):
3554             return TErr("We cannot split multiline strings.")
3555
3556         return Ok(None)
3557
3558     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3559         """
3560         Calculates the max string length used when attempting to determine
3561         whether or not the target string is responsible for causing the line to
3562         go over the line length limit.
3563
3564         WARNING: This method is tightly coupled to both StringSplitter and
3565         (especially) StringParenWrapper. There is probably a better way to
3566         accomplish what is being done here.
3567
3568         Returns:
3569             max_string_length: such that `line.leaves[string_idx].value >
3570             max_string_length` implies that the target string IS responsible
3571             for causing this line to exceed the line length limit.
3572         """
3573         LL = line.leaves
3574
3575         is_valid_index = is_valid_index_factory(LL)
3576
3577         # We use the shorthand "WMA4" in comments to abbreviate "We must
3578         # account for". When giving examples, we use STRING to mean some/any
3579         # valid string.
3580         #
3581         # Finally, we use the following convenience variables:
3582         #
3583         #   P:  The leaf that is before the target string leaf.
3584         #   N:  The leaf that is after the target string leaf.
3585         #   NN: The leaf that is after N.
3586
3587         # WMA4 the whitespace at the beginning of the line.
3588         offset = line.depth * 4
3589
3590         if is_valid_index(string_idx - 1):
3591             p_idx = string_idx - 1
3592             if (
3593                 LL[string_idx - 1].type == token.LPAR
3594                 and LL[string_idx - 1].value == ""
3595                 and string_idx >= 2
3596             ):
3597                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3598                 p_idx -= 1
3599
3600             P = LL[p_idx]
3601             if P.type == token.PLUS:
3602                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3603                 offset += 2
3604
3605             if P.type == token.COMMA:
3606                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3607                 offset += 3
3608
3609             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3610                 # This conditional branch is meant to handle dictionary keys,
3611                 # variable assignments, 'return STRING' statement lines, and
3612                 # 'else STRING' ternary expression lines.
3613
3614                 # WMA4 a single space.
3615                 offset += 1
3616
3617                 # WMA4 the lengths of any leaves that came before that space,
3618                 # but after any closing bracket before that space.
3619                 for leaf in reversed(LL[: p_idx + 1]):
3620                     offset += len(str(leaf))
3621                     if leaf.type in CLOSING_BRACKETS:
3622                         break
3623
3624         if is_valid_index(string_idx + 1):
3625             N = LL[string_idx + 1]
3626             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3627                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3628                 N = LL[string_idx + 2]
3629
3630             if N.type == token.COMMA:
3631                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3632                 offset += 1
3633
3634             if is_valid_index(string_idx + 2):
3635                 NN = LL[string_idx + 2]
3636
3637                 if N.type == token.DOT and NN.type == token.NAME:
3638                     # This conditional branch is meant to handle method calls invoked
3639                     # off of a string literal up to and including the LPAR character.
3640
3641                     # WMA4 the '.' character.
3642                     offset += 1
3643
3644                     if (
3645                         is_valid_index(string_idx + 3)
3646                         and LL[string_idx + 3].type == token.LPAR
3647                     ):
3648                         # WMA4 the left parenthesis character.
3649                         offset += 1
3650
3651                     # WMA4 the length of the method's name.
3652                     offset += len(NN.value)
3653
3654         has_comments = False
3655         for comment_leaf in line.comments_after(LL[string_idx]):
3656             if not has_comments:
3657                 has_comments = True
3658                 # WMA4 two spaces before the '#' character.
3659                 offset += 2
3660
3661             # WMA4 the length of the inline comment.
3662             offset += len(comment_leaf.value)
3663
3664         max_string_length = self.line_length - offset
3665         return max_string_length
3666
3667
3668 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3669     """
3670     StringTransformer that splits "atom" strings (i.e. strings which exist on
3671     lines by themselves).
3672
3673     Requirements:
3674         * The line consists ONLY of a single string (with the exception of a
3675         '+' symbol which MAY exist at the start of the line), MAYBE a string
3676         trailer, and MAYBE a trailing comma.
3677             AND
3678         * All of the requirements listed in BaseStringSplitter's docstring.
3679
3680     Transformations:
3681         The string mentioned in the 'Requirements' section is split into as
3682         many substrings as necessary to adhere to the configured line length.
3683
3684         In the final set of substrings, no substring should be smaller than
3685         MIN_SUBSTR_SIZE characters.
3686
3687         The string will ONLY be split on spaces (i.e. each new substring should
3688         start with a space). Note that the string will NOT be split on a space
3689         which is escaped with a backslash.
3690
3691         If the string is an f-string, it will NOT be split in the middle of an
3692         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3693         else bar()} is an f-expression).
3694
3695         If the string that is being split has an associated set of custom split
3696         records and those custom splits will NOT result in any line going over
3697         the configured line length, those custom splits are used. Otherwise the
3698         string is split as late as possible (from left-to-right) while still
3699         adhering to the transformation rules listed above.
3700
3701     Collaborations:
3702         StringSplitter relies on StringMerger to construct the appropriate
3703         CustomSplit objects and add them to the custom split map.
3704     """
3705
3706     MIN_SUBSTR_SIZE = 6
3707     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3708     RE_FEXPR = r"""
3709     (?<!\{) (?:\{\{)* \{ (?!\{)
3710         (?:
3711             [^\{\}]
3712             | \{\{
3713             | \}\}
3714             | (?R)
3715         )+?
3716     (?<!\}) \} (?:\}\})* (?!\})
3717     """
3718
3719     def do_splitter_match(self, line: Line) -> TMatchResult:
3720         LL = line.leaves
3721
3722         is_valid_index = is_valid_index_factory(LL)
3723
3724         idx = 0
3725
3726         # The first leaf MAY be a '+' symbol...
3727         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3728             idx += 1
3729
3730         # The next/first leaf MAY be an empty LPAR...
3731         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3732             idx += 1
3733
3734         # The next/first leaf MUST be a string...
3735         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3736             return TErr("Line does not start with a string.")
3737
3738         string_idx = idx
3739
3740         # Skip the string trailer, if one exists.
3741         string_parser = StringParser()
3742         idx = string_parser.parse(LL, string_idx)
3743
3744         # That string MAY be followed by an empty RPAR...
3745         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3746             idx += 1
3747
3748         # That string / empty RPAR leaf MAY be followed by a comma...
3749         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3750             idx += 1
3751
3752         # But no more leaves are allowed...
3753         if is_valid_index(idx):
3754             return TErr("This line does not end with a string.")
3755
3756         return Ok(string_idx)
3757
3758     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3759         LL = line.leaves
3760
3761         QUOTE = LL[string_idx].value[-1]
3762
3763         is_valid_index = is_valid_index_factory(LL)
3764         insert_str_child = insert_str_child_factory(LL[string_idx])
3765
3766         prefix = get_string_prefix(LL[string_idx].value)
3767
3768         # We MAY choose to drop the 'f' prefix from substrings that don't
3769         # contain any f-expressions, but ONLY if the original f-string
3770         # contains at least one f-expression. Otherwise, we will alter the AST
3771         # of the program.
3772         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3773             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3774         )
3775
3776         first_string_line = True
3777         starts_with_plus = LL[0].type == token.PLUS
3778
3779         def line_needs_plus() -> bool:
3780             return first_string_line and starts_with_plus
3781
3782         def maybe_append_plus(new_line: Line) -> None:
3783             """
3784             Side Effects:
3785                 If @line starts with a plus and this is the first line we are
3786                 constructing, this function appends a PLUS leaf to @new_line
3787                 and replaces the old PLUS leaf in the node structure. Otherwise
3788                 this function does nothing.
3789             """
3790             if line_needs_plus():
3791                 plus_leaf = Leaf(token.PLUS, "+")
3792                 replace_child(LL[0], plus_leaf)
3793                 new_line.append(plus_leaf)
3794
3795         ends_with_comma = (
3796             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3797         )
3798
3799         def max_last_string() -> int:
3800             """
3801             Returns:
3802                 The max allowed length of the string value used for the last
3803                 line we will construct.
3804             """
3805             result = self.line_length
3806             result -= line.depth * 4
3807             result -= 1 if ends_with_comma else 0
3808             result -= 2 if line_needs_plus() else 0
3809             return result
3810
3811         # --- Calculate Max Break Index (for string value)
3812         # We start with the line length limit
3813         max_break_idx = self.line_length
3814         # The last index of a string of length N is N-1.
3815         max_break_idx -= 1
3816         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3817         max_break_idx -= line.depth * 4
3818         if max_break_idx < 0:
3819             yield TErr(
3820                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3821                 f" {line.depth}"
3822             )
3823             return
3824
3825         # Check if StringMerger registered any custom splits.
3826         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3827         # We use them ONLY if none of them would produce lines that exceed the
3828         # line limit.
3829         use_custom_breakpoints = bool(
3830             custom_splits
3831             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3832         )
3833
3834         # Temporary storage for the remaining chunk of the string line that
3835         # can't fit onto the line currently being constructed.
3836         rest_value = LL[string_idx].value
3837
3838         def more_splits_should_be_made() -> bool:
3839             """
3840             Returns:
3841                 True iff `rest_value` (the remaining string value from the last
3842                 split), should be split again.
3843             """
3844             if use_custom_breakpoints:
3845                 return len(custom_splits) > 1
3846             else:
3847                 return len(rest_value) > max_last_string()
3848
3849         string_line_results: List[Ok[Line]] = []
3850         while more_splits_should_be_made():
3851             if use_custom_breakpoints:
3852                 # Custom User Split (manual)
3853                 csplit = custom_splits.pop(0)
3854                 break_idx = csplit.break_idx
3855             else:
3856                 # Algorithmic Split (automatic)
3857                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3858                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3859                 if maybe_break_idx is None:
3860                     # If we are unable to algorithmically determine a good split
3861                     # and this string has custom splits registered to it, we
3862                     # fall back to using them--which means we have to start
3863                     # over from the beginning.
3864                     if custom_splits:
3865                         rest_value = LL[string_idx].value
3866                         string_line_results = []
3867                         first_string_line = True
3868                         use_custom_breakpoints = True
3869                         continue
3870
3871                     # Otherwise, we stop splitting here.
3872                     break
3873
3874                 break_idx = maybe_break_idx
3875
3876             # --- Construct `next_value`
3877             next_value = rest_value[:break_idx] + QUOTE
3878             if (
3879                 # Are we allowed to try to drop a pointless 'f' prefix?
3880                 drop_pointless_f_prefix
3881                 # If we are, will we be successful?
3882                 and next_value != self.__normalize_f_string(next_value, prefix)
3883             ):
3884                 # If the current custom split did NOT originally use a prefix,
3885                 # then `csplit.break_idx` will be off by one after removing
3886                 # the 'f' prefix.
3887                 break_idx = (
3888                     break_idx + 1
3889                     if use_custom_breakpoints and not csplit.has_prefix
3890                     else break_idx
3891                 )
3892                 next_value = rest_value[:break_idx] + QUOTE
3893                 next_value = self.__normalize_f_string(next_value, prefix)
3894
3895             # --- Construct `next_leaf`
3896             next_leaf = Leaf(token.STRING, next_value)
3897             insert_str_child(next_leaf)
3898             self.__maybe_normalize_string_quotes(next_leaf)
3899
3900             # --- Construct `next_line`
3901             next_line = line.clone()
3902             maybe_append_plus(next_line)
3903             next_line.append(next_leaf)
3904             string_line_results.append(Ok(next_line))
3905
3906             rest_value = prefix + QUOTE + rest_value[break_idx:]
3907             first_string_line = False
3908
3909         yield from string_line_results
3910
3911         if drop_pointless_f_prefix:
3912             rest_value = self.__normalize_f_string(rest_value, prefix)
3913
3914         rest_leaf = Leaf(token.STRING, rest_value)
3915         insert_str_child(rest_leaf)
3916
3917         # NOTE: I could not find a test case that verifies that the following
3918         # line is actually necessary, but it seems to be. Otherwise we risk
3919         # not normalizing the last substring, right?
3920         self.__maybe_normalize_string_quotes(rest_leaf)
3921
3922         last_line = line.clone()
3923         maybe_append_plus(last_line)
3924
3925         # If there are any leaves to the right of the target string...
3926         if is_valid_index(string_idx + 1):
3927             # We use `temp_value` here to determine how long the last line
3928             # would be if we were to append all the leaves to the right of the
3929             # target string to the last string line.
3930             temp_value = rest_value
3931             for leaf in LL[string_idx + 1 :]:
3932                 temp_value += str(leaf)
3933                 if leaf.type == token.LPAR:
3934                     break
3935
3936             # Try to fit them all on the same line with the last substring...
3937             if (
3938                 len(temp_value) <= max_last_string()
3939                 or LL[string_idx + 1].type == token.COMMA
3940             ):
3941                 last_line.append(rest_leaf)
3942                 append_leaves(last_line, line, LL[string_idx + 1 :])
3943                 yield Ok(last_line)
3944             # Otherwise, place the last substring on one line and everything
3945             # else on a line below that...
3946             else:
3947                 last_line.append(rest_leaf)
3948                 yield Ok(last_line)
3949
3950                 non_string_line = line.clone()
3951                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3952                 yield Ok(non_string_line)
3953         # Else the target string was the last leaf...
3954         else:
3955             last_line.append(rest_leaf)
3956             last_line.comments = line.comments.copy()
3957             yield Ok(last_line)
3958
3959     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3960         """
3961         This method contains the algorithm that StringSplitter uses to
3962         determine which character to split each string at.
3963
3964         Args:
3965             @string: The substring that we are attempting to split.
3966             @max_break_idx: The ideal break index. We will return this value if it
3967             meets all the necessary conditions. In the likely event that it
3968             doesn't we will try to find the closest index BELOW @max_break_idx
3969             that does. If that fails, we will expand our search by also
3970             considering all valid indices ABOVE @max_break_idx.
3971
3972         Pre-Conditions:
3973             * assert_is_leaf_string(@string)
3974             * 0 <= @max_break_idx < len(@string)
3975
3976         Returns:
3977             break_idx, if an index is able to be found that meets all of the
3978             conditions listed in the 'Transformations' section of this classes'
3979             docstring.
3980                 OR
3981             None, otherwise.
3982         """
3983         is_valid_index = is_valid_index_factory(string)
3984
3985         assert is_valid_index(max_break_idx)
3986         assert_is_leaf_string(string)
3987
3988         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3989
3990         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3991             """
3992             Yields:
3993                 All ranges of @string which, if @string were to be split there,
3994                 would result in the splitting of an f-expression (which is NOT
3995                 allowed).
3996             """
3997             nonlocal _fexpr_slices
3998
3999             if _fexpr_slices is None:
4000                 _fexpr_slices = []
4001                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
4002                     _fexpr_slices.append(match.span())
4003
4004             yield from _fexpr_slices
4005
4006         is_fstring = "f" in get_string_prefix(string)
4007
4008         def breaks_fstring_expression(i: Index) -> bool:
4009             """
4010             Returns:
4011                 True iff returning @i would result in the splitting of an
4012                 f-expression (which is NOT allowed).
4013             """
4014             if not is_fstring:
4015                 return False
4016
4017             for (start, end) in fexpr_slices():
4018                 if start <= i < end:
4019                     return True
4020
4021             return False
4022
4023         def passes_all_checks(i: Index) -> bool:
4024             """
4025             Returns:
4026                 True iff ALL of the conditions listed in the 'Transformations'
4027                 section of this classes' docstring would be be met by returning @i.
4028             """
4029             is_space = string[i] == " "
4030
4031             is_not_escaped = True
4032             j = i - 1
4033             while is_valid_index(j) and string[j] == "\\":
4034                 is_not_escaped = not is_not_escaped
4035                 j -= 1
4036
4037             is_big_enough = (
4038                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
4039                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
4040             )
4041             return (
4042                 is_space
4043                 and is_not_escaped
4044                 and is_big_enough
4045                 and not breaks_fstring_expression(i)
4046             )
4047
4048         # First, we check all indices BELOW @max_break_idx.
4049         break_idx = max_break_idx
4050         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
4051             break_idx -= 1
4052
4053         if not passes_all_checks(break_idx):
4054             # If that fails, we check all indices ABOVE @max_break_idx.
4055             #
4056             # If we are able to find a valid index here, the next line is going
4057             # to be longer than the specified line length, but it's probably
4058             # better than doing nothing at all.
4059             break_idx = max_break_idx + 1
4060             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
4061                 break_idx += 1
4062
4063             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
4064                 return None
4065
4066         return break_idx
4067
4068     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
4069         if self.normalize_strings:
4070             normalize_string_quotes(leaf)
4071
4072     def __normalize_f_string(self, string: str, prefix: str) -> str:
4073         """
4074         Pre-Conditions:
4075             * assert_is_leaf_string(@string)
4076
4077         Returns:
4078             * If @string is an f-string that contains no f-expressions, we
4079             return a string identical to @string except that the 'f' prefix
4080             has been stripped and all double braces (i.e. '{{' or '}}') have
4081             been normalized (i.e. turned into '{' or '}').
4082                 OR
4083             * Otherwise, we return @string.
4084         """
4085         assert_is_leaf_string(string)
4086
4087         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
4088             new_prefix = prefix.replace("f", "")
4089
4090             temp = string[len(prefix) :]
4091             temp = re.sub(r"\{\{", "{", temp)
4092             temp = re.sub(r"\}\}", "}", temp)
4093             new_string = temp
4094
4095             return f"{new_prefix}{new_string}"
4096         else:
4097             return string
4098
4099
4100 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4101     """
4102     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4103     exist on lines by themselves).
4104
4105     Requirements:
4106         All of the requirements listed in BaseStringSplitter's docstring in
4107         addition to the requirements listed below:
4108
4109         * The line is a return/yield statement, which returns/yields a string.
4110             OR
4111         * The line is part of a ternary expression (e.g. `x = y if cond else
4112         z`) such that the line starts with `else <string>`, where <string> is
4113         some string.
4114             OR
4115         * The line is an assert statement, which ends with a string.
4116             OR
4117         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4118         <string>`) such that the variable is being assigned the value of some
4119         string.
4120             OR
4121         * The line is a dictionary key assignment where some valid key is being
4122         assigned the value of some string.
4123
4124     Transformations:
4125         The chosen string is wrapped in parentheses and then split at the LPAR.
4126
4127         We then have one line which ends with an LPAR and another line that
4128         starts with the chosen string. The latter line is then split again at
4129         the RPAR. This results in the RPAR (and possibly a trailing comma)
4130         being placed on its own line.
4131
4132         NOTE: If any leaves exist to the right of the chosen string (except
4133         for a trailing comma, which would be placed after the RPAR), those
4134         leaves are placed inside the parentheses.  In effect, the chosen
4135         string is not necessarily being "wrapped" by parentheses. We can,
4136         however, count on the LPAR being placed directly before the chosen
4137         string.
4138
4139         In other words, StringParenWrapper creates "atom" strings. These
4140         can then be split again by StringSplitter, if necessary.
4141
4142     Collaborations:
4143         In the event that a string line split by StringParenWrapper is
4144         changed such that it no longer needs to be given its own line,
4145         StringParenWrapper relies on StringParenStripper to clean up the
4146         parentheses it created.
4147     """
4148
4149     def do_splitter_match(self, line: Line) -> TMatchResult:
4150         LL = line.leaves
4151
4152         string_idx = (
4153             self._return_match(LL)
4154             or self._else_match(LL)
4155             or self._assert_match(LL)
4156             or self._assign_match(LL)
4157             or self._dict_match(LL)
4158         )
4159
4160         if string_idx is not None:
4161             string_value = line.leaves[string_idx].value
4162             # If the string has no spaces...
4163             if " " not in string_value:
4164                 # And will still violate the line length limit when split...
4165                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4166                 if len(string_value) > max_string_length:
4167                     # And has no associated custom splits...
4168                     if not self.has_custom_splits(string_value):
4169                         # Then we should NOT put this string on its own line.
4170                         return TErr(
4171                             "We do not wrap long strings in parentheses when the"
4172                             " resultant line would still be over the specified line"
4173                             " length and can't be split further by StringSplitter."
4174                         )
4175             return Ok(string_idx)
4176
4177         return TErr("This line does not contain any non-atomic strings.")
4178
4179     @staticmethod
4180     def _return_match(LL: List[Leaf]) -> Optional[int]:
4181         """
4182         Returns:
4183             string_idx such that @LL[string_idx] is equal to our target (i.e.
4184             matched) string, if this line matches the return/yield statement
4185             requirements listed in the 'Requirements' section of this classes'
4186             docstring.
4187                 OR
4188             None, otherwise.
4189         """
4190         # If this line is apart of a return/yield statement and the first leaf
4191         # contains either the "return" or "yield" keywords...
4192         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4193             0
4194         ].value in ["return", "yield"]:
4195             is_valid_index = is_valid_index_factory(LL)
4196
4197             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4198             # The next visible leaf MUST contain a string...
4199             if is_valid_index(idx) and LL[idx].type == token.STRING:
4200                 return idx
4201
4202         return None
4203
4204     @staticmethod
4205     def _else_match(LL: List[Leaf]) -> Optional[int]:
4206         """
4207         Returns:
4208             string_idx such that @LL[string_idx] is equal to our target (i.e.
4209             matched) string, if this line matches the ternary expression
4210             requirements listed in the 'Requirements' section of this classes'
4211             docstring.
4212                 OR
4213             None, otherwise.
4214         """
4215         # If this line is apart of a ternary expression and the first leaf
4216         # contains the "else" keyword...
4217         if (
4218             parent_type(LL[0]) == syms.test
4219             and LL[0].type == token.NAME
4220             and LL[0].value == "else"
4221         ):
4222             is_valid_index = is_valid_index_factory(LL)
4223
4224             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4225             # The next visible leaf MUST contain a string...
4226             if is_valid_index(idx) and LL[idx].type == token.STRING:
4227                 return idx
4228
4229         return None
4230
4231     @staticmethod
4232     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4233         """
4234         Returns:
4235             string_idx such that @LL[string_idx] is equal to our target (i.e.
4236             matched) string, if this line matches the assert statement
4237             requirements listed in the 'Requirements' section of this classes'
4238             docstring.
4239                 OR
4240             None, otherwise.
4241         """
4242         # If this line is apart of an assert statement and the first leaf
4243         # contains the "assert" keyword...
4244         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4245             is_valid_index = is_valid_index_factory(LL)
4246
4247             for (i, leaf) in enumerate(LL):
4248                 # We MUST find a comma...
4249                 if leaf.type == token.COMMA:
4250                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4251
4252                     # That comma MUST be followed by a string...
4253                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4254                         string_idx = idx
4255
4256                         # Skip the string trailer, if one exists.
4257                         string_parser = StringParser()
4258                         idx = string_parser.parse(LL, string_idx)
4259
4260                         # But no more leaves are allowed...
4261                         if not is_valid_index(idx):
4262                             return string_idx
4263
4264         return None
4265
4266     @staticmethod
4267     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4268         """
4269         Returns:
4270             string_idx such that @LL[string_idx] is equal to our target (i.e.
4271             matched) string, if this line matches the assignment statement
4272             requirements listed in the 'Requirements' section of this classes'
4273             docstring.
4274                 OR
4275             None, otherwise.
4276         """
4277         # If this line is apart of an expression statement or is a function
4278         # argument AND the first leaf contains a variable name...
4279         if (
4280             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4281             and LL[0].type == token.NAME
4282         ):
4283             is_valid_index = is_valid_index_factory(LL)
4284
4285             for (i, leaf) in enumerate(LL):
4286                 # We MUST find either an '=' or '+=' symbol...
4287                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4288                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4289
4290                     # That symbol MUST be followed by a string...
4291                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4292                         string_idx = idx
4293
4294                         # Skip the string trailer, if one exists.
4295                         string_parser = StringParser()
4296                         idx = string_parser.parse(LL, string_idx)
4297
4298                         # The next leaf MAY be a comma iff this line is apart
4299                         # of a function argument...
4300                         if (
4301                             parent_type(LL[0]) == syms.argument
4302                             and is_valid_index(idx)
4303                             and LL[idx].type == token.COMMA
4304                         ):
4305                             idx += 1
4306
4307                         # But no more leaves are allowed...
4308                         if not is_valid_index(idx):
4309                             return string_idx
4310
4311         return None
4312
4313     @staticmethod
4314     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4315         """
4316         Returns:
4317             string_idx such that @LL[string_idx] is equal to our target (i.e.
4318             matched) string, if this line matches the dictionary key assignment
4319             statement requirements listed in the 'Requirements' section of this
4320             classes' docstring.
4321                 OR
4322             None, otherwise.
4323         """
4324         # If this line is apart of a dictionary key assignment...
4325         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4326             is_valid_index = is_valid_index_factory(LL)
4327
4328             for (i, leaf) in enumerate(LL):
4329                 # We MUST find a colon...
4330                 if leaf.type == token.COLON:
4331                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4332
4333                     # That colon MUST be followed by a string...
4334                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4335                         string_idx = idx
4336
4337                         # Skip the string trailer, if one exists.
4338                         string_parser = StringParser()
4339                         idx = string_parser.parse(LL, string_idx)
4340
4341                         # That string MAY be followed by a comma...
4342                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4343                             idx += 1
4344
4345                         # But no more leaves are allowed...
4346                         if not is_valid_index(idx):
4347                             return string_idx
4348
4349         return None
4350
4351     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4352         LL = line.leaves
4353
4354         is_valid_index = is_valid_index_factory(LL)
4355         insert_str_child = insert_str_child_factory(LL[string_idx])
4356
4357         comma_idx = -1
4358         ends_with_comma = False
4359         if LL[comma_idx].type == token.COMMA:
4360             ends_with_comma = True
4361
4362         leaves_to_steal_comments_from = [LL[string_idx]]
4363         if ends_with_comma:
4364             leaves_to_steal_comments_from.append(LL[comma_idx])
4365
4366         # --- First Line
4367         first_line = line.clone()
4368         left_leaves = LL[:string_idx]
4369
4370         # We have to remember to account for (possibly invisible) LPAR and RPAR
4371         # leaves that already wrapped the target string. If these leaves do
4372         # exist, we will replace them with our own LPAR and RPAR leaves.
4373         old_parens_exist = False
4374         if left_leaves and left_leaves[-1].type == token.LPAR:
4375             old_parens_exist = True
4376             leaves_to_steal_comments_from.append(left_leaves[-1])
4377             left_leaves.pop()
4378
4379         append_leaves(first_line, line, left_leaves)
4380
4381         lpar_leaf = Leaf(token.LPAR, "(")
4382         if old_parens_exist:
4383             replace_child(LL[string_idx - 1], lpar_leaf)
4384         else:
4385             insert_str_child(lpar_leaf)
4386         first_line.append(lpar_leaf)
4387
4388         # We throw inline comments that were originally to the right of the
4389         # target string to the top line. They will now be shown to the right of
4390         # the LPAR.
4391         for leaf in leaves_to_steal_comments_from:
4392             for comment_leaf in line.comments_after(leaf):
4393                 first_line.append(comment_leaf, preformatted=True)
4394
4395         yield Ok(first_line)
4396
4397         # --- Middle (String) Line
4398         # We only need to yield one (possibly too long) string line, since the
4399         # `StringSplitter` will break it down further if necessary.
4400         string_value = LL[string_idx].value
4401         string_line = Line(
4402             mode=line.mode,
4403             depth=line.depth + 1,
4404             inside_brackets=True,
4405             should_split_rhs=line.should_split_rhs,
4406             magic_trailing_comma=line.magic_trailing_comma,
4407         )
4408         string_leaf = Leaf(token.STRING, string_value)
4409         insert_str_child(string_leaf)
4410         string_line.append(string_leaf)
4411
4412         old_rpar_leaf = None
4413         if is_valid_index(string_idx + 1):
4414             right_leaves = LL[string_idx + 1 :]
4415             if ends_with_comma:
4416                 right_leaves.pop()
4417
4418             if old_parens_exist:
4419                 assert (
4420                     right_leaves and right_leaves[-1].type == token.RPAR
4421                 ), "Apparently, old parentheses do NOT exist?!"
4422                 old_rpar_leaf = right_leaves.pop()
4423
4424             append_leaves(string_line, line, right_leaves)
4425
4426         yield Ok(string_line)
4427
4428         # --- Last Line
4429         last_line = line.clone()
4430         last_line.bracket_tracker = first_line.bracket_tracker
4431
4432         new_rpar_leaf = Leaf(token.RPAR, ")")
4433         if old_rpar_leaf is not None:
4434             replace_child(old_rpar_leaf, new_rpar_leaf)
4435         else:
4436             insert_str_child(new_rpar_leaf)
4437         last_line.append(new_rpar_leaf)
4438
4439         # If the target string ended with a comma, we place this comma to the
4440         # right of the RPAR on the last line.
4441         if ends_with_comma:
4442             comma_leaf = Leaf(token.COMMA, ",")
4443             replace_child(LL[comma_idx], comma_leaf)
4444             last_line.append(comma_leaf)
4445
4446         yield Ok(last_line)
4447
4448
4449 class StringParser:
4450     """
4451     A state machine that aids in parsing a string's "trailer", which can be
4452     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4453     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4454     varY)`).
4455
4456     NOTE: A new StringParser object MUST be instantiated for each string
4457     trailer we need to parse.
4458
4459     Examples:
4460         We shall assume that `line` equals the `Line` object that corresponds
4461         to the following line of python code:
4462         ```
4463         x = "Some {}.".format("String") + some_other_string
4464         ```
4465
4466         Furthermore, we will assume that `string_idx` is some index such that:
4467         ```
4468         assert line.leaves[string_idx].value == "Some {}."
4469         ```
4470
4471         The following code snippet then holds:
4472         ```
4473         string_parser = StringParser()
4474         idx = string_parser.parse(line.leaves, string_idx)
4475         assert line.leaves[idx].type == token.PLUS
4476         ```
4477     """
4478
4479     DEFAULT_TOKEN = -1
4480
4481     # String Parser States
4482     START = 1
4483     DOT = 2
4484     NAME = 3
4485     PERCENT = 4
4486     SINGLE_FMT_ARG = 5
4487     LPAR = 6
4488     RPAR = 7
4489     DONE = 8
4490
4491     # Lookup Table for Next State
4492     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4493         # A string trailer may start with '.' OR '%'.
4494         (START, token.DOT): DOT,
4495         (START, token.PERCENT): PERCENT,
4496         (START, DEFAULT_TOKEN): DONE,
4497         # A '.' MUST be followed by an attribute or method name.
4498         (DOT, token.NAME): NAME,
4499         # A method name MUST be followed by an '(', whereas an attribute name
4500         # is the last symbol in the string trailer.
4501         (NAME, token.LPAR): LPAR,
4502         (NAME, DEFAULT_TOKEN): DONE,
4503         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4504         # string or variable name).
4505         (PERCENT, token.LPAR): LPAR,
4506         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4507         # If a '%' symbol is followed by a single argument, that argument is
4508         # the last leaf in the string trailer.
4509         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4510         # If present, a ')' symbol is the last symbol in a string trailer.
4511         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4512         # since they are treated as a special case by the parsing logic in this
4513         # classes' implementation.)
4514         (RPAR, DEFAULT_TOKEN): DONE,
4515     }
4516
4517     def __init__(self) -> None:
4518         self._state = self.START
4519         self._unmatched_lpars = 0
4520
4521     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4522         """
4523         Pre-conditions:
4524             * @leaves[@string_idx].type == token.STRING
4525
4526         Returns:
4527             The index directly after the last leaf which is apart of the string
4528             trailer, if a "trailer" exists.
4529                 OR
4530             @string_idx + 1, if no string "trailer" exists.
4531         """
4532         assert leaves[string_idx].type == token.STRING
4533
4534         idx = string_idx + 1
4535         while idx < len(leaves) and self._next_state(leaves[idx]):
4536             idx += 1
4537         return idx
4538
4539     def _next_state(self, leaf: Leaf) -> bool:
4540         """
4541         Pre-conditions:
4542             * On the first call to this function, @leaf MUST be the leaf that
4543             was directly after the string leaf in question (e.g. if our target
4544             string is `line.leaves[i]` then the first call to this method must
4545             be `line.leaves[i + 1]`).
4546             * On the next call to this function, the leaf parameter passed in
4547             MUST be the leaf directly following @leaf.
4548
4549         Returns:
4550             True iff @leaf is apart of the string's trailer.
4551         """
4552         # We ignore empty LPAR or RPAR leaves.
4553         if is_empty_par(leaf):
4554             return True
4555
4556         next_token = leaf.type
4557         if next_token == token.LPAR:
4558             self._unmatched_lpars += 1
4559
4560         current_state = self._state
4561
4562         # The LPAR parser state is a special case. We will return True until we
4563         # find the matching RPAR token.
4564         if current_state == self.LPAR:
4565             if next_token == token.RPAR:
4566                 self._unmatched_lpars -= 1
4567                 if self._unmatched_lpars == 0:
4568                     self._state = self.RPAR
4569         # Otherwise, we use a lookup table to determine the next state.
4570         else:
4571             # If the lookup table matches the current state to the next
4572             # token, we use the lookup table.
4573             if (current_state, next_token) in self._goto:
4574                 self._state = self._goto[current_state, next_token]
4575             else:
4576                 # Otherwise, we check if a the current state was assigned a
4577                 # default.
4578                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4579                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4580                 # If no default has been assigned, then this parser has a logic
4581                 # error.
4582                 else:
4583                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4584
4585             if self._state == self.DONE:
4586                 return False
4587
4588         return True
4589
4590
4591 def TErr(err_msg: str) -> Err[CannotTransform]:
4592     """(T)ransform Err
4593
4594     Convenience function used when working with the TResult type.
4595     """
4596     cant_transform = CannotTransform(err_msg)
4597     return Err(cant_transform)
4598
4599
4600 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4601     """
4602     Returns:
4603         True iff one of the comments in @comment_list is a pragma used by one
4604         of the more common static analysis tools for python (e.g. mypy, flake8,
4605         pylint).
4606     """
4607     for comment in comment_list:
4608         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4609             return True
4610
4611     return False
4612
4613
4614 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4615     """
4616     Factory for a convenience function that is used to orphan @string_leaf
4617     and then insert multiple new leaves into the same part of the node
4618     structure that @string_leaf had originally occupied.
4619
4620     Examples:
4621         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4622         string_leaf.parent`. Assume the node `N` has the following
4623         original structure:
4624
4625         Node(
4626             expr_stmt, [
4627                 Leaf(NAME, 'x'),
4628                 Leaf(EQUAL, '='),
4629                 Leaf(STRING, '"foo"'),
4630             ]
4631         )
4632
4633         We then run the code snippet shown below.
4634         ```
4635         insert_str_child = insert_str_child_factory(string_leaf)
4636
4637         lpar = Leaf(token.LPAR, '(')
4638         insert_str_child(lpar)
4639
4640         bar = Leaf(token.STRING, '"bar"')
4641         insert_str_child(bar)
4642
4643         rpar = Leaf(token.RPAR, ')')
4644         insert_str_child(rpar)
4645         ```
4646
4647         After which point, it follows that `string_leaf.parent is None` and
4648         the node `N` now has the following structure:
4649
4650         Node(
4651             expr_stmt, [
4652                 Leaf(NAME, 'x'),
4653                 Leaf(EQUAL, '='),
4654                 Leaf(LPAR, '('),
4655                 Leaf(STRING, '"bar"'),
4656                 Leaf(RPAR, ')'),
4657             ]
4658         )
4659     """
4660     string_parent = string_leaf.parent
4661     string_child_idx = string_leaf.remove()
4662
4663     def insert_str_child(child: LN) -> None:
4664         nonlocal string_child_idx
4665
4666         assert string_parent is not None
4667         assert string_child_idx is not None
4668
4669         string_parent.insert_child(string_child_idx, child)
4670         string_child_idx += 1
4671
4672     return insert_str_child
4673
4674
4675 def has_triple_quotes(string: str) -> bool:
4676     """
4677     Returns:
4678         True iff @string starts with three quotation characters.
4679     """
4680     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4681     return raw_string[:3] in {'"""', "'''"}
4682
4683
4684 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4685     """
4686     Returns:
4687         @node.parent.type, if @node is not None and has a parent.
4688             OR
4689         None, otherwise.
4690     """
4691     if node is None or node.parent is None:
4692         return None
4693
4694     return node.parent.type
4695
4696
4697 def is_empty_par(leaf: Leaf) -> bool:
4698     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4699
4700
4701 def is_empty_lpar(leaf: Leaf) -> bool:
4702     return leaf.type == token.LPAR and leaf.value == ""
4703
4704
4705 def is_empty_rpar(leaf: Leaf) -> bool:
4706     return leaf.type == token.RPAR and leaf.value == ""
4707
4708
4709 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4710     """
4711     Examples:
4712         ```
4713         my_list = [1, 2, 3]
4714
4715         is_valid_index = is_valid_index_factory(my_list)
4716
4717         assert is_valid_index(0)
4718         assert is_valid_index(2)
4719
4720         assert not is_valid_index(3)
4721         assert not is_valid_index(-1)
4722         ```
4723     """
4724
4725     def is_valid_index(idx: int) -> bool:
4726         """
4727         Returns:
4728             True iff @idx is positive AND seq[@idx] does NOT raise an
4729             IndexError.
4730         """
4731         return 0 <= idx < len(seq)
4732
4733     return is_valid_index
4734
4735
4736 def line_to_string(line: Line) -> str:
4737     """Returns the string representation of @line.
4738
4739     WARNING: This is known to be computationally expensive.
4740     """
4741     return str(line).strip("\n")
4742
4743
4744 def append_leaves(
4745     new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
4746 ) -> None:
4747     """
4748     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4749     underlying Node structure where appropriate.
4750
4751     All of the leaves in @leaves are duplicated. The duplicates are then
4752     appended to @new_line and used to replace their originals in the underlying
4753     Node structure. Any comments attached to the old leaves are reattached to
4754     the new leaves.
4755
4756     Pre-conditions:
4757         set(@leaves) is a subset of set(@old_line.leaves).
4758     """
4759     for old_leaf in leaves:
4760         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4761         replace_child(old_leaf, new_leaf)
4762         new_line.append(new_leaf, preformatted=preformatted)
4763
4764         for comment_leaf in old_line.comments_after(old_leaf):
4765             new_line.append(comment_leaf, preformatted=True)
4766
4767
4768 def replace_child(old_child: LN, new_child: LN) -> None:
4769     """
4770     Side Effects:
4771         * If @old_child.parent is set, replace @old_child with @new_child in
4772         @old_child's underlying Node structure.
4773             OR
4774         * Otherwise, this function does nothing.
4775     """
4776     parent = old_child.parent
4777     if not parent:
4778         return
4779
4780     child_idx = old_child.remove()
4781     if child_idx is not None:
4782         parent.insert_child(child_idx, new_child)
4783
4784
4785 def get_string_prefix(string: str) -> str:
4786     """
4787     Pre-conditions:
4788         * assert_is_leaf_string(@string)
4789
4790     Returns:
4791         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4792     """
4793     assert_is_leaf_string(string)
4794
4795     prefix = ""
4796     prefix_idx = 0
4797     while string[prefix_idx] in STRING_PREFIX_CHARS:
4798         prefix += string[prefix_idx].lower()
4799         prefix_idx += 1
4800
4801     return prefix
4802
4803
4804 def assert_is_leaf_string(string: str) -> None:
4805     """
4806     Checks the pre-condition that @string has the format that you would expect
4807     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4808     token.STRING`. A more precise description of the pre-conditions that are
4809     checked are listed below.
4810
4811     Pre-conditions:
4812         * @string starts with either ', ", <prefix>', or <prefix>" where
4813         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4814         * @string ends with a quote character (' or ").
4815
4816     Raises:
4817         AssertionError(...) if the pre-conditions listed above are not
4818         satisfied.
4819     """
4820     dquote_idx = string.find('"')
4821     squote_idx = string.find("'")
4822     if -1 in [dquote_idx, squote_idx]:
4823         quote_idx = max(dquote_idx, squote_idx)
4824     else:
4825         quote_idx = min(squote_idx, dquote_idx)
4826
4827     assert (
4828         0 <= quote_idx < len(string) - 1
4829     ), f"{string!r} is missing a starting quote character (' or \")."
4830     assert string[-1] in (
4831         "'",
4832         '"',
4833     ), f"{string!r} is missing an ending quote character (' or \")."
4834     assert set(string[:quote_idx]).issubset(
4835         set(STRING_PREFIX_CHARS)
4836     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4837
4838
4839 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4840     """Split line into many lines, starting with the first matching bracket pair.
4841
4842     Note: this usually looks weird, only use this for function definitions.
4843     Prefer RHS otherwise.  This is why this function is not symmetrical with
4844     :func:`right_hand_split` which also handles optional parentheses.
4845     """
4846     tail_leaves: List[Leaf] = []
4847     body_leaves: List[Leaf] = []
4848     head_leaves: List[Leaf] = []
4849     current_leaves = head_leaves
4850     matching_bracket: Optional[Leaf] = None
4851     for leaf in line.leaves:
4852         if (
4853             current_leaves is body_leaves
4854             and leaf.type in CLOSING_BRACKETS
4855             and leaf.opening_bracket is matching_bracket
4856         ):
4857             current_leaves = tail_leaves if body_leaves else head_leaves
4858         current_leaves.append(leaf)
4859         if current_leaves is head_leaves:
4860             if leaf.type in OPENING_BRACKETS:
4861                 matching_bracket = leaf
4862                 current_leaves = body_leaves
4863     if not matching_bracket:
4864         raise CannotSplit("No brackets found")
4865
4866     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4867     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4868     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4869     bracket_split_succeeded_or_raise(head, body, tail)
4870     for result in (head, body, tail):
4871         if result:
4872             yield result
4873
4874
4875 def right_hand_split(
4876     line: Line,
4877     line_length: int,
4878     features: Collection[Feature] = (),
4879     omit: Collection[LeafID] = (),
4880 ) -> Iterator[Line]:
4881     """Split line into many lines, starting with the last matching bracket pair.
4882
4883     If the split was by optional parentheses, attempt splitting without them, too.
4884     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4885     this split.
4886
4887     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4888     """
4889     tail_leaves: List[Leaf] = []
4890     body_leaves: List[Leaf] = []
4891     head_leaves: List[Leaf] = []
4892     current_leaves = tail_leaves
4893     opening_bracket: Optional[Leaf] = None
4894     closing_bracket: Optional[Leaf] = None
4895     for leaf in reversed(line.leaves):
4896         if current_leaves is body_leaves:
4897             if leaf is opening_bracket:
4898                 current_leaves = head_leaves if body_leaves else tail_leaves
4899         current_leaves.append(leaf)
4900         if current_leaves is tail_leaves:
4901             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4902                 opening_bracket = leaf.opening_bracket
4903                 closing_bracket = leaf
4904                 current_leaves = body_leaves
4905     if not (opening_bracket and closing_bracket and head_leaves):
4906         # If there is no opening or closing_bracket that means the split failed and
4907         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4908         # the matching `opening_bracket` wasn't available on `line` anymore.
4909         raise CannotSplit("No brackets found")
4910
4911     tail_leaves.reverse()
4912     body_leaves.reverse()
4913     head_leaves.reverse()
4914     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4915     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4916     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4917     bracket_split_succeeded_or_raise(head, body, tail)
4918     if (
4919         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4920         # the opening bracket is an optional paren
4921         and opening_bracket.type == token.LPAR
4922         and not opening_bracket.value
4923         # the closing bracket is an optional paren
4924         and closing_bracket.type == token.RPAR
4925         and not closing_bracket.value
4926         # it's not an import (optional parens are the only thing we can split on
4927         # in this case; attempting a split without them is a waste of time)
4928         and not line.is_import
4929         # there are no standalone comments in the body
4930         and not body.contains_standalone_comments(0)
4931         # and we can actually remove the parens
4932         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4933     ):
4934         omit = {id(closing_bracket), *omit}
4935         try:
4936             yield from right_hand_split(line, line_length, features=features, omit=omit)
4937             return
4938
4939         except CannotSplit:
4940             if not (
4941                 can_be_split(body)
4942                 or is_line_short_enough(body, line_length=line_length)
4943             ):
4944                 raise CannotSplit(
4945                     "Splitting failed, body is still too long and can't be split."
4946                 )
4947
4948             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4949                 raise CannotSplit(
4950                     "The current optional pair of parentheses is bound to fail to"
4951                     " satisfy the splitting algorithm because the head or the tail"
4952                     " contains multiline strings which by definition never fit one"
4953                     " line."
4954                 )
4955
4956     ensure_visible(opening_bracket)
4957     ensure_visible(closing_bracket)
4958     for result in (head, body, tail):
4959         if result:
4960             yield result
4961
4962
4963 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4964     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4965
4966     Do nothing otherwise.
4967
4968     A left- or right-hand split is based on a pair of brackets. Content before
4969     (and including) the opening bracket is left on one line, content inside the
4970     brackets is put on a separate line, and finally content starting with and
4971     following the closing bracket is put on a separate line.
4972
4973     Those are called `head`, `body`, and `tail`, respectively. If the split
4974     produced the same line (all content in `head`) or ended up with an empty `body`
4975     and the `tail` is just the closing bracket, then it's considered failed.
4976     """
4977     tail_len = len(str(tail).strip())
4978     if not body:
4979         if tail_len == 0:
4980             raise CannotSplit("Splitting brackets produced the same line")
4981
4982         elif tail_len < 3:
4983             raise CannotSplit(
4984                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4985                 " not worth it"
4986             )
4987
4988
4989 def bracket_split_build_line(
4990     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4991 ) -> Line:
4992     """Return a new line with given `leaves` and respective comments from `original`.
4993
4994     If `is_body` is True, the result line is one-indented inside brackets and as such
4995     has its first leaf's prefix normalized and a trailing comma added when expected.
4996     """
4997     result = Line(mode=original.mode, depth=original.depth)
4998     if is_body:
4999         result.inside_brackets = True
5000         result.depth += 1
5001         if leaves:
5002             # Since body is a new indent level, remove spurious leading whitespace.
5003             normalize_prefix(leaves[0], inside_brackets=True)
5004             # Ensure a trailing comma for imports and standalone function arguments, but
5005             # be careful not to add one after any comments or within type annotations.
5006             no_commas = (
5007                 original.is_def
5008                 and opening_bracket.value == "("
5009                 and not any(leaf.type == token.COMMA for leaf in leaves)
5010             )
5011
5012             if original.is_import or no_commas:
5013                 for i in range(len(leaves) - 1, -1, -1):
5014                     if leaves[i].type == STANDALONE_COMMENT:
5015                         continue
5016
5017                     if leaves[i].type != token.COMMA:
5018                         new_comma = Leaf(token.COMMA, ",")
5019                         leaves.insert(i + 1, new_comma)
5020                     break
5021
5022     # Populate the line
5023     for leaf in leaves:
5024         result.append(leaf, preformatted=True)
5025         for comment_after in original.comments_after(leaf):
5026             result.append(comment_after, preformatted=True)
5027     if is_body and should_split_line(result, opening_bracket):
5028         result.should_split_rhs = True
5029     return result
5030
5031
5032 def dont_increase_indentation(split_func: Transformer) -> Transformer:
5033     """Normalize prefix of the first leaf in every line returned by `split_func`.
5034
5035     This is a decorator over relevant split functions.
5036     """
5037
5038     @wraps(split_func)
5039     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5040         for line in split_func(line, features):
5041             normalize_prefix(line.leaves[0], inside_brackets=True)
5042             yield line
5043
5044     return split_wrapper
5045
5046
5047 @dont_increase_indentation
5048 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5049     """Split according to delimiters of the highest priority.
5050
5051     If the appropriate Features are given, the split will add trailing commas
5052     also in function signatures and calls that contain `*` and `**`.
5053     """
5054     try:
5055         last_leaf = line.leaves[-1]
5056     except IndexError:
5057         raise CannotSplit("Line empty")
5058
5059     bt = line.bracket_tracker
5060     try:
5061         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
5062     except ValueError:
5063         raise CannotSplit("No delimiters found")
5064
5065     if delimiter_priority == DOT_PRIORITY:
5066         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
5067             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
5068
5069     current_line = Line(
5070         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5071     )
5072     lowest_depth = sys.maxsize
5073     trailing_comma_safe = True
5074
5075     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5076         """Append `leaf` to current line or to new line if appending impossible."""
5077         nonlocal current_line
5078         try:
5079             current_line.append_safe(leaf, preformatted=True)
5080         except ValueError:
5081             yield current_line
5082
5083             current_line = Line(
5084                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5085             )
5086             current_line.append(leaf)
5087
5088     for leaf in line.leaves:
5089         yield from append_to_line(leaf)
5090
5091         for comment_after in line.comments_after(leaf):
5092             yield from append_to_line(comment_after)
5093
5094         lowest_depth = min(lowest_depth, leaf.bracket_depth)
5095         if leaf.bracket_depth == lowest_depth:
5096             if is_vararg(leaf, within={syms.typedargslist}):
5097                 trailing_comma_safe = (
5098                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
5099                 )
5100             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
5101                 trailing_comma_safe = (
5102                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
5103                 )
5104
5105         leaf_priority = bt.delimiters.get(id(leaf))
5106         if leaf_priority == delimiter_priority:
5107             yield current_line
5108
5109             current_line = Line(
5110                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5111             )
5112     if current_line:
5113         if (
5114             trailing_comma_safe
5115             and delimiter_priority == COMMA_PRIORITY
5116             and current_line.leaves[-1].type != token.COMMA
5117             and current_line.leaves[-1].type != STANDALONE_COMMENT
5118         ):
5119             new_comma = Leaf(token.COMMA, ",")
5120             current_line.append(new_comma)
5121         yield current_line
5122
5123
5124 @dont_increase_indentation
5125 def standalone_comment_split(
5126     line: Line, features: Collection[Feature] = ()
5127 ) -> Iterator[Line]:
5128     """Split standalone comments from the rest of the line."""
5129     if not line.contains_standalone_comments(0):
5130         raise CannotSplit("Line does not have any standalone comments")
5131
5132     current_line = Line(
5133         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5134     )
5135
5136     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5137         """Append `leaf` to current line or to new line if appending impossible."""
5138         nonlocal current_line
5139         try:
5140             current_line.append_safe(leaf, preformatted=True)
5141         except ValueError:
5142             yield current_line
5143
5144             current_line = Line(
5145                 line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5146             )
5147             current_line.append(leaf)
5148
5149     for leaf in line.leaves:
5150         yield from append_to_line(leaf)
5151
5152         for comment_after in line.comments_after(leaf):
5153             yield from append_to_line(comment_after)
5154
5155     if current_line:
5156         yield current_line
5157
5158
5159 def is_import(leaf: Leaf) -> bool:
5160     """Return True if the given leaf starts an import statement."""
5161     p = leaf.parent
5162     t = leaf.type
5163     v = leaf.value
5164     return bool(
5165         t == token.NAME
5166         and (
5167             (v == "import" and p and p.type == syms.import_name)
5168             or (v == "from" and p and p.type == syms.import_from)
5169         )
5170     )
5171
5172
5173 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5174     """Return True if the given leaf is a special comment.
5175     Only returns true for type comments for now."""
5176     t = leaf.type
5177     v = leaf.value
5178     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5179
5180
5181 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5182     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5183     else.
5184
5185     Note: don't use backslashes for formatting or you'll lose your voting rights.
5186     """
5187     if not inside_brackets:
5188         spl = leaf.prefix.split("#")
5189         if "\\" not in spl[0]:
5190             nl_count = spl[-1].count("\n")
5191             if len(spl) > 1:
5192                 nl_count -= 1
5193             leaf.prefix = "\n" * nl_count
5194             return
5195
5196     leaf.prefix = ""
5197
5198
5199 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5200     """Make all string prefixes lowercase.
5201
5202     If remove_u_prefix is given, also removes any u prefix from the string.
5203
5204     Note: Mutates its argument.
5205     """
5206     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5207     assert match is not None, f"failed to match string {leaf.value!r}"
5208     orig_prefix = match.group(1)
5209     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5210     if remove_u_prefix:
5211         new_prefix = new_prefix.replace("u", "")
5212     leaf.value = f"{new_prefix}{match.group(2)}"
5213
5214
5215 def normalize_string_quotes(leaf: Leaf) -> None:
5216     """Prefer double quotes but only if it doesn't cause more escaping.
5217
5218     Adds or removes backslashes as appropriate. Doesn't parse and fix
5219     strings nested in f-strings (yet).
5220
5221     Note: Mutates its argument.
5222     """
5223     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5224     if value[:3] == '"""':
5225         return
5226
5227     elif value[:3] == "'''":
5228         orig_quote = "'''"
5229         new_quote = '"""'
5230     elif value[0] == '"':
5231         orig_quote = '"'
5232         new_quote = "'"
5233     else:
5234         orig_quote = "'"
5235         new_quote = '"'
5236     first_quote_pos = leaf.value.find(orig_quote)
5237     if first_quote_pos == -1:
5238         return  # There's an internal error
5239
5240     prefix = leaf.value[:first_quote_pos]
5241     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5242     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5243     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5244     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5245     if "r" in prefix.casefold():
5246         if unescaped_new_quote.search(body):
5247             # There's at least one unescaped new_quote in this raw string
5248             # so converting is impossible
5249             return
5250
5251         # Do not introduce or remove backslashes in raw strings
5252         new_body = body
5253     else:
5254         # remove unnecessary escapes
5255         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5256         if body != new_body:
5257             # Consider the string without unnecessary escapes as the original
5258             body = new_body
5259             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5260         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5261         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5262     if "f" in prefix.casefold():
5263         matches = re.findall(
5264             r"""
5265             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5266                 ([^{].*?)  # contents of the brackets except if begins with {{
5267             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5268             """,
5269             new_body,
5270             re.VERBOSE,
5271         )
5272         for m in matches:
5273             if "\\" in str(m):
5274                 # Do not introduce backslashes in interpolated expressions
5275                 return
5276
5277     if new_quote == '"""' and new_body[-1:] == '"':
5278         # edge case:
5279         new_body = new_body[:-1] + '\\"'
5280     orig_escape_count = body.count("\\")
5281     new_escape_count = new_body.count("\\")
5282     if new_escape_count > orig_escape_count:
5283         return  # Do not introduce more escaping
5284
5285     if new_escape_count == orig_escape_count and orig_quote == '"':
5286         return  # Prefer double quotes
5287
5288     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5289
5290
5291 def normalize_numeric_literal(leaf: Leaf) -> None:
5292     """Normalizes numeric (float, int, and complex) literals.
5293
5294     All letters used in the representation are normalized to lowercase (except
5295     in Python 2 long literals).
5296     """
5297     text = leaf.value.lower()
5298     if text.startswith(("0o", "0b")):
5299         # Leave octal and binary literals alone.
5300         pass
5301     elif text.startswith("0x"):
5302         text = format_hex(text)
5303     elif "e" in text:
5304         text = format_scientific_notation(text)
5305     elif text.endswith(("j", "l")):
5306         text = format_long_or_complex_number(text)
5307     else:
5308         text = format_float_or_int_string(text)
5309     leaf.value = text
5310
5311
5312 def format_hex(text: str) -> str:
5313     """
5314     Formats a hexadecimal string like "0x12b3"
5315
5316     Uses lowercase because of similarity between "B" and "8", which
5317     can cause security issues.
5318     see: https://github.com/psf/black/issues/1692
5319     """
5320
5321     before, after = text[:2], text[2:]
5322     return f"{before}{after.lower()}"
5323
5324
5325 def format_scientific_notation(text: str) -> str:
5326     """Formats a numeric string utilizing scentific notation"""
5327     before, after = text.split("e")
5328     sign = ""
5329     if after.startswith("-"):
5330         after = after[1:]
5331         sign = "-"
5332     elif after.startswith("+"):
5333         after = after[1:]
5334     before = format_float_or_int_string(before)
5335     return f"{before}e{sign}{after}"
5336
5337
5338 def format_long_or_complex_number(text: str) -> str:
5339     """Formats a long or complex string like `10L` or `10j`"""
5340     number = text[:-1]
5341     suffix = text[-1]
5342     # Capitalize in "2L" because "l" looks too similar to "1".
5343     if suffix == "l":
5344         suffix = "L"
5345     return f"{format_float_or_int_string(number)}{suffix}"
5346
5347
5348 def format_float_or_int_string(text: str) -> str:
5349     """Formats a float string like "1.0"."""
5350     if "." not in text:
5351         return text
5352
5353     before, after = text.split(".")
5354     return f"{before or 0}.{after or 0}"
5355
5356
5357 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5358     """Make existing optional parentheses invisible or create new ones.
5359
5360     `parens_after` is a set of string leaf values immediately after which parens
5361     should be put.
5362
5363     Standardizes on visible parentheses for single-element tuples, and keeps
5364     existing visible parentheses for other tuples and generator expressions.
5365     """
5366     for pc in list_comments(node.prefix, is_endmarker=False):
5367         if pc.value in FMT_OFF:
5368             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5369             return
5370     check_lpar = False
5371     for index, child in enumerate(list(node.children)):
5372         # Fixes a bug where invisible parens are not properly stripped from
5373         # assignment statements that contain type annotations.
5374         if isinstance(child, Node) and child.type == syms.annassign:
5375             normalize_invisible_parens(child, parens_after=parens_after)
5376
5377         # Add parentheses around long tuple unpacking in assignments.
5378         if (
5379             index == 0
5380             and isinstance(child, Node)
5381             and child.type == syms.testlist_star_expr
5382         ):
5383             check_lpar = True
5384
5385         if check_lpar:
5386             if child.type == syms.atom:
5387                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5388                     wrap_in_parentheses(node, child, visible=False)
5389             elif is_one_tuple(child):
5390                 wrap_in_parentheses(node, child, visible=True)
5391             elif node.type == syms.import_from:
5392                 # "import from" nodes store parentheses directly as part of
5393                 # the statement
5394                 if child.type == token.LPAR:
5395                     # make parentheses invisible
5396                     child.value = ""  # type: ignore
5397                     node.children[-1].value = ""  # type: ignore
5398                 elif child.type != token.STAR:
5399                     # insert invisible parentheses
5400                     node.insert_child(index, Leaf(token.LPAR, ""))
5401                     node.append_child(Leaf(token.RPAR, ""))
5402                 break
5403
5404             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5405                 wrap_in_parentheses(node, child, visible=False)
5406
5407         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5408
5409
5410 def normalize_fmt_off(node: Node) -> None:
5411     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5412     try_again = True
5413     while try_again:
5414         try_again = convert_one_fmt_off_pair(node)
5415
5416
5417 def convert_one_fmt_off_pair(node: Node) -> bool:
5418     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5419
5420     Returns True if a pair was converted.
5421     """
5422     for leaf in node.leaves():
5423         previous_consumed = 0
5424         for comment in list_comments(leaf.prefix, is_endmarker=False):
5425             if comment.value not in FMT_PASS:
5426                 previous_consumed = comment.consumed
5427                 continue
5428             # We only want standalone comments. If there's no previous leaf or
5429             # the previous leaf is indentation, it's a standalone comment in
5430             # disguise.
5431             if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
5432                 prev = preceding_leaf(leaf)
5433                 if prev:
5434                     if comment.value in FMT_OFF and prev.type not in WHITESPACE:
5435                         continue
5436                     if comment.value in FMT_SKIP and prev.type in WHITESPACE:
5437                         continue
5438
5439             ignored_nodes = list(generate_ignored_nodes(leaf, comment))
5440             if not ignored_nodes:
5441                 continue
5442
5443             first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5444             parent = first.parent
5445             prefix = first.prefix
5446             first.prefix = prefix[comment.consumed :]
5447             hidden_value = "".join(str(n) for n in ignored_nodes)
5448             if comment.value in FMT_OFF:
5449                 hidden_value = comment.value + "\n" + hidden_value
5450             if comment.value in FMT_SKIP:
5451                 hidden_value += "  " + comment.value
5452             if hidden_value.endswith("\n"):
5453                 # That happens when one of the `ignored_nodes` ended with a NEWLINE
5454                 # leaf (possibly followed by a DEDENT).
5455                 hidden_value = hidden_value[:-1]
5456             first_idx: Optional[int] = None
5457             for ignored in ignored_nodes:
5458                 index = ignored.remove()
5459                 if first_idx is None:
5460                     first_idx = index
5461             assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5462             assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5463             parent.insert_child(
5464                 first_idx,
5465                 Leaf(
5466                     STANDALONE_COMMENT,
5467                     hidden_value,
5468                     prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5469                 ),
5470             )
5471             return True
5472
5473     return False
5474
5475
5476 def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
5477     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5478
5479     If comment is skip, returns leaf only.
5480     Stops at the end of the block.
5481     """
5482     container: Optional[LN] = container_of(leaf)
5483     if comment.value in FMT_SKIP:
5484         prev_sibling = leaf.prev_sibling
5485         if comment.value in leaf.prefix and prev_sibling is not None:
5486             leaf.prefix = leaf.prefix.replace(comment.value, "")
5487             siblings = [prev_sibling]
5488             while (
5489                 "\n" not in prev_sibling.prefix
5490                 and prev_sibling.prev_sibling is not None
5491             ):
5492                 prev_sibling = prev_sibling.prev_sibling
5493                 siblings.insert(0, prev_sibling)
5494             for sibling in siblings:
5495                 yield sibling
5496         elif leaf.parent is not None:
5497             yield leaf.parent
5498         return
5499     while container is not None and container.type != token.ENDMARKER:
5500         if is_fmt_on(container):
5501             return
5502
5503         # fix for fmt: on in children
5504         if contains_fmt_on_at_column(container, leaf.column):
5505             for child in container.children:
5506                 if contains_fmt_on_at_column(child, leaf.column):
5507                     return
5508                 yield child
5509         else:
5510             yield container
5511             container = container.next_sibling
5512
5513
5514 def is_fmt_on(container: LN) -> bool:
5515     """Determine whether formatting is switched on within a container.
5516     Determined by whether the last `# fmt:` comment is `on` or `off`.
5517     """
5518     fmt_on = False
5519     for comment in list_comments(container.prefix, is_endmarker=False):
5520         if comment.value in FMT_ON:
5521             fmt_on = True
5522         elif comment.value in FMT_OFF:
5523             fmt_on = False
5524     return fmt_on
5525
5526
5527 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5528     """Determine if children at a given column have formatting switched on."""
5529     for child in container.children:
5530         if (
5531             isinstance(child, Node)
5532             and first_leaf_column(child) == column
5533             or isinstance(child, Leaf)
5534             and child.column == column
5535         ):
5536             if is_fmt_on(child):
5537                 return True
5538
5539     return False
5540
5541
5542 def first_leaf_column(node: Node) -> Optional[int]:
5543     """Returns the column of the first leaf child of a node."""
5544     for child in node.children:
5545         if isinstance(child, Leaf):
5546             return child.column
5547     return None
5548
5549
5550 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5551     """If it's safe, make the parens in the atom `node` invisible, recursively.
5552     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5553     as they are redundant.
5554
5555     Returns whether the node should itself be wrapped in invisible parentheses.
5556
5557     """
5558
5559     if (
5560         node.type != syms.atom
5561         or is_empty_tuple(node)
5562         or is_one_tuple(node)
5563         or (is_yield(node) and parent.type != syms.expr_stmt)
5564         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5565     ):
5566         return False
5567
5568     if is_walrus_assignment(node):
5569         if parent.type in [syms.annassign, syms.expr_stmt]:
5570             return False
5571
5572     first = node.children[0]
5573     last = node.children[-1]
5574     if first.type == token.LPAR and last.type == token.RPAR:
5575         middle = node.children[1]
5576         # make parentheses invisible
5577         first.value = ""  # type: ignore
5578         last.value = ""  # type: ignore
5579         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5580
5581         if is_atom_with_invisible_parens(middle):
5582             # Strip the invisible parens from `middle` by replacing
5583             # it with the child in-between the invisible parens
5584             middle.replace(middle.children[1])
5585
5586         return False
5587
5588     return True
5589
5590
5591 def is_atom_with_invisible_parens(node: LN) -> bool:
5592     """Given a `LN`, determines whether it's an atom `node` with invisible
5593     parens. Useful in dedupe-ing and normalizing parens.
5594     """
5595     if isinstance(node, Leaf) or node.type != syms.atom:
5596         return False
5597
5598     first, last = node.children[0], node.children[-1]
5599     return (
5600         isinstance(first, Leaf)
5601         and first.type == token.LPAR
5602         and first.value == ""
5603         and isinstance(last, Leaf)
5604         and last.type == token.RPAR
5605         and last.value == ""
5606     )
5607
5608
5609 def is_empty_tuple(node: LN) -> bool:
5610     """Return True if `node` holds an empty tuple."""
5611     return (
5612         node.type == syms.atom
5613         and len(node.children) == 2
5614         and node.children[0].type == token.LPAR
5615         and node.children[1].type == token.RPAR
5616     )
5617
5618
5619 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5620     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5621
5622     Parenthesis can be optional. Returns None otherwise"""
5623     if len(node.children) != 3:
5624         return None
5625
5626     lpar, wrapped, rpar = node.children
5627     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5628         return None
5629
5630     return wrapped
5631
5632
5633 def first_child_is_arith(node: Node) -> bool:
5634     """Whether first child is an arithmetic or a binary arithmetic expression"""
5635     expr_types = {
5636         syms.arith_expr,
5637         syms.shift_expr,
5638         syms.xor_expr,
5639         syms.and_expr,
5640     }
5641     return bool(node.children and node.children[0].type in expr_types)
5642
5643
5644 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5645     """Wrap `child` in parentheses.
5646
5647     This replaces `child` with an atom holding the parentheses and the old
5648     child.  That requires moving the prefix.
5649
5650     If `visible` is False, the leaves will be valueless (and thus invisible).
5651     """
5652     lpar = Leaf(token.LPAR, "(" if visible else "")
5653     rpar = Leaf(token.RPAR, ")" if visible else "")
5654     prefix = child.prefix
5655     child.prefix = ""
5656     index = child.remove() or 0
5657     new_child = Node(syms.atom, [lpar, child, rpar])
5658     new_child.prefix = prefix
5659     parent.insert_child(index, new_child)
5660
5661
5662 def is_one_tuple(node: LN) -> bool:
5663     """Return True if `node` holds a tuple with one element, with or without parens."""
5664     if node.type == syms.atom:
5665         gexp = unwrap_singleton_parenthesis(node)
5666         if gexp is None or gexp.type != syms.testlist_gexp:
5667             return False
5668
5669         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5670
5671     return (
5672         node.type in IMPLICIT_TUPLE
5673         and len(node.children) == 2
5674         and node.children[1].type == token.COMMA
5675     )
5676
5677
5678 def is_walrus_assignment(node: LN) -> bool:
5679     """Return True iff `node` is of the shape ( test := test )"""
5680     inner = unwrap_singleton_parenthesis(node)
5681     return inner is not None and inner.type == syms.namedexpr_test
5682
5683
5684 def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
5685     """Return True iff `node` is a trailer valid in a simple decorator"""
5686     return node.type == syms.trailer and (
5687         (
5688             len(node.children) == 2
5689             and node.children[0].type == token.DOT
5690             and node.children[1].type == token.NAME
5691         )
5692         # last trailer can be arguments
5693         or (
5694             last
5695             and len(node.children) == 3
5696             and node.children[0].type == token.LPAR
5697             # and node.children[1].type == syms.argument
5698             and node.children[2].type == token.RPAR
5699         )
5700     )
5701
5702
5703 def is_simple_decorator_expression(node: LN) -> bool:
5704     """Return True iff `node` could be a 'dotted name' decorator
5705
5706     This function takes the node of the 'namedexpr_test' of the new decorator
5707     grammar and test if it would be valid under the old decorator grammar.
5708
5709     The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
5710     The new grammar is : decorator: @ namedexpr_test NEWLINE
5711     """
5712     if node.type == token.NAME:
5713         return True
5714     if node.type == syms.power:
5715         if node.children:
5716             return (
5717                 node.children[0].type == token.NAME
5718                 and all(map(is_simple_decorator_trailer, node.children[1:-1]))
5719                 and (
5720                     len(node.children) < 2
5721                     or is_simple_decorator_trailer(node.children[-1], last=True)
5722                 )
5723             )
5724     return False
5725
5726
5727 def is_yield(node: LN) -> bool:
5728     """Return True if `node` holds a `yield` or `yield from` expression."""
5729     if node.type == syms.yield_expr:
5730         return True
5731
5732     if node.type == token.NAME and node.value == "yield":  # type: ignore
5733         return True
5734
5735     if node.type != syms.atom:
5736         return False
5737
5738     if len(node.children) != 3:
5739         return False
5740
5741     lpar, expr, rpar = node.children
5742     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5743         return is_yield(expr)
5744
5745     return False
5746
5747
5748 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5749     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5750
5751     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5752     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5753     extended iterable unpacking (PEP 3132) and additional unpacking
5754     generalizations (PEP 448).
5755     """
5756     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5757         return False
5758
5759     p = leaf.parent
5760     if p.type == syms.star_expr:
5761         # Star expressions are also used as assignment targets in extended
5762         # iterable unpacking (PEP 3132).  See what its parent is instead.
5763         if not p.parent:
5764             return False
5765
5766         p = p.parent
5767
5768     return p.type in within
5769
5770
5771 def is_multiline_string(leaf: Leaf) -> bool:
5772     """Return True if `leaf` is a multiline string that actually spans many lines."""
5773     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5774
5775
5776 def is_stub_suite(node: Node) -> bool:
5777     """Return True if `node` is a suite with a stub body."""
5778     if (
5779         len(node.children) != 4
5780         or node.children[0].type != token.NEWLINE
5781         or node.children[1].type != token.INDENT
5782         or node.children[3].type != token.DEDENT
5783     ):
5784         return False
5785
5786     return is_stub_body(node.children[2])
5787
5788
5789 def is_stub_body(node: LN) -> bool:
5790     """Return True if `node` is a simple statement containing an ellipsis."""
5791     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5792         return False
5793
5794     if len(node.children) != 2:
5795         return False
5796
5797     child = node.children[0]
5798     return (
5799         child.type == syms.atom
5800         and len(child.children) == 3
5801         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5802     )
5803
5804
5805 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5806     """Return maximum delimiter priority inside `node`.
5807
5808     This is specific to atoms with contents contained in a pair of parentheses.
5809     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5810     """
5811     if node.type != syms.atom:
5812         return 0
5813
5814     first = node.children[0]
5815     last = node.children[-1]
5816     if not (first.type == token.LPAR and last.type == token.RPAR):
5817         return 0
5818
5819     bt = BracketTracker()
5820     for c in node.children[1:-1]:
5821         if isinstance(c, Leaf):
5822             bt.mark(c)
5823         else:
5824             for leaf in c.leaves():
5825                 bt.mark(leaf)
5826     try:
5827         return bt.max_delimiter_priority()
5828
5829     except ValueError:
5830         return 0
5831
5832
5833 def ensure_visible(leaf: Leaf) -> None:
5834     """Make sure parentheses are visible.
5835
5836     They could be invisible as part of some statements (see
5837     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5838     """
5839     if leaf.type == token.LPAR:
5840         leaf.value = "("
5841     elif leaf.type == token.RPAR:
5842         leaf.value = ")"
5843
5844
5845 def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
5846     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5847
5848     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5849         return False
5850
5851     # We're essentially checking if the body is delimited by commas and there's more
5852     # than one of them (we're excluding the trailing comma and if the delimiter priority
5853     # is still commas, that means there's more).
5854     exclude = set()
5855     trailing_comma = False
5856     try:
5857         last_leaf = line.leaves[-1]
5858         if last_leaf.type == token.COMMA:
5859             trailing_comma = True
5860             exclude.add(id(last_leaf))
5861         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5862     except (IndexError, ValueError):
5863         return False
5864
5865     return max_priority == COMMA_PRIORITY and (
5866         (line.mode.magic_trailing_comma and trailing_comma)
5867         # always explode imports
5868         or opening_bracket.parent.type in {syms.atom, syms.import_from}
5869     )
5870
5871
5872 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5873     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5874     if opening.type != token.LPAR and closing.type != token.RPAR:
5875         return False
5876
5877     depth = closing.bracket_depth + 1
5878     for _opening_index, leaf in enumerate(leaves):
5879         if leaf is opening:
5880             break
5881
5882     else:
5883         raise LookupError("Opening paren not found in `leaves`")
5884
5885     commas = 0
5886     _opening_index += 1
5887     for leaf in leaves[_opening_index:]:
5888         if leaf is closing:
5889             break
5890
5891         bracket_depth = leaf.bracket_depth
5892         if bracket_depth == depth and leaf.type == token.COMMA:
5893             commas += 1
5894             if leaf.parent and leaf.parent.type in {
5895                 syms.arglist,
5896                 syms.typedargslist,
5897             }:
5898                 commas += 1
5899                 break
5900
5901     return commas < 2
5902
5903
5904 def get_features_used(node: Node) -> Set[Feature]:
5905     """Return a set of (relatively) new Python features used in this file.
5906
5907     Currently looking for:
5908     - f-strings;
5909     - underscores in numeric literals;
5910     - trailing commas after * or ** in function signatures and calls;
5911     - positional only arguments in function signatures and lambdas;
5912     - assignment expression;
5913     - relaxed decorator syntax;
5914     """
5915     features: Set[Feature] = set()
5916     for n in node.pre_order():
5917         if n.type == token.STRING:
5918             value_head = n.value[:2]  # type: ignore
5919             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5920                 features.add(Feature.F_STRINGS)
5921
5922         elif n.type == token.NUMBER:
5923             if "_" in n.value:  # type: ignore
5924                 features.add(Feature.NUMERIC_UNDERSCORES)
5925
5926         elif n.type == token.SLASH:
5927             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5928                 features.add(Feature.POS_ONLY_ARGUMENTS)
5929
5930         elif n.type == token.COLONEQUAL:
5931             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5932
5933         elif n.type == syms.decorator:
5934             if len(n.children) > 1 and not is_simple_decorator_expression(
5935                 n.children[1]
5936             ):
5937                 features.add(Feature.RELAXED_DECORATORS)
5938
5939         elif (
5940             n.type in {syms.typedargslist, syms.arglist}
5941             and n.children
5942             and n.children[-1].type == token.COMMA
5943         ):
5944             if n.type == syms.typedargslist:
5945                 feature = Feature.TRAILING_COMMA_IN_DEF
5946             else:
5947                 feature = Feature.TRAILING_COMMA_IN_CALL
5948
5949             for ch in n.children:
5950                 if ch.type in STARS:
5951                     features.add(feature)
5952
5953                 if ch.type == syms.argument:
5954                     for argch in ch.children:
5955                         if argch.type in STARS:
5956                             features.add(feature)
5957
5958     return features
5959
5960
5961 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5962     """Detect the version to target based on the nodes used."""
5963     features = get_features_used(node)
5964     return {
5965         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5966     }
5967
5968
5969 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5970     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5971
5972     Brackets can be omitted if the entire trailer up to and including
5973     a preceding closing bracket fits in one line.
5974
5975     Yielded sets are cumulative (contain results of previous yields, too).  First
5976     set is empty, unless the line should explode, in which case bracket pairs until
5977     the one that needs to explode are omitted.
5978     """
5979
5980     omit: Set[LeafID] = set()
5981     if not line.magic_trailing_comma:
5982         yield omit
5983
5984     length = 4 * line.depth
5985     opening_bracket: Optional[Leaf] = None
5986     closing_bracket: Optional[Leaf] = None
5987     inner_brackets: Set[LeafID] = set()
5988     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5989         length += leaf_length
5990         if length > line_length:
5991             break
5992
5993         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5994         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5995             break
5996
5997         if opening_bracket:
5998             if leaf is opening_bracket:
5999                 opening_bracket = None
6000             elif leaf.type in CLOSING_BRACKETS:
6001                 prev = line.leaves[index - 1] if index > 0 else None
6002                 if (
6003                     prev
6004                     and prev.type == token.COMMA
6005                     and not is_one_tuple_between(
6006                         leaf.opening_bracket, leaf, line.leaves
6007                     )
6008                 ):
6009                     # Never omit bracket pairs with trailing commas.
6010                     # We need to explode on those.
6011                     break
6012
6013                 inner_brackets.add(id(leaf))
6014         elif leaf.type in CLOSING_BRACKETS:
6015             prev = line.leaves[index - 1] if index > 0 else None
6016             if prev and prev.type in OPENING_BRACKETS:
6017                 # Empty brackets would fail a split so treat them as "inner"
6018                 # brackets (e.g. only add them to the `omit` set if another
6019                 # pair of brackets was good enough.
6020                 inner_brackets.add(id(leaf))
6021                 continue
6022
6023             if closing_bracket:
6024                 omit.add(id(closing_bracket))
6025                 omit.update(inner_brackets)
6026                 inner_brackets.clear()
6027                 yield omit
6028
6029             if (
6030                 prev
6031                 and prev.type == token.COMMA
6032                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
6033             ):
6034                 # Never omit bracket pairs with trailing commas.
6035                 # We need to explode on those.
6036                 break
6037
6038             if leaf.value:
6039                 opening_bracket = leaf.opening_bracket
6040                 closing_bracket = leaf
6041
6042
6043 def get_future_imports(node: Node) -> Set[str]:
6044     """Return a set of __future__ imports in the file."""
6045     imports: Set[str] = set()
6046
6047     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
6048         for child in children:
6049             if isinstance(child, Leaf):
6050                 if child.type == token.NAME:
6051                     yield child.value
6052
6053             elif child.type == syms.import_as_name:
6054                 orig_name = child.children[0]
6055                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
6056                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
6057                 yield orig_name.value
6058
6059             elif child.type == syms.import_as_names:
6060                 yield from get_imports_from_children(child.children)
6061
6062             else:
6063                 raise AssertionError("Invalid syntax parsing imports")
6064
6065     for child in node.children:
6066         if child.type != syms.simple_stmt:
6067             break
6068
6069         first_child = child.children[0]
6070         if isinstance(first_child, Leaf):
6071             # Continue looking if we see a docstring; otherwise stop.
6072             if (
6073                 len(child.children) == 2
6074                 and first_child.type == token.STRING
6075                 and child.children[1].type == token.NEWLINE
6076             ):
6077                 continue
6078
6079             break
6080
6081         elif first_child.type == syms.import_from:
6082             module_name = first_child.children[1]
6083             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
6084                 break
6085
6086             imports |= set(get_imports_from_children(first_child.children[3:]))
6087         else:
6088             break
6089
6090     return imports
6091
6092
6093 @lru_cache()
6094 def get_gitignore(root: Path) -> PathSpec:
6095     """ Return a PathSpec matching gitignore content if present."""
6096     gitignore = root / ".gitignore"
6097     lines: List[str] = []
6098     if gitignore.is_file():
6099         with gitignore.open() as gf:
6100             lines = gf.readlines()
6101     return PathSpec.from_lines("gitwildmatch", lines)
6102
6103
6104 def normalize_path_maybe_ignore(
6105     path: Path, root: Path, report: "Report"
6106 ) -> Optional[str]:
6107     """Normalize `path`. May return `None` if `path` was ignored.
6108
6109     `report` is where "path ignored" output goes.
6110     """
6111     try:
6112         abspath = path if path.is_absolute() else Path.cwd() / path
6113         normalized_path = abspath.resolve().relative_to(root).as_posix()
6114     except OSError as e:
6115         report.path_ignored(path, f"cannot be read because {e}")
6116         return None
6117
6118     except ValueError:
6119         if path.is_symlink():
6120             report.path_ignored(path, f"is a symbolic link that points outside {root}")
6121             return None
6122
6123         raise
6124
6125     return normalized_path
6126
6127
6128 def path_is_excluded(
6129     normalized_path: str,
6130     pattern: Optional[Pattern[str]],
6131 ) -> bool:
6132     match = pattern.search(normalized_path) if pattern else None
6133     return bool(match and match.group(0))
6134
6135
6136 def gen_python_files(
6137     paths: Iterable[Path],
6138     root: Path,
6139     include: Optional[Pattern[str]],
6140     exclude: Pattern[str],
6141     extend_exclude: Optional[Pattern[str]],
6142     force_exclude: Optional[Pattern[str]],
6143     report: "Report",
6144     gitignore: PathSpec,
6145 ) -> Iterator[Path]:
6146     """Generate all files under `path` whose paths are not excluded by the
6147     `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
6148     but are included by the `include` regex.
6149
6150     Symbolic links pointing outside of the `root` directory are ignored.
6151
6152     `report` is where output about exclusions goes.
6153     """
6154     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
6155     for child in paths:
6156         normalized_path = normalize_path_maybe_ignore(child, root, report)
6157         if normalized_path is None:
6158             continue
6159
6160         # First ignore files matching .gitignore
6161         if gitignore.match_file(normalized_path):
6162             report.path_ignored(child, "matches the .gitignore file content")
6163             continue
6164
6165         # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
6166         normalized_path = "/" + normalized_path
6167         if child.is_dir():
6168             normalized_path += "/"
6169
6170         if path_is_excluded(normalized_path, exclude):
6171             report.path_ignored(child, "matches the --exclude regular expression")
6172             continue
6173
6174         if path_is_excluded(normalized_path, extend_exclude):
6175             report.path_ignored(
6176                 child, "matches the --extend-exclude regular expression"
6177             )
6178             continue
6179
6180         if path_is_excluded(normalized_path, force_exclude):
6181             report.path_ignored(child, "matches the --force-exclude regular expression")
6182             continue
6183
6184         if child.is_dir():
6185             yield from gen_python_files(
6186                 child.iterdir(),
6187                 root,
6188                 include,
6189                 exclude,
6190                 extend_exclude,
6191                 force_exclude,
6192                 report,
6193                 gitignore,
6194             )
6195
6196         elif child.is_file():
6197             include_match = include.search(normalized_path) if include else True
6198             if include_match:
6199                 yield child
6200
6201
6202 @lru_cache()
6203 def find_project_root(srcs: Iterable[str]) -> Path:
6204     """Return a directory containing .git, .hg, or pyproject.toml.
6205
6206     That directory will be a common parent of all files and directories
6207     passed in `srcs`.
6208
6209     If no directory in the tree contains a marker that would specify it's the
6210     project root, the root of the file system is returned.
6211     """
6212     if not srcs:
6213         return Path("/").resolve()
6214
6215     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
6216
6217     # A list of lists of parents for each 'src'. 'src' is included as a
6218     # "parent" of itself if it is a directory
6219     src_parents = [
6220         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
6221     ]
6222
6223     common_base = max(
6224         set.intersection(*(set(parents) for parents in src_parents)),
6225         key=lambda path: path.parts,
6226     )
6227
6228     for directory in (common_base, *common_base.parents):
6229         if (directory / ".git").exists():
6230             return directory
6231
6232         if (directory / ".hg").is_dir():
6233             return directory
6234
6235         if (directory / "pyproject.toml").is_file():
6236             return directory
6237
6238     return directory
6239
6240
6241 @dataclass
6242 class Report:
6243     """Provides a reformatting counter. Can be rendered with `str(report)`."""
6244
6245     check: bool = False
6246     diff: bool = False
6247     quiet: bool = False
6248     verbose: bool = False
6249     change_count: int = 0
6250     same_count: int = 0
6251     failure_count: int = 0
6252
6253     def done(self, src: Path, changed: Changed) -> None:
6254         """Increment the counter for successful reformatting. Write out a message."""
6255         if changed is Changed.YES:
6256             reformatted = "would reformat" if self.check or self.diff else "reformatted"
6257             if self.verbose or not self.quiet:
6258                 out(f"{reformatted} {src}")
6259             self.change_count += 1
6260         else:
6261             if self.verbose:
6262                 if changed is Changed.NO:
6263                     msg = f"{src} already well formatted, good job."
6264                 else:
6265                     msg = f"{src} wasn't modified on disk since last run."
6266                 out(msg, bold=False)
6267             self.same_count += 1
6268
6269     def failed(self, src: Path, message: str) -> None:
6270         """Increment the counter for failed reformatting. Write out a message."""
6271         err(f"error: cannot format {src}: {message}")
6272         self.failure_count += 1
6273
6274     def path_ignored(self, path: Path, message: str) -> None:
6275         if self.verbose:
6276             out(f"{path} ignored: {message}", bold=False)
6277
6278     @property
6279     def return_code(self) -> int:
6280         """Return the exit code that the app should use.
6281
6282         This considers the current state of changed files and failures:
6283         - if there were any failures, return 123;
6284         - if any files were changed and --check is being used, return 1;
6285         - otherwise return 0.
6286         """
6287         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6288         # 126 we have special return codes reserved by the shell.
6289         if self.failure_count:
6290             return 123
6291
6292         elif self.change_count and self.check:
6293             return 1
6294
6295         return 0
6296
6297     def __str__(self) -> str:
6298         """Render a color report of the current state.
6299
6300         Use `click.unstyle` to remove colors.
6301         """
6302         if self.check or self.diff:
6303             reformatted = "would be reformatted"
6304             unchanged = "would be left unchanged"
6305             failed = "would fail to reformat"
6306         else:
6307             reformatted = "reformatted"
6308             unchanged = "left unchanged"
6309             failed = "failed to reformat"
6310         report = []
6311         if self.change_count:
6312             s = "s" if self.change_count > 1 else ""
6313             report.append(
6314                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6315             )
6316         if self.same_count:
6317             s = "s" if self.same_count > 1 else ""
6318             report.append(f"{self.same_count} file{s} {unchanged}")
6319         if self.failure_count:
6320             s = "s" if self.failure_count > 1 else ""
6321             report.append(
6322                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6323             )
6324         return ", ".join(report) + "."
6325
6326
6327 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6328     filename = "<unknown>"
6329     if sys.version_info >= (3, 8):
6330         # TODO: support Python 4+ ;)
6331         for minor_version in range(sys.version_info[1], 4, -1):
6332             try:
6333                 return ast.parse(src, filename, feature_version=(3, minor_version))
6334             except SyntaxError:
6335                 continue
6336     else:
6337         for feature_version in (7, 6):
6338             try:
6339                 return ast3.parse(src, filename, feature_version=feature_version)
6340             except SyntaxError:
6341                 continue
6342
6343     return ast27.parse(src)
6344
6345
6346 def _fixup_ast_constants(
6347     node: Union[ast.AST, ast3.AST, ast27.AST]
6348 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6349     """Map ast nodes deprecated in 3.8 to Constant."""
6350     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6351         return ast.Constant(value=node.s)
6352
6353     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6354         return ast.Constant(value=node.n)
6355
6356     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6357         return ast.Constant(value=node.value)
6358
6359     return node
6360
6361
6362 def _stringify_ast(
6363     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6364 ) -> Iterator[str]:
6365     """Simple visitor generating strings to compare ASTs by content."""
6366
6367     node = _fixup_ast_constants(node)
6368
6369     yield f"{'  ' * depth}{node.__class__.__name__}("
6370
6371     for field in sorted(node._fields):  # noqa: F402
6372         # TypeIgnore has only one field 'lineno' which breaks this comparison
6373         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6374         if sys.version_info >= (3, 8):
6375             type_ignore_classes += (ast.TypeIgnore,)
6376         if isinstance(node, type_ignore_classes):
6377             break
6378
6379         try:
6380             value = getattr(node, field)
6381         except AttributeError:
6382             continue
6383
6384         yield f"{'  ' * (depth+1)}{field}="
6385
6386         if isinstance(value, list):
6387             for item in value:
6388                 # Ignore nested tuples within del statements, because we may insert
6389                 # parentheses and they change the AST.
6390                 if (
6391                     field == "targets"
6392                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6393                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6394                 ):
6395                     for item in item.elts:
6396                         yield from _stringify_ast(item, depth + 2)
6397
6398                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6399                     yield from _stringify_ast(item, depth + 2)
6400
6401         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6402             yield from _stringify_ast(value, depth + 2)
6403
6404         else:
6405             # Constant strings may be indented across newlines, if they are
6406             # docstrings; fold spaces after newlines when comparing. Similarly,
6407             # trailing and leading space may be removed.
6408             if (
6409                 isinstance(node, ast.Constant)
6410                 and field == "value"
6411                 and isinstance(value, str)
6412             ):
6413                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6414             else:
6415                 normalized = value
6416             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6417
6418     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6419
6420
6421 def assert_equivalent(src: str, dst: str) -> None:
6422     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6423     try:
6424         src_ast = parse_ast(src)
6425     except Exception as exc:
6426         raise AssertionError(
6427             "cannot use --safe with this file; failed to parse source file.  AST"
6428             f" error message: {exc}"
6429         )
6430
6431     try:
6432         dst_ast = parse_ast(dst)
6433     except Exception as exc:
6434         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6435         raise AssertionError(
6436             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6437             " on https://github.com/psf/black/issues.  This invalid output might be"
6438             f" helpful: {log}"
6439         ) from None
6440
6441     src_ast_str = "\n".join(_stringify_ast(src_ast))
6442     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6443     if src_ast_str != dst_ast_str:
6444         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6445         raise AssertionError(
6446             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6447             " source.  Please report a bug on https://github.com/psf/black/issues. "
6448             f" This diff might be helpful: {log}"
6449         ) from None
6450
6451
6452 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6453     """Raise AssertionError if `dst` reformats differently the second time."""
6454     newdst = format_str(dst, mode=mode)
6455     if dst != newdst:
6456         log = dump_to_file(
6457             str(mode),
6458             diff(src, dst, "source", "first pass"),
6459             diff(dst, newdst, "first pass", "second pass"),
6460         )
6461         raise AssertionError(
6462             "INTERNAL ERROR: Black produced different code on the second pass of the"
6463             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6464             f"  This diff might be helpful: {log}"
6465         ) from None
6466
6467
6468 @mypyc_attr(patchable=True)
6469 def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
6470     """Dump `output` to a temporary file. Return path to the file."""
6471     with tempfile.NamedTemporaryFile(
6472         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6473     ) as f:
6474         for lines in output:
6475             f.write(lines)
6476             if ensure_final_newline and lines and lines[-1] != "\n":
6477                 f.write("\n")
6478     return f.name
6479
6480
6481 @contextmanager
6482 def nullcontext() -> Iterator[None]:
6483     """Return an empty context manager.
6484
6485     To be used like `nullcontext` in Python 3.7.
6486     """
6487     yield
6488
6489
6490 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6491     """Return a unified diff string between strings `a` and `b`."""
6492     import difflib
6493
6494     a_lines = [line for line in a.splitlines(keepends=True)]
6495     b_lines = [line for line in b.splitlines(keepends=True)]
6496     diff_lines = []
6497     for line in difflib.unified_diff(
6498         a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
6499     ):
6500         # Work around https://bugs.python.org/issue2142
6501         # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
6502         if line[-1] == "\n":
6503             diff_lines.append(line)
6504         else:
6505             diff_lines.append(line + "\n")
6506             diff_lines.append("\\ No newline at end of file\n")
6507     return "".join(diff_lines)
6508
6509
6510 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6511     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6512     err("Aborted!")
6513     for task in tasks:
6514         task.cancel()
6515
6516
6517 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6518     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6519     try:
6520         if sys.version_info[:2] >= (3, 7):
6521             all_tasks = asyncio.all_tasks
6522         else:
6523             all_tasks = asyncio.Task.all_tasks
6524         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6525         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6526         if not to_cancel:
6527             return
6528
6529         for task in to_cancel:
6530             task.cancel()
6531         loop.run_until_complete(
6532             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6533         )
6534     finally:
6535         # `concurrent.futures.Future` objects cannot be cancelled once they
6536         # are already running. There might be some when the `shutdown()` happened.
6537         # Silence their logger's spew about the event loop being closed.
6538         cf_logger = logging.getLogger("concurrent.futures")
6539         cf_logger.setLevel(logging.CRITICAL)
6540         loop.close()
6541
6542
6543 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6544     """Replace `regex` with `replacement` twice on `original`.
6545
6546     This is used by string normalization to perform replaces on
6547     overlapping matches.
6548     """
6549     return regex.sub(replacement, regex.sub(replacement, original))
6550
6551
6552 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6553     """Compile a regular expression string in `regex`.
6554
6555     If it contains newlines, use verbose mode.
6556     """
6557     if "\n" in regex:
6558         regex = "(?x)" + regex
6559     compiled: Pattern[str] = re.compile(regex)
6560     return compiled
6561
6562
6563 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6564     """Like `reversed(enumerate(sequence))` if that were possible."""
6565     index = len(sequence) - 1
6566     for element in reversed(sequence):
6567         yield (index, element)
6568         index -= 1
6569
6570
6571 def enumerate_with_length(
6572     line: Line, reversed: bool = False
6573 ) -> Iterator[Tuple[Index, Leaf, int]]:
6574     """Return an enumeration of leaves with their length.
6575
6576     Stops prematurely on multiline strings and standalone comments.
6577     """
6578     op = cast(
6579         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6580         enumerate_reversed if reversed else enumerate,
6581     )
6582     for index, leaf in op(line.leaves):
6583         length = len(leaf.prefix) + len(leaf.value)
6584         if "\n" in leaf.value:
6585             return  # Multiline strings, we can't continue.
6586
6587         for comment in line.comments_after(leaf):
6588             length += len(comment.value)
6589
6590         yield index, leaf, length
6591
6592
6593 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6594     """Return True if `line` is no longer than `line_length`.
6595
6596     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6597     """
6598     if not line_str:
6599         line_str = line_to_string(line)
6600     return (
6601         len(line_str) <= line_length
6602         and "\n" not in line_str  # multiline strings
6603         and not line.contains_standalone_comments()
6604     )
6605
6606
6607 def can_be_split(line: Line) -> bool:
6608     """Return False if the line cannot be split *for sure*.
6609
6610     This is not an exhaustive search but a cheap heuristic that we can use to
6611     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6612     in unnecessary parentheses).
6613     """
6614     leaves = line.leaves
6615     if len(leaves) < 2:
6616         return False
6617
6618     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6619         call_count = 0
6620         dot_count = 0
6621         next = leaves[-1]
6622         for leaf in leaves[-2::-1]:
6623             if leaf.type in OPENING_BRACKETS:
6624                 if next.type not in CLOSING_BRACKETS:
6625                     return False
6626
6627                 call_count += 1
6628             elif leaf.type == token.DOT:
6629                 dot_count += 1
6630             elif leaf.type == token.NAME:
6631                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6632                     return False
6633
6634             elif leaf.type not in CLOSING_BRACKETS:
6635                 return False
6636
6637             if dot_count > 1 and call_count > 1:
6638                 return False
6639
6640     return True
6641
6642
6643 def can_omit_invisible_parens(
6644     line: Line,
6645     line_length: int,
6646     omit_on_explode: Collection[LeafID] = (),
6647 ) -> bool:
6648     """Does `line` have a shape safe to reformat without optional parens around it?
6649
6650     Returns True for only a subset of potentially nice looking formattings but
6651     the point is to not return false positives that end up producing lines that
6652     are too long.
6653     """
6654     bt = line.bracket_tracker
6655     if not bt.delimiters:
6656         # Without delimiters the optional parentheses are useless.
6657         return True
6658
6659     max_priority = bt.max_delimiter_priority()
6660     if bt.delimiter_count_with_priority(max_priority) > 1:
6661         # With more than one delimiter of a kind the optional parentheses read better.
6662         return False
6663
6664     if max_priority == DOT_PRIORITY:
6665         # A single stranded method call doesn't require optional parentheses.
6666         return True
6667
6668     assert len(line.leaves) >= 2, "Stranded delimiter"
6669
6670     # With a single delimiter, omit if the expression starts or ends with
6671     # a bracket.
6672     first = line.leaves[0]
6673     second = line.leaves[1]
6674     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6675         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6676             return True
6677
6678         # Note: we are not returning False here because a line might have *both*
6679         # a leading opening bracket and a trailing closing bracket.  If the
6680         # opening bracket doesn't match our rule, maybe the closing will.
6681
6682     penultimate = line.leaves[-2]
6683     last = line.leaves[-1]
6684     if line.magic_trailing_comma:
6685         try:
6686             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6687         except LookupError:
6688             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6689             return False
6690
6691     if (
6692         last.type == token.RPAR
6693         or last.type == token.RBRACE
6694         or (
6695             # don't use indexing for omitting optional parentheses;
6696             # it looks weird
6697             last.type == token.RSQB
6698             and last.parent
6699             and last.parent.type != syms.trailer
6700         )
6701     ):
6702         if penultimate.type in OPENING_BRACKETS:
6703             # Empty brackets don't help.
6704             return False
6705
6706         if is_multiline_string(first):
6707             # Additional wrapping of a multiline string in this situation is
6708             # unnecessary.
6709             return True
6710
6711         if line.magic_trailing_comma and penultimate.type == token.COMMA:
6712             # The rightmost non-omitted bracket pair is the one we want to explode on.
6713             return True
6714
6715         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6716             return True
6717
6718     return False
6719
6720
6721 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6722     """See `can_omit_invisible_parens`."""
6723     remainder = False
6724     length = 4 * line.depth
6725     _index = -1
6726     for _index, leaf, leaf_length in enumerate_with_length(line):
6727         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6728             remainder = True
6729         if remainder:
6730             length += leaf_length
6731             if length > line_length:
6732                 break
6733
6734             if leaf.type in OPENING_BRACKETS:
6735                 # There are brackets we can further split on.
6736                 remainder = False
6737
6738     else:
6739         # checked the entire string and line length wasn't exceeded
6740         if len(line.leaves) == _index + 1:
6741             return True
6742
6743     return False
6744
6745
6746 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6747     """See `can_omit_invisible_parens`."""
6748     length = 4 * line.depth
6749     seen_other_brackets = False
6750     for _index, leaf, leaf_length in enumerate_with_length(line):
6751         length += leaf_length
6752         if leaf is last.opening_bracket:
6753             if seen_other_brackets or length <= line_length:
6754                 return True
6755
6756         elif leaf.type in OPENING_BRACKETS:
6757             # There are brackets we can further split on.
6758             seen_other_brackets = True
6759
6760     return False
6761
6762
6763 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6764     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6765     stop_after = None
6766     last = None
6767     for leaf in reversed(leaves):
6768         if stop_after:
6769             if leaf is stop_after:
6770                 stop_after = None
6771             continue
6772
6773         if last:
6774             return leaf, last
6775
6776         if id(leaf) in omit:
6777             stop_after = leaf.opening_bracket
6778         else:
6779             last = leaf
6780     else:
6781         raise LookupError("Last two leaves were also skipped")
6782
6783
6784 def run_transformer(
6785     line: Line,
6786     transform: Transformer,
6787     mode: Mode,
6788     features: Collection[Feature],
6789     *,
6790     line_str: str = "",
6791 ) -> List[Line]:
6792     if not line_str:
6793         line_str = line_to_string(line)
6794     result: List[Line] = []
6795     for transformed_line in transform(line, features):
6796         if str(transformed_line).strip("\n") == line_str:
6797             raise CannotTransform("Line transformer returned an unchanged result")
6798
6799         result.extend(transform_line(transformed_line, mode=mode, features=features))
6800
6801     if not (
6802         transform.__name__ == "rhs"
6803         and line.bracket_tracker.invisible
6804         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6805         and not line.contains_multiline_strings()
6806         and not result[0].contains_uncollapsable_type_comments()
6807         and not result[0].contains_unsplittable_type_ignore()
6808         and not is_line_short_enough(result[0], line_length=mode.line_length)
6809     ):
6810         return result
6811
6812     line_copy = line.clone()
6813     append_leaves(line_copy, line, line.leaves)
6814     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6815     second_opinion = run_transformer(
6816         line_copy, transform, mode, features_fop, line_str=line_str
6817     )
6818     if all(
6819         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6820     ):
6821         result = second_opinion
6822     return result
6823
6824
6825 def get_cache_file(mode: Mode) -> Path:
6826     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6827
6828
6829 def read_cache(mode: Mode) -> Cache:
6830     """Read the cache if it exists and is well formed.
6831
6832     If it is not well formed, the call to write_cache later should resolve the issue.
6833     """
6834     cache_file = get_cache_file(mode)
6835     if not cache_file.exists():
6836         return {}
6837
6838     with cache_file.open("rb") as fobj:
6839         try:
6840             cache: Cache = pickle.load(fobj)
6841         except (pickle.UnpicklingError, ValueError):
6842             return {}
6843
6844     return cache
6845
6846
6847 def get_cache_info(path: Path) -> CacheInfo:
6848     """Return the information used to check if a file is already formatted or not."""
6849     stat = path.stat()
6850     return stat.st_mtime, stat.st_size
6851
6852
6853 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6854     """Split an iterable of paths in `sources` into two sets.
6855
6856     The first contains paths of files that modified on disk or are not in the
6857     cache. The other contains paths to non-modified files.
6858     """
6859     todo, done = set(), set()
6860     for src in sources:
6861         res_src = src.resolve()
6862         if cache.get(str(res_src)) != get_cache_info(res_src):
6863             todo.add(src)
6864         else:
6865             done.add(src)
6866     return todo, done
6867
6868
6869 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6870     """Update the cache file."""
6871     cache_file = get_cache_file(mode)
6872     try:
6873         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6874         new_cache = {
6875             **cache,
6876             **{str(src.resolve()): get_cache_info(src) for src in sources},
6877         }
6878         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6879             pickle.dump(new_cache, f, protocol=4)
6880         os.replace(f.name, cache_file)
6881     except OSError:
6882         pass
6883
6884
6885 def patch_click() -> None:
6886     """Make Click not crash.
6887
6888     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6889     default which restricts paths that it can access during the lifetime of the
6890     application.  Click refuses to work in this scenario by raising a RuntimeError.
6891
6892     In case of Black the likelihood that non-ASCII characters are going to be used in
6893     file paths is minimal since it's Python source code.  Moreover, this crash was
6894     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6895     """
6896     try:
6897         from click import core
6898         from click import _unicodefun  # type: ignore
6899     except ModuleNotFoundError:
6900         return
6901
6902     for module in (core, _unicodefun):
6903         if hasattr(module, "_verify_python3_env"):
6904             module._verify_python3_env = lambda: None
6905
6906
6907 def patched_main() -> None:
6908     freeze_support()
6909     patch_click()
6910     main()
6911
6912
6913 def is_docstring(leaf: Leaf) -> bool:
6914     if not is_multiline_string(leaf):
6915         # For the purposes of docstring re-indentation, we don't need to do anything
6916         # with single-line docstrings.
6917         return False
6918
6919     if prev_siblings_are(
6920         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
6921     ):
6922         return True
6923
6924     # Multiline docstring on the same line as the `def`.
6925     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
6926         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
6927         # grammar. We're safe to return True without further checks.
6928         return True
6929
6930     return False
6931
6932
6933 def lines_with_leading_tabs_expanded(s: str) -> List[str]:
6934     """
6935     Splits string into lines and expands only leading tabs (following the normal
6936     Python rules)
6937     """
6938     lines = []
6939     for line in s.splitlines():
6940         # Find the index of the first non-whitespace character after a string of
6941         # whitespace that includes at least one tab
6942         match = re.match(r"\s*\t+\s*(\S)", line)
6943         if match:
6944             first_non_whitespace_idx = match.start(1)
6945
6946             lines.append(
6947                 line[:first_non_whitespace_idx].expandtabs()
6948                 + line[first_non_whitespace_idx:]
6949             )
6950         else:
6951             lines.append(line)
6952     return lines
6953
6954
6955 def fix_docstring(docstring: str, prefix: str) -> str:
6956     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6957     if not docstring:
6958         return ""
6959     lines = lines_with_leading_tabs_expanded(docstring)
6960     # Determine minimum indentation (first line doesn't count):
6961     indent = sys.maxsize
6962     for line in lines[1:]:
6963         stripped = line.lstrip()
6964         if stripped:
6965             indent = min(indent, len(line) - len(stripped))
6966     # Remove indentation (first line is special):
6967     trimmed = [lines[0].strip()]
6968     if indent < sys.maxsize:
6969         last_line_idx = len(lines) - 2
6970         for i, line in enumerate(lines[1:]):
6971             stripped_line = line[indent:].rstrip()
6972             if stripped_line or i == last_line_idx:
6973                 trimmed.append(prefix + stripped_line)
6974             else:
6975                 trimmed.append("")
6976     return "\n".join(trimmed)
6977
6978
6979 if __name__ == "__main__":
6980     patched_main()