]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Upgrade docs to Sphinx 3+ and add doc build test (#1613)
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 T = TypeVar("T")
116 E = TypeVar("E", bound=Exception)
117
118
119 class Ok(Generic[T]):
120     def __init__(self, value: T) -> None:
121         self._value = value
122
123     def ok(self) -> T:
124         return self._value
125
126
127 class Err(Generic[E]):
128     def __init__(self, e: E) -> None:
129         self._e = e
130
131     def err(self) -> E:
132         return self._e
133
134
135 # The 'Result' return type is used to implement an error-handling model heavily
136 # influenced by that used by the Rust programming language
137 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
138 Result = Union[Ok[T], Err[E]]
139 TResult = Result[T, CannotTransform]  # (T)ransform Result
140 TMatchResult = TResult[Index]
141
142
143 class WriteBack(Enum):
144     NO = 0
145     YES = 1
146     DIFF = 2
147     CHECK = 3
148     COLOR_DIFF = 4
149
150     @classmethod
151     def from_configuration(
152         cls, *, check: bool, diff: bool, color: bool = False
153     ) -> "WriteBack":
154         if check and not diff:
155             return cls.CHECK
156
157         if diff and color:
158             return cls.COLOR_DIFF
159
160         return cls.DIFF if diff else cls.YES
161
162
163 class Changed(Enum):
164     NO = 0
165     CACHED = 1
166     YES = 2
167
168
169 class TargetVersion(Enum):
170     PY27 = 2
171     PY33 = 3
172     PY34 = 4
173     PY35 = 5
174     PY36 = 6
175     PY37 = 7
176     PY38 = 8
177
178     def is_python2(self) -> bool:
179         return self is TargetVersion.PY27
180
181
182 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
183
184
185 class Feature(Enum):
186     # All string literals are unicode
187     UNICODE_LITERALS = 1
188     F_STRINGS = 2
189     NUMERIC_UNDERSCORES = 3
190     TRAILING_COMMA_IN_CALL = 4
191     TRAILING_COMMA_IN_DEF = 5
192     # The following two feature-flags are mutually exclusive, and exactly one should be
193     # set for every version of python.
194     ASYNC_IDENTIFIERS = 6
195     ASYNC_KEYWORDS = 7
196     ASSIGNMENT_EXPRESSIONS = 8
197     POS_ONLY_ARGUMENTS = 9
198
199
200 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
201     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
202     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
203     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
204     TargetVersion.PY35: {
205         Feature.UNICODE_LITERALS,
206         Feature.TRAILING_COMMA_IN_CALL,
207         Feature.ASYNC_IDENTIFIERS,
208     },
209     TargetVersion.PY36: {
210         Feature.UNICODE_LITERALS,
211         Feature.F_STRINGS,
212         Feature.NUMERIC_UNDERSCORES,
213         Feature.TRAILING_COMMA_IN_CALL,
214         Feature.TRAILING_COMMA_IN_DEF,
215         Feature.ASYNC_IDENTIFIERS,
216     },
217     TargetVersion.PY37: {
218         Feature.UNICODE_LITERALS,
219         Feature.F_STRINGS,
220         Feature.NUMERIC_UNDERSCORES,
221         Feature.TRAILING_COMMA_IN_CALL,
222         Feature.TRAILING_COMMA_IN_DEF,
223         Feature.ASYNC_KEYWORDS,
224     },
225     TargetVersion.PY38: {
226         Feature.UNICODE_LITERALS,
227         Feature.F_STRINGS,
228         Feature.NUMERIC_UNDERSCORES,
229         Feature.TRAILING_COMMA_IN_CALL,
230         Feature.TRAILING_COMMA_IN_DEF,
231         Feature.ASYNC_KEYWORDS,
232         Feature.ASSIGNMENT_EXPRESSIONS,
233         Feature.POS_ONLY_ARGUMENTS,
234     },
235 }
236
237
238 @dataclass
239 class Mode:
240     target_versions: Set[TargetVersion] = field(default_factory=set)
241     line_length: int = DEFAULT_LINE_LENGTH
242     string_normalization: bool = True
243     experimental_string_processing: bool = False
244     is_pyi: bool = False
245
246     def get_cache_key(self) -> str:
247         if self.target_versions:
248             version_str = ",".join(
249                 str(version.value)
250                 for version in sorted(self.target_versions, key=lambda v: v.value)
251             )
252         else:
253             version_str = "-"
254         parts = [
255             version_str,
256             str(self.line_length),
257             str(int(self.string_normalization)),
258             str(int(self.is_pyi)),
259         ]
260         return ".".join(parts)
261
262
263 # Legacy name, left for integrations.
264 FileMode = Mode
265
266
267 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
268     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
269
270
271 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
272     """Find the absolute filepath to a pyproject.toml if it exists"""
273     path_project_root = find_project_root(path_search_start)
274     path_pyproject_toml = path_project_root / "pyproject.toml"
275     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
276
277
278 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
279     """Parse a pyproject toml file, pulling out relevant parts for Black
280
281     If parsing fails, will raise a toml.TomlDecodeError
282     """
283     pyproject_toml = toml.load(path_config)
284     config = pyproject_toml.get("tool", {}).get("black", {})
285     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
286
287
288 def read_pyproject_toml(
289     ctx: click.Context, param: click.Parameter, value: Optional[str]
290 ) -> Optional[str]:
291     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
292
293     Returns the path to a successfully found and read configuration file, None
294     otherwise.
295     """
296     if not value:
297         value = find_pyproject_toml(ctx.params.get("src", ()))
298         if value is None:
299             return None
300
301     try:
302         config = parse_pyproject_toml(value)
303     except (toml.TomlDecodeError, OSError) as e:
304         raise click.FileError(
305             filename=value, hint=f"Error reading configuration file: {e}"
306         )
307
308     if not config:
309         return None
310     else:
311         # Sanitize the values to be Click friendly. For more information please see:
312         # https://github.com/psf/black/issues/1458
313         # https://github.com/pallets/click/issues/1567
314         config = {
315             k: str(v) if not isinstance(v, (list, dict)) else v
316             for k, v in config.items()
317         }
318
319     target_version = config.get("target_version")
320     if target_version is not None and not isinstance(target_version, list):
321         raise click.BadOptionUsage(
322             "target-version", "Config key target-version must be a list"
323         )
324
325     default_map: Dict[str, Any] = {}
326     if ctx.default_map:
327         default_map.update(ctx.default_map)
328     default_map.update(config)
329
330     ctx.default_map = default_map
331     return value
332
333
334 def target_version_option_callback(
335     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
336 ) -> List[TargetVersion]:
337     """Compute the target versions from a --target-version flag.
338
339     This is its own function because mypy couldn't infer the type correctly
340     when it was a lambda, causing mypyc trouble.
341     """
342     return [TargetVersion[val.upper()] for val in v]
343
344
345 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
346 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
347 @click.option(
348     "-l",
349     "--line-length",
350     type=int,
351     default=DEFAULT_LINE_LENGTH,
352     help="How many characters per line to allow.",
353     show_default=True,
354 )
355 @click.option(
356     "-t",
357     "--target-version",
358     type=click.Choice([v.name.lower() for v in TargetVersion]),
359     callback=target_version_option_callback,
360     multiple=True,
361     help=(
362         "Python versions that should be supported by Black's output. [default: per-file"
363         " auto-detection]"
364     ),
365 )
366 @click.option(
367     "--pyi",
368     is_flag=True,
369     help=(
370         "Format all input files like typing stubs regardless of file extension (useful"
371         " when piping source on standard input)."
372     ),
373 )
374 @click.option(
375     "-S",
376     "--skip-string-normalization",
377     is_flag=True,
378     help="Don't normalize string quotes or prefixes.",
379 )
380 @click.option(
381     "--experimental-string-processing",
382     is_flag=True,
383     hidden=True,
384     help=(
385         "Experimental option that performs more normalization on string literals."
386         " Currently disabled because it leads to some crashes."
387     ),
388 )
389 @click.option(
390     "--check",
391     is_flag=True,
392     help=(
393         "Don't write the files back, just return the status.  Return code 0 means"
394         " nothing would change.  Return code 1 means some files would be reformatted."
395         " Return code 123 means there was an internal error."
396     ),
397 )
398 @click.option(
399     "--diff",
400     is_flag=True,
401     help="Don't write the files back, just output a diff for each file on stdout.",
402 )
403 @click.option(
404     "--color/--no-color",
405     is_flag=True,
406     help="Show colored diff. Only applies when `--diff` is given.",
407 )
408 @click.option(
409     "--fast/--safe",
410     is_flag=True,
411     help="If --fast given, skip temporary sanity checks. [default: --safe]",
412 )
413 @click.option(
414     "--include",
415     type=str,
416     default=DEFAULT_INCLUDES,
417     help=(
418         "A regular expression that matches files and directories that should be"
419         " included on recursive searches.  An empty value means all files are included"
420         " regardless of the name.  Use forward slashes for directories on all platforms"
421         " (Windows, too).  Exclusions are calculated first, inclusions later."
422     ),
423     show_default=True,
424 )
425 @click.option(
426     "--exclude",
427     type=str,
428     default=DEFAULT_EXCLUDES,
429     help=(
430         "A regular expression that matches files and directories that should be"
431         " excluded on recursive searches.  An empty value means no paths are excluded."
432         " Use forward slashes for directories on all platforms (Windows, too). "
433         " Exclusions are calculated first, inclusions later."
434     ),
435     show_default=True,
436 )
437 @click.option(
438     "--force-exclude",
439     type=str,
440     help=(
441         "Like --exclude, but files and directories matching this regex will be "
442         "excluded even when they are passed explicitly as arguments"
443     ),
444 )
445 @click.option(
446     "-q",
447     "--quiet",
448     is_flag=True,
449     help=(
450         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
451         " those with 2>/dev/null."
452     ),
453 )
454 @click.option(
455     "-v",
456     "--verbose",
457     is_flag=True,
458     help=(
459         "Also emit messages to stderr about files that were not changed or were ignored"
460         " due to --exclude=."
461     ),
462 )
463 @click.version_option(version=__version__)
464 @click.argument(
465     "src",
466     nargs=-1,
467     type=click.Path(
468         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
469     ),
470     is_eager=True,
471 )
472 @click.option(
473     "--config",
474     type=click.Path(
475         exists=True,
476         file_okay=True,
477         dir_okay=False,
478         readable=True,
479         allow_dash=False,
480         path_type=str,
481     ),
482     is_eager=True,
483     callback=read_pyproject_toml,
484     help="Read configuration from FILE path.",
485 )
486 @click.pass_context
487 def main(
488     ctx: click.Context,
489     code: Optional[str],
490     line_length: int,
491     target_version: List[TargetVersion],
492     check: bool,
493     diff: bool,
494     color: bool,
495     fast: bool,
496     pyi: bool,
497     skip_string_normalization: bool,
498     experimental_string_processing: bool,
499     quiet: bool,
500     verbose: bool,
501     include: str,
502     exclude: str,
503     force_exclude: Optional[str],
504     src: Tuple[str, ...],
505     config: Optional[str],
506 ) -> None:
507     """The uncompromising code formatter."""
508     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
509     if target_version:
510         versions = set(target_version)
511     else:
512         # We'll autodetect later.
513         versions = set()
514     mode = Mode(
515         target_versions=versions,
516         line_length=line_length,
517         is_pyi=pyi,
518         string_normalization=not skip_string_normalization,
519         experimental_string_processing=experimental_string_processing,
520     )
521     if config and verbose:
522         out(f"Using configuration from {config}.", bold=False, fg="blue")
523     if code is not None:
524         print(format_str(code, mode=mode))
525         ctx.exit(0)
526     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
527     sources = get_sources(
528         ctx=ctx,
529         src=src,
530         quiet=quiet,
531         verbose=verbose,
532         include=include,
533         exclude=exclude,
534         force_exclude=force_exclude,
535         report=report,
536     )
537
538     path_empty(
539         sources,
540         "No Python files are present to be formatted. Nothing to do 😴",
541         quiet,
542         verbose,
543         ctx,
544     )
545
546     if len(sources) == 1:
547         reformat_one(
548             src=sources.pop(),
549             fast=fast,
550             write_back=write_back,
551             mode=mode,
552             report=report,
553         )
554     else:
555         reformat_many(
556             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
557         )
558
559     if verbose or not quiet:
560         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
561         click.secho(str(report), err=True)
562     ctx.exit(report.return_code)
563
564
565 def get_sources(
566     *,
567     ctx: click.Context,
568     src: Tuple[str, ...],
569     quiet: bool,
570     verbose: bool,
571     include: str,
572     exclude: str,
573     force_exclude: Optional[str],
574     report: "Report",
575 ) -> Set[Path]:
576     """Compute the set of files to be formatted."""
577     try:
578         include_regex = re_compile_maybe_verbose(include)
579     except re.error:
580         err(f"Invalid regular expression for include given: {include!r}")
581         ctx.exit(2)
582     try:
583         exclude_regex = re_compile_maybe_verbose(exclude)
584     except re.error:
585         err(f"Invalid regular expression for exclude given: {exclude!r}")
586         ctx.exit(2)
587     try:
588         force_exclude_regex = (
589             re_compile_maybe_verbose(force_exclude) if force_exclude else None
590         )
591     except re.error:
592         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
593         ctx.exit(2)
594
595     root = find_project_root(src)
596     sources: Set[Path] = set()
597     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
598     gitignore = get_gitignore(root)
599
600     for s in src:
601         p = Path(s)
602         if p.is_dir():
603             sources.update(
604                 gen_python_files(
605                     p.iterdir(),
606                     root,
607                     include_regex,
608                     exclude_regex,
609                     force_exclude_regex,
610                     report,
611                     gitignore,
612                 )
613             )
614         elif s == "-":
615             sources.add(p)
616         elif p.is_file():
617             normalized_path = normalize_path_maybe_ignore(p, root, report)
618             if normalized_path is None:
619                 continue
620
621             normalized_path = "/" + normalized_path
622             # Hard-exclude any files that matches the `--force-exclude` regex.
623             if force_exclude_regex:
624                 force_exclude_match = force_exclude_regex.search(normalized_path)
625             else:
626                 force_exclude_match = None
627             if force_exclude_match and force_exclude_match.group(0):
628                 report.path_ignored(p, "matches the --force-exclude regular expression")
629                 continue
630
631             sources.add(p)
632         else:
633             err(f"invalid path: {s}")
634     return sources
635
636
637 def path_empty(
638     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
639 ) -> None:
640     """
641     Exit if there is no `src` provided for formatting
642     """
643     if len(src) == 0:
644         if verbose or not quiet:
645             out(msg)
646             ctx.exit(0)
647
648
649 def reformat_one(
650     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
651 ) -> None:
652     """Reformat a single file under `src` without spawning child processes.
653
654     `fast`, `write_back`, and `mode` options are passed to
655     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
656     """
657     try:
658         changed = Changed.NO
659         if not src.is_file() and str(src) == "-":
660             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
661                 changed = Changed.YES
662         else:
663             cache: Cache = {}
664             if write_back != WriteBack.DIFF:
665                 cache = read_cache(mode)
666                 res_src = src.resolve()
667                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
668                     changed = Changed.CACHED
669             if changed is not Changed.CACHED and format_file_in_place(
670                 src, fast=fast, write_back=write_back, mode=mode
671             ):
672                 changed = Changed.YES
673             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
674                 write_back is WriteBack.CHECK and changed is Changed.NO
675             ):
676                 write_cache(cache, [src], mode)
677         report.done(src, changed)
678     except Exception as exc:
679         if report.verbose:
680             traceback.print_exc()
681         report.failed(src, str(exc))
682
683
684 def reformat_many(
685     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
686 ) -> None:
687     """Reformat multiple files using a ProcessPoolExecutor."""
688     executor: Executor
689     loop = asyncio.get_event_loop()
690     worker_count = os.cpu_count()
691     if sys.platform == "win32":
692         # Work around https://bugs.python.org/issue26903
693         worker_count = min(worker_count, 61)
694     try:
695         executor = ProcessPoolExecutor(max_workers=worker_count)
696     except (ImportError, OSError):
697         # we arrive here if the underlying system does not support multi-processing
698         # like in AWS Lambda or Termux, in which case we gracefully fallback to
699         # a ThreadPollExecutor with just a single worker (more workers would not do us
700         # any good due to the Global Interpreter Lock)
701         executor = ThreadPoolExecutor(max_workers=1)
702
703     try:
704         loop.run_until_complete(
705             schedule_formatting(
706                 sources=sources,
707                 fast=fast,
708                 write_back=write_back,
709                 mode=mode,
710                 report=report,
711                 loop=loop,
712                 executor=executor,
713             )
714         )
715     finally:
716         shutdown(loop)
717         if executor is not None:
718             executor.shutdown()
719
720
721 async def schedule_formatting(
722     sources: Set[Path],
723     fast: bool,
724     write_back: WriteBack,
725     mode: Mode,
726     report: "Report",
727     loop: asyncio.AbstractEventLoop,
728     executor: Executor,
729 ) -> None:
730     """Run formatting of `sources` in parallel using the provided `executor`.
731
732     (Use ProcessPoolExecutors for actual parallelism.)
733
734     `write_back`, `fast`, and `mode` options are passed to
735     :func:`format_file_in_place`.
736     """
737     cache: Cache = {}
738     if write_back != WriteBack.DIFF:
739         cache = read_cache(mode)
740         sources, cached = filter_cached(cache, sources)
741         for src in sorted(cached):
742             report.done(src, Changed.CACHED)
743     if not sources:
744         return
745
746     cancelled = []
747     sources_to_cache = []
748     lock = None
749     if write_back == WriteBack.DIFF:
750         # For diff output, we need locks to ensure we don't interleave output
751         # from different processes.
752         manager = Manager()
753         lock = manager.Lock()
754     tasks = {
755         asyncio.ensure_future(
756             loop.run_in_executor(
757                 executor, format_file_in_place, src, fast, mode, write_back, lock
758             )
759         ): src
760         for src in sorted(sources)
761     }
762     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
763     try:
764         loop.add_signal_handler(signal.SIGINT, cancel, pending)
765         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
766     except NotImplementedError:
767         # There are no good alternatives for these on Windows.
768         pass
769     while pending:
770         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
771         for task in done:
772             src = tasks.pop(task)
773             if task.cancelled():
774                 cancelled.append(task)
775             elif task.exception():
776                 report.failed(src, str(task.exception()))
777             else:
778                 changed = Changed.YES if task.result() else Changed.NO
779                 # If the file was written back or was successfully checked as
780                 # well-formatted, store this information in the cache.
781                 if write_back is WriteBack.YES or (
782                     write_back is WriteBack.CHECK and changed is Changed.NO
783                 ):
784                     sources_to_cache.append(src)
785                 report.done(src, changed)
786     if cancelled:
787         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
788     if sources_to_cache:
789         write_cache(cache, sources_to_cache, mode)
790
791
792 def format_file_in_place(
793     src: Path,
794     fast: bool,
795     mode: Mode,
796     write_back: WriteBack = WriteBack.NO,
797     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
798 ) -> bool:
799     """Format file under `src` path. Return True if changed.
800
801     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
802     code to the file.
803     `mode` and `fast` options are passed to :func:`format_file_contents`.
804     """
805     if src.suffix == ".pyi":
806         mode = replace(mode, is_pyi=True)
807
808     then = datetime.utcfromtimestamp(src.stat().st_mtime)
809     with open(src, "rb") as buf:
810         src_contents, encoding, newline = decode_bytes(buf.read())
811     try:
812         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
813     except NothingChanged:
814         return False
815
816     if write_back == WriteBack.YES:
817         with open(src, "w", encoding=encoding, newline=newline) as f:
818             f.write(dst_contents)
819     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
820         now = datetime.utcnow()
821         src_name = f"{src}\t{then} +0000"
822         dst_name = f"{src}\t{now} +0000"
823         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
824
825         if write_back == write_back.COLOR_DIFF:
826             diff_contents = color_diff(diff_contents)
827
828         with lock or nullcontext():
829             f = io.TextIOWrapper(
830                 sys.stdout.buffer,
831                 encoding=encoding,
832                 newline=newline,
833                 write_through=True,
834             )
835             f = wrap_stream_for_windows(f)
836             f.write(diff_contents)
837             f.detach()
838
839     return True
840
841
842 def color_diff(contents: str) -> str:
843     """Inject the ANSI color codes to the diff."""
844     lines = contents.split("\n")
845     for i, line in enumerate(lines):
846         if line.startswith("+++") or line.startswith("---"):
847             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
848         if line.startswith("@@"):
849             line = "\033[36m" + line + "\033[0m"  # cyan, reset
850         if line.startswith("+"):
851             line = "\033[32m" + line + "\033[0m"  # green, reset
852         elif line.startswith("-"):
853             line = "\033[31m" + line + "\033[0m"  # red, reset
854         lines[i] = line
855     return "\n".join(lines)
856
857
858 def wrap_stream_for_windows(
859     f: io.TextIOWrapper,
860 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
861     """
862     Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
863
864     If `colorama` is not found, then no change is made. If `colorama` does
865     exist, then it handles the logic to determine whether or not to change
866     things.
867     """
868     try:
869         from colorama import initialise
870
871         # We set `strip=False` so that we can don't have to modify
872         # test_express_diff_with_color.
873         f = initialise.wrap_stream(
874             f, convert=None, strip=False, autoreset=False, wrap=True
875         )
876
877         # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
878         # which does not have a `detach()` method. So we fake one.
879         f.detach = lambda *args, **kwargs: None  # type: ignore
880     except ImportError:
881         pass
882
883     return f
884
885
886 def format_stdin_to_stdout(
887     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
888 ) -> bool:
889     """Format file on stdin. Return True if changed.
890
891     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
892     write a diff to stdout. The `mode` argument is passed to
893     :func:`format_file_contents`.
894     """
895     then = datetime.utcnow()
896     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
897     dst = src
898     try:
899         dst = format_file_contents(src, fast=fast, mode=mode)
900         return True
901
902     except NothingChanged:
903         return False
904
905     finally:
906         f = io.TextIOWrapper(
907             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
908         )
909         if write_back == WriteBack.YES:
910             f.write(dst)
911         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
912             now = datetime.utcnow()
913             src_name = f"STDIN\t{then} +0000"
914             dst_name = f"STDOUT\t{now} +0000"
915             d = diff(src, dst, src_name, dst_name)
916             if write_back == WriteBack.COLOR_DIFF:
917                 d = color_diff(d)
918                 f = wrap_stream_for_windows(f)
919             f.write(d)
920         f.detach()
921
922
923 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
924     """Reformat contents a file and return new contents.
925
926     If `fast` is False, additionally confirm that the reformatted code is
927     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
928     `mode` is passed to :func:`format_str`.
929     """
930     if src_contents.strip() == "":
931         raise NothingChanged
932
933     dst_contents = format_str(src_contents, mode=mode)
934     if src_contents == dst_contents:
935         raise NothingChanged
936
937     if not fast:
938         assert_equivalent(src_contents, dst_contents)
939         assert_stable(src_contents, dst_contents, mode=mode)
940     return dst_contents
941
942
943 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
944     """Reformat a string and return new contents.
945
946     `mode` determines formatting options, such as how many characters per line are
947     allowed.  Example:
948
949     >>> import black
950     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
951     def f(arg: str = "") -> None:
952         ...
953
954     A more complex example:
955
956     >>> print(
957     ...   black.format_str(
958     ...     "def f(arg:str='')->None: hey",
959     ...     mode=black.Mode(
960     ...       target_versions={black.TargetVersion.PY36},
961     ...       line_length=10,
962     ...       string_normalization=False,
963     ...       is_pyi=False,
964     ...     ),
965     ...   ),
966     ... )
967     def f(
968         arg: str = '',
969     ) -> None:
970         hey
971
972     """
973     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
974     dst_contents = []
975     future_imports = get_future_imports(src_node)
976     if mode.target_versions:
977         versions = mode.target_versions
978     else:
979         versions = detect_target_versions(src_node)
980     normalize_fmt_off(src_node)
981     lines = LineGenerator(
982         remove_u_prefix="unicode_literals" in future_imports
983         or supports_feature(versions, Feature.UNICODE_LITERALS),
984         is_pyi=mode.is_pyi,
985         normalize_strings=mode.string_normalization,
986     )
987     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
988     empty_line = Line()
989     after = 0
990     split_line_features = {
991         feature
992         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
993         if supports_feature(versions, feature)
994     }
995     for current_line in lines.visit(src_node):
996         dst_contents.append(str(empty_line) * after)
997         before, after = elt.maybe_empty_lines(current_line)
998         dst_contents.append(str(empty_line) * before)
999         for line in transform_line(
1000             current_line, mode=mode, features=split_line_features
1001         ):
1002             dst_contents.append(str(line))
1003     return "".join(dst_contents)
1004
1005
1006 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1007     """Return a tuple of (decoded_contents, encoding, newline).
1008
1009     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1010     universal newlines (i.e. only contains LF).
1011     """
1012     srcbuf = io.BytesIO(src)
1013     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1014     if not lines:
1015         return "", encoding, "\n"
1016
1017     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1018     srcbuf.seek(0)
1019     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1020         return tiow.read(), encoding, newline
1021
1022
1023 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1024     if not target_versions:
1025         # No target_version specified, so try all grammars.
1026         return [
1027             # Python 3.7+
1028             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1029             # Python 3.0-3.6
1030             pygram.python_grammar_no_print_statement_no_exec_statement,
1031             # Python 2.7 with future print_function import
1032             pygram.python_grammar_no_print_statement,
1033             # Python 2.7
1034             pygram.python_grammar,
1035         ]
1036
1037     if all(version.is_python2() for version in target_versions):
1038         # Python 2-only code, so try Python 2 grammars.
1039         return [
1040             # Python 2.7 with future print_function import
1041             pygram.python_grammar_no_print_statement,
1042             # Python 2.7
1043             pygram.python_grammar,
1044         ]
1045
1046     # Python 3-compatible code, so only try Python 3 grammar.
1047     grammars = []
1048     # If we have to parse both, try to parse async as a keyword first
1049     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1050         # Python 3.7+
1051         grammars.append(
1052             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1053         )
1054     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1055         # Python 3.0-3.6
1056         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1057     # At least one of the above branches must have been taken, because every Python
1058     # version has exactly one of the two 'ASYNC_*' flags
1059     return grammars
1060
1061
1062 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1063     """Given a string with source, return the lib2to3 Node."""
1064     if src_txt[-1:] != "\n":
1065         src_txt += "\n"
1066
1067     for grammar in get_grammars(set(target_versions)):
1068         drv = driver.Driver(grammar, pytree.convert)
1069         try:
1070             result = drv.parse_string(src_txt, True)
1071             break
1072
1073         except ParseError as pe:
1074             lineno, column = pe.context[1]
1075             lines = src_txt.splitlines()
1076             try:
1077                 faulty_line = lines[lineno - 1]
1078             except IndexError:
1079                 faulty_line = "<line number missing in source>"
1080             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1081     else:
1082         raise exc from None
1083
1084     if isinstance(result, Leaf):
1085         result = Node(syms.file_input, [result])
1086     return result
1087
1088
1089 def lib2to3_unparse(node: Node) -> str:
1090     """Given a lib2to3 node, return its string representation."""
1091     code = str(node)
1092     return code
1093
1094
1095 class Visitor(Generic[T]):
1096     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1097
1098     def visit(self, node: LN) -> Iterator[T]:
1099         """Main method to visit `node` and its children.
1100
1101         It tries to find a `visit_*()` method for the given `node.type`, like
1102         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1103         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1104         instead.
1105
1106         Then yields objects of type `T` from the selected visitor.
1107         """
1108         if node.type < 256:
1109             name = token.tok_name[node.type]
1110         else:
1111             name = str(type_repr(node.type))
1112         # We explicitly branch on whether a visitor exists (instead of
1113         # using self.visit_default as the default arg to getattr) in order
1114         # to save needing to create a bound method object and so mypyc can
1115         # generate a native call to visit_default.
1116         visitf = getattr(self, f"visit_{name}", None)
1117         if visitf:
1118             yield from visitf(node)
1119         else:
1120             yield from self.visit_default(node)
1121
1122     def visit_default(self, node: LN) -> Iterator[T]:
1123         """Default `visit_*()` implementation. Recurses to children of `node`."""
1124         if isinstance(node, Node):
1125             for child in node.children:
1126                 yield from self.visit(child)
1127
1128
1129 @dataclass
1130 class DebugVisitor(Visitor[T]):
1131     tree_depth: int = 0
1132
1133     def visit_default(self, node: LN) -> Iterator[T]:
1134         indent = " " * (2 * self.tree_depth)
1135         if isinstance(node, Node):
1136             _type = type_repr(node.type)
1137             out(f"{indent}{_type}", fg="yellow")
1138             self.tree_depth += 1
1139             for child in node.children:
1140                 yield from self.visit(child)
1141
1142             self.tree_depth -= 1
1143             out(f"{indent}/{_type}", fg="yellow", bold=False)
1144         else:
1145             _type = token.tok_name.get(node.type, str(node.type))
1146             out(f"{indent}{_type}", fg="blue", nl=False)
1147             if node.prefix:
1148                 # We don't have to handle prefixes for `Node` objects since
1149                 # that delegates to the first child anyway.
1150                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1151             out(f" {node.value!r}", fg="blue", bold=False)
1152
1153     @classmethod
1154     def show(cls, code: Union[str, Leaf, Node]) -> None:
1155         """Pretty-print the lib2to3 AST of a given string of `code`.
1156
1157         Convenience method for debugging.
1158         """
1159         v: DebugVisitor[None] = DebugVisitor()
1160         if isinstance(code, str):
1161             code = lib2to3_parse(code)
1162         list(v.visit(code))
1163
1164
1165 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1166 STATEMENT: Final = {
1167     syms.if_stmt,
1168     syms.while_stmt,
1169     syms.for_stmt,
1170     syms.try_stmt,
1171     syms.except_clause,
1172     syms.with_stmt,
1173     syms.funcdef,
1174     syms.classdef,
1175 }
1176 STANDALONE_COMMENT: Final = 153
1177 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1178 LOGIC_OPERATORS: Final = {"and", "or"}
1179 COMPARATORS: Final = {
1180     token.LESS,
1181     token.GREATER,
1182     token.EQEQUAL,
1183     token.NOTEQUAL,
1184     token.LESSEQUAL,
1185     token.GREATEREQUAL,
1186 }
1187 MATH_OPERATORS: Final = {
1188     token.VBAR,
1189     token.CIRCUMFLEX,
1190     token.AMPER,
1191     token.LEFTSHIFT,
1192     token.RIGHTSHIFT,
1193     token.PLUS,
1194     token.MINUS,
1195     token.STAR,
1196     token.SLASH,
1197     token.DOUBLESLASH,
1198     token.PERCENT,
1199     token.AT,
1200     token.TILDE,
1201     token.DOUBLESTAR,
1202 }
1203 STARS: Final = {token.STAR, token.DOUBLESTAR}
1204 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1205 VARARGS_PARENTS: Final = {
1206     syms.arglist,
1207     syms.argument,  # double star in arglist
1208     syms.trailer,  # single argument to call
1209     syms.typedargslist,
1210     syms.varargslist,  # lambdas
1211 }
1212 UNPACKING_PARENTS: Final = {
1213     syms.atom,  # single element of a list or set literal
1214     syms.dictsetmaker,
1215     syms.listmaker,
1216     syms.testlist_gexp,
1217     syms.testlist_star_expr,
1218 }
1219 TEST_DESCENDANTS: Final = {
1220     syms.test,
1221     syms.lambdef,
1222     syms.or_test,
1223     syms.and_test,
1224     syms.not_test,
1225     syms.comparison,
1226     syms.star_expr,
1227     syms.expr,
1228     syms.xor_expr,
1229     syms.and_expr,
1230     syms.shift_expr,
1231     syms.arith_expr,
1232     syms.trailer,
1233     syms.term,
1234     syms.power,
1235 }
1236 ASSIGNMENTS: Final = {
1237     "=",
1238     "+=",
1239     "-=",
1240     "*=",
1241     "@=",
1242     "/=",
1243     "%=",
1244     "&=",
1245     "|=",
1246     "^=",
1247     "<<=",
1248     ">>=",
1249     "**=",
1250     "//=",
1251 }
1252 COMPREHENSION_PRIORITY: Final = 20
1253 COMMA_PRIORITY: Final = 18
1254 TERNARY_PRIORITY: Final = 16
1255 LOGIC_PRIORITY: Final = 14
1256 STRING_PRIORITY: Final = 12
1257 COMPARATOR_PRIORITY: Final = 10
1258 MATH_PRIORITIES: Final = {
1259     token.VBAR: 9,
1260     token.CIRCUMFLEX: 8,
1261     token.AMPER: 7,
1262     token.LEFTSHIFT: 6,
1263     token.RIGHTSHIFT: 6,
1264     token.PLUS: 5,
1265     token.MINUS: 5,
1266     token.STAR: 4,
1267     token.SLASH: 4,
1268     token.DOUBLESLASH: 4,
1269     token.PERCENT: 4,
1270     token.AT: 4,
1271     token.TILDE: 3,
1272     token.DOUBLESTAR: 2,
1273 }
1274 DOT_PRIORITY: Final = 1
1275
1276
1277 @dataclass
1278 class BracketTracker:
1279     """Keeps track of brackets on a line."""
1280
1281     depth: int = 0
1282     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1283     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1284     previous: Optional[Leaf] = None
1285     _for_loop_depths: List[int] = field(default_factory=list)
1286     _lambda_argument_depths: List[int] = field(default_factory=list)
1287
1288     def mark(self, leaf: Leaf) -> None:
1289         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1290
1291         All leaves receive an int `bracket_depth` field that stores how deep
1292         within brackets a given leaf is. 0 means there are no enclosing brackets
1293         that started on this line.
1294
1295         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1296         field that it forms a pair with. This is a one-directional link to
1297         avoid reference cycles.
1298
1299         If a leaf is a delimiter (a token on which Black can split the line if
1300         needed) and it's on depth 0, its `id()` is stored in the tracker's
1301         `delimiters` field.
1302         """
1303         if leaf.type == token.COMMENT:
1304             return
1305
1306         self.maybe_decrement_after_for_loop_variable(leaf)
1307         self.maybe_decrement_after_lambda_arguments(leaf)
1308         if leaf.type in CLOSING_BRACKETS:
1309             self.depth -= 1
1310             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1311             leaf.opening_bracket = opening_bracket
1312         leaf.bracket_depth = self.depth
1313         if self.depth == 0:
1314             delim = is_split_before_delimiter(leaf, self.previous)
1315             if delim and self.previous is not None:
1316                 self.delimiters[id(self.previous)] = delim
1317             else:
1318                 delim = is_split_after_delimiter(leaf, self.previous)
1319                 if delim:
1320                     self.delimiters[id(leaf)] = delim
1321         if leaf.type in OPENING_BRACKETS:
1322             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1323             self.depth += 1
1324         self.previous = leaf
1325         self.maybe_increment_lambda_arguments(leaf)
1326         self.maybe_increment_for_loop_variable(leaf)
1327
1328     def any_open_brackets(self) -> bool:
1329         """Return True if there is an yet unmatched open bracket on the line."""
1330         return bool(self.bracket_match)
1331
1332     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1333         """Return the highest priority of a delimiter found on the line.
1334
1335         Values are consistent with what `is_split_*_delimiter()` return.
1336         Raises ValueError on no delimiters.
1337         """
1338         return max(v for k, v in self.delimiters.items() if k not in exclude)
1339
1340     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1341         """Return the number of delimiters with the given `priority`.
1342
1343         If no `priority` is passed, defaults to max priority on the line.
1344         """
1345         if not self.delimiters:
1346             return 0
1347
1348         priority = priority or self.max_delimiter_priority()
1349         return sum(1 for p in self.delimiters.values() if p == priority)
1350
1351     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1352         """In a for loop, or comprehension, the variables are often unpacks.
1353
1354         To avoid splitting on the comma in this situation, increase the depth of
1355         tokens between `for` and `in`.
1356         """
1357         if leaf.type == token.NAME and leaf.value == "for":
1358             self.depth += 1
1359             self._for_loop_depths.append(self.depth)
1360             return True
1361
1362         return False
1363
1364     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1365         """See `maybe_increment_for_loop_variable` above for explanation."""
1366         if (
1367             self._for_loop_depths
1368             and self._for_loop_depths[-1] == self.depth
1369             and leaf.type == token.NAME
1370             and leaf.value == "in"
1371         ):
1372             self.depth -= 1
1373             self._for_loop_depths.pop()
1374             return True
1375
1376         return False
1377
1378     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1379         """In a lambda expression, there might be more than one argument.
1380
1381         To avoid splitting on the comma in this situation, increase the depth of
1382         tokens between `lambda` and `:`.
1383         """
1384         if leaf.type == token.NAME and leaf.value == "lambda":
1385             self.depth += 1
1386             self._lambda_argument_depths.append(self.depth)
1387             return True
1388
1389         return False
1390
1391     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1392         """See `maybe_increment_lambda_arguments` above for explanation."""
1393         if (
1394             self._lambda_argument_depths
1395             and self._lambda_argument_depths[-1] == self.depth
1396             and leaf.type == token.COLON
1397         ):
1398             self.depth -= 1
1399             self._lambda_argument_depths.pop()
1400             return True
1401
1402         return False
1403
1404     def get_open_lsqb(self) -> Optional[Leaf]:
1405         """Return the most recent opening square bracket (if any)."""
1406         return self.bracket_match.get((self.depth - 1, token.RSQB))
1407
1408
1409 @dataclass
1410 class Line:
1411     """Holds leaves and comments. Can be printed with `str(line)`."""
1412
1413     depth: int = 0
1414     leaves: List[Leaf] = field(default_factory=list)
1415     # keys ordered like `leaves`
1416     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1417     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1418     inside_brackets: bool = False
1419     should_explode: bool = False
1420
1421     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1422         """Add a new `leaf` to the end of the line.
1423
1424         Unless `preformatted` is True, the `leaf` will receive a new consistent
1425         whitespace prefix and metadata applied by :class:`BracketTracker`.
1426         Trailing commas are maybe removed, unpacked for loop variables are
1427         demoted from being delimiters.
1428
1429         Inline comments are put aside.
1430         """
1431         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1432         if not has_value:
1433             return
1434
1435         if token.COLON == leaf.type and self.is_class_paren_empty:
1436             del self.leaves[-2:]
1437         if self.leaves and not preformatted:
1438             # Note: at this point leaf.prefix should be empty except for
1439             # imports, for which we only preserve newlines.
1440             leaf.prefix += whitespace(
1441                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1442             )
1443         if self.inside_brackets or not preformatted:
1444             self.bracket_tracker.mark(leaf)
1445             self.maybe_remove_trailing_comma(leaf)
1446         if not self.append_comment(leaf):
1447             self.leaves.append(leaf)
1448
1449     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1450         """Like :func:`append()` but disallow invalid standalone comment structure.
1451
1452         Raises ValueError when any `leaf` is appended after a standalone comment
1453         or when a standalone comment is not the first leaf on the line.
1454         """
1455         if self.bracket_tracker.depth == 0:
1456             if self.is_comment:
1457                 raise ValueError("cannot append to standalone comments")
1458
1459             if self.leaves and leaf.type == STANDALONE_COMMENT:
1460                 raise ValueError(
1461                     "cannot append standalone comments to a populated line"
1462                 )
1463
1464         self.append(leaf, preformatted=preformatted)
1465
1466     @property
1467     def is_comment(self) -> bool:
1468         """Is this line a standalone comment?"""
1469         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1470
1471     @property
1472     def is_decorator(self) -> bool:
1473         """Is this line a decorator?"""
1474         return bool(self) and self.leaves[0].type == token.AT
1475
1476     @property
1477     def is_import(self) -> bool:
1478         """Is this an import line?"""
1479         return bool(self) and is_import(self.leaves[0])
1480
1481     @property
1482     def is_class(self) -> bool:
1483         """Is this line a class definition?"""
1484         return (
1485             bool(self)
1486             and self.leaves[0].type == token.NAME
1487             and self.leaves[0].value == "class"
1488         )
1489
1490     @property
1491     def is_stub_class(self) -> bool:
1492         """Is this line a class definition with a body consisting only of "..."?"""
1493         return self.is_class and self.leaves[-3:] == [
1494             Leaf(token.DOT, ".") for _ in range(3)
1495         ]
1496
1497     @property
1498     def is_collection_with_optional_trailing_comma(self) -> bool:
1499         """Is this line a collection literal with a trailing comma that's optional?
1500
1501         Note that the trailing comma in a 1-tuple is not optional.
1502         """
1503         if not self.leaves or len(self.leaves) < 4:
1504             return False
1505
1506         # Look for and address a trailing colon.
1507         if self.leaves[-1].type == token.COLON:
1508             closer = self.leaves[-2]
1509             close_index = -2
1510         else:
1511             closer = self.leaves[-1]
1512             close_index = -1
1513         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1514             return False
1515
1516         if closer.type == token.RPAR:
1517             # Tuples require an extra check, because if there's only
1518             # one element in the tuple removing the comma unmakes the
1519             # tuple.
1520             #
1521             # We also check for parens before looking for the trailing
1522             # comma because in some cases (eg assigning a dict
1523             # literal) the literal gets wrapped in temporary parens
1524             # during parsing. This case is covered by the
1525             # collections.py test data.
1526             opener = closer.opening_bracket
1527             for _open_index, leaf in enumerate(self.leaves):
1528                 if leaf is opener:
1529                     break
1530
1531             else:
1532                 # Couldn't find the matching opening paren, play it safe.
1533                 return False
1534
1535             commas = 0
1536             comma_depth = self.leaves[close_index - 1].bracket_depth
1537             for leaf in self.leaves[_open_index + 1 : close_index]:
1538                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1539                     commas += 1
1540             if commas > 1:
1541                 # We haven't looked yet for the trailing comma because
1542                 # we might also have caught noop parens.
1543                 return self.leaves[close_index - 1].type == token.COMMA
1544
1545             elif commas == 1:
1546                 return False  # it's either a one-tuple or didn't have a trailing comma
1547
1548             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1549                 close_index -= 1
1550                 closer = self.leaves[close_index]
1551                 if closer.type == token.RPAR:
1552                     # TODO: this is a gut feeling. Will we ever see this?
1553                     return False
1554
1555         if self.leaves[close_index - 1].type != token.COMMA:
1556             return False
1557
1558         return True
1559
1560     @property
1561     def is_def(self) -> bool:
1562         """Is this a function definition? (Also returns True for async defs.)"""
1563         try:
1564             first_leaf = self.leaves[0]
1565         except IndexError:
1566             return False
1567
1568         try:
1569             second_leaf: Optional[Leaf] = self.leaves[1]
1570         except IndexError:
1571             second_leaf = None
1572         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1573             first_leaf.type == token.ASYNC
1574             and second_leaf is not None
1575             and second_leaf.type == token.NAME
1576             and second_leaf.value == "def"
1577         )
1578
1579     @property
1580     def is_class_paren_empty(self) -> bool:
1581         """Is this a class with no base classes but using parentheses?
1582
1583         Those are unnecessary and should be removed.
1584         """
1585         return (
1586             bool(self)
1587             and len(self.leaves) == 4
1588             and self.is_class
1589             and self.leaves[2].type == token.LPAR
1590             and self.leaves[2].value == "("
1591             and self.leaves[3].type == token.RPAR
1592             and self.leaves[3].value == ")"
1593         )
1594
1595     @property
1596     def is_triple_quoted_string(self) -> bool:
1597         """Is the line a triple quoted string?"""
1598         return (
1599             bool(self)
1600             and self.leaves[0].type == token.STRING
1601             and self.leaves[0].value.startswith(('"""', "'''"))
1602         )
1603
1604     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1605         """If so, needs to be split before emitting."""
1606         for leaf in self.leaves:
1607             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1608                 return True
1609
1610         return False
1611
1612     def contains_uncollapsable_type_comments(self) -> bool:
1613         ignored_ids = set()
1614         try:
1615             last_leaf = self.leaves[-1]
1616             ignored_ids.add(id(last_leaf))
1617             if last_leaf.type == token.COMMA or (
1618                 last_leaf.type == token.RPAR and not last_leaf.value
1619             ):
1620                 # When trailing commas or optional parens are inserted by Black for
1621                 # consistency, comments after the previous last element are not moved
1622                 # (they don't have to, rendering will still be correct).  So we ignore
1623                 # trailing commas and invisible.
1624                 last_leaf = self.leaves[-2]
1625                 ignored_ids.add(id(last_leaf))
1626         except IndexError:
1627             return False
1628
1629         # A type comment is uncollapsable if it is attached to a leaf
1630         # that isn't at the end of the line (since that could cause it
1631         # to get associated to a different argument) or if there are
1632         # comments before it (since that could cause it to get hidden
1633         # behind a comment.
1634         comment_seen = False
1635         for leaf_id, comments in self.comments.items():
1636             for comment in comments:
1637                 if is_type_comment(comment):
1638                     if comment_seen or (
1639                         not is_type_comment(comment, " ignore")
1640                         and leaf_id not in ignored_ids
1641                     ):
1642                         return True
1643
1644                 comment_seen = True
1645
1646         return False
1647
1648     def contains_unsplittable_type_ignore(self) -> bool:
1649         if not self.leaves:
1650             return False
1651
1652         # If a 'type: ignore' is attached to the end of a line, we
1653         # can't split the line, because we can't know which of the
1654         # subexpressions the ignore was meant to apply to.
1655         #
1656         # We only want this to apply to actual physical lines from the
1657         # original source, though: we don't want the presence of a
1658         # 'type: ignore' at the end of a multiline expression to
1659         # justify pushing it all onto one line. Thus we
1660         # (unfortunately) need to check the actual source lines and
1661         # only report an unsplittable 'type: ignore' if this line was
1662         # one line in the original code.
1663
1664         # Grab the first and last line numbers, skipping generated leaves
1665         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1666         last_line = next(
1667             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1668         )
1669
1670         if first_line == last_line:
1671             # We look at the last two leaves since a comma or an
1672             # invisible paren could have been added at the end of the
1673             # line.
1674             for node in self.leaves[-2:]:
1675                 for comment in self.comments.get(id(node), []):
1676                     if is_type_comment(comment, " ignore"):
1677                         return True
1678
1679         return False
1680
1681     def contains_multiline_strings(self) -> bool:
1682         return any(is_multiline_string(leaf) for leaf in self.leaves)
1683
1684     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1685         """Remove trailing comma if there is one and it's safe."""
1686         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1687             return False
1688
1689         # We remove trailing commas only in the case of importing a
1690         # single name from a module.
1691         if not (
1692             self.leaves
1693             and self.is_import
1694             and len(self.leaves) > 4
1695             and self.leaves[-1].type == token.COMMA
1696             and closing.type in CLOSING_BRACKETS
1697             and self.leaves[-4].type == token.NAME
1698             and (
1699                 # regular `from foo import bar,`
1700                 self.leaves[-4].value == "import"
1701                 # `from foo import (bar as baz,)
1702                 or (
1703                     len(self.leaves) > 6
1704                     and self.leaves[-6].value == "import"
1705                     and self.leaves[-3].value == "as"
1706                 )
1707                 # `from foo import bar as baz,`
1708                 or (
1709                     len(self.leaves) > 5
1710                     and self.leaves[-5].value == "import"
1711                     and self.leaves[-3].value == "as"
1712                 )
1713             )
1714             and closing.type == token.RPAR
1715         ):
1716             return False
1717
1718         self.remove_trailing_comma()
1719         return True
1720
1721     def append_comment(self, comment: Leaf) -> bool:
1722         """Add an inline or standalone comment to the line."""
1723         if (
1724             comment.type == STANDALONE_COMMENT
1725             and self.bracket_tracker.any_open_brackets()
1726         ):
1727             comment.prefix = ""
1728             return False
1729
1730         if comment.type != token.COMMENT:
1731             return False
1732
1733         if not self.leaves:
1734             comment.type = STANDALONE_COMMENT
1735             comment.prefix = ""
1736             return False
1737
1738         last_leaf = self.leaves[-1]
1739         if (
1740             last_leaf.type == token.RPAR
1741             and not last_leaf.value
1742             and last_leaf.parent
1743             and len(list(last_leaf.parent.leaves())) <= 3
1744             and not is_type_comment(comment)
1745         ):
1746             # Comments on an optional parens wrapping a single leaf should belong to
1747             # the wrapped node except if it's a type comment. Pinning the comment like
1748             # this avoids unstable formatting caused by comment migration.
1749             if len(self.leaves) < 2:
1750                 comment.type = STANDALONE_COMMENT
1751                 comment.prefix = ""
1752                 return False
1753
1754             last_leaf = self.leaves[-2]
1755         self.comments.setdefault(id(last_leaf), []).append(comment)
1756         return True
1757
1758     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1759         """Generate comments that should appear directly after `leaf`."""
1760         return self.comments.get(id(leaf), [])
1761
1762     def remove_trailing_comma(self) -> None:
1763         """Remove the trailing comma and moves the comments attached to it."""
1764         trailing_comma = self.leaves.pop()
1765         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1766         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1767             trailing_comma_comments
1768         )
1769
1770     def is_complex_subscript(self, leaf: Leaf) -> bool:
1771         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1772         open_lsqb = self.bracket_tracker.get_open_lsqb()
1773         if open_lsqb is None:
1774             return False
1775
1776         subscript_start = open_lsqb.next_sibling
1777
1778         if isinstance(subscript_start, Node):
1779             if subscript_start.type == syms.listmaker:
1780                 return False
1781
1782             if subscript_start.type == syms.subscriptlist:
1783                 subscript_start = child_towards(subscript_start, leaf)
1784         return subscript_start is not None and any(
1785             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1786         )
1787
1788     def clone(self) -> "Line":
1789         return Line(
1790             depth=self.depth,
1791             inside_brackets=self.inside_brackets,
1792             should_explode=self.should_explode,
1793         )
1794
1795     def __str__(self) -> str:
1796         """Render the line."""
1797         if not self:
1798             return "\n"
1799
1800         indent = "    " * self.depth
1801         leaves = iter(self.leaves)
1802         first = next(leaves)
1803         res = f"{first.prefix}{indent}{first.value}"
1804         for leaf in leaves:
1805             res += str(leaf)
1806         for comment in itertools.chain.from_iterable(self.comments.values()):
1807             res += str(comment)
1808
1809         return res + "\n"
1810
1811     def __bool__(self) -> bool:
1812         """Return True if the line has leaves or comments."""
1813         return bool(self.leaves or self.comments)
1814
1815
1816 @dataclass
1817 class EmptyLineTracker:
1818     """Provides a stateful method that returns the number of potential extra
1819     empty lines needed before and after the currently processed line.
1820
1821     Note: this tracker works on lines that haven't been split yet.  It assumes
1822     the prefix of the first leaf consists of optional newlines.  Those newlines
1823     are consumed by `maybe_empty_lines()` and included in the computation.
1824     """
1825
1826     is_pyi: bool = False
1827     previous_line: Optional[Line] = None
1828     previous_after: int = 0
1829     previous_defs: List[int] = field(default_factory=list)
1830
1831     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1832         """Return the number of extra empty lines before and after the `current_line`.
1833
1834         This is for separating `def`, `async def` and `class` with extra empty
1835         lines (two on module-level).
1836         """
1837         before, after = self._maybe_empty_lines(current_line)
1838         before = (
1839             # Black should not insert empty lines at the beginning
1840             # of the file
1841             0
1842             if self.previous_line is None
1843             else before - self.previous_after
1844         )
1845         self.previous_after = after
1846         self.previous_line = current_line
1847         return before, after
1848
1849     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1850         max_allowed = 1
1851         if current_line.depth == 0:
1852             max_allowed = 1 if self.is_pyi else 2
1853         if current_line.leaves:
1854             # Consume the first leaf's extra newlines.
1855             first_leaf = current_line.leaves[0]
1856             before = first_leaf.prefix.count("\n")
1857             before = min(before, max_allowed)
1858             first_leaf.prefix = ""
1859         else:
1860             before = 0
1861         depth = current_line.depth
1862         while self.previous_defs and self.previous_defs[-1] >= depth:
1863             self.previous_defs.pop()
1864             if self.is_pyi:
1865                 before = 0 if depth else 1
1866             else:
1867                 before = 1 if depth else 2
1868         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1869             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1870
1871         if (
1872             self.previous_line
1873             and self.previous_line.is_import
1874             and not current_line.is_import
1875             and depth == self.previous_line.depth
1876         ):
1877             return (before or 1), 0
1878
1879         if (
1880             self.previous_line
1881             and self.previous_line.is_class
1882             and current_line.is_triple_quoted_string
1883         ):
1884             return before, 1
1885
1886         return before, 0
1887
1888     def _maybe_empty_lines_for_class_or_def(
1889         self, current_line: Line, before: int
1890     ) -> Tuple[int, int]:
1891         if not current_line.is_decorator:
1892             self.previous_defs.append(current_line.depth)
1893         if self.previous_line is None:
1894             # Don't insert empty lines before the first line in the file.
1895             return 0, 0
1896
1897         if self.previous_line.is_decorator:
1898             return 0, 0
1899
1900         if self.previous_line.depth < current_line.depth and (
1901             self.previous_line.is_class or self.previous_line.is_def
1902         ):
1903             return 0, 0
1904
1905         if (
1906             self.previous_line.is_comment
1907             and self.previous_line.depth == current_line.depth
1908             and before == 0
1909         ):
1910             return 0, 0
1911
1912         if self.is_pyi:
1913             if self.previous_line.depth > current_line.depth:
1914                 newlines = 1
1915             elif current_line.is_class or self.previous_line.is_class:
1916                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1917                     # No blank line between classes with an empty body
1918                     newlines = 0
1919                 else:
1920                     newlines = 1
1921             elif current_line.is_def and not self.previous_line.is_def:
1922                 # Blank line between a block of functions and a block of non-functions
1923                 newlines = 1
1924             else:
1925                 newlines = 0
1926         else:
1927             newlines = 2
1928         if current_line.depth and newlines:
1929             newlines -= 1
1930         return newlines, 0
1931
1932
1933 @dataclass
1934 class LineGenerator(Visitor[Line]):
1935     """Generates reformatted Line objects.  Empty lines are not emitted.
1936
1937     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1938     in ways that will no longer stringify to valid Python code on the tree.
1939     """
1940
1941     is_pyi: bool = False
1942     normalize_strings: bool = True
1943     current_line: Line = field(default_factory=Line)
1944     remove_u_prefix: bool = False
1945
1946     def line(self, indent: int = 0) -> Iterator[Line]:
1947         """Generate a line.
1948
1949         If the line is empty, only emit if it makes sense.
1950         If the line is too long, split it first and then generate.
1951
1952         If any lines were generated, set up a new current_line.
1953         """
1954         if not self.current_line:
1955             self.current_line.depth += indent
1956             return  # Line is empty, don't emit. Creating a new one unnecessary.
1957
1958         complete_line = self.current_line
1959         self.current_line = Line(depth=complete_line.depth + indent)
1960         yield complete_line
1961
1962     def visit_default(self, node: LN) -> Iterator[Line]:
1963         """Default `visit_*()` implementation. Recurses to children of `node`."""
1964         if isinstance(node, Leaf):
1965             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1966             for comment in generate_comments(node):
1967                 if any_open_brackets:
1968                     # any comment within brackets is subject to splitting
1969                     self.current_line.append(comment)
1970                 elif comment.type == token.COMMENT:
1971                     # regular trailing comment
1972                     self.current_line.append(comment)
1973                     yield from self.line()
1974
1975                 else:
1976                     # regular standalone comment
1977                     yield from self.line()
1978
1979                     self.current_line.append(comment)
1980                     yield from self.line()
1981
1982             normalize_prefix(node, inside_brackets=any_open_brackets)
1983             if self.normalize_strings and node.type == token.STRING:
1984                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1985                 normalize_string_quotes(node)
1986             if node.type == token.NUMBER:
1987                 normalize_numeric_literal(node)
1988             if node.type not in WHITESPACE:
1989                 self.current_line.append(node)
1990         yield from super().visit_default(node)
1991
1992     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1993         """Increase indentation level, maybe yield a line."""
1994         # In blib2to3 INDENT never holds comments.
1995         yield from self.line(+1)
1996         yield from self.visit_default(node)
1997
1998     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1999         """Decrease indentation level, maybe yield a line."""
2000         # The current line might still wait for trailing comments.  At DEDENT time
2001         # there won't be any (they would be prefixes on the preceding NEWLINE).
2002         # Emit the line then.
2003         yield from self.line()
2004
2005         # While DEDENT has no value, its prefix may contain standalone comments
2006         # that belong to the current indentation level.  Get 'em.
2007         yield from self.visit_default(node)
2008
2009         # Finally, emit the dedent.
2010         yield from self.line(-1)
2011
2012     def visit_stmt(
2013         self, node: Node, keywords: Set[str], parens: Set[str]
2014     ) -> Iterator[Line]:
2015         """Visit a statement.
2016
2017         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2018         `def`, `with`, `class`, `assert` and assignments.
2019
2020         The relevant Python language `keywords` for a given statement will be
2021         NAME leaves within it. This methods puts those on a separate line.
2022
2023         `parens` holds a set of string leaf values immediately after which
2024         invisible parens should be put.
2025         """
2026         normalize_invisible_parens(node, parens_after=parens)
2027         for child in node.children:
2028             if child.type == token.NAME and child.value in keywords:  # type: ignore
2029                 yield from self.line()
2030
2031             yield from self.visit(child)
2032
2033     def visit_suite(self, node: Node) -> Iterator[Line]:
2034         """Visit a suite."""
2035         if self.is_pyi and is_stub_suite(node):
2036             yield from self.visit(node.children[2])
2037         else:
2038             yield from self.visit_default(node)
2039
2040     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2041         """Visit a statement without nested statements."""
2042         is_suite_like = node.parent and node.parent.type in STATEMENT
2043         if is_suite_like:
2044             if self.is_pyi and is_stub_body(node):
2045                 yield from self.visit_default(node)
2046             else:
2047                 yield from self.line(+1)
2048                 yield from self.visit_default(node)
2049                 yield from self.line(-1)
2050
2051         else:
2052             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
2053                 yield from self.line()
2054             yield from self.visit_default(node)
2055
2056     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2057         """Visit `async def`, `async for`, `async with`."""
2058         yield from self.line()
2059
2060         children = iter(node.children)
2061         for child in children:
2062             yield from self.visit(child)
2063
2064             if child.type == token.ASYNC:
2065                 break
2066
2067         internal_stmt = next(children)
2068         for child in internal_stmt.children:
2069             yield from self.visit(child)
2070
2071     def visit_decorators(self, node: Node) -> Iterator[Line]:
2072         """Visit decorators."""
2073         for child in node.children:
2074             yield from self.line()
2075             yield from self.visit(child)
2076
2077     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2078         """Remove a semicolon and put the other statement on a separate line."""
2079         yield from self.line()
2080
2081     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2082         """End of file. Process outstanding comments and end with a newline."""
2083         yield from self.visit_default(leaf)
2084         yield from self.line()
2085
2086     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2087         if not self.current_line.bracket_tracker.any_open_brackets():
2088             yield from self.line()
2089         yield from self.visit_default(leaf)
2090
2091     def visit_factor(self, node: Node) -> Iterator[Line]:
2092         """Force parentheses between a unary op and a binary power:
2093
2094         -2 ** 8 -> -(2 ** 8)
2095         """
2096         _operator, operand = node.children
2097         if (
2098             operand.type == syms.power
2099             and len(operand.children) == 3
2100             and operand.children[1].type == token.DOUBLESTAR
2101         ):
2102             lpar = Leaf(token.LPAR, "(")
2103             rpar = Leaf(token.RPAR, ")")
2104             index = operand.remove() or 0
2105             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2106         yield from self.visit_default(node)
2107
2108     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2109         # Check if it's a docstring
2110         if prev_siblings_are(
2111             leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
2112         ) and is_multiline_string(leaf):
2113             prefix = "    " * self.current_line.depth
2114             docstring = fix_docstring(leaf.value[3:-3], prefix)
2115             leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
2116             normalize_string_quotes(leaf)
2117
2118         yield from self.visit_default(leaf)
2119
2120     def __post_init__(self) -> None:
2121         """You are in a twisty little maze of passages."""
2122         v = self.visit_stmt
2123         Ø: Set[str] = set()
2124         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2125         self.visit_if_stmt = partial(
2126             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2127         )
2128         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2129         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2130         self.visit_try_stmt = partial(
2131             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2132         )
2133         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2134         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2135         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2136         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2137         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2138         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2139         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2140         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2141         self.visit_async_funcdef = self.visit_async_stmt
2142         self.visit_decorated = self.visit_decorators
2143
2144
2145 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2146 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2147 OPENING_BRACKETS = set(BRACKET.keys())
2148 CLOSING_BRACKETS = set(BRACKET.values())
2149 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2150 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2151
2152
2153 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2154     """Return whitespace prefix if needed for the given `leaf`.
2155
2156     `complex_subscript` signals whether the given leaf is part of a subscription
2157     which has non-trivial arguments, like arithmetic expressions or function calls.
2158     """
2159     NO = ""
2160     SPACE = " "
2161     DOUBLESPACE = "  "
2162     t = leaf.type
2163     p = leaf.parent
2164     v = leaf.value
2165     if t in ALWAYS_NO_SPACE:
2166         return NO
2167
2168     if t == token.COMMENT:
2169         return DOUBLESPACE
2170
2171     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2172     if t == token.COLON and p.type not in {
2173         syms.subscript,
2174         syms.subscriptlist,
2175         syms.sliceop,
2176     }:
2177         return NO
2178
2179     prev = leaf.prev_sibling
2180     if not prev:
2181         prevp = preceding_leaf(p)
2182         if not prevp or prevp.type in OPENING_BRACKETS:
2183             return NO
2184
2185         if t == token.COLON:
2186             if prevp.type == token.COLON:
2187                 return NO
2188
2189             elif prevp.type != token.COMMA and not complex_subscript:
2190                 return NO
2191
2192             return SPACE
2193
2194         if prevp.type == token.EQUAL:
2195             if prevp.parent:
2196                 if prevp.parent.type in {
2197                     syms.arglist,
2198                     syms.argument,
2199                     syms.parameters,
2200                     syms.varargslist,
2201                 }:
2202                     return NO
2203
2204                 elif prevp.parent.type == syms.typedargslist:
2205                     # A bit hacky: if the equal sign has whitespace, it means we
2206                     # previously found it's a typed argument.  So, we're using
2207                     # that, too.
2208                     return prevp.prefix
2209
2210         elif prevp.type in VARARGS_SPECIALS:
2211             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2212                 return NO
2213
2214         elif prevp.type == token.COLON:
2215             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2216                 return SPACE if complex_subscript else NO
2217
2218         elif (
2219             prevp.parent
2220             and prevp.parent.type == syms.factor
2221             and prevp.type in MATH_OPERATORS
2222         ):
2223             return NO
2224
2225         elif (
2226             prevp.type == token.RIGHTSHIFT
2227             and prevp.parent
2228             and prevp.parent.type == syms.shift_expr
2229             and prevp.prev_sibling
2230             and prevp.prev_sibling.type == token.NAME
2231             and prevp.prev_sibling.value == "print"  # type: ignore
2232         ):
2233             # Python 2 print chevron
2234             return NO
2235
2236     elif prev.type in OPENING_BRACKETS:
2237         return NO
2238
2239     if p.type in {syms.parameters, syms.arglist}:
2240         # untyped function signatures or calls
2241         if not prev or prev.type != token.COMMA:
2242             return NO
2243
2244     elif p.type == syms.varargslist:
2245         # lambdas
2246         if prev and prev.type != token.COMMA:
2247             return NO
2248
2249     elif p.type == syms.typedargslist:
2250         # typed function signatures
2251         if not prev:
2252             return NO
2253
2254         if t == token.EQUAL:
2255             if prev.type != syms.tname:
2256                 return NO
2257
2258         elif prev.type == token.EQUAL:
2259             # A bit hacky: if the equal sign has whitespace, it means we
2260             # previously found it's a typed argument.  So, we're using that, too.
2261             return prev.prefix
2262
2263         elif prev.type != token.COMMA:
2264             return NO
2265
2266     elif p.type == syms.tname:
2267         # type names
2268         if not prev:
2269             prevp = preceding_leaf(p)
2270             if not prevp or prevp.type != token.COMMA:
2271                 return NO
2272
2273     elif p.type == syms.trailer:
2274         # attributes and calls
2275         if t == token.LPAR or t == token.RPAR:
2276             return NO
2277
2278         if not prev:
2279             if t == token.DOT:
2280                 prevp = preceding_leaf(p)
2281                 if not prevp or prevp.type != token.NUMBER:
2282                     return NO
2283
2284             elif t == token.LSQB:
2285                 return NO
2286
2287         elif prev.type != token.COMMA:
2288             return NO
2289
2290     elif p.type == syms.argument:
2291         # single argument
2292         if t == token.EQUAL:
2293             return NO
2294
2295         if not prev:
2296             prevp = preceding_leaf(p)
2297             if not prevp or prevp.type == token.LPAR:
2298                 return NO
2299
2300         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2301             return NO
2302
2303     elif p.type == syms.decorator:
2304         # decorators
2305         return NO
2306
2307     elif p.type == syms.dotted_name:
2308         if prev:
2309             return NO
2310
2311         prevp = preceding_leaf(p)
2312         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2313             return NO
2314
2315     elif p.type == syms.classdef:
2316         if t == token.LPAR:
2317             return NO
2318
2319         if prev and prev.type == token.LPAR:
2320             return NO
2321
2322     elif p.type in {syms.subscript, syms.sliceop}:
2323         # indexing
2324         if not prev:
2325             assert p.parent is not None, "subscripts are always parented"
2326             if p.parent.type == syms.subscriptlist:
2327                 return SPACE
2328
2329             return NO
2330
2331         elif not complex_subscript:
2332             return NO
2333
2334     elif p.type == syms.atom:
2335         if prev and t == token.DOT:
2336             # dots, but not the first one.
2337             return NO
2338
2339     elif p.type == syms.dictsetmaker:
2340         # dict unpacking
2341         if prev and prev.type == token.DOUBLESTAR:
2342             return NO
2343
2344     elif p.type in {syms.factor, syms.star_expr}:
2345         # unary ops
2346         if not prev:
2347             prevp = preceding_leaf(p)
2348             if not prevp or prevp.type in OPENING_BRACKETS:
2349                 return NO
2350
2351             prevp_parent = prevp.parent
2352             assert prevp_parent is not None
2353             if prevp.type == token.COLON and prevp_parent.type in {
2354                 syms.subscript,
2355                 syms.sliceop,
2356             }:
2357                 return NO
2358
2359             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2360                 return NO
2361
2362         elif t in {token.NAME, token.NUMBER, token.STRING}:
2363             return NO
2364
2365     elif p.type == syms.import_from:
2366         if t == token.DOT:
2367             if prev and prev.type == token.DOT:
2368                 return NO
2369
2370         elif t == token.NAME:
2371             if v == "import":
2372                 return SPACE
2373
2374             if prev and prev.type == token.DOT:
2375                 return NO
2376
2377     elif p.type == syms.sliceop:
2378         return NO
2379
2380     return SPACE
2381
2382
2383 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2384     """Return the first leaf that precedes `node`, if any."""
2385     while node:
2386         res = node.prev_sibling
2387         if res:
2388             if isinstance(res, Leaf):
2389                 return res
2390
2391             try:
2392                 return list(res.leaves())[-1]
2393
2394             except IndexError:
2395                 return None
2396
2397         node = node.parent
2398     return None
2399
2400
2401 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2402     """Return if the `node` and its previous siblings match types against the provided
2403     list of tokens; the provided `node`has its type matched against the last element in
2404     the list.  `None` can be used as the first element to declare that the start of the
2405     list is anchored at the start of its parent's children."""
2406     if not tokens:
2407         return True
2408     if tokens[-1] is None:
2409         return node is None
2410     if not node:
2411         return False
2412     if node.type != tokens[-1]:
2413         return False
2414     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2415
2416
2417 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2418     """Return the child of `ancestor` that contains `descendant`."""
2419     node: Optional[LN] = descendant
2420     while node and node.parent != ancestor:
2421         node = node.parent
2422     return node
2423
2424
2425 def container_of(leaf: Leaf) -> LN:
2426     """Return `leaf` or one of its ancestors that is the topmost container of it.
2427
2428     By "container" we mean a node where `leaf` is the very first child.
2429     """
2430     same_prefix = leaf.prefix
2431     container: LN = leaf
2432     while container:
2433         parent = container.parent
2434         if parent is None:
2435             break
2436
2437         if parent.children[0].prefix != same_prefix:
2438             break
2439
2440         if parent.type == syms.file_input:
2441             break
2442
2443         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2444             break
2445
2446         container = parent
2447     return container
2448
2449
2450 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2451     """Return the priority of the `leaf` delimiter, given a line break after it.
2452
2453     The delimiter priorities returned here are from those delimiters that would
2454     cause a line break after themselves.
2455
2456     Higher numbers are higher priority.
2457     """
2458     if leaf.type == token.COMMA:
2459         return COMMA_PRIORITY
2460
2461     return 0
2462
2463
2464 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2465     """Return the priority of the `leaf` delimiter, given a line break before it.
2466
2467     The delimiter priorities returned here are from those delimiters that would
2468     cause a line break before themselves.
2469
2470     Higher numbers are higher priority.
2471     """
2472     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2473         # * and ** might also be MATH_OPERATORS but in this case they are not.
2474         # Don't treat them as a delimiter.
2475         return 0
2476
2477     if (
2478         leaf.type == token.DOT
2479         and leaf.parent
2480         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2481         and (previous is None or previous.type in CLOSING_BRACKETS)
2482     ):
2483         return DOT_PRIORITY
2484
2485     if (
2486         leaf.type in MATH_OPERATORS
2487         and leaf.parent
2488         and leaf.parent.type not in {syms.factor, syms.star_expr}
2489     ):
2490         return MATH_PRIORITIES[leaf.type]
2491
2492     if leaf.type in COMPARATORS:
2493         return COMPARATOR_PRIORITY
2494
2495     if (
2496         leaf.type == token.STRING
2497         and previous is not None
2498         and previous.type == token.STRING
2499     ):
2500         return STRING_PRIORITY
2501
2502     if leaf.type not in {token.NAME, token.ASYNC}:
2503         return 0
2504
2505     if (
2506         leaf.value == "for"
2507         and leaf.parent
2508         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2509         or leaf.type == token.ASYNC
2510     ):
2511         if (
2512             not isinstance(leaf.prev_sibling, Leaf)
2513             or leaf.prev_sibling.value != "async"
2514         ):
2515             return COMPREHENSION_PRIORITY
2516
2517     if (
2518         leaf.value == "if"
2519         and leaf.parent
2520         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2521     ):
2522         return COMPREHENSION_PRIORITY
2523
2524     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2525         return TERNARY_PRIORITY
2526
2527     if leaf.value == "is":
2528         return COMPARATOR_PRIORITY
2529
2530     if (
2531         leaf.value == "in"
2532         and leaf.parent
2533         and leaf.parent.type in {syms.comp_op, syms.comparison}
2534         and not (
2535             previous is not None
2536             and previous.type == token.NAME
2537             and previous.value == "not"
2538         )
2539     ):
2540         return COMPARATOR_PRIORITY
2541
2542     if (
2543         leaf.value == "not"
2544         and leaf.parent
2545         and leaf.parent.type == syms.comp_op
2546         and not (
2547             previous is not None
2548             and previous.type == token.NAME
2549             and previous.value == "is"
2550         )
2551     ):
2552         return COMPARATOR_PRIORITY
2553
2554     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2555         return LOGIC_PRIORITY
2556
2557     return 0
2558
2559
2560 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2561 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2562
2563
2564 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2565     """Clean the prefix of the `leaf` and generate comments from it, if any.
2566
2567     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2568     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2569     move because it does away with modifying the grammar to include all the
2570     possible places in which comments can be placed.
2571
2572     The sad consequence for us though is that comments don't "belong" anywhere.
2573     This is why this function generates simple parentless Leaf objects for
2574     comments.  We simply don't know what the correct parent should be.
2575
2576     No matter though, we can live without this.  We really only need to
2577     differentiate between inline and standalone comments.  The latter don't
2578     share the line with any code.
2579
2580     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2581     are emitted with a fake STANDALONE_COMMENT token identifier.
2582     """
2583     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2584         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2585
2586
2587 @dataclass
2588 class ProtoComment:
2589     """Describes a piece of syntax that is a comment.
2590
2591     It's not a :class:`blib2to3.pytree.Leaf` so that:
2592
2593     * it can be cached (`Leaf` objects should not be reused more than once as
2594       they store their lineno, column, prefix, and parent information);
2595     * `newlines` and `consumed` fields are kept separate from the `value`. This
2596       simplifies handling of special marker comments like ``# fmt: off/on``.
2597     """
2598
2599     type: int  # token.COMMENT or STANDALONE_COMMENT
2600     value: str  # content of the comment
2601     newlines: int  # how many newlines before the comment
2602     consumed: int  # how many characters of the original leaf's prefix did we consume
2603
2604
2605 @lru_cache(maxsize=4096)
2606 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2607     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2608     result: List[ProtoComment] = []
2609     if not prefix or "#" not in prefix:
2610         return result
2611
2612     consumed = 0
2613     nlines = 0
2614     ignored_lines = 0
2615     for index, line in enumerate(prefix.split("\n")):
2616         consumed += len(line) + 1  # adding the length of the split '\n'
2617         line = line.lstrip()
2618         if not line:
2619             nlines += 1
2620         if not line.startswith("#"):
2621             # Escaped newlines outside of a comment are not really newlines at
2622             # all. We treat a single-line comment following an escaped newline
2623             # as a simple trailing comment.
2624             if line.endswith("\\"):
2625                 ignored_lines += 1
2626             continue
2627
2628         if index == ignored_lines and not is_endmarker:
2629             comment_type = token.COMMENT  # simple trailing comment
2630         else:
2631             comment_type = STANDALONE_COMMENT
2632         comment = make_comment(line)
2633         result.append(
2634             ProtoComment(
2635                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2636             )
2637         )
2638         nlines = 0
2639     return result
2640
2641
2642 def make_comment(content: str) -> str:
2643     """Return a consistently formatted comment from the given `content` string.
2644
2645     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2646     space between the hash sign and the content.
2647
2648     If `content` didn't start with a hash sign, one is provided.
2649     """
2650     content = content.rstrip()
2651     if not content:
2652         return "#"
2653
2654     if content[0] == "#":
2655         content = content[1:]
2656     if content and content[0] not in " !:#'%":
2657         content = " " + content
2658     return "#" + content
2659
2660
2661 def transform_line(
2662     line: Line, mode: Mode, features: Collection[Feature] = ()
2663 ) -> Iterator[Line]:
2664     """Transform a `line`, potentially splitting it into many lines.
2665
2666     They should fit in the allotted `line_length` but might not be able to.
2667
2668     `features` are syntactical features that may be used in the output.
2669     """
2670     if line.is_comment:
2671         yield line
2672         return
2673
2674     line_str = line_to_string(line)
2675
2676     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2677         """Initialize StringTransformer"""
2678         return ST(mode.line_length, mode.string_normalization)
2679
2680     string_merge = init_st(StringMerger)
2681     string_paren_strip = init_st(StringParenStripper)
2682     string_split = init_st(StringSplitter)
2683     string_paren_wrap = init_st(StringParenWrapper)
2684
2685     transformers: List[Transformer]
2686     if (
2687         not line.contains_uncollapsable_type_comments()
2688         and not line.should_explode
2689         and not line.is_collection_with_optional_trailing_comma
2690         and (
2691             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2692             or line.contains_unsplittable_type_ignore()
2693         )
2694         and not (line.contains_standalone_comments() and line.inside_brackets)
2695     ):
2696         # Only apply basic string preprocessing, since lines shouldn't be split here.
2697         if mode.experimental_string_processing:
2698             transformers = [string_merge, string_paren_strip]
2699         else:
2700             transformers = []
2701     elif line.is_def:
2702         transformers = [left_hand_split]
2703     else:
2704
2705         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2706             for omit in generate_trailers_to_omit(line, mode.line_length):
2707                 lines = list(
2708                     right_hand_split(line, mode.line_length, features, omit=omit)
2709                 )
2710                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2711                     yield from lines
2712                     return
2713
2714             # All splits failed, best effort split with no omits.
2715             # This mostly happens to multiline strings that are by definition
2716             # reported as not fitting a single line.
2717             # line_length=1 here was historically a bug that somehow became a feature.
2718             # See #762 and #781 for the full story.
2719             yield from right_hand_split(line, line_length=1, features=features)
2720
2721         if mode.experimental_string_processing:
2722             if line.inside_brackets:
2723                 transformers = [
2724                     string_merge,
2725                     string_paren_strip,
2726                     delimiter_split,
2727                     standalone_comment_split,
2728                     string_split,
2729                     string_paren_wrap,
2730                     rhs,
2731                 ]
2732             else:
2733                 transformers = [
2734                     string_merge,
2735                     string_paren_strip,
2736                     string_split,
2737                     string_paren_wrap,
2738                     rhs,
2739                 ]
2740         else:
2741             if line.inside_brackets:
2742                 transformers = [delimiter_split, standalone_comment_split, rhs]
2743             else:
2744                 transformers = [rhs]
2745
2746     for transform in transformers:
2747         # We are accumulating lines in `result` because we might want to abort
2748         # mission and return the original line in the end, or attempt a different
2749         # split altogether.
2750         result: List[Line] = []
2751         try:
2752             for transformed_line in transform(line, features):
2753                 if str(transformed_line).strip("\n") == line_str:
2754                     raise CannotTransform(
2755                         "Line transformer returned an unchanged result"
2756                     )
2757
2758                 result.extend(
2759                     transform_line(transformed_line, mode=mode, features=features)
2760                 )
2761         except CannotTransform:
2762             continue
2763         else:
2764             yield from result
2765             break
2766
2767     else:
2768         yield line
2769
2770
2771 @dataclass  # type: ignore
2772 class StringTransformer(ABC):
2773     """
2774     An implementation of the Transformer protocol that relies on its
2775     subclasses overriding the template methods `do_match(...)` and
2776     `do_transform(...)`.
2777
2778     This Transformer works exclusively on strings (for example, by merging
2779     or splitting them).
2780
2781     The following sections can be found among the docstrings of each concrete
2782     StringTransformer subclass.
2783
2784     Requirements:
2785         Which requirements must be met of the given Line for this
2786         StringTransformer to be applied?
2787
2788     Transformations:
2789         If the given Line meets all of the above requirements, which string
2790         transformations can you expect to be applied to it by this
2791         StringTransformer?
2792
2793     Collaborations:
2794         What contractual agreements does this StringTransformer have with other
2795         StringTransfomers? Such collaborations should be eliminated/minimized
2796         as much as possible.
2797     """
2798
2799     line_length: int
2800     normalize_strings: bool
2801
2802     @abstractmethod
2803     def do_match(self, line: Line) -> TMatchResult:
2804         """
2805         Returns:
2806             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2807             string, if a match was able to be made.
2808                 OR
2809             * Err(CannotTransform), if a match was not able to be made.
2810         """
2811
2812     @abstractmethod
2813     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2814         """
2815         Yields:
2816             * Ok(new_line) where new_line is the new transformed line.
2817                 OR
2818             * Err(CannotTransform) if the transformation failed for some reason. The
2819             `do_match(...)` template method should usually be used to reject
2820             the form of the given Line, but in some cases it is difficult to
2821             know whether or not a Line meets the StringTransformer's
2822             requirements until the transformation is already midway.
2823
2824         Side Effects:
2825             This method should NOT mutate @line directly, but it MAY mutate the
2826             Line's underlying Node structure. (WARNING: If the underlying Node
2827             structure IS altered, then this method should NOT be allowed to
2828             yield an CannotTransform after that point.)
2829         """
2830
2831     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2832         """
2833         StringTransformer instances have a call signature that mirrors that of
2834         the Transformer type.
2835
2836         Raises:
2837             CannotTransform(...) if the concrete StringTransformer class is unable
2838             to transform @line.
2839         """
2840         # Optimization to avoid calling `self.do_match(...)` when the line does
2841         # not contain any string.
2842         if not any(leaf.type == token.STRING for leaf in line.leaves):
2843             raise CannotTransform("There are no strings in this line.")
2844
2845         match_result = self.do_match(line)
2846
2847         if isinstance(match_result, Err):
2848             cant_transform = match_result.err()
2849             raise CannotTransform(
2850                 f"The string transformer {self.__class__.__name__} does not recognize"
2851                 " this line as one that it can transform."
2852             ) from cant_transform
2853
2854         string_idx = match_result.ok()
2855
2856         for line_result in self.do_transform(line, string_idx):
2857             if isinstance(line_result, Err):
2858                 cant_transform = line_result.err()
2859                 raise CannotTransform(
2860                     "StringTransformer failed while attempting to transform string."
2861                 ) from cant_transform
2862             line = line_result.ok()
2863             yield line
2864
2865
2866 @dataclass
2867 class CustomSplit:
2868     """A custom (i.e. manual) string split.
2869
2870     A single CustomSplit instance represents a single substring.
2871
2872     Examples:
2873         Consider the following string:
2874         ```
2875         "Hi there friend."
2876         " This is a custom"
2877         f" string {split}."
2878         ```
2879
2880         This string will correspond to the following three CustomSplit instances:
2881         ```
2882         CustomSplit(False, 16)
2883         CustomSplit(False, 17)
2884         CustomSplit(True, 16)
2885         ```
2886     """
2887
2888     has_prefix: bool
2889     break_idx: int
2890
2891
2892 class CustomSplitMapMixin:
2893     """
2894     This mixin class is used to map merged strings to a sequence of
2895     CustomSplits, which will then be used to re-split the strings iff none of
2896     the resultant substrings go over the configured max line length.
2897     """
2898
2899     _Key = Tuple[StringID, str]
2900     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2901
2902     @staticmethod
2903     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2904         """
2905         Returns:
2906             A unique identifier that is used internally to map @string to a
2907             group of custom splits.
2908         """
2909         return (id(string), string)
2910
2911     def add_custom_splits(
2912         self, string: str, custom_splits: Iterable[CustomSplit]
2913     ) -> None:
2914         """Custom Split Map Setter Method
2915
2916         Side Effects:
2917             Adds a mapping from @string to the custom splits @custom_splits.
2918         """
2919         key = self._get_key(string)
2920         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2921
2922     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2923         """Custom Split Map Getter Method
2924
2925         Returns:
2926             * A list of the custom splits that are mapped to @string, if any
2927             exist.
2928                 OR
2929             * [], otherwise.
2930
2931         Side Effects:
2932             Deletes the mapping between @string and its associated custom
2933             splits (which are returned to the caller).
2934         """
2935         key = self._get_key(string)
2936
2937         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2938         del self._CUSTOM_SPLIT_MAP[key]
2939
2940         return list(custom_splits)
2941
2942     def has_custom_splits(self, string: str) -> bool:
2943         """
2944         Returns:
2945             True iff @string is associated with a set of custom splits.
2946         """
2947         key = self._get_key(string)
2948         return key in self._CUSTOM_SPLIT_MAP
2949
2950
2951 class StringMerger(CustomSplitMapMixin, StringTransformer):
2952     """StringTransformer that merges strings together.
2953
2954     Requirements:
2955         (A) The line contains adjacent strings such that at most one substring
2956         has inline comments AND none of those inline comments are pragmas AND
2957         the set of all substring prefixes is either of length 1 or equal to
2958         {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed
2959         with 'r').
2960             OR
2961         (B) The line contains a string which uses line continuation backslashes.
2962
2963     Transformations:
2964         Depending on which of the two requirements above where met, either:
2965
2966         (A) The string group associated with the target string is merged.
2967             OR
2968         (B) All line-continuation backslashes are removed from the target string.
2969
2970     Collaborations:
2971         StringMerger provides custom split information to StringSplitter.
2972     """
2973
2974     def do_match(self, line: Line) -> TMatchResult:
2975         LL = line.leaves
2976
2977         is_valid_index = is_valid_index_factory(LL)
2978
2979         for (i, leaf) in enumerate(LL):
2980             if (
2981                 leaf.type == token.STRING
2982                 and is_valid_index(i + 1)
2983                 and LL[i + 1].type == token.STRING
2984             ):
2985                 return Ok(i)
2986
2987             if leaf.type == token.STRING and "\\\n" in leaf.value:
2988                 return Ok(i)
2989
2990         return TErr("This line has no strings that need merging.")
2991
2992     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2993         new_line = line
2994         rblc_result = self.__remove_backslash_line_continuation_chars(
2995             new_line, string_idx
2996         )
2997         if isinstance(rblc_result, Ok):
2998             new_line = rblc_result.ok()
2999
3000         msg_result = self.__merge_string_group(new_line, string_idx)
3001         if isinstance(msg_result, Ok):
3002             new_line = msg_result.ok()
3003
3004         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
3005             msg_cant_transform = msg_result.err()
3006             rblc_cant_transform = rblc_result.err()
3007             cant_transform = CannotTransform(
3008                 "StringMerger failed to merge any strings in this line."
3009             )
3010
3011             # Chain the errors together using `__cause__`.
3012             msg_cant_transform.__cause__ = rblc_cant_transform
3013             cant_transform.__cause__ = msg_cant_transform
3014
3015             yield Err(cant_transform)
3016         else:
3017             yield Ok(new_line)
3018
3019     @staticmethod
3020     def __remove_backslash_line_continuation_chars(
3021         line: Line, string_idx: int
3022     ) -> TResult[Line]:
3023         """
3024         Merge strings that were split across multiple lines using
3025         line-continuation backslashes.
3026
3027         Returns:
3028             Ok(new_line), if @line contains backslash line-continuation
3029             characters.
3030                 OR
3031             Err(CannotTransform), otherwise.
3032         """
3033         LL = line.leaves
3034
3035         string_leaf = LL[string_idx]
3036         if not (
3037             string_leaf.type == token.STRING
3038             and "\\\n" in string_leaf.value
3039             and not has_triple_quotes(string_leaf.value)
3040         ):
3041             return TErr(
3042                 f"String leaf {string_leaf} does not contain any backslash line"
3043                 " continuation characters."
3044             )
3045
3046         new_line = line.clone()
3047         new_line.comments = line.comments
3048         append_leaves(new_line, line, LL)
3049
3050         new_string_leaf = new_line.leaves[string_idx]
3051         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3052
3053         return Ok(new_line)
3054
3055     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3056         """
3057         Merges string group (i.e. set of adjacent strings) where the first
3058         string in the group is `line.leaves[string_idx]`.
3059
3060         Returns:
3061             Ok(new_line), if ALL of the validation checks found in
3062             __validate_msg(...) pass.
3063                 OR
3064             Err(CannotTransform), otherwise.
3065         """
3066         LL = line.leaves
3067
3068         is_valid_index = is_valid_index_factory(LL)
3069
3070         vresult = self.__validate_msg(line, string_idx)
3071         if isinstance(vresult, Err):
3072             return vresult
3073
3074         # If the string group is wrapped inside an Atom node, we must make sure
3075         # to later replace that Atom with our new (merged) string leaf.
3076         atom_node = LL[string_idx].parent
3077
3078         # We will place BREAK_MARK in between every two substrings that we
3079         # merge. We will then later go through our final result and use the
3080         # various instances of BREAK_MARK we find to add the right values to
3081         # the custom split map.
3082         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3083
3084         QUOTE = LL[string_idx].value[-1]
3085
3086         def make_naked(string: str, string_prefix: str) -> str:
3087             """Strip @string (i.e. make it a "naked" string)
3088
3089             Pre-conditions:
3090                 * assert_is_leaf_string(@string)
3091
3092             Returns:
3093                 A string that is identical to @string except that
3094                 @string_prefix has been stripped, the surrounding QUOTE
3095                 characters have been removed, and any remaining QUOTE
3096                 characters have been escaped.
3097             """
3098             assert_is_leaf_string(string)
3099
3100             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3101             naked_string = string[len(string_prefix) + 1 : -1]
3102             naked_string = re.sub(
3103                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3104             )
3105             return naked_string
3106
3107         # Holds the CustomSplit objects that will later be added to the custom
3108         # split map.
3109         custom_splits = []
3110
3111         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3112         prefix_tracker = []
3113
3114         # Sets the 'prefix' variable. This is the prefix that the final merged
3115         # string will have.
3116         next_str_idx = string_idx
3117         prefix = ""
3118         while (
3119             not prefix
3120             and is_valid_index(next_str_idx)
3121             and LL[next_str_idx].type == token.STRING
3122         ):
3123             prefix = get_string_prefix(LL[next_str_idx].value)
3124             next_str_idx += 1
3125
3126         # The next loop merges the string group. The final string will be
3127         # contained in 'S'.
3128         #
3129         # The following convenience variables are used:
3130         #
3131         #   S: string
3132         #   NS: naked string
3133         #   SS: next string
3134         #   NSS: naked next string
3135         S = ""
3136         NS = ""
3137         num_of_strings = 0
3138         next_str_idx = string_idx
3139         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3140             num_of_strings += 1
3141
3142             SS = LL[next_str_idx].value
3143             next_prefix = get_string_prefix(SS)
3144
3145             # If this is an f-string group but this substring is not prefixed
3146             # with 'f'...
3147             if "f" in prefix and "f" not in next_prefix:
3148                 # Then we must escape any braces contained in this substring.
3149                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3150
3151             NSS = make_naked(SS, next_prefix)
3152
3153             has_prefix = bool(next_prefix)
3154             prefix_tracker.append(has_prefix)
3155
3156             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3157             NS = make_naked(S, prefix)
3158
3159             next_str_idx += 1
3160
3161         S_leaf = Leaf(token.STRING, S)
3162         if self.normalize_strings:
3163             normalize_string_quotes(S_leaf)
3164
3165         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3166         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3167         for has_prefix in prefix_tracker:
3168             mark_idx = temp_string.find(BREAK_MARK)
3169             assert (
3170                 mark_idx >= 0
3171             ), "Logic error while filling the custom string breakpoint cache."
3172
3173             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3174             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3175             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3176
3177         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3178
3179         if atom_node is not None:
3180             replace_child(atom_node, string_leaf)
3181
3182         # Build the final line ('new_line') that this method will later return.
3183         new_line = line.clone()
3184         for (i, leaf) in enumerate(LL):
3185             if i == string_idx:
3186                 new_line.append(string_leaf)
3187
3188             if string_idx <= i < string_idx + num_of_strings:
3189                 for comment_leaf in line.comments_after(LL[i]):
3190                     new_line.append(comment_leaf, preformatted=True)
3191                 continue
3192
3193             append_leaves(new_line, line, [leaf])
3194
3195         self.add_custom_splits(string_leaf.value, custom_splits)
3196         return Ok(new_line)
3197
3198     @staticmethod
3199     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3200         """Validate (M)erge (S)tring (G)roup
3201
3202         Transform-time string validation logic for __merge_string_group(...).
3203
3204         Returns:
3205             * Ok(None), if ALL validation checks (listed below) pass.
3206                 OR
3207             * Err(CannotTransform), if any of the following are true:
3208                 - The target string is not in a string group (i.e. it has no
3209                   adjacent strings).
3210                 - The string group has more than one inline comment.
3211                 - The string group has an inline comment that appears to be a pragma.
3212                 - The set of all string prefixes in the string group is of
3213                   length greater than one and is not equal to {"", "f"}.
3214                 - The string group consists of raw strings.
3215         """
3216         num_of_inline_string_comments = 0
3217         set_of_prefixes = set()
3218         num_of_strings = 0
3219         for leaf in line.leaves[string_idx:]:
3220             if leaf.type != token.STRING:
3221                 # If the string group is trailed by a comma, we count the
3222                 # comments trailing the comma to be one of the string group's
3223                 # comments.
3224                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3225                     num_of_inline_string_comments += 1
3226                 break
3227
3228             if has_triple_quotes(leaf.value):
3229                 return TErr("StringMerger does NOT merge multiline strings.")
3230
3231             num_of_strings += 1
3232             prefix = get_string_prefix(leaf.value)
3233             if "r" in prefix:
3234                 return TErr("StringMerger does NOT merge raw strings.")
3235
3236             set_of_prefixes.add(prefix)
3237
3238             if id(leaf) in line.comments:
3239                 num_of_inline_string_comments += 1
3240                 if contains_pragma_comment(line.comments[id(leaf)]):
3241                     return TErr("Cannot merge strings which have pragma comments.")
3242
3243         if num_of_strings < 2:
3244             return TErr(
3245                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3246             )
3247
3248         if num_of_inline_string_comments > 1:
3249             return TErr(
3250                 f"Too many inline string comments ({num_of_inline_string_comments})."
3251             )
3252
3253         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3254             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3255
3256         return Ok(None)
3257
3258
3259 class StringParenStripper(StringTransformer):
3260     """StringTransformer that strips surrounding parentheses from strings.
3261
3262     Requirements:
3263         The line contains a string which is surrounded by parentheses and:
3264             - The target string is NOT the only argument to a function call).
3265             - If the target string contains a PERCENT, the brackets are not
3266               preceeded or followed by an operator with higher precedence than
3267               PERCENT.
3268
3269     Transformations:
3270         The parentheses mentioned in the 'Requirements' section are stripped.
3271
3272     Collaborations:
3273         StringParenStripper has its own inherent usefulness, but it is also
3274         relied on to clean up the parentheses created by StringParenWrapper (in
3275         the event that they are no longer needed).
3276     """
3277
3278     def do_match(self, line: Line) -> TMatchResult:
3279         LL = line.leaves
3280
3281         is_valid_index = is_valid_index_factory(LL)
3282
3283         for (idx, leaf) in enumerate(LL):
3284             # Should be a string...
3285             if leaf.type != token.STRING:
3286                 continue
3287
3288             # Should be preceded by a non-empty LPAR...
3289             if (
3290                 not is_valid_index(idx - 1)
3291                 or LL[idx - 1].type != token.LPAR
3292                 or is_empty_lpar(LL[idx - 1])
3293             ):
3294                 continue
3295
3296             # That LPAR should NOT be preceded by a function name or a closing
3297             # bracket (which could be a function which returns a function or a
3298             # list/dictionary that contains a function)...
3299             if is_valid_index(idx - 2) and (
3300                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3301             ):
3302                 continue
3303
3304             string_idx = idx
3305
3306             # Skip the string trailer, if one exists.
3307             string_parser = StringParser()
3308             next_idx = string_parser.parse(LL, string_idx)
3309
3310             # if the leaves in the parsed string include a PERCENT, we need to
3311             # make sure the initial LPAR is NOT preceded by an operator with
3312             # higher or equal precedence to PERCENT
3313             if is_valid_index(idx - 2):
3314                 # mypy can't quite follow unless we name this
3315                 before_lpar = LL[idx - 2]
3316                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3317                     (
3318                         before_lpar.type
3319                         in {
3320                             token.STAR,
3321                             token.AT,
3322                             token.SLASH,
3323                             token.DOUBLESLASH,
3324                             token.PERCENT,
3325                             token.TILDE,
3326                             token.DOUBLESTAR,
3327                             token.AWAIT,
3328                             token.LSQB,
3329                             token.LPAR,
3330                         }
3331                     )
3332                     or (
3333                         # only unary PLUS/MINUS
3334                         before_lpar.parent
3335                         and before_lpar.parent.type == syms.factor
3336                         and (before_lpar.type in {token.PLUS, token.MINUS})
3337                     )
3338                 ):
3339                     continue
3340
3341             # Should be followed by a non-empty RPAR...
3342             if (
3343                 is_valid_index(next_idx)
3344                 and LL[next_idx].type == token.RPAR
3345                 and not is_empty_rpar(LL[next_idx])
3346             ):
3347                 # That RPAR should NOT be followed by anything with higher
3348                 # precedence than PERCENT
3349                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3350                     token.DOUBLESTAR,
3351                     token.LSQB,
3352                     token.LPAR,
3353                     token.DOT,
3354                 }:
3355                     continue
3356
3357                 return Ok(string_idx)
3358
3359         return TErr("This line has no strings wrapped in parens.")
3360
3361     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3362         LL = line.leaves
3363
3364         string_parser = StringParser()
3365         rpar_idx = string_parser.parse(LL, string_idx)
3366
3367         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3368             if line.comments_after(leaf):
3369                 yield TErr(
3370                     "Will not strip parentheses which have comments attached to them."
3371                 )
3372
3373         new_line = line.clone()
3374         new_line.comments = line.comments.copy()
3375
3376         append_leaves(new_line, line, LL[: string_idx - 1])
3377
3378         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3379         LL[string_idx - 1].remove()
3380         replace_child(LL[string_idx], string_leaf)
3381         new_line.append(string_leaf)
3382
3383         append_leaves(
3384             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3385         )
3386
3387         LL[rpar_idx].remove()
3388
3389         yield Ok(new_line)
3390
3391
3392 class BaseStringSplitter(StringTransformer):
3393     """
3394     Abstract class for StringTransformers which transform a Line's strings by splitting
3395     them or placing them on their own lines where necessary to avoid going over
3396     the configured line length.
3397
3398     Requirements:
3399         * The target string value is responsible for the line going over the
3400         line length limit. It follows that after all of black's other line
3401         split methods have been exhausted, this line (or one of the resulting
3402         lines after all line splits are performed) would still be over the
3403         line_length limit unless we split this string.
3404             AND
3405         * The target string is NOT a "pointless" string (i.e. a string that has
3406         no parent or siblings).
3407             AND
3408         * The target string is not followed by an inline comment that appears
3409         to be a pragma.
3410             AND
3411         * The target string is not a multiline (i.e. triple-quote) string.
3412     """
3413
3414     @abstractmethod
3415     def do_splitter_match(self, line: Line) -> TMatchResult:
3416         """
3417         BaseStringSplitter asks its clients to override this method instead of
3418         `StringTransformer.do_match(...)`.
3419
3420         Follows the same protocol as `StringTransformer.do_match(...)`.
3421
3422         Refer to `help(StringTransformer.do_match)` for more information.
3423         """
3424
3425     def do_match(self, line: Line) -> TMatchResult:
3426         match_result = self.do_splitter_match(line)
3427         if isinstance(match_result, Err):
3428             return match_result
3429
3430         string_idx = match_result.ok()
3431         vresult = self.__validate(line, string_idx)
3432         if isinstance(vresult, Err):
3433             return vresult
3434
3435         return match_result
3436
3437     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3438         """
3439         Checks that @line meets all of the requirements listed in this classes'
3440         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3441         description of those requirements.
3442
3443         Returns:
3444             * Ok(None), if ALL of the requirements are met.
3445                 OR
3446             * Err(CannotTransform), if ANY of the requirements are NOT met.
3447         """
3448         LL = line.leaves
3449
3450         string_leaf = LL[string_idx]
3451
3452         max_string_length = self.__get_max_string_length(line, string_idx)
3453         if len(string_leaf.value) <= max_string_length:
3454             return TErr(
3455                 "The string itself is not what is causing this line to be too long."
3456             )
3457
3458         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3459             token.STRING,
3460             token.NEWLINE,
3461         ]:
3462             return TErr(
3463                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3464                 " no parent)."
3465             )
3466
3467         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3468             line.comments[id(line.leaves[string_idx])]
3469         ):
3470             return TErr(
3471                 "Line appears to end with an inline pragma comment. Splitting the line"
3472                 " could modify the pragma's behavior."
3473             )
3474
3475         if has_triple_quotes(string_leaf.value):
3476             return TErr("We cannot split multiline strings.")
3477
3478         return Ok(None)
3479
3480     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3481         """
3482         Calculates the max string length used when attempting to determine
3483         whether or not the target string is responsible for causing the line to
3484         go over the line length limit.
3485
3486         WARNING: This method is tightly coupled to both StringSplitter and
3487         (especially) StringParenWrapper. There is probably a better way to
3488         accomplish what is being done here.
3489
3490         Returns:
3491             max_string_length: such that `line.leaves[string_idx].value >
3492             max_string_length` implies that the target string IS responsible
3493             for causing this line to exceed the line length limit.
3494         """
3495         LL = line.leaves
3496
3497         is_valid_index = is_valid_index_factory(LL)
3498
3499         # We use the shorthand "WMA4" in comments to abbreviate "We must
3500         # account for". When giving examples, we use STRING to mean some/any
3501         # valid string.
3502         #
3503         # Finally, we use the following convenience variables:
3504         #
3505         #   P:  The leaf that is before the target string leaf.
3506         #   N:  The leaf that is after the target string leaf.
3507         #   NN: The leaf that is after N.
3508
3509         # WMA4 the whitespace at the beginning of the line.
3510         offset = line.depth * 4
3511
3512         if is_valid_index(string_idx - 1):
3513             p_idx = string_idx - 1
3514             if (
3515                 LL[string_idx - 1].type == token.LPAR
3516                 and LL[string_idx - 1].value == ""
3517                 and string_idx >= 2
3518             ):
3519                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3520                 p_idx -= 1
3521
3522             P = LL[p_idx]
3523             if P.type == token.PLUS:
3524                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3525                 offset += 2
3526
3527             if P.type == token.COMMA:
3528                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3529                 offset += 3
3530
3531             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3532                 # This conditional branch is meant to handle dictionary keys,
3533                 # variable assignments, 'return STRING' statement lines, and
3534                 # 'else STRING' ternary expression lines.
3535
3536                 # WMA4 a single space.
3537                 offset += 1
3538
3539                 # WMA4 the lengths of any leaves that came before that space.
3540                 for leaf in LL[: p_idx + 1]:
3541                     offset += len(str(leaf))
3542
3543         if is_valid_index(string_idx + 1):
3544             N = LL[string_idx + 1]
3545             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3546                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3547                 N = LL[string_idx + 2]
3548
3549             if N.type == token.COMMA:
3550                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3551                 offset += 1
3552
3553             if is_valid_index(string_idx + 2):
3554                 NN = LL[string_idx + 2]
3555
3556                 if N.type == token.DOT and NN.type == token.NAME:
3557                     # This conditional branch is meant to handle method calls invoked
3558                     # off of a string literal up to and including the LPAR character.
3559
3560                     # WMA4 the '.' character.
3561                     offset += 1
3562
3563                     if (
3564                         is_valid_index(string_idx + 3)
3565                         and LL[string_idx + 3].type == token.LPAR
3566                     ):
3567                         # WMA4 the left parenthesis character.
3568                         offset += 1
3569
3570                     # WMA4 the length of the method's name.
3571                     offset += len(NN.value)
3572
3573         has_comments = False
3574         for comment_leaf in line.comments_after(LL[string_idx]):
3575             if not has_comments:
3576                 has_comments = True
3577                 # WMA4 two spaces before the '#' character.
3578                 offset += 2
3579
3580             # WMA4 the length of the inline comment.
3581             offset += len(comment_leaf.value)
3582
3583         max_string_length = self.line_length - offset
3584         return max_string_length
3585
3586
3587 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3588     """
3589     StringTransformer that splits "atom" strings (i.e. strings which exist on
3590     lines by themselves).
3591
3592     Requirements:
3593         * The line consists ONLY of a single string (with the exception of a
3594         '+' symbol which MAY exist at the start of the line), MAYBE a string
3595         trailer, and MAYBE a trailing comma.
3596             AND
3597         * All of the requirements listed in BaseStringSplitter's docstring.
3598
3599     Transformations:
3600         The string mentioned in the 'Requirements' section is split into as
3601         many substrings as necessary to adhere to the configured line length.
3602
3603         In the final set of substrings, no substring should be smaller than
3604         MIN_SUBSTR_SIZE characters.
3605
3606         The string will ONLY be split on spaces (i.e. each new substring should
3607         start with a space).
3608
3609         If the string is an f-string, it will NOT be split in the middle of an
3610         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3611         else bar()} is an f-expression).
3612
3613         If the string that is being split has an associated set of custom split
3614         records and those custom splits will NOT result in any line going over
3615         the configured line length, those custom splits are used. Otherwise the
3616         string is split as late as possible (from left-to-right) while still
3617         adhering to the transformation rules listed above.
3618
3619     Collaborations:
3620         StringSplitter relies on StringMerger to construct the appropriate
3621         CustomSplit objects and add them to the custom split map.
3622     """
3623
3624     MIN_SUBSTR_SIZE = 6
3625     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3626     RE_FEXPR = r"""
3627     (?<!\{)\{
3628         (?:
3629             [^\{\}]
3630             | \{\{
3631             | \}\}
3632         )+?
3633     (?<!\})(?:\}\})*\}(?!\})
3634     """
3635
3636     def do_splitter_match(self, line: Line) -> TMatchResult:
3637         LL = line.leaves
3638
3639         is_valid_index = is_valid_index_factory(LL)
3640
3641         idx = 0
3642
3643         # The first leaf MAY be a '+' symbol...
3644         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3645             idx += 1
3646
3647         # The next/first leaf MAY be an empty LPAR...
3648         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3649             idx += 1
3650
3651         # The next/first leaf MUST be a string...
3652         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3653             return TErr("Line does not start with a string.")
3654
3655         string_idx = idx
3656
3657         # Skip the string trailer, if one exists.
3658         string_parser = StringParser()
3659         idx = string_parser.parse(LL, string_idx)
3660
3661         # That string MAY be followed by an empty RPAR...
3662         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3663             idx += 1
3664
3665         # That string / empty RPAR leaf MAY be followed by a comma...
3666         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3667             idx += 1
3668
3669         # But no more leaves are allowed...
3670         if is_valid_index(idx):
3671             return TErr("This line does not end with a string.")
3672
3673         return Ok(string_idx)
3674
3675     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3676         LL = line.leaves
3677
3678         QUOTE = LL[string_idx].value[-1]
3679
3680         is_valid_index = is_valid_index_factory(LL)
3681         insert_str_child = insert_str_child_factory(LL[string_idx])
3682
3683         prefix = get_string_prefix(LL[string_idx].value)
3684
3685         # We MAY choose to drop the 'f' prefix from substrings that don't
3686         # contain any f-expressions, but ONLY if the original f-string
3687         # contains at least one f-expression. Otherwise, we will alter the AST
3688         # of the program.
3689         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3690             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3691         )
3692
3693         first_string_line = True
3694         starts_with_plus = LL[0].type == token.PLUS
3695
3696         def line_needs_plus() -> bool:
3697             return first_string_line and starts_with_plus
3698
3699         def maybe_append_plus(new_line: Line) -> None:
3700             """
3701             Side Effects:
3702                 If @line starts with a plus and this is the first line we are
3703                 constructing, this function appends a PLUS leaf to @new_line
3704                 and replaces the old PLUS leaf in the node structure. Otherwise
3705                 this function does nothing.
3706             """
3707             if line_needs_plus():
3708                 plus_leaf = Leaf(token.PLUS, "+")
3709                 replace_child(LL[0], plus_leaf)
3710                 new_line.append(plus_leaf)
3711
3712         ends_with_comma = (
3713             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3714         )
3715
3716         def max_last_string() -> int:
3717             """
3718             Returns:
3719                 The max allowed length of the string value used for the last
3720                 line we will construct.
3721             """
3722             result = self.line_length
3723             result -= line.depth * 4
3724             result -= 1 if ends_with_comma else 0
3725             result -= 2 if line_needs_plus() else 0
3726             return result
3727
3728         # --- Calculate Max Break Index (for string value)
3729         # We start with the line length limit
3730         max_break_idx = self.line_length
3731         # The last index of a string of length N is N-1.
3732         max_break_idx -= 1
3733         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3734         max_break_idx -= line.depth * 4
3735         if max_break_idx < 0:
3736             yield TErr(
3737                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3738                 f" {line.depth}"
3739             )
3740             return
3741
3742         # Check if StringMerger registered any custom splits.
3743         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3744         # We use them ONLY if none of them would produce lines that exceed the
3745         # line limit.
3746         use_custom_breakpoints = bool(
3747             custom_splits
3748             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3749         )
3750
3751         # Temporary storage for the remaining chunk of the string line that
3752         # can't fit onto the line currently being constructed.
3753         rest_value = LL[string_idx].value
3754
3755         def more_splits_should_be_made() -> bool:
3756             """
3757             Returns:
3758                 True iff `rest_value` (the remaining string value from the last
3759                 split), should be split again.
3760             """
3761             if use_custom_breakpoints:
3762                 return len(custom_splits) > 1
3763             else:
3764                 return len(rest_value) > max_last_string()
3765
3766         string_line_results: List[Ok[Line]] = []
3767         while more_splits_should_be_made():
3768             if use_custom_breakpoints:
3769                 # Custom User Split (manual)
3770                 csplit = custom_splits.pop(0)
3771                 break_idx = csplit.break_idx
3772             else:
3773                 # Algorithmic Split (automatic)
3774                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3775                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3776                 if maybe_break_idx is None:
3777                     # If we are unable to algorithmically determine a good split
3778                     # and this string has custom splits registered to it, we
3779                     # fall back to using them--which means we have to start
3780                     # over from the beginning.
3781                     if custom_splits:
3782                         rest_value = LL[string_idx].value
3783                         string_line_results = []
3784                         first_string_line = True
3785                         use_custom_breakpoints = True
3786                         continue
3787
3788                     # Otherwise, we stop splitting here.
3789                     break
3790
3791                 break_idx = maybe_break_idx
3792
3793             # --- Construct `next_value`
3794             next_value = rest_value[:break_idx] + QUOTE
3795             if (
3796                 # Are we allowed to try to drop a pointless 'f' prefix?
3797                 drop_pointless_f_prefix
3798                 # If we are, will we be successful?
3799                 and next_value != self.__normalize_f_string(next_value, prefix)
3800             ):
3801                 # If the current custom split did NOT originally use a prefix,
3802                 # then `csplit.break_idx` will be off by one after removing
3803                 # the 'f' prefix.
3804                 break_idx = (
3805                     break_idx + 1
3806                     if use_custom_breakpoints and not csplit.has_prefix
3807                     else break_idx
3808                 )
3809                 next_value = rest_value[:break_idx] + QUOTE
3810                 next_value = self.__normalize_f_string(next_value, prefix)
3811
3812             # --- Construct `next_leaf`
3813             next_leaf = Leaf(token.STRING, next_value)
3814             insert_str_child(next_leaf)
3815             self.__maybe_normalize_string_quotes(next_leaf)
3816
3817             # --- Construct `next_line`
3818             next_line = line.clone()
3819             maybe_append_plus(next_line)
3820             next_line.append(next_leaf)
3821             string_line_results.append(Ok(next_line))
3822
3823             rest_value = prefix + QUOTE + rest_value[break_idx:]
3824             first_string_line = False
3825
3826         yield from string_line_results
3827
3828         if drop_pointless_f_prefix:
3829             rest_value = self.__normalize_f_string(rest_value, prefix)
3830
3831         rest_leaf = Leaf(token.STRING, rest_value)
3832         insert_str_child(rest_leaf)
3833
3834         # NOTE: I could not find a test case that verifies that the following
3835         # line is actually necessary, but it seems to be. Otherwise we risk
3836         # not normalizing the last substring, right?
3837         self.__maybe_normalize_string_quotes(rest_leaf)
3838
3839         last_line = line.clone()
3840         maybe_append_plus(last_line)
3841
3842         # If there are any leaves to the right of the target string...
3843         if is_valid_index(string_idx + 1):
3844             # We use `temp_value` here to determine how long the last line
3845             # would be if we were to append all the leaves to the right of the
3846             # target string to the last string line.
3847             temp_value = rest_value
3848             for leaf in LL[string_idx + 1 :]:
3849                 temp_value += str(leaf)
3850                 if leaf.type == token.LPAR:
3851                     break
3852
3853             # Try to fit them all on the same line with the last substring...
3854             if (
3855                 len(temp_value) <= max_last_string()
3856                 or LL[string_idx + 1].type == token.COMMA
3857             ):
3858                 last_line.append(rest_leaf)
3859                 append_leaves(last_line, line, LL[string_idx + 1 :])
3860                 yield Ok(last_line)
3861             # Otherwise, place the last substring on one line and everything
3862             # else on a line below that...
3863             else:
3864                 last_line.append(rest_leaf)
3865                 yield Ok(last_line)
3866
3867                 non_string_line = line.clone()
3868                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3869                 yield Ok(non_string_line)
3870         # Else the target string was the last leaf...
3871         else:
3872             last_line.append(rest_leaf)
3873             last_line.comments = line.comments.copy()
3874             yield Ok(last_line)
3875
3876     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3877         """
3878         This method contains the algorithm that StringSplitter uses to
3879         determine which character to split each string at.
3880
3881         Args:
3882             @string: The substring that we are attempting to split.
3883             @max_break_idx: The ideal break index. We will return this value if it
3884             meets all the necessary conditions. In the likely event that it
3885             doesn't we will try to find the closest index BELOW @max_break_idx
3886             that does. If that fails, we will expand our search by also
3887             considering all valid indices ABOVE @max_break_idx.
3888
3889         Pre-Conditions:
3890             * assert_is_leaf_string(@string)
3891             * 0 <= @max_break_idx < len(@string)
3892
3893         Returns:
3894             break_idx, if an index is able to be found that meets all of the
3895             conditions listed in the 'Transformations' section of this classes'
3896             docstring.
3897                 OR
3898             None, otherwise.
3899         """
3900         is_valid_index = is_valid_index_factory(string)
3901
3902         assert is_valid_index(max_break_idx)
3903         assert_is_leaf_string(string)
3904
3905         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3906
3907         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3908             """
3909             Yields:
3910                 All ranges of @string which, if @string were to be split there,
3911                 would result in the splitting of an f-expression (which is NOT
3912                 allowed).
3913             """
3914             nonlocal _fexpr_slices
3915
3916             if _fexpr_slices is None:
3917                 _fexpr_slices = []
3918                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3919                     _fexpr_slices.append(match.span())
3920
3921             yield from _fexpr_slices
3922
3923         is_fstring = "f" in get_string_prefix(string)
3924
3925         def breaks_fstring_expression(i: Index) -> bool:
3926             """
3927             Returns:
3928                 True iff returning @i would result in the splitting of an
3929                 f-expression (which is NOT allowed).
3930             """
3931             if not is_fstring:
3932                 return False
3933
3934             for (start, end) in fexpr_slices():
3935                 if start <= i < end:
3936                     return True
3937
3938             return False
3939
3940         def passes_all_checks(i: Index) -> bool:
3941             """
3942             Returns:
3943                 True iff ALL of the conditions listed in the 'Transformations'
3944                 section of this classes' docstring would be be met by returning @i.
3945             """
3946             is_space = string[i] == " "
3947             is_big_enough = (
3948                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3949                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3950             )
3951             return is_space and is_big_enough and not breaks_fstring_expression(i)
3952
3953         # First, we check all indices BELOW @max_break_idx.
3954         break_idx = max_break_idx
3955         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3956             break_idx -= 1
3957
3958         if not passes_all_checks(break_idx):
3959             # If that fails, we check all indices ABOVE @max_break_idx.
3960             #
3961             # If we are able to find a valid index here, the next line is going
3962             # to be longer than the specified line length, but it's probably
3963             # better than doing nothing at all.
3964             break_idx = max_break_idx + 1
3965             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3966                 break_idx += 1
3967
3968             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3969                 return None
3970
3971         return break_idx
3972
3973     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3974         if self.normalize_strings:
3975             normalize_string_quotes(leaf)
3976
3977     def __normalize_f_string(self, string: str, prefix: str) -> str:
3978         """
3979         Pre-Conditions:
3980             * assert_is_leaf_string(@string)
3981
3982         Returns:
3983             * If @string is an f-string that contains no f-expressions, we
3984             return a string identical to @string except that the 'f' prefix
3985             has been stripped and all double braces (i.e. '{{' or '}}') have
3986             been normalized (i.e. turned into '{' or '}').
3987                 OR
3988             * Otherwise, we return @string.
3989         """
3990         assert_is_leaf_string(string)
3991
3992         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3993             new_prefix = prefix.replace("f", "")
3994
3995             temp = string[len(prefix) :]
3996             temp = re.sub(r"\{\{", "{", temp)
3997             temp = re.sub(r"\}\}", "}", temp)
3998             new_string = temp
3999
4000             return f"{new_prefix}{new_string}"
4001         else:
4002             return string
4003
4004
4005 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4006     """
4007     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4008     exist on lines by themselves).
4009
4010     Requirements:
4011         All of the requirements listed in BaseStringSplitter's docstring in
4012         addition to the requirements listed below:
4013
4014         * The line is a return/yield statement, which returns/yields a string.
4015             OR
4016         * The line is part of a ternary expression (e.g. `x = y if cond else
4017         z`) such that the line starts with `else <string>`, where <string> is
4018         some string.
4019             OR
4020         * The line is an assert statement, which ends with a string.
4021             OR
4022         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4023         <string>`) such that the variable is being assigned the value of some
4024         string.
4025             OR
4026         * The line is a dictionary key assignment where some valid key is being
4027         assigned the value of some string.
4028
4029     Transformations:
4030         The chosen string is wrapped in parentheses and then split at the LPAR.
4031
4032         We then have one line which ends with an LPAR and another line that
4033         starts with the chosen string. The latter line is then split again at
4034         the RPAR. This results in the RPAR (and possibly a trailing comma)
4035         being placed on its own line.
4036
4037         NOTE: If any leaves exist to the right of the chosen string (except
4038         for a trailing comma, which would be placed after the RPAR), those
4039         leaves are placed inside the parentheses.  In effect, the chosen
4040         string is not necessarily being "wrapped" by parentheses. We can,
4041         however, count on the LPAR being placed directly before the chosen
4042         string.
4043
4044         In other words, StringParenWrapper creates "atom" strings. These
4045         can then be split again by StringSplitter, if necessary.
4046
4047     Collaborations:
4048         In the event that a string line split by StringParenWrapper is
4049         changed such that it no longer needs to be given its own line,
4050         StringParenWrapper relies on StringParenStripper to clean up the
4051         parentheses it created.
4052     """
4053
4054     def do_splitter_match(self, line: Line) -> TMatchResult:
4055         LL = line.leaves
4056
4057         string_idx = None
4058         string_idx = string_idx or self._return_match(LL)
4059         string_idx = string_idx or self._else_match(LL)
4060         string_idx = string_idx or self._assert_match(LL)
4061         string_idx = string_idx or self._assign_match(LL)
4062         string_idx = string_idx or self._dict_match(LL)
4063
4064         if string_idx is not None:
4065             string_value = line.leaves[string_idx].value
4066             # If the string has no spaces...
4067             if " " not in string_value:
4068                 # And will still violate the line length limit when split...
4069                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4070                 if len(string_value) > max_string_length:
4071                     # And has no associated custom splits...
4072                     if not self.has_custom_splits(string_value):
4073                         # Then we should NOT put this string on its own line.
4074                         return TErr(
4075                             "We do not wrap long strings in parentheses when the"
4076                             " resultant line would still be over the specified line"
4077                             " length and can't be split further by StringSplitter."
4078                         )
4079             return Ok(string_idx)
4080
4081         return TErr("This line does not contain any non-atomic strings.")
4082
4083     @staticmethod
4084     def _return_match(LL: List[Leaf]) -> Optional[int]:
4085         """
4086         Returns:
4087             string_idx such that @LL[string_idx] is equal to our target (i.e.
4088             matched) string, if this line matches the return/yield statement
4089             requirements listed in the 'Requirements' section of this classes'
4090             docstring.
4091                 OR
4092             None, otherwise.
4093         """
4094         # If this line is apart of a return/yield statement and the first leaf
4095         # contains either the "return" or "yield" keywords...
4096         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4097             0
4098         ].value in ["return", "yield"]:
4099             is_valid_index = is_valid_index_factory(LL)
4100
4101             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4102             # The next visible leaf MUST contain a string...
4103             if is_valid_index(idx) and LL[idx].type == token.STRING:
4104                 return idx
4105
4106         return None
4107
4108     @staticmethod
4109     def _else_match(LL: List[Leaf]) -> Optional[int]:
4110         """
4111         Returns:
4112             string_idx such that @LL[string_idx] is equal to our target (i.e.
4113             matched) string, if this line matches the ternary expression
4114             requirements listed in the 'Requirements' section of this classes'
4115             docstring.
4116                 OR
4117             None, otherwise.
4118         """
4119         # If this line is apart of a ternary expression and the first leaf
4120         # contains the "else" keyword...
4121         if (
4122             parent_type(LL[0]) == syms.test
4123             and LL[0].type == token.NAME
4124             and LL[0].value == "else"
4125         ):
4126             is_valid_index = is_valid_index_factory(LL)
4127
4128             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4129             # The next visible leaf MUST contain a string...
4130             if is_valid_index(idx) and LL[idx].type == token.STRING:
4131                 return idx
4132
4133         return None
4134
4135     @staticmethod
4136     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4137         """
4138         Returns:
4139             string_idx such that @LL[string_idx] is equal to our target (i.e.
4140             matched) string, if this line matches the assert statement
4141             requirements listed in the 'Requirements' section of this classes'
4142             docstring.
4143                 OR
4144             None, otherwise.
4145         """
4146         # If this line is apart of an assert statement and the first leaf
4147         # contains the "assert" keyword...
4148         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4149             is_valid_index = is_valid_index_factory(LL)
4150
4151             for (i, leaf) in enumerate(LL):
4152                 # We MUST find a comma...
4153                 if leaf.type == token.COMMA:
4154                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4155
4156                     # That comma MUST be followed by a string...
4157                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4158                         string_idx = idx
4159
4160                         # Skip the string trailer, if one exists.
4161                         string_parser = StringParser()
4162                         idx = string_parser.parse(LL, string_idx)
4163
4164                         # But no more leaves are allowed...
4165                         if not is_valid_index(idx):
4166                             return string_idx
4167
4168         return None
4169
4170     @staticmethod
4171     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4172         """
4173         Returns:
4174             string_idx such that @LL[string_idx] is equal to our target (i.e.
4175             matched) string, if this line matches the assignment statement
4176             requirements listed in the 'Requirements' section of this classes'
4177             docstring.
4178                 OR
4179             None, otherwise.
4180         """
4181         # If this line is apart of an expression statement or is a function
4182         # argument AND the first leaf contains a variable name...
4183         if (
4184             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4185             and LL[0].type == token.NAME
4186         ):
4187             is_valid_index = is_valid_index_factory(LL)
4188
4189             for (i, leaf) in enumerate(LL):
4190                 # We MUST find either an '=' or '+=' symbol...
4191                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4192                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4193
4194                     # That symbol MUST be followed by a string...
4195                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4196                         string_idx = idx
4197
4198                         # Skip the string trailer, if one exists.
4199                         string_parser = StringParser()
4200                         idx = string_parser.parse(LL, string_idx)
4201
4202                         # The next leaf MAY be a comma iff this line is apart
4203                         # of a function argument...
4204                         if (
4205                             parent_type(LL[0]) == syms.argument
4206                             and is_valid_index(idx)
4207                             and LL[idx].type == token.COMMA
4208                         ):
4209                             idx += 1
4210
4211                         # But no more leaves are allowed...
4212                         if not is_valid_index(idx):
4213                             return string_idx
4214
4215         return None
4216
4217     @staticmethod
4218     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4219         """
4220         Returns:
4221             string_idx such that @LL[string_idx] is equal to our target (i.e.
4222             matched) string, if this line matches the dictionary key assignment
4223             statement requirements listed in the 'Requirements' section of this
4224             classes' docstring.
4225                 OR
4226             None, otherwise.
4227         """
4228         # If this line is apart of a dictionary key assignment...
4229         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4230             is_valid_index = is_valid_index_factory(LL)
4231
4232             for (i, leaf) in enumerate(LL):
4233                 # We MUST find a colon...
4234                 if leaf.type == token.COLON:
4235                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4236
4237                     # That colon MUST be followed by a string...
4238                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4239                         string_idx = idx
4240
4241                         # Skip the string trailer, if one exists.
4242                         string_parser = StringParser()
4243                         idx = string_parser.parse(LL, string_idx)
4244
4245                         # That string MAY be followed by a comma...
4246                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4247                             idx += 1
4248
4249                         # But no more leaves are allowed...
4250                         if not is_valid_index(idx):
4251                             return string_idx
4252
4253         return None
4254
4255     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4256         LL = line.leaves
4257
4258         is_valid_index = is_valid_index_factory(LL)
4259         insert_str_child = insert_str_child_factory(LL[string_idx])
4260
4261         comma_idx = len(LL) - 1
4262         ends_with_comma = False
4263         if LL[comma_idx].type == token.COMMA:
4264             ends_with_comma = True
4265
4266         leaves_to_steal_comments_from = [LL[string_idx]]
4267         if ends_with_comma:
4268             leaves_to_steal_comments_from.append(LL[comma_idx])
4269
4270         # --- First Line
4271         first_line = line.clone()
4272         left_leaves = LL[:string_idx]
4273
4274         # We have to remember to account for (possibly invisible) LPAR and RPAR
4275         # leaves that already wrapped the target string. If these leaves do
4276         # exist, we will replace them with our own LPAR and RPAR leaves.
4277         old_parens_exist = False
4278         if left_leaves and left_leaves[-1].type == token.LPAR:
4279             old_parens_exist = True
4280             leaves_to_steal_comments_from.append(left_leaves[-1])
4281             left_leaves.pop()
4282
4283         append_leaves(first_line, line, left_leaves)
4284
4285         lpar_leaf = Leaf(token.LPAR, "(")
4286         if old_parens_exist:
4287             replace_child(LL[string_idx - 1], lpar_leaf)
4288         else:
4289             insert_str_child(lpar_leaf)
4290         first_line.append(lpar_leaf)
4291
4292         # We throw inline comments that were originally to the right of the
4293         # target string to the top line. They will now be shown to the right of
4294         # the LPAR.
4295         for leaf in leaves_to_steal_comments_from:
4296             for comment_leaf in line.comments_after(leaf):
4297                 first_line.append(comment_leaf, preformatted=True)
4298
4299         yield Ok(first_line)
4300
4301         # --- Middle (String) Line
4302         # We only need to yield one (possibly too long) string line, since the
4303         # `StringSplitter` will break it down further if necessary.
4304         string_value = LL[string_idx].value
4305         string_line = Line(
4306             depth=line.depth + 1,
4307             inside_brackets=True,
4308             should_explode=line.should_explode,
4309         )
4310         string_leaf = Leaf(token.STRING, string_value)
4311         insert_str_child(string_leaf)
4312         string_line.append(string_leaf)
4313
4314         old_rpar_leaf = None
4315         if is_valid_index(string_idx + 1):
4316             right_leaves = LL[string_idx + 1 :]
4317             if ends_with_comma:
4318                 right_leaves.pop()
4319
4320             if old_parens_exist:
4321                 assert (
4322                     right_leaves and right_leaves[-1].type == token.RPAR
4323                 ), "Apparently, old parentheses do NOT exist?!"
4324                 old_rpar_leaf = right_leaves.pop()
4325
4326             append_leaves(string_line, line, right_leaves)
4327
4328         yield Ok(string_line)
4329
4330         # --- Last Line
4331         last_line = line.clone()
4332         last_line.bracket_tracker = first_line.bracket_tracker
4333
4334         new_rpar_leaf = Leaf(token.RPAR, ")")
4335         if old_rpar_leaf is not None:
4336             replace_child(old_rpar_leaf, new_rpar_leaf)
4337         else:
4338             insert_str_child(new_rpar_leaf)
4339         last_line.append(new_rpar_leaf)
4340
4341         # If the target string ended with a comma, we place this comma to the
4342         # right of the RPAR on the last line.
4343         if ends_with_comma:
4344             comma_leaf = Leaf(token.COMMA, ",")
4345             replace_child(LL[comma_idx], comma_leaf)
4346             last_line.append(comma_leaf)
4347
4348         yield Ok(last_line)
4349
4350
4351 class StringParser:
4352     """
4353     A state machine that aids in parsing a string's "trailer", which can be
4354     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4355     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4356     varY)`).
4357
4358     NOTE: A new StringParser object MUST be instantiated for each string
4359     trailer we need to parse.
4360
4361     Examples:
4362         We shall assume that `line` equals the `Line` object that corresponds
4363         to the following line of python code:
4364         ```
4365         x = "Some {}.".format("String") + some_other_string
4366         ```
4367
4368         Furthermore, we will assume that `string_idx` is some index such that:
4369         ```
4370         assert line.leaves[string_idx].value == "Some {}."
4371         ```
4372
4373         The following code snippet then holds:
4374         ```
4375         string_parser = StringParser()
4376         idx = string_parser.parse(line.leaves, string_idx)
4377         assert line.leaves[idx].type == token.PLUS
4378         ```
4379     """
4380
4381     DEFAULT_TOKEN = -1
4382
4383     # String Parser States
4384     START = 1
4385     DOT = 2
4386     NAME = 3
4387     PERCENT = 4
4388     SINGLE_FMT_ARG = 5
4389     LPAR = 6
4390     RPAR = 7
4391     DONE = 8
4392
4393     # Lookup Table for Next State
4394     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4395         # A string trailer may start with '.' OR '%'.
4396         (START, token.DOT): DOT,
4397         (START, token.PERCENT): PERCENT,
4398         (START, DEFAULT_TOKEN): DONE,
4399         # A '.' MUST be followed by an attribute or method name.
4400         (DOT, token.NAME): NAME,
4401         # A method name MUST be followed by an '(', whereas an attribute name
4402         # is the last symbol in the string trailer.
4403         (NAME, token.LPAR): LPAR,
4404         (NAME, DEFAULT_TOKEN): DONE,
4405         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4406         # string or variable name).
4407         (PERCENT, token.LPAR): LPAR,
4408         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4409         # If a '%' symbol is followed by a single argument, that argument is
4410         # the last leaf in the string trailer.
4411         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4412         # If present, a ')' symbol is the last symbol in a string trailer.
4413         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4414         # since they are treated as a special case by the parsing logic in this
4415         # classes' implementation.)
4416         (RPAR, DEFAULT_TOKEN): DONE,
4417     }
4418
4419     def __init__(self) -> None:
4420         self._state = self.START
4421         self._unmatched_lpars = 0
4422
4423     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4424         """
4425         Pre-conditions:
4426             * @leaves[@string_idx].type == token.STRING
4427
4428         Returns:
4429             The index directly after the last leaf which is apart of the string
4430             trailer, if a "trailer" exists.
4431                 OR
4432             @string_idx + 1, if no string "trailer" exists.
4433         """
4434         assert leaves[string_idx].type == token.STRING
4435
4436         idx = string_idx + 1
4437         while idx < len(leaves) and self._next_state(leaves[idx]):
4438             idx += 1
4439         return idx
4440
4441     def _next_state(self, leaf: Leaf) -> bool:
4442         """
4443         Pre-conditions:
4444             * On the first call to this function, @leaf MUST be the leaf that
4445             was directly after the string leaf in question (e.g. if our target
4446             string is `line.leaves[i]` then the first call to this method must
4447             be `line.leaves[i + 1]`).
4448             * On the next call to this function, the leaf parameter passed in
4449             MUST be the leaf directly following @leaf.
4450
4451         Returns:
4452             True iff @leaf is apart of the string's trailer.
4453         """
4454         # We ignore empty LPAR or RPAR leaves.
4455         if is_empty_par(leaf):
4456             return True
4457
4458         next_token = leaf.type
4459         if next_token == token.LPAR:
4460             self._unmatched_lpars += 1
4461
4462         current_state = self._state
4463
4464         # The LPAR parser state is a special case. We will return True until we
4465         # find the matching RPAR token.
4466         if current_state == self.LPAR:
4467             if next_token == token.RPAR:
4468                 self._unmatched_lpars -= 1
4469                 if self._unmatched_lpars == 0:
4470                     self._state = self.RPAR
4471         # Otherwise, we use a lookup table to determine the next state.
4472         else:
4473             # If the lookup table matches the current state to the next
4474             # token, we use the lookup table.
4475             if (current_state, next_token) in self._goto:
4476                 self._state = self._goto[current_state, next_token]
4477             else:
4478                 # Otherwise, we check if a the current state was assigned a
4479                 # default.
4480                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4481                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4482                 # If no default has been assigned, then this parser has a logic
4483                 # error.
4484                 else:
4485                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4486
4487             if self._state == self.DONE:
4488                 return False
4489
4490         return True
4491
4492
4493 def TErr(err_msg: str) -> Err[CannotTransform]:
4494     """(T)ransform Err
4495
4496     Convenience function used when working with the TResult type.
4497     """
4498     cant_transform = CannotTransform(err_msg)
4499     return Err(cant_transform)
4500
4501
4502 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4503     """
4504     Returns:
4505         True iff one of the comments in @comment_list is a pragma used by one
4506         of the more common static analysis tools for python (e.g. mypy, flake8,
4507         pylint).
4508     """
4509     for comment in comment_list:
4510         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4511             return True
4512
4513     return False
4514
4515
4516 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4517     """
4518     Factory for a convenience function that is used to orphan @string_leaf
4519     and then insert multiple new leaves into the same part of the node
4520     structure that @string_leaf had originally occupied.
4521
4522     Examples:
4523         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4524         string_leaf.parent`. Assume the node `N` has the following
4525         original structure:
4526
4527         Node(
4528             expr_stmt, [
4529                 Leaf(NAME, 'x'),
4530                 Leaf(EQUAL, '='),
4531                 Leaf(STRING, '"foo"'),
4532             ]
4533         )
4534
4535         We then run the code snippet shown below.
4536         ```
4537         insert_str_child = insert_str_child_factory(string_leaf)
4538
4539         lpar = Leaf(token.LPAR, '(')
4540         insert_str_child(lpar)
4541
4542         bar = Leaf(token.STRING, '"bar"')
4543         insert_str_child(bar)
4544
4545         rpar = Leaf(token.RPAR, ')')
4546         insert_str_child(rpar)
4547         ```
4548
4549         After which point, it follows that `string_leaf.parent is None` and
4550         the node `N` now has the following structure:
4551
4552         Node(
4553             expr_stmt, [
4554                 Leaf(NAME, 'x'),
4555                 Leaf(EQUAL, '='),
4556                 Leaf(LPAR, '('),
4557                 Leaf(STRING, '"bar"'),
4558                 Leaf(RPAR, ')'),
4559             ]
4560         )
4561     """
4562     string_parent = string_leaf.parent
4563     string_child_idx = string_leaf.remove()
4564
4565     def insert_str_child(child: LN) -> None:
4566         nonlocal string_child_idx
4567
4568         assert string_parent is not None
4569         assert string_child_idx is not None
4570
4571         string_parent.insert_child(string_child_idx, child)
4572         string_child_idx += 1
4573
4574     return insert_str_child
4575
4576
4577 def has_triple_quotes(string: str) -> bool:
4578     """
4579     Returns:
4580         True iff @string starts with three quotation characters.
4581     """
4582     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4583     return raw_string[:3] in {'"""', "'''"}
4584
4585
4586 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4587     """
4588     Returns:
4589         @node.parent.type, if @node is not None and has a parent.
4590             OR
4591         None, otherwise.
4592     """
4593     if node is None or node.parent is None:
4594         return None
4595
4596     return node.parent.type
4597
4598
4599 def is_empty_par(leaf: Leaf) -> bool:
4600     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4601
4602
4603 def is_empty_lpar(leaf: Leaf) -> bool:
4604     return leaf.type == token.LPAR and leaf.value == ""
4605
4606
4607 def is_empty_rpar(leaf: Leaf) -> bool:
4608     return leaf.type == token.RPAR and leaf.value == ""
4609
4610
4611 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4612     """
4613     Examples:
4614         ```
4615         my_list = [1, 2, 3]
4616
4617         is_valid_index = is_valid_index_factory(my_list)
4618
4619         assert is_valid_index(0)
4620         assert is_valid_index(2)
4621
4622         assert not is_valid_index(3)
4623         assert not is_valid_index(-1)
4624         ```
4625     """
4626
4627     def is_valid_index(idx: int) -> bool:
4628         """
4629         Returns:
4630             True iff @idx is positive AND seq[@idx] does NOT raise an
4631             IndexError.
4632         """
4633         return 0 <= idx < len(seq)
4634
4635     return is_valid_index
4636
4637
4638 def line_to_string(line: Line) -> str:
4639     """Returns the string representation of @line.
4640
4641     WARNING: This is known to be computationally expensive.
4642     """
4643     return str(line).strip("\n")
4644
4645
4646 def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
4647     """
4648     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4649     underlying Node structure where appropriate.
4650
4651     All of the leaves in @leaves are duplicated. The duplicates are then
4652     appended to @new_line and used to replace their originals in the underlying
4653     Node structure. Any comments attached to the old leaves are reattached to
4654     the new leaves.
4655
4656     Pre-conditions:
4657         set(@leaves) is a subset of set(@old_line.leaves).
4658     """
4659     for old_leaf in leaves:
4660         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4661         replace_child(old_leaf, new_leaf)
4662         new_line.append(new_leaf)
4663
4664         for comment_leaf in old_line.comments_after(old_leaf):
4665             new_line.append(comment_leaf, preformatted=True)
4666
4667
4668 def replace_child(old_child: LN, new_child: LN) -> None:
4669     """
4670     Side Effects:
4671         * If @old_child.parent is set, replace @old_child with @new_child in
4672         @old_child's underlying Node structure.
4673             OR
4674         * Otherwise, this function does nothing.
4675     """
4676     parent = old_child.parent
4677     if not parent:
4678         return
4679
4680     child_idx = old_child.remove()
4681     if child_idx is not None:
4682         parent.insert_child(child_idx, new_child)
4683
4684
4685 def get_string_prefix(string: str) -> str:
4686     """
4687     Pre-conditions:
4688         * assert_is_leaf_string(@string)
4689
4690     Returns:
4691         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4692     """
4693     assert_is_leaf_string(string)
4694
4695     prefix = ""
4696     prefix_idx = 0
4697     while string[prefix_idx] in STRING_PREFIX_CHARS:
4698         prefix += string[prefix_idx].lower()
4699         prefix_idx += 1
4700
4701     return prefix
4702
4703
4704 def assert_is_leaf_string(string: str) -> None:
4705     """
4706     Checks the pre-condition that @string has the format that you would expect
4707     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4708     token.STRING`. A more precise description of the pre-conditions that are
4709     checked are listed below.
4710
4711     Pre-conditions:
4712         * @string starts with either ', ", <prefix>', or <prefix>" where
4713         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4714         * @string ends with a quote character (' or ").
4715
4716     Raises:
4717         AssertionError(...) if the pre-conditions listed above are not
4718         satisfied.
4719     """
4720     dquote_idx = string.find('"')
4721     squote_idx = string.find("'")
4722     if -1 in [dquote_idx, squote_idx]:
4723         quote_idx = max(dquote_idx, squote_idx)
4724     else:
4725         quote_idx = min(squote_idx, dquote_idx)
4726
4727     assert (
4728         0 <= quote_idx < len(string) - 1
4729     ), f"{string!r} is missing a starting quote character (' or \")."
4730     assert string[-1] in (
4731         "'",
4732         '"',
4733     ), f"{string!r} is missing an ending quote character (' or \")."
4734     assert set(string[:quote_idx]).issubset(
4735         set(STRING_PREFIX_CHARS)
4736     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4737
4738
4739 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4740     """Split line into many lines, starting with the first matching bracket pair.
4741
4742     Note: this usually looks weird, only use this for function definitions.
4743     Prefer RHS otherwise.  This is why this function is not symmetrical with
4744     :func:`right_hand_split` which also handles optional parentheses.
4745     """
4746     tail_leaves: List[Leaf] = []
4747     body_leaves: List[Leaf] = []
4748     head_leaves: List[Leaf] = []
4749     current_leaves = head_leaves
4750     matching_bracket: Optional[Leaf] = None
4751     for leaf in line.leaves:
4752         if (
4753             current_leaves is body_leaves
4754             and leaf.type in CLOSING_BRACKETS
4755             and leaf.opening_bracket is matching_bracket
4756         ):
4757             current_leaves = tail_leaves if body_leaves else head_leaves
4758         current_leaves.append(leaf)
4759         if current_leaves is head_leaves:
4760             if leaf.type in OPENING_BRACKETS:
4761                 matching_bracket = leaf
4762                 current_leaves = body_leaves
4763     if not matching_bracket:
4764         raise CannotSplit("No brackets found")
4765
4766     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4767     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4768     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4769     bracket_split_succeeded_or_raise(head, body, tail)
4770     for result in (head, body, tail):
4771         if result:
4772             yield result
4773
4774
4775 def right_hand_split(
4776     line: Line,
4777     line_length: int,
4778     features: Collection[Feature] = (),
4779     omit: Collection[LeafID] = (),
4780 ) -> Iterator[Line]:
4781     """Split line into many lines, starting with the last matching bracket pair.
4782
4783     If the split was by optional parentheses, attempt splitting without them, too.
4784     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4785     this split.
4786
4787     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4788     """
4789     tail_leaves: List[Leaf] = []
4790     body_leaves: List[Leaf] = []
4791     head_leaves: List[Leaf] = []
4792     current_leaves = tail_leaves
4793     opening_bracket: Optional[Leaf] = None
4794     closing_bracket: Optional[Leaf] = None
4795     for leaf in reversed(line.leaves):
4796         if current_leaves is body_leaves:
4797             if leaf is opening_bracket:
4798                 current_leaves = head_leaves if body_leaves else tail_leaves
4799         current_leaves.append(leaf)
4800         if current_leaves is tail_leaves:
4801             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4802                 opening_bracket = leaf.opening_bracket
4803                 closing_bracket = leaf
4804                 current_leaves = body_leaves
4805     if not (opening_bracket and closing_bracket and head_leaves):
4806         # If there is no opening or closing_bracket that means the split failed and
4807         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4808         # the matching `opening_bracket` wasn't available on `line` anymore.
4809         raise CannotSplit("No brackets found")
4810
4811     tail_leaves.reverse()
4812     body_leaves.reverse()
4813     head_leaves.reverse()
4814     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4815     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4816     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4817     bracket_split_succeeded_or_raise(head, body, tail)
4818     if (
4819         # the body shouldn't be exploded
4820         not body.should_explode
4821         # the opening bracket is an optional paren
4822         and opening_bracket.type == token.LPAR
4823         and not opening_bracket.value
4824         # the closing bracket is an optional paren
4825         and closing_bracket.type == token.RPAR
4826         and not closing_bracket.value
4827         # it's not an import (optional parens are the only thing we can split on
4828         # in this case; attempting a split without them is a waste of time)
4829         and not line.is_import
4830         # there are no standalone comments in the body
4831         and not body.contains_standalone_comments(0)
4832         # and we can actually remove the parens
4833         and can_omit_invisible_parens(body, line_length)
4834     ):
4835         omit = {id(closing_bracket), *omit}
4836         try:
4837             yield from right_hand_split(line, line_length, features=features, omit=omit)
4838             return
4839
4840         except CannotSplit:
4841             if not (
4842                 can_be_split(body)
4843                 or is_line_short_enough(body, line_length=line_length)
4844             ):
4845                 raise CannotSplit(
4846                     "Splitting failed, body is still too long and can't be split."
4847                 )
4848
4849             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4850                 raise CannotSplit(
4851                     "The current optional pair of parentheses is bound to fail to"
4852                     " satisfy the splitting algorithm because the head or the tail"
4853                     " contains multiline strings which by definition never fit one"
4854                     " line."
4855                 )
4856
4857     ensure_visible(opening_bracket)
4858     ensure_visible(closing_bracket)
4859     for result in (head, body, tail):
4860         if result:
4861             yield result
4862
4863
4864 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4865     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4866
4867     Do nothing otherwise.
4868
4869     A left- or right-hand split is based on a pair of brackets. Content before
4870     (and including) the opening bracket is left on one line, content inside the
4871     brackets is put on a separate line, and finally content starting with and
4872     following the closing bracket is put on a separate line.
4873
4874     Those are called `head`, `body`, and `tail`, respectively. If the split
4875     produced the same line (all content in `head`) or ended up with an empty `body`
4876     and the `tail` is just the closing bracket, then it's considered failed.
4877     """
4878     tail_len = len(str(tail).strip())
4879     if not body:
4880         if tail_len == 0:
4881             raise CannotSplit("Splitting brackets produced the same line")
4882
4883         elif tail_len < 3:
4884             raise CannotSplit(
4885                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4886                 " not worth it"
4887             )
4888
4889
4890 def bracket_split_build_line(
4891     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4892 ) -> Line:
4893     """Return a new line with given `leaves` and respective comments from `original`.
4894
4895     If `is_body` is True, the result line is one-indented inside brackets and as such
4896     has its first leaf's prefix normalized and a trailing comma added when expected.
4897     """
4898     result = Line(depth=original.depth)
4899     if is_body:
4900         result.inside_brackets = True
4901         result.depth += 1
4902         if leaves:
4903             # Since body is a new indent level, remove spurious leading whitespace.
4904             normalize_prefix(leaves[0], inside_brackets=True)
4905             # Ensure a trailing comma for imports and standalone function arguments, but
4906             # be careful not to add one after any comments or within type annotations.
4907             no_commas = (
4908                 original.is_def
4909                 and opening_bracket.value == "("
4910                 and not any(leaf.type == token.COMMA for leaf in leaves)
4911             )
4912
4913             if original.is_import or no_commas:
4914                 for i in range(len(leaves) - 1, -1, -1):
4915                     if leaves[i].type == STANDALONE_COMMENT:
4916                         continue
4917
4918                     if leaves[i].type != token.COMMA:
4919                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
4920                     break
4921
4922     # Populate the line
4923     for leaf in leaves:
4924         result.append(leaf, preformatted=True)
4925         for comment_after in original.comments_after(leaf):
4926             result.append(comment_after, preformatted=True)
4927     if is_body:
4928         result.should_explode = should_explode(result, opening_bracket)
4929     return result
4930
4931
4932 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4933     """Normalize prefix of the first leaf in every line returned by `split_func`.
4934
4935     This is a decorator over relevant split functions.
4936     """
4937
4938     @wraps(split_func)
4939     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4940         for line in split_func(line, features):
4941             normalize_prefix(line.leaves[0], inside_brackets=True)
4942             yield line
4943
4944     return split_wrapper
4945
4946
4947 @dont_increase_indentation
4948 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4949     """Split according to delimiters of the highest priority.
4950
4951     If the appropriate Features are given, the split will add trailing commas
4952     also in function signatures and calls that contain `*` and `**`.
4953     """
4954     try:
4955         last_leaf = line.leaves[-1]
4956     except IndexError:
4957         raise CannotSplit("Line empty")
4958
4959     bt = line.bracket_tracker
4960     try:
4961         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4962     except ValueError:
4963         raise CannotSplit("No delimiters found")
4964
4965     if delimiter_priority == DOT_PRIORITY:
4966         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4967             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4968
4969     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4970     lowest_depth = sys.maxsize
4971     trailing_comma_safe = True
4972
4973     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4974         """Append `leaf` to current line or to new line if appending impossible."""
4975         nonlocal current_line
4976         try:
4977             current_line.append_safe(leaf, preformatted=True)
4978         except ValueError:
4979             yield current_line
4980
4981             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4982             current_line.append(leaf)
4983
4984     for leaf in line.leaves:
4985         yield from append_to_line(leaf)
4986
4987         for comment_after in line.comments_after(leaf):
4988             yield from append_to_line(comment_after)
4989
4990         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4991         if leaf.bracket_depth == lowest_depth:
4992             if is_vararg(leaf, within={syms.typedargslist}):
4993                 trailing_comma_safe = (
4994                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4995                 )
4996             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4997                 trailing_comma_safe = (
4998                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4999                 )
5000
5001         leaf_priority = bt.delimiters.get(id(leaf))
5002         if leaf_priority == delimiter_priority:
5003             yield current_line
5004
5005             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5006     if current_line:
5007         if (
5008             trailing_comma_safe
5009             and delimiter_priority == COMMA_PRIORITY
5010             and current_line.leaves[-1].type != token.COMMA
5011             and current_line.leaves[-1].type != STANDALONE_COMMENT
5012         ):
5013             current_line.append(Leaf(token.COMMA, ","))
5014         yield current_line
5015
5016
5017 @dont_increase_indentation
5018 def standalone_comment_split(
5019     line: Line, features: Collection[Feature] = ()
5020 ) -> Iterator[Line]:
5021     """Split standalone comments from the rest of the line."""
5022     if not line.contains_standalone_comments(0):
5023         raise CannotSplit("Line does not have any standalone comments")
5024
5025     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5026
5027     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5028         """Append `leaf` to current line or to new line if appending impossible."""
5029         nonlocal current_line
5030         try:
5031             current_line.append_safe(leaf, preformatted=True)
5032         except ValueError:
5033             yield current_line
5034
5035             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5036             current_line.append(leaf)
5037
5038     for leaf in line.leaves:
5039         yield from append_to_line(leaf)
5040
5041         for comment_after in line.comments_after(leaf):
5042             yield from append_to_line(comment_after)
5043
5044     if current_line:
5045         yield current_line
5046
5047
5048 def is_import(leaf: Leaf) -> bool:
5049     """Return True if the given leaf starts an import statement."""
5050     p = leaf.parent
5051     t = leaf.type
5052     v = leaf.value
5053     return bool(
5054         t == token.NAME
5055         and (
5056             (v == "import" and p and p.type == syms.import_name)
5057             or (v == "from" and p and p.type == syms.import_from)
5058         )
5059     )
5060
5061
5062 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5063     """Return True if the given leaf is a special comment.
5064     Only returns true for type comments for now."""
5065     t = leaf.type
5066     v = leaf.value
5067     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5068
5069
5070 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5071     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5072     else.
5073
5074     Note: don't use backslashes for formatting or you'll lose your voting rights.
5075     """
5076     if not inside_brackets:
5077         spl = leaf.prefix.split("#")
5078         if "\\" not in spl[0]:
5079             nl_count = spl[-1].count("\n")
5080             if len(spl) > 1:
5081                 nl_count -= 1
5082             leaf.prefix = "\n" * nl_count
5083             return
5084
5085     leaf.prefix = ""
5086
5087
5088 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5089     """Make all string prefixes lowercase.
5090
5091     If remove_u_prefix is given, also removes any u prefix from the string.
5092
5093     Note: Mutates its argument.
5094     """
5095     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5096     assert match is not None, f"failed to match string {leaf.value!r}"
5097     orig_prefix = match.group(1)
5098     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5099     if remove_u_prefix:
5100         new_prefix = new_prefix.replace("u", "")
5101     leaf.value = f"{new_prefix}{match.group(2)}"
5102
5103
5104 def normalize_string_quotes(leaf: Leaf) -> None:
5105     """Prefer double quotes but only if it doesn't cause more escaping.
5106
5107     Adds or removes backslashes as appropriate. Doesn't parse and fix
5108     strings nested in f-strings (yet).
5109
5110     Note: Mutates its argument.
5111     """
5112     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5113     if value[:3] == '"""':
5114         return
5115
5116     elif value[:3] == "'''":
5117         orig_quote = "'''"
5118         new_quote = '"""'
5119     elif value[0] == '"':
5120         orig_quote = '"'
5121         new_quote = "'"
5122     else:
5123         orig_quote = "'"
5124         new_quote = '"'
5125     first_quote_pos = leaf.value.find(orig_quote)
5126     if first_quote_pos == -1:
5127         return  # There's an internal error
5128
5129     prefix = leaf.value[:first_quote_pos]
5130     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5131     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5132     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5133     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5134     if "r" in prefix.casefold():
5135         if unescaped_new_quote.search(body):
5136             # There's at least one unescaped new_quote in this raw string
5137             # so converting is impossible
5138             return
5139
5140         # Do not introduce or remove backslashes in raw strings
5141         new_body = body
5142     else:
5143         # remove unnecessary escapes
5144         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5145         if body != new_body:
5146             # Consider the string without unnecessary escapes as the original
5147             body = new_body
5148             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5149         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5150         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5151     if "f" in prefix.casefold():
5152         matches = re.findall(
5153             r"""
5154             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5155                 ([^{].*?)  # contents of the brackets except if begins with {{
5156             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5157             """,
5158             new_body,
5159             re.VERBOSE,
5160         )
5161         for m in matches:
5162             if "\\" in str(m):
5163                 # Do not introduce backslashes in interpolated expressions
5164                 return
5165
5166     if new_quote == '"""' and new_body[-1:] == '"':
5167         # edge case:
5168         new_body = new_body[:-1] + '\\"'
5169     orig_escape_count = body.count("\\")
5170     new_escape_count = new_body.count("\\")
5171     if new_escape_count > orig_escape_count:
5172         return  # Do not introduce more escaping
5173
5174     if new_escape_count == orig_escape_count and orig_quote == '"':
5175         return  # Prefer double quotes
5176
5177     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5178
5179
5180 def normalize_numeric_literal(leaf: Leaf) -> None:
5181     """Normalizes numeric (float, int, and complex) literals.
5182
5183     All letters used in the representation are normalized to lowercase (except
5184     in Python 2 long literals).
5185     """
5186     text = leaf.value.lower()
5187     if text.startswith(("0o", "0b")):
5188         # Leave octal and binary literals alone.
5189         pass
5190     elif text.startswith("0x"):
5191         # Change hex literals to upper case.
5192         before, after = text[:2], text[2:]
5193         text = f"{before}{after.upper()}"
5194     elif "e" in text:
5195         before, after = text.split("e")
5196         sign = ""
5197         if after.startswith("-"):
5198             after = after[1:]
5199             sign = "-"
5200         elif after.startswith("+"):
5201             after = after[1:]
5202         before = format_float_or_int_string(before)
5203         text = f"{before}e{sign}{after}"
5204     elif text.endswith(("j", "l")):
5205         number = text[:-1]
5206         suffix = text[-1]
5207         # Capitalize in "2L" because "l" looks too similar to "1".
5208         if suffix == "l":
5209             suffix = "L"
5210         text = f"{format_float_or_int_string(number)}{suffix}"
5211     else:
5212         text = format_float_or_int_string(text)
5213     leaf.value = text
5214
5215
5216 def format_float_or_int_string(text: str) -> str:
5217     """Formats a float string like "1.0"."""
5218     if "." not in text:
5219         return text
5220
5221     before, after = text.split(".")
5222     return f"{before or 0}.{after or 0}"
5223
5224
5225 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5226     """Make existing optional parentheses invisible or create new ones.
5227
5228     `parens_after` is a set of string leaf values immediately after which parens
5229     should be put.
5230
5231     Standardizes on visible parentheses for single-element tuples, and keeps
5232     existing visible parentheses for other tuples and generator expressions.
5233     """
5234     for pc in list_comments(node.prefix, is_endmarker=False):
5235         if pc.value in FMT_OFF:
5236             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5237             return
5238     check_lpar = False
5239     for index, child in enumerate(list(node.children)):
5240         # Fixes a bug where invisible parens are not properly stripped from
5241         # assignment statements that contain type annotations.
5242         if isinstance(child, Node) and child.type == syms.annassign:
5243             normalize_invisible_parens(child, parens_after=parens_after)
5244
5245         # Add parentheses around long tuple unpacking in assignments.
5246         if (
5247             index == 0
5248             and isinstance(child, Node)
5249             and child.type == syms.testlist_star_expr
5250         ):
5251             check_lpar = True
5252
5253         if check_lpar:
5254             if is_walrus_assignment(child):
5255                 continue
5256
5257             if child.type == syms.atom:
5258                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5259                     wrap_in_parentheses(node, child, visible=False)
5260             elif is_one_tuple(child):
5261                 wrap_in_parentheses(node, child, visible=True)
5262             elif node.type == syms.import_from:
5263                 # "import from" nodes store parentheses directly as part of
5264                 # the statement
5265                 if child.type == token.LPAR:
5266                     # make parentheses invisible
5267                     child.value = ""  # type: ignore
5268                     node.children[-1].value = ""  # type: ignore
5269                 elif child.type != token.STAR:
5270                     # insert invisible parentheses
5271                     node.insert_child(index, Leaf(token.LPAR, ""))
5272                     node.append_child(Leaf(token.RPAR, ""))
5273                 break
5274
5275             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5276                 wrap_in_parentheses(node, child, visible=False)
5277
5278         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5279
5280
5281 def normalize_fmt_off(node: Node) -> None:
5282     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5283     try_again = True
5284     while try_again:
5285         try_again = convert_one_fmt_off_pair(node)
5286
5287
5288 def convert_one_fmt_off_pair(node: Node) -> bool:
5289     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5290
5291     Returns True if a pair was converted.
5292     """
5293     for leaf in node.leaves():
5294         previous_consumed = 0
5295         for comment in list_comments(leaf.prefix, is_endmarker=False):
5296             if comment.value in FMT_OFF:
5297                 # We only want standalone comments. If there's no previous leaf or
5298                 # the previous leaf is indentation, it's a standalone comment in
5299                 # disguise.
5300                 if comment.type != STANDALONE_COMMENT:
5301                     prev = preceding_leaf(leaf)
5302                     if prev and prev.type not in WHITESPACE:
5303                         continue
5304
5305                 ignored_nodes = list(generate_ignored_nodes(leaf))
5306                 if not ignored_nodes:
5307                     continue
5308
5309                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5310                 parent = first.parent
5311                 prefix = first.prefix
5312                 first.prefix = prefix[comment.consumed :]
5313                 hidden_value = (
5314                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5315                 )
5316                 if hidden_value.endswith("\n"):
5317                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5318                     # leaf (possibly followed by a DEDENT).
5319                     hidden_value = hidden_value[:-1]
5320                 first_idx: Optional[int] = None
5321                 for ignored in ignored_nodes:
5322                     index = ignored.remove()
5323                     if first_idx is None:
5324                         first_idx = index
5325                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5326                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5327                 parent.insert_child(
5328                     first_idx,
5329                     Leaf(
5330                         STANDALONE_COMMENT,
5331                         hidden_value,
5332                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5333                     ),
5334                 )
5335                 return True
5336
5337             previous_consumed = comment.consumed
5338
5339     return False
5340
5341
5342 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5343     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5344
5345     Stops at the end of the block.
5346     """
5347     container: Optional[LN] = container_of(leaf)
5348     while container is not None and container.type != token.ENDMARKER:
5349         if is_fmt_on(container):
5350             return
5351
5352         # fix for fmt: on in children
5353         if contains_fmt_on_at_column(container, leaf.column):
5354             for child in container.children:
5355                 if contains_fmt_on_at_column(child, leaf.column):
5356                     return
5357                 yield child
5358         else:
5359             yield container
5360             container = container.next_sibling
5361
5362
5363 def is_fmt_on(container: LN) -> bool:
5364     """Determine whether formatting is switched on within a container.
5365     Determined by whether the last `# fmt:` comment is `on` or `off`.
5366     """
5367     fmt_on = False
5368     for comment in list_comments(container.prefix, is_endmarker=False):
5369         if comment.value in FMT_ON:
5370             fmt_on = True
5371         elif comment.value in FMT_OFF:
5372             fmt_on = False
5373     return fmt_on
5374
5375
5376 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5377     """Determine if children at a given column have formatting switched on."""
5378     for child in container.children:
5379         if (
5380             isinstance(child, Node)
5381             and first_leaf_column(child) == column
5382             or isinstance(child, Leaf)
5383             and child.column == column
5384         ):
5385             if is_fmt_on(child):
5386                 return True
5387
5388     return False
5389
5390
5391 def first_leaf_column(node: Node) -> Optional[int]:
5392     """Returns the column of the first leaf child of a node."""
5393     for child in node.children:
5394         if isinstance(child, Leaf):
5395             return child.column
5396     return None
5397
5398
5399 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5400     """If it's safe, make the parens in the atom `node` invisible, recursively.
5401     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5402     as they are redundant.
5403
5404     Returns whether the node should itself be wrapped in invisible parentheses.
5405
5406     """
5407     if (
5408         node.type != syms.atom
5409         or is_empty_tuple(node)
5410         or is_one_tuple(node)
5411         or (is_yield(node) and parent.type != syms.expr_stmt)
5412         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5413     ):
5414         return False
5415
5416     first = node.children[0]
5417     last = node.children[-1]
5418     if first.type == token.LPAR and last.type == token.RPAR:
5419         middle = node.children[1]
5420         # make parentheses invisible
5421         first.value = ""  # type: ignore
5422         last.value = ""  # type: ignore
5423         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5424
5425         if is_atom_with_invisible_parens(middle):
5426             # Strip the invisible parens from `middle` by replacing
5427             # it with the child in-between the invisible parens
5428             middle.replace(middle.children[1])
5429
5430         return False
5431
5432     return True
5433
5434
5435 def is_atom_with_invisible_parens(node: LN) -> bool:
5436     """Given a `LN`, determines whether it's an atom `node` with invisible
5437     parens. Useful in dedupe-ing and normalizing parens.
5438     """
5439     if isinstance(node, Leaf) or node.type != syms.atom:
5440         return False
5441
5442     first, last = node.children[0], node.children[-1]
5443     return (
5444         isinstance(first, Leaf)
5445         and first.type == token.LPAR
5446         and first.value == ""
5447         and isinstance(last, Leaf)
5448         and last.type == token.RPAR
5449         and last.value == ""
5450     )
5451
5452
5453 def is_empty_tuple(node: LN) -> bool:
5454     """Return True if `node` holds an empty tuple."""
5455     return (
5456         node.type == syms.atom
5457         and len(node.children) == 2
5458         and node.children[0].type == token.LPAR
5459         and node.children[1].type == token.RPAR
5460     )
5461
5462
5463 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5464     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5465
5466     Parenthesis can be optional. Returns None otherwise"""
5467     if len(node.children) != 3:
5468         return None
5469
5470     lpar, wrapped, rpar = node.children
5471     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5472         return None
5473
5474     return wrapped
5475
5476
5477 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5478     """Wrap `child` in parentheses.
5479
5480     This replaces `child` with an atom holding the parentheses and the old
5481     child.  That requires moving the prefix.
5482
5483     If `visible` is False, the leaves will be valueless (and thus invisible).
5484     """
5485     lpar = Leaf(token.LPAR, "(" if visible else "")
5486     rpar = Leaf(token.RPAR, ")" if visible else "")
5487     prefix = child.prefix
5488     child.prefix = ""
5489     index = child.remove() or 0
5490     new_child = Node(syms.atom, [lpar, child, rpar])
5491     new_child.prefix = prefix
5492     parent.insert_child(index, new_child)
5493
5494
5495 def is_one_tuple(node: LN) -> bool:
5496     """Return True if `node` holds a tuple with one element, with or without parens."""
5497     if node.type == syms.atom:
5498         gexp = unwrap_singleton_parenthesis(node)
5499         if gexp is None or gexp.type != syms.testlist_gexp:
5500             return False
5501
5502         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5503
5504     return (
5505         node.type in IMPLICIT_TUPLE
5506         and len(node.children) == 2
5507         and node.children[1].type == token.COMMA
5508     )
5509
5510
5511 def is_walrus_assignment(node: LN) -> bool:
5512     """Return True iff `node` is of the shape ( test := test )"""
5513     inner = unwrap_singleton_parenthesis(node)
5514     return inner is not None and inner.type == syms.namedexpr_test
5515
5516
5517 def is_yield(node: LN) -> bool:
5518     """Return True if `node` holds a `yield` or `yield from` expression."""
5519     if node.type == syms.yield_expr:
5520         return True
5521
5522     if node.type == token.NAME and node.value == "yield":  # type: ignore
5523         return True
5524
5525     if node.type != syms.atom:
5526         return False
5527
5528     if len(node.children) != 3:
5529         return False
5530
5531     lpar, expr, rpar = node.children
5532     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5533         return is_yield(expr)
5534
5535     return False
5536
5537
5538 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5539     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5540
5541     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5542     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5543     extended iterable unpacking (PEP 3132) and additional unpacking
5544     generalizations (PEP 448).
5545     """
5546     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5547         return False
5548
5549     p = leaf.parent
5550     if p.type == syms.star_expr:
5551         # Star expressions are also used as assignment targets in extended
5552         # iterable unpacking (PEP 3132).  See what its parent is instead.
5553         if not p.parent:
5554             return False
5555
5556         p = p.parent
5557
5558     return p.type in within
5559
5560
5561 def is_multiline_string(leaf: Leaf) -> bool:
5562     """Return True if `leaf` is a multiline string that actually spans many lines."""
5563     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5564
5565
5566 def is_stub_suite(node: Node) -> bool:
5567     """Return True if `node` is a suite with a stub body."""
5568     if (
5569         len(node.children) != 4
5570         or node.children[0].type != token.NEWLINE
5571         or node.children[1].type != token.INDENT
5572         or node.children[3].type != token.DEDENT
5573     ):
5574         return False
5575
5576     return is_stub_body(node.children[2])
5577
5578
5579 def is_stub_body(node: LN) -> bool:
5580     """Return True if `node` is a simple statement containing an ellipsis."""
5581     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5582         return False
5583
5584     if len(node.children) != 2:
5585         return False
5586
5587     child = node.children[0]
5588     return (
5589         child.type == syms.atom
5590         and len(child.children) == 3
5591         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5592     )
5593
5594
5595 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5596     """Return maximum delimiter priority inside `node`.
5597
5598     This is specific to atoms with contents contained in a pair of parentheses.
5599     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5600     """
5601     if node.type != syms.atom:
5602         return 0
5603
5604     first = node.children[0]
5605     last = node.children[-1]
5606     if not (first.type == token.LPAR and last.type == token.RPAR):
5607         return 0
5608
5609     bt = BracketTracker()
5610     for c in node.children[1:-1]:
5611         if isinstance(c, Leaf):
5612             bt.mark(c)
5613         else:
5614             for leaf in c.leaves():
5615                 bt.mark(leaf)
5616     try:
5617         return bt.max_delimiter_priority()
5618
5619     except ValueError:
5620         return 0
5621
5622
5623 def ensure_visible(leaf: Leaf) -> None:
5624     """Make sure parentheses are visible.
5625
5626     They could be invisible as part of some statements (see
5627     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5628     """
5629     if leaf.type == token.LPAR:
5630         leaf.value = "("
5631     elif leaf.type == token.RPAR:
5632         leaf.value = ")"
5633
5634
5635 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
5636     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
5637
5638     if not (
5639         opening_bracket.parent
5640         and opening_bracket.parent.type in {syms.atom, syms.import_from}
5641         and opening_bracket.value in "[{("
5642     ):
5643         return False
5644
5645     try:
5646         last_leaf = line.leaves[-1]
5647         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
5648         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5649     except (IndexError, ValueError):
5650         return False
5651
5652     return max_priority == COMMA_PRIORITY
5653
5654
5655 def get_features_used(node: Node) -> Set[Feature]:
5656     """Return a set of (relatively) new Python features used in this file.
5657
5658     Currently looking for:
5659     - f-strings;
5660     - underscores in numeric literals;
5661     - trailing commas after * or ** in function signatures and calls;
5662     - positional only arguments in function signatures and lambdas;
5663     """
5664     features: Set[Feature] = set()
5665     for n in node.pre_order():
5666         if n.type == token.STRING:
5667             value_head = n.value[:2]  # type: ignore
5668             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5669                 features.add(Feature.F_STRINGS)
5670
5671         elif n.type == token.NUMBER:
5672             if "_" in n.value:  # type: ignore
5673                 features.add(Feature.NUMERIC_UNDERSCORES)
5674
5675         elif n.type == token.SLASH:
5676             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5677                 features.add(Feature.POS_ONLY_ARGUMENTS)
5678
5679         elif n.type == token.COLONEQUAL:
5680             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5681
5682         elif (
5683             n.type in {syms.typedargslist, syms.arglist}
5684             and n.children
5685             and n.children[-1].type == token.COMMA
5686         ):
5687             if n.type == syms.typedargslist:
5688                 feature = Feature.TRAILING_COMMA_IN_DEF
5689             else:
5690                 feature = Feature.TRAILING_COMMA_IN_CALL
5691
5692             for ch in n.children:
5693                 if ch.type in STARS:
5694                     features.add(feature)
5695
5696                 if ch.type == syms.argument:
5697                     for argch in ch.children:
5698                         if argch.type in STARS:
5699                             features.add(feature)
5700
5701     return features
5702
5703
5704 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5705     """Detect the version to target based on the nodes used."""
5706     features = get_features_used(node)
5707     return {
5708         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5709     }
5710
5711
5712 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5713     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5714
5715     Brackets can be omitted if the entire trailer up to and including
5716     a preceding closing bracket fits in one line.
5717
5718     Yielded sets are cumulative (contain results of previous yields, too).  First
5719     set is empty.
5720     """
5721
5722     omit: Set[LeafID] = set()
5723     yield omit
5724
5725     length = 4 * line.depth
5726     opening_bracket: Optional[Leaf] = None
5727     closing_bracket: Optional[Leaf] = None
5728     inner_brackets: Set[LeafID] = set()
5729     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5730         length += leaf_length
5731         if length > line_length:
5732             break
5733
5734         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5735         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5736             break
5737
5738         if opening_bracket:
5739             if leaf is opening_bracket:
5740                 opening_bracket = None
5741             elif leaf.type in CLOSING_BRACKETS:
5742                 inner_brackets.add(id(leaf))
5743         elif leaf.type in CLOSING_BRACKETS:
5744             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
5745                 # Empty brackets would fail a split so treat them as "inner"
5746                 # brackets (e.g. only add them to the `omit` set if another
5747                 # pair of brackets was good enough.
5748                 inner_brackets.add(id(leaf))
5749                 continue
5750
5751             if closing_bracket:
5752                 omit.add(id(closing_bracket))
5753                 omit.update(inner_brackets)
5754                 inner_brackets.clear()
5755                 yield omit
5756
5757             if leaf.value:
5758                 opening_bracket = leaf.opening_bracket
5759                 closing_bracket = leaf
5760
5761
5762 def get_future_imports(node: Node) -> Set[str]:
5763     """Return a set of __future__ imports in the file."""
5764     imports: Set[str] = set()
5765
5766     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5767         for child in children:
5768             if isinstance(child, Leaf):
5769                 if child.type == token.NAME:
5770                     yield child.value
5771
5772             elif child.type == syms.import_as_name:
5773                 orig_name = child.children[0]
5774                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5775                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5776                 yield orig_name.value
5777
5778             elif child.type == syms.import_as_names:
5779                 yield from get_imports_from_children(child.children)
5780
5781             else:
5782                 raise AssertionError("Invalid syntax parsing imports")
5783
5784     for child in node.children:
5785         if child.type != syms.simple_stmt:
5786             break
5787
5788         first_child = child.children[0]
5789         if isinstance(first_child, Leaf):
5790             # Continue looking if we see a docstring; otherwise stop.
5791             if (
5792                 len(child.children) == 2
5793                 and first_child.type == token.STRING
5794                 and child.children[1].type == token.NEWLINE
5795             ):
5796                 continue
5797
5798             break
5799
5800         elif first_child.type == syms.import_from:
5801             module_name = first_child.children[1]
5802             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5803                 break
5804
5805             imports |= set(get_imports_from_children(first_child.children[3:]))
5806         else:
5807             break
5808
5809     return imports
5810
5811
5812 @lru_cache()
5813 def get_gitignore(root: Path) -> PathSpec:
5814     """ Return a PathSpec matching gitignore content if present."""
5815     gitignore = root / ".gitignore"
5816     lines: List[str] = []
5817     if gitignore.is_file():
5818         with gitignore.open() as gf:
5819             lines = gf.readlines()
5820     return PathSpec.from_lines("gitwildmatch", lines)
5821
5822
5823 def normalize_path_maybe_ignore(
5824     path: Path, root: Path, report: "Report"
5825 ) -> Optional[str]:
5826     """Normalize `path`. May return `None` if `path` was ignored.
5827
5828     `report` is where "path ignored" output goes.
5829     """
5830     try:
5831         normalized_path = path.resolve().relative_to(root).as_posix()
5832     except OSError as e:
5833         report.path_ignored(path, f"cannot be read because {e}")
5834         return None
5835
5836     except ValueError:
5837         if path.is_symlink():
5838             report.path_ignored(path, f"is a symbolic link that points outside {root}")
5839             return None
5840
5841         raise
5842
5843     return normalized_path
5844
5845
5846 def gen_python_files(
5847     paths: Iterable[Path],
5848     root: Path,
5849     include: Optional[Pattern[str]],
5850     exclude: Pattern[str],
5851     force_exclude: Optional[Pattern[str]],
5852     report: "Report",
5853     gitignore: PathSpec,
5854 ) -> Iterator[Path]:
5855     """Generate all files under `path` whose paths are not excluded by the
5856     `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
5857
5858     Symbolic links pointing outside of the `root` directory are ignored.
5859
5860     `report` is where output about exclusions goes.
5861     """
5862     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5863     for child in paths:
5864         normalized_path = normalize_path_maybe_ignore(child, root, report)
5865         if normalized_path is None:
5866             continue
5867
5868         # First ignore files matching .gitignore
5869         if gitignore.match_file(normalized_path):
5870             report.path_ignored(child, "matches the .gitignore file content")
5871             continue
5872
5873         # Then ignore with `--exclude` and `--force-exclude` options.
5874         normalized_path = "/" + normalized_path
5875         if child.is_dir():
5876             normalized_path += "/"
5877
5878         exclude_match = exclude.search(normalized_path) if exclude else None
5879         if exclude_match and exclude_match.group(0):
5880             report.path_ignored(child, "matches the --exclude regular expression")
5881             continue
5882
5883         force_exclude_match = (
5884             force_exclude.search(normalized_path) if force_exclude else None
5885         )
5886         if force_exclude_match and force_exclude_match.group(0):
5887             report.path_ignored(child, "matches the --force-exclude regular expression")
5888             continue
5889
5890         if child.is_dir():
5891             yield from gen_python_files(
5892                 child.iterdir(),
5893                 root,
5894                 include,
5895                 exclude,
5896                 force_exclude,
5897                 report,
5898                 gitignore,
5899             )
5900
5901         elif child.is_file():
5902             include_match = include.search(normalized_path) if include else True
5903             if include_match:
5904                 yield child
5905
5906
5907 @lru_cache()
5908 def find_project_root(srcs: Iterable[str]) -> Path:
5909     """Return a directory containing .git, .hg, or pyproject.toml.
5910
5911     That directory will be a common parent of all files and directories
5912     passed in `srcs`.
5913
5914     If no directory in the tree contains a marker that would specify it's the
5915     project root, the root of the file system is returned.
5916     """
5917     if not srcs:
5918         return Path("/").resolve()
5919
5920     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
5921
5922     # A list of lists of parents for each 'src'. 'src' is included as a
5923     # "parent" of itself if it is a directory
5924     src_parents = [
5925         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
5926     ]
5927
5928     common_base = max(
5929         set.intersection(*(set(parents) for parents in src_parents)),
5930         key=lambda path: path.parts,
5931     )
5932
5933     for directory in (common_base, *common_base.parents):
5934         if (directory / ".git").exists():
5935             return directory
5936
5937         if (directory / ".hg").is_dir():
5938             return directory
5939
5940         if (directory / "pyproject.toml").is_file():
5941             return directory
5942
5943     return directory
5944
5945
5946 @dataclass
5947 class Report:
5948     """Provides a reformatting counter. Can be rendered with `str(report)`."""
5949
5950     check: bool = False
5951     diff: bool = False
5952     quiet: bool = False
5953     verbose: bool = False
5954     change_count: int = 0
5955     same_count: int = 0
5956     failure_count: int = 0
5957
5958     def done(self, src: Path, changed: Changed) -> None:
5959         """Increment the counter for successful reformatting. Write out a message."""
5960         if changed is Changed.YES:
5961             reformatted = "would reformat" if self.check or self.diff else "reformatted"
5962             if self.verbose or not self.quiet:
5963                 out(f"{reformatted} {src}")
5964             self.change_count += 1
5965         else:
5966             if self.verbose:
5967                 if changed is Changed.NO:
5968                     msg = f"{src} already well formatted, good job."
5969                 else:
5970                     msg = f"{src} wasn't modified on disk since last run."
5971                 out(msg, bold=False)
5972             self.same_count += 1
5973
5974     def failed(self, src: Path, message: str) -> None:
5975         """Increment the counter for failed reformatting. Write out a message."""
5976         err(f"error: cannot format {src}: {message}")
5977         self.failure_count += 1
5978
5979     def path_ignored(self, path: Path, message: str) -> None:
5980         if self.verbose:
5981             out(f"{path} ignored: {message}", bold=False)
5982
5983     @property
5984     def return_code(self) -> int:
5985         """Return the exit code that the app should use.
5986
5987         This considers the current state of changed files and failures:
5988         - if there were any failures, return 123;
5989         - if any files were changed and --check is being used, return 1;
5990         - otherwise return 0.
5991         """
5992         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
5993         # 126 we have special return codes reserved by the shell.
5994         if self.failure_count:
5995             return 123
5996
5997         elif self.change_count and self.check:
5998             return 1
5999
6000         return 0
6001
6002     def __str__(self) -> str:
6003         """Render a color report of the current state.
6004
6005         Use `click.unstyle` to remove colors.
6006         """
6007         if self.check or self.diff:
6008             reformatted = "would be reformatted"
6009             unchanged = "would be left unchanged"
6010             failed = "would fail to reformat"
6011         else:
6012             reformatted = "reformatted"
6013             unchanged = "left unchanged"
6014             failed = "failed to reformat"
6015         report = []
6016         if self.change_count:
6017             s = "s" if self.change_count > 1 else ""
6018             report.append(
6019                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6020             )
6021         if self.same_count:
6022             s = "s" if self.same_count > 1 else ""
6023             report.append(f"{self.same_count} file{s} {unchanged}")
6024         if self.failure_count:
6025             s = "s" if self.failure_count > 1 else ""
6026             report.append(
6027                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6028             )
6029         return ", ".join(report) + "."
6030
6031
6032 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6033     filename = "<unknown>"
6034     if sys.version_info >= (3, 8):
6035         # TODO: support Python 4+ ;)
6036         for minor_version in range(sys.version_info[1], 4, -1):
6037             try:
6038                 return ast.parse(src, filename, feature_version=(3, minor_version))
6039             except SyntaxError:
6040                 continue
6041     else:
6042         for feature_version in (7, 6):
6043             try:
6044                 return ast3.parse(src, filename, feature_version=feature_version)
6045             except SyntaxError:
6046                 continue
6047
6048     return ast27.parse(src)
6049
6050
6051 def _fixup_ast_constants(
6052     node: Union[ast.AST, ast3.AST, ast27.AST]
6053 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6054     """Map ast nodes deprecated in 3.8 to Constant."""
6055     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6056         return ast.Constant(value=node.s)
6057
6058     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6059         return ast.Constant(value=node.n)
6060
6061     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6062         return ast.Constant(value=node.value)
6063
6064     return node
6065
6066
6067 def _stringify_ast(
6068     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6069 ) -> Iterator[str]:
6070     """Simple visitor generating strings to compare ASTs by content."""
6071
6072     node = _fixup_ast_constants(node)
6073
6074     yield f"{'  ' * depth}{node.__class__.__name__}("
6075
6076     for field in sorted(node._fields):  # noqa: F402
6077         # TypeIgnore has only one field 'lineno' which breaks this comparison
6078         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6079         if sys.version_info >= (3, 8):
6080             type_ignore_classes += (ast.TypeIgnore,)
6081         if isinstance(node, type_ignore_classes):
6082             break
6083
6084         try:
6085             value = getattr(node, field)
6086         except AttributeError:
6087             continue
6088
6089         yield f"{'  ' * (depth+1)}{field}="
6090
6091         if isinstance(value, list):
6092             for item in value:
6093                 # Ignore nested tuples within del statements, because we may insert
6094                 # parentheses and they change the AST.
6095                 if (
6096                     field == "targets"
6097                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6098                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6099                 ):
6100                     for item in item.elts:
6101                         yield from _stringify_ast(item, depth + 2)
6102
6103                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6104                     yield from _stringify_ast(item, depth + 2)
6105
6106         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6107             yield from _stringify_ast(value, depth + 2)
6108
6109         else:
6110             # Constant strings may be indented across newlines, if they are
6111             # docstrings; fold spaces after newlines when comparing. Similarly,
6112             # trailing and leading space may be removed.
6113             if (
6114                 isinstance(node, ast.Constant)
6115                 and field == "value"
6116                 and isinstance(value, str)
6117             ):
6118                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6119             else:
6120                 normalized = value
6121             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6122
6123     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6124
6125
6126 def assert_equivalent(src: str, dst: str) -> None:
6127     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6128     try:
6129         src_ast = parse_ast(src)
6130     except Exception as exc:
6131         raise AssertionError(
6132             "cannot use --safe with this file; failed to parse source file.  AST"
6133             f" error message: {exc}"
6134         )
6135
6136     try:
6137         dst_ast = parse_ast(dst)
6138     except Exception as exc:
6139         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6140         raise AssertionError(
6141             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6142             " on https://github.com/psf/black/issues.  This invalid output might be"
6143             f" helpful: {log}"
6144         ) from None
6145
6146     src_ast_str = "\n".join(_stringify_ast(src_ast))
6147     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6148     if src_ast_str != dst_ast_str:
6149         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6150         raise AssertionError(
6151             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6152             " source.  Please report a bug on https://github.com/psf/black/issues. "
6153             f" This diff might be helpful: {log}"
6154         ) from None
6155
6156
6157 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6158     """Raise AssertionError if `dst` reformats differently the second time."""
6159     newdst = format_str(dst, mode=mode)
6160     if dst != newdst:
6161         log = dump_to_file(
6162             diff(src, dst, "source", "first pass"),
6163             diff(dst, newdst, "first pass", "second pass"),
6164         )
6165         raise AssertionError(
6166             "INTERNAL ERROR: Black produced different code on the second pass of the"
6167             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6168             f"  This diff might be helpful: {log}"
6169         ) from None
6170
6171
6172 @mypyc_attr(patchable=True)
6173 def dump_to_file(*output: str) -> str:
6174     """Dump `output` to a temporary file. Return path to the file."""
6175     with tempfile.NamedTemporaryFile(
6176         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6177     ) as f:
6178         for lines in output:
6179             f.write(lines)
6180             if lines and lines[-1] != "\n":
6181                 f.write("\n")
6182     return f.name
6183
6184
6185 @contextmanager
6186 def nullcontext() -> Iterator[None]:
6187     """Return an empty context manager.
6188
6189     To be used like `nullcontext` in Python 3.7.
6190     """
6191     yield
6192
6193
6194 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6195     """Return a unified diff string between strings `a` and `b`."""
6196     import difflib
6197
6198     a_lines = [line + "\n" for line in a.splitlines()]
6199     b_lines = [line + "\n" for line in b.splitlines()]
6200     return "".join(
6201         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6202     )
6203
6204
6205 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6206     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6207     err("Aborted!")
6208     for task in tasks:
6209         task.cancel()
6210
6211
6212 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6213     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6214     try:
6215         if sys.version_info[:2] >= (3, 7):
6216             all_tasks = asyncio.all_tasks
6217         else:
6218             all_tasks = asyncio.Task.all_tasks
6219         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6220         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6221         if not to_cancel:
6222             return
6223
6224         for task in to_cancel:
6225             task.cancel()
6226         loop.run_until_complete(
6227             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6228         )
6229     finally:
6230         # `concurrent.futures.Future` objects cannot be cancelled once they
6231         # are already running. There might be some when the `shutdown()` happened.
6232         # Silence their logger's spew about the event loop being closed.
6233         cf_logger = logging.getLogger("concurrent.futures")
6234         cf_logger.setLevel(logging.CRITICAL)
6235         loop.close()
6236
6237
6238 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6239     """Replace `regex` with `replacement` twice on `original`.
6240
6241     This is used by string normalization to perform replaces on
6242     overlapping matches.
6243     """
6244     return regex.sub(replacement, regex.sub(replacement, original))
6245
6246
6247 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6248     """Compile a regular expression string in `regex`.
6249
6250     If it contains newlines, use verbose mode.
6251     """
6252     if "\n" in regex:
6253         regex = "(?x)" + regex
6254     compiled: Pattern[str] = re.compile(regex)
6255     return compiled
6256
6257
6258 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6259     """Like `reversed(enumerate(sequence))` if that were possible."""
6260     index = len(sequence) - 1
6261     for element in reversed(sequence):
6262         yield (index, element)
6263         index -= 1
6264
6265
6266 def enumerate_with_length(
6267     line: Line, reversed: bool = False
6268 ) -> Iterator[Tuple[Index, Leaf, int]]:
6269     """Return an enumeration of leaves with their length.
6270
6271     Stops prematurely on multiline strings and standalone comments.
6272     """
6273     op = cast(
6274         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6275         enumerate_reversed if reversed else enumerate,
6276     )
6277     for index, leaf in op(line.leaves):
6278         length = len(leaf.prefix) + len(leaf.value)
6279         if "\n" in leaf.value:
6280             return  # Multiline strings, we can't continue.
6281
6282         for comment in line.comments_after(leaf):
6283             length += len(comment.value)
6284
6285         yield index, leaf, length
6286
6287
6288 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6289     """Return True if `line` is no longer than `line_length`.
6290
6291     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6292     """
6293     if not line_str:
6294         line_str = line_to_string(line)
6295     return (
6296         len(line_str) <= line_length
6297         and "\n" not in line_str  # multiline strings
6298         and not line.contains_standalone_comments()
6299     )
6300
6301
6302 def can_be_split(line: Line) -> bool:
6303     """Return False if the line cannot be split *for sure*.
6304
6305     This is not an exhaustive search but a cheap heuristic that we can use to
6306     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6307     in unnecessary parentheses).
6308     """
6309     leaves = line.leaves
6310     if len(leaves) < 2:
6311         return False
6312
6313     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6314         call_count = 0
6315         dot_count = 0
6316         next = leaves[-1]
6317         for leaf in leaves[-2::-1]:
6318             if leaf.type in OPENING_BRACKETS:
6319                 if next.type not in CLOSING_BRACKETS:
6320                     return False
6321
6322                 call_count += 1
6323             elif leaf.type == token.DOT:
6324                 dot_count += 1
6325             elif leaf.type == token.NAME:
6326                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6327                     return False
6328
6329             elif leaf.type not in CLOSING_BRACKETS:
6330                 return False
6331
6332             if dot_count > 1 and call_count > 1:
6333                 return False
6334
6335     return True
6336
6337
6338 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
6339     """Does `line` have a shape safe to reformat without optional parens around it?
6340
6341     Returns True for only a subset of potentially nice looking formattings but
6342     the point is to not return false positives that end up producing lines that
6343     are too long.
6344     """
6345     bt = line.bracket_tracker
6346     if not bt.delimiters:
6347         # Without delimiters the optional parentheses are useless.
6348         return True
6349
6350     max_priority = bt.max_delimiter_priority()
6351     if bt.delimiter_count_with_priority(max_priority) > 1:
6352         # With more than one delimiter of a kind the optional parentheses read better.
6353         return False
6354
6355     if max_priority == DOT_PRIORITY:
6356         # A single stranded method call doesn't require optional parentheses.
6357         return True
6358
6359     assert len(line.leaves) >= 2, "Stranded delimiter"
6360
6361     first = line.leaves[0]
6362     second = line.leaves[1]
6363     penultimate = line.leaves[-2]
6364     last = line.leaves[-1]
6365
6366     # With a single delimiter, omit if the expression starts or ends with
6367     # a bracket.
6368     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6369         remainder = False
6370         length = 4 * line.depth
6371         for _index, leaf, leaf_length in enumerate_with_length(line):
6372             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6373                 remainder = True
6374             if remainder:
6375                 length += leaf_length
6376                 if length > line_length:
6377                     break
6378
6379                 if leaf.type in OPENING_BRACKETS:
6380                     # There are brackets we can further split on.
6381                     remainder = False
6382
6383         else:
6384             # checked the entire string and line length wasn't exceeded
6385             if len(line.leaves) == _index + 1:
6386                 return True
6387
6388         # Note: we are not returning False here because a line might have *both*
6389         # a leading opening bracket and a trailing closing bracket.  If the
6390         # opening bracket doesn't match our rule, maybe the closing will.
6391
6392     if (
6393         last.type == token.RPAR
6394         or last.type == token.RBRACE
6395         or (
6396             # don't use indexing for omitting optional parentheses;
6397             # it looks weird
6398             last.type == token.RSQB
6399             and last.parent
6400             and last.parent.type != syms.trailer
6401         )
6402     ):
6403         if penultimate.type in OPENING_BRACKETS:
6404             # Empty brackets don't help.
6405             return False
6406
6407         if is_multiline_string(first):
6408             # Additional wrapping of a multiline string in this situation is
6409             # unnecessary.
6410             return True
6411
6412         length = 4 * line.depth
6413         seen_other_brackets = False
6414         for _index, leaf, leaf_length in enumerate_with_length(line):
6415             length += leaf_length
6416             if leaf is last.opening_bracket:
6417                 if seen_other_brackets or length <= line_length:
6418                     return True
6419
6420             elif leaf.type in OPENING_BRACKETS:
6421                 # There are brackets we can further split on.
6422                 seen_other_brackets = True
6423
6424     return False
6425
6426
6427 def get_cache_file(mode: Mode) -> Path:
6428     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6429
6430
6431 def read_cache(mode: Mode) -> Cache:
6432     """Read the cache if it exists and is well formed.
6433
6434     If it is not well formed, the call to write_cache later should resolve the issue.
6435     """
6436     cache_file = get_cache_file(mode)
6437     if not cache_file.exists():
6438         return {}
6439
6440     with cache_file.open("rb") as fobj:
6441         try:
6442             cache: Cache = pickle.load(fobj)
6443         except (pickle.UnpicklingError, ValueError):
6444             return {}
6445
6446     return cache
6447
6448
6449 def get_cache_info(path: Path) -> CacheInfo:
6450     """Return the information used to check if a file is already formatted or not."""
6451     stat = path.stat()
6452     return stat.st_mtime, stat.st_size
6453
6454
6455 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6456     """Split an iterable of paths in `sources` into two sets.
6457
6458     The first contains paths of files that modified on disk or are not in the
6459     cache. The other contains paths to non-modified files.
6460     """
6461     todo, done = set(), set()
6462     for src in sources:
6463         src = src.resolve()
6464         if cache.get(src) != get_cache_info(src):
6465             todo.add(src)
6466         else:
6467             done.add(src)
6468     return todo, done
6469
6470
6471 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6472     """Update the cache file."""
6473     cache_file = get_cache_file(mode)
6474     try:
6475         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6476         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6477         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6478             pickle.dump(new_cache, f, protocol=4)
6479         os.replace(f.name, cache_file)
6480     except OSError:
6481         pass
6482
6483
6484 def patch_click() -> None:
6485     """Make Click not crash.
6486
6487     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6488     default which restricts paths that it can access during the lifetime of the
6489     application.  Click refuses to work in this scenario by raising a RuntimeError.
6490
6491     In case of Black the likelihood that non-ASCII characters are going to be used in
6492     file paths is minimal since it's Python source code.  Moreover, this crash was
6493     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6494     """
6495     try:
6496         from click import core
6497         from click import _unicodefun  # type: ignore
6498     except ModuleNotFoundError:
6499         return
6500
6501     for module in (core, _unicodefun):
6502         if hasattr(module, "_verify_python3_env"):
6503             module._verify_python3_env = lambda: None
6504
6505
6506 def patched_main() -> None:
6507     freeze_support()
6508     patch_click()
6509     main()
6510
6511
6512 def fix_docstring(docstring: str, prefix: str) -> str:
6513     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6514     if not docstring:
6515         return ""
6516     # Convert tabs to spaces (following the normal Python rules)
6517     # and split into a list of lines:
6518     lines = docstring.expandtabs().splitlines()
6519     # Determine minimum indentation (first line doesn't count):
6520     indent = sys.maxsize
6521     for line in lines[1:]:
6522         stripped = line.lstrip()
6523         if stripped:
6524             indent = min(indent, len(line) - len(stripped))
6525     # Remove indentation (first line is special):
6526     trimmed = [lines[0].strip()]
6527     if indent < sys.maxsize:
6528         last_line_idx = len(lines) - 2
6529         for i, line in enumerate(lines[1:]):
6530             stripped_line = line[indent:].rstrip()
6531             if stripped_line or i == last_line_idx:
6532                 trimmed.append(prefix + stripped_line)
6533             else:
6534                 trimmed.append("")
6535     # Return a single string:
6536     return "\n".join(trimmed)
6537
6538
6539 if __name__ == "__main__":
6540     patched_main()