]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

52a57695aefdb0dc2202eda8b745b866187eba0e
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from mypy_extensions import mypyc_attr
46
47 from appdirs import user_cache_dir
48 from dataclasses import dataclass, field, replace
49 import click
50 import toml
51
52 try:
53     from typed_ast import ast3, ast27
54 except ImportError:
55     if sys.version_info < (3, 8):
56         print(
57             "The typed_ast package is not installed.\n"
58             "You can install it with `python3 -m pip install typed-ast`.",
59             file=sys.stderr,
60         )
61         sys.exit(1)
62     else:
63         ast3 = ast27 = ast
64
65 from pathspec import PathSpec
66
67 # lib2to3 fork
68 from blib2to3.pytree import Node, Leaf, type_repr
69 from blib2to3 import pygram, pytree
70 from blib2to3.pgen2 import driver, token
71 from blib2to3.pgen2.grammar import Grammar
72 from blib2to3.pgen2.parse import ParseError
73
74 from _black_version import version as __version__
75
76 if sys.version_info < (3, 8):
77     from typing_extensions import Final
78 else:
79     from typing import Final
80
81 if TYPE_CHECKING:
82     import colorama  # noqa: F401
83
84 DEFAULT_LINE_LENGTH = 88
85 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
86 DEFAULT_INCLUDES = r"\.pyi?$"
87 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
88 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
89
90 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
91
92
93 # types
94 FileContent = str
95 Encoding = str
96 NewLine = str
97 Depth = int
98 NodeType = int
99 ParserState = int
100 LeafID = int
101 StringID = int
102 Priority = int
103 Index = int
104 LN = Union[Leaf, Node]
105 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
106 Timestamp = float
107 FileSize = int
108 CacheInfo = Tuple[Timestamp, FileSize]
109 Cache = Dict[str, CacheInfo]
110 out = partial(click.secho, bold=True, err=True)
111 err = partial(click.secho, fg="red", err=True)
112
113 pygram.initialize(CACHE_DIR)
114 syms = pygram.python_symbols
115
116
117 class NothingChanged(UserWarning):
118     """Raised when reformatted code is the same as source."""
119
120
121 class CannotTransform(Exception):
122     """Base class for errors raised by Transformers."""
123
124
125 class CannotSplit(CannotTransform):
126     """A readable split that fits the allotted line length is impossible."""
127
128
129 class InvalidInput(ValueError):
130     """Raised when input source code fails all parse attempts."""
131
132
133 class BracketMatchError(KeyError):
134     """Raised when an opening bracket is unable to be matched to a closing bracket."""
135
136
137 T = TypeVar("T")
138 E = TypeVar("E", bound=Exception)
139
140
141 class Ok(Generic[T]):
142     def __init__(self, value: T) -> None:
143         self._value = value
144
145     def ok(self) -> T:
146         return self._value
147
148
149 class Err(Generic[E]):
150     def __init__(self, e: E) -> None:
151         self._e = e
152
153     def err(self) -> E:
154         return self._e
155
156
157 # The 'Result' return type is used to implement an error-handling model heavily
158 # influenced by that used by the Rust programming language
159 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
160 Result = Union[Ok[T], Err[E]]
161 TResult = Result[T, CannotTransform]  # (T)ransform Result
162 TMatchResult = TResult[Index]
163
164
165 class WriteBack(Enum):
166     NO = 0
167     YES = 1
168     DIFF = 2
169     CHECK = 3
170     COLOR_DIFF = 4
171
172     @classmethod
173     def from_configuration(
174         cls, *, check: bool, diff: bool, color: bool = False
175     ) -> "WriteBack":
176         if check and not diff:
177             return cls.CHECK
178
179         if diff and color:
180             return cls.COLOR_DIFF
181
182         return cls.DIFF if diff else cls.YES
183
184
185 class Changed(Enum):
186     NO = 0
187     CACHED = 1
188     YES = 2
189
190
191 class TargetVersion(Enum):
192     PY27 = 2
193     PY33 = 3
194     PY34 = 4
195     PY35 = 5
196     PY36 = 6
197     PY37 = 7
198     PY38 = 8
199     PY39 = 9
200
201     def is_python2(self) -> bool:
202         return self is TargetVersion.PY27
203
204
205 class Feature(Enum):
206     # All string literals are unicode
207     UNICODE_LITERALS = 1
208     F_STRINGS = 2
209     NUMERIC_UNDERSCORES = 3
210     TRAILING_COMMA_IN_CALL = 4
211     TRAILING_COMMA_IN_DEF = 5
212     # The following two feature-flags are mutually exclusive, and exactly one should be
213     # set for every version of python.
214     ASYNC_IDENTIFIERS = 6
215     ASYNC_KEYWORDS = 7
216     ASSIGNMENT_EXPRESSIONS = 8
217     POS_ONLY_ARGUMENTS = 9
218     RELAXED_DECORATORS = 10
219     FORCE_OPTIONAL_PARENTHESES = 50
220
221
222 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
223     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
224     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
225     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
226     TargetVersion.PY35: {
227         Feature.UNICODE_LITERALS,
228         Feature.TRAILING_COMMA_IN_CALL,
229         Feature.ASYNC_IDENTIFIERS,
230     },
231     TargetVersion.PY36: {
232         Feature.UNICODE_LITERALS,
233         Feature.F_STRINGS,
234         Feature.NUMERIC_UNDERSCORES,
235         Feature.TRAILING_COMMA_IN_CALL,
236         Feature.TRAILING_COMMA_IN_DEF,
237         Feature.ASYNC_IDENTIFIERS,
238     },
239     TargetVersion.PY37: {
240         Feature.UNICODE_LITERALS,
241         Feature.F_STRINGS,
242         Feature.NUMERIC_UNDERSCORES,
243         Feature.TRAILING_COMMA_IN_CALL,
244         Feature.TRAILING_COMMA_IN_DEF,
245         Feature.ASYNC_KEYWORDS,
246     },
247     TargetVersion.PY38: {
248         Feature.UNICODE_LITERALS,
249         Feature.F_STRINGS,
250         Feature.NUMERIC_UNDERSCORES,
251         Feature.TRAILING_COMMA_IN_CALL,
252         Feature.TRAILING_COMMA_IN_DEF,
253         Feature.ASYNC_KEYWORDS,
254         Feature.ASSIGNMENT_EXPRESSIONS,
255         Feature.POS_ONLY_ARGUMENTS,
256     },
257     TargetVersion.PY39: {
258         Feature.UNICODE_LITERALS,
259         Feature.F_STRINGS,
260         Feature.NUMERIC_UNDERSCORES,
261         Feature.TRAILING_COMMA_IN_CALL,
262         Feature.TRAILING_COMMA_IN_DEF,
263         Feature.ASYNC_KEYWORDS,
264         Feature.ASSIGNMENT_EXPRESSIONS,
265         Feature.RELAXED_DECORATORS,
266         Feature.POS_ONLY_ARGUMENTS,
267     },
268 }
269
270
271 @dataclass
272 class Mode:
273     target_versions: Set[TargetVersion] = field(default_factory=set)
274     line_length: int = DEFAULT_LINE_LENGTH
275     string_normalization: bool = True
276     magic_trailing_comma: bool = True
277     experimental_string_processing: bool = False
278     is_pyi: bool = False
279
280     def get_cache_key(self) -> str:
281         if self.target_versions:
282             version_str = ",".join(
283                 str(version.value)
284                 for version in sorted(self.target_versions, key=lambda v: v.value)
285             )
286         else:
287             version_str = "-"
288         parts = [
289             version_str,
290             str(self.line_length),
291             str(int(self.string_normalization)),
292             str(int(self.is_pyi)),
293         ]
294         return ".".join(parts)
295
296
297 # Legacy name, left for integrations.
298 FileMode = Mode
299
300
301 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
302     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
303
304
305 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
306     """Find the absolute filepath to a pyproject.toml if it exists"""
307     path_project_root = find_project_root(path_search_start)
308     path_pyproject_toml = path_project_root / "pyproject.toml"
309     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
310
311
312 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
313     """Parse a pyproject toml file, pulling out relevant parts for Black
314
315     If parsing fails, will raise a toml.TomlDecodeError
316     """
317     pyproject_toml = toml.load(path_config)
318     config = pyproject_toml.get("tool", {}).get("black", {})
319     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
320
321
322 def read_pyproject_toml(
323     ctx: click.Context, param: click.Parameter, value: Optional[str]
324 ) -> Optional[str]:
325     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
326
327     Returns the path to a successfully found and read configuration file, None
328     otherwise.
329     """
330     if not value:
331         value = find_pyproject_toml(ctx.params.get("src", ()))
332         if value is None:
333             return None
334
335     try:
336         config = parse_pyproject_toml(value)
337     except (toml.TomlDecodeError, OSError) as e:
338         raise click.FileError(
339             filename=value, hint=f"Error reading configuration file: {e}"
340         )
341
342     if not config:
343         return None
344     else:
345         # Sanitize the values to be Click friendly. For more information please see:
346         # https://github.com/psf/black/issues/1458
347         # https://github.com/pallets/click/issues/1567
348         config = {
349             k: str(v) if not isinstance(v, (list, dict)) else v
350             for k, v in config.items()
351         }
352
353     target_version = config.get("target_version")
354     if target_version is not None and not isinstance(target_version, list):
355         raise click.BadOptionUsage(
356             "target-version", "Config key target-version must be a list"
357         )
358
359     default_map: Dict[str, Any] = {}
360     if ctx.default_map:
361         default_map.update(ctx.default_map)
362     default_map.update(config)
363
364     ctx.default_map = default_map
365     return value
366
367
368 def target_version_option_callback(
369     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
370 ) -> List[TargetVersion]:
371     """Compute the target versions from a --target-version flag.
372
373     This is its own function because mypy couldn't infer the type correctly
374     when it was a lambda, causing mypyc trouble.
375     """
376     return [TargetVersion[val.upper()] for val in v]
377
378
379 def validate_regex(
380     ctx: click.Context,
381     param: click.Parameter,
382     value: Optional[str],
383 ) -> Optional[Pattern]:
384     try:
385         return re_compile_maybe_verbose(value) if value is not None else None
386     except re.error:
387         raise click.BadParameter("Not a valid regular expression")
388
389
390 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
391 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
392 @click.option(
393     "-l",
394     "--line-length",
395     type=int,
396     default=DEFAULT_LINE_LENGTH,
397     help="How many characters per line to allow.",
398     show_default=True,
399 )
400 @click.option(
401     "-t",
402     "--target-version",
403     type=click.Choice([v.name.lower() for v in TargetVersion]),
404     callback=target_version_option_callback,
405     multiple=True,
406     help=(
407         "Python versions that should be supported by Black's output. [default: per-file"
408         " auto-detection]"
409     ),
410 )
411 @click.option(
412     "--pyi",
413     is_flag=True,
414     help=(
415         "Format all input files like typing stubs regardless of file extension (useful"
416         " when piping source on standard input)."
417     ),
418 )
419 @click.option(
420     "-S",
421     "--skip-string-normalization",
422     is_flag=True,
423     help="Don't normalize string quotes or prefixes.",
424 )
425 @click.option(
426     "-C",
427     "--skip-magic-trailing-comma",
428     is_flag=True,
429     help="Don't use trailing commas as a reason to split lines.",
430 )
431 @click.option(
432     "--experimental-string-processing",
433     is_flag=True,
434     hidden=True,
435     help=(
436         "Experimental option that performs more normalization on string literals."
437         " Currently disabled because it leads to some crashes."
438     ),
439 )
440 @click.option(
441     "--check",
442     is_flag=True,
443     help=(
444         "Don't write the files back, just return the status.  Return code 0 means"
445         " nothing would change.  Return code 1 means some files would be reformatted."
446         " Return code 123 means there was an internal error."
447     ),
448 )
449 @click.option(
450     "--diff",
451     is_flag=True,
452     help="Don't write the files back, just output a diff for each file on stdout.",
453 )
454 @click.option(
455     "--color/--no-color",
456     is_flag=True,
457     help="Show colored diff. Only applies when `--diff` is given.",
458 )
459 @click.option(
460     "--fast/--safe",
461     is_flag=True,
462     help="If --fast given, skip temporary sanity checks. [default: --safe]",
463 )
464 @click.option(
465     "--include",
466     type=str,
467     default=DEFAULT_INCLUDES,
468     callback=validate_regex,
469     help=(
470         "A regular expression that matches files and directories that should be"
471         " included on recursive searches.  An empty value means all files are included"
472         " regardless of the name.  Use forward slashes for directories on all platforms"
473         " (Windows, too).  Exclusions are calculated first, inclusions later."
474     ),
475     show_default=True,
476 )
477 @click.option(
478     "--exclude",
479     type=str,
480     default=DEFAULT_EXCLUDES,
481     callback=validate_regex,
482     help=(
483         "A regular expression that matches files and directories that should be"
484         " excluded on recursive searches.  An empty value means no paths are excluded."
485         " Use forward slashes for directories on all platforms (Windows, too). "
486         " Exclusions are calculated first, inclusions later."
487     ),
488     show_default=True,
489 )
490 @click.option(
491     "--extend-exclude",
492     type=str,
493     callback=validate_regex,
494     help=(
495         "Like --exclude, but adds additional files and directories on top of the"
496         " excluded ones. (Useful if you simply want to add to the default)"
497     ),
498 )
499 @click.option(
500     "--force-exclude",
501     type=str,
502     callback=validate_regex,
503     help=(
504         "Like --exclude, but files and directories matching this regex will be "
505         "excluded even when they are passed explicitly as arguments."
506     ),
507 )
508 @click.option(
509     "--stdin-filename",
510     type=str,
511     help=(
512         "The name of the file when passing it through stdin. Useful to make "
513         "sure Black will respect --force-exclude option on some "
514         "editors that rely on using stdin."
515     ),
516 )
517 @click.option(
518     "-q",
519     "--quiet",
520     is_flag=True,
521     help=(
522         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
523         " those with 2>/dev/null."
524     ),
525 )
526 @click.option(
527     "-v",
528     "--verbose",
529     is_flag=True,
530     help=(
531         "Also emit messages to stderr about files that were not changed or were ignored"
532         " due to exclusion patterns."
533     ),
534 )
535 @click.version_option(version=__version__)
536 @click.argument(
537     "src",
538     nargs=-1,
539     type=click.Path(
540         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
541     ),
542     is_eager=True,
543 )
544 @click.option(
545     "--config",
546     type=click.Path(
547         exists=True,
548         file_okay=True,
549         dir_okay=False,
550         readable=True,
551         allow_dash=False,
552         path_type=str,
553     ),
554     is_eager=True,
555     callback=read_pyproject_toml,
556     help="Read configuration from FILE path.",
557 )
558 @click.pass_context
559 def main(
560     ctx: click.Context,
561     code: Optional[str],
562     line_length: int,
563     target_version: List[TargetVersion],
564     check: bool,
565     diff: bool,
566     color: bool,
567     fast: bool,
568     pyi: bool,
569     skip_string_normalization: bool,
570     skip_magic_trailing_comma: bool,
571     experimental_string_processing: bool,
572     quiet: bool,
573     verbose: bool,
574     include: Pattern,
575     exclude: Pattern,
576     extend_exclude: Optional[Pattern],
577     force_exclude: Optional[Pattern],
578     stdin_filename: Optional[str],
579     src: Tuple[str, ...],
580     config: Optional[str],
581 ) -> None:
582     """The uncompromising code formatter."""
583     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
584     if target_version:
585         versions = set(target_version)
586     else:
587         # We'll autodetect later.
588         versions = set()
589     mode = Mode(
590         target_versions=versions,
591         line_length=line_length,
592         is_pyi=pyi,
593         string_normalization=not skip_string_normalization,
594         magic_trailing_comma=not skip_magic_trailing_comma,
595         experimental_string_processing=experimental_string_processing,
596     )
597     if config and verbose:
598         out(f"Using configuration from {config}.", bold=False, fg="blue")
599     if code is not None:
600         print(format_str(code, mode=mode))
601         ctx.exit(0)
602     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
603     sources = get_sources(
604         ctx=ctx,
605         src=src,
606         quiet=quiet,
607         verbose=verbose,
608         include=include,
609         exclude=exclude,
610         extend_exclude=extend_exclude,
611         force_exclude=force_exclude,
612         report=report,
613         stdin_filename=stdin_filename,
614     )
615
616     path_empty(
617         sources,
618         "No Python files are present to be formatted. Nothing to do 😴",
619         quiet,
620         verbose,
621         ctx,
622     )
623
624     if len(sources) == 1:
625         reformat_one(
626             src=sources.pop(),
627             fast=fast,
628             write_back=write_back,
629             mode=mode,
630             report=report,
631         )
632     else:
633         reformat_many(
634             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
635         )
636
637     if verbose or not quiet:
638         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
639         click.secho(str(report), err=True)
640     ctx.exit(report.return_code)
641
642
643 def get_sources(
644     *,
645     ctx: click.Context,
646     src: Tuple[str, ...],
647     quiet: bool,
648     verbose: bool,
649     include: Pattern[str],
650     exclude: Pattern[str],
651     extend_exclude: Optional[Pattern[str]],
652     force_exclude: Optional[Pattern[str]],
653     report: "Report",
654     stdin_filename: Optional[str],
655 ) -> Set[Path]:
656     """Compute the set of files to be formatted."""
657
658     root = find_project_root(src)
659     sources: Set[Path] = set()
660     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
661     gitignore = get_gitignore(root)
662
663     for s in src:
664         if s == "-" and stdin_filename:
665             p = Path(stdin_filename)
666             is_stdin = True
667         else:
668             p = Path(s)
669             is_stdin = False
670
671         if is_stdin or p.is_file():
672             normalized_path = normalize_path_maybe_ignore(p, root, report)
673             if normalized_path is None:
674                 continue
675
676             normalized_path = "/" + normalized_path
677             # Hard-exclude any files that matches the `--force-exclude` regex.
678             if force_exclude:
679                 force_exclude_match = force_exclude.search(normalized_path)
680             else:
681                 force_exclude_match = None
682             if force_exclude_match and force_exclude_match.group(0):
683                 report.path_ignored(p, "matches the --force-exclude regular expression")
684                 continue
685
686             if is_stdin:
687                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
688
689             sources.add(p)
690         elif p.is_dir():
691             sources.update(
692                 gen_python_files(
693                     p.iterdir(),
694                     root,
695                     include,
696                     exclude,
697                     extend_exclude,
698                     force_exclude,
699                     report,
700                     gitignore,
701                 )
702             )
703         elif s == "-":
704             sources.add(p)
705         else:
706             err(f"invalid path: {s}")
707     return sources
708
709
710 def path_empty(
711     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
712 ) -> None:
713     """
714     Exit if there is no `src` provided for formatting
715     """
716     if not src and (verbose or not quiet):
717         out(msg)
718         ctx.exit(0)
719
720
721 def reformat_one(
722     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
723 ) -> None:
724     """Reformat a single file under `src` without spawning child processes.
725
726     `fast`, `write_back`, and `mode` options are passed to
727     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
728     """
729     try:
730         changed = Changed.NO
731
732         if str(src) == "-":
733             is_stdin = True
734         elif str(src).startswith(STDIN_PLACEHOLDER):
735             is_stdin = True
736             # Use the original name again in case we want to print something
737             # to the user
738             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
739         else:
740             is_stdin = False
741
742         if is_stdin:
743             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
744                 changed = Changed.YES
745         else:
746             cache: Cache = {}
747             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
748                 cache = read_cache(mode)
749                 res_src = src.resolve()
750                 res_src_s = str(res_src)
751                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
752                     changed = Changed.CACHED
753             if changed is not Changed.CACHED and format_file_in_place(
754                 src, fast=fast, write_back=write_back, mode=mode
755             ):
756                 changed = Changed.YES
757             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
758                 write_back is WriteBack.CHECK and changed is Changed.NO
759             ):
760                 write_cache(cache, [src], mode)
761         report.done(src, changed)
762     except Exception as exc:
763         if report.verbose:
764             traceback.print_exc()
765         report.failed(src, str(exc))
766
767
768 def reformat_many(
769     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
770 ) -> None:
771     """Reformat multiple files using a ProcessPoolExecutor."""
772     executor: Executor
773     loop = asyncio.get_event_loop()
774     worker_count = os.cpu_count()
775     if sys.platform == "win32":
776         # Work around https://bugs.python.org/issue26903
777         worker_count = min(worker_count, 60)
778     try:
779         executor = ProcessPoolExecutor(max_workers=worker_count)
780     except (ImportError, OSError):
781         # we arrive here if the underlying system does not support multi-processing
782         # like in AWS Lambda or Termux, in which case we gracefully fallback to
783         # a ThreadPollExecutor with just a single worker (more workers would not do us
784         # any good due to the Global Interpreter Lock)
785         executor = ThreadPoolExecutor(max_workers=1)
786
787     try:
788         loop.run_until_complete(
789             schedule_formatting(
790                 sources=sources,
791                 fast=fast,
792                 write_back=write_back,
793                 mode=mode,
794                 report=report,
795                 loop=loop,
796                 executor=executor,
797             )
798         )
799     finally:
800         shutdown(loop)
801         if executor is not None:
802             executor.shutdown()
803
804
805 async def schedule_formatting(
806     sources: Set[Path],
807     fast: bool,
808     write_back: WriteBack,
809     mode: Mode,
810     report: "Report",
811     loop: asyncio.AbstractEventLoop,
812     executor: Executor,
813 ) -> None:
814     """Run formatting of `sources` in parallel using the provided `executor`.
815
816     (Use ProcessPoolExecutors for actual parallelism.)
817
818     `write_back`, `fast`, and `mode` options are passed to
819     :func:`format_file_in_place`.
820     """
821     cache: Cache = {}
822     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
823         cache = read_cache(mode)
824         sources, cached = filter_cached(cache, sources)
825         for src in sorted(cached):
826             report.done(src, Changed.CACHED)
827     if not sources:
828         return
829
830     cancelled = []
831     sources_to_cache = []
832     lock = None
833     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
834         # For diff output, we need locks to ensure we don't interleave output
835         # from different processes.
836         manager = Manager()
837         lock = manager.Lock()
838     tasks = {
839         asyncio.ensure_future(
840             loop.run_in_executor(
841                 executor, format_file_in_place, src, fast, mode, write_back, lock
842             )
843         ): src
844         for src in sorted(sources)
845     }
846     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
847     try:
848         loop.add_signal_handler(signal.SIGINT, cancel, pending)
849         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
850     except NotImplementedError:
851         # There are no good alternatives for these on Windows.
852         pass
853     while pending:
854         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
855         for task in done:
856             src = tasks.pop(task)
857             if task.cancelled():
858                 cancelled.append(task)
859             elif task.exception():
860                 report.failed(src, str(task.exception()))
861             else:
862                 changed = Changed.YES if task.result() else Changed.NO
863                 # If the file was written back or was successfully checked as
864                 # well-formatted, store this information in the cache.
865                 if write_back is WriteBack.YES or (
866                     write_back is WriteBack.CHECK and changed is Changed.NO
867                 ):
868                     sources_to_cache.append(src)
869                 report.done(src, changed)
870     if cancelled:
871         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
872     if sources_to_cache:
873         write_cache(cache, sources_to_cache, mode)
874
875
876 def format_file_in_place(
877     src: Path,
878     fast: bool,
879     mode: Mode,
880     write_back: WriteBack = WriteBack.NO,
881     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
882 ) -> bool:
883     """Format file under `src` path. Return True if changed.
884
885     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
886     code to the file.
887     `mode` and `fast` options are passed to :func:`format_file_contents`.
888     """
889     if src.suffix == ".pyi":
890         mode = replace(mode, is_pyi=True)
891
892     then = datetime.utcfromtimestamp(src.stat().st_mtime)
893     with open(src, "rb") as buf:
894         src_contents, encoding, newline = decode_bytes(buf.read())
895     try:
896         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
897     except NothingChanged:
898         return False
899
900     if write_back == WriteBack.YES:
901         with open(src, "w", encoding=encoding, newline=newline) as f:
902             f.write(dst_contents)
903     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
904         now = datetime.utcnow()
905         src_name = f"{src}\t{then} +0000"
906         dst_name = f"{src}\t{now} +0000"
907         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
908
909         if write_back == WriteBack.COLOR_DIFF:
910             diff_contents = color_diff(diff_contents)
911
912         with lock or nullcontext():
913             f = io.TextIOWrapper(
914                 sys.stdout.buffer,
915                 encoding=encoding,
916                 newline=newline,
917                 write_through=True,
918             )
919             f = wrap_stream_for_windows(f)
920             f.write(diff_contents)
921             f.detach()
922
923     return True
924
925
926 def color_diff(contents: str) -> str:
927     """Inject the ANSI color codes to the diff."""
928     lines = contents.split("\n")
929     for i, line in enumerate(lines):
930         if line.startswith("+++") or line.startswith("---"):
931             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
932         elif line.startswith("@@"):
933             line = "\033[36m" + line + "\033[0m"  # cyan, reset
934         elif line.startswith("+"):
935             line = "\033[32m" + line + "\033[0m"  # green, reset
936         elif line.startswith("-"):
937             line = "\033[31m" + line + "\033[0m"  # red, reset
938         lines[i] = line
939     return "\n".join(lines)
940
941
942 def wrap_stream_for_windows(
943     f: io.TextIOWrapper,
944 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
945     """
946     Wrap stream with colorama's wrap_stream so colors are shown on Windows.
947
948     If `colorama` is unavailable, the original stream is returned unmodified.
949     Otherwise, the `wrap_stream()` function determines whether the stream needs
950     to be wrapped for a Windows environment and will accordingly either return
951     an `AnsiToWin32` wrapper or the original stream.
952     """
953     try:
954         from colorama.initialise import wrap_stream
955     except ImportError:
956         return f
957     else:
958         # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
959         return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
960
961
962 def format_stdin_to_stdout(
963     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
964 ) -> bool:
965     """Format file on stdin. Return True if changed.
966
967     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
968     write a diff to stdout. The `mode` argument is passed to
969     :func:`format_file_contents`.
970     """
971     then = datetime.utcnow()
972     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
973     dst = src
974     try:
975         dst = format_file_contents(src, fast=fast, mode=mode)
976         return True
977
978     except NothingChanged:
979         return False
980
981     finally:
982         f = io.TextIOWrapper(
983             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
984         )
985         if write_back == WriteBack.YES:
986             f.write(dst)
987         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
988             now = datetime.utcnow()
989             src_name = f"STDIN\t{then} +0000"
990             dst_name = f"STDOUT\t{now} +0000"
991             d = diff(src, dst, src_name, dst_name)
992             if write_back == WriteBack.COLOR_DIFF:
993                 d = color_diff(d)
994                 f = wrap_stream_for_windows(f)
995             f.write(d)
996         f.detach()
997
998
999 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1000     """Reformat contents of a file and return new contents.
1001
1002     If `fast` is False, additionally confirm that the reformatted code is
1003     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
1004     `mode` is passed to :func:`format_str`.
1005     """
1006     if not src_contents.strip():
1007         raise NothingChanged
1008
1009     dst_contents = format_str(src_contents, mode=mode)
1010     if src_contents == dst_contents:
1011         raise NothingChanged
1012
1013     if not fast:
1014         assert_equivalent(src_contents, dst_contents)
1015         assert_stable(src_contents, dst_contents, mode=mode)
1016     return dst_contents
1017
1018
1019 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1020     """Reformat a string and return new contents.
1021
1022     `mode` determines formatting options, such as how many characters per line are
1023     allowed.  Example:
1024
1025     >>> import black
1026     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1027     def f(arg: str = "") -> None:
1028         ...
1029
1030     A more complex example:
1031
1032     >>> print(
1033     ...   black.format_str(
1034     ...     "def f(arg:str='')->None: hey",
1035     ...     mode=black.Mode(
1036     ...       target_versions={black.TargetVersion.PY36},
1037     ...       line_length=10,
1038     ...       string_normalization=False,
1039     ...       is_pyi=False,
1040     ...     ),
1041     ...   ),
1042     ... )
1043     def f(
1044         arg: str = '',
1045     ) -> None:
1046         hey
1047
1048     """
1049     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1050     dst_contents = []
1051     future_imports = get_future_imports(src_node)
1052     if mode.target_versions:
1053         versions = mode.target_versions
1054     else:
1055         versions = detect_target_versions(src_node)
1056     normalize_fmt_off(src_node)
1057     lines = LineGenerator(
1058         mode=mode,
1059         remove_u_prefix="unicode_literals" in future_imports
1060         or supports_feature(versions, Feature.UNICODE_LITERALS),
1061     )
1062     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1063     empty_line = Line(mode=mode)
1064     after = 0
1065     split_line_features = {
1066         feature
1067         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1068         if supports_feature(versions, feature)
1069     }
1070     for current_line in lines.visit(src_node):
1071         dst_contents.append(str(empty_line) * after)
1072         before, after = elt.maybe_empty_lines(current_line)
1073         dst_contents.append(str(empty_line) * before)
1074         for line in transform_line(
1075             current_line, mode=mode, features=split_line_features
1076         ):
1077             dst_contents.append(str(line))
1078     return "".join(dst_contents)
1079
1080
1081 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1082     """Return a tuple of (decoded_contents, encoding, newline).
1083
1084     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1085     universal newlines (i.e. only contains LF).
1086     """
1087     srcbuf = io.BytesIO(src)
1088     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1089     if not lines:
1090         return "", encoding, "\n"
1091
1092     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1093     srcbuf.seek(0)
1094     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1095         return tiow.read(), encoding, newline
1096
1097
1098 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1099     if not target_versions:
1100         # No target_version specified, so try all grammars.
1101         return [
1102             # Python 3.7+
1103             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1104             # Python 3.0-3.6
1105             pygram.python_grammar_no_print_statement_no_exec_statement,
1106             # Python 2.7 with future print_function import
1107             pygram.python_grammar_no_print_statement,
1108             # Python 2.7
1109             pygram.python_grammar,
1110         ]
1111
1112     if all(version.is_python2() for version in target_versions):
1113         # Python 2-only code, so try Python 2 grammars.
1114         return [
1115             # Python 2.7 with future print_function import
1116             pygram.python_grammar_no_print_statement,
1117             # Python 2.7
1118             pygram.python_grammar,
1119         ]
1120
1121     # Python 3-compatible code, so only try Python 3 grammar.
1122     grammars = []
1123     # If we have to parse both, try to parse async as a keyword first
1124     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1125         # Python 3.7+
1126         grammars.append(
1127             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1128         )
1129     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1130         # Python 3.0-3.6
1131         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1132     # At least one of the above branches must have been taken, because every Python
1133     # version has exactly one of the two 'ASYNC_*' flags
1134     return grammars
1135
1136
1137 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1138     """Given a string with source, return the lib2to3 Node."""
1139     if not src_txt.endswith("\n"):
1140         src_txt += "\n"
1141
1142     for grammar in get_grammars(set(target_versions)):
1143         drv = driver.Driver(grammar, pytree.convert)
1144         try:
1145             result = drv.parse_string(src_txt, True)
1146             break
1147
1148         except ParseError as pe:
1149             lineno, column = pe.context[1]
1150             lines = src_txt.splitlines()
1151             try:
1152                 faulty_line = lines[lineno - 1]
1153             except IndexError:
1154                 faulty_line = "<line number missing in source>"
1155             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1156     else:
1157         raise exc from None
1158
1159     if isinstance(result, Leaf):
1160         result = Node(syms.file_input, [result])
1161     return result
1162
1163
1164 def lib2to3_unparse(node: Node) -> str:
1165     """Given a lib2to3 node, return its string representation."""
1166     code = str(node)
1167     return code
1168
1169
1170 class Visitor(Generic[T]):
1171     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1172
1173     def visit(self, node: LN) -> Iterator[T]:
1174         """Main method to visit `node` and its children.
1175
1176         It tries to find a `visit_*()` method for the given `node.type`, like
1177         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1178         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1179         instead.
1180
1181         Then yields objects of type `T` from the selected visitor.
1182         """
1183         if node.type < 256:
1184             name = token.tok_name[node.type]
1185         else:
1186             name = str(type_repr(node.type))
1187         # We explicitly branch on whether a visitor exists (instead of
1188         # using self.visit_default as the default arg to getattr) in order
1189         # to save needing to create a bound method object and so mypyc can
1190         # generate a native call to visit_default.
1191         visitf = getattr(self, f"visit_{name}", None)
1192         if visitf:
1193             yield from visitf(node)
1194         else:
1195             yield from self.visit_default(node)
1196
1197     def visit_default(self, node: LN) -> Iterator[T]:
1198         """Default `visit_*()` implementation. Recurses to children of `node`."""
1199         if isinstance(node, Node):
1200             for child in node.children:
1201                 yield from self.visit(child)
1202
1203
1204 @dataclass
1205 class DebugVisitor(Visitor[T]):
1206     tree_depth: int = 0
1207
1208     def visit_default(self, node: LN) -> Iterator[T]:
1209         indent = " " * (2 * self.tree_depth)
1210         if isinstance(node, Node):
1211             _type = type_repr(node.type)
1212             out(f"{indent}{_type}", fg="yellow")
1213             self.tree_depth += 1
1214             for child in node.children:
1215                 yield from self.visit(child)
1216
1217             self.tree_depth -= 1
1218             out(f"{indent}/{_type}", fg="yellow", bold=False)
1219         else:
1220             _type = token.tok_name.get(node.type, str(node.type))
1221             out(f"{indent}{_type}", fg="blue", nl=False)
1222             if node.prefix:
1223                 # We don't have to handle prefixes for `Node` objects since
1224                 # that delegates to the first child anyway.
1225                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1226             out(f" {node.value!r}", fg="blue", bold=False)
1227
1228     @classmethod
1229     def show(cls, code: Union[str, Leaf, Node]) -> None:
1230         """Pretty-print the lib2to3 AST of a given string of `code`.
1231
1232         Convenience method for debugging.
1233         """
1234         v: DebugVisitor[None] = DebugVisitor()
1235         if isinstance(code, str):
1236             code = lib2to3_parse(code)
1237         list(v.visit(code))
1238
1239
1240 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1241 STATEMENT: Final = {
1242     syms.if_stmt,
1243     syms.while_stmt,
1244     syms.for_stmt,
1245     syms.try_stmt,
1246     syms.except_clause,
1247     syms.with_stmt,
1248     syms.funcdef,
1249     syms.classdef,
1250 }
1251 STANDALONE_COMMENT: Final = 153
1252 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1253 LOGIC_OPERATORS: Final = {"and", "or"}
1254 COMPARATORS: Final = {
1255     token.LESS,
1256     token.GREATER,
1257     token.EQEQUAL,
1258     token.NOTEQUAL,
1259     token.LESSEQUAL,
1260     token.GREATEREQUAL,
1261 }
1262 MATH_OPERATORS: Final = {
1263     token.VBAR,
1264     token.CIRCUMFLEX,
1265     token.AMPER,
1266     token.LEFTSHIFT,
1267     token.RIGHTSHIFT,
1268     token.PLUS,
1269     token.MINUS,
1270     token.STAR,
1271     token.SLASH,
1272     token.DOUBLESLASH,
1273     token.PERCENT,
1274     token.AT,
1275     token.TILDE,
1276     token.DOUBLESTAR,
1277 }
1278 STARS: Final = {token.STAR, token.DOUBLESTAR}
1279 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1280 VARARGS_PARENTS: Final = {
1281     syms.arglist,
1282     syms.argument,  # double star in arglist
1283     syms.trailer,  # single argument to call
1284     syms.typedargslist,
1285     syms.varargslist,  # lambdas
1286 }
1287 UNPACKING_PARENTS: Final = {
1288     syms.atom,  # single element of a list or set literal
1289     syms.dictsetmaker,
1290     syms.listmaker,
1291     syms.testlist_gexp,
1292     syms.testlist_star_expr,
1293 }
1294 TEST_DESCENDANTS: Final = {
1295     syms.test,
1296     syms.lambdef,
1297     syms.or_test,
1298     syms.and_test,
1299     syms.not_test,
1300     syms.comparison,
1301     syms.star_expr,
1302     syms.expr,
1303     syms.xor_expr,
1304     syms.and_expr,
1305     syms.shift_expr,
1306     syms.arith_expr,
1307     syms.trailer,
1308     syms.term,
1309     syms.power,
1310 }
1311 ASSIGNMENTS: Final = {
1312     "=",
1313     "+=",
1314     "-=",
1315     "*=",
1316     "@=",
1317     "/=",
1318     "%=",
1319     "&=",
1320     "|=",
1321     "^=",
1322     "<<=",
1323     ">>=",
1324     "**=",
1325     "//=",
1326 }
1327 COMPREHENSION_PRIORITY: Final = 20
1328 COMMA_PRIORITY: Final = 18
1329 TERNARY_PRIORITY: Final = 16
1330 LOGIC_PRIORITY: Final = 14
1331 STRING_PRIORITY: Final = 12
1332 COMPARATOR_PRIORITY: Final = 10
1333 MATH_PRIORITIES: Final = {
1334     token.VBAR: 9,
1335     token.CIRCUMFLEX: 8,
1336     token.AMPER: 7,
1337     token.LEFTSHIFT: 6,
1338     token.RIGHTSHIFT: 6,
1339     token.PLUS: 5,
1340     token.MINUS: 5,
1341     token.STAR: 4,
1342     token.SLASH: 4,
1343     token.DOUBLESLASH: 4,
1344     token.PERCENT: 4,
1345     token.AT: 4,
1346     token.TILDE: 3,
1347     token.DOUBLESTAR: 2,
1348 }
1349 DOT_PRIORITY: Final = 1
1350
1351
1352 @dataclass
1353 class BracketTracker:
1354     """Keeps track of brackets on a line."""
1355
1356     depth: int = 0
1357     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1358     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1359     previous: Optional[Leaf] = None
1360     _for_loop_depths: List[int] = field(default_factory=list)
1361     _lambda_argument_depths: List[int] = field(default_factory=list)
1362     invisible: List[Leaf] = field(default_factory=list)
1363
1364     def mark(self, leaf: Leaf) -> None:
1365         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1366
1367         All leaves receive an int `bracket_depth` field that stores how deep
1368         within brackets a given leaf is. 0 means there are no enclosing brackets
1369         that started on this line.
1370
1371         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1372         field that it forms a pair with. This is a one-directional link to
1373         avoid reference cycles.
1374
1375         If a leaf is a delimiter (a token on which Black can split the line if
1376         needed) and it's on depth 0, its `id()` is stored in the tracker's
1377         `delimiters` field.
1378         """
1379         if leaf.type == token.COMMENT:
1380             return
1381
1382         self.maybe_decrement_after_for_loop_variable(leaf)
1383         self.maybe_decrement_after_lambda_arguments(leaf)
1384         if leaf.type in CLOSING_BRACKETS:
1385             self.depth -= 1
1386             try:
1387                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1388             except KeyError as e:
1389                 raise BracketMatchError(
1390                     "Unable to match a closing bracket to the following opening"
1391                     f" bracket: {leaf}"
1392                 ) from e
1393             leaf.opening_bracket = opening_bracket
1394             if not leaf.value:
1395                 self.invisible.append(leaf)
1396         leaf.bracket_depth = self.depth
1397         if self.depth == 0:
1398             delim = is_split_before_delimiter(leaf, self.previous)
1399             if delim and self.previous is not None:
1400                 self.delimiters[id(self.previous)] = delim
1401             else:
1402                 delim = is_split_after_delimiter(leaf, self.previous)
1403                 if delim:
1404                     self.delimiters[id(leaf)] = delim
1405         if leaf.type in OPENING_BRACKETS:
1406             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1407             self.depth += 1
1408             if not leaf.value:
1409                 self.invisible.append(leaf)
1410         self.previous = leaf
1411         self.maybe_increment_lambda_arguments(leaf)
1412         self.maybe_increment_for_loop_variable(leaf)
1413
1414     def any_open_brackets(self) -> bool:
1415         """Return True if there is an yet unmatched open bracket on the line."""
1416         return bool(self.bracket_match)
1417
1418     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1419         """Return the highest priority of a delimiter found on the line.
1420
1421         Values are consistent with what `is_split_*_delimiter()` return.
1422         Raises ValueError on no delimiters.
1423         """
1424         return max(v for k, v in self.delimiters.items() if k not in exclude)
1425
1426     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1427         """Return the number of delimiters with the given `priority`.
1428
1429         If no `priority` is passed, defaults to max priority on the line.
1430         """
1431         if not self.delimiters:
1432             return 0
1433
1434         priority = priority or self.max_delimiter_priority()
1435         return sum(1 for p in self.delimiters.values() if p == priority)
1436
1437     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1438         """In a for loop, or comprehension, the variables are often unpacks.
1439
1440         To avoid splitting on the comma in this situation, increase the depth of
1441         tokens between `for` and `in`.
1442         """
1443         if leaf.type == token.NAME and leaf.value == "for":
1444             self.depth += 1
1445             self._for_loop_depths.append(self.depth)
1446             return True
1447
1448         return False
1449
1450     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1451         """See `maybe_increment_for_loop_variable` above for explanation."""
1452         if (
1453             self._for_loop_depths
1454             and self._for_loop_depths[-1] == self.depth
1455             and leaf.type == token.NAME
1456             and leaf.value == "in"
1457         ):
1458             self.depth -= 1
1459             self._for_loop_depths.pop()
1460             return True
1461
1462         return False
1463
1464     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1465         """In a lambda expression, there might be more than one argument.
1466
1467         To avoid splitting on the comma in this situation, increase the depth of
1468         tokens between `lambda` and `:`.
1469         """
1470         if leaf.type == token.NAME and leaf.value == "lambda":
1471             self.depth += 1
1472             self._lambda_argument_depths.append(self.depth)
1473             return True
1474
1475         return False
1476
1477     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1478         """See `maybe_increment_lambda_arguments` above for explanation."""
1479         if (
1480             self._lambda_argument_depths
1481             and self._lambda_argument_depths[-1] == self.depth
1482             and leaf.type == token.COLON
1483         ):
1484             self.depth -= 1
1485             self._lambda_argument_depths.pop()
1486             return True
1487
1488         return False
1489
1490     def get_open_lsqb(self) -> Optional[Leaf]:
1491         """Return the most recent opening square bracket (if any)."""
1492         return self.bracket_match.get((self.depth - 1, token.RSQB))
1493
1494
1495 @dataclass
1496 class Line:
1497     """Holds leaves and comments. Can be printed with `str(line)`."""
1498
1499     mode: Mode
1500     depth: int = 0
1501     leaves: List[Leaf] = field(default_factory=list)
1502     # keys ordered like `leaves`
1503     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1504     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1505     inside_brackets: bool = False
1506     should_split_rhs: bool = False
1507     magic_trailing_comma: Optional[Leaf] = None
1508
1509     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1510         """Add a new `leaf` to the end of the line.
1511
1512         Unless `preformatted` is True, the `leaf` will receive a new consistent
1513         whitespace prefix and metadata applied by :class:`BracketTracker`.
1514         Trailing commas are maybe removed, unpacked for loop variables are
1515         demoted from being delimiters.
1516
1517         Inline comments are put aside.
1518         """
1519         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1520         if not has_value:
1521             return
1522
1523         if token.COLON == leaf.type and self.is_class_paren_empty:
1524             del self.leaves[-2:]
1525         if self.leaves and not preformatted:
1526             # Note: at this point leaf.prefix should be empty except for
1527             # imports, for which we only preserve newlines.
1528             leaf.prefix += whitespace(
1529                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1530             )
1531         if self.inside_brackets or not preformatted:
1532             self.bracket_tracker.mark(leaf)
1533             if self.mode.magic_trailing_comma:
1534                 if self.has_magic_trailing_comma(leaf):
1535                     self.magic_trailing_comma = leaf
1536             elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
1537                 self.remove_trailing_comma()
1538         if not self.append_comment(leaf):
1539             self.leaves.append(leaf)
1540
1541     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1542         """Like :func:`append()` but disallow invalid standalone comment structure.
1543
1544         Raises ValueError when any `leaf` is appended after a standalone comment
1545         or when a standalone comment is not the first leaf on the line.
1546         """
1547         if self.bracket_tracker.depth == 0:
1548             if self.is_comment:
1549                 raise ValueError("cannot append to standalone comments")
1550
1551             if self.leaves and leaf.type == STANDALONE_COMMENT:
1552                 raise ValueError(
1553                     "cannot append standalone comments to a populated line"
1554                 )
1555
1556         self.append(leaf, preformatted=preformatted)
1557
1558     @property
1559     def is_comment(self) -> bool:
1560         """Is this line a standalone comment?"""
1561         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1562
1563     @property
1564     def is_decorator(self) -> bool:
1565         """Is this line a decorator?"""
1566         return bool(self) and self.leaves[0].type == token.AT
1567
1568     @property
1569     def is_import(self) -> bool:
1570         """Is this an import line?"""
1571         return bool(self) and is_import(self.leaves[0])
1572
1573     @property
1574     def is_class(self) -> bool:
1575         """Is this line a class definition?"""
1576         return (
1577             bool(self)
1578             and self.leaves[0].type == token.NAME
1579             and self.leaves[0].value == "class"
1580         )
1581
1582     @property
1583     def is_stub_class(self) -> bool:
1584         """Is this line a class definition with a body consisting only of "..."?"""
1585         return self.is_class and self.leaves[-3:] == [
1586             Leaf(token.DOT, ".") for _ in range(3)
1587         ]
1588
1589     @property
1590     def is_def(self) -> bool:
1591         """Is this a function definition? (Also returns True for async defs.)"""
1592         try:
1593             first_leaf = self.leaves[0]
1594         except IndexError:
1595             return False
1596
1597         try:
1598             second_leaf: Optional[Leaf] = self.leaves[1]
1599         except IndexError:
1600             second_leaf = None
1601         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1602             first_leaf.type == token.ASYNC
1603             and second_leaf is not None
1604             and second_leaf.type == token.NAME
1605             and second_leaf.value == "def"
1606         )
1607
1608     @property
1609     def is_class_paren_empty(self) -> bool:
1610         """Is this a class with no base classes but using parentheses?
1611
1612         Those are unnecessary and should be removed.
1613         """
1614         return (
1615             bool(self)
1616             and len(self.leaves) == 4
1617             and self.is_class
1618             and self.leaves[2].type == token.LPAR
1619             and self.leaves[2].value == "("
1620             and self.leaves[3].type == token.RPAR
1621             and self.leaves[3].value == ")"
1622         )
1623
1624     @property
1625     def is_triple_quoted_string(self) -> bool:
1626         """Is the line a triple quoted string?"""
1627         return (
1628             bool(self)
1629             and self.leaves[0].type == token.STRING
1630             and self.leaves[0].value.startswith(('"""', "'''"))
1631         )
1632
1633     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1634         """If so, needs to be split before emitting."""
1635         for leaf in self.leaves:
1636             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1637                 return True
1638
1639         return False
1640
1641     def contains_uncollapsable_type_comments(self) -> bool:
1642         ignored_ids = set()
1643         try:
1644             last_leaf = self.leaves[-1]
1645             ignored_ids.add(id(last_leaf))
1646             if last_leaf.type == token.COMMA or (
1647                 last_leaf.type == token.RPAR and not last_leaf.value
1648             ):
1649                 # When trailing commas or optional parens are inserted by Black for
1650                 # consistency, comments after the previous last element are not moved
1651                 # (they don't have to, rendering will still be correct).  So we ignore
1652                 # trailing commas and invisible.
1653                 last_leaf = self.leaves[-2]
1654                 ignored_ids.add(id(last_leaf))
1655         except IndexError:
1656             return False
1657
1658         # A type comment is uncollapsable if it is attached to a leaf
1659         # that isn't at the end of the line (since that could cause it
1660         # to get associated to a different argument) or if there are
1661         # comments before it (since that could cause it to get hidden
1662         # behind a comment.
1663         comment_seen = False
1664         for leaf_id, comments in self.comments.items():
1665             for comment in comments:
1666                 if is_type_comment(comment):
1667                     if comment_seen or (
1668                         not is_type_comment(comment, " ignore")
1669                         and leaf_id not in ignored_ids
1670                     ):
1671                         return True
1672
1673                 comment_seen = True
1674
1675         return False
1676
1677     def contains_unsplittable_type_ignore(self) -> bool:
1678         if not self.leaves:
1679             return False
1680
1681         # If a 'type: ignore' is attached to the end of a line, we
1682         # can't split the line, because we can't know which of the
1683         # subexpressions the ignore was meant to apply to.
1684         #
1685         # We only want this to apply to actual physical lines from the
1686         # original source, though: we don't want the presence of a
1687         # 'type: ignore' at the end of a multiline expression to
1688         # justify pushing it all onto one line. Thus we
1689         # (unfortunately) need to check the actual source lines and
1690         # only report an unsplittable 'type: ignore' if this line was
1691         # one line in the original code.
1692
1693         # Grab the first and last line numbers, skipping generated leaves
1694         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1695         last_line = next(
1696             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1697         )
1698
1699         if first_line == last_line:
1700             # We look at the last two leaves since a comma or an
1701             # invisible paren could have been added at the end of the
1702             # line.
1703             for node in self.leaves[-2:]:
1704                 for comment in self.comments.get(id(node), []):
1705                     if is_type_comment(comment, " ignore"):
1706                         return True
1707
1708         return False
1709
1710     def contains_multiline_strings(self) -> bool:
1711         return any(is_multiline_string(leaf) for leaf in self.leaves)
1712
1713     def has_magic_trailing_comma(
1714         self, closing: Leaf, ensure_removable: bool = False
1715     ) -> bool:
1716         """Return True if we have a magic trailing comma, that is when:
1717         - there's a trailing comma here
1718         - it's not a one-tuple
1719         Additionally, if ensure_removable:
1720         - it's not from square bracket indexing
1721         """
1722         if not (
1723             closing.type in CLOSING_BRACKETS
1724             and self.leaves
1725             and self.leaves[-1].type == token.COMMA
1726         ):
1727             return False
1728
1729         if closing.type == token.RBRACE:
1730             return True
1731
1732         if closing.type == token.RSQB:
1733             if not ensure_removable:
1734                 return True
1735             comma = self.leaves[-1]
1736             return bool(comma.parent and comma.parent.type == syms.listmaker)
1737
1738         if self.is_import:
1739             return True
1740
1741         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1742             return True
1743
1744         return False
1745
1746     def append_comment(self, comment: Leaf) -> bool:
1747         """Add an inline or standalone comment to the line."""
1748         if (
1749             comment.type == STANDALONE_COMMENT
1750             and self.bracket_tracker.any_open_brackets()
1751         ):
1752             comment.prefix = ""
1753             return False
1754
1755         if comment.type != token.COMMENT:
1756             return False
1757
1758         if not self.leaves:
1759             comment.type = STANDALONE_COMMENT
1760             comment.prefix = ""
1761             return False
1762
1763         last_leaf = self.leaves[-1]
1764         if (
1765             last_leaf.type == token.RPAR
1766             and not last_leaf.value
1767             and last_leaf.parent
1768             and len(list(last_leaf.parent.leaves())) <= 3
1769             and not is_type_comment(comment)
1770         ):
1771             # Comments on an optional parens wrapping a single leaf should belong to
1772             # the wrapped node except if it's a type comment. Pinning the comment like
1773             # this avoids unstable formatting caused by comment migration.
1774             if len(self.leaves) < 2:
1775                 comment.type = STANDALONE_COMMENT
1776                 comment.prefix = ""
1777                 return False
1778
1779             last_leaf = self.leaves[-2]
1780         self.comments.setdefault(id(last_leaf), []).append(comment)
1781         return True
1782
1783     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1784         """Generate comments that should appear directly after `leaf`."""
1785         return self.comments.get(id(leaf), [])
1786
1787     def remove_trailing_comma(self) -> None:
1788         """Remove the trailing comma and moves the comments attached to it."""
1789         trailing_comma = self.leaves.pop()
1790         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1791         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1792             trailing_comma_comments
1793         )
1794
1795     def is_complex_subscript(self, leaf: Leaf) -> bool:
1796         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1797         open_lsqb = self.bracket_tracker.get_open_lsqb()
1798         if open_lsqb is None:
1799             return False
1800
1801         subscript_start = open_lsqb.next_sibling
1802
1803         if isinstance(subscript_start, Node):
1804             if subscript_start.type == syms.listmaker:
1805                 return False
1806
1807             if subscript_start.type == syms.subscriptlist:
1808                 subscript_start = child_towards(subscript_start, leaf)
1809         return subscript_start is not None and any(
1810             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1811         )
1812
1813     def clone(self) -> "Line":
1814         return Line(
1815             mode=self.mode,
1816             depth=self.depth,
1817             inside_brackets=self.inside_brackets,
1818             should_split_rhs=self.should_split_rhs,
1819             magic_trailing_comma=self.magic_trailing_comma,
1820         )
1821
1822     def __str__(self) -> str:
1823         """Render the line."""
1824         if not self:
1825             return "\n"
1826
1827         indent = "    " * self.depth
1828         leaves = iter(self.leaves)
1829         first = next(leaves)
1830         res = f"{first.prefix}{indent}{first.value}"
1831         for leaf in leaves:
1832             res += str(leaf)
1833         for comment in itertools.chain.from_iterable(self.comments.values()):
1834             res += str(comment)
1835
1836         return res + "\n"
1837
1838     def __bool__(self) -> bool:
1839         """Return True if the line has leaves or comments."""
1840         return bool(self.leaves or self.comments)
1841
1842
1843 @dataclass
1844 class EmptyLineTracker:
1845     """Provides a stateful method that returns the number of potential extra
1846     empty lines needed before and after the currently processed line.
1847
1848     Note: this tracker works on lines that haven't been split yet.  It assumes
1849     the prefix of the first leaf consists of optional newlines.  Those newlines
1850     are consumed by `maybe_empty_lines()` and included in the computation.
1851     """
1852
1853     is_pyi: bool = False
1854     previous_line: Optional[Line] = None
1855     previous_after: int = 0
1856     previous_defs: List[int] = field(default_factory=list)
1857
1858     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1859         """Return the number of extra empty lines before and after the `current_line`.
1860
1861         This is for separating `def`, `async def` and `class` with extra empty
1862         lines (two on module-level).
1863         """
1864         before, after = self._maybe_empty_lines(current_line)
1865         before = (
1866             # Black should not insert empty lines at the beginning
1867             # of the file
1868             0
1869             if self.previous_line is None
1870             else before - self.previous_after
1871         )
1872         self.previous_after = after
1873         self.previous_line = current_line
1874         return before, after
1875
1876     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1877         max_allowed = 1
1878         if current_line.depth == 0:
1879             max_allowed = 1 if self.is_pyi else 2
1880         if current_line.leaves:
1881             # Consume the first leaf's extra newlines.
1882             first_leaf = current_line.leaves[0]
1883             before = first_leaf.prefix.count("\n")
1884             before = min(before, max_allowed)
1885             first_leaf.prefix = ""
1886         else:
1887             before = 0
1888         depth = current_line.depth
1889         while self.previous_defs and self.previous_defs[-1] >= depth:
1890             self.previous_defs.pop()
1891             if self.is_pyi:
1892                 before = 0 if depth else 1
1893             else:
1894                 before = 1 if depth else 2
1895         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1896             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1897
1898         if (
1899             self.previous_line
1900             and self.previous_line.is_import
1901             and not current_line.is_import
1902             and depth == self.previous_line.depth
1903         ):
1904             return (before or 1), 0
1905
1906         if (
1907             self.previous_line
1908             and self.previous_line.is_class
1909             and current_line.is_triple_quoted_string
1910         ):
1911             return before, 1
1912
1913         return before, 0
1914
1915     def _maybe_empty_lines_for_class_or_def(
1916         self, current_line: Line, before: int
1917     ) -> Tuple[int, int]:
1918         if not current_line.is_decorator:
1919             self.previous_defs.append(current_line.depth)
1920         if self.previous_line is None:
1921             # Don't insert empty lines before the first line in the file.
1922             return 0, 0
1923
1924         if self.previous_line.is_decorator:
1925             if self.is_pyi and current_line.is_stub_class:
1926                 # Insert an empty line after a decorated stub class
1927                 return 0, 1
1928
1929             return 0, 0
1930
1931         if self.previous_line.depth < current_line.depth and (
1932             self.previous_line.is_class or self.previous_line.is_def
1933         ):
1934             return 0, 0
1935
1936         if (
1937             self.previous_line.is_comment
1938             and self.previous_line.depth == current_line.depth
1939             and before == 0
1940         ):
1941             return 0, 0
1942
1943         if self.is_pyi:
1944             if self.previous_line.depth > current_line.depth:
1945                 newlines = 1
1946             elif current_line.is_class or self.previous_line.is_class:
1947                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1948                     # No blank line between classes with an empty body
1949                     newlines = 0
1950                 else:
1951                     newlines = 1
1952             elif (
1953                 current_line.is_def or current_line.is_decorator
1954             ) and not self.previous_line.is_def:
1955                 # Blank line between a block of functions (maybe with preceding
1956                 # decorators) and a block of non-functions
1957                 newlines = 1
1958             else:
1959                 newlines = 0
1960         else:
1961             newlines = 2
1962         if current_line.depth and newlines:
1963             newlines -= 1
1964         return newlines, 0
1965
1966
1967 @dataclass
1968 class LineGenerator(Visitor[Line]):
1969     """Generates reformatted Line objects.  Empty lines are not emitted.
1970
1971     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1972     in ways that will no longer stringify to valid Python code on the tree.
1973     """
1974
1975     mode: Mode
1976     remove_u_prefix: bool = False
1977     current_line: Line = field(init=False)
1978
1979     def line(self, indent: int = 0) -> Iterator[Line]:
1980         """Generate a line.
1981
1982         If the line is empty, only emit if it makes sense.
1983         If the line is too long, split it first and then generate.
1984
1985         If any lines were generated, set up a new current_line.
1986         """
1987         if not self.current_line:
1988             self.current_line.depth += indent
1989             return  # Line is empty, don't emit. Creating a new one unnecessary.
1990
1991         complete_line = self.current_line
1992         self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
1993         yield complete_line
1994
1995     def visit_default(self, node: LN) -> Iterator[Line]:
1996         """Default `visit_*()` implementation. Recurses to children of `node`."""
1997         if isinstance(node, Leaf):
1998             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1999             for comment in generate_comments(node):
2000                 if any_open_brackets:
2001                     # any comment within brackets is subject to splitting
2002                     self.current_line.append(comment)
2003                 elif comment.type == token.COMMENT:
2004                     # regular trailing comment
2005                     self.current_line.append(comment)
2006                     yield from self.line()
2007
2008                 else:
2009                     # regular standalone comment
2010                     yield from self.line()
2011
2012                     self.current_line.append(comment)
2013                     yield from self.line()
2014
2015             normalize_prefix(node, inside_brackets=any_open_brackets)
2016             if self.mode.string_normalization and node.type == token.STRING:
2017                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
2018                 normalize_string_quotes(node)
2019             if node.type == token.NUMBER:
2020                 normalize_numeric_literal(node)
2021             if node.type not in WHITESPACE:
2022                 self.current_line.append(node)
2023         yield from super().visit_default(node)
2024
2025     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
2026         """Increase indentation level, maybe yield a line."""
2027         # In blib2to3 INDENT never holds comments.
2028         yield from self.line(+1)
2029         yield from self.visit_default(node)
2030
2031     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
2032         """Decrease indentation level, maybe yield a line."""
2033         # The current line might still wait for trailing comments.  At DEDENT time
2034         # there won't be any (they would be prefixes on the preceding NEWLINE).
2035         # Emit the line then.
2036         yield from self.line()
2037
2038         # While DEDENT has no value, its prefix may contain standalone comments
2039         # that belong to the current indentation level.  Get 'em.
2040         yield from self.visit_default(node)
2041
2042         # Finally, emit the dedent.
2043         yield from self.line(-1)
2044
2045     def visit_stmt(
2046         self, node: Node, keywords: Set[str], parens: Set[str]
2047     ) -> Iterator[Line]:
2048         """Visit a statement.
2049
2050         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2051         `def`, `with`, `class`, `assert` and assignments.
2052
2053         The relevant Python language `keywords` for a given statement will be
2054         NAME leaves within it. This methods puts those on a separate line.
2055
2056         `parens` holds a set of string leaf values immediately after which
2057         invisible parens should be put.
2058         """
2059         normalize_invisible_parens(node, parens_after=parens)
2060         for child in node.children:
2061             if child.type == token.NAME and child.value in keywords:  # type: ignore
2062                 yield from self.line()
2063
2064             yield from self.visit(child)
2065
2066     def visit_suite(self, node: Node) -> Iterator[Line]:
2067         """Visit a suite."""
2068         if self.mode.is_pyi and is_stub_suite(node):
2069             yield from self.visit(node.children[2])
2070         else:
2071             yield from self.visit_default(node)
2072
2073     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2074         """Visit a statement without nested statements."""
2075         if first_child_is_arith(node):
2076             wrap_in_parentheses(node, node.children[0], visible=False)
2077         is_suite_like = node.parent and node.parent.type in STATEMENT
2078         if is_suite_like:
2079             if self.mode.is_pyi and is_stub_body(node):
2080                 yield from self.visit_default(node)
2081             else:
2082                 yield from self.line(+1)
2083                 yield from self.visit_default(node)
2084                 yield from self.line(-1)
2085
2086         else:
2087             if (
2088                 not self.mode.is_pyi
2089                 or not node.parent
2090                 or not is_stub_suite(node.parent)
2091             ):
2092                 yield from self.line()
2093             yield from self.visit_default(node)
2094
2095     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2096         """Visit `async def`, `async for`, `async with`."""
2097         yield from self.line()
2098
2099         children = iter(node.children)
2100         for child in children:
2101             yield from self.visit(child)
2102
2103             if child.type == token.ASYNC:
2104                 break
2105
2106         internal_stmt = next(children)
2107         for child in internal_stmt.children:
2108             yield from self.visit(child)
2109
2110     def visit_decorators(self, node: Node) -> Iterator[Line]:
2111         """Visit decorators."""
2112         for child in node.children:
2113             yield from self.line()
2114             yield from self.visit(child)
2115
2116     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2117         """Remove a semicolon and put the other statement on a separate line."""
2118         yield from self.line()
2119
2120     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2121         """End of file. Process outstanding comments and end with a newline."""
2122         yield from self.visit_default(leaf)
2123         yield from self.line()
2124
2125     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2126         if not self.current_line.bracket_tracker.any_open_brackets():
2127             yield from self.line()
2128         yield from self.visit_default(leaf)
2129
2130     def visit_factor(self, node: Node) -> Iterator[Line]:
2131         """Force parentheses between a unary op and a binary power:
2132
2133         -2 ** 8 -> -(2 ** 8)
2134         """
2135         _operator, operand = node.children
2136         if (
2137             operand.type == syms.power
2138             and len(operand.children) == 3
2139             and operand.children[1].type == token.DOUBLESTAR
2140         ):
2141             lpar = Leaf(token.LPAR, "(")
2142             rpar = Leaf(token.RPAR, ")")
2143             index = operand.remove() or 0
2144             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2145         yield from self.visit_default(node)
2146
2147     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2148         if is_docstring(leaf) and "\\\n" not in leaf.value:
2149             # We're ignoring docstrings with backslash newline escapes because changing
2150             # indentation of those changes the AST representation of the code.
2151             prefix = get_string_prefix(leaf.value)
2152             lead_len = len(prefix) + 3
2153             tail_len = -3
2154             indent = " " * 4 * self.current_line.depth
2155             docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
2156             if docstring:
2157                 if leaf.value[lead_len - 1] == docstring[0]:
2158                     docstring = " " + docstring
2159                 if leaf.value[tail_len + 1] == docstring[-1]:
2160                     docstring = docstring + " "
2161             leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
2162
2163         yield from self.visit_default(leaf)
2164
2165     def __post_init__(self) -> None:
2166         """You are in a twisty little maze of passages."""
2167         self.current_line = Line(mode=self.mode)
2168
2169         v = self.visit_stmt
2170         Ø: Set[str] = set()
2171         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2172         self.visit_if_stmt = partial(
2173             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2174         )
2175         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2176         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2177         self.visit_try_stmt = partial(
2178             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2179         )
2180         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2181         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2182         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2183         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2184         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2185         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2186         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2187         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2188         self.visit_async_funcdef = self.visit_async_stmt
2189         self.visit_decorated = self.visit_decorators
2190
2191
2192 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2193 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2194 OPENING_BRACKETS = set(BRACKET.keys())
2195 CLOSING_BRACKETS = set(BRACKET.values())
2196 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2197 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2198
2199
2200 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2201     """Return whitespace prefix if needed for the given `leaf`.
2202
2203     `complex_subscript` signals whether the given leaf is part of a subscription
2204     which has non-trivial arguments, like arithmetic expressions or function calls.
2205     """
2206     NO = ""
2207     SPACE = " "
2208     DOUBLESPACE = "  "
2209     t = leaf.type
2210     p = leaf.parent
2211     v = leaf.value
2212     if t in ALWAYS_NO_SPACE:
2213         return NO
2214
2215     if t == token.COMMENT:
2216         return DOUBLESPACE
2217
2218     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2219     if t == token.COLON and p.type not in {
2220         syms.subscript,
2221         syms.subscriptlist,
2222         syms.sliceop,
2223     }:
2224         return NO
2225
2226     prev = leaf.prev_sibling
2227     if not prev:
2228         prevp = preceding_leaf(p)
2229         if not prevp or prevp.type in OPENING_BRACKETS:
2230             return NO
2231
2232         if t == token.COLON:
2233             if prevp.type == token.COLON:
2234                 return NO
2235
2236             elif prevp.type != token.COMMA and not complex_subscript:
2237                 return NO
2238
2239             return SPACE
2240
2241         if prevp.type == token.EQUAL:
2242             if prevp.parent:
2243                 if prevp.parent.type in {
2244                     syms.arglist,
2245                     syms.argument,
2246                     syms.parameters,
2247                     syms.varargslist,
2248                 }:
2249                     return NO
2250
2251                 elif prevp.parent.type == syms.typedargslist:
2252                     # A bit hacky: if the equal sign has whitespace, it means we
2253                     # previously found it's a typed argument.  So, we're using
2254                     # that, too.
2255                     return prevp.prefix
2256
2257         elif prevp.type in VARARGS_SPECIALS:
2258             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2259                 return NO
2260
2261         elif prevp.type == token.COLON:
2262             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2263                 return SPACE if complex_subscript else NO
2264
2265         elif (
2266             prevp.parent
2267             and prevp.parent.type == syms.factor
2268             and prevp.type in MATH_OPERATORS
2269         ):
2270             return NO
2271
2272         elif (
2273             prevp.type == token.RIGHTSHIFT
2274             and prevp.parent
2275             and prevp.parent.type == syms.shift_expr
2276             and prevp.prev_sibling
2277             and prevp.prev_sibling.type == token.NAME
2278             and prevp.prev_sibling.value == "print"  # type: ignore
2279         ):
2280             # Python 2 print chevron
2281             return NO
2282         elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
2283             # no space in decorators
2284             return NO
2285
2286     elif prev.type in OPENING_BRACKETS:
2287         return NO
2288
2289     if p.type in {syms.parameters, syms.arglist}:
2290         # untyped function signatures or calls
2291         if not prev or prev.type != token.COMMA:
2292             return NO
2293
2294     elif p.type == syms.varargslist:
2295         # lambdas
2296         if prev and prev.type != token.COMMA:
2297             return NO
2298
2299     elif p.type == syms.typedargslist:
2300         # typed function signatures
2301         if not prev:
2302             return NO
2303
2304         if t == token.EQUAL:
2305             if prev.type != syms.tname:
2306                 return NO
2307
2308         elif prev.type == token.EQUAL:
2309             # A bit hacky: if the equal sign has whitespace, it means we
2310             # previously found it's a typed argument.  So, we're using that, too.
2311             return prev.prefix
2312
2313         elif prev.type != token.COMMA:
2314             return NO
2315
2316     elif p.type == syms.tname:
2317         # type names
2318         if not prev:
2319             prevp = preceding_leaf(p)
2320             if not prevp or prevp.type != token.COMMA:
2321                 return NO
2322
2323     elif p.type == syms.trailer:
2324         # attributes and calls
2325         if t == token.LPAR or t == token.RPAR:
2326             return NO
2327
2328         if not prev:
2329             if t == token.DOT:
2330                 prevp = preceding_leaf(p)
2331                 if not prevp or prevp.type != token.NUMBER:
2332                     return NO
2333
2334             elif t == token.LSQB:
2335                 return NO
2336
2337         elif prev.type != token.COMMA:
2338             return NO
2339
2340     elif p.type == syms.argument:
2341         # single argument
2342         if t == token.EQUAL:
2343             return NO
2344
2345         if not prev:
2346             prevp = preceding_leaf(p)
2347             if not prevp or prevp.type == token.LPAR:
2348                 return NO
2349
2350         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2351             return NO
2352
2353     elif p.type == syms.decorator:
2354         # decorators
2355         return NO
2356
2357     elif p.type == syms.dotted_name:
2358         if prev:
2359             return NO
2360
2361         prevp = preceding_leaf(p)
2362         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2363             return NO
2364
2365     elif p.type == syms.classdef:
2366         if t == token.LPAR:
2367             return NO
2368
2369         if prev and prev.type == token.LPAR:
2370             return NO
2371
2372     elif p.type in {syms.subscript, syms.sliceop}:
2373         # indexing
2374         if not prev:
2375             assert p.parent is not None, "subscripts are always parented"
2376             if p.parent.type == syms.subscriptlist:
2377                 return SPACE
2378
2379             return NO
2380
2381         elif not complex_subscript:
2382             return NO
2383
2384     elif p.type == syms.atom:
2385         if prev and t == token.DOT:
2386             # dots, but not the first one.
2387             return NO
2388
2389     elif p.type == syms.dictsetmaker:
2390         # dict unpacking
2391         if prev and prev.type == token.DOUBLESTAR:
2392             return NO
2393
2394     elif p.type in {syms.factor, syms.star_expr}:
2395         # unary ops
2396         if not prev:
2397             prevp = preceding_leaf(p)
2398             if not prevp or prevp.type in OPENING_BRACKETS:
2399                 return NO
2400
2401             prevp_parent = prevp.parent
2402             assert prevp_parent is not None
2403             if prevp.type == token.COLON and prevp_parent.type in {
2404                 syms.subscript,
2405                 syms.sliceop,
2406             }:
2407                 return NO
2408
2409             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2410                 return NO
2411
2412         elif t in {token.NAME, token.NUMBER, token.STRING}:
2413             return NO
2414
2415     elif p.type == syms.import_from:
2416         if t == token.DOT:
2417             if prev and prev.type == token.DOT:
2418                 return NO
2419
2420         elif t == token.NAME:
2421             if v == "import":
2422                 return SPACE
2423
2424             if prev and prev.type == token.DOT:
2425                 return NO
2426
2427     elif p.type == syms.sliceop:
2428         return NO
2429
2430     return SPACE
2431
2432
2433 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2434     """Return the first leaf that precedes `node`, if any."""
2435     while node:
2436         res = node.prev_sibling
2437         if res:
2438             if isinstance(res, Leaf):
2439                 return res
2440
2441             try:
2442                 return list(res.leaves())[-1]
2443
2444             except IndexError:
2445                 return None
2446
2447         node = node.parent
2448     return None
2449
2450
2451 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2452     """Return if the `node` and its previous siblings match types against the provided
2453     list of tokens; the provided `node`has its type matched against the last element in
2454     the list.  `None` can be used as the first element to declare that the start of the
2455     list is anchored at the start of its parent's children."""
2456     if not tokens:
2457         return True
2458     if tokens[-1] is None:
2459         return node is None
2460     if not node:
2461         return False
2462     if node.type != tokens[-1]:
2463         return False
2464     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2465
2466
2467 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2468     """Return the child of `ancestor` that contains `descendant`."""
2469     node: Optional[LN] = descendant
2470     while node and node.parent != ancestor:
2471         node = node.parent
2472     return node
2473
2474
2475 def container_of(leaf: Leaf) -> LN:
2476     """Return `leaf` or one of its ancestors that is the topmost container of it.
2477
2478     By "container" we mean a node where `leaf` is the very first child.
2479     """
2480     same_prefix = leaf.prefix
2481     container: LN = leaf
2482     while container:
2483         parent = container.parent
2484         if parent is None:
2485             break
2486
2487         if parent.children[0].prefix != same_prefix:
2488             break
2489
2490         if parent.type == syms.file_input:
2491             break
2492
2493         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2494             break
2495
2496         container = parent
2497     return container
2498
2499
2500 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2501     """Return the priority of the `leaf` delimiter, given a line break after it.
2502
2503     The delimiter priorities returned here are from those delimiters that would
2504     cause a line break after themselves.
2505
2506     Higher numbers are higher priority.
2507     """
2508     if leaf.type == token.COMMA:
2509         return COMMA_PRIORITY
2510
2511     return 0
2512
2513
2514 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2515     """Return the priority of the `leaf` delimiter, given a line break before it.
2516
2517     The delimiter priorities returned here are from those delimiters that would
2518     cause a line break before themselves.
2519
2520     Higher numbers are higher priority.
2521     """
2522     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2523         # * and ** might also be MATH_OPERATORS but in this case they are not.
2524         # Don't treat them as a delimiter.
2525         return 0
2526
2527     if (
2528         leaf.type == token.DOT
2529         and leaf.parent
2530         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2531         and (previous is None or previous.type in CLOSING_BRACKETS)
2532     ):
2533         return DOT_PRIORITY
2534
2535     if (
2536         leaf.type in MATH_OPERATORS
2537         and leaf.parent
2538         and leaf.parent.type not in {syms.factor, syms.star_expr}
2539     ):
2540         return MATH_PRIORITIES[leaf.type]
2541
2542     if leaf.type in COMPARATORS:
2543         return COMPARATOR_PRIORITY
2544
2545     if (
2546         leaf.type == token.STRING
2547         and previous is not None
2548         and previous.type == token.STRING
2549     ):
2550         return STRING_PRIORITY
2551
2552     if leaf.type not in {token.NAME, token.ASYNC}:
2553         return 0
2554
2555     if (
2556         leaf.value == "for"
2557         and leaf.parent
2558         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2559         or leaf.type == token.ASYNC
2560     ):
2561         if (
2562             not isinstance(leaf.prev_sibling, Leaf)
2563             or leaf.prev_sibling.value != "async"
2564         ):
2565             return COMPREHENSION_PRIORITY
2566
2567     if (
2568         leaf.value == "if"
2569         and leaf.parent
2570         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2571     ):
2572         return COMPREHENSION_PRIORITY
2573
2574     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2575         return TERNARY_PRIORITY
2576
2577     if leaf.value == "is":
2578         return COMPARATOR_PRIORITY
2579
2580     if (
2581         leaf.value == "in"
2582         and leaf.parent
2583         and leaf.parent.type in {syms.comp_op, syms.comparison}
2584         and not (
2585             previous is not None
2586             and previous.type == token.NAME
2587             and previous.value == "not"
2588         )
2589     ):
2590         return COMPARATOR_PRIORITY
2591
2592     if (
2593         leaf.value == "not"
2594         and leaf.parent
2595         and leaf.parent.type == syms.comp_op
2596         and not (
2597             previous is not None
2598             and previous.type == token.NAME
2599             and previous.value == "is"
2600         )
2601     ):
2602         return COMPARATOR_PRIORITY
2603
2604     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2605         return LOGIC_PRIORITY
2606
2607     return 0
2608
2609
2610 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2611 FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
2612 FMT_PASS = {*FMT_OFF, *FMT_SKIP}
2613 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2614
2615
2616 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2617     """Clean the prefix of the `leaf` and generate comments from it, if any.
2618
2619     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2620     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2621     move because it does away with modifying the grammar to include all the
2622     possible places in which comments can be placed.
2623
2624     The sad consequence for us though is that comments don't "belong" anywhere.
2625     This is why this function generates simple parentless Leaf objects for
2626     comments.  We simply don't know what the correct parent should be.
2627
2628     No matter though, we can live without this.  We really only need to
2629     differentiate between inline and standalone comments.  The latter don't
2630     share the line with any code.
2631
2632     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2633     are emitted with a fake STANDALONE_COMMENT token identifier.
2634     """
2635     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2636         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2637
2638
2639 @dataclass
2640 class ProtoComment:
2641     """Describes a piece of syntax that is a comment.
2642
2643     It's not a :class:`blib2to3.pytree.Leaf` so that:
2644
2645     * it can be cached (`Leaf` objects should not be reused more than once as
2646       they store their lineno, column, prefix, and parent information);
2647     * `newlines` and `consumed` fields are kept separate from the `value`. This
2648       simplifies handling of special marker comments like ``# fmt: off/on``.
2649     """
2650
2651     type: int  # token.COMMENT or STANDALONE_COMMENT
2652     value: str  # content of the comment
2653     newlines: int  # how many newlines before the comment
2654     consumed: int  # how many characters of the original leaf's prefix did we consume
2655
2656
2657 @lru_cache(maxsize=4096)
2658 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2659     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2660     result: List[ProtoComment] = []
2661     if not prefix or "#" not in prefix:
2662         return result
2663
2664     consumed = 0
2665     nlines = 0
2666     ignored_lines = 0
2667     for index, line in enumerate(re.split("\r?\n", prefix)):
2668         consumed += len(line) + 1  # adding the length of the split '\n'
2669         line = line.lstrip()
2670         if not line:
2671             nlines += 1
2672         if not line.startswith("#"):
2673             # Escaped newlines outside of a comment are not really newlines at
2674             # all. We treat a single-line comment following an escaped newline
2675             # as a simple trailing comment.
2676             if line.endswith("\\"):
2677                 ignored_lines += 1
2678             continue
2679
2680         if index == ignored_lines and not is_endmarker:
2681             comment_type = token.COMMENT  # simple trailing comment
2682         else:
2683             comment_type = STANDALONE_COMMENT
2684         comment = make_comment(line)
2685         result.append(
2686             ProtoComment(
2687                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2688             )
2689         )
2690         nlines = 0
2691     return result
2692
2693
2694 def make_comment(content: str) -> str:
2695     """Return a consistently formatted comment from the given `content` string.
2696
2697     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2698     space between the hash sign and the content.
2699
2700     If `content` didn't start with a hash sign, one is provided.
2701     """
2702     content = content.rstrip()
2703     if not content:
2704         return "#"
2705
2706     if content[0] == "#":
2707         content = content[1:]
2708     if content and content[0] not in " !:#'%":
2709         content = " " + content
2710     return "#" + content
2711
2712
2713 def transform_line(
2714     line: Line, mode: Mode, features: Collection[Feature] = ()
2715 ) -> Iterator[Line]:
2716     """Transform a `line`, potentially splitting it into many lines.
2717
2718     They should fit in the allotted `line_length` but might not be able to.
2719
2720     `features` are syntactical features that may be used in the output.
2721     """
2722     if line.is_comment:
2723         yield line
2724         return
2725
2726     line_str = line_to_string(line)
2727
2728     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2729         """Initialize StringTransformer"""
2730         return ST(mode.line_length, mode.string_normalization)
2731
2732     string_merge = init_st(StringMerger)
2733     string_paren_strip = init_st(StringParenStripper)
2734     string_split = init_st(StringSplitter)
2735     string_paren_wrap = init_st(StringParenWrapper)
2736
2737     transformers: List[Transformer]
2738     if (
2739         not line.contains_uncollapsable_type_comments()
2740         and not line.should_split_rhs
2741         and not line.magic_trailing_comma
2742         and (
2743             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2744             or line.contains_unsplittable_type_ignore()
2745         )
2746         and not (line.inside_brackets and line.contains_standalone_comments())
2747     ):
2748         # Only apply basic string preprocessing, since lines shouldn't be split here.
2749         if mode.experimental_string_processing:
2750             transformers = [string_merge, string_paren_strip]
2751         else:
2752             transformers = []
2753     elif line.is_def:
2754         transformers = [left_hand_split]
2755     else:
2756
2757         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2758             """Wraps calls to `right_hand_split`.
2759
2760             The calls increasingly `omit` right-hand trailers (bracket pairs with
2761             content), meaning the trailers get glued together to split on another
2762             bracket pair instead.
2763             """
2764             for omit in generate_trailers_to_omit(line, mode.line_length):
2765                 lines = list(
2766                     right_hand_split(line, mode.line_length, features, omit=omit)
2767                 )
2768                 # Note: this check is only able to figure out if the first line of the
2769                 # *current* transformation fits in the line length.  This is true only
2770                 # for simple cases.  All others require running more transforms via
2771                 # `transform_line()`.  This check doesn't know if those would succeed.
2772                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2773                     yield from lines
2774                     return
2775
2776             # All splits failed, best effort split with no omits.
2777             # This mostly happens to multiline strings that are by definition
2778             # reported as not fitting a single line, as well as lines that contain
2779             # trailing commas (those have to be exploded).
2780             yield from right_hand_split(
2781                 line, line_length=mode.line_length, features=features
2782             )
2783
2784         if mode.experimental_string_processing:
2785             if line.inside_brackets:
2786                 transformers = [
2787                     string_merge,
2788                     string_paren_strip,
2789                     string_split,
2790                     delimiter_split,
2791                     standalone_comment_split,
2792                     string_paren_wrap,
2793                     rhs,
2794                 ]
2795             else:
2796                 transformers = [
2797                     string_merge,
2798                     string_paren_strip,
2799                     string_split,
2800                     string_paren_wrap,
2801                     rhs,
2802                 ]
2803         else:
2804             if line.inside_brackets:
2805                 transformers = [delimiter_split, standalone_comment_split, rhs]
2806             else:
2807                 transformers = [rhs]
2808
2809     for transform in transformers:
2810         # We are accumulating lines in `result` because we might want to abort
2811         # mission and return the original line in the end, or attempt a different
2812         # split altogether.
2813         try:
2814             result = run_transformer(line, transform, mode, features, line_str=line_str)
2815         except CannotTransform:
2816             continue
2817         else:
2818             yield from result
2819             break
2820
2821     else:
2822         yield line
2823
2824
2825 @dataclass  # type: ignore
2826 class StringTransformer(ABC):
2827     """
2828     An implementation of the Transformer protocol that relies on its
2829     subclasses overriding the template methods `do_match(...)` and
2830     `do_transform(...)`.
2831
2832     This Transformer works exclusively on strings (for example, by merging
2833     or splitting them).
2834
2835     The following sections can be found among the docstrings of each concrete
2836     StringTransformer subclass.
2837
2838     Requirements:
2839         Which requirements must be met of the given Line for this
2840         StringTransformer to be applied?
2841
2842     Transformations:
2843         If the given Line meets all of the above requirements, which string
2844         transformations can you expect to be applied to it by this
2845         StringTransformer?
2846
2847     Collaborations:
2848         What contractual agreements does this StringTransformer have with other
2849         StringTransfomers? Such collaborations should be eliminated/minimized
2850         as much as possible.
2851     """
2852
2853     line_length: int
2854     normalize_strings: bool
2855     __name__ = "StringTransformer"
2856
2857     @abstractmethod
2858     def do_match(self, line: Line) -> TMatchResult:
2859         """
2860         Returns:
2861             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2862             string, if a match was able to be made.
2863                 OR
2864             * Err(CannotTransform), if a match was not able to be made.
2865         """
2866
2867     @abstractmethod
2868     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2869         """
2870         Yields:
2871             * Ok(new_line) where new_line is the new transformed line.
2872                 OR
2873             * Err(CannotTransform) if the transformation failed for some reason. The
2874             `do_match(...)` template method should usually be used to reject
2875             the form of the given Line, but in some cases it is difficult to
2876             know whether or not a Line meets the StringTransformer's
2877             requirements until the transformation is already midway.
2878
2879         Side Effects:
2880             This method should NOT mutate @line directly, but it MAY mutate the
2881             Line's underlying Node structure. (WARNING: If the underlying Node
2882             structure IS altered, then this method should NOT be allowed to
2883             yield an CannotTransform after that point.)
2884         """
2885
2886     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2887         """
2888         StringTransformer instances have a call signature that mirrors that of
2889         the Transformer type.
2890
2891         Raises:
2892             CannotTransform(...) if the concrete StringTransformer class is unable
2893             to transform @line.
2894         """
2895         # Optimization to avoid calling `self.do_match(...)` when the line does
2896         # not contain any string.
2897         if not any(leaf.type == token.STRING for leaf in line.leaves):
2898             raise CannotTransform("There are no strings in this line.")
2899
2900         match_result = self.do_match(line)
2901
2902         if isinstance(match_result, Err):
2903             cant_transform = match_result.err()
2904             raise CannotTransform(
2905                 f"The string transformer {self.__class__.__name__} does not recognize"
2906                 " this line as one that it can transform."
2907             ) from cant_transform
2908
2909         string_idx = match_result.ok()
2910
2911         for line_result in self.do_transform(line, string_idx):
2912             if isinstance(line_result, Err):
2913                 cant_transform = line_result.err()
2914                 raise CannotTransform(
2915                     "StringTransformer failed while attempting to transform string."
2916                 ) from cant_transform
2917             line = line_result.ok()
2918             yield line
2919
2920
2921 @dataclass
2922 class CustomSplit:
2923     """A custom (i.e. manual) string split.
2924
2925     A single CustomSplit instance represents a single substring.
2926
2927     Examples:
2928         Consider the following string:
2929         ```
2930         "Hi there friend."
2931         " This is a custom"
2932         f" string {split}."
2933         ```
2934
2935         This string will correspond to the following three CustomSplit instances:
2936         ```
2937         CustomSplit(False, 16)
2938         CustomSplit(False, 17)
2939         CustomSplit(True, 16)
2940         ```
2941     """
2942
2943     has_prefix: bool
2944     break_idx: int
2945
2946
2947 class CustomSplitMapMixin:
2948     """
2949     This mixin class is used to map merged strings to a sequence of
2950     CustomSplits, which will then be used to re-split the strings iff none of
2951     the resultant substrings go over the configured max line length.
2952     """
2953
2954     _Key = Tuple[StringID, str]
2955     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2956
2957     @staticmethod
2958     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2959         """
2960         Returns:
2961             A unique identifier that is used internally to map @string to a
2962             group of custom splits.
2963         """
2964         return (id(string), string)
2965
2966     def add_custom_splits(
2967         self, string: str, custom_splits: Iterable[CustomSplit]
2968     ) -> None:
2969         """Custom Split Map Setter Method
2970
2971         Side Effects:
2972             Adds a mapping from @string to the custom splits @custom_splits.
2973         """
2974         key = self._get_key(string)
2975         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2976
2977     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2978         """Custom Split Map Getter Method
2979
2980         Returns:
2981             * A list of the custom splits that are mapped to @string, if any
2982             exist.
2983                 OR
2984             * [], otherwise.
2985
2986         Side Effects:
2987             Deletes the mapping between @string and its associated custom
2988             splits (which are returned to the caller).
2989         """
2990         key = self._get_key(string)
2991
2992         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2993         del self._CUSTOM_SPLIT_MAP[key]
2994
2995         return list(custom_splits)
2996
2997     def has_custom_splits(self, string: str) -> bool:
2998         """
2999         Returns:
3000             True iff @string is associated with a set of custom splits.
3001         """
3002         key = self._get_key(string)
3003         return key in self._CUSTOM_SPLIT_MAP
3004
3005
3006 class StringMerger(CustomSplitMapMixin, StringTransformer):
3007     """StringTransformer that merges strings together.
3008
3009     Requirements:
3010         (A) The line contains adjacent strings such that ALL of the validation checks
3011         listed in StringMerger.__validate_msg(...)'s docstring pass.
3012             OR
3013         (B) The line contains a string which uses line continuation backslashes.
3014
3015     Transformations:
3016         Depending on which of the two requirements above where met, either:
3017
3018         (A) The string group associated with the target string is merged.
3019             OR
3020         (B) All line-continuation backslashes are removed from the target string.
3021
3022     Collaborations:
3023         StringMerger provides custom split information to StringSplitter.
3024     """
3025
3026     def do_match(self, line: Line) -> TMatchResult:
3027         LL = line.leaves
3028
3029         is_valid_index = is_valid_index_factory(LL)
3030
3031         for (i, leaf) in enumerate(LL):
3032             if (
3033                 leaf.type == token.STRING
3034                 and is_valid_index(i + 1)
3035                 and LL[i + 1].type == token.STRING
3036             ):
3037                 return Ok(i)
3038
3039             if leaf.type == token.STRING and "\\\n" in leaf.value:
3040                 return Ok(i)
3041
3042         return TErr("This line has no strings that need merging.")
3043
3044     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3045         new_line = line
3046         rblc_result = self.__remove_backslash_line_continuation_chars(
3047             new_line, string_idx
3048         )
3049         if isinstance(rblc_result, Ok):
3050             new_line = rblc_result.ok()
3051
3052         msg_result = self.__merge_string_group(new_line, string_idx)
3053         if isinstance(msg_result, Ok):
3054             new_line = msg_result.ok()
3055
3056         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
3057             msg_cant_transform = msg_result.err()
3058             rblc_cant_transform = rblc_result.err()
3059             cant_transform = CannotTransform(
3060                 "StringMerger failed to merge any strings in this line."
3061             )
3062
3063             # Chain the errors together using `__cause__`.
3064             msg_cant_transform.__cause__ = rblc_cant_transform
3065             cant_transform.__cause__ = msg_cant_transform
3066
3067             yield Err(cant_transform)
3068         else:
3069             yield Ok(new_line)
3070
3071     @staticmethod
3072     def __remove_backslash_line_continuation_chars(
3073         line: Line, string_idx: int
3074     ) -> TResult[Line]:
3075         """
3076         Merge strings that were split across multiple lines using
3077         line-continuation backslashes.
3078
3079         Returns:
3080             Ok(new_line), if @line contains backslash line-continuation
3081             characters.
3082                 OR
3083             Err(CannotTransform), otherwise.
3084         """
3085         LL = line.leaves
3086
3087         string_leaf = LL[string_idx]
3088         if not (
3089             string_leaf.type == token.STRING
3090             and "\\\n" in string_leaf.value
3091             and not has_triple_quotes(string_leaf.value)
3092         ):
3093             return TErr(
3094                 f"String leaf {string_leaf} does not contain any backslash line"
3095                 " continuation characters."
3096             )
3097
3098         new_line = line.clone()
3099         new_line.comments = line.comments.copy()
3100         append_leaves(new_line, line, LL)
3101
3102         new_string_leaf = new_line.leaves[string_idx]
3103         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3104
3105         return Ok(new_line)
3106
3107     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3108         """
3109         Merges string group (i.e. set of adjacent strings) where the first
3110         string in the group is `line.leaves[string_idx]`.
3111
3112         Returns:
3113             Ok(new_line), if ALL of the validation checks found in
3114             __validate_msg(...) pass.
3115                 OR
3116             Err(CannotTransform), otherwise.
3117         """
3118         LL = line.leaves
3119
3120         is_valid_index = is_valid_index_factory(LL)
3121
3122         vresult = self.__validate_msg(line, string_idx)
3123         if isinstance(vresult, Err):
3124             return vresult
3125
3126         # If the string group is wrapped inside an Atom node, we must make sure
3127         # to later replace that Atom with our new (merged) string leaf.
3128         atom_node = LL[string_idx].parent
3129
3130         # We will place BREAK_MARK in between every two substrings that we
3131         # merge. We will then later go through our final result and use the
3132         # various instances of BREAK_MARK we find to add the right values to
3133         # the custom split map.
3134         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3135
3136         QUOTE = LL[string_idx].value[-1]
3137
3138         def make_naked(string: str, string_prefix: str) -> str:
3139             """Strip @string (i.e. make it a "naked" string)
3140
3141             Pre-conditions:
3142                 * assert_is_leaf_string(@string)
3143
3144             Returns:
3145                 A string that is identical to @string except that
3146                 @string_prefix has been stripped, the surrounding QUOTE
3147                 characters have been removed, and any remaining QUOTE
3148                 characters have been escaped.
3149             """
3150             assert_is_leaf_string(string)
3151
3152             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3153             naked_string = string[len(string_prefix) + 1 : -1]
3154             naked_string = re.sub(
3155                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3156             )
3157             return naked_string
3158
3159         # Holds the CustomSplit objects that will later be added to the custom
3160         # split map.
3161         custom_splits = []
3162
3163         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3164         prefix_tracker = []
3165
3166         # Sets the 'prefix' variable. This is the prefix that the final merged
3167         # string will have.
3168         next_str_idx = string_idx
3169         prefix = ""
3170         while (
3171             not prefix
3172             and is_valid_index(next_str_idx)
3173             and LL[next_str_idx].type == token.STRING
3174         ):
3175             prefix = get_string_prefix(LL[next_str_idx].value)
3176             next_str_idx += 1
3177
3178         # The next loop merges the string group. The final string will be
3179         # contained in 'S'.
3180         #
3181         # The following convenience variables are used:
3182         #
3183         #   S: string
3184         #   NS: naked string
3185         #   SS: next string
3186         #   NSS: naked next string
3187         S = ""
3188         NS = ""
3189         num_of_strings = 0
3190         next_str_idx = string_idx
3191         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3192             num_of_strings += 1
3193
3194             SS = LL[next_str_idx].value
3195             next_prefix = get_string_prefix(SS)
3196
3197             # If this is an f-string group but this substring is not prefixed
3198             # with 'f'...
3199             if "f" in prefix and "f" not in next_prefix:
3200                 # Then we must escape any braces contained in this substring.
3201                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3202
3203             NSS = make_naked(SS, next_prefix)
3204
3205             has_prefix = bool(next_prefix)
3206             prefix_tracker.append(has_prefix)
3207
3208             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3209             NS = make_naked(S, prefix)
3210
3211             next_str_idx += 1
3212
3213         S_leaf = Leaf(token.STRING, S)
3214         if self.normalize_strings:
3215             normalize_string_quotes(S_leaf)
3216
3217         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3218         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3219         for has_prefix in prefix_tracker:
3220             mark_idx = temp_string.find(BREAK_MARK)
3221             assert (
3222                 mark_idx >= 0
3223             ), "Logic error while filling the custom string breakpoint cache."
3224
3225             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3226             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3227             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3228
3229         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3230
3231         if atom_node is not None:
3232             replace_child(atom_node, string_leaf)
3233
3234         # Build the final line ('new_line') that this method will later return.
3235         new_line = line.clone()
3236         for (i, leaf) in enumerate(LL):
3237             if i == string_idx:
3238                 new_line.append(string_leaf)
3239
3240             if string_idx <= i < string_idx + num_of_strings:
3241                 for comment_leaf in line.comments_after(LL[i]):
3242                     new_line.append(comment_leaf, preformatted=True)
3243                 continue
3244
3245             append_leaves(new_line, line, [leaf])
3246
3247         self.add_custom_splits(string_leaf.value, custom_splits)
3248         return Ok(new_line)
3249
3250     @staticmethod
3251     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3252         """Validate (M)erge (S)tring (G)roup
3253
3254         Transform-time string validation logic for __merge_string_group(...).
3255
3256         Returns:
3257             * Ok(None), if ALL validation checks (listed below) pass.
3258                 OR
3259             * Err(CannotTransform), if any of the following are true:
3260                 - The target string group does not contain ANY stand-alone comments.
3261                 - The target string is not in a string group (i.e. it has no
3262                   adjacent strings).
3263                 - The string group has more than one inline comment.
3264                 - The string group has an inline comment that appears to be a pragma.
3265                 - The set of all string prefixes in the string group is of
3266                   length greater than one and is not equal to {"", "f"}.
3267                 - The string group consists of raw strings.
3268         """
3269         # We first check for "inner" stand-alone comments (i.e. stand-alone
3270         # comments that have a string leaf before them AND after them).
3271         for inc in [1, -1]:
3272             i = string_idx
3273             found_sa_comment = False
3274             is_valid_index = is_valid_index_factory(line.leaves)
3275             while is_valid_index(i) and line.leaves[i].type in [
3276                 token.STRING,
3277                 STANDALONE_COMMENT,
3278             ]:
3279                 if line.leaves[i].type == STANDALONE_COMMENT:
3280                     found_sa_comment = True
3281                 elif found_sa_comment:
3282                     return TErr(
3283                         "StringMerger does NOT merge string groups which contain "
3284                         "stand-alone comments."
3285                     )
3286
3287                 i += inc
3288
3289         num_of_inline_string_comments = 0
3290         set_of_prefixes = set()
3291         num_of_strings = 0
3292         for leaf in line.leaves[string_idx:]:
3293             if leaf.type != token.STRING:
3294                 # If the string group is trailed by a comma, we count the
3295                 # comments trailing the comma to be one of the string group's
3296                 # comments.
3297                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3298                     num_of_inline_string_comments += 1
3299                 break
3300
3301             if has_triple_quotes(leaf.value):
3302                 return TErr("StringMerger does NOT merge multiline strings.")
3303
3304             num_of_strings += 1
3305             prefix = get_string_prefix(leaf.value)
3306             if "r" in prefix:
3307                 return TErr("StringMerger does NOT merge raw strings.")
3308
3309             set_of_prefixes.add(prefix)
3310
3311             if id(leaf) in line.comments:
3312                 num_of_inline_string_comments += 1
3313                 if contains_pragma_comment(line.comments[id(leaf)]):
3314                     return TErr("Cannot merge strings which have pragma comments.")
3315
3316         if num_of_strings < 2:
3317             return TErr(
3318                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3319             )
3320
3321         if num_of_inline_string_comments > 1:
3322             return TErr(
3323                 f"Too many inline string comments ({num_of_inline_string_comments})."
3324             )
3325
3326         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3327             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3328
3329         return Ok(None)
3330
3331
3332 class StringParenStripper(StringTransformer):
3333     """StringTransformer that strips surrounding parentheses from strings.
3334
3335     Requirements:
3336         The line contains a string which is surrounded by parentheses and:
3337             - The target string is NOT the only argument to a function call.
3338             - The target string is NOT a "pointless" string.
3339             - If the target string contains a PERCENT, the brackets are not
3340               preceeded or followed by an operator with higher precedence than
3341               PERCENT.
3342
3343     Transformations:
3344         The parentheses mentioned in the 'Requirements' section are stripped.
3345
3346     Collaborations:
3347         StringParenStripper has its own inherent usefulness, but it is also
3348         relied on to clean up the parentheses created by StringParenWrapper (in
3349         the event that they are no longer needed).
3350     """
3351
3352     def do_match(self, line: Line) -> TMatchResult:
3353         LL = line.leaves
3354
3355         is_valid_index = is_valid_index_factory(LL)
3356
3357         for (idx, leaf) in enumerate(LL):
3358             # Should be a string...
3359             if leaf.type != token.STRING:
3360                 continue
3361
3362             # If this is a "pointless" string...
3363             if (
3364                 leaf.parent
3365                 and leaf.parent.parent
3366                 and leaf.parent.parent.type == syms.simple_stmt
3367             ):
3368                 continue
3369
3370             # Should be preceded by a non-empty LPAR...
3371             if (
3372                 not is_valid_index(idx - 1)
3373                 or LL[idx - 1].type != token.LPAR
3374                 or is_empty_lpar(LL[idx - 1])
3375             ):
3376                 continue
3377
3378             # That LPAR should NOT be preceded by a function name or a closing
3379             # bracket (which could be a function which returns a function or a
3380             # list/dictionary that contains a function)...
3381             if is_valid_index(idx - 2) and (
3382                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3383             ):
3384                 continue
3385
3386             string_idx = idx
3387
3388             # Skip the string trailer, if one exists.
3389             string_parser = StringParser()
3390             next_idx = string_parser.parse(LL, string_idx)
3391
3392             # if the leaves in the parsed string include a PERCENT, we need to
3393             # make sure the initial LPAR is NOT preceded by an operator with
3394             # higher or equal precedence to PERCENT
3395             if is_valid_index(idx - 2):
3396                 # mypy can't quite follow unless we name this
3397                 before_lpar = LL[idx - 2]
3398                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3399                     (
3400                         before_lpar.type
3401                         in {
3402                             token.STAR,
3403                             token.AT,
3404                             token.SLASH,
3405                             token.DOUBLESLASH,
3406                             token.PERCENT,
3407                             token.TILDE,
3408                             token.DOUBLESTAR,
3409                             token.AWAIT,
3410                             token.LSQB,
3411                             token.LPAR,
3412                         }
3413                     )
3414                     or (
3415                         # only unary PLUS/MINUS
3416                         before_lpar.parent
3417                         and before_lpar.parent.type == syms.factor
3418                         and (before_lpar.type in {token.PLUS, token.MINUS})
3419                     )
3420                 ):
3421                     continue
3422
3423             # Should be followed by a non-empty RPAR...
3424             if (
3425                 is_valid_index(next_idx)
3426                 and LL[next_idx].type == token.RPAR
3427                 and not is_empty_rpar(LL[next_idx])
3428             ):
3429                 # That RPAR should NOT be followed by anything with higher
3430                 # precedence than PERCENT
3431                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3432                     token.DOUBLESTAR,
3433                     token.LSQB,
3434                     token.LPAR,
3435                     token.DOT,
3436                 }:
3437                     continue
3438
3439                 return Ok(string_idx)
3440
3441         return TErr("This line has no strings wrapped in parens.")
3442
3443     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3444         LL = line.leaves
3445
3446         string_parser = StringParser()
3447         rpar_idx = string_parser.parse(LL, string_idx)
3448
3449         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3450             if line.comments_after(leaf):
3451                 yield TErr(
3452                     "Will not strip parentheses which have comments attached to them."
3453                 )
3454                 return
3455
3456         new_line = line.clone()
3457         new_line.comments = line.comments.copy()
3458         try:
3459             append_leaves(new_line, line, LL[: string_idx - 1])
3460         except BracketMatchError:
3461             # HACK: I believe there is currently a bug somewhere in
3462             # right_hand_split() that is causing brackets to not be tracked
3463             # properly by a shared BracketTracker.
3464             append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
3465
3466         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3467         LL[string_idx - 1].remove()
3468         replace_child(LL[string_idx], string_leaf)
3469         new_line.append(string_leaf)
3470
3471         append_leaves(
3472             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3473         )
3474
3475         LL[rpar_idx].remove()
3476
3477         yield Ok(new_line)
3478
3479
3480 class BaseStringSplitter(StringTransformer):
3481     """
3482     Abstract class for StringTransformers which transform a Line's strings by splitting
3483     them or placing them on their own lines where necessary to avoid going over
3484     the configured line length.
3485
3486     Requirements:
3487         * The target string value is responsible for the line going over the
3488         line length limit. It follows that after all of black's other line
3489         split methods have been exhausted, this line (or one of the resulting
3490         lines after all line splits are performed) would still be over the
3491         line_length limit unless we split this string.
3492             AND
3493         * The target string is NOT a "pointless" string (i.e. a string that has
3494         no parent or siblings).
3495             AND
3496         * The target string is not followed by an inline comment that appears
3497         to be a pragma.
3498             AND
3499         * The target string is not a multiline (i.e. triple-quote) string.
3500     """
3501
3502     @abstractmethod
3503     def do_splitter_match(self, line: Line) -> TMatchResult:
3504         """
3505         BaseStringSplitter asks its clients to override this method instead of
3506         `StringTransformer.do_match(...)`.
3507
3508         Follows the same protocol as `StringTransformer.do_match(...)`.
3509
3510         Refer to `help(StringTransformer.do_match)` for more information.
3511         """
3512
3513     def do_match(self, line: Line) -> TMatchResult:
3514         match_result = self.do_splitter_match(line)
3515         if isinstance(match_result, Err):
3516             return match_result
3517
3518         string_idx = match_result.ok()
3519         vresult = self.__validate(line, string_idx)
3520         if isinstance(vresult, Err):
3521             return vresult
3522
3523         return match_result
3524
3525     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3526         """
3527         Checks that @line meets all of the requirements listed in this classes'
3528         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3529         description of those requirements.
3530
3531         Returns:
3532             * Ok(None), if ALL of the requirements are met.
3533                 OR
3534             * Err(CannotTransform), if ANY of the requirements are NOT met.
3535         """
3536         LL = line.leaves
3537
3538         string_leaf = LL[string_idx]
3539
3540         max_string_length = self.__get_max_string_length(line, string_idx)
3541         if len(string_leaf.value) <= max_string_length:
3542             return TErr(
3543                 "The string itself is not what is causing this line to be too long."
3544             )
3545
3546         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3547             token.STRING,
3548             token.NEWLINE,
3549         ]:
3550             return TErr(
3551                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3552                 " no parent)."
3553             )
3554
3555         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3556             line.comments[id(line.leaves[string_idx])]
3557         ):
3558             return TErr(
3559                 "Line appears to end with an inline pragma comment. Splitting the line"
3560                 " could modify the pragma's behavior."
3561             )
3562
3563         if has_triple_quotes(string_leaf.value):
3564             return TErr("We cannot split multiline strings.")
3565
3566         return Ok(None)
3567
3568     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3569         """
3570         Calculates the max string length used when attempting to determine
3571         whether or not the target string is responsible for causing the line to
3572         go over the line length limit.
3573
3574         WARNING: This method is tightly coupled to both StringSplitter and
3575         (especially) StringParenWrapper. There is probably a better way to
3576         accomplish what is being done here.
3577
3578         Returns:
3579             max_string_length: such that `line.leaves[string_idx].value >
3580             max_string_length` implies that the target string IS responsible
3581             for causing this line to exceed the line length limit.
3582         """
3583         LL = line.leaves
3584
3585         is_valid_index = is_valid_index_factory(LL)
3586
3587         # We use the shorthand "WMA4" in comments to abbreviate "We must
3588         # account for". When giving examples, we use STRING to mean some/any
3589         # valid string.
3590         #
3591         # Finally, we use the following convenience variables:
3592         #
3593         #   P:  The leaf that is before the target string leaf.
3594         #   N:  The leaf that is after the target string leaf.
3595         #   NN: The leaf that is after N.
3596
3597         # WMA4 the whitespace at the beginning of the line.
3598         offset = line.depth * 4
3599
3600         if is_valid_index(string_idx - 1):
3601             p_idx = string_idx - 1
3602             if (
3603                 LL[string_idx - 1].type == token.LPAR
3604                 and LL[string_idx - 1].value == ""
3605                 and string_idx >= 2
3606             ):
3607                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3608                 p_idx -= 1
3609
3610             P = LL[p_idx]
3611             if P.type == token.PLUS:
3612                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3613                 offset += 2
3614
3615             if P.type == token.COMMA:
3616                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3617                 offset += 3
3618
3619             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3620                 # This conditional branch is meant to handle dictionary keys,
3621                 # variable assignments, 'return STRING' statement lines, and
3622                 # 'else STRING' ternary expression lines.
3623
3624                 # WMA4 a single space.
3625                 offset += 1
3626
3627                 # WMA4 the lengths of any leaves that came before that space,
3628                 # but after any closing bracket before that space.
3629                 for leaf in reversed(LL[: p_idx + 1]):
3630                     offset += len(str(leaf))
3631                     if leaf.type in CLOSING_BRACKETS:
3632                         break
3633
3634         if is_valid_index(string_idx + 1):
3635             N = LL[string_idx + 1]
3636             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3637                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3638                 N = LL[string_idx + 2]
3639
3640             if N.type == token.COMMA:
3641                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3642                 offset += 1
3643
3644             if is_valid_index(string_idx + 2):
3645                 NN = LL[string_idx + 2]
3646
3647                 if N.type == token.DOT and NN.type == token.NAME:
3648                     # This conditional branch is meant to handle method calls invoked
3649                     # off of a string literal up to and including the LPAR character.
3650
3651                     # WMA4 the '.' character.
3652                     offset += 1
3653
3654                     if (
3655                         is_valid_index(string_idx + 3)
3656                         and LL[string_idx + 3].type == token.LPAR
3657                     ):
3658                         # WMA4 the left parenthesis character.
3659                         offset += 1
3660
3661                     # WMA4 the length of the method's name.
3662                     offset += len(NN.value)
3663
3664         has_comments = False
3665         for comment_leaf in line.comments_after(LL[string_idx]):
3666             if not has_comments:
3667                 has_comments = True
3668                 # WMA4 two spaces before the '#' character.
3669                 offset += 2
3670
3671             # WMA4 the length of the inline comment.
3672             offset += len(comment_leaf.value)
3673
3674         max_string_length = self.line_length - offset
3675         return max_string_length
3676
3677
3678 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3679     """
3680     StringTransformer that splits "atom" strings (i.e. strings which exist on
3681     lines by themselves).
3682
3683     Requirements:
3684         * The line consists ONLY of a single string (with the exception of a
3685         '+' symbol which MAY exist at the start of the line), MAYBE a string
3686         trailer, and MAYBE a trailing comma.
3687             AND
3688         * All of the requirements listed in BaseStringSplitter's docstring.
3689
3690     Transformations:
3691         The string mentioned in the 'Requirements' section is split into as
3692         many substrings as necessary to adhere to the configured line length.
3693
3694         In the final set of substrings, no substring should be smaller than
3695         MIN_SUBSTR_SIZE characters.
3696
3697         The string will ONLY be split on spaces (i.e. each new substring should
3698         start with a space). Note that the string will NOT be split on a space
3699         which is escaped with a backslash.
3700
3701         If the string is an f-string, it will NOT be split in the middle of an
3702         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3703         else bar()} is an f-expression).
3704
3705         If the string that is being split has an associated set of custom split
3706         records and those custom splits will NOT result in any line going over
3707         the configured line length, those custom splits are used. Otherwise the
3708         string is split as late as possible (from left-to-right) while still
3709         adhering to the transformation rules listed above.
3710
3711     Collaborations:
3712         StringSplitter relies on StringMerger to construct the appropriate
3713         CustomSplit objects and add them to the custom split map.
3714     """
3715
3716     MIN_SUBSTR_SIZE = 6
3717     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3718     RE_FEXPR = r"""
3719     (?<!\{) (?:\{\{)* \{ (?!\{)
3720         (?:
3721             [^\{\}]
3722             | \{\{
3723             | \}\}
3724             | (?R)
3725         )+?
3726     (?<!\}) \} (?:\}\})* (?!\})
3727     """
3728
3729     def do_splitter_match(self, line: Line) -> TMatchResult:
3730         LL = line.leaves
3731
3732         is_valid_index = is_valid_index_factory(LL)
3733
3734         idx = 0
3735
3736         # The first leaf MAY be a '+' symbol...
3737         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3738             idx += 1
3739
3740         # The next/first leaf MAY be an empty LPAR...
3741         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3742             idx += 1
3743
3744         # The next/first leaf MUST be a string...
3745         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3746             return TErr("Line does not start with a string.")
3747
3748         string_idx = idx
3749
3750         # Skip the string trailer, if one exists.
3751         string_parser = StringParser()
3752         idx = string_parser.parse(LL, string_idx)
3753
3754         # That string MAY be followed by an empty RPAR...
3755         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3756             idx += 1
3757
3758         # That string / empty RPAR leaf MAY be followed by a comma...
3759         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3760             idx += 1
3761
3762         # But no more leaves are allowed...
3763         if is_valid_index(idx):
3764             return TErr("This line does not end with a string.")
3765
3766         return Ok(string_idx)
3767
3768     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3769         LL = line.leaves
3770
3771         QUOTE = LL[string_idx].value[-1]
3772
3773         is_valid_index = is_valid_index_factory(LL)
3774         insert_str_child = insert_str_child_factory(LL[string_idx])
3775
3776         prefix = get_string_prefix(LL[string_idx].value)
3777
3778         # We MAY choose to drop the 'f' prefix from substrings that don't
3779         # contain any f-expressions, but ONLY if the original f-string
3780         # contains at least one f-expression. Otherwise, we will alter the AST
3781         # of the program.
3782         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3783             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3784         )
3785
3786         first_string_line = True
3787         starts_with_plus = LL[0].type == token.PLUS
3788
3789         def line_needs_plus() -> bool:
3790             return first_string_line and starts_with_plus
3791
3792         def maybe_append_plus(new_line: Line) -> None:
3793             """
3794             Side Effects:
3795                 If @line starts with a plus and this is the first line we are
3796                 constructing, this function appends a PLUS leaf to @new_line
3797                 and replaces the old PLUS leaf in the node structure. Otherwise
3798                 this function does nothing.
3799             """
3800             if line_needs_plus():
3801                 plus_leaf = Leaf(token.PLUS, "+")
3802                 replace_child(LL[0], plus_leaf)
3803                 new_line.append(plus_leaf)
3804
3805         ends_with_comma = (
3806             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3807         )
3808
3809         def max_last_string() -> int:
3810             """
3811             Returns:
3812                 The max allowed length of the string value used for the last
3813                 line we will construct.
3814             """
3815             result = self.line_length
3816             result -= line.depth * 4
3817             result -= 1 if ends_with_comma else 0
3818             result -= 2 if line_needs_plus() else 0
3819             return result
3820
3821         # --- Calculate Max Break Index (for string value)
3822         # We start with the line length limit
3823         max_break_idx = self.line_length
3824         # The last index of a string of length N is N-1.
3825         max_break_idx -= 1
3826         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3827         max_break_idx -= line.depth * 4
3828         if max_break_idx < 0:
3829             yield TErr(
3830                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3831                 f" {line.depth}"
3832             )
3833             return
3834
3835         # Check if StringMerger registered any custom splits.
3836         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3837         # We use them ONLY if none of them would produce lines that exceed the
3838         # line limit.
3839         use_custom_breakpoints = bool(
3840             custom_splits
3841             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3842         )
3843
3844         # Temporary storage for the remaining chunk of the string line that
3845         # can't fit onto the line currently being constructed.
3846         rest_value = LL[string_idx].value
3847
3848         def more_splits_should_be_made() -> bool:
3849             """
3850             Returns:
3851                 True iff `rest_value` (the remaining string value from the last
3852                 split), should be split again.
3853             """
3854             if use_custom_breakpoints:
3855                 return len(custom_splits) > 1
3856             else:
3857                 return len(rest_value) > max_last_string()
3858
3859         string_line_results: List[Ok[Line]] = []
3860         while more_splits_should_be_made():
3861             if use_custom_breakpoints:
3862                 # Custom User Split (manual)
3863                 csplit = custom_splits.pop(0)
3864                 break_idx = csplit.break_idx
3865             else:
3866                 # Algorithmic Split (automatic)
3867                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3868                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3869                 if maybe_break_idx is None:
3870                     # If we are unable to algorithmically determine a good split
3871                     # and this string has custom splits registered to it, we
3872                     # fall back to using them--which means we have to start
3873                     # over from the beginning.
3874                     if custom_splits:
3875                         rest_value = LL[string_idx].value
3876                         string_line_results = []
3877                         first_string_line = True
3878                         use_custom_breakpoints = True
3879                         continue
3880
3881                     # Otherwise, we stop splitting here.
3882                     break
3883
3884                 break_idx = maybe_break_idx
3885
3886             # --- Construct `next_value`
3887             next_value = rest_value[:break_idx] + QUOTE
3888             if (
3889                 # Are we allowed to try to drop a pointless 'f' prefix?
3890                 drop_pointless_f_prefix
3891                 # If we are, will we be successful?
3892                 and next_value != self.__normalize_f_string(next_value, prefix)
3893             ):
3894                 # If the current custom split did NOT originally use a prefix,
3895                 # then `csplit.break_idx` will be off by one after removing
3896                 # the 'f' prefix.
3897                 break_idx = (
3898                     break_idx + 1
3899                     if use_custom_breakpoints and not csplit.has_prefix
3900                     else break_idx
3901                 )
3902                 next_value = rest_value[:break_idx] + QUOTE
3903                 next_value = self.__normalize_f_string(next_value, prefix)
3904
3905             # --- Construct `next_leaf`
3906             next_leaf = Leaf(token.STRING, next_value)
3907             insert_str_child(next_leaf)
3908             self.__maybe_normalize_string_quotes(next_leaf)
3909
3910             # --- Construct `next_line`
3911             next_line = line.clone()
3912             maybe_append_plus(next_line)
3913             next_line.append(next_leaf)
3914             string_line_results.append(Ok(next_line))
3915
3916             rest_value = prefix + QUOTE + rest_value[break_idx:]
3917             first_string_line = False
3918
3919         yield from string_line_results
3920
3921         if drop_pointless_f_prefix:
3922             rest_value = self.__normalize_f_string(rest_value, prefix)
3923
3924         rest_leaf = Leaf(token.STRING, rest_value)
3925         insert_str_child(rest_leaf)
3926
3927         # NOTE: I could not find a test case that verifies that the following
3928         # line is actually necessary, but it seems to be. Otherwise we risk
3929         # not normalizing the last substring, right?
3930         self.__maybe_normalize_string_quotes(rest_leaf)
3931
3932         last_line = line.clone()
3933         maybe_append_plus(last_line)
3934
3935         # If there are any leaves to the right of the target string...
3936         if is_valid_index(string_idx + 1):
3937             # We use `temp_value` here to determine how long the last line
3938             # would be if we were to append all the leaves to the right of the
3939             # target string to the last string line.
3940             temp_value = rest_value
3941             for leaf in LL[string_idx + 1 :]:
3942                 temp_value += str(leaf)
3943                 if leaf.type == token.LPAR:
3944                     break
3945
3946             # Try to fit them all on the same line with the last substring...
3947             if (
3948                 len(temp_value) <= max_last_string()
3949                 or LL[string_idx + 1].type == token.COMMA
3950             ):
3951                 last_line.append(rest_leaf)
3952                 append_leaves(last_line, line, LL[string_idx + 1 :])
3953                 yield Ok(last_line)
3954             # Otherwise, place the last substring on one line and everything
3955             # else on a line below that...
3956             else:
3957                 last_line.append(rest_leaf)
3958                 yield Ok(last_line)
3959
3960                 non_string_line = line.clone()
3961                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3962                 yield Ok(non_string_line)
3963         # Else the target string was the last leaf...
3964         else:
3965             last_line.append(rest_leaf)
3966             last_line.comments = line.comments.copy()
3967             yield Ok(last_line)
3968
3969     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3970         """
3971         This method contains the algorithm that StringSplitter uses to
3972         determine which character to split each string at.
3973
3974         Args:
3975             @string: The substring that we are attempting to split.
3976             @max_break_idx: The ideal break index. We will return this value if it
3977             meets all the necessary conditions. In the likely event that it
3978             doesn't we will try to find the closest index BELOW @max_break_idx
3979             that does. If that fails, we will expand our search by also
3980             considering all valid indices ABOVE @max_break_idx.
3981
3982         Pre-Conditions:
3983             * assert_is_leaf_string(@string)
3984             * 0 <= @max_break_idx < len(@string)
3985
3986         Returns:
3987             break_idx, if an index is able to be found that meets all of the
3988             conditions listed in the 'Transformations' section of this classes'
3989             docstring.
3990                 OR
3991             None, otherwise.
3992         """
3993         is_valid_index = is_valid_index_factory(string)
3994
3995         assert is_valid_index(max_break_idx)
3996         assert_is_leaf_string(string)
3997
3998         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3999
4000         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
4001             """
4002             Yields:
4003                 All ranges of @string which, if @string were to be split there,
4004                 would result in the splitting of an f-expression (which is NOT
4005                 allowed).
4006             """
4007             nonlocal _fexpr_slices
4008
4009             if _fexpr_slices is None:
4010                 _fexpr_slices = []
4011                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
4012                     _fexpr_slices.append(match.span())
4013
4014             yield from _fexpr_slices
4015
4016         is_fstring = "f" in get_string_prefix(string)
4017
4018         def breaks_fstring_expression(i: Index) -> bool:
4019             """
4020             Returns:
4021                 True iff returning @i would result in the splitting of an
4022                 f-expression (which is NOT allowed).
4023             """
4024             if not is_fstring:
4025                 return False
4026
4027             for (start, end) in fexpr_slices():
4028                 if start <= i < end:
4029                     return True
4030
4031             return False
4032
4033         def passes_all_checks(i: Index) -> bool:
4034             """
4035             Returns:
4036                 True iff ALL of the conditions listed in the 'Transformations'
4037                 section of this classes' docstring would be be met by returning @i.
4038             """
4039             is_space = string[i] == " "
4040
4041             is_not_escaped = True
4042             j = i - 1
4043             while is_valid_index(j) and string[j] == "\\":
4044                 is_not_escaped = not is_not_escaped
4045                 j -= 1
4046
4047             is_big_enough = (
4048                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
4049                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
4050             )
4051             return (
4052                 is_space
4053                 and is_not_escaped
4054                 and is_big_enough
4055                 and not breaks_fstring_expression(i)
4056             )
4057
4058         # First, we check all indices BELOW @max_break_idx.
4059         break_idx = max_break_idx
4060         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
4061             break_idx -= 1
4062
4063         if not passes_all_checks(break_idx):
4064             # If that fails, we check all indices ABOVE @max_break_idx.
4065             #
4066             # If we are able to find a valid index here, the next line is going
4067             # to be longer than the specified line length, but it's probably
4068             # better than doing nothing at all.
4069             break_idx = max_break_idx + 1
4070             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
4071                 break_idx += 1
4072
4073             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
4074                 return None
4075
4076         return break_idx
4077
4078     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
4079         if self.normalize_strings:
4080             normalize_string_quotes(leaf)
4081
4082     def __normalize_f_string(self, string: str, prefix: str) -> str:
4083         """
4084         Pre-Conditions:
4085             * assert_is_leaf_string(@string)
4086
4087         Returns:
4088             * If @string is an f-string that contains no f-expressions, we
4089             return a string identical to @string except that the 'f' prefix
4090             has been stripped and all double braces (i.e. '{{' or '}}') have
4091             been normalized (i.e. turned into '{' or '}').
4092                 OR
4093             * Otherwise, we return @string.
4094         """
4095         assert_is_leaf_string(string)
4096
4097         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
4098             new_prefix = prefix.replace("f", "")
4099
4100             temp = string[len(prefix) :]
4101             temp = re.sub(r"\{\{", "{", temp)
4102             temp = re.sub(r"\}\}", "}", temp)
4103             new_string = temp
4104
4105             return f"{new_prefix}{new_string}"
4106         else:
4107             return string
4108
4109
4110 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4111     """
4112     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4113     exist on lines by themselves).
4114
4115     Requirements:
4116         All of the requirements listed in BaseStringSplitter's docstring in
4117         addition to the requirements listed below:
4118
4119         * The line is a return/yield statement, which returns/yields a string.
4120             OR
4121         * The line is part of a ternary expression (e.g. `x = y if cond else
4122         z`) such that the line starts with `else <string>`, where <string> is
4123         some string.
4124             OR
4125         * The line is an assert statement, which ends with a string.
4126             OR
4127         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4128         <string>`) such that the variable is being assigned the value of some
4129         string.
4130             OR
4131         * The line is a dictionary key assignment where some valid key is being
4132         assigned the value of some string.
4133
4134     Transformations:
4135         The chosen string is wrapped in parentheses and then split at the LPAR.
4136
4137         We then have one line which ends with an LPAR and another line that
4138         starts with the chosen string. The latter line is then split again at
4139         the RPAR. This results in the RPAR (and possibly a trailing comma)
4140         being placed on its own line.
4141
4142         NOTE: If any leaves exist to the right of the chosen string (except
4143         for a trailing comma, which would be placed after the RPAR), those
4144         leaves are placed inside the parentheses.  In effect, the chosen
4145         string is not necessarily being "wrapped" by parentheses. We can,
4146         however, count on the LPAR being placed directly before the chosen
4147         string.
4148
4149         In other words, StringParenWrapper creates "atom" strings. These
4150         can then be split again by StringSplitter, if necessary.
4151
4152     Collaborations:
4153         In the event that a string line split by StringParenWrapper is
4154         changed such that it no longer needs to be given its own line,
4155         StringParenWrapper relies on StringParenStripper to clean up the
4156         parentheses it created.
4157     """
4158
4159     def do_splitter_match(self, line: Line) -> TMatchResult:
4160         LL = line.leaves
4161
4162         string_idx = (
4163             self._return_match(LL)
4164             or self._else_match(LL)
4165             or self._assert_match(LL)
4166             or self._assign_match(LL)
4167             or self._dict_match(LL)
4168         )
4169
4170         if string_idx is not None:
4171             string_value = line.leaves[string_idx].value
4172             # If the string has no spaces...
4173             if " " not in string_value:
4174                 # And will still violate the line length limit when split...
4175                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4176                 if len(string_value) > max_string_length:
4177                     # And has no associated custom splits...
4178                     if not self.has_custom_splits(string_value):
4179                         # Then we should NOT put this string on its own line.
4180                         return TErr(
4181                             "We do not wrap long strings in parentheses when the"
4182                             " resultant line would still be over the specified line"
4183                             " length and can't be split further by StringSplitter."
4184                         )
4185             return Ok(string_idx)
4186
4187         return TErr("This line does not contain any non-atomic strings.")
4188
4189     @staticmethod
4190     def _return_match(LL: List[Leaf]) -> Optional[int]:
4191         """
4192         Returns:
4193             string_idx such that @LL[string_idx] is equal to our target (i.e.
4194             matched) string, if this line matches the return/yield statement
4195             requirements listed in the 'Requirements' section of this classes'
4196             docstring.
4197                 OR
4198             None, otherwise.
4199         """
4200         # If this line is apart of a return/yield statement and the first leaf
4201         # contains either the "return" or "yield" keywords...
4202         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4203             0
4204         ].value in ["return", "yield"]:
4205             is_valid_index = is_valid_index_factory(LL)
4206
4207             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4208             # The next visible leaf MUST contain a string...
4209             if is_valid_index(idx) and LL[idx].type == token.STRING:
4210                 return idx
4211
4212         return None
4213
4214     @staticmethod
4215     def _else_match(LL: List[Leaf]) -> Optional[int]:
4216         """
4217         Returns:
4218             string_idx such that @LL[string_idx] is equal to our target (i.e.
4219             matched) string, if this line matches the ternary expression
4220             requirements listed in the 'Requirements' section of this classes'
4221             docstring.
4222                 OR
4223             None, otherwise.
4224         """
4225         # If this line is apart of a ternary expression and the first leaf
4226         # contains the "else" keyword...
4227         if (
4228             parent_type(LL[0]) == syms.test
4229             and LL[0].type == token.NAME
4230             and LL[0].value == "else"
4231         ):
4232             is_valid_index = is_valid_index_factory(LL)
4233
4234             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4235             # The next visible leaf MUST contain a string...
4236             if is_valid_index(idx) and LL[idx].type == token.STRING:
4237                 return idx
4238
4239         return None
4240
4241     @staticmethod
4242     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4243         """
4244         Returns:
4245             string_idx such that @LL[string_idx] is equal to our target (i.e.
4246             matched) string, if this line matches the assert statement
4247             requirements listed in the 'Requirements' section of this classes'
4248             docstring.
4249                 OR
4250             None, otherwise.
4251         """
4252         # If this line is apart of an assert statement and the first leaf
4253         # contains the "assert" keyword...
4254         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4255             is_valid_index = is_valid_index_factory(LL)
4256
4257             for (i, leaf) in enumerate(LL):
4258                 # We MUST find a comma...
4259                 if leaf.type == token.COMMA:
4260                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4261
4262                     # That comma MUST be followed by a string...
4263                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4264                         string_idx = idx
4265
4266                         # Skip the string trailer, if one exists.
4267                         string_parser = StringParser()
4268                         idx = string_parser.parse(LL, string_idx)
4269
4270                         # But no more leaves are allowed...
4271                         if not is_valid_index(idx):
4272                             return string_idx
4273
4274         return None
4275
4276     @staticmethod
4277     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4278         """
4279         Returns:
4280             string_idx such that @LL[string_idx] is equal to our target (i.e.
4281             matched) string, if this line matches the assignment statement
4282             requirements listed in the 'Requirements' section of this classes'
4283             docstring.
4284                 OR
4285             None, otherwise.
4286         """
4287         # If this line is apart of an expression statement or is a function
4288         # argument AND the first leaf contains a variable name...
4289         if (
4290             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4291             and LL[0].type == token.NAME
4292         ):
4293             is_valid_index = is_valid_index_factory(LL)
4294
4295             for (i, leaf) in enumerate(LL):
4296                 # We MUST find either an '=' or '+=' symbol...
4297                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4298                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4299
4300                     # That symbol MUST be followed by a string...
4301                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4302                         string_idx = idx
4303
4304                         # Skip the string trailer, if one exists.
4305                         string_parser = StringParser()
4306                         idx = string_parser.parse(LL, string_idx)
4307
4308                         # The next leaf MAY be a comma iff this line is apart
4309                         # of a function argument...
4310                         if (
4311                             parent_type(LL[0]) == syms.argument
4312                             and is_valid_index(idx)
4313                             and LL[idx].type == token.COMMA
4314                         ):
4315                             idx += 1
4316
4317                         # But no more leaves are allowed...
4318                         if not is_valid_index(idx):
4319                             return string_idx
4320
4321         return None
4322
4323     @staticmethod
4324     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4325         """
4326         Returns:
4327             string_idx such that @LL[string_idx] is equal to our target (i.e.
4328             matched) string, if this line matches the dictionary key assignment
4329             statement requirements listed in the 'Requirements' section of this
4330             classes' docstring.
4331                 OR
4332             None, otherwise.
4333         """
4334         # If this line is apart of a dictionary key assignment...
4335         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4336             is_valid_index = is_valid_index_factory(LL)
4337
4338             for (i, leaf) in enumerate(LL):
4339                 # We MUST find a colon...
4340                 if leaf.type == token.COLON:
4341                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4342
4343                     # That colon MUST be followed by a string...
4344                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4345                         string_idx = idx
4346
4347                         # Skip the string trailer, if one exists.
4348                         string_parser = StringParser()
4349                         idx = string_parser.parse(LL, string_idx)
4350
4351                         # That string MAY be followed by a comma...
4352                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4353                             idx += 1
4354
4355                         # But no more leaves are allowed...
4356                         if not is_valid_index(idx):
4357                             return string_idx
4358
4359         return None
4360
4361     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4362         LL = line.leaves
4363
4364         is_valid_index = is_valid_index_factory(LL)
4365         insert_str_child = insert_str_child_factory(LL[string_idx])
4366
4367         comma_idx = -1
4368         ends_with_comma = False
4369         if LL[comma_idx].type == token.COMMA:
4370             ends_with_comma = True
4371
4372         leaves_to_steal_comments_from = [LL[string_idx]]
4373         if ends_with_comma:
4374             leaves_to_steal_comments_from.append(LL[comma_idx])
4375
4376         # --- First Line
4377         first_line = line.clone()
4378         left_leaves = LL[:string_idx]
4379
4380         # We have to remember to account for (possibly invisible) LPAR and RPAR
4381         # leaves that already wrapped the target string. If these leaves do
4382         # exist, we will replace them with our own LPAR and RPAR leaves.
4383         old_parens_exist = False
4384         if left_leaves and left_leaves[-1].type == token.LPAR:
4385             old_parens_exist = True
4386             leaves_to_steal_comments_from.append(left_leaves[-1])
4387             left_leaves.pop()
4388
4389         append_leaves(first_line, line, left_leaves)
4390
4391         lpar_leaf = Leaf(token.LPAR, "(")
4392         if old_parens_exist:
4393             replace_child(LL[string_idx - 1], lpar_leaf)
4394         else:
4395             insert_str_child(lpar_leaf)
4396         first_line.append(lpar_leaf)
4397
4398         # We throw inline comments that were originally to the right of the
4399         # target string to the top line. They will now be shown to the right of
4400         # the LPAR.
4401         for leaf in leaves_to_steal_comments_from:
4402             for comment_leaf in line.comments_after(leaf):
4403                 first_line.append(comment_leaf, preformatted=True)
4404
4405         yield Ok(first_line)
4406
4407         # --- Middle (String) Line
4408         # We only need to yield one (possibly too long) string line, since the
4409         # `StringSplitter` will break it down further if necessary.
4410         string_value = LL[string_idx].value
4411         string_line = Line(
4412             mode=line.mode,
4413             depth=line.depth + 1,
4414             inside_brackets=True,
4415             should_split_rhs=line.should_split_rhs,
4416             magic_trailing_comma=line.magic_trailing_comma,
4417         )
4418         string_leaf = Leaf(token.STRING, string_value)
4419         insert_str_child(string_leaf)
4420         string_line.append(string_leaf)
4421
4422         old_rpar_leaf = None
4423         if is_valid_index(string_idx + 1):
4424             right_leaves = LL[string_idx + 1 :]
4425             if ends_with_comma:
4426                 right_leaves.pop()
4427
4428             if old_parens_exist:
4429                 assert (
4430                     right_leaves and right_leaves[-1].type == token.RPAR
4431                 ), "Apparently, old parentheses do NOT exist?!"
4432                 old_rpar_leaf = right_leaves.pop()
4433
4434             append_leaves(string_line, line, right_leaves)
4435
4436         yield Ok(string_line)
4437
4438         # --- Last Line
4439         last_line = line.clone()
4440         last_line.bracket_tracker = first_line.bracket_tracker
4441
4442         new_rpar_leaf = Leaf(token.RPAR, ")")
4443         if old_rpar_leaf is not None:
4444             replace_child(old_rpar_leaf, new_rpar_leaf)
4445         else:
4446             insert_str_child(new_rpar_leaf)
4447         last_line.append(new_rpar_leaf)
4448
4449         # If the target string ended with a comma, we place this comma to the
4450         # right of the RPAR on the last line.
4451         if ends_with_comma:
4452             comma_leaf = Leaf(token.COMMA, ",")
4453             replace_child(LL[comma_idx], comma_leaf)
4454             last_line.append(comma_leaf)
4455
4456         yield Ok(last_line)
4457
4458
4459 class StringParser:
4460     """
4461     A state machine that aids in parsing a string's "trailer", which can be
4462     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4463     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4464     varY)`).
4465
4466     NOTE: A new StringParser object MUST be instantiated for each string
4467     trailer we need to parse.
4468
4469     Examples:
4470         We shall assume that `line` equals the `Line` object that corresponds
4471         to the following line of python code:
4472         ```
4473         x = "Some {}.".format("String") + some_other_string
4474         ```
4475
4476         Furthermore, we will assume that `string_idx` is some index such that:
4477         ```
4478         assert line.leaves[string_idx].value == "Some {}."
4479         ```
4480
4481         The following code snippet then holds:
4482         ```
4483         string_parser = StringParser()
4484         idx = string_parser.parse(line.leaves, string_idx)
4485         assert line.leaves[idx].type == token.PLUS
4486         ```
4487     """
4488
4489     DEFAULT_TOKEN = -1
4490
4491     # String Parser States
4492     START = 1
4493     DOT = 2
4494     NAME = 3
4495     PERCENT = 4
4496     SINGLE_FMT_ARG = 5
4497     LPAR = 6
4498     RPAR = 7
4499     DONE = 8
4500
4501     # Lookup Table for Next State
4502     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4503         # A string trailer may start with '.' OR '%'.
4504         (START, token.DOT): DOT,
4505         (START, token.PERCENT): PERCENT,
4506         (START, DEFAULT_TOKEN): DONE,
4507         # A '.' MUST be followed by an attribute or method name.
4508         (DOT, token.NAME): NAME,
4509         # A method name MUST be followed by an '(', whereas an attribute name
4510         # is the last symbol in the string trailer.
4511         (NAME, token.LPAR): LPAR,
4512         (NAME, DEFAULT_TOKEN): DONE,
4513         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4514         # string or variable name).
4515         (PERCENT, token.LPAR): LPAR,
4516         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4517         # If a '%' symbol is followed by a single argument, that argument is
4518         # the last leaf in the string trailer.
4519         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4520         # If present, a ')' symbol is the last symbol in a string trailer.
4521         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4522         # since they are treated as a special case by the parsing logic in this
4523         # classes' implementation.)
4524         (RPAR, DEFAULT_TOKEN): DONE,
4525     }
4526
4527     def __init__(self) -> None:
4528         self._state = self.START
4529         self._unmatched_lpars = 0
4530
4531     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4532         """
4533         Pre-conditions:
4534             * @leaves[@string_idx].type == token.STRING
4535
4536         Returns:
4537             The index directly after the last leaf which is apart of the string
4538             trailer, if a "trailer" exists.
4539                 OR
4540             @string_idx + 1, if no string "trailer" exists.
4541         """
4542         assert leaves[string_idx].type == token.STRING
4543
4544         idx = string_idx + 1
4545         while idx < len(leaves) and self._next_state(leaves[idx]):
4546             idx += 1
4547         return idx
4548
4549     def _next_state(self, leaf: Leaf) -> bool:
4550         """
4551         Pre-conditions:
4552             * On the first call to this function, @leaf MUST be the leaf that
4553             was directly after the string leaf in question (e.g. if our target
4554             string is `line.leaves[i]` then the first call to this method must
4555             be `line.leaves[i + 1]`).
4556             * On the next call to this function, the leaf parameter passed in
4557             MUST be the leaf directly following @leaf.
4558
4559         Returns:
4560             True iff @leaf is apart of the string's trailer.
4561         """
4562         # We ignore empty LPAR or RPAR leaves.
4563         if is_empty_par(leaf):
4564             return True
4565
4566         next_token = leaf.type
4567         if next_token == token.LPAR:
4568             self._unmatched_lpars += 1
4569
4570         current_state = self._state
4571
4572         # The LPAR parser state is a special case. We will return True until we
4573         # find the matching RPAR token.
4574         if current_state == self.LPAR:
4575             if next_token == token.RPAR:
4576                 self._unmatched_lpars -= 1
4577                 if self._unmatched_lpars == 0:
4578                     self._state = self.RPAR
4579         # Otherwise, we use a lookup table to determine the next state.
4580         else:
4581             # If the lookup table matches the current state to the next
4582             # token, we use the lookup table.
4583             if (current_state, next_token) in self._goto:
4584                 self._state = self._goto[current_state, next_token]
4585             else:
4586                 # Otherwise, we check if a the current state was assigned a
4587                 # default.
4588                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4589                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4590                 # If no default has been assigned, then this parser has a logic
4591                 # error.
4592                 else:
4593                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4594
4595             if self._state == self.DONE:
4596                 return False
4597
4598         return True
4599
4600
4601 def TErr(err_msg: str) -> Err[CannotTransform]:
4602     """(T)ransform Err
4603
4604     Convenience function used when working with the TResult type.
4605     """
4606     cant_transform = CannotTransform(err_msg)
4607     return Err(cant_transform)
4608
4609
4610 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4611     """
4612     Returns:
4613         True iff one of the comments in @comment_list is a pragma used by one
4614         of the more common static analysis tools for python (e.g. mypy, flake8,
4615         pylint).
4616     """
4617     for comment in comment_list:
4618         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4619             return True
4620
4621     return False
4622
4623
4624 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4625     """
4626     Factory for a convenience function that is used to orphan @string_leaf
4627     and then insert multiple new leaves into the same part of the node
4628     structure that @string_leaf had originally occupied.
4629
4630     Examples:
4631         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4632         string_leaf.parent`. Assume the node `N` has the following
4633         original structure:
4634
4635         Node(
4636             expr_stmt, [
4637                 Leaf(NAME, 'x'),
4638                 Leaf(EQUAL, '='),
4639                 Leaf(STRING, '"foo"'),
4640             ]
4641         )
4642
4643         We then run the code snippet shown below.
4644         ```
4645         insert_str_child = insert_str_child_factory(string_leaf)
4646
4647         lpar = Leaf(token.LPAR, '(')
4648         insert_str_child(lpar)
4649
4650         bar = Leaf(token.STRING, '"bar"')
4651         insert_str_child(bar)
4652
4653         rpar = Leaf(token.RPAR, ')')
4654         insert_str_child(rpar)
4655         ```
4656
4657         After which point, it follows that `string_leaf.parent is None` and
4658         the node `N` now has the following structure:
4659
4660         Node(
4661             expr_stmt, [
4662                 Leaf(NAME, 'x'),
4663                 Leaf(EQUAL, '='),
4664                 Leaf(LPAR, '('),
4665                 Leaf(STRING, '"bar"'),
4666                 Leaf(RPAR, ')'),
4667             ]
4668         )
4669     """
4670     string_parent = string_leaf.parent
4671     string_child_idx = string_leaf.remove()
4672
4673     def insert_str_child(child: LN) -> None:
4674         nonlocal string_child_idx
4675
4676         assert string_parent is not None
4677         assert string_child_idx is not None
4678
4679         string_parent.insert_child(string_child_idx, child)
4680         string_child_idx += 1
4681
4682     return insert_str_child
4683
4684
4685 def has_triple_quotes(string: str) -> bool:
4686     """
4687     Returns:
4688         True iff @string starts with three quotation characters.
4689     """
4690     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4691     return raw_string[:3] in {'"""', "'''"}
4692
4693
4694 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4695     """
4696     Returns:
4697         @node.parent.type, if @node is not None and has a parent.
4698             OR
4699         None, otherwise.
4700     """
4701     if node is None or node.parent is None:
4702         return None
4703
4704     return node.parent.type
4705
4706
4707 def is_empty_par(leaf: Leaf) -> bool:
4708     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4709
4710
4711 def is_empty_lpar(leaf: Leaf) -> bool:
4712     return leaf.type == token.LPAR and leaf.value == ""
4713
4714
4715 def is_empty_rpar(leaf: Leaf) -> bool:
4716     return leaf.type == token.RPAR and leaf.value == ""
4717
4718
4719 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4720     """
4721     Examples:
4722         ```
4723         my_list = [1, 2, 3]
4724
4725         is_valid_index = is_valid_index_factory(my_list)
4726
4727         assert is_valid_index(0)
4728         assert is_valid_index(2)
4729
4730         assert not is_valid_index(3)
4731         assert not is_valid_index(-1)
4732         ```
4733     """
4734
4735     def is_valid_index(idx: int) -> bool:
4736         """
4737         Returns:
4738             True iff @idx is positive AND seq[@idx] does NOT raise an
4739             IndexError.
4740         """
4741         return 0 <= idx < len(seq)
4742
4743     return is_valid_index
4744
4745
4746 def line_to_string(line: Line) -> str:
4747     """Returns the string representation of @line.
4748
4749     WARNING: This is known to be computationally expensive.
4750     """
4751     return str(line).strip("\n")
4752
4753
4754 def append_leaves(
4755     new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
4756 ) -> None:
4757     """
4758     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4759     underlying Node structure where appropriate.
4760
4761     All of the leaves in @leaves are duplicated. The duplicates are then
4762     appended to @new_line and used to replace their originals in the underlying
4763     Node structure. Any comments attached to the old leaves are reattached to
4764     the new leaves.
4765
4766     Pre-conditions:
4767         set(@leaves) is a subset of set(@old_line.leaves).
4768     """
4769     for old_leaf in leaves:
4770         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4771         replace_child(old_leaf, new_leaf)
4772         new_line.append(new_leaf, preformatted=preformatted)
4773
4774         for comment_leaf in old_line.comments_after(old_leaf):
4775             new_line.append(comment_leaf, preformatted=True)
4776
4777
4778 def replace_child(old_child: LN, new_child: LN) -> None:
4779     """
4780     Side Effects:
4781         * If @old_child.parent is set, replace @old_child with @new_child in
4782         @old_child's underlying Node structure.
4783             OR
4784         * Otherwise, this function does nothing.
4785     """
4786     parent = old_child.parent
4787     if not parent:
4788         return
4789
4790     child_idx = old_child.remove()
4791     if child_idx is not None:
4792         parent.insert_child(child_idx, new_child)
4793
4794
4795 def get_string_prefix(string: str) -> str:
4796     """
4797     Pre-conditions:
4798         * assert_is_leaf_string(@string)
4799
4800     Returns:
4801         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4802     """
4803     assert_is_leaf_string(string)
4804
4805     prefix = ""
4806     prefix_idx = 0
4807     while string[prefix_idx] in STRING_PREFIX_CHARS:
4808         prefix += string[prefix_idx].lower()
4809         prefix_idx += 1
4810
4811     return prefix
4812
4813
4814 def assert_is_leaf_string(string: str) -> None:
4815     """
4816     Checks the pre-condition that @string has the format that you would expect
4817     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4818     token.STRING`. A more precise description of the pre-conditions that are
4819     checked are listed below.
4820
4821     Pre-conditions:
4822         * @string starts with either ', ", <prefix>', or <prefix>" where
4823         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4824         * @string ends with a quote character (' or ").
4825
4826     Raises:
4827         AssertionError(...) if the pre-conditions listed above are not
4828         satisfied.
4829     """
4830     dquote_idx = string.find('"')
4831     squote_idx = string.find("'")
4832     if -1 in [dquote_idx, squote_idx]:
4833         quote_idx = max(dquote_idx, squote_idx)
4834     else:
4835         quote_idx = min(squote_idx, dquote_idx)
4836
4837     assert (
4838         0 <= quote_idx < len(string) - 1
4839     ), f"{string!r} is missing a starting quote character (' or \")."
4840     assert string[-1] in (
4841         "'",
4842         '"',
4843     ), f"{string!r} is missing an ending quote character (' or \")."
4844     assert set(string[:quote_idx]).issubset(
4845         set(STRING_PREFIX_CHARS)
4846     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4847
4848
4849 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4850     """Split line into many lines, starting with the first matching bracket pair.
4851
4852     Note: this usually looks weird, only use this for function definitions.
4853     Prefer RHS otherwise.  This is why this function is not symmetrical with
4854     :func:`right_hand_split` which also handles optional parentheses.
4855     """
4856     tail_leaves: List[Leaf] = []
4857     body_leaves: List[Leaf] = []
4858     head_leaves: List[Leaf] = []
4859     current_leaves = head_leaves
4860     matching_bracket: Optional[Leaf] = None
4861     for leaf in line.leaves:
4862         if (
4863             current_leaves is body_leaves
4864             and leaf.type in CLOSING_BRACKETS
4865             and leaf.opening_bracket is matching_bracket
4866         ):
4867             current_leaves = tail_leaves if body_leaves else head_leaves
4868         current_leaves.append(leaf)
4869         if current_leaves is head_leaves:
4870             if leaf.type in OPENING_BRACKETS:
4871                 matching_bracket = leaf
4872                 current_leaves = body_leaves
4873     if not matching_bracket:
4874         raise CannotSplit("No brackets found")
4875
4876     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4877     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4878     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4879     bracket_split_succeeded_or_raise(head, body, tail)
4880     for result in (head, body, tail):
4881         if result:
4882             yield result
4883
4884
4885 def right_hand_split(
4886     line: Line,
4887     line_length: int,
4888     features: Collection[Feature] = (),
4889     omit: Collection[LeafID] = (),
4890 ) -> Iterator[Line]:
4891     """Split line into many lines, starting with the last matching bracket pair.
4892
4893     If the split was by optional parentheses, attempt splitting without them, too.
4894     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4895     this split.
4896
4897     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4898     """
4899     tail_leaves: List[Leaf] = []
4900     body_leaves: List[Leaf] = []
4901     head_leaves: List[Leaf] = []
4902     current_leaves = tail_leaves
4903     opening_bracket: Optional[Leaf] = None
4904     closing_bracket: Optional[Leaf] = None
4905     for leaf in reversed(line.leaves):
4906         if current_leaves is body_leaves:
4907             if leaf is opening_bracket:
4908                 current_leaves = head_leaves if body_leaves else tail_leaves
4909         current_leaves.append(leaf)
4910         if current_leaves is tail_leaves:
4911             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4912                 opening_bracket = leaf.opening_bracket
4913                 closing_bracket = leaf
4914                 current_leaves = body_leaves
4915     if not (opening_bracket and closing_bracket and head_leaves):
4916         # If there is no opening or closing_bracket that means the split failed and
4917         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4918         # the matching `opening_bracket` wasn't available on `line` anymore.
4919         raise CannotSplit("No brackets found")
4920
4921     tail_leaves.reverse()
4922     body_leaves.reverse()
4923     head_leaves.reverse()
4924     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4925     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4926     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4927     bracket_split_succeeded_or_raise(head, body, tail)
4928     if (
4929         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4930         # the opening bracket is an optional paren
4931         and opening_bracket.type == token.LPAR
4932         and not opening_bracket.value
4933         # the closing bracket is an optional paren
4934         and closing_bracket.type == token.RPAR
4935         and not closing_bracket.value
4936         # it's not an import (optional parens are the only thing we can split on
4937         # in this case; attempting a split without them is a waste of time)
4938         and not line.is_import
4939         # there are no standalone comments in the body
4940         and not body.contains_standalone_comments(0)
4941         # and we can actually remove the parens
4942         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4943     ):
4944         omit = {id(closing_bracket), *omit}
4945         try:
4946             yield from right_hand_split(line, line_length, features=features, omit=omit)
4947             return
4948
4949         except CannotSplit:
4950             if not (
4951                 can_be_split(body)
4952                 or is_line_short_enough(body, line_length=line_length)
4953             ):
4954                 raise CannotSplit(
4955                     "Splitting failed, body is still too long and can't be split."
4956                 )
4957
4958             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4959                 raise CannotSplit(
4960                     "The current optional pair of parentheses is bound to fail to"
4961                     " satisfy the splitting algorithm because the head or the tail"
4962                     " contains multiline strings which by definition never fit one"
4963                     " line."
4964                 )
4965
4966     ensure_visible(opening_bracket)
4967     ensure_visible(closing_bracket)
4968     for result in (head, body, tail):
4969         if result:
4970             yield result
4971
4972
4973 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4974     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4975
4976     Do nothing otherwise.
4977
4978     A left- or right-hand split is based on a pair of brackets. Content before
4979     (and including) the opening bracket is left on one line, content inside the
4980     brackets is put on a separate line, and finally content starting with and
4981     following the closing bracket is put on a separate line.
4982
4983     Those are called `head`, `body`, and `tail`, respectively. If the split
4984     produced the same line (all content in `head`) or ended up with an empty `body`
4985     and the `tail` is just the closing bracket, then it's considered failed.
4986     """
4987     tail_len = len(str(tail).strip())
4988     if not body:
4989         if tail_len == 0:
4990             raise CannotSplit("Splitting brackets produced the same line")
4991
4992         elif tail_len < 3:
4993             raise CannotSplit(
4994                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4995                 " not worth it"
4996             )
4997
4998
4999 def bracket_split_build_line(
5000     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
5001 ) -> Line:
5002     """Return a new line with given `leaves` and respective comments from `original`.
5003
5004     If `is_body` is True, the result line is one-indented inside brackets and as such
5005     has its first leaf's prefix normalized and a trailing comma added when expected.
5006     """
5007     result = Line(mode=original.mode, depth=original.depth)
5008     if is_body:
5009         result.inside_brackets = True
5010         result.depth += 1
5011         if leaves:
5012             # Since body is a new indent level, remove spurious leading whitespace.
5013             normalize_prefix(leaves[0], inside_brackets=True)
5014             # Ensure a trailing comma for imports and standalone function arguments, but
5015             # be careful not to add one after any comments or within type annotations.
5016             no_commas = (
5017                 original.is_def
5018                 and opening_bracket.value == "("
5019                 and not any(leaf.type == token.COMMA for leaf in leaves)
5020             )
5021
5022             if original.is_import or no_commas:
5023                 for i in range(len(leaves) - 1, -1, -1):
5024                     if leaves[i].type == STANDALONE_COMMENT:
5025                         continue
5026
5027                     if leaves[i].type != token.COMMA:
5028                         new_comma = Leaf(token.COMMA, ",")
5029                         leaves.insert(i + 1, new_comma)
5030                     break
5031
5032     # Populate the line
5033     for leaf in leaves:
5034         result.append(leaf, preformatted=True)
5035         for comment_after in original.comments_after(leaf):
5036             result.append(comment_after, preformatted=True)
5037     if is_body and should_split_line(result, opening_bracket):
5038         result.should_split_rhs = True
5039     return result
5040
5041
5042 def dont_increase_indentation(split_func: Transformer) -> Transformer:
5043     """Normalize prefix of the first leaf in every line returned by `split_func`.
5044
5045     This is a decorator over relevant split functions.
5046     """
5047
5048     @wraps(split_func)
5049     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5050         for line in split_func(line, features):
5051             normalize_prefix(line.leaves[0], inside_brackets=True)
5052             yield line
5053
5054     return split_wrapper
5055
5056
5057 @dont_increase_indentation
5058 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5059     """Split according to delimiters of the highest priority.
5060
5061     If the appropriate Features are given, the split will add trailing commas
5062     also in function signatures and calls that contain `*` and `**`.
5063     """
5064     try:
5065         last_leaf = line.leaves[-1]
5066     except IndexError:
5067         raise CannotSplit("Line empty")
5068
5069     bt = line.bracket_tracker
5070     try:
5071         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
5072     except ValueError:
5073         raise CannotSplit("No delimiters found")
5074
5075     if delimiter_priority == DOT_PRIORITY:
5076         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
5077             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
5078
5079     current_line = Line(
5080         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5081     )
5082     lowest_depth = sys.maxsize
5083     trailing_comma_safe = True
5084
5085     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5086         """Append `leaf` to current line or to new line if appending impossible."""
5087         nonlocal current_line
5088         try:
5089             current_line.append_safe(leaf, preformatted=True)
5090         except ValueError:
5091             yield current_line
5092
5093             current_line = Line(
5094                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5095             )
5096             current_line.append(leaf)
5097
5098     for leaf in line.leaves:
5099         yield from append_to_line(leaf)
5100
5101         for comment_after in line.comments_after(leaf):
5102             yield from append_to_line(comment_after)
5103
5104         lowest_depth = min(lowest_depth, leaf.bracket_depth)
5105         if leaf.bracket_depth == lowest_depth:
5106             if is_vararg(leaf, within={syms.typedargslist}):
5107                 trailing_comma_safe = (
5108                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
5109                 )
5110             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
5111                 trailing_comma_safe = (
5112                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
5113                 )
5114
5115         leaf_priority = bt.delimiters.get(id(leaf))
5116         if leaf_priority == delimiter_priority:
5117             yield current_line
5118
5119             current_line = Line(
5120                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5121             )
5122     if current_line:
5123         if (
5124             trailing_comma_safe
5125             and delimiter_priority == COMMA_PRIORITY
5126             and current_line.leaves[-1].type != token.COMMA
5127             and current_line.leaves[-1].type != STANDALONE_COMMENT
5128         ):
5129             new_comma = Leaf(token.COMMA, ",")
5130             current_line.append(new_comma)
5131         yield current_line
5132
5133
5134 @dont_increase_indentation
5135 def standalone_comment_split(
5136     line: Line, features: Collection[Feature] = ()
5137 ) -> Iterator[Line]:
5138     """Split standalone comments from the rest of the line."""
5139     if not line.contains_standalone_comments(0):
5140         raise CannotSplit("Line does not have any standalone comments")
5141
5142     current_line = Line(
5143         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5144     )
5145
5146     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5147         """Append `leaf` to current line or to new line if appending impossible."""
5148         nonlocal current_line
5149         try:
5150             current_line.append_safe(leaf, preformatted=True)
5151         except ValueError:
5152             yield current_line
5153
5154             current_line = Line(
5155                 line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5156             )
5157             current_line.append(leaf)
5158
5159     for leaf in line.leaves:
5160         yield from append_to_line(leaf)
5161
5162         for comment_after in line.comments_after(leaf):
5163             yield from append_to_line(comment_after)
5164
5165     if current_line:
5166         yield current_line
5167
5168
5169 def is_import(leaf: Leaf) -> bool:
5170     """Return True if the given leaf starts an import statement."""
5171     p = leaf.parent
5172     t = leaf.type
5173     v = leaf.value
5174     return bool(
5175         t == token.NAME
5176         and (
5177             (v == "import" and p and p.type == syms.import_name)
5178             or (v == "from" and p and p.type == syms.import_from)
5179         )
5180     )
5181
5182
5183 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5184     """Return True if the given leaf is a special comment.
5185     Only returns true for type comments for now."""
5186     t = leaf.type
5187     v = leaf.value
5188     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5189
5190
5191 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5192     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5193     else.
5194
5195     Note: don't use backslashes for formatting or you'll lose your voting rights.
5196     """
5197     if not inside_brackets:
5198         spl = leaf.prefix.split("#")
5199         if "\\" not in spl[0]:
5200             nl_count = spl[-1].count("\n")
5201             if len(spl) > 1:
5202                 nl_count -= 1
5203             leaf.prefix = "\n" * nl_count
5204             return
5205
5206     leaf.prefix = ""
5207
5208
5209 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5210     """Make all string prefixes lowercase.
5211
5212     If remove_u_prefix is given, also removes any u prefix from the string.
5213
5214     Note: Mutates its argument.
5215     """
5216     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5217     assert match is not None, f"failed to match string {leaf.value!r}"
5218     orig_prefix = match.group(1)
5219     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5220     if remove_u_prefix:
5221         new_prefix = new_prefix.replace("u", "")
5222     leaf.value = f"{new_prefix}{match.group(2)}"
5223
5224
5225 def normalize_string_quotes(leaf: Leaf) -> None:
5226     """Prefer double quotes but only if it doesn't cause more escaping.
5227
5228     Adds or removes backslashes as appropriate. Doesn't parse and fix
5229     strings nested in f-strings (yet).
5230
5231     Note: Mutates its argument.
5232     """
5233     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5234     if value[:3] == '"""':
5235         return
5236
5237     elif value[:3] == "'''":
5238         orig_quote = "'''"
5239         new_quote = '"""'
5240     elif value[0] == '"':
5241         orig_quote = '"'
5242         new_quote = "'"
5243     else:
5244         orig_quote = "'"
5245         new_quote = '"'
5246     first_quote_pos = leaf.value.find(orig_quote)
5247     if first_quote_pos == -1:
5248         return  # There's an internal error
5249
5250     prefix = leaf.value[:first_quote_pos]
5251     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5252     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5253     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5254     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5255     if "r" in prefix.casefold():
5256         if unescaped_new_quote.search(body):
5257             # There's at least one unescaped new_quote in this raw string
5258             # so converting is impossible
5259             return
5260
5261         # Do not introduce or remove backslashes in raw strings
5262         new_body = body
5263     else:
5264         # remove unnecessary escapes
5265         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5266         if body != new_body:
5267             # Consider the string without unnecessary escapes as the original
5268             body = new_body
5269             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5270         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5271         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5272     if "f" in prefix.casefold():
5273         matches = re.findall(
5274             r"""
5275             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5276                 ([^{].*?)  # contents of the brackets except if begins with {{
5277             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5278             """,
5279             new_body,
5280             re.VERBOSE,
5281         )
5282         for m in matches:
5283             if "\\" in str(m):
5284                 # Do not introduce backslashes in interpolated expressions
5285                 return
5286
5287     if new_quote == '"""' and new_body[-1:] == '"':
5288         # edge case:
5289         new_body = new_body[:-1] + '\\"'
5290     orig_escape_count = body.count("\\")
5291     new_escape_count = new_body.count("\\")
5292     if new_escape_count > orig_escape_count:
5293         return  # Do not introduce more escaping
5294
5295     if new_escape_count == orig_escape_count and orig_quote == '"':
5296         return  # Prefer double quotes
5297
5298     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5299
5300
5301 def normalize_numeric_literal(leaf: Leaf) -> None:
5302     """Normalizes numeric (float, int, and complex) literals.
5303
5304     All letters used in the representation are normalized to lowercase (except
5305     in Python 2 long literals).
5306     """
5307     text = leaf.value.lower()
5308     if text.startswith(("0o", "0b")):
5309         # Leave octal and binary literals alone.
5310         pass
5311     elif text.startswith("0x"):
5312         text = format_hex(text)
5313     elif "e" in text:
5314         text = format_scientific_notation(text)
5315     elif text.endswith(("j", "l")):
5316         text = format_long_or_complex_number(text)
5317     else:
5318         text = format_float_or_int_string(text)
5319     leaf.value = text
5320
5321
5322 def format_hex(text: str) -> str:
5323     """
5324     Formats a hexadecimal string like "0x12b3"
5325
5326     Uses lowercase because of similarity between "B" and "8", which
5327     can cause security issues.
5328     see: https://github.com/psf/black/issues/1692
5329     """
5330
5331     before, after = text[:2], text[2:]
5332     return f"{before}{after.lower()}"
5333
5334
5335 def format_scientific_notation(text: str) -> str:
5336     """Formats a numeric string utilizing scentific notation"""
5337     before, after = text.split("e")
5338     sign = ""
5339     if after.startswith("-"):
5340         after = after[1:]
5341         sign = "-"
5342     elif after.startswith("+"):
5343         after = after[1:]
5344     before = format_float_or_int_string(before)
5345     return f"{before}e{sign}{after}"
5346
5347
5348 def format_long_or_complex_number(text: str) -> str:
5349     """Formats a long or complex string like `10L` or `10j`"""
5350     number = text[:-1]
5351     suffix = text[-1]
5352     # Capitalize in "2L" because "l" looks too similar to "1".
5353     if suffix == "l":
5354         suffix = "L"
5355     return f"{format_float_or_int_string(number)}{suffix}"
5356
5357
5358 def format_float_or_int_string(text: str) -> str:
5359     """Formats a float string like "1.0"."""
5360     if "." not in text:
5361         return text
5362
5363     before, after = text.split(".")
5364     return f"{before or 0}.{after or 0}"
5365
5366
5367 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5368     """Make existing optional parentheses invisible or create new ones.
5369
5370     `parens_after` is a set of string leaf values immediately after which parens
5371     should be put.
5372
5373     Standardizes on visible parentheses for single-element tuples, and keeps
5374     existing visible parentheses for other tuples and generator expressions.
5375     """
5376     for pc in list_comments(node.prefix, is_endmarker=False):
5377         if pc.value in FMT_OFF:
5378             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5379             return
5380     check_lpar = False
5381     for index, child in enumerate(list(node.children)):
5382         # Fixes a bug where invisible parens are not properly stripped from
5383         # assignment statements that contain type annotations.
5384         if isinstance(child, Node) and child.type == syms.annassign:
5385             normalize_invisible_parens(child, parens_after=parens_after)
5386
5387         # Add parentheses around long tuple unpacking in assignments.
5388         if (
5389             index == 0
5390             and isinstance(child, Node)
5391             and child.type == syms.testlist_star_expr
5392         ):
5393             check_lpar = True
5394
5395         if check_lpar:
5396             if child.type == syms.atom:
5397                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5398                     wrap_in_parentheses(node, child, visible=False)
5399             elif is_one_tuple(child):
5400                 wrap_in_parentheses(node, child, visible=True)
5401             elif node.type == syms.import_from:
5402                 # "import from" nodes store parentheses directly as part of
5403                 # the statement
5404                 if child.type == token.LPAR:
5405                     # make parentheses invisible
5406                     child.value = ""  # type: ignore
5407                     node.children[-1].value = ""  # type: ignore
5408                 elif child.type != token.STAR:
5409                     # insert invisible parentheses
5410                     node.insert_child(index, Leaf(token.LPAR, ""))
5411                     node.append_child(Leaf(token.RPAR, ""))
5412                 break
5413
5414             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5415                 wrap_in_parentheses(node, child, visible=False)
5416
5417         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5418
5419
5420 def normalize_fmt_off(node: Node) -> None:
5421     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5422     try_again = True
5423     while try_again:
5424         try_again = convert_one_fmt_off_pair(node)
5425
5426
5427 def convert_one_fmt_off_pair(node: Node) -> bool:
5428     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5429
5430     Returns True if a pair was converted.
5431     """
5432     for leaf in node.leaves():
5433         previous_consumed = 0
5434         for comment in list_comments(leaf.prefix, is_endmarker=False):
5435             if comment.value not in FMT_PASS:
5436                 previous_consumed = comment.consumed
5437                 continue
5438             # We only want standalone comments. If there's no previous leaf or
5439             # the previous leaf is indentation, it's a standalone comment in
5440             # disguise.
5441             if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
5442                 prev = preceding_leaf(leaf)
5443                 if prev:
5444                     if comment.value in FMT_OFF and prev.type not in WHITESPACE:
5445                         continue
5446                     if comment.value in FMT_SKIP and prev.type in WHITESPACE:
5447                         continue
5448
5449             ignored_nodes = list(generate_ignored_nodes(leaf, comment))
5450             if not ignored_nodes:
5451                 continue
5452
5453             first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5454             parent = first.parent
5455             prefix = first.prefix
5456             first.prefix = prefix[comment.consumed :]
5457             hidden_value = "".join(str(n) for n in ignored_nodes)
5458             if comment.value in FMT_OFF:
5459                 hidden_value = comment.value + "\n" + hidden_value
5460             if comment.value in FMT_SKIP:
5461                 hidden_value += "  " + comment.value
5462             if hidden_value.endswith("\n"):
5463                 # That happens when one of the `ignored_nodes` ended with a NEWLINE
5464                 # leaf (possibly followed by a DEDENT).
5465                 hidden_value = hidden_value[:-1]
5466             first_idx: Optional[int] = None
5467             for ignored in ignored_nodes:
5468                 index = ignored.remove()
5469                 if first_idx is None:
5470                     first_idx = index
5471             assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5472             assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5473             parent.insert_child(
5474                 first_idx,
5475                 Leaf(
5476                     STANDALONE_COMMENT,
5477                     hidden_value,
5478                     prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5479                 ),
5480             )
5481             return True
5482
5483     return False
5484
5485
5486 def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
5487     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5488
5489     If comment is skip, returns leaf only.
5490     Stops at the end of the block.
5491     """
5492     container: Optional[LN] = container_of(leaf)
5493     if comment.value in FMT_SKIP:
5494         prev_sibling = leaf.prev_sibling
5495         if comment.value in leaf.prefix and prev_sibling is not None:
5496             leaf.prefix = leaf.prefix.replace(comment.value, "")
5497             siblings = [prev_sibling]
5498             while (
5499                 "\n" not in prev_sibling.prefix
5500                 and prev_sibling.prev_sibling is not None
5501             ):
5502                 prev_sibling = prev_sibling.prev_sibling
5503                 siblings.insert(0, prev_sibling)
5504             for sibling in siblings:
5505                 yield sibling
5506         elif leaf.parent is not None:
5507             yield leaf.parent
5508         return
5509     while container is not None and container.type != token.ENDMARKER:
5510         if is_fmt_on(container):
5511             return
5512
5513         # fix for fmt: on in children
5514         if contains_fmt_on_at_column(container, leaf.column):
5515             for child in container.children:
5516                 if contains_fmt_on_at_column(child, leaf.column):
5517                     return
5518                 yield child
5519         else:
5520             yield container
5521             container = container.next_sibling
5522
5523
5524 def is_fmt_on(container: LN) -> bool:
5525     """Determine whether formatting is switched on within a container.
5526     Determined by whether the last `# fmt:` comment is `on` or `off`.
5527     """
5528     fmt_on = False
5529     for comment in list_comments(container.prefix, is_endmarker=False):
5530         if comment.value in FMT_ON:
5531             fmt_on = True
5532         elif comment.value in FMT_OFF:
5533             fmt_on = False
5534     return fmt_on
5535
5536
5537 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5538     """Determine if children at a given column have formatting switched on."""
5539     for child in container.children:
5540         if (
5541             isinstance(child, Node)
5542             and first_leaf_column(child) == column
5543             or isinstance(child, Leaf)
5544             and child.column == column
5545         ):
5546             if is_fmt_on(child):
5547                 return True
5548
5549     return False
5550
5551
5552 def first_leaf_column(node: Node) -> Optional[int]:
5553     """Returns the column of the first leaf child of a node."""
5554     for child in node.children:
5555         if isinstance(child, Leaf):
5556             return child.column
5557     return None
5558
5559
5560 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5561     """If it's safe, make the parens in the atom `node` invisible, recursively.
5562     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5563     as they are redundant.
5564
5565     Returns whether the node should itself be wrapped in invisible parentheses.
5566
5567     """
5568
5569     if (
5570         node.type != syms.atom
5571         or is_empty_tuple(node)
5572         or is_one_tuple(node)
5573         or (is_yield(node) and parent.type != syms.expr_stmt)
5574         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5575     ):
5576         return False
5577
5578     if is_walrus_assignment(node):
5579         if parent.type in [syms.annassign, syms.expr_stmt]:
5580             return False
5581
5582     first = node.children[0]
5583     last = node.children[-1]
5584     if first.type == token.LPAR and last.type == token.RPAR:
5585         middle = node.children[1]
5586         # make parentheses invisible
5587         first.value = ""  # type: ignore
5588         last.value = ""  # type: ignore
5589         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5590
5591         if is_atom_with_invisible_parens(middle):
5592             # Strip the invisible parens from `middle` by replacing
5593             # it with the child in-between the invisible parens
5594             middle.replace(middle.children[1])
5595
5596         return False
5597
5598     return True
5599
5600
5601 def is_atom_with_invisible_parens(node: LN) -> bool:
5602     """Given a `LN`, determines whether it's an atom `node` with invisible
5603     parens. Useful in dedupe-ing and normalizing parens.
5604     """
5605     if isinstance(node, Leaf) or node.type != syms.atom:
5606         return False
5607
5608     first, last = node.children[0], node.children[-1]
5609     return (
5610         isinstance(first, Leaf)
5611         and first.type == token.LPAR
5612         and first.value == ""
5613         and isinstance(last, Leaf)
5614         and last.type == token.RPAR
5615         and last.value == ""
5616     )
5617
5618
5619 def is_empty_tuple(node: LN) -> bool:
5620     """Return True if `node` holds an empty tuple."""
5621     return (
5622         node.type == syms.atom
5623         and len(node.children) == 2
5624         and node.children[0].type == token.LPAR
5625         and node.children[1].type == token.RPAR
5626     )
5627
5628
5629 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5630     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5631
5632     Parenthesis can be optional. Returns None otherwise"""
5633     if len(node.children) != 3:
5634         return None
5635
5636     lpar, wrapped, rpar = node.children
5637     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5638         return None
5639
5640     return wrapped
5641
5642
5643 def first_child_is_arith(node: Node) -> bool:
5644     """Whether first child is an arithmetic or a binary arithmetic expression"""
5645     expr_types = {
5646         syms.arith_expr,
5647         syms.shift_expr,
5648         syms.xor_expr,
5649         syms.and_expr,
5650     }
5651     return bool(node.children and node.children[0].type in expr_types)
5652
5653
5654 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5655     """Wrap `child` in parentheses.
5656
5657     This replaces `child` with an atom holding the parentheses and the old
5658     child.  That requires moving the prefix.
5659
5660     If `visible` is False, the leaves will be valueless (and thus invisible).
5661     """
5662     lpar = Leaf(token.LPAR, "(" if visible else "")
5663     rpar = Leaf(token.RPAR, ")" if visible else "")
5664     prefix = child.prefix
5665     child.prefix = ""
5666     index = child.remove() or 0
5667     new_child = Node(syms.atom, [lpar, child, rpar])
5668     new_child.prefix = prefix
5669     parent.insert_child(index, new_child)
5670
5671
5672 def is_one_tuple(node: LN) -> bool:
5673     """Return True if `node` holds a tuple with one element, with or without parens."""
5674     if node.type == syms.atom:
5675         gexp = unwrap_singleton_parenthesis(node)
5676         if gexp is None or gexp.type != syms.testlist_gexp:
5677             return False
5678
5679         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5680
5681     return (
5682         node.type in IMPLICIT_TUPLE
5683         and len(node.children) == 2
5684         and node.children[1].type == token.COMMA
5685     )
5686
5687
5688 def is_walrus_assignment(node: LN) -> bool:
5689     """Return True iff `node` is of the shape ( test := test )"""
5690     inner = unwrap_singleton_parenthesis(node)
5691     return inner is not None and inner.type == syms.namedexpr_test
5692
5693
5694 def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
5695     """Return True iff `node` is a trailer valid in a simple decorator"""
5696     return node.type == syms.trailer and (
5697         (
5698             len(node.children) == 2
5699             and node.children[0].type == token.DOT
5700             and node.children[1].type == token.NAME
5701         )
5702         # last trailer can be arguments
5703         or (
5704             last
5705             and len(node.children) == 3
5706             and node.children[0].type == token.LPAR
5707             # and node.children[1].type == syms.argument
5708             and node.children[2].type == token.RPAR
5709         )
5710     )
5711
5712
5713 def is_simple_decorator_expression(node: LN) -> bool:
5714     """Return True iff `node` could be a 'dotted name' decorator
5715
5716     This function takes the node of the 'namedexpr_test' of the new decorator
5717     grammar and test if it would be valid under the old decorator grammar.
5718
5719     The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
5720     The new grammar is : decorator: @ namedexpr_test NEWLINE
5721     """
5722     if node.type == token.NAME:
5723         return True
5724     if node.type == syms.power:
5725         if node.children:
5726             return (
5727                 node.children[0].type == token.NAME
5728                 and all(map(is_simple_decorator_trailer, node.children[1:-1]))
5729                 and (
5730                     len(node.children) < 2
5731                     or is_simple_decorator_trailer(node.children[-1], last=True)
5732                 )
5733             )
5734     return False
5735
5736
5737 def is_yield(node: LN) -> bool:
5738     """Return True if `node` holds a `yield` or `yield from` expression."""
5739     if node.type == syms.yield_expr:
5740         return True
5741
5742     if node.type == token.NAME and node.value == "yield":  # type: ignore
5743         return True
5744
5745     if node.type != syms.atom:
5746         return False
5747
5748     if len(node.children) != 3:
5749         return False
5750
5751     lpar, expr, rpar = node.children
5752     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5753         return is_yield(expr)
5754
5755     return False
5756
5757
5758 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5759     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5760
5761     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5762     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5763     extended iterable unpacking (PEP 3132) and additional unpacking
5764     generalizations (PEP 448).
5765     """
5766     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5767         return False
5768
5769     p = leaf.parent
5770     if p.type == syms.star_expr:
5771         # Star expressions are also used as assignment targets in extended
5772         # iterable unpacking (PEP 3132).  See what its parent is instead.
5773         if not p.parent:
5774             return False
5775
5776         p = p.parent
5777
5778     return p.type in within
5779
5780
5781 def is_multiline_string(leaf: Leaf) -> bool:
5782     """Return True if `leaf` is a multiline string that actually spans many lines."""
5783     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5784
5785
5786 def is_stub_suite(node: Node) -> bool:
5787     """Return True if `node` is a suite with a stub body."""
5788     if (
5789         len(node.children) != 4
5790         or node.children[0].type != token.NEWLINE
5791         or node.children[1].type != token.INDENT
5792         or node.children[3].type != token.DEDENT
5793     ):
5794         return False
5795
5796     return is_stub_body(node.children[2])
5797
5798
5799 def is_stub_body(node: LN) -> bool:
5800     """Return True if `node` is a simple statement containing an ellipsis."""
5801     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5802         return False
5803
5804     if len(node.children) != 2:
5805         return False
5806
5807     child = node.children[0]
5808     return (
5809         child.type == syms.atom
5810         and len(child.children) == 3
5811         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5812     )
5813
5814
5815 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5816     """Return maximum delimiter priority inside `node`.
5817
5818     This is specific to atoms with contents contained in a pair of parentheses.
5819     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5820     """
5821     if node.type != syms.atom:
5822         return 0
5823
5824     first = node.children[0]
5825     last = node.children[-1]
5826     if not (first.type == token.LPAR and last.type == token.RPAR):
5827         return 0
5828
5829     bt = BracketTracker()
5830     for c in node.children[1:-1]:
5831         if isinstance(c, Leaf):
5832             bt.mark(c)
5833         else:
5834             for leaf in c.leaves():
5835                 bt.mark(leaf)
5836     try:
5837         return bt.max_delimiter_priority()
5838
5839     except ValueError:
5840         return 0
5841
5842
5843 def ensure_visible(leaf: Leaf) -> None:
5844     """Make sure parentheses are visible.
5845
5846     They could be invisible as part of some statements (see
5847     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5848     """
5849     if leaf.type == token.LPAR:
5850         leaf.value = "("
5851     elif leaf.type == token.RPAR:
5852         leaf.value = ")"
5853
5854
5855 def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
5856     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5857
5858     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5859         return False
5860
5861     # We're essentially checking if the body is delimited by commas and there's more
5862     # than one of them (we're excluding the trailing comma and if the delimiter priority
5863     # is still commas, that means there's more).
5864     exclude = set()
5865     trailing_comma = False
5866     try:
5867         last_leaf = line.leaves[-1]
5868         if last_leaf.type == token.COMMA:
5869             trailing_comma = True
5870             exclude.add(id(last_leaf))
5871         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5872     except (IndexError, ValueError):
5873         return False
5874
5875     return max_priority == COMMA_PRIORITY and (
5876         (line.mode.magic_trailing_comma and trailing_comma)
5877         # always explode imports
5878         or opening_bracket.parent.type in {syms.atom, syms.import_from}
5879     )
5880
5881
5882 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5883     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5884     if opening.type != token.LPAR and closing.type != token.RPAR:
5885         return False
5886
5887     depth = closing.bracket_depth + 1
5888     for _opening_index, leaf in enumerate(leaves):
5889         if leaf is opening:
5890             break
5891
5892     else:
5893         raise LookupError("Opening paren not found in `leaves`")
5894
5895     commas = 0
5896     _opening_index += 1
5897     for leaf in leaves[_opening_index:]:
5898         if leaf is closing:
5899             break
5900
5901         bracket_depth = leaf.bracket_depth
5902         if bracket_depth == depth and leaf.type == token.COMMA:
5903             commas += 1
5904             if leaf.parent and leaf.parent.type in {
5905                 syms.arglist,
5906                 syms.typedargslist,
5907             }:
5908                 commas += 1
5909                 break
5910
5911     return commas < 2
5912
5913
5914 def get_features_used(node: Node) -> Set[Feature]:
5915     """Return a set of (relatively) new Python features used in this file.
5916
5917     Currently looking for:
5918     - f-strings;
5919     - underscores in numeric literals;
5920     - trailing commas after * or ** in function signatures and calls;
5921     - positional only arguments in function signatures and lambdas;
5922     - assignment expression;
5923     - relaxed decorator syntax;
5924     """
5925     features: Set[Feature] = set()
5926     for n in node.pre_order():
5927         if n.type == token.STRING:
5928             value_head = n.value[:2]  # type: ignore
5929             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5930                 features.add(Feature.F_STRINGS)
5931
5932         elif n.type == token.NUMBER:
5933             if "_" in n.value:  # type: ignore
5934                 features.add(Feature.NUMERIC_UNDERSCORES)
5935
5936         elif n.type == token.SLASH:
5937             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5938                 features.add(Feature.POS_ONLY_ARGUMENTS)
5939
5940         elif n.type == token.COLONEQUAL:
5941             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5942
5943         elif n.type == syms.decorator:
5944             if len(n.children) > 1 and not is_simple_decorator_expression(
5945                 n.children[1]
5946             ):
5947                 features.add(Feature.RELAXED_DECORATORS)
5948
5949         elif (
5950             n.type in {syms.typedargslist, syms.arglist}
5951             and n.children
5952             and n.children[-1].type == token.COMMA
5953         ):
5954             if n.type == syms.typedargslist:
5955                 feature = Feature.TRAILING_COMMA_IN_DEF
5956             else:
5957                 feature = Feature.TRAILING_COMMA_IN_CALL
5958
5959             for ch in n.children:
5960                 if ch.type in STARS:
5961                     features.add(feature)
5962
5963                 if ch.type == syms.argument:
5964                     for argch in ch.children:
5965                         if argch.type in STARS:
5966                             features.add(feature)
5967
5968     return features
5969
5970
5971 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5972     """Detect the version to target based on the nodes used."""
5973     features = get_features_used(node)
5974     return {
5975         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5976     }
5977
5978
5979 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5980     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5981
5982     Brackets can be omitted if the entire trailer up to and including
5983     a preceding closing bracket fits in one line.
5984
5985     Yielded sets are cumulative (contain results of previous yields, too).  First
5986     set is empty, unless the line should explode, in which case bracket pairs until
5987     the one that needs to explode are omitted.
5988     """
5989
5990     omit: Set[LeafID] = set()
5991     if not line.magic_trailing_comma:
5992         yield omit
5993
5994     length = 4 * line.depth
5995     opening_bracket: Optional[Leaf] = None
5996     closing_bracket: Optional[Leaf] = None
5997     inner_brackets: Set[LeafID] = set()
5998     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5999         length += leaf_length
6000         if length > line_length:
6001             break
6002
6003         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
6004         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
6005             break
6006
6007         if opening_bracket:
6008             if leaf is opening_bracket:
6009                 opening_bracket = None
6010             elif leaf.type in CLOSING_BRACKETS:
6011                 prev = line.leaves[index - 1] if index > 0 else None
6012                 if (
6013                     prev
6014                     and prev.type == token.COMMA
6015                     and not is_one_tuple_between(
6016                         leaf.opening_bracket, leaf, line.leaves
6017                     )
6018                 ):
6019                     # Never omit bracket pairs with trailing commas.
6020                     # We need to explode on those.
6021                     break
6022
6023                 inner_brackets.add(id(leaf))
6024         elif leaf.type in CLOSING_BRACKETS:
6025             prev = line.leaves[index - 1] if index > 0 else None
6026             if prev and prev.type in OPENING_BRACKETS:
6027                 # Empty brackets would fail a split so treat them as "inner"
6028                 # brackets (e.g. only add them to the `omit` set if another
6029                 # pair of brackets was good enough.
6030                 inner_brackets.add(id(leaf))
6031                 continue
6032
6033             if closing_bracket:
6034                 omit.add(id(closing_bracket))
6035                 omit.update(inner_brackets)
6036                 inner_brackets.clear()
6037                 yield omit
6038
6039             if (
6040                 prev
6041                 and prev.type == token.COMMA
6042                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
6043             ):
6044                 # Never omit bracket pairs with trailing commas.
6045                 # We need to explode on those.
6046                 break
6047
6048             if leaf.value:
6049                 opening_bracket = leaf.opening_bracket
6050                 closing_bracket = leaf
6051
6052
6053 def get_future_imports(node: Node) -> Set[str]:
6054     """Return a set of __future__ imports in the file."""
6055     imports: Set[str] = set()
6056
6057     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
6058         for child in children:
6059             if isinstance(child, Leaf):
6060                 if child.type == token.NAME:
6061                     yield child.value
6062
6063             elif child.type == syms.import_as_name:
6064                 orig_name = child.children[0]
6065                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
6066                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
6067                 yield orig_name.value
6068
6069             elif child.type == syms.import_as_names:
6070                 yield from get_imports_from_children(child.children)
6071
6072             else:
6073                 raise AssertionError("Invalid syntax parsing imports")
6074
6075     for child in node.children:
6076         if child.type != syms.simple_stmt:
6077             break
6078
6079         first_child = child.children[0]
6080         if isinstance(first_child, Leaf):
6081             # Continue looking if we see a docstring; otherwise stop.
6082             if (
6083                 len(child.children) == 2
6084                 and first_child.type == token.STRING
6085                 and child.children[1].type == token.NEWLINE
6086             ):
6087                 continue
6088
6089             break
6090
6091         elif first_child.type == syms.import_from:
6092             module_name = first_child.children[1]
6093             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
6094                 break
6095
6096             imports |= set(get_imports_from_children(first_child.children[3:]))
6097         else:
6098             break
6099
6100     return imports
6101
6102
6103 @lru_cache()
6104 def get_gitignore(root: Path) -> PathSpec:
6105     """ Return a PathSpec matching gitignore content if present."""
6106     gitignore = root / ".gitignore"
6107     lines: List[str] = []
6108     if gitignore.is_file():
6109         with gitignore.open() as gf:
6110             lines = gf.readlines()
6111     return PathSpec.from_lines("gitwildmatch", lines)
6112
6113
6114 def normalize_path_maybe_ignore(
6115     path: Path, root: Path, report: "Report"
6116 ) -> Optional[str]:
6117     """Normalize `path`. May return `None` if `path` was ignored.
6118
6119     `report` is where "path ignored" output goes.
6120     """
6121     try:
6122         abspath = path if path.is_absolute() else Path.cwd() / path
6123         normalized_path = abspath.resolve().relative_to(root).as_posix()
6124     except OSError as e:
6125         report.path_ignored(path, f"cannot be read because {e}")
6126         return None
6127
6128     except ValueError:
6129         if path.is_symlink():
6130             report.path_ignored(path, f"is a symbolic link that points outside {root}")
6131             return None
6132
6133         raise
6134
6135     return normalized_path
6136
6137
6138 def path_is_excluded(
6139     normalized_path: str,
6140     pattern: Optional[Pattern[str]],
6141 ) -> bool:
6142     match = pattern.search(normalized_path) if pattern else None
6143     return bool(match and match.group(0))
6144
6145
6146 def gen_python_files(
6147     paths: Iterable[Path],
6148     root: Path,
6149     include: Optional[Pattern[str]],
6150     exclude: Pattern[str],
6151     extend_exclude: Optional[Pattern[str]],
6152     force_exclude: Optional[Pattern[str]],
6153     report: "Report",
6154     gitignore: PathSpec,
6155 ) -> Iterator[Path]:
6156     """Generate all files under `path` whose paths are not excluded by the
6157     `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
6158     but are included by the `include` regex.
6159
6160     Symbolic links pointing outside of the `root` directory are ignored.
6161
6162     `report` is where output about exclusions goes.
6163     """
6164     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
6165     for child in paths:
6166         normalized_path = normalize_path_maybe_ignore(child, root, report)
6167         if normalized_path is None:
6168             continue
6169
6170         # First ignore files matching .gitignore
6171         if gitignore.match_file(normalized_path):
6172             report.path_ignored(child, "matches the .gitignore file content")
6173             continue
6174
6175         # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
6176         normalized_path = "/" + normalized_path
6177         if child.is_dir():
6178             normalized_path += "/"
6179
6180         if path_is_excluded(normalized_path, exclude):
6181             report.path_ignored(child, "matches the --exclude regular expression")
6182             continue
6183
6184         if path_is_excluded(normalized_path, extend_exclude):
6185             report.path_ignored(
6186                 child, "matches the --extend-exclude regular expression"
6187             )
6188             continue
6189
6190         if path_is_excluded(normalized_path, force_exclude):
6191             report.path_ignored(child, "matches the --force-exclude regular expression")
6192             continue
6193
6194         if child.is_dir():
6195             yield from gen_python_files(
6196                 child.iterdir(),
6197                 root,
6198                 include,
6199                 exclude,
6200                 extend_exclude,
6201                 force_exclude,
6202                 report,
6203                 gitignore,
6204             )
6205
6206         elif child.is_file():
6207             include_match = include.search(normalized_path) if include else True
6208             if include_match:
6209                 yield child
6210
6211
6212 @lru_cache()
6213 def find_project_root(srcs: Iterable[str]) -> Path:
6214     """Return a directory containing .git, .hg, or pyproject.toml.
6215
6216     That directory will be a common parent of all files and directories
6217     passed in `srcs`.
6218
6219     If no directory in the tree contains a marker that would specify it's the
6220     project root, the root of the file system is returned.
6221     """
6222     if not srcs:
6223         return Path("/").resolve()
6224
6225     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
6226
6227     # A list of lists of parents for each 'src'. 'src' is included as a
6228     # "parent" of itself if it is a directory
6229     src_parents = [
6230         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
6231     ]
6232
6233     common_base = max(
6234         set.intersection(*(set(parents) for parents in src_parents)),
6235         key=lambda path: path.parts,
6236     )
6237
6238     for directory in (common_base, *common_base.parents):
6239         if (directory / ".git").exists():
6240             return directory
6241
6242         if (directory / ".hg").is_dir():
6243             return directory
6244
6245         if (directory / "pyproject.toml").is_file():
6246             return directory
6247
6248     return directory
6249
6250
6251 @dataclass
6252 class Report:
6253     """Provides a reformatting counter. Can be rendered with `str(report)`."""
6254
6255     check: bool = False
6256     diff: bool = False
6257     quiet: bool = False
6258     verbose: bool = False
6259     change_count: int = 0
6260     same_count: int = 0
6261     failure_count: int = 0
6262
6263     def done(self, src: Path, changed: Changed) -> None:
6264         """Increment the counter for successful reformatting. Write out a message."""
6265         if changed is Changed.YES:
6266             reformatted = "would reformat" if self.check or self.diff else "reformatted"
6267             if self.verbose or not self.quiet:
6268                 out(f"{reformatted} {src}")
6269             self.change_count += 1
6270         else:
6271             if self.verbose:
6272                 if changed is Changed.NO:
6273                     msg = f"{src} already well formatted, good job."
6274                 else:
6275                     msg = f"{src} wasn't modified on disk since last run."
6276                 out(msg, bold=False)
6277             self.same_count += 1
6278
6279     def failed(self, src: Path, message: str) -> None:
6280         """Increment the counter for failed reformatting. Write out a message."""
6281         err(f"error: cannot format {src}: {message}")
6282         self.failure_count += 1
6283
6284     def path_ignored(self, path: Path, message: str) -> None:
6285         if self.verbose:
6286             out(f"{path} ignored: {message}", bold=False)
6287
6288     @property
6289     def return_code(self) -> int:
6290         """Return the exit code that the app should use.
6291
6292         This considers the current state of changed files and failures:
6293         - if there were any failures, return 123;
6294         - if any files were changed and --check is being used, return 1;
6295         - otherwise return 0.
6296         """
6297         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6298         # 126 we have special return codes reserved by the shell.
6299         if self.failure_count:
6300             return 123
6301
6302         elif self.change_count and self.check:
6303             return 1
6304
6305         return 0
6306
6307     def __str__(self) -> str:
6308         """Render a color report of the current state.
6309
6310         Use `click.unstyle` to remove colors.
6311         """
6312         if self.check or self.diff:
6313             reformatted = "would be reformatted"
6314             unchanged = "would be left unchanged"
6315             failed = "would fail to reformat"
6316         else:
6317             reformatted = "reformatted"
6318             unchanged = "left unchanged"
6319             failed = "failed to reformat"
6320         report = []
6321         if self.change_count:
6322             s = "s" if self.change_count > 1 else ""
6323             report.append(
6324                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6325             )
6326         if self.same_count:
6327             s = "s" if self.same_count > 1 else ""
6328             report.append(f"{self.same_count} file{s} {unchanged}")
6329         if self.failure_count:
6330             s = "s" if self.failure_count > 1 else ""
6331             report.append(
6332                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6333             )
6334         return ", ".join(report) + "."
6335
6336
6337 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6338     filename = "<unknown>"
6339     if sys.version_info >= (3, 8):
6340         # TODO: support Python 4+ ;)
6341         for minor_version in range(sys.version_info[1], 4, -1):
6342             try:
6343                 return ast.parse(src, filename, feature_version=(3, minor_version))
6344             except SyntaxError:
6345                 continue
6346     else:
6347         for feature_version in (7, 6):
6348             try:
6349                 return ast3.parse(src, filename, feature_version=feature_version)
6350             except SyntaxError:
6351                 continue
6352     if ast27.__name__ == "ast":
6353         raise SyntaxError(
6354             "The requested source code has invalid Python 3 syntax.\n"
6355             "If you are trying to format Python 2 files please reinstall Black"
6356             " with the 'python2' extra: `python3 -m pip install black[python2]`."
6357         )
6358     return ast27.parse(src)
6359
6360
6361 def _fixup_ast_constants(
6362     node: Union[ast.AST, ast3.AST, ast27.AST]
6363 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6364     """Map ast nodes deprecated in 3.8 to Constant."""
6365     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6366         return ast.Constant(value=node.s)
6367
6368     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6369         return ast.Constant(value=node.n)
6370
6371     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6372         return ast.Constant(value=node.value)
6373
6374     return node
6375
6376
6377 def _stringify_ast(
6378     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6379 ) -> Iterator[str]:
6380     """Simple visitor generating strings to compare ASTs by content."""
6381
6382     node = _fixup_ast_constants(node)
6383
6384     yield f"{'  ' * depth}{node.__class__.__name__}("
6385
6386     for field in sorted(node._fields):  # noqa: F402
6387         # TypeIgnore has only one field 'lineno' which breaks this comparison
6388         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6389         if sys.version_info >= (3, 8):
6390             type_ignore_classes += (ast.TypeIgnore,)
6391         if isinstance(node, type_ignore_classes):
6392             break
6393
6394         try:
6395             value = getattr(node, field)
6396         except AttributeError:
6397             continue
6398
6399         yield f"{'  ' * (depth+1)}{field}="
6400
6401         if isinstance(value, list):
6402             for item in value:
6403                 # Ignore nested tuples within del statements, because we may insert
6404                 # parentheses and they change the AST.
6405                 if (
6406                     field == "targets"
6407                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6408                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6409                 ):
6410                     for item in item.elts:
6411                         yield from _stringify_ast(item, depth + 2)
6412
6413                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6414                     yield from _stringify_ast(item, depth + 2)
6415
6416         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6417             yield from _stringify_ast(value, depth + 2)
6418
6419         else:
6420             # Constant strings may be indented across newlines, if they are
6421             # docstrings; fold spaces after newlines when comparing. Similarly,
6422             # trailing and leading space may be removed.
6423             if (
6424                 isinstance(node, ast.Constant)
6425                 and field == "value"
6426                 and isinstance(value, str)
6427             ):
6428                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6429             else:
6430                 normalized = value
6431             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6432
6433     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6434
6435
6436 def assert_equivalent(src: str, dst: str) -> None:
6437     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6438     try:
6439         src_ast = parse_ast(src)
6440     except Exception as exc:
6441         raise AssertionError(
6442             "cannot use --safe with this file; failed to parse source file.  AST"
6443             f" error message: {exc}"
6444         )
6445
6446     try:
6447         dst_ast = parse_ast(dst)
6448     except Exception as exc:
6449         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6450         raise AssertionError(
6451             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6452             " on https://github.com/psf/black/issues.  This invalid output might be"
6453             f" helpful: {log}"
6454         ) from None
6455
6456     src_ast_str = "\n".join(_stringify_ast(src_ast))
6457     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6458     if src_ast_str != dst_ast_str:
6459         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6460         raise AssertionError(
6461             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6462             " source.  Please report a bug on https://github.com/psf/black/issues. "
6463             f" This diff might be helpful: {log}"
6464         ) from None
6465
6466
6467 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6468     """Raise AssertionError if `dst` reformats differently the second time."""
6469     newdst = format_str(dst, mode=mode)
6470     if dst != newdst:
6471         log = dump_to_file(
6472             str(mode),
6473             diff(src, dst, "source", "first pass"),
6474             diff(dst, newdst, "first pass", "second pass"),
6475         )
6476         raise AssertionError(
6477             "INTERNAL ERROR: Black produced different code on the second pass of the"
6478             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6479             f"  This diff might be helpful: {log}"
6480         ) from None
6481
6482
6483 @mypyc_attr(patchable=True)
6484 def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
6485     """Dump `output` to a temporary file. Return path to the file."""
6486     with tempfile.NamedTemporaryFile(
6487         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6488     ) as f:
6489         for lines in output:
6490             f.write(lines)
6491             if ensure_final_newline and lines and lines[-1] != "\n":
6492                 f.write("\n")
6493     return f.name
6494
6495
6496 @contextmanager
6497 def nullcontext() -> Iterator[None]:
6498     """Return an empty context manager.
6499
6500     To be used like `nullcontext` in Python 3.7.
6501     """
6502     yield
6503
6504
6505 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6506     """Return a unified diff string between strings `a` and `b`."""
6507     import difflib
6508
6509     a_lines = [line for line in a.splitlines(keepends=True)]
6510     b_lines = [line for line in b.splitlines(keepends=True)]
6511     diff_lines = []
6512     for line in difflib.unified_diff(
6513         a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
6514     ):
6515         # Work around https://bugs.python.org/issue2142
6516         # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
6517         if line[-1] == "\n":
6518             diff_lines.append(line)
6519         else:
6520             diff_lines.append(line + "\n")
6521             diff_lines.append("\\ No newline at end of file\n")
6522     return "".join(diff_lines)
6523
6524
6525 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6526     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6527     err("Aborted!")
6528     for task in tasks:
6529         task.cancel()
6530
6531
6532 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6533     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6534     try:
6535         if sys.version_info[:2] >= (3, 7):
6536             all_tasks = asyncio.all_tasks
6537         else:
6538             all_tasks = asyncio.Task.all_tasks
6539         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6540         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6541         if not to_cancel:
6542             return
6543
6544         for task in to_cancel:
6545             task.cancel()
6546         loop.run_until_complete(
6547             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6548         )
6549     finally:
6550         # `concurrent.futures.Future` objects cannot be cancelled once they
6551         # are already running. There might be some when the `shutdown()` happened.
6552         # Silence their logger's spew about the event loop being closed.
6553         cf_logger = logging.getLogger("concurrent.futures")
6554         cf_logger.setLevel(logging.CRITICAL)
6555         loop.close()
6556
6557
6558 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6559     """Replace `regex` with `replacement` twice on `original`.
6560
6561     This is used by string normalization to perform replaces on
6562     overlapping matches.
6563     """
6564     return regex.sub(replacement, regex.sub(replacement, original))
6565
6566
6567 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6568     """Compile a regular expression string in `regex`.
6569
6570     If it contains newlines, use verbose mode.
6571     """
6572     if "\n" in regex:
6573         regex = "(?x)" + regex
6574     compiled: Pattern[str] = re.compile(regex)
6575     return compiled
6576
6577
6578 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6579     """Like `reversed(enumerate(sequence))` if that were possible."""
6580     index = len(sequence) - 1
6581     for element in reversed(sequence):
6582         yield (index, element)
6583         index -= 1
6584
6585
6586 def enumerate_with_length(
6587     line: Line, reversed: bool = False
6588 ) -> Iterator[Tuple[Index, Leaf, int]]:
6589     """Return an enumeration of leaves with their length.
6590
6591     Stops prematurely on multiline strings and standalone comments.
6592     """
6593     op = cast(
6594         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6595         enumerate_reversed if reversed else enumerate,
6596     )
6597     for index, leaf in op(line.leaves):
6598         length = len(leaf.prefix) + len(leaf.value)
6599         if "\n" in leaf.value:
6600             return  # Multiline strings, we can't continue.
6601
6602         for comment in line.comments_after(leaf):
6603             length += len(comment.value)
6604
6605         yield index, leaf, length
6606
6607
6608 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6609     """Return True if `line` is no longer than `line_length`.
6610
6611     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6612     """
6613     if not line_str:
6614         line_str = line_to_string(line)
6615     return (
6616         len(line_str) <= line_length
6617         and "\n" not in line_str  # multiline strings
6618         and not line.contains_standalone_comments()
6619     )
6620
6621
6622 def can_be_split(line: Line) -> bool:
6623     """Return False if the line cannot be split *for sure*.
6624
6625     This is not an exhaustive search but a cheap heuristic that we can use to
6626     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6627     in unnecessary parentheses).
6628     """
6629     leaves = line.leaves
6630     if len(leaves) < 2:
6631         return False
6632
6633     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6634         call_count = 0
6635         dot_count = 0
6636         next = leaves[-1]
6637         for leaf in leaves[-2::-1]:
6638             if leaf.type in OPENING_BRACKETS:
6639                 if next.type not in CLOSING_BRACKETS:
6640                     return False
6641
6642                 call_count += 1
6643             elif leaf.type == token.DOT:
6644                 dot_count += 1
6645             elif leaf.type == token.NAME:
6646                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6647                     return False
6648
6649             elif leaf.type not in CLOSING_BRACKETS:
6650                 return False
6651
6652             if dot_count > 1 and call_count > 1:
6653                 return False
6654
6655     return True
6656
6657
6658 def can_omit_invisible_parens(
6659     line: Line,
6660     line_length: int,
6661     omit_on_explode: Collection[LeafID] = (),
6662 ) -> bool:
6663     """Does `line` have a shape safe to reformat without optional parens around it?
6664
6665     Returns True for only a subset of potentially nice looking formattings but
6666     the point is to not return false positives that end up producing lines that
6667     are too long.
6668     """
6669     bt = line.bracket_tracker
6670     if not bt.delimiters:
6671         # Without delimiters the optional parentheses are useless.
6672         return True
6673
6674     max_priority = bt.max_delimiter_priority()
6675     if bt.delimiter_count_with_priority(max_priority) > 1:
6676         # With more than one delimiter of a kind the optional parentheses read better.
6677         return False
6678
6679     if max_priority == DOT_PRIORITY:
6680         # A single stranded method call doesn't require optional parentheses.
6681         return True
6682
6683     assert len(line.leaves) >= 2, "Stranded delimiter"
6684
6685     # With a single delimiter, omit if the expression starts or ends with
6686     # a bracket.
6687     first = line.leaves[0]
6688     second = line.leaves[1]
6689     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6690         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6691             return True
6692
6693         # Note: we are not returning False here because a line might have *both*
6694         # a leading opening bracket and a trailing closing bracket.  If the
6695         # opening bracket doesn't match our rule, maybe the closing will.
6696
6697     penultimate = line.leaves[-2]
6698     last = line.leaves[-1]
6699     if line.magic_trailing_comma:
6700         try:
6701             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6702         except LookupError:
6703             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6704             return False
6705
6706     if (
6707         last.type == token.RPAR
6708         or last.type == token.RBRACE
6709         or (
6710             # don't use indexing for omitting optional parentheses;
6711             # it looks weird
6712             last.type == token.RSQB
6713             and last.parent
6714             and last.parent.type != syms.trailer
6715         )
6716     ):
6717         if penultimate.type in OPENING_BRACKETS:
6718             # Empty brackets don't help.
6719             return False
6720
6721         if is_multiline_string(first):
6722             # Additional wrapping of a multiline string in this situation is
6723             # unnecessary.
6724             return True
6725
6726         if line.magic_trailing_comma and penultimate.type == token.COMMA:
6727             # The rightmost non-omitted bracket pair is the one we want to explode on.
6728             return True
6729
6730         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6731             return True
6732
6733     return False
6734
6735
6736 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6737     """See `can_omit_invisible_parens`."""
6738     remainder = False
6739     length = 4 * line.depth
6740     _index = -1
6741     for _index, leaf, leaf_length in enumerate_with_length(line):
6742         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6743             remainder = True
6744         if remainder:
6745             length += leaf_length
6746             if length > line_length:
6747                 break
6748
6749             if leaf.type in OPENING_BRACKETS:
6750                 # There are brackets we can further split on.
6751                 remainder = False
6752
6753     else:
6754         # checked the entire string and line length wasn't exceeded
6755         if len(line.leaves) == _index + 1:
6756             return True
6757
6758     return False
6759
6760
6761 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6762     """See `can_omit_invisible_parens`."""
6763     length = 4 * line.depth
6764     seen_other_brackets = False
6765     for _index, leaf, leaf_length in enumerate_with_length(line):
6766         length += leaf_length
6767         if leaf is last.opening_bracket:
6768             if seen_other_brackets or length <= line_length:
6769                 return True
6770
6771         elif leaf.type in OPENING_BRACKETS:
6772             # There are brackets we can further split on.
6773             seen_other_brackets = True
6774
6775     return False
6776
6777
6778 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6779     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6780     stop_after = None
6781     last = None
6782     for leaf in reversed(leaves):
6783         if stop_after:
6784             if leaf is stop_after:
6785                 stop_after = None
6786             continue
6787
6788         if last:
6789             return leaf, last
6790
6791         if id(leaf) in omit:
6792             stop_after = leaf.opening_bracket
6793         else:
6794             last = leaf
6795     else:
6796         raise LookupError("Last two leaves were also skipped")
6797
6798
6799 def run_transformer(
6800     line: Line,
6801     transform: Transformer,
6802     mode: Mode,
6803     features: Collection[Feature],
6804     *,
6805     line_str: str = "",
6806 ) -> List[Line]:
6807     if not line_str:
6808         line_str = line_to_string(line)
6809     result: List[Line] = []
6810     for transformed_line in transform(line, features):
6811         if str(transformed_line).strip("\n") == line_str:
6812             raise CannotTransform("Line transformer returned an unchanged result")
6813
6814         result.extend(transform_line(transformed_line, mode=mode, features=features))
6815
6816     if not (
6817         transform.__name__ == "rhs"
6818         and line.bracket_tracker.invisible
6819         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6820         and not line.contains_multiline_strings()
6821         and not result[0].contains_uncollapsable_type_comments()
6822         and not result[0].contains_unsplittable_type_ignore()
6823         and not is_line_short_enough(result[0], line_length=mode.line_length)
6824     ):
6825         return result
6826
6827     line_copy = line.clone()
6828     append_leaves(line_copy, line, line.leaves)
6829     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6830     second_opinion = run_transformer(
6831         line_copy, transform, mode, features_fop, line_str=line_str
6832     )
6833     if all(
6834         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6835     ):
6836         result = second_opinion
6837     return result
6838
6839
6840 def get_cache_file(mode: Mode) -> Path:
6841     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6842
6843
6844 def read_cache(mode: Mode) -> Cache:
6845     """Read the cache if it exists and is well formed.
6846
6847     If it is not well formed, the call to write_cache later should resolve the issue.
6848     """
6849     cache_file = get_cache_file(mode)
6850     if not cache_file.exists():
6851         return {}
6852
6853     with cache_file.open("rb") as fobj:
6854         try:
6855             cache: Cache = pickle.load(fobj)
6856         except (pickle.UnpicklingError, ValueError):
6857             return {}
6858
6859     return cache
6860
6861
6862 def get_cache_info(path: Path) -> CacheInfo:
6863     """Return the information used to check if a file is already formatted or not."""
6864     stat = path.stat()
6865     return stat.st_mtime, stat.st_size
6866
6867
6868 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6869     """Split an iterable of paths in `sources` into two sets.
6870
6871     The first contains paths of files that modified on disk or are not in the
6872     cache. The other contains paths to non-modified files.
6873     """
6874     todo, done = set(), set()
6875     for src in sources:
6876         res_src = src.resolve()
6877         if cache.get(str(res_src)) != get_cache_info(res_src):
6878             todo.add(src)
6879         else:
6880             done.add(src)
6881     return todo, done
6882
6883
6884 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6885     """Update the cache file."""
6886     cache_file = get_cache_file(mode)
6887     try:
6888         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6889         new_cache = {
6890             **cache,
6891             **{str(src.resolve()): get_cache_info(src) for src in sources},
6892         }
6893         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6894             pickle.dump(new_cache, f, protocol=4)
6895         os.replace(f.name, cache_file)
6896     except OSError:
6897         pass
6898
6899
6900 def patch_click() -> None:
6901     """Make Click not crash.
6902
6903     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6904     default which restricts paths that it can access during the lifetime of the
6905     application.  Click refuses to work in this scenario by raising a RuntimeError.
6906
6907     In case of Black the likelihood that non-ASCII characters are going to be used in
6908     file paths is minimal since it's Python source code.  Moreover, this crash was
6909     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6910     """
6911     try:
6912         from click import core
6913         from click import _unicodefun  # type: ignore
6914     except ModuleNotFoundError:
6915         return
6916
6917     for module in (core, _unicodefun):
6918         if hasattr(module, "_verify_python3_env"):
6919             module._verify_python3_env = lambda: None
6920
6921
6922 def patched_main() -> None:
6923     freeze_support()
6924     patch_click()
6925     main()
6926
6927
6928 def is_docstring(leaf: Leaf) -> bool:
6929     if not is_multiline_string(leaf):
6930         # For the purposes of docstring re-indentation, we don't need to do anything
6931         # with single-line docstrings.
6932         return False
6933
6934     if prev_siblings_are(
6935         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
6936     ):
6937         return True
6938
6939     # Multiline docstring on the same line as the `def`.
6940     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
6941         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
6942         # grammar. We're safe to return True without further checks.
6943         return True
6944
6945     return False
6946
6947
6948 def lines_with_leading_tabs_expanded(s: str) -> List[str]:
6949     """
6950     Splits string into lines and expands only leading tabs (following the normal
6951     Python rules)
6952     """
6953     lines = []
6954     for line in s.splitlines():
6955         # Find the index of the first non-whitespace character after a string of
6956         # whitespace that includes at least one tab
6957         match = re.match(r"\s*\t+\s*(\S)", line)
6958         if match:
6959             first_non_whitespace_idx = match.start(1)
6960
6961             lines.append(
6962                 line[:first_non_whitespace_idx].expandtabs()
6963                 + line[first_non_whitespace_idx:]
6964             )
6965         else:
6966             lines.append(line)
6967     return lines
6968
6969
6970 def fix_docstring(docstring: str, prefix: str) -> str:
6971     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6972     if not docstring:
6973         return ""
6974     lines = lines_with_leading_tabs_expanded(docstring)
6975     # Determine minimum indentation (first line doesn't count):
6976     indent = sys.maxsize
6977     for line in lines[1:]:
6978         stripped = line.lstrip()
6979         if stripped:
6980             indent = min(indent, len(line) - len(stripped))
6981     # Remove indentation (first line is special):
6982     trimmed = [lines[0].strip()]
6983     if indent < sys.maxsize:
6984         last_line_idx = len(lines) - 2
6985         for i, line in enumerate(lines[1:]):
6986             stripped_line = line[indent:].rstrip()
6987             if stripped_line or i == last_line_idx:
6988                 trimmed.append(prefix + stripped_line)
6989             else:
6990                 trimmed.append("")
6991     return "\n".join(trimmed)
6992
6993
6994 if __name__ == "__main__":
6995     patched_main()