]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

614fa8ae5ceee9eb7f596a26a38f0d6b334965a1
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 T = TypeVar("T")
116 E = TypeVar("E", bound=Exception)
117
118
119 class Ok(Generic[T]):
120     def __init__(self, value: T) -> None:
121         self._value = value
122
123     def ok(self) -> T:
124         return self._value
125
126
127 class Err(Generic[E]):
128     def __init__(self, e: E) -> None:
129         self._e = e
130
131     def err(self) -> E:
132         return self._e
133
134
135 # The 'Result' return type is used to implement an error-handling model heavily
136 # influenced by that used by the Rust programming language
137 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
138 Result = Union[Ok[T], Err[E]]
139 TResult = Result[T, CannotTransform]  # (T)ransform Result
140 TMatchResult = TResult[Index]
141
142
143 class WriteBack(Enum):
144     NO = 0
145     YES = 1
146     DIFF = 2
147     CHECK = 3
148     COLOR_DIFF = 4
149
150     @classmethod
151     def from_configuration(
152         cls, *, check: bool, diff: bool, color: bool = False
153     ) -> "WriteBack":
154         if check and not diff:
155             return cls.CHECK
156
157         if diff and color:
158             return cls.COLOR_DIFF
159
160         return cls.DIFF if diff else cls.YES
161
162
163 class Changed(Enum):
164     NO = 0
165     CACHED = 1
166     YES = 2
167
168
169 class TargetVersion(Enum):
170     PY27 = 2
171     PY33 = 3
172     PY34 = 4
173     PY35 = 5
174     PY36 = 6
175     PY37 = 7
176     PY38 = 8
177
178     def is_python2(self) -> bool:
179         return self is TargetVersion.PY27
180
181
182 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
183
184
185 class Feature(Enum):
186     # All string literals are unicode
187     UNICODE_LITERALS = 1
188     F_STRINGS = 2
189     NUMERIC_UNDERSCORES = 3
190     TRAILING_COMMA_IN_CALL = 4
191     TRAILING_COMMA_IN_DEF = 5
192     # The following two feature-flags are mutually exclusive, and exactly one should be
193     # set for every version of python.
194     ASYNC_IDENTIFIERS = 6
195     ASYNC_KEYWORDS = 7
196     ASSIGNMENT_EXPRESSIONS = 8
197     POS_ONLY_ARGUMENTS = 9
198
199
200 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
201     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
202     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
203     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
204     TargetVersion.PY35: {
205         Feature.UNICODE_LITERALS,
206         Feature.TRAILING_COMMA_IN_CALL,
207         Feature.ASYNC_IDENTIFIERS,
208     },
209     TargetVersion.PY36: {
210         Feature.UNICODE_LITERALS,
211         Feature.F_STRINGS,
212         Feature.NUMERIC_UNDERSCORES,
213         Feature.TRAILING_COMMA_IN_CALL,
214         Feature.TRAILING_COMMA_IN_DEF,
215         Feature.ASYNC_IDENTIFIERS,
216     },
217     TargetVersion.PY37: {
218         Feature.UNICODE_LITERALS,
219         Feature.F_STRINGS,
220         Feature.NUMERIC_UNDERSCORES,
221         Feature.TRAILING_COMMA_IN_CALL,
222         Feature.TRAILING_COMMA_IN_DEF,
223         Feature.ASYNC_KEYWORDS,
224     },
225     TargetVersion.PY38: {
226         Feature.UNICODE_LITERALS,
227         Feature.F_STRINGS,
228         Feature.NUMERIC_UNDERSCORES,
229         Feature.TRAILING_COMMA_IN_CALL,
230         Feature.TRAILING_COMMA_IN_DEF,
231         Feature.ASYNC_KEYWORDS,
232         Feature.ASSIGNMENT_EXPRESSIONS,
233         Feature.POS_ONLY_ARGUMENTS,
234     },
235 }
236
237
238 @dataclass
239 class Mode:
240     target_versions: Set[TargetVersion] = field(default_factory=set)
241     line_length: int = DEFAULT_LINE_LENGTH
242     string_normalization: bool = True
243     is_pyi: bool = False
244
245     def get_cache_key(self) -> str:
246         if self.target_versions:
247             version_str = ",".join(
248                 str(version.value)
249                 for version in sorted(self.target_versions, key=lambda v: v.value)
250             )
251         else:
252             version_str = "-"
253         parts = [
254             version_str,
255             str(self.line_length),
256             str(int(self.string_normalization)),
257             str(int(self.is_pyi)),
258         ]
259         return ".".join(parts)
260
261
262 # Legacy name, left for integrations.
263 FileMode = Mode
264
265
266 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
267     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
268
269
270 def find_pyproject_toml(path_search_start: str) -> Optional[str]:
271     """Find the absolute filepath to a pyproject.toml if it exists"""
272     path_project_root = find_project_root(path_search_start)
273     path_pyproject_toml = path_project_root / "pyproject.toml"
274     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
275
276
277 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
278     """Parse a pyproject toml file, pulling out relevant parts for Black
279
280     If parsing fails, will raise a toml.TomlDecodeError
281     """
282     pyproject_toml = toml.load(path_config)
283     config = pyproject_toml.get("tool", {}).get("black", {})
284     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
285
286
287 def read_pyproject_toml(
288     ctx: click.Context, param: click.Parameter, value: Optional[str]
289 ) -> Optional[str]:
290     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
291
292     Returns the path to a successfully found and read configuration file, None
293     otherwise.
294     """
295     if not value:
296         value = find_pyproject_toml(ctx.params.get("src", ()))
297         if value is None:
298             return None
299
300     try:
301         config = parse_pyproject_toml(value)
302     except (toml.TomlDecodeError, OSError) as e:
303         raise click.FileError(
304             filename=value, hint=f"Error reading configuration file: {e}"
305         )
306
307     if not config:
308         return None
309
310     target_version = config.get("target_version")
311     if target_version is not None and not isinstance(target_version, list):
312         raise click.BadOptionUsage(
313             "target-version", f"Config key target-version must be a list"
314         )
315
316     default_map: Dict[str, Any] = {}
317     if ctx.default_map:
318         default_map.update(ctx.default_map)
319     default_map.update(config)
320
321     ctx.default_map = default_map
322     return value
323
324
325 def target_version_option_callback(
326     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
327 ) -> List[TargetVersion]:
328     """Compute the target versions from a --target-version flag.
329
330     This is its own function because mypy couldn't infer the type correctly
331     when it was a lambda, causing mypyc trouble.
332     """
333     return [TargetVersion[val.upper()] for val in v]
334
335
336 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
337 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
338 @click.option(
339     "-l",
340     "--line-length",
341     type=int,
342     default=DEFAULT_LINE_LENGTH,
343     help="How many characters per line to allow.",
344     show_default=True,
345 )
346 @click.option(
347     "-t",
348     "--target-version",
349     type=click.Choice([v.name.lower() for v in TargetVersion]),
350     callback=target_version_option_callback,
351     multiple=True,
352     help=(
353         "Python versions that should be supported by Black's output. [default: per-file"
354         " auto-detection]"
355     ),
356 )
357 @click.option(
358     "--py36",
359     is_flag=True,
360     help=(
361         "Allow using Python 3.6-only syntax on all input files.  This will put trailing"
362         " commas in function signatures and calls also after *args and **kwargs."
363         " Deprecated; use --target-version instead. [default: per-file auto-detection]"
364     ),
365 )
366 @click.option(
367     "--pyi",
368     is_flag=True,
369     help=(
370         "Format all input files like typing stubs regardless of file extension (useful"
371         " when piping source on standard input)."
372     ),
373 )
374 @click.option(
375     "-S",
376     "--skip-string-normalization",
377     is_flag=True,
378     help="Don't normalize string quotes or prefixes.",
379 )
380 @click.option(
381     "--check",
382     is_flag=True,
383     help=(
384         "Don't write the files back, just return the status.  Return code 0 means"
385         " nothing would change.  Return code 1 means some files would be reformatted."
386         " Return code 123 means there was an internal error."
387     ),
388 )
389 @click.option(
390     "--diff",
391     is_flag=True,
392     help="Don't write the files back, just output a diff for each file on stdout.",
393 )
394 @click.option(
395     "--color/--no-color",
396     is_flag=True,
397     help="Show colored diff. Only applies when `--diff` is given.",
398 )
399 @click.option(
400     "--fast/--safe",
401     is_flag=True,
402     help="If --fast given, skip temporary sanity checks. [default: --safe]",
403 )
404 @click.option(
405     "--include",
406     type=str,
407     default=DEFAULT_INCLUDES,
408     help=(
409         "A regular expression that matches files and directories that should be"
410         " included on recursive searches.  An empty value means all files are included"
411         " regardless of the name.  Use forward slashes for directories on all platforms"
412         " (Windows, too).  Exclusions are calculated first, inclusions later."
413     ),
414     show_default=True,
415 )
416 @click.option(
417     "--exclude",
418     type=str,
419     default=DEFAULT_EXCLUDES,
420     help=(
421         "A regular expression that matches files and directories that should be"
422         " excluded on recursive searches.  An empty value means no paths are excluded."
423         " Use forward slashes for directories on all platforms (Windows, too). "
424         " Exclusions are calculated first, inclusions later."
425     ),
426     show_default=True,
427 )
428 @click.option(
429     "--force-exclude",
430     type=str,
431     help=(
432         "Like --exclude, but files and directories matching this regex will be "
433         "excluded even when they are passed explicitly as arguments"
434     ),
435 )
436 @click.option(
437     "-q",
438     "--quiet",
439     is_flag=True,
440     help=(
441         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
442         " those with 2>/dev/null."
443     ),
444 )
445 @click.option(
446     "-v",
447     "--verbose",
448     is_flag=True,
449     help=(
450         "Also emit messages to stderr about files that were not changed or were ignored"
451         " due to --exclude=."
452     ),
453 )
454 @click.version_option(version=__version__)
455 @click.argument(
456     "src",
457     nargs=-1,
458     type=click.Path(
459         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
460     ),
461     is_eager=True,
462 )
463 @click.option(
464     "--config",
465     type=click.Path(
466         exists=True,
467         file_okay=True,
468         dir_okay=False,
469         readable=True,
470         allow_dash=False,
471         path_type=str,
472     ),
473     is_eager=True,
474     callback=read_pyproject_toml,
475     help="Read configuration from PATH.",
476 )
477 @click.pass_context
478 def main(
479     ctx: click.Context,
480     code: Optional[str],
481     line_length: int,
482     target_version: List[TargetVersion],
483     check: bool,
484     diff: bool,
485     color: bool,
486     fast: bool,
487     pyi: bool,
488     py36: bool,
489     skip_string_normalization: bool,
490     quiet: bool,
491     verbose: bool,
492     include: str,
493     exclude: str,
494     force_exclude: Optional[str],
495     src: Tuple[str, ...],
496     config: Optional[str],
497 ) -> None:
498     """The uncompromising code formatter."""
499     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
500     if target_version:
501         if py36:
502             err("Cannot use both --target-version and --py36")
503             ctx.exit(2)
504         else:
505             versions = set(target_version)
506     elif py36:
507         err(
508             "--py36 is deprecated and will be removed in a future version. Use"
509             " --target-version py36 instead."
510         )
511         versions = PY36_VERSIONS
512     else:
513         # We'll autodetect later.
514         versions = set()
515     mode = Mode(
516         target_versions=versions,
517         line_length=line_length,
518         is_pyi=pyi,
519         string_normalization=not skip_string_normalization,
520     )
521     if config and verbose:
522         out(f"Using configuration from {config}.", bold=False, fg="blue")
523     if code is not None:
524         print(format_str(code, mode=mode))
525         ctx.exit(0)
526     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
527     sources = get_sources(
528         ctx=ctx,
529         src=src,
530         quiet=quiet,
531         verbose=verbose,
532         include=include,
533         exclude=exclude,
534         force_exclude=force_exclude,
535         report=report,
536     )
537
538     path_empty(
539         sources,
540         "No Python files are present to be formatted. Nothing to do 😴",
541         quiet,
542         verbose,
543         ctx,
544     )
545
546     if len(sources) == 1:
547         reformat_one(
548             src=sources.pop(),
549             fast=fast,
550             write_back=write_back,
551             mode=mode,
552             report=report,
553         )
554     else:
555         reformat_many(
556             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
557         )
558
559     if verbose or not quiet:
560         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
561         click.secho(str(report), err=True)
562     ctx.exit(report.return_code)
563
564
565 def get_sources(
566     *,
567     ctx: click.Context,
568     src: Tuple[str, ...],
569     quiet: bool,
570     verbose: bool,
571     include: str,
572     exclude: str,
573     force_exclude: Optional[str],
574     report: "Report",
575 ) -> Set[Path]:
576     """Compute the set of files to be formatted."""
577     try:
578         include_regex = re_compile_maybe_verbose(include)
579     except re.error:
580         err(f"Invalid regular expression for include given: {include!r}")
581         ctx.exit(2)
582     try:
583         exclude_regex = re_compile_maybe_verbose(exclude)
584     except re.error:
585         err(f"Invalid regular expression for exclude given: {exclude!r}")
586         ctx.exit(2)
587     try:
588         force_exclude_regex = (
589             re_compile_maybe_verbose(force_exclude) if force_exclude else None
590         )
591     except re.error:
592         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
593         ctx.exit(2)
594
595     root = find_project_root(src)
596     sources: Set[Path] = set()
597     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
598     exclude_regexes = [exclude_regex]
599     if force_exclude_regex is not None:
600         exclude_regexes.append(force_exclude_regex)
601
602     for s in src:
603         p = Path(s)
604         if p.is_dir():
605             sources.update(
606                 gen_python_files(
607                     p.iterdir(),
608                     root,
609                     include_regex,
610                     exclude_regexes,
611                     report,
612                     get_gitignore(root),
613                 )
614             )
615         elif s == "-":
616             sources.add(p)
617         elif p.is_file():
618             sources.update(
619                 gen_python_files(
620                     [p], root, None, exclude_regexes, report, get_gitignore(root)
621                 )
622             )
623         else:
624             err(f"invalid path: {s}")
625     return sources
626
627
628 def path_empty(
629     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
630 ) -> None:
631     """
632     Exit if there is no `src` provided for formatting
633     """
634     if len(src) == 0:
635         if verbose or not quiet:
636             out(msg)
637             ctx.exit(0)
638
639
640 def reformat_one(
641     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
642 ) -> None:
643     """Reformat a single file under `src` without spawning child processes.
644
645     `fast`, `write_back`, and `mode` options are passed to
646     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
647     """
648     try:
649         changed = Changed.NO
650         if not src.is_file() and str(src) == "-":
651             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
652                 changed = Changed.YES
653         else:
654             cache: Cache = {}
655             if write_back != WriteBack.DIFF:
656                 cache = read_cache(mode)
657                 res_src = src.resolve()
658                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
659                     changed = Changed.CACHED
660             if changed is not Changed.CACHED and format_file_in_place(
661                 src, fast=fast, write_back=write_back, mode=mode
662             ):
663                 changed = Changed.YES
664             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
665                 write_back is WriteBack.CHECK and changed is Changed.NO
666             ):
667                 write_cache(cache, [src], mode)
668         report.done(src, changed)
669     except Exception as exc:
670         report.failed(src, str(exc))
671
672
673 def reformat_many(
674     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
675 ) -> None:
676     """Reformat multiple files using a ProcessPoolExecutor."""
677     executor: Executor
678     loop = asyncio.get_event_loop()
679     worker_count = os.cpu_count()
680     if sys.platform == "win32":
681         # Work around https://bugs.python.org/issue26903
682         worker_count = min(worker_count, 61)
683     try:
684         executor = ProcessPoolExecutor(max_workers=worker_count)
685     except OSError:
686         # we arrive here if the underlying system does not support multi-processing
687         # like in AWS Lambda, in which case we gracefully fallback to
688         # a ThreadPollExecutor with just a single worker (more workers would not do us
689         # any good due to the Global Interpreter Lock)
690         executor = ThreadPoolExecutor(max_workers=1)
691
692     try:
693         loop.run_until_complete(
694             schedule_formatting(
695                 sources=sources,
696                 fast=fast,
697                 write_back=write_back,
698                 mode=mode,
699                 report=report,
700                 loop=loop,
701                 executor=executor,
702             )
703         )
704     finally:
705         shutdown(loop)
706         if executor is not None:
707             executor.shutdown()
708
709
710 async def schedule_formatting(
711     sources: Set[Path],
712     fast: bool,
713     write_back: WriteBack,
714     mode: Mode,
715     report: "Report",
716     loop: asyncio.AbstractEventLoop,
717     executor: Executor,
718 ) -> None:
719     """Run formatting of `sources` in parallel using the provided `executor`.
720
721     (Use ProcessPoolExecutors for actual parallelism.)
722
723     `write_back`, `fast`, and `mode` options are passed to
724     :func:`format_file_in_place`.
725     """
726     cache: Cache = {}
727     if write_back != WriteBack.DIFF:
728         cache = read_cache(mode)
729         sources, cached = filter_cached(cache, sources)
730         for src in sorted(cached):
731             report.done(src, Changed.CACHED)
732     if not sources:
733         return
734
735     cancelled = []
736     sources_to_cache = []
737     lock = None
738     if write_back == WriteBack.DIFF:
739         # For diff output, we need locks to ensure we don't interleave output
740         # from different processes.
741         manager = Manager()
742         lock = manager.Lock()
743     tasks = {
744         asyncio.ensure_future(
745             loop.run_in_executor(
746                 executor, format_file_in_place, src, fast, mode, write_back, lock
747             )
748         ): src
749         for src in sorted(sources)
750     }
751     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
752     try:
753         loop.add_signal_handler(signal.SIGINT, cancel, pending)
754         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
755     except NotImplementedError:
756         # There are no good alternatives for these on Windows.
757         pass
758     while pending:
759         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
760         for task in done:
761             src = tasks.pop(task)
762             if task.cancelled():
763                 cancelled.append(task)
764             elif task.exception():
765                 report.failed(src, str(task.exception()))
766             else:
767                 changed = Changed.YES if task.result() else Changed.NO
768                 # If the file was written back or was successfully checked as
769                 # well-formatted, store this information in the cache.
770                 if write_back is WriteBack.YES or (
771                     write_back is WriteBack.CHECK and changed is Changed.NO
772                 ):
773                     sources_to_cache.append(src)
774                 report.done(src, changed)
775     if cancelled:
776         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
777     if sources_to_cache:
778         write_cache(cache, sources_to_cache, mode)
779
780
781 def format_file_in_place(
782     src: Path,
783     fast: bool,
784     mode: Mode,
785     write_back: WriteBack = WriteBack.NO,
786     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
787 ) -> bool:
788     """Format file under `src` path. Return True if changed.
789
790     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
791     code to the file.
792     `mode` and `fast` options are passed to :func:`format_file_contents`.
793     """
794     if src.suffix == ".pyi":
795         mode = replace(mode, is_pyi=True)
796
797     then = datetime.utcfromtimestamp(src.stat().st_mtime)
798     with open(src, "rb") as buf:
799         src_contents, encoding, newline = decode_bytes(buf.read())
800     try:
801         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
802     except NothingChanged:
803         return False
804
805     if write_back == WriteBack.YES:
806         with open(src, "w", encoding=encoding, newline=newline) as f:
807             f.write(dst_contents)
808     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
809         now = datetime.utcnow()
810         src_name = f"{src}\t{then} +0000"
811         dst_name = f"{src}\t{now} +0000"
812         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
813
814         if write_back == write_back.COLOR_DIFF:
815             diff_contents = color_diff(diff_contents)
816
817         with lock or nullcontext():
818             f = io.TextIOWrapper(
819                 sys.stdout.buffer,
820                 encoding=encoding,
821                 newline=newline,
822                 write_through=True,
823             )
824             f = wrap_stream_for_windows(f)
825             f.write(diff_contents)
826             f.detach()
827
828     return True
829
830
831 def color_diff(contents: str) -> str:
832     """Inject the ANSI color codes to the diff."""
833     lines = contents.split("\n")
834     for i, line in enumerate(lines):
835         if line.startswith("+++") or line.startswith("---"):
836             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
837         if line.startswith("@@"):
838             line = "\033[36m" + line + "\033[0m"  # cyan, reset
839         if line.startswith("+"):
840             line = "\033[32m" + line + "\033[0m"  # green, reset
841         elif line.startswith("-"):
842             line = "\033[31m" + line + "\033[0m"  # red, reset
843         lines[i] = line
844     return "\n".join(lines)
845
846
847 def wrap_stream_for_windows(
848     f: io.TextIOWrapper,
849 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
850     """
851     Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
852
853     If `colorama` is not found, then no change is made. If `colorama` does
854     exist, then it handles the logic to determine whether or not to change
855     things.
856     """
857     try:
858         from colorama import initialise
859
860         # We set `strip=False` so that we can don't have to modify
861         # test_express_diff_with_color.
862         f = initialise.wrap_stream(
863             f, convert=None, strip=False, autoreset=False, wrap=True
864         )
865
866         # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
867         # which does not have a `detach()` method. So we fake one.
868         f.detach = lambda *args, **kwargs: None  # type: ignore
869     except ImportError:
870         pass
871
872     return f
873
874
875 def format_stdin_to_stdout(
876     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
877 ) -> bool:
878     """Format file on stdin. Return True if changed.
879
880     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
881     write a diff to stdout. The `mode` argument is passed to
882     :func:`format_file_contents`.
883     """
884     then = datetime.utcnow()
885     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
886     dst = src
887     try:
888         dst = format_file_contents(src, fast=fast, mode=mode)
889         return True
890
891     except NothingChanged:
892         return False
893
894     finally:
895         f = io.TextIOWrapper(
896             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
897         )
898         if write_back == WriteBack.YES:
899             f.write(dst)
900         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
901             now = datetime.utcnow()
902             src_name = f"STDIN\t{then} +0000"
903             dst_name = f"STDOUT\t{now} +0000"
904             d = diff(src, dst, src_name, dst_name)
905             if write_back == WriteBack.COLOR_DIFF:
906                 d = color_diff(d)
907                 f = wrap_stream_for_windows(f)
908             f.write(d)
909         f.detach()
910
911
912 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
913     """Reformat contents a file and return new contents.
914
915     If `fast` is False, additionally confirm that the reformatted code is
916     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
917     `mode` is passed to :func:`format_str`.
918     """
919     if src_contents.strip() == "":
920         raise NothingChanged
921
922     dst_contents = format_str(src_contents, mode=mode)
923     if src_contents == dst_contents:
924         raise NothingChanged
925
926     if not fast:
927         assert_equivalent(src_contents, dst_contents)
928         assert_stable(src_contents, dst_contents, mode=mode)
929     return dst_contents
930
931
932 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
933     """Reformat a string and return new contents.
934
935     `mode` determines formatting options, such as how many characters per line are
936     allowed.  Example:
937
938     >>> import black
939     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
940     def f(arg: str = "") -> None:
941         ...
942
943     A more complex example:
944     >>> print(
945     ...   black.format_str(
946     ...     "def f(arg:str='')->None: hey",
947     ...     mode=black.Mode(
948     ...       target_versions={black.TargetVersion.PY36},
949     ...       line_length=10,
950     ...       string_normalization=False,
951     ...       is_pyi=False,
952     ...     ),
953     ...   ),
954     ... )
955     def f(
956         arg: str = '',
957     ) -> None:
958         hey
959
960     """
961     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
962     dst_contents = []
963     future_imports = get_future_imports(src_node)
964     if mode.target_versions:
965         versions = mode.target_versions
966     else:
967         versions = detect_target_versions(src_node)
968     normalize_fmt_off(src_node)
969     lines = LineGenerator(
970         remove_u_prefix="unicode_literals" in future_imports
971         or supports_feature(versions, Feature.UNICODE_LITERALS),
972         is_pyi=mode.is_pyi,
973         normalize_strings=mode.string_normalization,
974     )
975     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
976     empty_line = Line()
977     after = 0
978     split_line_features = {
979         feature
980         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
981         if supports_feature(versions, feature)
982     }
983     for current_line in lines.visit(src_node):
984         dst_contents.append(str(empty_line) * after)
985         before, after = elt.maybe_empty_lines(current_line)
986         dst_contents.append(str(empty_line) * before)
987         for line in transform_line(
988             current_line,
989             line_length=mode.line_length,
990             normalize_strings=mode.string_normalization,
991             features=split_line_features,
992         ):
993             dst_contents.append(str(line))
994     return "".join(dst_contents)
995
996
997 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
998     """Return a tuple of (decoded_contents, encoding, newline).
999
1000     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1001     universal newlines (i.e. only contains LF).
1002     """
1003     srcbuf = io.BytesIO(src)
1004     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1005     if not lines:
1006         return "", encoding, "\n"
1007
1008     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1009     srcbuf.seek(0)
1010     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1011         return tiow.read(), encoding, newline
1012
1013
1014 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1015     if not target_versions:
1016         # No target_version specified, so try all grammars.
1017         return [
1018             # Python 3.7+
1019             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1020             # Python 3.0-3.6
1021             pygram.python_grammar_no_print_statement_no_exec_statement,
1022             # Python 2.7 with future print_function import
1023             pygram.python_grammar_no_print_statement,
1024             # Python 2.7
1025             pygram.python_grammar,
1026         ]
1027
1028     if all(version.is_python2() for version in target_versions):
1029         # Python 2-only code, so try Python 2 grammars.
1030         return [
1031             # Python 2.7 with future print_function import
1032             pygram.python_grammar_no_print_statement,
1033             # Python 2.7
1034             pygram.python_grammar,
1035         ]
1036
1037     # Python 3-compatible code, so only try Python 3 grammar.
1038     grammars = []
1039     # If we have to parse both, try to parse async as a keyword first
1040     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1041         # Python 3.7+
1042         grammars.append(
1043             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1044         )
1045     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1046         # Python 3.0-3.6
1047         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1048     # At least one of the above branches must have been taken, because every Python
1049     # version has exactly one of the two 'ASYNC_*' flags
1050     return grammars
1051
1052
1053 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1054     """Given a string with source, return the lib2to3 Node."""
1055     if src_txt[-1:] != "\n":
1056         src_txt += "\n"
1057
1058     for grammar in get_grammars(set(target_versions)):
1059         drv = driver.Driver(grammar, pytree.convert)
1060         try:
1061             result = drv.parse_string(src_txt, True)
1062             break
1063
1064         except ParseError as pe:
1065             lineno, column = pe.context[1]
1066             lines = src_txt.splitlines()
1067             try:
1068                 faulty_line = lines[lineno - 1]
1069             except IndexError:
1070                 faulty_line = "<line number missing in source>"
1071             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1072     else:
1073         raise exc from None
1074
1075     if isinstance(result, Leaf):
1076         result = Node(syms.file_input, [result])
1077     return result
1078
1079
1080 def lib2to3_unparse(node: Node) -> str:
1081     """Given a lib2to3 node, return its string representation."""
1082     code = str(node)
1083     return code
1084
1085
1086 class Visitor(Generic[T]):
1087     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1088
1089     def visit(self, node: LN) -> Iterator[T]:
1090         """Main method to visit `node` and its children.
1091
1092         It tries to find a `visit_*()` method for the given `node.type`, like
1093         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1094         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1095         instead.
1096
1097         Then yields objects of type `T` from the selected visitor.
1098         """
1099         if node.type < 256:
1100             name = token.tok_name[node.type]
1101         else:
1102             name = str(type_repr(node.type))
1103         # We explicitly branch on whether a visitor exists (instead of
1104         # using self.visit_default as the default arg to getattr) in order
1105         # to save needing to create a bound method object and so mypyc can
1106         # generate a native call to visit_default.
1107         visitf = getattr(self, f"visit_{name}", None)
1108         if visitf:
1109             yield from visitf(node)
1110         else:
1111             yield from self.visit_default(node)
1112
1113     def visit_default(self, node: LN) -> Iterator[T]:
1114         """Default `visit_*()` implementation. Recurses to children of `node`."""
1115         if isinstance(node, Node):
1116             for child in node.children:
1117                 yield from self.visit(child)
1118
1119
1120 @dataclass
1121 class DebugVisitor(Visitor[T]):
1122     tree_depth: int = 0
1123
1124     def visit_default(self, node: LN) -> Iterator[T]:
1125         indent = " " * (2 * self.tree_depth)
1126         if isinstance(node, Node):
1127             _type = type_repr(node.type)
1128             out(f"{indent}{_type}", fg="yellow")
1129             self.tree_depth += 1
1130             for child in node.children:
1131                 yield from self.visit(child)
1132
1133             self.tree_depth -= 1
1134             out(f"{indent}/{_type}", fg="yellow", bold=False)
1135         else:
1136             _type = token.tok_name.get(node.type, str(node.type))
1137             out(f"{indent}{_type}", fg="blue", nl=False)
1138             if node.prefix:
1139                 # We don't have to handle prefixes for `Node` objects since
1140                 # that delegates to the first child anyway.
1141                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1142             out(f" {node.value!r}", fg="blue", bold=False)
1143
1144     @classmethod
1145     def show(cls, code: Union[str, Leaf, Node]) -> None:
1146         """Pretty-print the lib2to3 AST of a given string of `code`.
1147
1148         Convenience method for debugging.
1149         """
1150         v: DebugVisitor[None] = DebugVisitor()
1151         if isinstance(code, str):
1152             code = lib2to3_parse(code)
1153         list(v.visit(code))
1154
1155
1156 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1157 STATEMENT: Final = {
1158     syms.if_stmt,
1159     syms.while_stmt,
1160     syms.for_stmt,
1161     syms.try_stmt,
1162     syms.except_clause,
1163     syms.with_stmt,
1164     syms.funcdef,
1165     syms.classdef,
1166 }
1167 STANDALONE_COMMENT: Final = 153
1168 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1169 LOGIC_OPERATORS: Final = {"and", "or"}
1170 COMPARATORS: Final = {
1171     token.LESS,
1172     token.GREATER,
1173     token.EQEQUAL,
1174     token.NOTEQUAL,
1175     token.LESSEQUAL,
1176     token.GREATEREQUAL,
1177 }
1178 MATH_OPERATORS: Final = {
1179     token.VBAR,
1180     token.CIRCUMFLEX,
1181     token.AMPER,
1182     token.LEFTSHIFT,
1183     token.RIGHTSHIFT,
1184     token.PLUS,
1185     token.MINUS,
1186     token.STAR,
1187     token.SLASH,
1188     token.DOUBLESLASH,
1189     token.PERCENT,
1190     token.AT,
1191     token.TILDE,
1192     token.DOUBLESTAR,
1193 }
1194 STARS: Final = {token.STAR, token.DOUBLESTAR}
1195 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1196 VARARGS_PARENTS: Final = {
1197     syms.arglist,
1198     syms.argument,  # double star in arglist
1199     syms.trailer,  # single argument to call
1200     syms.typedargslist,
1201     syms.varargslist,  # lambdas
1202 }
1203 UNPACKING_PARENTS: Final = {
1204     syms.atom,  # single element of a list or set literal
1205     syms.dictsetmaker,
1206     syms.listmaker,
1207     syms.testlist_gexp,
1208     syms.testlist_star_expr,
1209 }
1210 TEST_DESCENDANTS: Final = {
1211     syms.test,
1212     syms.lambdef,
1213     syms.or_test,
1214     syms.and_test,
1215     syms.not_test,
1216     syms.comparison,
1217     syms.star_expr,
1218     syms.expr,
1219     syms.xor_expr,
1220     syms.and_expr,
1221     syms.shift_expr,
1222     syms.arith_expr,
1223     syms.trailer,
1224     syms.term,
1225     syms.power,
1226 }
1227 ASSIGNMENTS: Final = {
1228     "=",
1229     "+=",
1230     "-=",
1231     "*=",
1232     "@=",
1233     "/=",
1234     "%=",
1235     "&=",
1236     "|=",
1237     "^=",
1238     "<<=",
1239     ">>=",
1240     "**=",
1241     "//=",
1242 }
1243 COMPREHENSION_PRIORITY: Final = 20
1244 COMMA_PRIORITY: Final = 18
1245 TERNARY_PRIORITY: Final = 16
1246 LOGIC_PRIORITY: Final = 14
1247 STRING_PRIORITY: Final = 12
1248 COMPARATOR_PRIORITY: Final = 10
1249 MATH_PRIORITIES: Final = {
1250     token.VBAR: 9,
1251     token.CIRCUMFLEX: 8,
1252     token.AMPER: 7,
1253     token.LEFTSHIFT: 6,
1254     token.RIGHTSHIFT: 6,
1255     token.PLUS: 5,
1256     token.MINUS: 5,
1257     token.STAR: 4,
1258     token.SLASH: 4,
1259     token.DOUBLESLASH: 4,
1260     token.PERCENT: 4,
1261     token.AT: 4,
1262     token.TILDE: 3,
1263     token.DOUBLESTAR: 2,
1264 }
1265 DOT_PRIORITY: Final = 1
1266
1267
1268 @dataclass
1269 class BracketTracker:
1270     """Keeps track of brackets on a line."""
1271
1272     depth: int = 0
1273     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1274     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1275     previous: Optional[Leaf] = None
1276     _for_loop_depths: List[int] = field(default_factory=list)
1277     _lambda_argument_depths: List[int] = field(default_factory=list)
1278
1279     def mark(self, leaf: Leaf) -> None:
1280         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1281
1282         All leaves receive an int `bracket_depth` field that stores how deep
1283         within brackets a given leaf is. 0 means there are no enclosing brackets
1284         that started on this line.
1285
1286         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1287         field that it forms a pair with. This is a one-directional link to
1288         avoid reference cycles.
1289
1290         If a leaf is a delimiter (a token on which Black can split the line if
1291         needed) and it's on depth 0, its `id()` is stored in the tracker's
1292         `delimiters` field.
1293         """
1294         if leaf.type == token.COMMENT:
1295             return
1296
1297         self.maybe_decrement_after_for_loop_variable(leaf)
1298         self.maybe_decrement_after_lambda_arguments(leaf)
1299         if leaf.type in CLOSING_BRACKETS:
1300             self.depth -= 1
1301             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1302             leaf.opening_bracket = opening_bracket
1303         leaf.bracket_depth = self.depth
1304         if self.depth == 0:
1305             delim = is_split_before_delimiter(leaf, self.previous)
1306             if delim and self.previous is not None:
1307                 self.delimiters[id(self.previous)] = delim
1308             else:
1309                 delim = is_split_after_delimiter(leaf, self.previous)
1310                 if delim:
1311                     self.delimiters[id(leaf)] = delim
1312         if leaf.type in OPENING_BRACKETS:
1313             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1314             self.depth += 1
1315         self.previous = leaf
1316         self.maybe_increment_lambda_arguments(leaf)
1317         self.maybe_increment_for_loop_variable(leaf)
1318
1319     def any_open_brackets(self) -> bool:
1320         """Return True if there is an yet unmatched open bracket on the line."""
1321         return bool(self.bracket_match)
1322
1323     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1324         """Return the highest priority of a delimiter found on the line.
1325
1326         Values are consistent with what `is_split_*_delimiter()` return.
1327         Raises ValueError on no delimiters.
1328         """
1329         return max(v for k, v in self.delimiters.items() if k not in exclude)
1330
1331     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1332         """Return the number of delimiters with the given `priority`.
1333
1334         If no `priority` is passed, defaults to max priority on the line.
1335         """
1336         if not self.delimiters:
1337             return 0
1338
1339         priority = priority or self.max_delimiter_priority()
1340         return sum(1 for p in self.delimiters.values() if p == priority)
1341
1342     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1343         """In a for loop, or comprehension, the variables are often unpacks.
1344
1345         To avoid splitting on the comma in this situation, increase the depth of
1346         tokens between `for` and `in`.
1347         """
1348         if leaf.type == token.NAME and leaf.value == "for":
1349             self.depth += 1
1350             self._for_loop_depths.append(self.depth)
1351             return True
1352
1353         return False
1354
1355     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1356         """See `maybe_increment_for_loop_variable` above for explanation."""
1357         if (
1358             self._for_loop_depths
1359             and self._for_loop_depths[-1] == self.depth
1360             and leaf.type == token.NAME
1361             and leaf.value == "in"
1362         ):
1363             self.depth -= 1
1364             self._for_loop_depths.pop()
1365             return True
1366
1367         return False
1368
1369     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1370         """In a lambda expression, there might be more than one argument.
1371
1372         To avoid splitting on the comma in this situation, increase the depth of
1373         tokens between `lambda` and `:`.
1374         """
1375         if leaf.type == token.NAME and leaf.value == "lambda":
1376             self.depth += 1
1377             self._lambda_argument_depths.append(self.depth)
1378             return True
1379
1380         return False
1381
1382     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1383         """See `maybe_increment_lambda_arguments` above for explanation."""
1384         if (
1385             self._lambda_argument_depths
1386             and self._lambda_argument_depths[-1] == self.depth
1387             and leaf.type == token.COLON
1388         ):
1389             self.depth -= 1
1390             self._lambda_argument_depths.pop()
1391             return True
1392
1393         return False
1394
1395     def get_open_lsqb(self) -> Optional[Leaf]:
1396         """Return the most recent opening square bracket (if any)."""
1397         return self.bracket_match.get((self.depth - 1, token.RSQB))
1398
1399
1400 @dataclass
1401 class Line:
1402     """Holds leaves and comments. Can be printed with `str(line)`."""
1403
1404     depth: int = 0
1405     leaves: List[Leaf] = field(default_factory=list)
1406     # keys ordered like `leaves`
1407     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1408     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1409     inside_brackets: bool = False
1410     should_explode: bool = False
1411
1412     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1413         """Add a new `leaf` to the end of the line.
1414
1415         Unless `preformatted` is True, the `leaf` will receive a new consistent
1416         whitespace prefix and metadata applied by :class:`BracketTracker`.
1417         Trailing commas are maybe removed, unpacked for loop variables are
1418         demoted from being delimiters.
1419
1420         Inline comments are put aside.
1421         """
1422         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1423         if not has_value:
1424             return
1425
1426         if token.COLON == leaf.type and self.is_class_paren_empty:
1427             del self.leaves[-2:]
1428         if self.leaves and not preformatted:
1429             # Note: at this point leaf.prefix should be empty except for
1430             # imports, for which we only preserve newlines.
1431             leaf.prefix += whitespace(
1432                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1433             )
1434         if self.inside_brackets or not preformatted:
1435             self.bracket_tracker.mark(leaf)
1436             self.maybe_remove_trailing_comma(leaf)
1437         if not self.append_comment(leaf):
1438             self.leaves.append(leaf)
1439
1440     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1441         """Like :func:`append()` but disallow invalid standalone comment structure.
1442
1443         Raises ValueError when any `leaf` is appended after a standalone comment
1444         or when a standalone comment is not the first leaf on the line.
1445         """
1446         if self.bracket_tracker.depth == 0:
1447             if self.is_comment:
1448                 raise ValueError("cannot append to standalone comments")
1449
1450             if self.leaves and leaf.type == STANDALONE_COMMENT:
1451                 raise ValueError(
1452                     "cannot append standalone comments to a populated line"
1453                 )
1454
1455         self.append(leaf, preformatted=preformatted)
1456
1457     @property
1458     def is_comment(self) -> bool:
1459         """Is this line a standalone comment?"""
1460         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1461
1462     @property
1463     def is_decorator(self) -> bool:
1464         """Is this line a decorator?"""
1465         return bool(self) and self.leaves[0].type == token.AT
1466
1467     @property
1468     def is_import(self) -> bool:
1469         """Is this an import line?"""
1470         return bool(self) and is_import(self.leaves[0])
1471
1472     @property
1473     def is_class(self) -> bool:
1474         """Is this line a class definition?"""
1475         return (
1476             bool(self)
1477             and self.leaves[0].type == token.NAME
1478             and self.leaves[0].value == "class"
1479         )
1480
1481     @property
1482     def is_stub_class(self) -> bool:
1483         """Is this line a class definition with a body consisting only of "..."?"""
1484         return self.is_class and self.leaves[-3:] == [
1485             Leaf(token.DOT, ".") for _ in range(3)
1486         ]
1487
1488     @property
1489     def is_collection_with_optional_trailing_comma(self) -> bool:
1490         """Is this line a collection literal with a trailing comma that's optional?
1491
1492         Note that the trailing comma in a 1-tuple is not optional.
1493         """
1494         if not self.leaves or len(self.leaves) < 4:
1495             return False
1496
1497         # Look for and address a trailing colon.
1498         if self.leaves[-1].type == token.COLON:
1499             closer = self.leaves[-2]
1500             close_index = -2
1501         else:
1502             closer = self.leaves[-1]
1503             close_index = -1
1504         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1505             return False
1506
1507         if closer.type == token.RPAR:
1508             # Tuples require an extra check, because if there's only
1509             # one element in the tuple removing the comma unmakes the
1510             # tuple.
1511             #
1512             # We also check for parens before looking for the trailing
1513             # comma because in some cases (eg assigning a dict
1514             # literal) the literal gets wrapped in temporary parens
1515             # during parsing. This case is covered by the
1516             # collections.py test data.
1517             opener = closer.opening_bracket
1518             for _open_index, leaf in enumerate(self.leaves):
1519                 if leaf is opener:
1520                     break
1521
1522             else:
1523                 # Couldn't find the matching opening paren, play it safe.
1524                 return False
1525
1526             commas = 0
1527             comma_depth = self.leaves[close_index - 1].bracket_depth
1528             for leaf in self.leaves[_open_index + 1 : close_index]:
1529                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1530                     commas += 1
1531             if commas > 1:
1532                 # We haven't looked yet for the trailing comma because
1533                 # we might also have caught noop parens.
1534                 return self.leaves[close_index - 1].type == token.COMMA
1535
1536             elif commas == 1:
1537                 return False  # it's either a one-tuple or didn't have a trailing comma
1538
1539             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1540                 close_index -= 1
1541                 closer = self.leaves[close_index]
1542                 if closer.type == token.RPAR:
1543                     # TODO: this is a gut feeling. Will we ever see this?
1544                     return False
1545
1546         if self.leaves[close_index - 1].type != token.COMMA:
1547             return False
1548
1549         return True
1550
1551     @property
1552     def is_def(self) -> bool:
1553         """Is this a function definition? (Also returns True for async defs.)"""
1554         try:
1555             first_leaf = self.leaves[0]
1556         except IndexError:
1557             return False
1558
1559         try:
1560             second_leaf: Optional[Leaf] = self.leaves[1]
1561         except IndexError:
1562             second_leaf = None
1563         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1564             first_leaf.type == token.ASYNC
1565             and second_leaf is not None
1566             and second_leaf.type == token.NAME
1567             and second_leaf.value == "def"
1568         )
1569
1570     @property
1571     def is_class_paren_empty(self) -> bool:
1572         """Is this a class with no base classes but using parentheses?
1573
1574         Those are unnecessary and should be removed.
1575         """
1576         return (
1577             bool(self)
1578             and len(self.leaves) == 4
1579             and self.is_class
1580             and self.leaves[2].type == token.LPAR
1581             and self.leaves[2].value == "("
1582             and self.leaves[3].type == token.RPAR
1583             and self.leaves[3].value == ")"
1584         )
1585
1586     @property
1587     def is_triple_quoted_string(self) -> bool:
1588         """Is the line a triple quoted string?"""
1589         return (
1590             bool(self)
1591             and self.leaves[0].type == token.STRING
1592             and self.leaves[0].value.startswith(('"""', "'''"))
1593         )
1594
1595     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1596         """If so, needs to be split before emitting."""
1597         for leaf in self.leaves:
1598             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1599                 return True
1600
1601         return False
1602
1603     def contains_uncollapsable_type_comments(self) -> bool:
1604         ignored_ids = set()
1605         try:
1606             last_leaf = self.leaves[-1]
1607             ignored_ids.add(id(last_leaf))
1608             if last_leaf.type == token.COMMA or (
1609                 last_leaf.type == token.RPAR and not last_leaf.value
1610             ):
1611                 # When trailing commas or optional parens are inserted by Black for
1612                 # consistency, comments after the previous last element are not moved
1613                 # (they don't have to, rendering will still be correct).  So we ignore
1614                 # trailing commas and invisible.
1615                 last_leaf = self.leaves[-2]
1616                 ignored_ids.add(id(last_leaf))
1617         except IndexError:
1618             return False
1619
1620         # A type comment is uncollapsable if it is attached to a leaf
1621         # that isn't at the end of the line (since that could cause it
1622         # to get associated to a different argument) or if there are
1623         # comments before it (since that could cause it to get hidden
1624         # behind a comment.
1625         comment_seen = False
1626         for leaf_id, comments in self.comments.items():
1627             for comment in comments:
1628                 if is_type_comment(comment):
1629                     if comment_seen or (
1630                         not is_type_comment(comment, " ignore")
1631                         and leaf_id not in ignored_ids
1632                     ):
1633                         return True
1634
1635                 comment_seen = True
1636
1637         return False
1638
1639     def contains_unsplittable_type_ignore(self) -> bool:
1640         if not self.leaves:
1641             return False
1642
1643         # If a 'type: ignore' is attached to the end of a line, we
1644         # can't split the line, because we can't know which of the
1645         # subexpressions the ignore was meant to apply to.
1646         #
1647         # We only want this to apply to actual physical lines from the
1648         # original source, though: we don't want the presence of a
1649         # 'type: ignore' at the end of a multiline expression to
1650         # justify pushing it all onto one line. Thus we
1651         # (unfortunately) need to check the actual source lines and
1652         # only report an unsplittable 'type: ignore' if this line was
1653         # one line in the original code.
1654
1655         # Grab the first and last line numbers, skipping generated leaves
1656         first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0)
1657         last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0)
1658
1659         if first_line == last_line:
1660             # We look at the last two leaves since a comma or an
1661             # invisible paren could have been added at the end of the
1662             # line.
1663             for node in self.leaves[-2:]:
1664                 for comment in self.comments.get(id(node), []):
1665                     if is_type_comment(comment, " ignore"):
1666                         return True
1667
1668         return False
1669
1670     def contains_multiline_strings(self) -> bool:
1671         return any(is_multiline_string(leaf) for leaf in self.leaves)
1672
1673     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1674         """Remove trailing comma if there is one and it's safe."""
1675         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1676             return False
1677
1678         # We remove trailing commas only in the case of importing a
1679         # single name from a module.
1680         if not (
1681             self.leaves
1682             and self.is_import
1683             and len(self.leaves) > 4
1684             and self.leaves[-1].type == token.COMMA
1685             and closing.type in CLOSING_BRACKETS
1686             and self.leaves[-4].type == token.NAME
1687             and (
1688                 # regular `from foo import bar,`
1689                 self.leaves[-4].value == "import"
1690                 # `from foo import (bar as baz,)
1691                 or (
1692                     len(self.leaves) > 6
1693                     and self.leaves[-6].value == "import"
1694                     and self.leaves[-3].value == "as"
1695                 )
1696                 # `from foo import bar as baz,`
1697                 or (
1698                     len(self.leaves) > 5
1699                     and self.leaves[-5].value == "import"
1700                     and self.leaves[-3].value == "as"
1701                 )
1702             )
1703             and closing.type == token.RPAR
1704         ):
1705             return False
1706
1707         self.remove_trailing_comma()
1708         return True
1709
1710     def append_comment(self, comment: Leaf) -> bool:
1711         """Add an inline or standalone comment to the line."""
1712         if (
1713             comment.type == STANDALONE_COMMENT
1714             and self.bracket_tracker.any_open_brackets()
1715         ):
1716             comment.prefix = ""
1717             return False
1718
1719         if comment.type != token.COMMENT:
1720             return False
1721
1722         if not self.leaves:
1723             comment.type = STANDALONE_COMMENT
1724             comment.prefix = ""
1725             return False
1726
1727         last_leaf = self.leaves[-1]
1728         if (
1729             last_leaf.type == token.RPAR
1730             and not last_leaf.value
1731             and last_leaf.parent
1732             and len(list(last_leaf.parent.leaves())) <= 3
1733             and not is_type_comment(comment)
1734         ):
1735             # Comments on an optional parens wrapping a single leaf should belong to
1736             # the wrapped node except if it's a type comment. Pinning the comment like
1737             # this avoids unstable formatting caused by comment migration.
1738             if len(self.leaves) < 2:
1739                 comment.type = STANDALONE_COMMENT
1740                 comment.prefix = ""
1741                 return False
1742
1743             last_leaf = self.leaves[-2]
1744         self.comments.setdefault(id(last_leaf), []).append(comment)
1745         return True
1746
1747     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1748         """Generate comments that should appear directly after `leaf`."""
1749         return self.comments.get(id(leaf), [])
1750
1751     def remove_trailing_comma(self) -> None:
1752         """Remove the trailing comma and moves the comments attached to it."""
1753         trailing_comma = self.leaves.pop()
1754         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1755         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1756             trailing_comma_comments
1757         )
1758
1759     def is_complex_subscript(self, leaf: Leaf) -> bool:
1760         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1761         open_lsqb = self.bracket_tracker.get_open_lsqb()
1762         if open_lsqb is None:
1763             return False
1764
1765         subscript_start = open_lsqb.next_sibling
1766
1767         if isinstance(subscript_start, Node):
1768             if subscript_start.type == syms.listmaker:
1769                 return False
1770
1771             if subscript_start.type == syms.subscriptlist:
1772                 subscript_start = child_towards(subscript_start, leaf)
1773         return subscript_start is not None and any(
1774             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1775         )
1776
1777     def clone(self) -> "Line":
1778         return Line(
1779             depth=self.depth,
1780             inside_brackets=self.inside_brackets,
1781             should_explode=self.should_explode,
1782         )
1783
1784     def __str__(self) -> str:
1785         """Render the line."""
1786         if not self:
1787             return "\n"
1788
1789         indent = "    " * self.depth
1790         leaves = iter(self.leaves)
1791         first = next(leaves)
1792         res = f"{first.prefix}{indent}{first.value}"
1793         for leaf in leaves:
1794             res += str(leaf)
1795         for comment in itertools.chain.from_iterable(self.comments.values()):
1796             res += str(comment)
1797
1798         return res + "\n"
1799
1800     def __bool__(self) -> bool:
1801         """Return True if the line has leaves or comments."""
1802         return bool(self.leaves or self.comments)
1803
1804
1805 @dataclass
1806 class EmptyLineTracker:
1807     """Provides a stateful method that returns the number of potential extra
1808     empty lines needed before and after the currently processed line.
1809
1810     Note: this tracker works on lines that haven't been split yet.  It assumes
1811     the prefix of the first leaf consists of optional newlines.  Those newlines
1812     are consumed by `maybe_empty_lines()` and included in the computation.
1813     """
1814
1815     is_pyi: bool = False
1816     previous_line: Optional[Line] = None
1817     previous_after: int = 0
1818     previous_defs: List[int] = field(default_factory=list)
1819
1820     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1821         """Return the number of extra empty lines before and after the `current_line`.
1822
1823         This is for separating `def`, `async def` and `class` with extra empty
1824         lines (two on module-level).
1825         """
1826         before, after = self._maybe_empty_lines(current_line)
1827         before = (
1828             # Black should not insert empty lines at the beginning
1829             # of the file
1830             0
1831             if self.previous_line is None
1832             else before - self.previous_after
1833         )
1834         self.previous_after = after
1835         self.previous_line = current_line
1836         return before, after
1837
1838     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1839         max_allowed = 1
1840         if current_line.depth == 0:
1841             max_allowed = 1 if self.is_pyi else 2
1842         if current_line.leaves:
1843             # Consume the first leaf's extra newlines.
1844             first_leaf = current_line.leaves[0]
1845             before = first_leaf.prefix.count("\n")
1846             before = min(before, max_allowed)
1847             first_leaf.prefix = ""
1848         else:
1849             before = 0
1850         depth = current_line.depth
1851         while self.previous_defs and self.previous_defs[-1] >= depth:
1852             self.previous_defs.pop()
1853             if self.is_pyi:
1854                 before = 0 if depth else 1
1855             else:
1856                 before = 1 if depth else 2
1857         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1858             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1859
1860         if (
1861             self.previous_line
1862             and self.previous_line.is_import
1863             and not current_line.is_import
1864             and depth == self.previous_line.depth
1865         ):
1866             return (before or 1), 0
1867
1868         if (
1869             self.previous_line
1870             and self.previous_line.is_class
1871             and current_line.is_triple_quoted_string
1872         ):
1873             return before, 1
1874
1875         return before, 0
1876
1877     def _maybe_empty_lines_for_class_or_def(
1878         self, current_line: Line, before: int
1879     ) -> Tuple[int, int]:
1880         if not current_line.is_decorator:
1881             self.previous_defs.append(current_line.depth)
1882         if self.previous_line is None:
1883             # Don't insert empty lines before the first line in the file.
1884             return 0, 0
1885
1886         if self.previous_line.is_decorator:
1887             return 0, 0
1888
1889         if self.previous_line.depth < current_line.depth and (
1890             self.previous_line.is_class or self.previous_line.is_def
1891         ):
1892             return 0, 0
1893
1894         if (
1895             self.previous_line.is_comment
1896             and self.previous_line.depth == current_line.depth
1897             and before == 0
1898         ):
1899             return 0, 0
1900
1901         if self.is_pyi:
1902             if self.previous_line.depth > current_line.depth:
1903                 newlines = 1
1904             elif current_line.is_class or self.previous_line.is_class:
1905                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1906                     # No blank line between classes with an empty body
1907                     newlines = 0
1908                 else:
1909                     newlines = 1
1910             elif current_line.is_def and not self.previous_line.is_def:
1911                 # Blank line between a block of functions and a block of non-functions
1912                 newlines = 1
1913             else:
1914                 newlines = 0
1915         else:
1916             newlines = 2
1917         if current_line.depth and newlines:
1918             newlines -= 1
1919         return newlines, 0
1920
1921
1922 @dataclass
1923 class LineGenerator(Visitor[Line]):
1924     """Generates reformatted Line objects.  Empty lines are not emitted.
1925
1926     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1927     in ways that will no longer stringify to valid Python code on the tree.
1928     """
1929
1930     is_pyi: bool = False
1931     normalize_strings: bool = True
1932     current_line: Line = field(default_factory=Line)
1933     remove_u_prefix: bool = False
1934
1935     def line(self, indent: int = 0) -> Iterator[Line]:
1936         """Generate a line.
1937
1938         If the line is empty, only emit if it makes sense.
1939         If the line is too long, split it first and then generate.
1940
1941         If any lines were generated, set up a new current_line.
1942         """
1943         if not self.current_line:
1944             self.current_line.depth += indent
1945             return  # Line is empty, don't emit. Creating a new one unnecessary.
1946
1947         complete_line = self.current_line
1948         self.current_line = Line(depth=complete_line.depth + indent)
1949         yield complete_line
1950
1951     def visit_default(self, node: LN) -> Iterator[Line]:
1952         """Default `visit_*()` implementation. Recurses to children of `node`."""
1953         if isinstance(node, Leaf):
1954             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1955             for comment in generate_comments(node):
1956                 if any_open_brackets:
1957                     # any comment within brackets is subject to splitting
1958                     self.current_line.append(comment)
1959                 elif comment.type == token.COMMENT:
1960                     # regular trailing comment
1961                     self.current_line.append(comment)
1962                     yield from self.line()
1963
1964                 else:
1965                     # regular standalone comment
1966                     yield from self.line()
1967
1968                     self.current_line.append(comment)
1969                     yield from self.line()
1970
1971             normalize_prefix(node, inside_brackets=any_open_brackets)
1972             if self.normalize_strings and node.type == token.STRING:
1973                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1974                 normalize_string_quotes(node)
1975             if node.type == token.NUMBER:
1976                 normalize_numeric_literal(node)
1977             if node.type not in WHITESPACE:
1978                 self.current_line.append(node)
1979         yield from super().visit_default(node)
1980
1981     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1982         """Increase indentation level, maybe yield a line."""
1983         # In blib2to3 INDENT never holds comments.
1984         yield from self.line(+1)
1985         yield from self.visit_default(node)
1986
1987     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1988         """Decrease indentation level, maybe yield a line."""
1989         # The current line might still wait for trailing comments.  At DEDENT time
1990         # there won't be any (they would be prefixes on the preceding NEWLINE).
1991         # Emit the line then.
1992         yield from self.line()
1993
1994         # While DEDENT has no value, its prefix may contain standalone comments
1995         # that belong to the current indentation level.  Get 'em.
1996         yield from self.visit_default(node)
1997
1998         # Finally, emit the dedent.
1999         yield from self.line(-1)
2000
2001     def visit_stmt(
2002         self, node: Node, keywords: Set[str], parens: Set[str]
2003     ) -> Iterator[Line]:
2004         """Visit a statement.
2005
2006         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2007         `def`, `with`, `class`, `assert` and assignments.
2008
2009         The relevant Python language `keywords` for a given statement will be
2010         NAME leaves within it. This methods puts those on a separate line.
2011
2012         `parens` holds a set of string leaf values immediately after which
2013         invisible parens should be put.
2014         """
2015         normalize_invisible_parens(node, parens_after=parens)
2016         for child in node.children:
2017             if child.type == token.NAME and child.value in keywords:  # type: ignore
2018                 yield from self.line()
2019
2020             yield from self.visit(child)
2021
2022     def visit_suite(self, node: Node) -> Iterator[Line]:
2023         """Visit a suite."""
2024         if self.is_pyi and is_stub_suite(node):
2025             yield from self.visit(node.children[2])
2026         else:
2027             yield from self.visit_default(node)
2028
2029     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2030         """Visit a statement without nested statements."""
2031         is_suite_like = node.parent and node.parent.type in STATEMENT
2032         if is_suite_like:
2033             if self.is_pyi and is_stub_body(node):
2034                 yield from self.visit_default(node)
2035             else:
2036                 yield from self.line(+1)
2037                 yield from self.visit_default(node)
2038                 yield from self.line(-1)
2039
2040         else:
2041             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
2042                 yield from self.line()
2043             yield from self.visit_default(node)
2044
2045     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2046         """Visit `async def`, `async for`, `async with`."""
2047         yield from self.line()
2048
2049         children = iter(node.children)
2050         for child in children:
2051             yield from self.visit(child)
2052
2053             if child.type == token.ASYNC:
2054                 break
2055
2056         internal_stmt = next(children)
2057         for child in internal_stmt.children:
2058             yield from self.visit(child)
2059
2060     def visit_decorators(self, node: Node) -> Iterator[Line]:
2061         """Visit decorators."""
2062         for child in node.children:
2063             yield from self.line()
2064             yield from self.visit(child)
2065
2066     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2067         """Remove a semicolon and put the other statement on a separate line."""
2068         yield from self.line()
2069
2070     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2071         """End of file. Process outstanding comments and end with a newline."""
2072         yield from self.visit_default(leaf)
2073         yield from self.line()
2074
2075     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2076         if not self.current_line.bracket_tracker.any_open_brackets():
2077             yield from self.line()
2078         yield from self.visit_default(leaf)
2079
2080     def visit_factor(self, node: Node) -> Iterator[Line]:
2081         """Force parentheses between a unary op and a binary power:
2082
2083         -2 ** 8 -> -(2 ** 8)
2084         """
2085         _operator, operand = node.children
2086         if (
2087             operand.type == syms.power
2088             and len(operand.children) == 3
2089             and operand.children[1].type == token.DOUBLESTAR
2090         ):
2091             lpar = Leaf(token.LPAR, "(")
2092             rpar = Leaf(token.RPAR, ")")
2093             index = operand.remove() or 0
2094             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2095         yield from self.visit_default(node)
2096
2097     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2098         # Check if it's a docstring
2099         if prev_siblings_are(
2100             leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
2101         ) and is_multiline_string(leaf):
2102             prefix = "    " * self.current_line.depth
2103             docstring = fix_docstring(leaf.value[3:-3], prefix)
2104             leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
2105             normalize_string_quotes(leaf)
2106
2107         yield from self.visit_default(leaf)
2108
2109     def __post_init__(self) -> None:
2110         """You are in a twisty little maze of passages."""
2111         v = self.visit_stmt
2112         Ø: Set[str] = set()
2113         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2114         self.visit_if_stmt = partial(
2115             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2116         )
2117         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2118         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2119         self.visit_try_stmt = partial(
2120             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2121         )
2122         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2123         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2124         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2125         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2126         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2127         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2128         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2129         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2130         self.visit_async_funcdef = self.visit_async_stmt
2131         self.visit_decorated = self.visit_decorators
2132
2133
2134 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2135 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2136 OPENING_BRACKETS = set(BRACKET.keys())
2137 CLOSING_BRACKETS = set(BRACKET.values())
2138 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2139 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2140
2141
2142 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2143     """Return whitespace prefix if needed for the given `leaf`.
2144
2145     `complex_subscript` signals whether the given leaf is part of a subscription
2146     which has non-trivial arguments, like arithmetic expressions or function calls.
2147     """
2148     NO = ""
2149     SPACE = " "
2150     DOUBLESPACE = "  "
2151     t = leaf.type
2152     p = leaf.parent
2153     v = leaf.value
2154     if t in ALWAYS_NO_SPACE:
2155         return NO
2156
2157     if t == token.COMMENT:
2158         return DOUBLESPACE
2159
2160     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2161     if t == token.COLON and p.type not in {
2162         syms.subscript,
2163         syms.subscriptlist,
2164         syms.sliceop,
2165     }:
2166         return NO
2167
2168     prev = leaf.prev_sibling
2169     if not prev:
2170         prevp = preceding_leaf(p)
2171         if not prevp or prevp.type in OPENING_BRACKETS:
2172             return NO
2173
2174         if t == token.COLON:
2175             if prevp.type == token.COLON:
2176                 return NO
2177
2178             elif prevp.type != token.COMMA and not complex_subscript:
2179                 return NO
2180
2181             return SPACE
2182
2183         if prevp.type == token.EQUAL:
2184             if prevp.parent:
2185                 if prevp.parent.type in {
2186                     syms.arglist,
2187                     syms.argument,
2188                     syms.parameters,
2189                     syms.varargslist,
2190                 }:
2191                     return NO
2192
2193                 elif prevp.parent.type == syms.typedargslist:
2194                     # A bit hacky: if the equal sign has whitespace, it means we
2195                     # previously found it's a typed argument.  So, we're using
2196                     # that, too.
2197                     return prevp.prefix
2198
2199         elif prevp.type in VARARGS_SPECIALS:
2200             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2201                 return NO
2202
2203         elif prevp.type == token.COLON:
2204             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2205                 return SPACE if complex_subscript else NO
2206
2207         elif (
2208             prevp.parent
2209             and prevp.parent.type == syms.factor
2210             and prevp.type in MATH_OPERATORS
2211         ):
2212             return NO
2213
2214         elif (
2215             prevp.type == token.RIGHTSHIFT
2216             and prevp.parent
2217             and prevp.parent.type == syms.shift_expr
2218             and prevp.prev_sibling
2219             and prevp.prev_sibling.type == token.NAME
2220             and prevp.prev_sibling.value == "print"  # type: ignore
2221         ):
2222             # Python 2 print chevron
2223             return NO
2224
2225     elif prev.type in OPENING_BRACKETS:
2226         return NO
2227
2228     if p.type in {syms.parameters, syms.arglist}:
2229         # untyped function signatures or calls
2230         if not prev or prev.type != token.COMMA:
2231             return NO
2232
2233     elif p.type == syms.varargslist:
2234         # lambdas
2235         if prev and prev.type != token.COMMA:
2236             return NO
2237
2238     elif p.type == syms.typedargslist:
2239         # typed function signatures
2240         if not prev:
2241             return NO
2242
2243         if t == token.EQUAL:
2244             if prev.type != syms.tname:
2245                 return NO
2246
2247         elif prev.type == token.EQUAL:
2248             # A bit hacky: if the equal sign has whitespace, it means we
2249             # previously found it's a typed argument.  So, we're using that, too.
2250             return prev.prefix
2251
2252         elif prev.type != token.COMMA:
2253             return NO
2254
2255     elif p.type == syms.tname:
2256         # type names
2257         if not prev:
2258             prevp = preceding_leaf(p)
2259             if not prevp or prevp.type != token.COMMA:
2260                 return NO
2261
2262     elif p.type == syms.trailer:
2263         # attributes and calls
2264         if t == token.LPAR or t == token.RPAR:
2265             return NO
2266
2267         if not prev:
2268             if t == token.DOT:
2269                 prevp = preceding_leaf(p)
2270                 if not prevp or prevp.type != token.NUMBER:
2271                     return NO
2272
2273             elif t == token.LSQB:
2274                 return NO
2275
2276         elif prev.type != token.COMMA:
2277             return NO
2278
2279     elif p.type == syms.argument:
2280         # single argument
2281         if t == token.EQUAL:
2282             return NO
2283
2284         if not prev:
2285             prevp = preceding_leaf(p)
2286             if not prevp or prevp.type == token.LPAR:
2287                 return NO
2288
2289         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2290             return NO
2291
2292     elif p.type == syms.decorator:
2293         # decorators
2294         return NO
2295
2296     elif p.type == syms.dotted_name:
2297         if prev:
2298             return NO
2299
2300         prevp = preceding_leaf(p)
2301         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2302             return NO
2303
2304     elif p.type == syms.classdef:
2305         if t == token.LPAR:
2306             return NO
2307
2308         if prev and prev.type == token.LPAR:
2309             return NO
2310
2311     elif p.type in {syms.subscript, syms.sliceop}:
2312         # indexing
2313         if not prev:
2314             assert p.parent is not None, "subscripts are always parented"
2315             if p.parent.type == syms.subscriptlist:
2316                 return SPACE
2317
2318             return NO
2319
2320         elif not complex_subscript:
2321             return NO
2322
2323     elif p.type == syms.atom:
2324         if prev and t == token.DOT:
2325             # dots, but not the first one.
2326             return NO
2327
2328     elif p.type == syms.dictsetmaker:
2329         # dict unpacking
2330         if prev and prev.type == token.DOUBLESTAR:
2331             return NO
2332
2333     elif p.type in {syms.factor, syms.star_expr}:
2334         # unary ops
2335         if not prev:
2336             prevp = preceding_leaf(p)
2337             if not prevp or prevp.type in OPENING_BRACKETS:
2338                 return NO
2339
2340             prevp_parent = prevp.parent
2341             assert prevp_parent is not None
2342             if prevp.type == token.COLON and prevp_parent.type in {
2343                 syms.subscript,
2344                 syms.sliceop,
2345             }:
2346                 return NO
2347
2348             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2349                 return NO
2350
2351         elif t in {token.NAME, token.NUMBER, token.STRING}:
2352             return NO
2353
2354     elif p.type == syms.import_from:
2355         if t == token.DOT:
2356             if prev and prev.type == token.DOT:
2357                 return NO
2358
2359         elif t == token.NAME:
2360             if v == "import":
2361                 return SPACE
2362
2363             if prev and prev.type == token.DOT:
2364                 return NO
2365
2366     elif p.type == syms.sliceop:
2367         return NO
2368
2369     return SPACE
2370
2371
2372 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2373     """Return the first leaf that precedes `node`, if any."""
2374     while node:
2375         res = node.prev_sibling
2376         if res:
2377             if isinstance(res, Leaf):
2378                 return res
2379
2380             try:
2381                 return list(res.leaves())[-1]
2382
2383             except IndexError:
2384                 return None
2385
2386         node = node.parent
2387     return None
2388
2389
2390 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2391     """Return if the `node` and its previous siblings match types against the provided
2392     list of tokens; the provided `node`has its type matched against the last element in
2393     the list.  `None` can be used as the first element to declare that the start of the
2394     list is anchored at the start of its parent's children."""
2395     if not tokens:
2396         return True
2397     if tokens[-1] is None:
2398         return node is None
2399     if not node:
2400         return False
2401     if node.type != tokens[-1]:
2402         return False
2403     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2404
2405
2406 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2407     """Return the child of `ancestor` that contains `descendant`."""
2408     node: Optional[LN] = descendant
2409     while node and node.parent != ancestor:
2410         node = node.parent
2411     return node
2412
2413
2414 def container_of(leaf: Leaf) -> LN:
2415     """Return `leaf` or one of its ancestors that is the topmost container of it.
2416
2417     By "container" we mean a node where `leaf` is the very first child.
2418     """
2419     same_prefix = leaf.prefix
2420     container: LN = leaf
2421     while container:
2422         parent = container.parent
2423         if parent is None:
2424             break
2425
2426         if parent.children[0].prefix != same_prefix:
2427             break
2428
2429         if parent.type == syms.file_input:
2430             break
2431
2432         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2433             break
2434
2435         container = parent
2436     return container
2437
2438
2439 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2440     """Return the priority of the `leaf` delimiter, given a line break after it.
2441
2442     The delimiter priorities returned here are from those delimiters that would
2443     cause a line break after themselves.
2444
2445     Higher numbers are higher priority.
2446     """
2447     if leaf.type == token.COMMA:
2448         return COMMA_PRIORITY
2449
2450     return 0
2451
2452
2453 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2454     """Return the priority of the `leaf` delimiter, given a line break before it.
2455
2456     The delimiter priorities returned here are from those delimiters that would
2457     cause a line break before themselves.
2458
2459     Higher numbers are higher priority.
2460     """
2461     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2462         # * and ** might also be MATH_OPERATORS but in this case they are not.
2463         # Don't treat them as a delimiter.
2464         return 0
2465
2466     if (
2467         leaf.type == token.DOT
2468         and leaf.parent
2469         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2470         and (previous is None or previous.type in CLOSING_BRACKETS)
2471     ):
2472         return DOT_PRIORITY
2473
2474     if (
2475         leaf.type in MATH_OPERATORS
2476         and leaf.parent
2477         and leaf.parent.type not in {syms.factor, syms.star_expr}
2478     ):
2479         return MATH_PRIORITIES[leaf.type]
2480
2481     if leaf.type in COMPARATORS:
2482         return COMPARATOR_PRIORITY
2483
2484     if (
2485         leaf.type == token.STRING
2486         and previous is not None
2487         and previous.type == token.STRING
2488     ):
2489         return STRING_PRIORITY
2490
2491     if leaf.type not in {token.NAME, token.ASYNC}:
2492         return 0
2493
2494     if (
2495         leaf.value == "for"
2496         and leaf.parent
2497         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2498         or leaf.type == token.ASYNC
2499     ):
2500         if (
2501             not isinstance(leaf.prev_sibling, Leaf)
2502             or leaf.prev_sibling.value != "async"
2503         ):
2504             return COMPREHENSION_PRIORITY
2505
2506     if (
2507         leaf.value == "if"
2508         and leaf.parent
2509         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2510     ):
2511         return COMPREHENSION_PRIORITY
2512
2513     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2514         return TERNARY_PRIORITY
2515
2516     if leaf.value == "is":
2517         return COMPARATOR_PRIORITY
2518
2519     if (
2520         leaf.value == "in"
2521         and leaf.parent
2522         and leaf.parent.type in {syms.comp_op, syms.comparison}
2523         and not (
2524             previous is not None
2525             and previous.type == token.NAME
2526             and previous.value == "not"
2527         )
2528     ):
2529         return COMPARATOR_PRIORITY
2530
2531     if (
2532         leaf.value == "not"
2533         and leaf.parent
2534         and leaf.parent.type == syms.comp_op
2535         and not (
2536             previous is not None
2537             and previous.type == token.NAME
2538             and previous.value == "is"
2539         )
2540     ):
2541         return COMPARATOR_PRIORITY
2542
2543     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2544         return LOGIC_PRIORITY
2545
2546     return 0
2547
2548
2549 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2550 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2551
2552
2553 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2554     """Clean the prefix of the `leaf` and generate comments from it, if any.
2555
2556     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2557     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2558     move because it does away with modifying the grammar to include all the
2559     possible places in which comments can be placed.
2560
2561     The sad consequence for us though is that comments don't "belong" anywhere.
2562     This is why this function generates simple parentless Leaf objects for
2563     comments.  We simply don't know what the correct parent should be.
2564
2565     No matter though, we can live without this.  We really only need to
2566     differentiate between inline and standalone comments.  The latter don't
2567     share the line with any code.
2568
2569     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2570     are emitted with a fake STANDALONE_COMMENT token identifier.
2571     """
2572     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2573         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2574
2575
2576 @dataclass
2577 class ProtoComment:
2578     """Describes a piece of syntax that is a comment.
2579
2580     It's not a :class:`blib2to3.pytree.Leaf` so that:
2581
2582     * it can be cached (`Leaf` objects should not be reused more than once as
2583       they store their lineno, column, prefix, and parent information);
2584     * `newlines` and `consumed` fields are kept separate from the `value`. This
2585       simplifies handling of special marker comments like ``# fmt: off/on``.
2586     """
2587
2588     type: int  # token.COMMENT or STANDALONE_COMMENT
2589     value: str  # content of the comment
2590     newlines: int  # how many newlines before the comment
2591     consumed: int  # how many characters of the original leaf's prefix did we consume
2592
2593
2594 @lru_cache(maxsize=4096)
2595 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2596     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2597     result: List[ProtoComment] = []
2598     if not prefix or "#" not in prefix:
2599         return result
2600
2601     consumed = 0
2602     nlines = 0
2603     ignored_lines = 0
2604     for index, line in enumerate(prefix.split("\n")):
2605         consumed += len(line) + 1  # adding the length of the split '\n'
2606         line = line.lstrip()
2607         if not line:
2608             nlines += 1
2609         if not line.startswith("#"):
2610             # Escaped newlines outside of a comment are not really newlines at
2611             # all. We treat a single-line comment following an escaped newline
2612             # as a simple trailing comment.
2613             if line.endswith("\\"):
2614                 ignored_lines += 1
2615             continue
2616
2617         if index == ignored_lines and not is_endmarker:
2618             comment_type = token.COMMENT  # simple trailing comment
2619         else:
2620             comment_type = STANDALONE_COMMENT
2621         comment = make_comment(line)
2622         result.append(
2623             ProtoComment(
2624                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2625             )
2626         )
2627         nlines = 0
2628     return result
2629
2630
2631 def make_comment(content: str) -> str:
2632     """Return a consistently formatted comment from the given `content` string.
2633
2634     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2635     space between the hash sign and the content.
2636
2637     If `content` didn't start with a hash sign, one is provided.
2638     """
2639     content = content.rstrip()
2640     if not content:
2641         return "#"
2642
2643     if content[0] == "#":
2644         content = content[1:]
2645     if content and content[0] not in " !:#'%":
2646         content = " " + content
2647     return "#" + content
2648
2649
2650 def transform_line(
2651     line: Line,
2652     line_length: int,
2653     normalize_strings: bool,
2654     features: Collection[Feature] = (),
2655 ) -> Iterator[Line]:
2656     """Transform a `line`, potentially splitting it into many lines.
2657
2658     They should fit in the allotted `line_length` but might not be able to.
2659
2660     `features` are syntactical features that may be used in the output.
2661     """
2662     if line.is_comment:
2663         yield line
2664         return
2665
2666     line_str = line_to_string(line)
2667
2668     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2669         """Initialize StringTransformer"""
2670         return ST(line_length, normalize_strings)
2671
2672     string_merge = init_st(StringMerger)
2673     string_paren_strip = init_st(StringParenStripper)
2674     string_split = init_st(StringSplitter)
2675     string_paren_wrap = init_st(StringParenWrapper)
2676
2677     transformers: List[Transformer]
2678     if (
2679         not line.contains_uncollapsable_type_comments()
2680         and not line.should_explode
2681         and not line.is_collection_with_optional_trailing_comma
2682         and (
2683             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2684             or line.contains_unsplittable_type_ignore()
2685         )
2686         and not (line.contains_standalone_comments() and line.inside_brackets)
2687     ):
2688         # Only apply basic string preprocessing, since lines shouldn't be split here.
2689         transformers = [string_merge, string_paren_strip]
2690     elif line.is_def:
2691         transformers = [left_hand_split]
2692     else:
2693
2694         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2695             for omit in generate_trailers_to_omit(line, line_length):
2696                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2697                 if is_line_short_enough(lines[0], line_length=line_length):
2698                     yield from lines
2699                     return
2700
2701             # All splits failed, best effort split with no omits.
2702             # This mostly happens to multiline strings that are by definition
2703             # reported as not fitting a single line.
2704             # line_length=1 here was historically a bug that somehow became a feature.
2705             # See #762 and #781 for the full story.
2706             yield from right_hand_split(line, line_length=1, features=features)
2707
2708         if line.inside_brackets:
2709             transformers = [
2710                 string_merge,
2711                 string_paren_strip,
2712                 delimiter_split,
2713                 standalone_comment_split,
2714                 string_split,
2715                 string_paren_wrap,
2716                 rhs,
2717             ]
2718         else:
2719             transformers = [
2720                 string_merge,
2721                 string_paren_strip,
2722                 string_split,
2723                 string_paren_wrap,
2724                 rhs,
2725             ]
2726
2727     for transform in transformers:
2728         # We are accumulating lines in `result` because we might want to abort
2729         # mission and return the original line in the end, or attempt a different
2730         # split altogether.
2731         result: List[Line] = []
2732         try:
2733             for l in transform(line, features):
2734                 if str(l).strip("\n") == line_str:
2735                     raise CannotTransform(
2736                         "Line transformer returned an unchanged result"
2737                     )
2738
2739                 result.extend(
2740                     transform_line(
2741                         l,
2742                         line_length=line_length,
2743                         normalize_strings=normalize_strings,
2744                         features=features,
2745                     )
2746                 )
2747         except CannotTransform:
2748             continue
2749         else:
2750             yield from result
2751             break
2752
2753     else:
2754         yield line
2755
2756
2757 @dataclass  # type: ignore
2758 class StringTransformer(ABC):
2759     """
2760     An implementation of the Transformer protocol that relies on its
2761     subclasses overriding the template methods `do_match(...)` and
2762     `do_transform(...)`.
2763
2764     This Transformer works exclusively on strings (for example, by merging
2765     or splitting them).
2766
2767     The following sections can be found among the docstrings of each concrete
2768     StringTransformer subclass.
2769
2770     Requirements:
2771         Which requirements must be met of the given Line for this
2772         StringTransformer to be applied?
2773
2774     Transformations:
2775         If the given Line meets all of the above requirments, which string
2776         transformations can you expect to be applied to it by this
2777         StringTransformer?
2778
2779     Collaborations:
2780         What contractual agreements does this StringTransformer have with other
2781         StringTransfomers? Such collaborations should be eliminated/minimized
2782         as much as possible.
2783     """
2784
2785     line_length: int
2786     normalize_strings: bool
2787
2788     @abstractmethod
2789     def do_match(self, line: Line) -> TMatchResult:
2790         """
2791         Returns:
2792             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2793             string, if a match was able to be made.
2794                 OR
2795             * Err(CannotTransform), if a match was not able to be made.
2796         """
2797
2798     @abstractmethod
2799     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2800         """
2801         Yields:
2802             * Ok(new_line) where new_line is the new transformed line.
2803                 OR
2804             * Err(CannotTransform) if the transformation failed for some reason. The
2805             `do_match(...)` template method should usually be used to reject
2806             the form of the given Line, but in some cases it is difficult to
2807             know whether or not a Line meets the StringTransformer's
2808             requirements until the transformation is already midway.
2809
2810         Side Effects:
2811             This method should NOT mutate @line directly, but it MAY mutate the
2812             Line's underlying Node structure. (WARNING: If the underlying Node
2813             structure IS altered, then this method should NOT be allowed to
2814             yield an CannotTransform after that point.)
2815         """
2816
2817     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2818         """
2819         StringTransformer instances have a call signature that mirrors that of
2820         the Transformer type.
2821
2822         Raises:
2823             CannotTransform(...) if the concrete StringTransformer class is unable
2824             to transform @line.
2825         """
2826         # Optimization to avoid calling `self.do_match(...)` when the line does
2827         # not contain any string.
2828         if not any(leaf.type == token.STRING for leaf in line.leaves):
2829             raise CannotTransform("There are no strings in this line.")
2830
2831         match_result = self.do_match(line)
2832
2833         if isinstance(match_result, Err):
2834             cant_transform = match_result.err()
2835             raise CannotTransform(
2836                 f"The string transformer {self.__class__.__name__} does not recognize"
2837                 " this line as one that it can transform."
2838             ) from cant_transform
2839
2840         string_idx = match_result.ok()
2841
2842         for line_result in self.do_transform(line, string_idx):
2843             if isinstance(line_result, Err):
2844                 cant_transform = line_result.err()
2845                 raise CannotTransform(
2846                     "StringTransformer failed while attempting to transform string."
2847                 ) from cant_transform
2848             line = line_result.ok()
2849             yield line
2850
2851
2852 @dataclass
2853 class CustomSplit:
2854     """A custom (i.e. manual) string split.
2855
2856     A single CustomSplit instance represents a single substring.
2857
2858     Examples:
2859         Consider the following string:
2860         ```
2861         "Hi there friend."
2862         " This is a custom"
2863         f" string {split}."
2864         ```
2865
2866         This string will correspond to the following three CustomSplit instances:
2867         ```
2868         CustomSplit(False, 16)
2869         CustomSplit(False, 17)
2870         CustomSplit(True, 16)
2871         ```
2872     """
2873
2874     has_prefix: bool
2875     break_idx: int
2876
2877
2878 class CustomSplitMapMixin:
2879     """
2880     This mixin class is used to map merged strings to a sequence of
2881     CustomSplits, which will then be used to re-split the strings iff none of
2882     the resultant substrings go over the configured max line length.
2883     """
2884
2885     _Key = Tuple[StringID, str]
2886     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2887
2888     @staticmethod
2889     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2890         """
2891         Returns:
2892             A unique identifier that is used internally to map @string to a
2893             group of custom splits.
2894         """
2895         return (id(string), string)
2896
2897     def add_custom_splits(
2898         self, string: str, custom_splits: Iterable[CustomSplit]
2899     ) -> None:
2900         """Custom Split Map Setter Method
2901
2902         Side Effects:
2903             Adds a mapping from @string to the custom splits @custom_splits.
2904         """
2905         key = self._get_key(string)
2906         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2907
2908     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2909         """Custom Split Map Getter Method
2910
2911         Returns:
2912             * A list of the custom splits that are mapped to @string, if any
2913             exist.
2914                 OR
2915             * [], otherwise.
2916
2917         Side Effects:
2918             Deletes the mapping between @string and its associated custom
2919             splits (which are returned to the caller).
2920         """
2921         key = self._get_key(string)
2922
2923         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2924         del self._CUSTOM_SPLIT_MAP[key]
2925
2926         return list(custom_splits)
2927
2928     def has_custom_splits(self, string: str) -> bool:
2929         """
2930         Returns:
2931             True iff @string is associated with a set of custom splits.
2932         """
2933         key = self._get_key(string)
2934         return key in self._CUSTOM_SPLIT_MAP
2935
2936
2937 class StringMerger(CustomSplitMapMixin, StringTransformer):
2938     """StringTransformer that merges strings together.
2939
2940     Requirements:
2941         (A) The line contains adjacent strings such that at most one substring
2942         has inline comments AND none of those inline comments are pragmas AND
2943         the set of all substring prefixes is either of length 1 or equal to
2944         {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed
2945         with 'r').
2946             OR
2947         (B) The line contains a string which uses line continuation backslashes.
2948
2949     Transformations:
2950         Depending on which of the two requirements above where met, either:
2951
2952         (A) The string group associated with the target string is merged.
2953             OR
2954         (B) All line-continuation backslashes are removed from the target string.
2955
2956     Collaborations:
2957         StringMerger provides custom split information to StringSplitter.
2958     """
2959
2960     def do_match(self, line: Line) -> TMatchResult:
2961         LL = line.leaves
2962
2963         is_valid_index = is_valid_index_factory(LL)
2964
2965         for (i, leaf) in enumerate(LL):
2966             if (
2967                 leaf.type == token.STRING
2968                 and is_valid_index(i + 1)
2969                 and LL[i + 1].type == token.STRING
2970             ):
2971                 return Ok(i)
2972
2973             if leaf.type == token.STRING and "\\\n" in leaf.value:
2974                 return Ok(i)
2975
2976         return TErr("This line has no strings that need merging.")
2977
2978     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2979         new_line = line
2980         rblc_result = self.__remove_backslash_line_continuation_chars(
2981             new_line, string_idx
2982         )
2983         if isinstance(rblc_result, Ok):
2984             new_line = rblc_result.ok()
2985
2986         msg_result = self.__merge_string_group(new_line, string_idx)
2987         if isinstance(msg_result, Ok):
2988             new_line = msg_result.ok()
2989
2990         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
2991             msg_cant_transform = msg_result.err()
2992             rblc_cant_transform = rblc_result.err()
2993             cant_transform = CannotTransform(
2994                 "StringMerger failed to merge any strings in this line."
2995             )
2996
2997             # Chain the errors together using `__cause__`.
2998             msg_cant_transform.__cause__ = rblc_cant_transform
2999             cant_transform.__cause__ = msg_cant_transform
3000
3001             yield Err(cant_transform)
3002         else:
3003             yield Ok(new_line)
3004
3005     @staticmethod
3006     def __remove_backslash_line_continuation_chars(
3007         line: Line, string_idx: int
3008     ) -> TResult[Line]:
3009         """
3010         Merge strings that were split across multiple lines using
3011         line-continuation backslashes.
3012
3013         Returns:
3014             Ok(new_line), if @line contains backslash line-continuation
3015             characters.
3016                 OR
3017             Err(CannotTransform), otherwise.
3018         """
3019         LL = line.leaves
3020
3021         string_leaf = LL[string_idx]
3022         if not (
3023             string_leaf.type == token.STRING
3024             and "\\\n" in string_leaf.value
3025             and not has_triple_quotes(string_leaf.value)
3026         ):
3027             return TErr(
3028                 f"String leaf {string_leaf} does not contain any backslash line"
3029                 " continuation characters."
3030             )
3031
3032         new_line = line.clone()
3033         new_line.comments = line.comments
3034         append_leaves(new_line, line, LL)
3035
3036         new_string_leaf = new_line.leaves[string_idx]
3037         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3038
3039         return Ok(new_line)
3040
3041     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3042         """
3043         Merges string group (i.e. set of adjacent strings) where the first
3044         string in the group is `line.leaves[string_idx]`.
3045
3046         Returns:
3047             Ok(new_line), if ALL of the validation checks found in
3048             __validate_msg(...) pass.
3049                 OR
3050             Err(CannotTransform), otherwise.
3051         """
3052         LL = line.leaves
3053
3054         is_valid_index = is_valid_index_factory(LL)
3055
3056         vresult = self.__validate_msg(line, string_idx)
3057         if isinstance(vresult, Err):
3058             return vresult
3059
3060         # If the string group is wrapped inside an Atom node, we must make sure
3061         # to later replace that Atom with our new (merged) string leaf.
3062         atom_node = LL[string_idx].parent
3063
3064         # We will place BREAK_MARK in between every two substrings that we
3065         # merge. We will then later go through our final result and use the
3066         # various instances of BREAK_MARK we find to add the right values to
3067         # the custom split map.
3068         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3069
3070         QUOTE = LL[string_idx].value[-1]
3071
3072         def make_naked(string: str, string_prefix: str) -> str:
3073             """Strip @string (i.e. make it a "naked" string)
3074
3075             Pre-conditions:
3076                 * assert_is_leaf_string(@string)
3077
3078             Returns:
3079                 A string that is identical to @string except that
3080                 @string_prefix has been stripped, the surrounding QUOTE
3081                 characters have been removed, and any remaining QUOTE
3082                 characters have been escaped.
3083             """
3084             assert_is_leaf_string(string)
3085
3086             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3087             naked_string = string[len(string_prefix) + 1 : -1]
3088             naked_string = re.sub(
3089                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3090             )
3091             return naked_string
3092
3093         # Holds the CustomSplit objects that will later be added to the custom
3094         # split map.
3095         custom_splits = []
3096
3097         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3098         prefix_tracker = []
3099
3100         # Sets the 'prefix' variable. This is the prefix that the final merged
3101         # string will have.
3102         next_str_idx = string_idx
3103         prefix = ""
3104         while (
3105             not prefix
3106             and is_valid_index(next_str_idx)
3107             and LL[next_str_idx].type == token.STRING
3108         ):
3109             prefix = get_string_prefix(LL[next_str_idx].value)
3110             next_str_idx += 1
3111
3112         # The next loop merges the string group. The final string will be
3113         # contained in 'S'.
3114         #
3115         # The following convenience variables are used:
3116         #
3117         #   S: string
3118         #   NS: naked string
3119         #   SS: next string
3120         #   NSS: naked next string
3121         S = ""
3122         NS = ""
3123         num_of_strings = 0
3124         next_str_idx = string_idx
3125         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3126             num_of_strings += 1
3127
3128             SS = LL[next_str_idx].value
3129             next_prefix = get_string_prefix(SS)
3130
3131             # If this is an f-string group but this substring is not prefixed
3132             # with 'f'...
3133             if "f" in prefix and "f" not in next_prefix:
3134                 # Then we must escape any braces contained in this substring.
3135                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3136
3137             NSS = make_naked(SS, next_prefix)
3138
3139             has_prefix = bool(next_prefix)
3140             prefix_tracker.append(has_prefix)
3141
3142             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3143             NS = make_naked(S, prefix)
3144
3145             next_str_idx += 1
3146
3147         S_leaf = Leaf(token.STRING, S)
3148         if self.normalize_strings:
3149             normalize_string_quotes(S_leaf)
3150
3151         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3152         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3153         for has_prefix in prefix_tracker:
3154             mark_idx = temp_string.find(BREAK_MARK)
3155             assert (
3156                 mark_idx >= 0
3157             ), "Logic error while filling the custom string breakpoint cache."
3158
3159             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3160             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3161             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3162
3163         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3164
3165         if atom_node is not None:
3166             replace_child(atom_node, string_leaf)
3167
3168         # Build the final line ('new_line') that this method will later return.
3169         new_line = line.clone()
3170         for (i, leaf) in enumerate(LL):
3171             if i == string_idx:
3172                 new_line.append(string_leaf)
3173
3174             if string_idx <= i < string_idx + num_of_strings:
3175                 for comment_leaf in line.comments_after(LL[i]):
3176                     new_line.append(comment_leaf, preformatted=True)
3177                 continue
3178
3179             append_leaves(new_line, line, [leaf])
3180
3181         self.add_custom_splits(string_leaf.value, custom_splits)
3182         return Ok(new_line)
3183
3184     @staticmethod
3185     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3186         """Validate (M)erge (S)tring (G)roup
3187
3188         Transform-time string validation logic for __merge_string_group(...).
3189
3190         Returns:
3191             * Ok(None), if ALL validation checks (listed below) pass.
3192                 OR
3193             * Err(CannotTransform), if any of the following are true:
3194                 - The target string is not in a string group (i.e. it has no
3195                   adjacent strings).
3196                 - The string group has more than one inline comment.
3197                 - The string group has an inline comment that appears to be a pragma.
3198                 - The set of all string prefixes in the string group is of
3199                   length greater than one and is not equal to {"", "f"}.
3200                 - The string group consists of raw strings.
3201         """
3202         num_of_inline_string_comments = 0
3203         set_of_prefixes = set()
3204         num_of_strings = 0
3205         for leaf in line.leaves[string_idx:]:
3206             if leaf.type != token.STRING:
3207                 # If the string group is trailed by a comma, we count the
3208                 # comments trailing the comma to be one of the string group's
3209                 # comments.
3210                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3211                     num_of_inline_string_comments += 1
3212                 break
3213
3214             if has_triple_quotes(leaf.value):
3215                 return TErr("StringMerger does NOT merge multiline strings.")
3216
3217             num_of_strings += 1
3218             prefix = get_string_prefix(leaf.value)
3219             if "r" in prefix:
3220                 return TErr("StringMerger does NOT merge raw strings.")
3221
3222             set_of_prefixes.add(prefix)
3223
3224             if id(leaf) in line.comments:
3225                 num_of_inline_string_comments += 1
3226                 if contains_pragma_comment(line.comments[id(leaf)]):
3227                     return TErr("Cannot merge strings which have pragma comments.")
3228
3229         if num_of_strings < 2:
3230             return TErr(
3231                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3232             )
3233
3234         if num_of_inline_string_comments > 1:
3235             return TErr(
3236                 f"Too many inline string comments ({num_of_inline_string_comments})."
3237             )
3238
3239         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3240             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3241
3242         return Ok(None)
3243
3244
3245 class StringParenStripper(StringTransformer):
3246     """StringTransformer that strips surrounding parentheses from strings.
3247
3248     Requirements:
3249         The line contains a string which is surrounded by parentheses and:
3250             - The target string is NOT the only argument to a function call).
3251             - The RPAR is NOT followed by an attribute access (i.e. a dot).
3252
3253     Transformations:
3254         The parentheses mentioned in the 'Requirements' section are stripped.
3255
3256     Collaborations:
3257         StringParenStripper has its own inherent usefulness, but it is also
3258         relied on to clean up the parentheses created by StringParenWrapper (in
3259         the event that they are no longer needed).
3260     """
3261
3262     def do_match(self, line: Line) -> TMatchResult:
3263         LL = line.leaves
3264
3265         is_valid_index = is_valid_index_factory(LL)
3266
3267         for (idx, leaf) in enumerate(LL):
3268             # Should be a string...
3269             if leaf.type != token.STRING:
3270                 continue
3271
3272             # Should be preceded by a non-empty LPAR...
3273             if (
3274                 not is_valid_index(idx - 1)
3275                 or LL[idx - 1].type != token.LPAR
3276                 or is_empty_lpar(LL[idx - 1])
3277             ):
3278                 continue
3279
3280             # That LPAR should NOT be preceded by a function name or a closing
3281             # bracket (which could be a function which returns a function or a
3282             # list/dictionary that contains a function)...
3283             if is_valid_index(idx - 2) and (
3284                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3285             ):
3286                 continue
3287
3288             string_idx = idx
3289
3290             # Skip the string trailer, if one exists.
3291             string_parser = StringParser()
3292             next_idx = string_parser.parse(LL, string_idx)
3293
3294             # Should be followed by a non-empty RPAR...
3295             if (
3296                 is_valid_index(next_idx)
3297                 and LL[next_idx].type == token.RPAR
3298                 and not is_empty_rpar(LL[next_idx])
3299             ):
3300                 # That RPAR should NOT be followed by a '.' symbol.
3301                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type == token.DOT:
3302                     continue
3303
3304                 return Ok(string_idx)
3305
3306         return TErr("This line has no strings wrapped in parens.")
3307
3308     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3309         LL = line.leaves
3310
3311         string_parser = StringParser()
3312         rpar_idx = string_parser.parse(LL, string_idx)
3313
3314         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3315             if line.comments_after(leaf):
3316                 yield TErr(
3317                     "Will not strip parentheses which have comments attached to them."
3318                 )
3319
3320         new_line = line.clone()
3321         new_line.comments = line.comments.copy()
3322
3323         append_leaves(new_line, line, LL[: string_idx - 1])
3324
3325         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3326         LL[string_idx - 1].remove()
3327         replace_child(LL[string_idx], string_leaf)
3328         new_line.append(string_leaf)
3329
3330         append_leaves(
3331             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :],
3332         )
3333
3334         LL[rpar_idx].remove()
3335
3336         yield Ok(new_line)
3337
3338
3339 class BaseStringSplitter(StringTransformer):
3340     """
3341     Abstract class for StringTransformers which transform a Line's strings by splitting
3342     them or placing them on their own lines where necessary to avoid going over
3343     the configured line length.
3344
3345     Requirements:
3346         * The target string value is responsible for the line going over the
3347         line length limit. It follows that after all of black's other line
3348         split methods have been exhausted, this line (or one of the resulting
3349         lines after all line splits are performed) would still be over the
3350         line_length limit unless we split this string.
3351             AND
3352         * The target string is NOT a "pointless" string (i.e. a string that has
3353         no parent or siblings).
3354             AND
3355         * The target string is not followed by an inline comment that appears
3356         to be a pragma.
3357             AND
3358         * The target string is not a multiline (i.e. triple-quote) string.
3359     """
3360
3361     @abstractmethod
3362     def do_splitter_match(self, line: Line) -> TMatchResult:
3363         """
3364         BaseStringSplitter asks its clients to override this method instead of
3365         `StringTransformer.do_match(...)`.
3366
3367         Follows the same protocol as `StringTransformer.do_match(...)`.
3368
3369         Refer to `help(StringTransformer.do_match)` for more information.
3370         """
3371
3372     def do_match(self, line: Line) -> TMatchResult:
3373         match_result = self.do_splitter_match(line)
3374         if isinstance(match_result, Err):
3375             return match_result
3376
3377         string_idx = match_result.ok()
3378         vresult = self.__validate(line, string_idx)
3379         if isinstance(vresult, Err):
3380             return vresult
3381
3382         return match_result
3383
3384     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3385         """
3386         Checks that @line meets all of the requirements listed in this classes'
3387         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3388         description of those requirements.
3389
3390         Returns:
3391             * Ok(None), if ALL of the requirements are met.
3392                 OR
3393             * Err(CannotTransform), if ANY of the requirements are NOT met.
3394         """
3395         LL = line.leaves
3396
3397         string_leaf = LL[string_idx]
3398
3399         max_string_length = self.__get_max_string_length(line, string_idx)
3400         if len(string_leaf.value) <= max_string_length:
3401             return TErr(
3402                 "The string itself is not what is causing this line to be too long."
3403             )
3404
3405         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3406             token.STRING,
3407             token.NEWLINE,
3408         ]:
3409             return TErr(
3410                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3411                 " no parent)."
3412             )
3413
3414         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3415             line.comments[id(line.leaves[string_idx])]
3416         ):
3417             return TErr(
3418                 "Line appears to end with an inline pragma comment. Splitting the line"
3419                 " could modify the pragma's behavior."
3420             )
3421
3422         if has_triple_quotes(string_leaf.value):
3423             return TErr("We cannot split multiline strings.")
3424
3425         return Ok(None)
3426
3427     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3428         """
3429         Calculates the max string length used when attempting to determine
3430         whether or not the target string is responsible for causing the line to
3431         go over the line length limit.
3432
3433         WARNING: This method is tightly coupled to both StringSplitter and
3434         (especially) StringParenWrapper. There is probably a better way to
3435         accomplish what is being done here.
3436
3437         Returns:
3438             max_string_length: such that `line.leaves[string_idx].value >
3439             max_string_length` implies that the target string IS responsible
3440             for causing this line to exceed the line length limit.
3441         """
3442         LL = line.leaves
3443
3444         is_valid_index = is_valid_index_factory(LL)
3445
3446         # We use the shorthand "WMA4" in comments to abbreviate "We must
3447         # account for". When giving examples, we use STRING to mean some/any
3448         # valid string.
3449         #
3450         # Finally, we use the following convenience variables:
3451         #
3452         #   P:  The leaf that is before the target string leaf.
3453         #   N:  The leaf that is after the target string leaf.
3454         #   NN: The leaf that is after N.
3455
3456         # WMA4 the whitespace at the beginning of the line.
3457         offset = line.depth * 4
3458
3459         if is_valid_index(string_idx - 1):
3460             p_idx = string_idx - 1
3461             if (
3462                 LL[string_idx - 1].type == token.LPAR
3463                 and LL[string_idx - 1].value == ""
3464                 and string_idx >= 2
3465             ):
3466                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3467                 p_idx -= 1
3468
3469             P = LL[p_idx]
3470             if P.type == token.PLUS:
3471                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3472                 offset += 2
3473
3474             if P.type == token.COMMA:
3475                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3476                 offset += 3
3477
3478             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3479                 # This conditional branch is meant to handle dictionary keys,
3480                 # variable assignments, 'return STRING' statement lines, and
3481                 # 'else STRING' ternary expression lines.
3482
3483                 # WMA4 a single space.
3484                 offset += 1
3485
3486                 # WMA4 the lengths of any leaves that came before that space.
3487                 for leaf in LL[: p_idx + 1]:
3488                     offset += len(str(leaf))
3489
3490         if is_valid_index(string_idx + 1):
3491             N = LL[string_idx + 1]
3492             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3493                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3494                 N = LL[string_idx + 2]
3495
3496             if N.type == token.COMMA:
3497                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3498                 offset += 1
3499
3500             if is_valid_index(string_idx + 2):
3501                 NN = LL[string_idx + 2]
3502
3503                 if N.type == token.DOT and NN.type == token.NAME:
3504                     # This conditional branch is meant to handle method calls invoked
3505                     # off of a string literal up to and including the LPAR character.
3506
3507                     # WMA4 the '.' character.
3508                     offset += 1
3509
3510                     if (
3511                         is_valid_index(string_idx + 3)
3512                         and LL[string_idx + 3].type == token.LPAR
3513                     ):
3514                         # WMA4 the left parenthesis character.
3515                         offset += 1
3516
3517                     # WMA4 the length of the method's name.
3518                     offset += len(NN.value)
3519
3520         has_comments = False
3521         for comment_leaf in line.comments_after(LL[string_idx]):
3522             if not has_comments:
3523                 has_comments = True
3524                 # WMA4 two spaces before the '#' character.
3525                 offset += 2
3526
3527             # WMA4 the length of the inline comment.
3528             offset += len(comment_leaf.value)
3529
3530         max_string_length = self.line_length - offset
3531         return max_string_length
3532
3533
3534 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3535     """
3536     StringTransformer that splits "atom" strings (i.e. strings which exist on
3537     lines by themselves).
3538
3539     Requirements:
3540         * The line consists ONLY of a single string (with the exception of a
3541         '+' symbol which MAY exist at the start of the line), MAYBE a string
3542         trailer, and MAYBE a trailing comma.
3543             AND
3544         * All of the requirements listed in BaseStringSplitter's docstring.
3545
3546     Transformations:
3547         The string mentioned in the 'Requirements' section is split into as
3548         many substrings as necessary to adhere to the configured line length.
3549
3550         In the final set of substrings, no substring should be smaller than
3551         MIN_SUBSTR_SIZE characters.
3552
3553         The string will ONLY be split on spaces (i.e. each new substring should
3554         start with a space).
3555
3556         If the string is an f-string, it will NOT be split in the middle of an
3557         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3558         else bar()} is an f-expression).
3559
3560         If the string that is being split has an associated set of custom split
3561         records and those custom splits will NOT result in any line going over
3562         the configured line length, those custom splits are used. Otherwise the
3563         string is split as late as possible (from left-to-right) while still
3564         adhering to the transformation rules listed above.
3565
3566     Collaborations:
3567         StringSplitter relies on StringMerger to construct the appropriate
3568         CustomSplit objects and add them to the custom split map.
3569     """
3570
3571     MIN_SUBSTR_SIZE = 6
3572     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3573     RE_FEXPR = r"""
3574     (?<!\{)\{
3575         (?:
3576             [^\{\}]
3577             | \{\{
3578             | \}\}
3579         )+?
3580     (?<!\})(?:\}\})*\}(?!\})
3581     """
3582
3583     def do_splitter_match(self, line: Line) -> TMatchResult:
3584         LL = line.leaves
3585
3586         is_valid_index = is_valid_index_factory(LL)
3587
3588         idx = 0
3589
3590         # The first leaf MAY be a '+' symbol...
3591         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3592             idx += 1
3593
3594         # The next/first leaf MAY be an empty LPAR...
3595         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3596             idx += 1
3597
3598         # The next/first leaf MUST be a string...
3599         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3600             return TErr("Line does not start with a string.")
3601
3602         string_idx = idx
3603
3604         # Skip the string trailer, if one exists.
3605         string_parser = StringParser()
3606         idx = string_parser.parse(LL, string_idx)
3607
3608         # That string MAY be followed by an empty RPAR...
3609         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3610             idx += 1
3611
3612         # That string / empty RPAR leaf MAY be followed by a comma...
3613         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3614             idx += 1
3615
3616         # But no more leaves are allowed...
3617         if is_valid_index(idx):
3618             return TErr("This line does not end with a string.")
3619
3620         return Ok(string_idx)
3621
3622     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3623         LL = line.leaves
3624
3625         QUOTE = LL[string_idx].value[-1]
3626
3627         is_valid_index = is_valid_index_factory(LL)
3628         insert_str_child = insert_str_child_factory(LL[string_idx])
3629
3630         prefix = get_string_prefix(LL[string_idx].value)
3631
3632         # We MAY choose to drop the 'f' prefix from substrings that don't
3633         # contain any f-expressions, but ONLY if the original f-string
3634         # containes at least one f-expression. Otherwise, we will alter the AST
3635         # of the program.
3636         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3637             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3638         )
3639
3640         first_string_line = True
3641         starts_with_plus = LL[0].type == token.PLUS
3642
3643         def line_needs_plus() -> bool:
3644             return first_string_line and starts_with_plus
3645
3646         def maybe_append_plus(new_line: Line) -> None:
3647             """
3648             Side Effects:
3649                 If @line starts with a plus and this is the first line we are
3650                 constructing, this function appends a PLUS leaf to @new_line
3651                 and replaces the old PLUS leaf in the node structure. Otherwise
3652                 this function does nothing.
3653             """
3654             if line_needs_plus():
3655                 plus_leaf = Leaf(token.PLUS, "+")
3656                 replace_child(LL[0], plus_leaf)
3657                 new_line.append(plus_leaf)
3658
3659         ends_with_comma = (
3660             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3661         )
3662
3663         def max_last_string() -> int:
3664             """
3665             Returns:
3666                 The max allowed length of the string value used for the last
3667                 line we will construct.
3668             """
3669             result = self.line_length
3670             result -= line.depth * 4
3671             result -= 1 if ends_with_comma else 0
3672             result -= 2 if line_needs_plus() else 0
3673             return result
3674
3675         # --- Calculate Max Break Index (for string value)
3676         # We start with the line length limit
3677         max_break_idx = self.line_length
3678         # The last index of a string of length N is N-1.
3679         max_break_idx -= 1
3680         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3681         max_break_idx -= line.depth * 4
3682         if max_break_idx < 0:
3683             yield TErr(
3684                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3685                 f" {line.depth}"
3686             )
3687             return
3688
3689         # Check if StringMerger registered any custom splits.
3690         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3691         # We use them ONLY if none of them would produce lines that exceed the
3692         # line limit.
3693         use_custom_breakpoints = bool(
3694             custom_splits
3695             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3696         )
3697
3698         # Temporary storage for the remaining chunk of the string line that
3699         # can't fit onto the line currently being constructed.
3700         rest_value = LL[string_idx].value
3701
3702         def more_splits_should_be_made() -> bool:
3703             """
3704             Returns:
3705                 True iff `rest_value` (the remaining string value from the last
3706                 split), should be split again.
3707             """
3708             if use_custom_breakpoints:
3709                 return len(custom_splits) > 1
3710             else:
3711                 return len(rest_value) > max_last_string()
3712
3713         string_line_results: List[Ok[Line]] = []
3714         while more_splits_should_be_made():
3715             if use_custom_breakpoints:
3716                 # Custom User Split (manual)
3717                 csplit = custom_splits.pop(0)
3718                 break_idx = csplit.break_idx
3719             else:
3720                 # Algorithmic Split (automatic)
3721                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3722                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3723                 if maybe_break_idx is None:
3724                     # If we are unable to algorthmically determine a good split
3725                     # and this string has custom splits registered to it, we
3726                     # fall back to using them--which means we have to start
3727                     # over from the beginning.
3728                     if custom_splits:
3729                         rest_value = LL[string_idx].value
3730                         string_line_results = []
3731                         first_string_line = True
3732                         use_custom_breakpoints = True
3733                         continue
3734
3735                     # Otherwise, we stop splitting here.
3736                     break
3737
3738                 break_idx = maybe_break_idx
3739
3740             # --- Construct `next_value`
3741             next_value = rest_value[:break_idx] + QUOTE
3742             if (
3743                 # Are we allowed to try to drop a pointless 'f' prefix?
3744                 drop_pointless_f_prefix
3745                 # If we are, will we be successful?
3746                 and next_value != self.__normalize_f_string(next_value, prefix)
3747             ):
3748                 # If the current custom split did NOT originally use a prefix,
3749                 # then `csplit.break_idx` will be off by one after removing
3750                 # the 'f' prefix.
3751                 break_idx = (
3752                     break_idx + 1
3753                     if use_custom_breakpoints and not csplit.has_prefix
3754                     else break_idx
3755                 )
3756                 next_value = rest_value[:break_idx] + QUOTE
3757                 next_value = self.__normalize_f_string(next_value, prefix)
3758
3759             # --- Construct `next_leaf`
3760             next_leaf = Leaf(token.STRING, next_value)
3761             insert_str_child(next_leaf)
3762             self.__maybe_normalize_string_quotes(next_leaf)
3763
3764             # --- Construct `next_line`
3765             next_line = line.clone()
3766             maybe_append_plus(next_line)
3767             next_line.append(next_leaf)
3768             string_line_results.append(Ok(next_line))
3769
3770             rest_value = prefix + QUOTE + rest_value[break_idx:]
3771             first_string_line = False
3772
3773         yield from string_line_results
3774
3775         if drop_pointless_f_prefix:
3776             rest_value = self.__normalize_f_string(rest_value, prefix)
3777
3778         rest_leaf = Leaf(token.STRING, rest_value)
3779         insert_str_child(rest_leaf)
3780
3781         # NOTE: I could not find a test case that verifies that the following
3782         # line is actually necessary, but it seems to be. Otherwise we risk
3783         # not normalizing the last substring, right?
3784         self.__maybe_normalize_string_quotes(rest_leaf)
3785
3786         last_line = line.clone()
3787         maybe_append_plus(last_line)
3788
3789         # If there are any leaves to the right of the target string...
3790         if is_valid_index(string_idx + 1):
3791             # We use `temp_value` here to determine how long the last line
3792             # would be if we were to append all the leaves to the right of the
3793             # target string to the last string line.
3794             temp_value = rest_value
3795             for leaf in LL[string_idx + 1 :]:
3796                 temp_value += str(leaf)
3797                 if leaf.type == token.LPAR:
3798                     break
3799
3800             # Try to fit them all on the same line with the last substring...
3801             if (
3802                 len(temp_value) <= max_last_string()
3803                 or LL[string_idx + 1].type == token.COMMA
3804             ):
3805                 last_line.append(rest_leaf)
3806                 append_leaves(last_line, line, LL[string_idx + 1 :])
3807                 yield Ok(last_line)
3808             # Otherwise, place the last substring on one line and everything
3809             # else on a line below that...
3810             else:
3811                 last_line.append(rest_leaf)
3812                 yield Ok(last_line)
3813
3814                 non_string_line = line.clone()
3815                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3816                 yield Ok(non_string_line)
3817         # Else the target string was the last leaf...
3818         else:
3819             last_line.append(rest_leaf)
3820             last_line.comments = line.comments.copy()
3821             yield Ok(last_line)
3822
3823     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3824         """
3825         This method contains the algorithm that StringSplitter uses to
3826         determine which character to split each string at.
3827
3828         Args:
3829             @string: The substring that we are attempting to split.
3830             @max_break_idx: The ideal break index. We will return this value if it
3831             meets all the necessary conditions. In the likely event that it
3832             doesn't we will try to find the closest index BELOW @max_break_idx
3833             that does. If that fails, we will expand our search by also
3834             considering all valid indices ABOVE @max_break_idx.
3835
3836         Pre-Conditions:
3837             * assert_is_leaf_string(@string)
3838             * 0 <= @max_break_idx < len(@string)
3839
3840         Returns:
3841             break_idx, if an index is able to be found that meets all of the
3842             conditions listed in the 'Transformations' section of this classes'
3843             docstring.
3844                 OR
3845             None, otherwise.
3846         """
3847         is_valid_index = is_valid_index_factory(string)
3848
3849         assert is_valid_index(max_break_idx)
3850         assert_is_leaf_string(string)
3851
3852         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3853
3854         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3855             """
3856             Yields:
3857                 All ranges of @string which, if @string were to be split there,
3858                 would result in the splitting of an f-expression (which is NOT
3859                 allowed).
3860             """
3861             nonlocal _fexpr_slices
3862
3863             if _fexpr_slices is None:
3864                 _fexpr_slices = []
3865                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3866                     _fexpr_slices.append(match.span())
3867
3868             yield from _fexpr_slices
3869
3870         is_fstring = "f" in get_string_prefix(string)
3871
3872         def breaks_fstring_expression(i: Index) -> bool:
3873             """
3874             Returns:
3875                 True iff returning @i would result in the splitting of an
3876                 f-expression (which is NOT allowed).
3877             """
3878             if not is_fstring:
3879                 return False
3880
3881             for (start, end) in fexpr_slices():
3882                 if start <= i < end:
3883                     return True
3884
3885             return False
3886
3887         def passes_all_checks(i: Index) -> bool:
3888             """
3889             Returns:
3890                 True iff ALL of the conditions listed in the 'Transformations'
3891                 section of this classes' docstring would be be met by returning @i.
3892             """
3893             is_space = string[i] == " "
3894             is_big_enough = (
3895                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3896                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3897             )
3898             return is_space and is_big_enough and not breaks_fstring_expression(i)
3899
3900         # First, we check all indices BELOW @max_break_idx.
3901         break_idx = max_break_idx
3902         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3903             break_idx -= 1
3904
3905         if not passes_all_checks(break_idx):
3906             # If that fails, we check all indices ABOVE @max_break_idx.
3907             #
3908             # If we are able to find a valid index here, the next line is going
3909             # to be longer than the specified line length, but it's probably
3910             # better than doing nothing at all.
3911             break_idx = max_break_idx + 1
3912             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3913                 break_idx += 1
3914
3915             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3916                 return None
3917
3918         return break_idx
3919
3920     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3921         if self.normalize_strings:
3922             normalize_string_quotes(leaf)
3923
3924     def __normalize_f_string(self, string: str, prefix: str) -> str:
3925         """
3926         Pre-Conditions:
3927             * assert_is_leaf_string(@string)
3928
3929         Returns:
3930             * If @string is an f-string that contains no f-expressions, we
3931             return a string identical to @string except that the 'f' prefix
3932             has been stripped and all double braces (i.e. '{{' or '}}') have
3933             been normalized (i.e. turned into '{' or '}').
3934                 OR
3935             * Otherwise, we return @string.
3936         """
3937         assert_is_leaf_string(string)
3938
3939         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3940             new_prefix = prefix.replace("f", "")
3941
3942             temp = string[len(prefix) :]
3943             temp = re.sub(r"\{\{", "{", temp)
3944             temp = re.sub(r"\}\}", "}", temp)
3945             new_string = temp
3946
3947             return f"{new_prefix}{new_string}"
3948         else:
3949             return string
3950
3951
3952 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
3953     """
3954     StringTransformer that splits non-"atom" strings (i.e. strings that do not
3955     exist on lines by themselves).
3956
3957     Requirements:
3958         All of the requirements listed in BaseStringSplitter's docstring in
3959         addition to the requirements listed below:
3960
3961         * The line is a return/yield statement, which returns/yields a string.
3962             OR
3963         * The line is part of a ternary expression (e.g. `x = y if cond else
3964         z`) such that the line starts with `else <string>`, where <string> is
3965         some string.
3966             OR
3967         * The line is an assert statement, which ends with a string.
3968             OR
3969         * The line is an assignment statement (e.g. `x = <string>` or `x +=
3970         <string>`) such that the variable is being assigned the value of some
3971         string.
3972             OR
3973         * The line is a dictionary key assignment where some valid key is being
3974         assigned the value of some string.
3975
3976     Transformations:
3977         The chosen string is wrapped in parentheses and then split at the LPAR.
3978
3979         We then have one line which ends with an LPAR and another line that
3980         starts with the chosen string. The latter line is then split again at
3981         the RPAR. This results in the RPAR (and possibly a trailing comma)
3982         being placed on its own line.
3983
3984         NOTE: If any leaves exist to the right of the chosen string (except
3985         for a trailing comma, which would be placed after the RPAR), those
3986         leaves are placed inside the parentheses.  In effect, the chosen
3987         string is not necessarily being "wrapped" by parentheses. We can,
3988         however, count on the LPAR being placed directly before the chosen
3989         string.
3990
3991         In other words, StringParenWrapper creates "atom" strings. These
3992         can then be split again by StringSplitter, if necessary.
3993
3994     Collaborations:
3995         In the event that a string line split by StringParenWrapper is
3996         changed such that it no longer needs to be given its own line,
3997         StringParenWrapper relies on StringParenStripper to clean up the
3998         parentheses it created.
3999     """
4000
4001     def do_splitter_match(self, line: Line) -> TMatchResult:
4002         LL = line.leaves
4003
4004         string_idx = None
4005         string_idx = string_idx or self._return_match(LL)
4006         string_idx = string_idx or self._else_match(LL)
4007         string_idx = string_idx or self._assert_match(LL)
4008         string_idx = string_idx or self._assign_match(LL)
4009         string_idx = string_idx or self._dict_match(LL)
4010
4011         if string_idx is not None:
4012             string_value = line.leaves[string_idx].value
4013             # If the string has no spaces...
4014             if " " not in string_value:
4015                 # And will still violate the line length limit when split...
4016                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4017                 if len(string_value) > max_string_length:
4018                     # And has no associated custom splits...
4019                     if not self.has_custom_splits(string_value):
4020                         # Then we should NOT put this string on its own line.
4021                         return TErr(
4022                             "We do not wrap long strings in parentheses when the"
4023                             " resultant line would still be over the specified line"
4024                             " length and can't be split further by StringSplitter."
4025                         )
4026             return Ok(string_idx)
4027
4028         return TErr("This line does not contain any non-atomic strings.")
4029
4030     @staticmethod
4031     def _return_match(LL: List[Leaf]) -> Optional[int]:
4032         """
4033         Returns:
4034             string_idx such that @LL[string_idx] is equal to our target (i.e.
4035             matched) string, if this line matches the return/yield statement
4036             requirements listed in the 'Requirements' section of this classes'
4037             docstring.
4038                 OR
4039             None, otherwise.
4040         """
4041         # If this line is apart of a return/yield statement and the first leaf
4042         # contains either the "return" or "yield" keywords...
4043         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4044             0
4045         ].value in ["return", "yield"]:
4046             is_valid_index = is_valid_index_factory(LL)
4047
4048             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4049             # The next visible leaf MUST contain a string...
4050             if is_valid_index(idx) and LL[idx].type == token.STRING:
4051                 return idx
4052
4053         return None
4054
4055     @staticmethod
4056     def _else_match(LL: List[Leaf]) -> Optional[int]:
4057         """
4058         Returns:
4059             string_idx such that @LL[string_idx] is equal to our target (i.e.
4060             matched) string, if this line matches the ternary expression
4061             requirements listed in the 'Requirements' section of this classes'
4062             docstring.
4063                 OR
4064             None, otherwise.
4065         """
4066         # If this line is apart of a ternary expression and the first leaf
4067         # contains the "else" keyword...
4068         if (
4069             parent_type(LL[0]) == syms.test
4070             and LL[0].type == token.NAME
4071             and LL[0].value == "else"
4072         ):
4073             is_valid_index = is_valid_index_factory(LL)
4074
4075             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4076             # The next visible leaf MUST contain a string...
4077             if is_valid_index(idx) and LL[idx].type == token.STRING:
4078                 return idx
4079
4080         return None
4081
4082     @staticmethod
4083     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4084         """
4085         Returns:
4086             string_idx such that @LL[string_idx] is equal to our target (i.e.
4087             matched) string, if this line matches the assert statement
4088             requirements listed in the 'Requirements' section of this classes'
4089             docstring.
4090                 OR
4091             None, otherwise.
4092         """
4093         # If this line is apart of an assert statement and the first leaf
4094         # contains the "assert" keyword...
4095         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4096             is_valid_index = is_valid_index_factory(LL)
4097
4098             for (i, leaf) in enumerate(LL):
4099                 # We MUST find a comma...
4100                 if leaf.type == token.COMMA:
4101                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4102
4103                     # That comma MUST be followed by a string...
4104                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4105                         string_idx = idx
4106
4107                         # Skip the string trailer, if one exists.
4108                         string_parser = StringParser()
4109                         idx = string_parser.parse(LL, string_idx)
4110
4111                         # But no more leaves are allowed...
4112                         if not is_valid_index(idx):
4113                             return string_idx
4114
4115         return None
4116
4117     @staticmethod
4118     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4119         """
4120         Returns:
4121             string_idx such that @LL[string_idx] is equal to our target (i.e.
4122             matched) string, if this line matches the assignment statement
4123             requirements listed in the 'Requirements' section of this classes'
4124             docstring.
4125                 OR
4126             None, otherwise.
4127         """
4128         # If this line is apart of an expression statement or is a function
4129         # argument AND the first leaf contains a variable name...
4130         if (
4131             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4132             and LL[0].type == token.NAME
4133         ):
4134             is_valid_index = is_valid_index_factory(LL)
4135
4136             for (i, leaf) in enumerate(LL):
4137                 # We MUST find either an '=' or '+=' symbol...
4138                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4139                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4140
4141                     # That symbol MUST be followed by a string...
4142                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4143                         string_idx = idx
4144
4145                         # Skip the string trailer, if one exists.
4146                         string_parser = StringParser()
4147                         idx = string_parser.parse(LL, string_idx)
4148
4149                         # The next leaf MAY be a comma iff this line is apart
4150                         # of a function argument...
4151                         if (
4152                             parent_type(LL[0]) == syms.argument
4153                             and is_valid_index(idx)
4154                             and LL[idx].type == token.COMMA
4155                         ):
4156                             idx += 1
4157
4158                         # But no more leaves are allowed...
4159                         if not is_valid_index(idx):
4160                             return string_idx
4161
4162         return None
4163
4164     @staticmethod
4165     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4166         """
4167         Returns:
4168             string_idx such that @LL[string_idx] is equal to our target (i.e.
4169             matched) string, if this line matches the dictionary key assignment
4170             statement requirements listed in the 'Requirements' section of this
4171             classes' docstring.
4172                 OR
4173             None, otherwise.
4174         """
4175         # If this line is apart of a dictionary key assignment...
4176         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4177             is_valid_index = is_valid_index_factory(LL)
4178
4179             for (i, leaf) in enumerate(LL):
4180                 # We MUST find a colon...
4181                 if leaf.type == token.COLON:
4182                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4183
4184                     # That colon MUST be followed by a string...
4185                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4186                         string_idx = idx
4187
4188                         # Skip the string trailer, if one exists.
4189                         string_parser = StringParser()
4190                         idx = string_parser.parse(LL, string_idx)
4191
4192                         # That string MAY be followed by a comma...
4193                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4194                             idx += 1
4195
4196                         # But no more leaves are allowed...
4197                         if not is_valid_index(idx):
4198                             return string_idx
4199
4200         return None
4201
4202     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4203         LL = line.leaves
4204
4205         is_valid_index = is_valid_index_factory(LL)
4206         insert_str_child = insert_str_child_factory(LL[string_idx])
4207
4208         comma_idx = len(LL) - 1
4209         ends_with_comma = False
4210         if LL[comma_idx].type == token.COMMA:
4211             ends_with_comma = True
4212
4213         leaves_to_steal_comments_from = [LL[string_idx]]
4214         if ends_with_comma:
4215             leaves_to_steal_comments_from.append(LL[comma_idx])
4216
4217         # --- First Line
4218         first_line = line.clone()
4219         left_leaves = LL[:string_idx]
4220
4221         # We have to remember to account for (possibly invisible) LPAR and RPAR
4222         # leaves that already wrapped the target string. If these leaves do
4223         # exist, we will replace them with our own LPAR and RPAR leaves.
4224         old_parens_exist = False
4225         if left_leaves and left_leaves[-1].type == token.LPAR:
4226             old_parens_exist = True
4227             leaves_to_steal_comments_from.append(left_leaves[-1])
4228             left_leaves.pop()
4229
4230         append_leaves(first_line, line, left_leaves)
4231
4232         lpar_leaf = Leaf(token.LPAR, "(")
4233         if old_parens_exist:
4234             replace_child(LL[string_idx - 1], lpar_leaf)
4235         else:
4236             insert_str_child(lpar_leaf)
4237         first_line.append(lpar_leaf)
4238
4239         # We throw inline comments that were originally to the right of the
4240         # target string to the top line. They will now be shown to the right of
4241         # the LPAR.
4242         for leaf in leaves_to_steal_comments_from:
4243             for comment_leaf in line.comments_after(leaf):
4244                 first_line.append(comment_leaf, preformatted=True)
4245
4246         yield Ok(first_line)
4247
4248         # --- Middle (String) Line
4249         # We only need to yield one (possibly too long) string line, since the
4250         # `StringSplitter` will break it down further if necessary.
4251         string_value = LL[string_idx].value
4252         string_line = Line(
4253             depth=line.depth + 1,
4254             inside_brackets=True,
4255             should_explode=line.should_explode,
4256         )
4257         string_leaf = Leaf(token.STRING, string_value)
4258         insert_str_child(string_leaf)
4259         string_line.append(string_leaf)
4260
4261         old_rpar_leaf = None
4262         if is_valid_index(string_idx + 1):
4263             right_leaves = LL[string_idx + 1 :]
4264             if ends_with_comma:
4265                 right_leaves.pop()
4266
4267             if old_parens_exist:
4268                 assert (
4269                     right_leaves and right_leaves[-1].type == token.RPAR
4270                 ), "Apparently, old parentheses do NOT exist?!"
4271                 old_rpar_leaf = right_leaves.pop()
4272
4273             append_leaves(string_line, line, right_leaves)
4274
4275         yield Ok(string_line)
4276
4277         # --- Last Line
4278         last_line = line.clone()
4279         last_line.bracket_tracker = first_line.bracket_tracker
4280
4281         new_rpar_leaf = Leaf(token.RPAR, ")")
4282         if old_rpar_leaf is not None:
4283             replace_child(old_rpar_leaf, new_rpar_leaf)
4284         else:
4285             insert_str_child(new_rpar_leaf)
4286         last_line.append(new_rpar_leaf)
4287
4288         # If the target string ended with a comma, we place this comma to the
4289         # right of the RPAR on the last line.
4290         if ends_with_comma:
4291             comma_leaf = Leaf(token.COMMA, ",")
4292             replace_child(LL[comma_idx], comma_leaf)
4293             last_line.append(comma_leaf)
4294
4295         yield Ok(last_line)
4296
4297
4298 class StringParser:
4299     """
4300     A state machine that aids in parsing a string's "trailer", which can be
4301     either non-existant, an old-style formatting sequence (e.g. `% varX` or `%
4302     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4303     varY)`).
4304
4305     NOTE: A new StringParser object MUST be instantiated for each string
4306     trailer we need to parse.
4307
4308     Examples:
4309         We shall assume that `line` equals the `Line` object that corresponds
4310         to the following line of python code:
4311         ```
4312         x = "Some {}.".format("String") + some_other_string
4313         ```
4314
4315         Furthermore, we will assume that `string_idx` is some index such that:
4316         ```
4317         assert line.leaves[string_idx].value == "Some {}."
4318         ```
4319
4320         The following code snippet then holds:
4321         ```
4322         string_parser = StringParser()
4323         idx = string_parser.parse(line.leaves, string_idx)
4324         assert line.leaves[idx].type == token.PLUS
4325         ```
4326     """
4327
4328     DEFAULT_TOKEN = -1
4329
4330     # String Parser States
4331     START = 1
4332     DOT = 2
4333     NAME = 3
4334     PERCENT = 4
4335     SINGLE_FMT_ARG = 5
4336     LPAR = 6
4337     RPAR = 7
4338     DONE = 8
4339
4340     # Lookup Table for Next State
4341     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4342         # A string trailer may start with '.' OR '%'.
4343         (START, token.DOT): DOT,
4344         (START, token.PERCENT): PERCENT,
4345         (START, DEFAULT_TOKEN): DONE,
4346         # A '.' MUST be followed by an attribute or method name.
4347         (DOT, token.NAME): NAME,
4348         # A method name MUST be followed by an '(', whereas an attribute name
4349         # is the last symbol in the string trailer.
4350         (NAME, token.LPAR): LPAR,
4351         (NAME, DEFAULT_TOKEN): DONE,
4352         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4353         # string or variable name).
4354         (PERCENT, token.LPAR): LPAR,
4355         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4356         # If a '%' symbol is followed by a single argument, that argument is
4357         # the last leaf in the string trailer.
4358         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4359         # If present, a ')' symbol is the last symbol in a string trailer.
4360         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4361         # since they are treated as a special case by the parsing logic in this
4362         # classes' implementation.)
4363         (RPAR, DEFAULT_TOKEN): DONE,
4364     }
4365
4366     def __init__(self) -> None:
4367         self._state = self.START
4368         self._unmatched_lpars = 0
4369
4370     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4371         """
4372         Pre-conditions:
4373             * @leaves[@string_idx].type == token.STRING
4374
4375         Returns:
4376             The index directly after the last leaf which is apart of the string
4377             trailer, if a "trailer" exists.
4378                 OR
4379             @string_idx + 1, if no string "trailer" exists.
4380         """
4381         assert leaves[string_idx].type == token.STRING
4382
4383         idx = string_idx + 1
4384         while idx < len(leaves) and self._next_state(leaves[idx]):
4385             idx += 1
4386         return idx
4387
4388     def _next_state(self, leaf: Leaf) -> bool:
4389         """
4390         Pre-conditions:
4391             * On the first call to this function, @leaf MUST be the leaf that
4392             was directly after the string leaf in question (e.g. if our target
4393             string is `line.leaves[i]` then the first call to this method must
4394             be `line.leaves[i + 1]`).
4395             * On the next call to this function, the leaf paramater passed in
4396             MUST be the leaf directly following @leaf.
4397
4398         Returns:
4399             True iff @leaf is apart of the string's trailer.
4400         """
4401         # We ignore empty LPAR or RPAR leaves.
4402         if is_empty_par(leaf):
4403             return True
4404
4405         next_token = leaf.type
4406         if next_token == token.LPAR:
4407             self._unmatched_lpars += 1
4408
4409         current_state = self._state
4410
4411         # The LPAR parser state is a special case. We will return True until we
4412         # find the matching RPAR token.
4413         if current_state == self.LPAR:
4414             if next_token == token.RPAR:
4415                 self._unmatched_lpars -= 1
4416                 if self._unmatched_lpars == 0:
4417                     self._state = self.RPAR
4418         # Otherwise, we use a lookup table to determine the next state.
4419         else:
4420             # If the lookup table matches the current state to the next
4421             # token, we use the lookup table.
4422             if (current_state, next_token) in self._goto:
4423                 self._state = self._goto[current_state, next_token]
4424             else:
4425                 # Otherwise, we check if a the current state was assigned a
4426                 # default.
4427                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4428                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4429                 # If no default has been assigned, then this parser has a logic
4430                 # error.
4431                 else:
4432                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4433
4434             if self._state == self.DONE:
4435                 return False
4436
4437         return True
4438
4439
4440 def TErr(err_msg: str) -> Err[CannotTransform]:
4441     """(T)ransform Err
4442
4443     Convenience function used when working with the TResult type.
4444     """
4445     cant_transform = CannotTransform(err_msg)
4446     return Err(cant_transform)
4447
4448
4449 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4450     """
4451     Returns:
4452         True iff one of the comments in @comment_list is a pragma used by one
4453         of the more common static analysis tools for python (e.g. mypy, flake8,
4454         pylint).
4455     """
4456     for comment in comment_list:
4457         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4458             return True
4459
4460     return False
4461
4462
4463 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4464     """
4465     Factory for a convenience function that is used to orphan @string_leaf
4466     and then insert multiple new leaves into the same part of the node
4467     structure that @string_leaf had originally occupied.
4468
4469     Examples:
4470         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4471         string_leaf.parent`. Assume the node `N` has the following
4472         original structure:
4473
4474         Node(
4475             expr_stmt, [
4476                 Leaf(NAME, 'x'),
4477                 Leaf(EQUAL, '='),
4478                 Leaf(STRING, '"foo"'),
4479             ]
4480         )
4481
4482         We then run the code snippet shown below.
4483         ```
4484         insert_str_child = insert_str_child_factory(string_leaf)
4485
4486         lpar = Leaf(token.LPAR, '(')
4487         insert_str_child(lpar)
4488
4489         bar = Leaf(token.STRING, '"bar"')
4490         insert_str_child(bar)
4491
4492         rpar = Leaf(token.RPAR, ')')
4493         insert_str_child(rpar)
4494         ```
4495
4496         After which point, it follows that `string_leaf.parent is None` and
4497         the node `N` now has the following structure:
4498
4499         Node(
4500             expr_stmt, [
4501                 Leaf(NAME, 'x'),
4502                 Leaf(EQUAL, '='),
4503                 Leaf(LPAR, '('),
4504                 Leaf(STRING, '"bar"'),
4505                 Leaf(RPAR, ')'),
4506             ]
4507         )
4508     """
4509     string_parent = string_leaf.parent
4510     string_child_idx = string_leaf.remove()
4511
4512     def insert_str_child(child: LN) -> None:
4513         nonlocal string_child_idx
4514
4515         assert string_parent is not None
4516         assert string_child_idx is not None
4517
4518         string_parent.insert_child(string_child_idx, child)
4519         string_child_idx += 1
4520
4521     return insert_str_child
4522
4523
4524 def has_triple_quotes(string: str) -> bool:
4525     """
4526     Returns:
4527         True iff @string starts with three quotation characters.
4528     """
4529     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4530     return raw_string[:3] in {'"""', "'''"}
4531
4532
4533 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4534     """
4535     Returns:
4536         @node.parent.type, if @node is not None and has a parent.
4537             OR
4538         None, otherwise.
4539     """
4540     if node is None or node.parent is None:
4541         return None
4542
4543     return node.parent.type
4544
4545
4546 def is_empty_par(leaf: Leaf) -> bool:
4547     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4548
4549
4550 def is_empty_lpar(leaf: Leaf) -> bool:
4551     return leaf.type == token.LPAR and leaf.value == ""
4552
4553
4554 def is_empty_rpar(leaf: Leaf) -> bool:
4555     return leaf.type == token.RPAR and leaf.value == ""
4556
4557
4558 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4559     """
4560     Examples:
4561         ```
4562         my_list = [1, 2, 3]
4563
4564         is_valid_index = is_valid_index_factory(my_list)
4565
4566         assert is_valid_index(0)
4567         assert is_valid_index(2)
4568
4569         assert not is_valid_index(3)
4570         assert not is_valid_index(-1)
4571         ```
4572     """
4573
4574     def is_valid_index(idx: int) -> bool:
4575         """
4576         Returns:
4577             True iff @idx is positive AND seq[@idx] does NOT raise an
4578             IndexError.
4579         """
4580         return 0 <= idx < len(seq)
4581
4582     return is_valid_index
4583
4584
4585 def line_to_string(line: Line) -> str:
4586     """Returns the string representation of @line.
4587
4588     WARNING: This is known to be computationally expensive.
4589     """
4590     return str(line).strip("\n")
4591
4592
4593 def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
4594     """
4595     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4596     underlying Node structure where appropriate.
4597
4598     All of the leaves in @leaves are duplicated. The duplicates are then
4599     appended to @new_line and used to replace their originals in the underlying
4600     Node structure. Any comments attatched to the old leaves are reattached to
4601     the new leaves.
4602
4603     Pre-conditions:
4604         set(@leaves) is a subset of set(@old_line.leaves).
4605     """
4606     for old_leaf in leaves:
4607         assert old_leaf in old_line.leaves
4608
4609         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4610         replace_child(old_leaf, new_leaf)
4611         new_line.append(new_leaf)
4612
4613         for comment_leaf in old_line.comments_after(old_leaf):
4614             new_line.append(comment_leaf, preformatted=True)
4615
4616
4617 def replace_child(old_child: LN, new_child: LN) -> None:
4618     """
4619     Side Effects:
4620         * If @old_child.parent is set, replace @old_child with @new_child in
4621         @old_child's underlying Node structure.
4622             OR
4623         * Otherwise, this function does nothing.
4624     """
4625     parent = old_child.parent
4626     if not parent:
4627         return
4628
4629     child_idx = old_child.remove()
4630     if child_idx is not None:
4631         parent.insert_child(child_idx, new_child)
4632
4633
4634 def get_string_prefix(string: str) -> str:
4635     """
4636     Pre-conditions:
4637         * assert_is_leaf_string(@string)
4638
4639     Returns:
4640         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4641     """
4642     assert_is_leaf_string(string)
4643
4644     prefix = ""
4645     prefix_idx = 0
4646     while string[prefix_idx] in STRING_PREFIX_CHARS:
4647         prefix += string[prefix_idx].lower()
4648         prefix_idx += 1
4649
4650     return prefix
4651
4652
4653 def assert_is_leaf_string(string: str) -> None:
4654     """
4655     Checks the pre-condition that @string has the format that you would expect
4656     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4657     token.STRING`. A more precise description of the pre-conditions that are
4658     checked are listed below.
4659
4660     Pre-conditions:
4661         * @string starts with either ', ", <prefix>', or <prefix>" where
4662         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4663         * @string ends with a quote character (' or ").
4664
4665     Raises:
4666         AssertionError(...) if the pre-conditions listed above are not
4667         satisfied.
4668     """
4669     dquote_idx = string.find('"')
4670     squote_idx = string.find("'")
4671     if -1 in [dquote_idx, squote_idx]:
4672         quote_idx = max(dquote_idx, squote_idx)
4673     else:
4674         quote_idx = min(squote_idx, dquote_idx)
4675
4676     assert (
4677         0 <= quote_idx < len(string) - 1
4678     ), f"{string!r} is missing a starting quote character (' or \")."
4679     assert string[-1] in (
4680         "'",
4681         '"',
4682     ), f"{string!r} is missing an ending quote character (' or \")."
4683     assert set(string[:quote_idx]).issubset(
4684         set(STRING_PREFIX_CHARS)
4685     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4686
4687
4688 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4689     """Split line into many lines, starting with the first matching bracket pair.
4690
4691     Note: this usually looks weird, only use this for function definitions.
4692     Prefer RHS otherwise.  This is why this function is not symmetrical with
4693     :func:`right_hand_split` which also handles optional parentheses.
4694     """
4695     tail_leaves: List[Leaf] = []
4696     body_leaves: List[Leaf] = []
4697     head_leaves: List[Leaf] = []
4698     current_leaves = head_leaves
4699     matching_bracket: Optional[Leaf] = None
4700     for leaf in line.leaves:
4701         if (
4702             current_leaves is body_leaves
4703             and leaf.type in CLOSING_BRACKETS
4704             and leaf.opening_bracket is matching_bracket
4705         ):
4706             current_leaves = tail_leaves if body_leaves else head_leaves
4707         current_leaves.append(leaf)
4708         if current_leaves is head_leaves:
4709             if leaf.type in OPENING_BRACKETS:
4710                 matching_bracket = leaf
4711                 current_leaves = body_leaves
4712     if not matching_bracket:
4713         raise CannotSplit("No brackets found")
4714
4715     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4716     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4717     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4718     bracket_split_succeeded_or_raise(head, body, tail)
4719     for result in (head, body, tail):
4720         if result:
4721             yield result
4722
4723
4724 def right_hand_split(
4725     line: Line,
4726     line_length: int,
4727     features: Collection[Feature] = (),
4728     omit: Collection[LeafID] = (),
4729 ) -> Iterator[Line]:
4730     """Split line into many lines, starting with the last matching bracket pair.
4731
4732     If the split was by optional parentheses, attempt splitting without them, too.
4733     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4734     this split.
4735
4736     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4737     """
4738     tail_leaves: List[Leaf] = []
4739     body_leaves: List[Leaf] = []
4740     head_leaves: List[Leaf] = []
4741     current_leaves = tail_leaves
4742     opening_bracket: Optional[Leaf] = None
4743     closing_bracket: Optional[Leaf] = None
4744     for leaf in reversed(line.leaves):
4745         if current_leaves is body_leaves:
4746             if leaf is opening_bracket:
4747                 current_leaves = head_leaves if body_leaves else tail_leaves
4748         current_leaves.append(leaf)
4749         if current_leaves is tail_leaves:
4750             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4751                 opening_bracket = leaf.opening_bracket
4752                 closing_bracket = leaf
4753                 current_leaves = body_leaves
4754     if not (opening_bracket and closing_bracket and head_leaves):
4755         # If there is no opening or closing_bracket that means the split failed and
4756         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4757         # the matching `opening_bracket` wasn't available on `line` anymore.
4758         raise CannotSplit("No brackets found")
4759
4760     tail_leaves.reverse()
4761     body_leaves.reverse()
4762     head_leaves.reverse()
4763     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4764     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4765     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4766     bracket_split_succeeded_or_raise(head, body, tail)
4767     if (
4768         # the body shouldn't be exploded
4769         not body.should_explode
4770         # the opening bracket is an optional paren
4771         and opening_bracket.type == token.LPAR
4772         and not opening_bracket.value
4773         # the closing bracket is an optional paren
4774         and closing_bracket.type == token.RPAR
4775         and not closing_bracket.value
4776         # it's not an import (optional parens are the only thing we can split on
4777         # in this case; attempting a split without them is a waste of time)
4778         and not line.is_import
4779         # there are no standalone comments in the body
4780         and not body.contains_standalone_comments(0)
4781         # and we can actually remove the parens
4782         and can_omit_invisible_parens(body, line_length)
4783     ):
4784         omit = {id(closing_bracket), *omit}
4785         try:
4786             yield from right_hand_split(line, line_length, features=features, omit=omit)
4787             return
4788
4789         except CannotSplit:
4790             if not (
4791                 can_be_split(body)
4792                 or is_line_short_enough(body, line_length=line_length)
4793             ):
4794                 raise CannotSplit(
4795                     "Splitting failed, body is still too long and can't be split."
4796                 )
4797
4798             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4799                 raise CannotSplit(
4800                     "The current optional pair of parentheses is bound to fail to"
4801                     " satisfy the splitting algorithm because the head or the tail"
4802                     " contains multiline strings which by definition never fit one"
4803                     " line."
4804                 )
4805
4806     ensure_visible(opening_bracket)
4807     ensure_visible(closing_bracket)
4808     for result in (head, body, tail):
4809         if result:
4810             yield result
4811
4812
4813 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4814     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4815
4816     Do nothing otherwise.
4817
4818     A left- or right-hand split is based on a pair of brackets. Content before
4819     (and including) the opening bracket is left on one line, content inside the
4820     brackets is put on a separate line, and finally content starting with and
4821     following the closing bracket is put on a separate line.
4822
4823     Those are called `head`, `body`, and `tail`, respectively. If the split
4824     produced the same line (all content in `head`) or ended up with an empty `body`
4825     and the `tail` is just the closing bracket, then it's considered failed.
4826     """
4827     tail_len = len(str(tail).strip())
4828     if not body:
4829         if tail_len == 0:
4830             raise CannotSplit("Splitting brackets produced the same line")
4831
4832         elif tail_len < 3:
4833             raise CannotSplit(
4834                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4835                 " not worth it"
4836             )
4837
4838
4839 def bracket_split_build_line(
4840     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4841 ) -> Line:
4842     """Return a new line with given `leaves` and respective comments from `original`.
4843
4844     If `is_body` is True, the result line is one-indented inside brackets and as such
4845     has its first leaf's prefix normalized and a trailing comma added when expected.
4846     """
4847     result = Line(depth=original.depth)
4848     if is_body:
4849         result.inside_brackets = True
4850         result.depth += 1
4851         if leaves:
4852             # Since body is a new indent level, remove spurious leading whitespace.
4853             normalize_prefix(leaves[0], inside_brackets=True)
4854             # Ensure a trailing comma for imports and standalone function arguments, but
4855             # be careful not to add one after any comments or within type annotations.
4856             no_commas = (
4857                 original.is_def
4858                 and opening_bracket.value == "("
4859                 and not any(l.type == token.COMMA for l in leaves)
4860             )
4861
4862             if original.is_import or no_commas:
4863                 for i in range(len(leaves) - 1, -1, -1):
4864                     if leaves[i].type == STANDALONE_COMMENT:
4865                         continue
4866
4867                     if leaves[i].type != token.COMMA:
4868                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
4869                     break
4870
4871     # Populate the line
4872     for leaf in leaves:
4873         result.append(leaf, preformatted=True)
4874         for comment_after in original.comments_after(leaf):
4875             result.append(comment_after, preformatted=True)
4876     if is_body:
4877         result.should_explode = should_explode(result, opening_bracket)
4878     return result
4879
4880
4881 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4882     """Normalize prefix of the first leaf in every line returned by `split_func`.
4883
4884     This is a decorator over relevant split functions.
4885     """
4886
4887     @wraps(split_func)
4888     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4889         for l in split_func(line, features):
4890             normalize_prefix(l.leaves[0], inside_brackets=True)
4891             yield l
4892
4893     return split_wrapper
4894
4895
4896 @dont_increase_indentation
4897 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4898     """Split according to delimiters of the highest priority.
4899
4900     If the appropriate Features are given, the split will add trailing commas
4901     also in function signatures and calls that contain `*` and `**`.
4902     """
4903     try:
4904         last_leaf = line.leaves[-1]
4905     except IndexError:
4906         raise CannotSplit("Line empty")
4907
4908     bt = line.bracket_tracker
4909     try:
4910         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4911     except ValueError:
4912         raise CannotSplit("No delimiters found")
4913
4914     if delimiter_priority == DOT_PRIORITY:
4915         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4916             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4917
4918     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4919     lowest_depth = sys.maxsize
4920     trailing_comma_safe = True
4921
4922     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4923         """Append `leaf` to current line or to new line if appending impossible."""
4924         nonlocal current_line
4925         try:
4926             current_line.append_safe(leaf, preformatted=True)
4927         except ValueError:
4928             yield current_line
4929
4930             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4931             current_line.append(leaf)
4932
4933     for leaf in line.leaves:
4934         yield from append_to_line(leaf)
4935
4936         for comment_after in line.comments_after(leaf):
4937             yield from append_to_line(comment_after)
4938
4939         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4940         if leaf.bracket_depth == lowest_depth:
4941             if is_vararg(leaf, within={syms.typedargslist}):
4942                 trailing_comma_safe = (
4943                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4944                 )
4945             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4946                 trailing_comma_safe = (
4947                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4948                 )
4949
4950         leaf_priority = bt.delimiters.get(id(leaf))
4951         if leaf_priority == delimiter_priority:
4952             yield current_line
4953
4954             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4955     if current_line:
4956         if (
4957             trailing_comma_safe
4958             and delimiter_priority == COMMA_PRIORITY
4959             and current_line.leaves[-1].type != token.COMMA
4960             and current_line.leaves[-1].type != STANDALONE_COMMENT
4961         ):
4962             current_line.append(Leaf(token.COMMA, ","))
4963         yield current_line
4964
4965
4966 @dont_increase_indentation
4967 def standalone_comment_split(
4968     line: Line, features: Collection[Feature] = ()
4969 ) -> Iterator[Line]:
4970     """Split standalone comments from the rest of the line."""
4971     if not line.contains_standalone_comments(0):
4972         raise CannotSplit("Line does not have any standalone comments")
4973
4974     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4975
4976     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4977         """Append `leaf` to current line or to new line if appending impossible."""
4978         nonlocal current_line
4979         try:
4980             current_line.append_safe(leaf, preformatted=True)
4981         except ValueError:
4982             yield current_line
4983
4984             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4985             current_line.append(leaf)
4986
4987     for leaf in line.leaves:
4988         yield from append_to_line(leaf)
4989
4990         for comment_after in line.comments_after(leaf):
4991             yield from append_to_line(comment_after)
4992
4993     if current_line:
4994         yield current_line
4995
4996
4997 def is_import(leaf: Leaf) -> bool:
4998     """Return True if the given leaf starts an import statement."""
4999     p = leaf.parent
5000     t = leaf.type
5001     v = leaf.value
5002     return bool(
5003         t == token.NAME
5004         and (
5005             (v == "import" and p and p.type == syms.import_name)
5006             or (v == "from" and p and p.type == syms.import_from)
5007         )
5008     )
5009
5010
5011 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5012     """Return True if the given leaf is a special comment.
5013     Only returns true for type comments for now."""
5014     t = leaf.type
5015     v = leaf.value
5016     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5017
5018
5019 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5020     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5021     else.
5022
5023     Note: don't use backslashes for formatting or you'll lose your voting rights.
5024     """
5025     if not inside_brackets:
5026         spl = leaf.prefix.split("#")
5027         if "\\" not in spl[0]:
5028             nl_count = spl[-1].count("\n")
5029             if len(spl) > 1:
5030                 nl_count -= 1
5031             leaf.prefix = "\n" * nl_count
5032             return
5033
5034     leaf.prefix = ""
5035
5036
5037 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5038     """Make all string prefixes lowercase.
5039
5040     If remove_u_prefix is given, also removes any u prefix from the string.
5041
5042     Note: Mutates its argument.
5043     """
5044     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5045     assert match is not None, f"failed to match string {leaf.value!r}"
5046     orig_prefix = match.group(1)
5047     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5048     if remove_u_prefix:
5049         new_prefix = new_prefix.replace("u", "")
5050     leaf.value = f"{new_prefix}{match.group(2)}"
5051
5052
5053 def normalize_string_quotes(leaf: Leaf) -> None:
5054     """Prefer double quotes but only if it doesn't cause more escaping.
5055
5056     Adds or removes backslashes as appropriate. Doesn't parse and fix
5057     strings nested in f-strings (yet).
5058
5059     Note: Mutates its argument.
5060     """
5061     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5062     if value[:3] == '"""':
5063         return
5064
5065     elif value[:3] == "'''":
5066         orig_quote = "'''"
5067         new_quote = '"""'
5068     elif value[0] == '"':
5069         orig_quote = '"'
5070         new_quote = "'"
5071     else:
5072         orig_quote = "'"
5073         new_quote = '"'
5074     first_quote_pos = leaf.value.find(orig_quote)
5075     if first_quote_pos == -1:
5076         return  # There's an internal error
5077
5078     prefix = leaf.value[:first_quote_pos]
5079     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5080     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5081     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5082     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5083     if "r" in prefix.casefold():
5084         if unescaped_new_quote.search(body):
5085             # There's at least one unescaped new_quote in this raw string
5086             # so converting is impossible
5087             return
5088
5089         # Do not introduce or remove backslashes in raw strings
5090         new_body = body
5091     else:
5092         # remove unnecessary escapes
5093         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5094         if body != new_body:
5095             # Consider the string without unnecessary escapes as the original
5096             body = new_body
5097             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5098         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5099         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5100     if "f" in prefix.casefold():
5101         matches = re.findall(
5102             r"""
5103             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5104                 ([^{].*?)  # contents of the brackets except if begins with {{
5105             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5106             """,
5107             new_body,
5108             re.VERBOSE,
5109         )
5110         for m in matches:
5111             if "\\" in str(m):
5112                 # Do not introduce backslashes in interpolated expressions
5113                 return
5114
5115     if new_quote == '"""' and new_body[-1:] == '"':
5116         # edge case:
5117         new_body = new_body[:-1] + '\\"'
5118     orig_escape_count = body.count("\\")
5119     new_escape_count = new_body.count("\\")
5120     if new_escape_count > orig_escape_count:
5121         return  # Do not introduce more escaping
5122
5123     if new_escape_count == orig_escape_count and orig_quote == '"':
5124         return  # Prefer double quotes
5125
5126     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5127
5128
5129 def normalize_numeric_literal(leaf: Leaf) -> None:
5130     """Normalizes numeric (float, int, and complex) literals.
5131
5132     All letters used in the representation are normalized to lowercase (except
5133     in Python 2 long literals).
5134     """
5135     text = leaf.value.lower()
5136     if text.startswith(("0o", "0b")):
5137         # Leave octal and binary literals alone.
5138         pass
5139     elif text.startswith("0x"):
5140         # Change hex literals to upper case.
5141         before, after = text[:2], text[2:]
5142         text = f"{before}{after.upper()}"
5143     elif "e" in text:
5144         before, after = text.split("e")
5145         sign = ""
5146         if after.startswith("-"):
5147             after = after[1:]
5148             sign = "-"
5149         elif after.startswith("+"):
5150             after = after[1:]
5151         before = format_float_or_int_string(before)
5152         text = f"{before}e{sign}{after}"
5153     elif text.endswith(("j", "l")):
5154         number = text[:-1]
5155         suffix = text[-1]
5156         # Capitalize in "2L" because "l" looks too similar to "1".
5157         if suffix == "l":
5158             suffix = "L"
5159         text = f"{format_float_or_int_string(number)}{suffix}"
5160     else:
5161         text = format_float_or_int_string(text)
5162     leaf.value = text
5163
5164
5165 def format_float_or_int_string(text: str) -> str:
5166     """Formats a float string like "1.0"."""
5167     if "." not in text:
5168         return text
5169
5170     before, after = text.split(".")
5171     return f"{before or 0}.{after or 0}"
5172
5173
5174 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5175     """Make existing optional parentheses invisible or create new ones.
5176
5177     `parens_after` is a set of string leaf values immediately after which parens
5178     should be put.
5179
5180     Standardizes on visible parentheses for single-element tuples, and keeps
5181     existing visible parentheses for other tuples and generator expressions.
5182     """
5183     for pc in list_comments(node.prefix, is_endmarker=False):
5184         if pc.value in FMT_OFF:
5185             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5186             return
5187     check_lpar = False
5188     for index, child in enumerate(list(node.children)):
5189         # Fixes a bug where invisible parens are not properly stripped from
5190         # assignment statements that contain type annotations.
5191         if isinstance(child, Node) and child.type == syms.annassign:
5192             normalize_invisible_parens(child, parens_after=parens_after)
5193
5194         # Add parentheses around long tuple unpacking in assignments.
5195         if (
5196             index == 0
5197             and isinstance(child, Node)
5198             and child.type == syms.testlist_star_expr
5199         ):
5200             check_lpar = True
5201
5202         if check_lpar:
5203             if is_walrus_assignment(child):
5204                 continue
5205
5206             if child.type == syms.atom:
5207                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5208                     wrap_in_parentheses(node, child, visible=False)
5209             elif is_one_tuple(child):
5210                 wrap_in_parentheses(node, child, visible=True)
5211             elif node.type == syms.import_from:
5212                 # "import from" nodes store parentheses directly as part of
5213                 # the statement
5214                 if child.type == token.LPAR:
5215                     # make parentheses invisible
5216                     child.value = ""  # type: ignore
5217                     node.children[-1].value = ""  # type: ignore
5218                 elif child.type != token.STAR:
5219                     # insert invisible parentheses
5220                     node.insert_child(index, Leaf(token.LPAR, ""))
5221                     node.append_child(Leaf(token.RPAR, ""))
5222                 break
5223
5224             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5225                 wrap_in_parentheses(node, child, visible=False)
5226
5227         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5228
5229
5230 def normalize_fmt_off(node: Node) -> None:
5231     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5232     try_again = True
5233     while try_again:
5234         try_again = convert_one_fmt_off_pair(node)
5235
5236
5237 def convert_one_fmt_off_pair(node: Node) -> bool:
5238     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5239
5240     Returns True if a pair was converted.
5241     """
5242     for leaf in node.leaves():
5243         previous_consumed = 0
5244         for comment in list_comments(leaf.prefix, is_endmarker=False):
5245             if comment.value in FMT_OFF:
5246                 # We only want standalone comments. If there's no previous leaf or
5247                 # the previous leaf is indentation, it's a standalone comment in
5248                 # disguise.
5249                 if comment.type != STANDALONE_COMMENT:
5250                     prev = preceding_leaf(leaf)
5251                     if prev and prev.type not in WHITESPACE:
5252                         continue
5253
5254                 ignored_nodes = list(generate_ignored_nodes(leaf))
5255                 if not ignored_nodes:
5256                     continue
5257
5258                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5259                 parent = first.parent
5260                 prefix = first.prefix
5261                 first.prefix = prefix[comment.consumed :]
5262                 hidden_value = (
5263                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5264                 )
5265                 if hidden_value.endswith("\n"):
5266                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5267                     # leaf (possibly followed by a DEDENT).
5268                     hidden_value = hidden_value[:-1]
5269                 first_idx: Optional[int] = None
5270                 for ignored in ignored_nodes:
5271                     index = ignored.remove()
5272                     if first_idx is None:
5273                         first_idx = index
5274                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5275                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5276                 parent.insert_child(
5277                     first_idx,
5278                     Leaf(
5279                         STANDALONE_COMMENT,
5280                         hidden_value,
5281                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5282                     ),
5283                 )
5284                 return True
5285
5286             previous_consumed = comment.consumed
5287
5288     return False
5289
5290
5291 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5292     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5293
5294     Stops at the end of the block.
5295     """
5296     container: Optional[LN] = container_of(leaf)
5297     while container is not None and container.type != token.ENDMARKER:
5298         if is_fmt_on(container):
5299             return
5300
5301         # fix for fmt: on in children
5302         if contains_fmt_on_at_column(container, leaf.column):
5303             for child in container.children:
5304                 if contains_fmt_on_at_column(child, leaf.column):
5305                     return
5306                 yield child
5307         else:
5308             yield container
5309             container = container.next_sibling
5310
5311
5312 def is_fmt_on(container: LN) -> bool:
5313     """Determine whether formatting is switched on within a container.
5314     Determined by whether the last `# fmt:` comment is `on` or `off`.
5315     """
5316     fmt_on = False
5317     for comment in list_comments(container.prefix, is_endmarker=False):
5318         if comment.value in FMT_ON:
5319             fmt_on = True
5320         elif comment.value in FMT_OFF:
5321             fmt_on = False
5322     return fmt_on
5323
5324
5325 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5326     """Determine if children at a given column have formatting switched on."""
5327     for child in container.children:
5328         if (
5329             isinstance(child, Node)
5330             and first_leaf_column(child) == column
5331             or isinstance(child, Leaf)
5332             and child.column == column
5333         ):
5334             if is_fmt_on(child):
5335                 return True
5336
5337     return False
5338
5339
5340 def first_leaf_column(node: Node) -> Optional[int]:
5341     """Returns the column of the first leaf child of a node."""
5342     for child in node.children:
5343         if isinstance(child, Leaf):
5344             return child.column
5345     return None
5346
5347
5348 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5349     """If it's safe, make the parens in the atom `node` invisible, recursively.
5350     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5351     as they are redundant.
5352
5353     Returns whether the node should itself be wrapped in invisible parentheses.
5354
5355     """
5356     if (
5357         node.type != syms.atom
5358         or is_empty_tuple(node)
5359         or is_one_tuple(node)
5360         or (is_yield(node) and parent.type != syms.expr_stmt)
5361         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5362     ):
5363         return False
5364
5365     first = node.children[0]
5366     last = node.children[-1]
5367     if first.type == token.LPAR and last.type == token.RPAR:
5368         middle = node.children[1]
5369         # make parentheses invisible
5370         first.value = ""  # type: ignore
5371         last.value = ""  # type: ignore
5372         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5373
5374         if is_atom_with_invisible_parens(middle):
5375             # Strip the invisible parens from `middle` by replacing
5376             # it with the child in-between the invisible parens
5377             middle.replace(middle.children[1])
5378
5379         return False
5380
5381     return True
5382
5383
5384 def is_atom_with_invisible_parens(node: LN) -> bool:
5385     """Given a `LN`, determines whether it's an atom `node` with invisible
5386     parens. Useful in dedupe-ing and normalizing parens.
5387     """
5388     if isinstance(node, Leaf) or node.type != syms.atom:
5389         return False
5390
5391     first, last = node.children[0], node.children[-1]
5392     return (
5393         isinstance(first, Leaf)
5394         and first.type == token.LPAR
5395         and first.value == ""
5396         and isinstance(last, Leaf)
5397         and last.type == token.RPAR
5398         and last.value == ""
5399     )
5400
5401
5402 def is_empty_tuple(node: LN) -> bool:
5403     """Return True if `node` holds an empty tuple."""
5404     return (
5405         node.type == syms.atom
5406         and len(node.children) == 2
5407         and node.children[0].type == token.LPAR
5408         and node.children[1].type == token.RPAR
5409     )
5410
5411
5412 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5413     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5414
5415     Parenthesis can be optional. Returns None otherwise"""
5416     if len(node.children) != 3:
5417         return None
5418
5419     lpar, wrapped, rpar = node.children
5420     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5421         return None
5422
5423     return wrapped
5424
5425
5426 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5427     """Wrap `child` in parentheses.
5428
5429     This replaces `child` with an atom holding the parentheses and the old
5430     child.  That requires moving the prefix.
5431
5432     If `visible` is False, the leaves will be valueless (and thus invisible).
5433     """
5434     lpar = Leaf(token.LPAR, "(" if visible else "")
5435     rpar = Leaf(token.RPAR, ")" if visible else "")
5436     prefix = child.prefix
5437     child.prefix = ""
5438     index = child.remove() or 0
5439     new_child = Node(syms.atom, [lpar, child, rpar])
5440     new_child.prefix = prefix
5441     parent.insert_child(index, new_child)
5442
5443
5444 def is_one_tuple(node: LN) -> bool:
5445     """Return True if `node` holds a tuple with one element, with or without parens."""
5446     if node.type == syms.atom:
5447         gexp = unwrap_singleton_parenthesis(node)
5448         if gexp is None or gexp.type != syms.testlist_gexp:
5449             return False
5450
5451         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5452
5453     return (
5454         node.type in IMPLICIT_TUPLE
5455         and len(node.children) == 2
5456         and node.children[1].type == token.COMMA
5457     )
5458
5459
5460 def is_walrus_assignment(node: LN) -> bool:
5461     """Return True iff `node` is of the shape ( test := test )"""
5462     inner = unwrap_singleton_parenthesis(node)
5463     return inner is not None and inner.type == syms.namedexpr_test
5464
5465
5466 def is_yield(node: LN) -> bool:
5467     """Return True if `node` holds a `yield` or `yield from` expression."""
5468     if node.type == syms.yield_expr:
5469         return True
5470
5471     if node.type == token.NAME and node.value == "yield":  # type: ignore
5472         return True
5473
5474     if node.type != syms.atom:
5475         return False
5476
5477     if len(node.children) != 3:
5478         return False
5479
5480     lpar, expr, rpar = node.children
5481     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5482         return is_yield(expr)
5483
5484     return False
5485
5486
5487 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5488     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5489
5490     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5491     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5492     extended iterable unpacking (PEP 3132) and additional unpacking
5493     generalizations (PEP 448).
5494     """
5495     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5496         return False
5497
5498     p = leaf.parent
5499     if p.type == syms.star_expr:
5500         # Star expressions are also used as assignment targets in extended
5501         # iterable unpacking (PEP 3132).  See what its parent is instead.
5502         if not p.parent:
5503             return False
5504
5505         p = p.parent
5506
5507     return p.type in within
5508
5509
5510 def is_multiline_string(leaf: Leaf) -> bool:
5511     """Return True if `leaf` is a multiline string that actually spans many lines."""
5512     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5513
5514
5515 def is_stub_suite(node: Node) -> bool:
5516     """Return True if `node` is a suite with a stub body."""
5517     if (
5518         len(node.children) != 4
5519         or node.children[0].type != token.NEWLINE
5520         or node.children[1].type != token.INDENT
5521         or node.children[3].type != token.DEDENT
5522     ):
5523         return False
5524
5525     return is_stub_body(node.children[2])
5526
5527
5528 def is_stub_body(node: LN) -> bool:
5529     """Return True if `node` is a simple statement containing an ellipsis."""
5530     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5531         return False
5532
5533     if len(node.children) != 2:
5534         return False
5535
5536     child = node.children[0]
5537     return (
5538         child.type == syms.atom
5539         and len(child.children) == 3
5540         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5541     )
5542
5543
5544 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5545     """Return maximum delimiter priority inside `node`.
5546
5547     This is specific to atoms with contents contained in a pair of parentheses.
5548     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5549     """
5550     if node.type != syms.atom:
5551         return 0
5552
5553     first = node.children[0]
5554     last = node.children[-1]
5555     if not (first.type == token.LPAR and last.type == token.RPAR):
5556         return 0
5557
5558     bt = BracketTracker()
5559     for c in node.children[1:-1]:
5560         if isinstance(c, Leaf):
5561             bt.mark(c)
5562         else:
5563             for leaf in c.leaves():
5564                 bt.mark(leaf)
5565     try:
5566         return bt.max_delimiter_priority()
5567
5568     except ValueError:
5569         return 0
5570
5571
5572 def ensure_visible(leaf: Leaf) -> None:
5573     """Make sure parentheses are visible.
5574
5575     They could be invisible as part of some statements (see
5576     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5577     """
5578     if leaf.type == token.LPAR:
5579         leaf.value = "("
5580     elif leaf.type == token.RPAR:
5581         leaf.value = ")"
5582
5583
5584 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
5585     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
5586
5587     if not (
5588         opening_bracket.parent
5589         and opening_bracket.parent.type in {syms.atom, syms.import_from}
5590         and opening_bracket.value in "[{("
5591     ):
5592         return False
5593
5594     try:
5595         last_leaf = line.leaves[-1]
5596         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
5597         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5598     except (IndexError, ValueError):
5599         return False
5600
5601     return max_priority == COMMA_PRIORITY
5602
5603
5604 def get_features_used(node: Node) -> Set[Feature]:
5605     """Return a set of (relatively) new Python features used in this file.
5606
5607     Currently looking for:
5608     - f-strings;
5609     - underscores in numeric literals;
5610     - trailing commas after * or ** in function signatures and calls;
5611     - positional only arguments in function signatures and lambdas;
5612     """
5613     features: Set[Feature] = set()
5614     for n in node.pre_order():
5615         if n.type == token.STRING:
5616             value_head = n.value[:2]  # type: ignore
5617             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5618                 features.add(Feature.F_STRINGS)
5619
5620         elif n.type == token.NUMBER:
5621             if "_" in n.value:  # type: ignore
5622                 features.add(Feature.NUMERIC_UNDERSCORES)
5623
5624         elif n.type == token.SLASH:
5625             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5626                 features.add(Feature.POS_ONLY_ARGUMENTS)
5627
5628         elif n.type == token.COLONEQUAL:
5629             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5630
5631         elif (
5632             n.type in {syms.typedargslist, syms.arglist}
5633             and n.children
5634             and n.children[-1].type == token.COMMA
5635         ):
5636             if n.type == syms.typedargslist:
5637                 feature = Feature.TRAILING_COMMA_IN_DEF
5638             else:
5639                 feature = Feature.TRAILING_COMMA_IN_CALL
5640
5641             for ch in n.children:
5642                 if ch.type in STARS:
5643                     features.add(feature)
5644
5645                 if ch.type == syms.argument:
5646                     for argch in ch.children:
5647                         if argch.type in STARS:
5648                             features.add(feature)
5649
5650     return features
5651
5652
5653 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5654     """Detect the version to target based on the nodes used."""
5655     features = get_features_used(node)
5656     return {
5657         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5658     }
5659
5660
5661 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5662     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5663
5664     Brackets can be omitted if the entire trailer up to and including
5665     a preceding closing bracket fits in one line.
5666
5667     Yielded sets are cumulative (contain results of previous yields, too).  First
5668     set is empty.
5669     """
5670
5671     omit: Set[LeafID] = set()
5672     yield omit
5673
5674     length = 4 * line.depth
5675     opening_bracket: Optional[Leaf] = None
5676     closing_bracket: Optional[Leaf] = None
5677     inner_brackets: Set[LeafID] = set()
5678     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5679         length += leaf_length
5680         if length > line_length:
5681             break
5682
5683         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5684         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5685             break
5686
5687         if opening_bracket:
5688             if leaf is opening_bracket:
5689                 opening_bracket = None
5690             elif leaf.type in CLOSING_BRACKETS:
5691                 inner_brackets.add(id(leaf))
5692         elif leaf.type in CLOSING_BRACKETS:
5693             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
5694                 # Empty brackets would fail a split so treat them as "inner"
5695                 # brackets (e.g. only add them to the `omit` set if another
5696                 # pair of brackets was good enough.
5697                 inner_brackets.add(id(leaf))
5698                 continue
5699
5700             if closing_bracket:
5701                 omit.add(id(closing_bracket))
5702                 omit.update(inner_brackets)
5703                 inner_brackets.clear()
5704                 yield omit
5705
5706             if leaf.value:
5707                 opening_bracket = leaf.opening_bracket
5708                 closing_bracket = leaf
5709
5710
5711 def get_future_imports(node: Node) -> Set[str]:
5712     """Return a set of __future__ imports in the file."""
5713     imports: Set[str] = set()
5714
5715     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5716         for child in children:
5717             if isinstance(child, Leaf):
5718                 if child.type == token.NAME:
5719                     yield child.value
5720
5721             elif child.type == syms.import_as_name:
5722                 orig_name = child.children[0]
5723                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5724                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5725                 yield orig_name.value
5726
5727             elif child.type == syms.import_as_names:
5728                 yield from get_imports_from_children(child.children)
5729
5730             else:
5731                 raise AssertionError("Invalid syntax parsing imports")
5732
5733     for child in node.children:
5734         if child.type != syms.simple_stmt:
5735             break
5736
5737         first_child = child.children[0]
5738         if isinstance(first_child, Leaf):
5739             # Continue looking if we see a docstring; otherwise stop.
5740             if (
5741                 len(child.children) == 2
5742                 and first_child.type == token.STRING
5743                 and child.children[1].type == token.NEWLINE
5744             ):
5745                 continue
5746
5747             break
5748
5749         elif first_child.type == syms.import_from:
5750             module_name = first_child.children[1]
5751             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5752                 break
5753
5754             imports |= set(get_imports_from_children(first_child.children[3:]))
5755         else:
5756             break
5757
5758     return imports
5759
5760
5761 @lru_cache()
5762 def get_gitignore(root: Path) -> PathSpec:
5763     """ Return a PathSpec matching gitignore content if present."""
5764     gitignore = root / ".gitignore"
5765     lines: List[str] = []
5766     if gitignore.is_file():
5767         with gitignore.open() as gf:
5768             lines = gf.readlines()
5769     return PathSpec.from_lines("gitwildmatch", lines)
5770
5771
5772 def gen_python_files(
5773     paths: Iterable[Path],
5774     root: Path,
5775     include: Optional[Pattern[str]],
5776     exclude_regexes: Iterable[Pattern[str]],
5777     report: "Report",
5778     gitignore: PathSpec,
5779 ) -> Iterator[Path]:
5780     """Generate all files under `path` whose paths are not excluded by the
5781     `exclude` regex, but are included by the `include` regex.
5782
5783     Symbolic links pointing outside of the `root` directory are ignored.
5784
5785     `report` is where output about exclusions goes.
5786     """
5787     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5788     for child in paths:
5789         # Then ignore with `exclude` option.
5790         try:
5791             normalized_path = child.resolve().relative_to(root).as_posix()
5792         except OSError as e:
5793             report.path_ignored(child, f"cannot be read because {e}")
5794             continue
5795         except ValueError:
5796             if child.is_symlink():
5797                 report.path_ignored(
5798                     child, f"is a symbolic link that points outside {root}"
5799                 )
5800                 continue
5801
5802             raise
5803
5804         # First ignore files matching .gitignore
5805         if gitignore.match_file(normalized_path):
5806             report.path_ignored(child, "matches the .gitignore file content")
5807             continue
5808
5809         normalized_path = "/" + normalized_path
5810         if child.is_dir():
5811             normalized_path += "/"
5812
5813         is_excluded = False
5814         for exclude in exclude_regexes:
5815             exclude_match = exclude.search(normalized_path) if exclude else None
5816             if exclude_match and exclude_match.group(0):
5817                 report.path_ignored(child, "matches the --exclude regular expression")
5818                 is_excluded = True
5819                 break
5820         if is_excluded:
5821             continue
5822
5823         if child.is_dir():
5824             yield from gen_python_files(
5825                 child.iterdir(), root, include, exclude_regexes, report, gitignore
5826             )
5827
5828         elif child.is_file():
5829             include_match = include.search(normalized_path) if include else True
5830             if include_match:
5831                 yield child
5832
5833
5834 @lru_cache()
5835 def find_project_root(srcs: Iterable[str]) -> Path:
5836     """Return a directory containing .git, .hg, or pyproject.toml.
5837
5838     That directory can be one of the directories passed in `srcs` or their
5839     common parent.
5840
5841     If no directory in the tree contains a marker that would specify it's the
5842     project root, the root of the file system is returned.
5843     """
5844     if not srcs:
5845         return Path("/").resolve()
5846
5847     common_base = min(Path(src).resolve() for src in srcs)
5848     if common_base.is_dir():
5849         # Append a fake file so `parents` below returns `common_base_dir`, too.
5850         common_base /= "fake-file"
5851     for directory in common_base.parents:
5852         if (directory / ".git").exists():
5853             return directory
5854
5855         if (directory / ".hg").is_dir():
5856             return directory
5857
5858         if (directory / "pyproject.toml").is_file():
5859             return directory
5860
5861     return directory
5862
5863
5864 @dataclass
5865 class Report:
5866     """Provides a reformatting counter. Can be rendered with `str(report)`."""
5867
5868     check: bool = False
5869     diff: bool = False
5870     quiet: bool = False
5871     verbose: bool = False
5872     change_count: int = 0
5873     same_count: int = 0
5874     failure_count: int = 0
5875
5876     def done(self, src: Path, changed: Changed) -> None:
5877         """Increment the counter for successful reformatting. Write out a message."""
5878         if changed is Changed.YES:
5879             reformatted = "would reformat" if self.check or self.diff else "reformatted"
5880             if self.verbose or not self.quiet:
5881                 out(f"{reformatted} {src}")
5882             self.change_count += 1
5883         else:
5884             if self.verbose:
5885                 if changed is Changed.NO:
5886                     msg = f"{src} already well formatted, good job."
5887                 else:
5888                     msg = f"{src} wasn't modified on disk since last run."
5889                 out(msg, bold=False)
5890             self.same_count += 1
5891
5892     def failed(self, src: Path, message: str) -> None:
5893         """Increment the counter for failed reformatting. Write out a message."""
5894         err(f"error: cannot format {src}: {message}")
5895         self.failure_count += 1
5896
5897     def path_ignored(self, path: Path, message: str) -> None:
5898         if self.verbose:
5899             out(f"{path} ignored: {message}", bold=False)
5900
5901     @property
5902     def return_code(self) -> int:
5903         """Return the exit code that the app should use.
5904
5905         This considers the current state of changed files and failures:
5906         - if there were any failures, return 123;
5907         - if any files were changed and --check is being used, return 1;
5908         - otherwise return 0.
5909         """
5910         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
5911         # 126 we have special return codes reserved by the shell.
5912         if self.failure_count:
5913             return 123
5914
5915         elif self.change_count and self.check:
5916             return 1
5917
5918         return 0
5919
5920     def __str__(self) -> str:
5921         """Render a color report of the current state.
5922
5923         Use `click.unstyle` to remove colors.
5924         """
5925         if self.check or self.diff:
5926             reformatted = "would be reformatted"
5927             unchanged = "would be left unchanged"
5928             failed = "would fail to reformat"
5929         else:
5930             reformatted = "reformatted"
5931             unchanged = "left unchanged"
5932             failed = "failed to reformat"
5933         report = []
5934         if self.change_count:
5935             s = "s" if self.change_count > 1 else ""
5936             report.append(
5937                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
5938             )
5939         if self.same_count:
5940             s = "s" if self.same_count > 1 else ""
5941             report.append(f"{self.same_count} file{s} {unchanged}")
5942         if self.failure_count:
5943             s = "s" if self.failure_count > 1 else ""
5944             report.append(
5945                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
5946             )
5947         return ", ".join(report) + "."
5948
5949
5950 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
5951     filename = "<unknown>"
5952     if sys.version_info >= (3, 8):
5953         # TODO: support Python 4+ ;)
5954         for minor_version in range(sys.version_info[1], 4, -1):
5955             try:
5956                 return ast.parse(src, filename, feature_version=(3, minor_version))
5957             except SyntaxError:
5958                 continue
5959     else:
5960         for feature_version in (7, 6):
5961             try:
5962                 return ast3.parse(src, filename, feature_version=feature_version)
5963             except SyntaxError:
5964                 continue
5965
5966     return ast27.parse(src)
5967
5968
5969 def _fixup_ast_constants(
5970     node: Union[ast.AST, ast3.AST, ast27.AST]
5971 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
5972     """Map ast nodes deprecated in 3.8 to Constant."""
5973     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
5974         return ast.Constant(value=node.s)
5975
5976     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
5977         return ast.Constant(value=node.n)
5978
5979     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
5980         return ast.Constant(value=node.value)
5981
5982     return node
5983
5984
5985 def _stringify_ast(
5986     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
5987 ) -> Iterator[str]:
5988     """Simple visitor generating strings to compare ASTs by content."""
5989
5990     node = _fixup_ast_constants(node)
5991
5992     yield f"{'  ' * depth}{node.__class__.__name__}("
5993
5994     for field in sorted(node._fields):  # noqa: F402
5995         # TypeIgnore has only one field 'lineno' which breaks this comparison
5996         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
5997         if sys.version_info >= (3, 8):
5998             type_ignore_classes += (ast.TypeIgnore,)
5999         if isinstance(node, type_ignore_classes):
6000             break
6001
6002         try:
6003             value = getattr(node, field)
6004         except AttributeError:
6005             continue
6006
6007         yield f"{'  ' * (depth+1)}{field}="
6008
6009         if isinstance(value, list):
6010             for item in value:
6011                 # Ignore nested tuples within del statements, because we may insert
6012                 # parentheses and they change the AST.
6013                 if (
6014                     field == "targets"
6015                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6016                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6017                 ):
6018                     for item in item.elts:
6019                         yield from _stringify_ast(item, depth + 2)
6020
6021                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6022                     yield from _stringify_ast(item, depth + 2)
6023
6024         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6025             yield from _stringify_ast(value, depth + 2)
6026
6027         else:
6028             # Constant strings may be indented across newlines, if they are
6029             # docstrings; fold spaces after newlines when comparing
6030             if (
6031                 isinstance(node, ast.Constant)
6032                 and field == "value"
6033                 and isinstance(value, str)
6034             ):
6035                 normalized = re.sub(r"\n[ \t]+", "\n ", value)
6036             else:
6037                 normalized = value
6038             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6039
6040     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6041
6042
6043 def assert_equivalent(src: str, dst: str) -> None:
6044     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6045     try:
6046         src_ast = parse_ast(src)
6047     except Exception as exc:
6048         raise AssertionError(
6049             "cannot use --safe with this file; failed to parse source file.  AST"
6050             f" error message: {exc}"
6051         )
6052
6053     try:
6054         dst_ast = parse_ast(dst)
6055     except Exception as exc:
6056         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6057         raise AssertionError(
6058             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6059             " on https://github.com/psf/black/issues.  This invalid output might be"
6060             f" helpful: {log}"
6061         ) from None
6062
6063     src_ast_str = "\n".join(_stringify_ast(src_ast))
6064     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6065     if src_ast_str != dst_ast_str:
6066         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6067         raise AssertionError(
6068             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6069             " source.  Please report a bug on https://github.com/psf/black/issues. "
6070             f" This diff might be helpful: {log}"
6071         ) from None
6072
6073
6074 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6075     """Raise AssertionError if `dst` reformats differently the second time."""
6076     newdst = format_str(dst, mode=mode)
6077     if dst != newdst:
6078         log = dump_to_file(
6079             diff(src, dst, "source", "first pass"),
6080             diff(dst, newdst, "first pass", "second pass"),
6081         )
6082         raise AssertionError(
6083             "INTERNAL ERROR: Black produced different code on the second pass of the"
6084             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6085             f"  This diff might be helpful: {log}"
6086         ) from None
6087
6088
6089 @mypyc_attr(patchable=True)
6090 def dump_to_file(*output: str) -> str:
6091     """Dump `output` to a temporary file. Return path to the file."""
6092     with tempfile.NamedTemporaryFile(
6093         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6094     ) as f:
6095         for lines in output:
6096             f.write(lines)
6097             if lines and lines[-1] != "\n":
6098                 f.write("\n")
6099     return f.name
6100
6101
6102 @contextmanager
6103 def nullcontext() -> Iterator[None]:
6104     """Return an empty context manager.
6105
6106     To be used like `nullcontext` in Python 3.7.
6107     """
6108     yield
6109
6110
6111 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6112     """Return a unified diff string between strings `a` and `b`."""
6113     import difflib
6114
6115     a_lines = [line + "\n" for line in a.splitlines()]
6116     b_lines = [line + "\n" for line in b.splitlines()]
6117     return "".join(
6118         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6119     )
6120
6121
6122 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6123     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6124     err("Aborted!")
6125     for task in tasks:
6126         task.cancel()
6127
6128
6129 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6130     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6131     try:
6132         if sys.version_info[:2] >= (3, 7):
6133             all_tasks = asyncio.all_tasks
6134         else:
6135             all_tasks = asyncio.Task.all_tasks
6136         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6137         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6138         if not to_cancel:
6139             return
6140
6141         for task in to_cancel:
6142             task.cancel()
6143         loop.run_until_complete(
6144             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6145         )
6146     finally:
6147         # `concurrent.futures.Future` objects cannot be cancelled once they
6148         # are already running. There might be some when the `shutdown()` happened.
6149         # Silence their logger's spew about the event loop being closed.
6150         cf_logger = logging.getLogger("concurrent.futures")
6151         cf_logger.setLevel(logging.CRITICAL)
6152         loop.close()
6153
6154
6155 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6156     """Replace `regex` with `replacement` twice on `original`.
6157
6158     This is used by string normalization to perform replaces on
6159     overlapping matches.
6160     """
6161     return regex.sub(replacement, regex.sub(replacement, original))
6162
6163
6164 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6165     """Compile a regular expression string in `regex`.
6166
6167     If it contains newlines, use verbose mode.
6168     """
6169     if "\n" in regex:
6170         regex = "(?x)" + regex
6171     compiled: Pattern[str] = re.compile(regex)
6172     return compiled
6173
6174
6175 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6176     """Like `reversed(enumerate(sequence))` if that were possible."""
6177     index = len(sequence) - 1
6178     for element in reversed(sequence):
6179         yield (index, element)
6180         index -= 1
6181
6182
6183 def enumerate_with_length(
6184     line: Line, reversed: bool = False
6185 ) -> Iterator[Tuple[Index, Leaf, int]]:
6186     """Return an enumeration of leaves with their length.
6187
6188     Stops prematurely on multiline strings and standalone comments.
6189     """
6190     op = cast(
6191         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6192         enumerate_reversed if reversed else enumerate,
6193     )
6194     for index, leaf in op(line.leaves):
6195         length = len(leaf.prefix) + len(leaf.value)
6196         if "\n" in leaf.value:
6197             return  # Multiline strings, we can't continue.
6198
6199         for comment in line.comments_after(leaf):
6200             length += len(comment.value)
6201
6202         yield index, leaf, length
6203
6204
6205 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6206     """Return True if `line` is no longer than `line_length`.
6207
6208     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6209     """
6210     if not line_str:
6211         line_str = line_to_string(line)
6212     return (
6213         len(line_str) <= line_length
6214         and "\n" not in line_str  # multiline strings
6215         and not line.contains_standalone_comments()
6216     )
6217
6218
6219 def can_be_split(line: Line) -> bool:
6220     """Return False if the line cannot be split *for sure*.
6221
6222     This is not an exhaustive search but a cheap heuristic that we can use to
6223     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6224     in unnecessary parentheses).
6225     """
6226     leaves = line.leaves
6227     if len(leaves) < 2:
6228         return False
6229
6230     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6231         call_count = 0
6232         dot_count = 0
6233         next = leaves[-1]
6234         for leaf in leaves[-2::-1]:
6235             if leaf.type in OPENING_BRACKETS:
6236                 if next.type not in CLOSING_BRACKETS:
6237                     return False
6238
6239                 call_count += 1
6240             elif leaf.type == token.DOT:
6241                 dot_count += 1
6242             elif leaf.type == token.NAME:
6243                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6244                     return False
6245
6246             elif leaf.type not in CLOSING_BRACKETS:
6247                 return False
6248
6249             if dot_count > 1 and call_count > 1:
6250                 return False
6251
6252     return True
6253
6254
6255 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
6256     """Does `line` have a shape safe to reformat without optional parens around it?
6257
6258     Returns True for only a subset of potentially nice looking formattings but
6259     the point is to not return false positives that end up producing lines that
6260     are too long.
6261     """
6262     bt = line.bracket_tracker
6263     if not bt.delimiters:
6264         # Without delimiters the optional parentheses are useless.
6265         return True
6266
6267     max_priority = bt.max_delimiter_priority()
6268     if bt.delimiter_count_with_priority(max_priority) > 1:
6269         # With more than one delimiter of a kind the optional parentheses read better.
6270         return False
6271
6272     if max_priority == DOT_PRIORITY:
6273         # A single stranded method call doesn't require optional parentheses.
6274         return True
6275
6276     assert len(line.leaves) >= 2, "Stranded delimiter"
6277
6278     first = line.leaves[0]
6279     second = line.leaves[1]
6280     penultimate = line.leaves[-2]
6281     last = line.leaves[-1]
6282
6283     # With a single delimiter, omit if the expression starts or ends with
6284     # a bracket.
6285     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6286         remainder = False
6287         length = 4 * line.depth
6288         for _index, leaf, leaf_length in enumerate_with_length(line):
6289             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6290                 remainder = True
6291             if remainder:
6292                 length += leaf_length
6293                 if length > line_length:
6294                     break
6295
6296                 if leaf.type in OPENING_BRACKETS:
6297                     # There are brackets we can further split on.
6298                     remainder = False
6299
6300         else:
6301             # checked the entire string and line length wasn't exceeded
6302             if len(line.leaves) == _index + 1:
6303                 return True
6304
6305         # Note: we are not returning False here because a line might have *both*
6306         # a leading opening bracket and a trailing closing bracket.  If the
6307         # opening bracket doesn't match our rule, maybe the closing will.
6308
6309     if (
6310         last.type == token.RPAR
6311         or last.type == token.RBRACE
6312         or (
6313             # don't use indexing for omitting optional parentheses;
6314             # it looks weird
6315             last.type == token.RSQB
6316             and last.parent
6317             and last.parent.type != syms.trailer
6318         )
6319     ):
6320         if penultimate.type in OPENING_BRACKETS:
6321             # Empty brackets don't help.
6322             return False
6323
6324         if is_multiline_string(first):
6325             # Additional wrapping of a multiline string in this situation is
6326             # unnecessary.
6327             return True
6328
6329         length = 4 * line.depth
6330         seen_other_brackets = False
6331         for _index, leaf, leaf_length in enumerate_with_length(line):
6332             length += leaf_length
6333             if leaf is last.opening_bracket:
6334                 if seen_other_brackets or length <= line_length:
6335                     return True
6336
6337             elif leaf.type in OPENING_BRACKETS:
6338                 # There are brackets we can further split on.
6339                 seen_other_brackets = True
6340
6341     return False
6342
6343
6344 def get_cache_file(mode: Mode) -> Path:
6345     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6346
6347
6348 def read_cache(mode: Mode) -> Cache:
6349     """Read the cache if it exists and is well formed.
6350
6351     If it is not well formed, the call to write_cache later should resolve the issue.
6352     """
6353     cache_file = get_cache_file(mode)
6354     if not cache_file.exists():
6355         return {}
6356
6357     with cache_file.open("rb") as fobj:
6358         try:
6359             cache: Cache = pickle.load(fobj)
6360         except (pickle.UnpicklingError, ValueError):
6361             return {}
6362
6363     return cache
6364
6365
6366 def get_cache_info(path: Path) -> CacheInfo:
6367     """Return the information used to check if a file is already formatted or not."""
6368     stat = path.stat()
6369     return stat.st_mtime, stat.st_size
6370
6371
6372 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6373     """Split an iterable of paths in `sources` into two sets.
6374
6375     The first contains paths of files that modified on disk or are not in the
6376     cache. The other contains paths to non-modified files.
6377     """
6378     todo, done = set(), set()
6379     for src in sources:
6380         src = src.resolve()
6381         if cache.get(src) != get_cache_info(src):
6382             todo.add(src)
6383         else:
6384             done.add(src)
6385     return todo, done
6386
6387
6388 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6389     """Update the cache file."""
6390     cache_file = get_cache_file(mode)
6391     try:
6392         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6393         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6394         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6395             pickle.dump(new_cache, f, protocol=4)
6396         os.replace(f.name, cache_file)
6397     except OSError:
6398         pass
6399
6400
6401 def patch_click() -> None:
6402     """Make Click not crash.
6403
6404     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6405     default which restricts paths that it can access during the lifetime of the
6406     application.  Click refuses to work in this scenario by raising a RuntimeError.
6407
6408     In case of Black the likelihood that non-ASCII characters are going to be used in
6409     file paths is minimal since it's Python source code.  Moreover, this crash was
6410     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6411     """
6412     try:
6413         from click import core
6414         from click import _unicodefun  # type: ignore
6415     except ModuleNotFoundError:
6416         return
6417
6418     for module in (core, _unicodefun):
6419         if hasattr(module, "_verify_python3_env"):
6420             module._verify_python3_env = lambda: None
6421
6422
6423 def patched_main() -> None:
6424     freeze_support()
6425     patch_click()
6426     main()
6427
6428
6429 def fix_docstring(docstring: str, prefix: str) -> str:
6430     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6431     if not docstring:
6432         return ""
6433     # Convert tabs to spaces (following the normal Python rules)
6434     # and split into a list of lines:
6435     lines = docstring.expandtabs().splitlines()
6436     # Determine minimum indentation (first line doesn't count):
6437     indent = sys.maxsize
6438     for line in lines[1:]:
6439         stripped = line.lstrip()
6440         if stripped:
6441             indent = min(indent, len(line) - len(stripped))
6442     # Remove indentation (first line is special):
6443     trimmed = [lines[0].strip()]
6444     if indent < sys.maxsize:
6445         last_line_idx = len(lines) - 2
6446         for i, line in enumerate(lines[1:]):
6447             stripped_line = line[indent:].rstrip()
6448             if stripped_line or i == last_line_idx:
6449                 trimmed.append(prefix + stripped_line)
6450             else:
6451                 trimmed.append("")
6452     # Return a single string:
6453     return "\n".join(trimmed)
6454
6455
6456 if __name__ == "__main__":
6457     patched_main()