]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Disable string splitting/merging by default (#1609)
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 T = TypeVar("T")
116 E = TypeVar("E", bound=Exception)
117
118
119 class Ok(Generic[T]):
120     def __init__(self, value: T) -> None:
121         self._value = value
122
123     def ok(self) -> T:
124         return self._value
125
126
127 class Err(Generic[E]):
128     def __init__(self, e: E) -> None:
129         self._e = e
130
131     def err(self) -> E:
132         return self._e
133
134
135 # The 'Result' return type is used to implement an error-handling model heavily
136 # influenced by that used by the Rust programming language
137 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
138 Result = Union[Ok[T], Err[E]]
139 TResult = Result[T, CannotTransform]  # (T)ransform Result
140 TMatchResult = TResult[Index]
141
142
143 class WriteBack(Enum):
144     NO = 0
145     YES = 1
146     DIFF = 2
147     CHECK = 3
148     COLOR_DIFF = 4
149
150     @classmethod
151     def from_configuration(
152         cls, *, check: bool, diff: bool, color: bool = False
153     ) -> "WriteBack":
154         if check and not diff:
155             return cls.CHECK
156
157         if diff and color:
158             return cls.COLOR_DIFF
159
160         return cls.DIFF if diff else cls.YES
161
162
163 class Changed(Enum):
164     NO = 0
165     CACHED = 1
166     YES = 2
167
168
169 class TargetVersion(Enum):
170     PY27 = 2
171     PY33 = 3
172     PY34 = 4
173     PY35 = 5
174     PY36 = 6
175     PY37 = 7
176     PY38 = 8
177
178     def is_python2(self) -> bool:
179         return self is TargetVersion.PY27
180
181
182 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
183
184
185 class Feature(Enum):
186     # All string literals are unicode
187     UNICODE_LITERALS = 1
188     F_STRINGS = 2
189     NUMERIC_UNDERSCORES = 3
190     TRAILING_COMMA_IN_CALL = 4
191     TRAILING_COMMA_IN_DEF = 5
192     # The following two feature-flags are mutually exclusive, and exactly one should be
193     # set for every version of python.
194     ASYNC_IDENTIFIERS = 6
195     ASYNC_KEYWORDS = 7
196     ASSIGNMENT_EXPRESSIONS = 8
197     POS_ONLY_ARGUMENTS = 9
198
199
200 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
201     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
202     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
203     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
204     TargetVersion.PY35: {
205         Feature.UNICODE_LITERALS,
206         Feature.TRAILING_COMMA_IN_CALL,
207         Feature.ASYNC_IDENTIFIERS,
208     },
209     TargetVersion.PY36: {
210         Feature.UNICODE_LITERALS,
211         Feature.F_STRINGS,
212         Feature.NUMERIC_UNDERSCORES,
213         Feature.TRAILING_COMMA_IN_CALL,
214         Feature.TRAILING_COMMA_IN_DEF,
215         Feature.ASYNC_IDENTIFIERS,
216     },
217     TargetVersion.PY37: {
218         Feature.UNICODE_LITERALS,
219         Feature.F_STRINGS,
220         Feature.NUMERIC_UNDERSCORES,
221         Feature.TRAILING_COMMA_IN_CALL,
222         Feature.TRAILING_COMMA_IN_DEF,
223         Feature.ASYNC_KEYWORDS,
224     },
225     TargetVersion.PY38: {
226         Feature.UNICODE_LITERALS,
227         Feature.F_STRINGS,
228         Feature.NUMERIC_UNDERSCORES,
229         Feature.TRAILING_COMMA_IN_CALL,
230         Feature.TRAILING_COMMA_IN_DEF,
231         Feature.ASYNC_KEYWORDS,
232         Feature.ASSIGNMENT_EXPRESSIONS,
233         Feature.POS_ONLY_ARGUMENTS,
234     },
235 }
236
237
238 @dataclass
239 class Mode:
240     target_versions: Set[TargetVersion] = field(default_factory=set)
241     line_length: int = DEFAULT_LINE_LENGTH
242     string_normalization: bool = True
243     experimental_string_processing: bool = False
244     is_pyi: bool = False
245
246     def get_cache_key(self) -> str:
247         if self.target_versions:
248             version_str = ",".join(
249                 str(version.value)
250                 for version in sorted(self.target_versions, key=lambda v: v.value)
251             )
252         else:
253             version_str = "-"
254         parts = [
255             version_str,
256             str(self.line_length),
257             str(int(self.string_normalization)),
258             str(int(self.is_pyi)),
259         ]
260         return ".".join(parts)
261
262
263 # Legacy name, left for integrations.
264 FileMode = Mode
265
266
267 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
268     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
269
270
271 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
272     """Find the absolute filepath to a pyproject.toml if it exists"""
273     path_project_root = find_project_root(path_search_start)
274     path_pyproject_toml = path_project_root / "pyproject.toml"
275     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
276
277
278 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
279     """Parse a pyproject toml file, pulling out relevant parts for Black
280
281     If parsing fails, will raise a toml.TomlDecodeError
282     """
283     pyproject_toml = toml.load(path_config)
284     config = pyproject_toml.get("tool", {}).get("black", {})
285     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
286
287
288 def read_pyproject_toml(
289     ctx: click.Context, param: click.Parameter, value: Optional[str]
290 ) -> Optional[str]:
291     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
292
293     Returns the path to a successfully found and read configuration file, None
294     otherwise.
295     """
296     if not value:
297         value = find_pyproject_toml(ctx.params.get("src", ()))
298         if value is None:
299             return None
300
301     try:
302         config = parse_pyproject_toml(value)
303     except (toml.TomlDecodeError, OSError) as e:
304         raise click.FileError(
305             filename=value, hint=f"Error reading configuration file: {e}"
306         )
307
308     if not config:
309         return None
310     else:
311         # Sanitize the values to be Click friendly. For more information please see:
312         # https://github.com/psf/black/issues/1458
313         # https://github.com/pallets/click/issues/1567
314         config = {
315             k: str(v) if not isinstance(v, (list, dict)) else v
316             for k, v in config.items()
317         }
318
319     target_version = config.get("target_version")
320     if target_version is not None and not isinstance(target_version, list):
321         raise click.BadOptionUsage(
322             "target-version", "Config key target-version must be a list"
323         )
324
325     default_map: Dict[str, Any] = {}
326     if ctx.default_map:
327         default_map.update(ctx.default_map)
328     default_map.update(config)
329
330     ctx.default_map = default_map
331     return value
332
333
334 def target_version_option_callback(
335     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
336 ) -> List[TargetVersion]:
337     """Compute the target versions from a --target-version flag.
338
339     This is its own function because mypy couldn't infer the type correctly
340     when it was a lambda, causing mypyc trouble.
341     """
342     return [TargetVersion[val.upper()] for val in v]
343
344
345 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
346 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
347 @click.option(
348     "-l",
349     "--line-length",
350     type=int,
351     default=DEFAULT_LINE_LENGTH,
352     help="How many characters per line to allow.",
353     show_default=True,
354 )
355 @click.option(
356     "-t",
357     "--target-version",
358     type=click.Choice([v.name.lower() for v in TargetVersion]),
359     callback=target_version_option_callback,
360     multiple=True,
361     help=(
362         "Python versions that should be supported by Black's output. [default: per-file"
363         " auto-detection]"
364     ),
365 )
366 @click.option(
367     "--pyi",
368     is_flag=True,
369     help=(
370         "Format all input files like typing stubs regardless of file extension (useful"
371         " when piping source on standard input)."
372     ),
373 )
374 @click.option(
375     "-S",
376     "--skip-string-normalization",
377     is_flag=True,
378     help="Don't normalize string quotes or prefixes.",
379 )
380 @click.option(
381     "--experimental-string-processing",
382     is_flag=True,
383     hidden=True,
384     help=(
385         "Experimental option that performs more normalization on string literals."
386         " Currently disabled because it leads to some crashes."
387     ),
388 )
389 @click.option(
390     "--check",
391     is_flag=True,
392     help=(
393         "Don't write the files back, just return the status.  Return code 0 means"
394         " nothing would change.  Return code 1 means some files would be reformatted."
395         " Return code 123 means there was an internal error."
396     ),
397 )
398 @click.option(
399     "--diff",
400     is_flag=True,
401     help="Don't write the files back, just output a diff for each file on stdout.",
402 )
403 @click.option(
404     "--color/--no-color",
405     is_flag=True,
406     help="Show colored diff. Only applies when `--diff` is given.",
407 )
408 @click.option(
409     "--fast/--safe",
410     is_flag=True,
411     help="If --fast given, skip temporary sanity checks. [default: --safe]",
412 )
413 @click.option(
414     "--include",
415     type=str,
416     default=DEFAULT_INCLUDES,
417     help=(
418         "A regular expression that matches files and directories that should be"
419         " included on recursive searches.  An empty value means all files are included"
420         " regardless of the name.  Use forward slashes for directories on all platforms"
421         " (Windows, too).  Exclusions are calculated first, inclusions later."
422     ),
423     show_default=True,
424 )
425 @click.option(
426     "--exclude",
427     type=str,
428     default=DEFAULT_EXCLUDES,
429     help=(
430         "A regular expression that matches files and directories that should be"
431         " excluded on recursive searches.  An empty value means no paths are excluded."
432         " Use forward slashes for directories on all platforms (Windows, too). "
433         " Exclusions are calculated first, inclusions later."
434     ),
435     show_default=True,
436 )
437 @click.option(
438     "--force-exclude",
439     type=str,
440     help=(
441         "Like --exclude, but files and directories matching this regex will be "
442         "excluded even when they are passed explicitly as arguments"
443     ),
444 )
445 @click.option(
446     "-q",
447     "--quiet",
448     is_flag=True,
449     help=(
450         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
451         " those with 2>/dev/null."
452     ),
453 )
454 @click.option(
455     "-v",
456     "--verbose",
457     is_flag=True,
458     help=(
459         "Also emit messages to stderr about files that were not changed or were ignored"
460         " due to --exclude=."
461     ),
462 )
463 @click.version_option(version=__version__)
464 @click.argument(
465     "src",
466     nargs=-1,
467     type=click.Path(
468         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
469     ),
470     is_eager=True,
471 )
472 @click.option(
473     "--config",
474     type=click.Path(
475         exists=True,
476         file_okay=True,
477         dir_okay=False,
478         readable=True,
479         allow_dash=False,
480         path_type=str,
481     ),
482     is_eager=True,
483     callback=read_pyproject_toml,
484     help="Read configuration from FILE path.",
485 )
486 @click.pass_context
487 def main(
488     ctx: click.Context,
489     code: Optional[str],
490     line_length: int,
491     target_version: List[TargetVersion],
492     check: bool,
493     diff: bool,
494     color: bool,
495     fast: bool,
496     pyi: bool,
497     skip_string_normalization: bool,
498     experimental_string_processing: bool,
499     quiet: bool,
500     verbose: bool,
501     include: str,
502     exclude: str,
503     force_exclude: Optional[str],
504     src: Tuple[str, ...],
505     config: Optional[str],
506 ) -> None:
507     """The uncompromising code formatter."""
508     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
509     if target_version:
510         versions = set(target_version)
511     else:
512         # We'll autodetect later.
513         versions = set()
514     mode = Mode(
515         target_versions=versions,
516         line_length=line_length,
517         is_pyi=pyi,
518         string_normalization=not skip_string_normalization,
519         experimental_string_processing=experimental_string_processing,
520     )
521     if config and verbose:
522         out(f"Using configuration from {config}.", bold=False, fg="blue")
523     if code is not None:
524         print(format_str(code, mode=mode))
525         ctx.exit(0)
526     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
527     sources = get_sources(
528         ctx=ctx,
529         src=src,
530         quiet=quiet,
531         verbose=verbose,
532         include=include,
533         exclude=exclude,
534         force_exclude=force_exclude,
535         report=report,
536     )
537
538     path_empty(
539         sources,
540         "No Python files are present to be formatted. Nothing to do 😴",
541         quiet,
542         verbose,
543         ctx,
544     )
545
546     if len(sources) == 1:
547         reformat_one(
548             src=sources.pop(),
549             fast=fast,
550             write_back=write_back,
551             mode=mode,
552             report=report,
553         )
554     else:
555         reformat_many(
556             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
557         )
558
559     if verbose or not quiet:
560         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
561         click.secho(str(report), err=True)
562     ctx.exit(report.return_code)
563
564
565 def get_sources(
566     *,
567     ctx: click.Context,
568     src: Tuple[str, ...],
569     quiet: bool,
570     verbose: bool,
571     include: str,
572     exclude: str,
573     force_exclude: Optional[str],
574     report: "Report",
575 ) -> Set[Path]:
576     """Compute the set of files to be formatted."""
577     try:
578         include_regex = re_compile_maybe_verbose(include)
579     except re.error:
580         err(f"Invalid regular expression for include given: {include!r}")
581         ctx.exit(2)
582     try:
583         exclude_regex = re_compile_maybe_verbose(exclude)
584     except re.error:
585         err(f"Invalid regular expression for exclude given: {exclude!r}")
586         ctx.exit(2)
587     try:
588         force_exclude_regex = (
589             re_compile_maybe_verbose(force_exclude) if force_exclude else None
590         )
591     except re.error:
592         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
593         ctx.exit(2)
594
595     root = find_project_root(src)
596     sources: Set[Path] = set()
597     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
598     gitignore = get_gitignore(root)
599
600     for s in src:
601         p = Path(s)
602         if p.is_dir():
603             sources.update(
604                 gen_python_files(
605                     p.iterdir(),
606                     root,
607                     include_regex,
608                     exclude_regex,
609                     force_exclude_regex,
610                     report,
611                     gitignore,
612                 )
613             )
614         elif s == "-":
615             sources.add(p)
616         elif p.is_file():
617             normalized_path = normalize_path_maybe_ignore(p, root, report)
618             if normalized_path is None:
619                 continue
620
621             normalized_path = "/" + normalized_path
622             # Hard-exclude any files that matches the `--force-exclude` regex.
623             if force_exclude_regex:
624                 force_exclude_match = force_exclude_regex.search(normalized_path)
625             else:
626                 force_exclude_match = None
627             if force_exclude_match and force_exclude_match.group(0):
628                 report.path_ignored(p, "matches the --force-exclude regular expression")
629                 continue
630
631             sources.add(p)
632         else:
633             err(f"invalid path: {s}")
634     return sources
635
636
637 def path_empty(
638     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
639 ) -> None:
640     """
641     Exit if there is no `src` provided for formatting
642     """
643     if len(src) == 0:
644         if verbose or not quiet:
645             out(msg)
646             ctx.exit(0)
647
648
649 def reformat_one(
650     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
651 ) -> None:
652     """Reformat a single file under `src` without spawning child processes.
653
654     `fast`, `write_back`, and `mode` options are passed to
655     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
656     """
657     try:
658         changed = Changed.NO
659         if not src.is_file() and str(src) == "-":
660             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
661                 changed = Changed.YES
662         else:
663             cache: Cache = {}
664             if write_back != WriteBack.DIFF:
665                 cache = read_cache(mode)
666                 res_src = src.resolve()
667                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
668                     changed = Changed.CACHED
669             if changed is not Changed.CACHED and format_file_in_place(
670                 src, fast=fast, write_back=write_back, mode=mode
671             ):
672                 changed = Changed.YES
673             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
674                 write_back is WriteBack.CHECK and changed is Changed.NO
675             ):
676                 write_cache(cache, [src], mode)
677         report.done(src, changed)
678     except Exception as exc:
679         if report.verbose:
680             traceback.print_exc()
681         report.failed(src, str(exc))
682
683
684 def reformat_many(
685     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
686 ) -> None:
687     """Reformat multiple files using a ProcessPoolExecutor."""
688     executor: Executor
689     loop = asyncio.get_event_loop()
690     worker_count = os.cpu_count()
691     if sys.platform == "win32":
692         # Work around https://bugs.python.org/issue26903
693         worker_count = min(worker_count, 61)
694     try:
695         executor = ProcessPoolExecutor(max_workers=worker_count)
696     except (ImportError, OSError):
697         # we arrive here if the underlying system does not support multi-processing
698         # like in AWS Lambda or Termux, in which case we gracefully fallback to
699         # a ThreadPollExecutor with just a single worker (more workers would not do us
700         # any good due to the Global Interpreter Lock)
701         executor = ThreadPoolExecutor(max_workers=1)
702
703     try:
704         loop.run_until_complete(
705             schedule_formatting(
706                 sources=sources,
707                 fast=fast,
708                 write_back=write_back,
709                 mode=mode,
710                 report=report,
711                 loop=loop,
712                 executor=executor,
713             )
714         )
715     finally:
716         shutdown(loop)
717         if executor is not None:
718             executor.shutdown()
719
720
721 async def schedule_formatting(
722     sources: Set[Path],
723     fast: bool,
724     write_back: WriteBack,
725     mode: Mode,
726     report: "Report",
727     loop: asyncio.AbstractEventLoop,
728     executor: Executor,
729 ) -> None:
730     """Run formatting of `sources` in parallel using the provided `executor`.
731
732     (Use ProcessPoolExecutors for actual parallelism.)
733
734     `write_back`, `fast`, and `mode` options are passed to
735     :func:`format_file_in_place`.
736     """
737     cache: Cache = {}
738     if write_back != WriteBack.DIFF:
739         cache = read_cache(mode)
740         sources, cached = filter_cached(cache, sources)
741         for src in sorted(cached):
742             report.done(src, Changed.CACHED)
743     if not sources:
744         return
745
746     cancelled = []
747     sources_to_cache = []
748     lock = None
749     if write_back == WriteBack.DIFF:
750         # For diff output, we need locks to ensure we don't interleave output
751         # from different processes.
752         manager = Manager()
753         lock = manager.Lock()
754     tasks = {
755         asyncio.ensure_future(
756             loop.run_in_executor(
757                 executor, format_file_in_place, src, fast, mode, write_back, lock
758             )
759         ): src
760         for src in sorted(sources)
761     }
762     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
763     try:
764         loop.add_signal_handler(signal.SIGINT, cancel, pending)
765         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
766     except NotImplementedError:
767         # There are no good alternatives for these on Windows.
768         pass
769     while pending:
770         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
771         for task in done:
772             src = tasks.pop(task)
773             if task.cancelled():
774                 cancelled.append(task)
775             elif task.exception():
776                 report.failed(src, str(task.exception()))
777             else:
778                 changed = Changed.YES if task.result() else Changed.NO
779                 # If the file was written back or was successfully checked as
780                 # well-formatted, store this information in the cache.
781                 if write_back is WriteBack.YES or (
782                     write_back is WriteBack.CHECK and changed is Changed.NO
783                 ):
784                     sources_to_cache.append(src)
785                 report.done(src, changed)
786     if cancelled:
787         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
788     if sources_to_cache:
789         write_cache(cache, sources_to_cache, mode)
790
791
792 def format_file_in_place(
793     src: Path,
794     fast: bool,
795     mode: Mode,
796     write_back: WriteBack = WriteBack.NO,
797     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
798 ) -> bool:
799     """Format file under `src` path. Return True if changed.
800
801     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
802     code to the file.
803     `mode` and `fast` options are passed to :func:`format_file_contents`.
804     """
805     if src.suffix == ".pyi":
806         mode = replace(mode, is_pyi=True)
807
808     then = datetime.utcfromtimestamp(src.stat().st_mtime)
809     with open(src, "rb") as buf:
810         src_contents, encoding, newline = decode_bytes(buf.read())
811     try:
812         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
813     except NothingChanged:
814         return False
815
816     if write_back == WriteBack.YES:
817         with open(src, "w", encoding=encoding, newline=newline) as f:
818             f.write(dst_contents)
819     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
820         now = datetime.utcnow()
821         src_name = f"{src}\t{then} +0000"
822         dst_name = f"{src}\t{now} +0000"
823         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
824
825         if write_back == write_back.COLOR_DIFF:
826             diff_contents = color_diff(diff_contents)
827
828         with lock or nullcontext():
829             f = io.TextIOWrapper(
830                 sys.stdout.buffer,
831                 encoding=encoding,
832                 newline=newline,
833                 write_through=True,
834             )
835             f = wrap_stream_for_windows(f)
836             f.write(diff_contents)
837             f.detach()
838
839     return True
840
841
842 def color_diff(contents: str) -> str:
843     """Inject the ANSI color codes to the diff."""
844     lines = contents.split("\n")
845     for i, line in enumerate(lines):
846         if line.startswith("+++") or line.startswith("---"):
847             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
848         if line.startswith("@@"):
849             line = "\033[36m" + line + "\033[0m"  # cyan, reset
850         if line.startswith("+"):
851             line = "\033[32m" + line + "\033[0m"  # green, reset
852         elif line.startswith("-"):
853             line = "\033[31m" + line + "\033[0m"  # red, reset
854         lines[i] = line
855     return "\n".join(lines)
856
857
858 def wrap_stream_for_windows(
859     f: io.TextIOWrapper,
860 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
861     """
862     Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
863
864     If `colorama` is not found, then no change is made. If `colorama` does
865     exist, then it handles the logic to determine whether or not to change
866     things.
867     """
868     try:
869         from colorama import initialise
870
871         # We set `strip=False` so that we can don't have to modify
872         # test_express_diff_with_color.
873         f = initialise.wrap_stream(
874             f, convert=None, strip=False, autoreset=False, wrap=True
875         )
876
877         # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
878         # which does not have a `detach()` method. So we fake one.
879         f.detach = lambda *args, **kwargs: None  # type: ignore
880     except ImportError:
881         pass
882
883     return f
884
885
886 def format_stdin_to_stdout(
887     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
888 ) -> bool:
889     """Format file on stdin. Return True if changed.
890
891     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
892     write a diff to stdout. The `mode` argument is passed to
893     :func:`format_file_contents`.
894     """
895     then = datetime.utcnow()
896     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
897     dst = src
898     try:
899         dst = format_file_contents(src, fast=fast, mode=mode)
900         return True
901
902     except NothingChanged:
903         return False
904
905     finally:
906         f = io.TextIOWrapper(
907             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
908         )
909         if write_back == WriteBack.YES:
910             f.write(dst)
911         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
912             now = datetime.utcnow()
913             src_name = f"STDIN\t{then} +0000"
914             dst_name = f"STDOUT\t{now} +0000"
915             d = diff(src, dst, src_name, dst_name)
916             if write_back == WriteBack.COLOR_DIFF:
917                 d = color_diff(d)
918                 f = wrap_stream_for_windows(f)
919             f.write(d)
920         f.detach()
921
922
923 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
924     """Reformat contents a file and return new contents.
925
926     If `fast` is False, additionally confirm that the reformatted code is
927     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
928     `mode` is passed to :func:`format_str`.
929     """
930     if src_contents.strip() == "":
931         raise NothingChanged
932
933     dst_contents = format_str(src_contents, mode=mode)
934     if src_contents == dst_contents:
935         raise NothingChanged
936
937     if not fast:
938         assert_equivalent(src_contents, dst_contents)
939         assert_stable(src_contents, dst_contents, mode=mode)
940     return dst_contents
941
942
943 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
944     """Reformat a string and return new contents.
945
946     `mode` determines formatting options, such as how many characters per line are
947     allowed.  Example:
948
949     >>> import black
950     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
951     def f(arg: str = "") -> None:
952         ...
953
954     A more complex example:
955     >>> print(
956     ...   black.format_str(
957     ...     "def f(arg:str='')->None: hey",
958     ...     mode=black.Mode(
959     ...       target_versions={black.TargetVersion.PY36},
960     ...       line_length=10,
961     ...       string_normalization=False,
962     ...       is_pyi=False,
963     ...     ),
964     ...   ),
965     ... )
966     def f(
967         arg: str = '',
968     ) -> None:
969         hey
970
971     """
972     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
973     dst_contents = []
974     future_imports = get_future_imports(src_node)
975     if mode.target_versions:
976         versions = mode.target_versions
977     else:
978         versions = detect_target_versions(src_node)
979     normalize_fmt_off(src_node)
980     lines = LineGenerator(
981         remove_u_prefix="unicode_literals" in future_imports
982         or supports_feature(versions, Feature.UNICODE_LITERALS),
983         is_pyi=mode.is_pyi,
984         normalize_strings=mode.string_normalization,
985     )
986     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
987     empty_line = Line()
988     after = 0
989     split_line_features = {
990         feature
991         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
992         if supports_feature(versions, feature)
993     }
994     for current_line in lines.visit(src_node):
995         dst_contents.append(str(empty_line) * after)
996         before, after = elt.maybe_empty_lines(current_line)
997         dst_contents.append(str(empty_line) * before)
998         for line in transform_line(
999             current_line, mode=mode, features=split_line_features
1000         ):
1001             dst_contents.append(str(line))
1002     return "".join(dst_contents)
1003
1004
1005 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1006     """Return a tuple of (decoded_contents, encoding, newline).
1007
1008     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1009     universal newlines (i.e. only contains LF).
1010     """
1011     srcbuf = io.BytesIO(src)
1012     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1013     if not lines:
1014         return "", encoding, "\n"
1015
1016     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1017     srcbuf.seek(0)
1018     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1019         return tiow.read(), encoding, newline
1020
1021
1022 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1023     if not target_versions:
1024         # No target_version specified, so try all grammars.
1025         return [
1026             # Python 3.7+
1027             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1028             # Python 3.0-3.6
1029             pygram.python_grammar_no_print_statement_no_exec_statement,
1030             # Python 2.7 with future print_function import
1031             pygram.python_grammar_no_print_statement,
1032             # Python 2.7
1033             pygram.python_grammar,
1034         ]
1035
1036     if all(version.is_python2() for version in target_versions):
1037         # Python 2-only code, so try Python 2 grammars.
1038         return [
1039             # Python 2.7 with future print_function import
1040             pygram.python_grammar_no_print_statement,
1041             # Python 2.7
1042             pygram.python_grammar,
1043         ]
1044
1045     # Python 3-compatible code, so only try Python 3 grammar.
1046     grammars = []
1047     # If we have to parse both, try to parse async as a keyword first
1048     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1049         # Python 3.7+
1050         grammars.append(
1051             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1052         )
1053     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1054         # Python 3.0-3.6
1055         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1056     # At least one of the above branches must have been taken, because every Python
1057     # version has exactly one of the two 'ASYNC_*' flags
1058     return grammars
1059
1060
1061 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1062     """Given a string with source, return the lib2to3 Node."""
1063     if src_txt[-1:] != "\n":
1064         src_txt += "\n"
1065
1066     for grammar in get_grammars(set(target_versions)):
1067         drv = driver.Driver(grammar, pytree.convert)
1068         try:
1069             result = drv.parse_string(src_txt, True)
1070             break
1071
1072         except ParseError as pe:
1073             lineno, column = pe.context[1]
1074             lines = src_txt.splitlines()
1075             try:
1076                 faulty_line = lines[lineno - 1]
1077             except IndexError:
1078                 faulty_line = "<line number missing in source>"
1079             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1080     else:
1081         raise exc from None
1082
1083     if isinstance(result, Leaf):
1084         result = Node(syms.file_input, [result])
1085     return result
1086
1087
1088 def lib2to3_unparse(node: Node) -> str:
1089     """Given a lib2to3 node, return its string representation."""
1090     code = str(node)
1091     return code
1092
1093
1094 class Visitor(Generic[T]):
1095     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1096
1097     def visit(self, node: LN) -> Iterator[T]:
1098         """Main method to visit `node` and its children.
1099
1100         It tries to find a `visit_*()` method for the given `node.type`, like
1101         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1102         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1103         instead.
1104
1105         Then yields objects of type `T` from the selected visitor.
1106         """
1107         if node.type < 256:
1108             name = token.tok_name[node.type]
1109         else:
1110             name = str(type_repr(node.type))
1111         # We explicitly branch on whether a visitor exists (instead of
1112         # using self.visit_default as the default arg to getattr) in order
1113         # to save needing to create a bound method object and so mypyc can
1114         # generate a native call to visit_default.
1115         visitf = getattr(self, f"visit_{name}", None)
1116         if visitf:
1117             yield from visitf(node)
1118         else:
1119             yield from self.visit_default(node)
1120
1121     def visit_default(self, node: LN) -> Iterator[T]:
1122         """Default `visit_*()` implementation. Recurses to children of `node`."""
1123         if isinstance(node, Node):
1124             for child in node.children:
1125                 yield from self.visit(child)
1126
1127
1128 @dataclass
1129 class DebugVisitor(Visitor[T]):
1130     tree_depth: int = 0
1131
1132     def visit_default(self, node: LN) -> Iterator[T]:
1133         indent = " " * (2 * self.tree_depth)
1134         if isinstance(node, Node):
1135             _type = type_repr(node.type)
1136             out(f"{indent}{_type}", fg="yellow")
1137             self.tree_depth += 1
1138             for child in node.children:
1139                 yield from self.visit(child)
1140
1141             self.tree_depth -= 1
1142             out(f"{indent}/{_type}", fg="yellow", bold=False)
1143         else:
1144             _type = token.tok_name.get(node.type, str(node.type))
1145             out(f"{indent}{_type}", fg="blue", nl=False)
1146             if node.prefix:
1147                 # We don't have to handle prefixes for `Node` objects since
1148                 # that delegates to the first child anyway.
1149                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1150             out(f" {node.value!r}", fg="blue", bold=False)
1151
1152     @classmethod
1153     def show(cls, code: Union[str, Leaf, Node]) -> None:
1154         """Pretty-print the lib2to3 AST of a given string of `code`.
1155
1156         Convenience method for debugging.
1157         """
1158         v: DebugVisitor[None] = DebugVisitor()
1159         if isinstance(code, str):
1160             code = lib2to3_parse(code)
1161         list(v.visit(code))
1162
1163
1164 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1165 STATEMENT: Final = {
1166     syms.if_stmt,
1167     syms.while_stmt,
1168     syms.for_stmt,
1169     syms.try_stmt,
1170     syms.except_clause,
1171     syms.with_stmt,
1172     syms.funcdef,
1173     syms.classdef,
1174 }
1175 STANDALONE_COMMENT: Final = 153
1176 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1177 LOGIC_OPERATORS: Final = {"and", "or"}
1178 COMPARATORS: Final = {
1179     token.LESS,
1180     token.GREATER,
1181     token.EQEQUAL,
1182     token.NOTEQUAL,
1183     token.LESSEQUAL,
1184     token.GREATEREQUAL,
1185 }
1186 MATH_OPERATORS: Final = {
1187     token.VBAR,
1188     token.CIRCUMFLEX,
1189     token.AMPER,
1190     token.LEFTSHIFT,
1191     token.RIGHTSHIFT,
1192     token.PLUS,
1193     token.MINUS,
1194     token.STAR,
1195     token.SLASH,
1196     token.DOUBLESLASH,
1197     token.PERCENT,
1198     token.AT,
1199     token.TILDE,
1200     token.DOUBLESTAR,
1201 }
1202 STARS: Final = {token.STAR, token.DOUBLESTAR}
1203 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1204 VARARGS_PARENTS: Final = {
1205     syms.arglist,
1206     syms.argument,  # double star in arglist
1207     syms.trailer,  # single argument to call
1208     syms.typedargslist,
1209     syms.varargslist,  # lambdas
1210 }
1211 UNPACKING_PARENTS: Final = {
1212     syms.atom,  # single element of a list or set literal
1213     syms.dictsetmaker,
1214     syms.listmaker,
1215     syms.testlist_gexp,
1216     syms.testlist_star_expr,
1217 }
1218 TEST_DESCENDANTS: Final = {
1219     syms.test,
1220     syms.lambdef,
1221     syms.or_test,
1222     syms.and_test,
1223     syms.not_test,
1224     syms.comparison,
1225     syms.star_expr,
1226     syms.expr,
1227     syms.xor_expr,
1228     syms.and_expr,
1229     syms.shift_expr,
1230     syms.arith_expr,
1231     syms.trailer,
1232     syms.term,
1233     syms.power,
1234 }
1235 ASSIGNMENTS: Final = {
1236     "=",
1237     "+=",
1238     "-=",
1239     "*=",
1240     "@=",
1241     "/=",
1242     "%=",
1243     "&=",
1244     "|=",
1245     "^=",
1246     "<<=",
1247     ">>=",
1248     "**=",
1249     "//=",
1250 }
1251 COMPREHENSION_PRIORITY: Final = 20
1252 COMMA_PRIORITY: Final = 18
1253 TERNARY_PRIORITY: Final = 16
1254 LOGIC_PRIORITY: Final = 14
1255 STRING_PRIORITY: Final = 12
1256 COMPARATOR_PRIORITY: Final = 10
1257 MATH_PRIORITIES: Final = {
1258     token.VBAR: 9,
1259     token.CIRCUMFLEX: 8,
1260     token.AMPER: 7,
1261     token.LEFTSHIFT: 6,
1262     token.RIGHTSHIFT: 6,
1263     token.PLUS: 5,
1264     token.MINUS: 5,
1265     token.STAR: 4,
1266     token.SLASH: 4,
1267     token.DOUBLESLASH: 4,
1268     token.PERCENT: 4,
1269     token.AT: 4,
1270     token.TILDE: 3,
1271     token.DOUBLESTAR: 2,
1272 }
1273 DOT_PRIORITY: Final = 1
1274
1275
1276 @dataclass
1277 class BracketTracker:
1278     """Keeps track of brackets on a line."""
1279
1280     depth: int = 0
1281     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1282     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1283     previous: Optional[Leaf] = None
1284     _for_loop_depths: List[int] = field(default_factory=list)
1285     _lambda_argument_depths: List[int] = field(default_factory=list)
1286
1287     def mark(self, leaf: Leaf) -> None:
1288         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1289
1290         All leaves receive an int `bracket_depth` field that stores how deep
1291         within brackets a given leaf is. 0 means there are no enclosing brackets
1292         that started on this line.
1293
1294         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1295         field that it forms a pair with. This is a one-directional link to
1296         avoid reference cycles.
1297
1298         If a leaf is a delimiter (a token on which Black can split the line if
1299         needed) and it's on depth 0, its `id()` is stored in the tracker's
1300         `delimiters` field.
1301         """
1302         if leaf.type == token.COMMENT:
1303             return
1304
1305         self.maybe_decrement_after_for_loop_variable(leaf)
1306         self.maybe_decrement_after_lambda_arguments(leaf)
1307         if leaf.type in CLOSING_BRACKETS:
1308             self.depth -= 1
1309             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1310             leaf.opening_bracket = opening_bracket
1311         leaf.bracket_depth = self.depth
1312         if self.depth == 0:
1313             delim = is_split_before_delimiter(leaf, self.previous)
1314             if delim and self.previous is not None:
1315                 self.delimiters[id(self.previous)] = delim
1316             else:
1317                 delim = is_split_after_delimiter(leaf, self.previous)
1318                 if delim:
1319                     self.delimiters[id(leaf)] = delim
1320         if leaf.type in OPENING_BRACKETS:
1321             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1322             self.depth += 1
1323         self.previous = leaf
1324         self.maybe_increment_lambda_arguments(leaf)
1325         self.maybe_increment_for_loop_variable(leaf)
1326
1327     def any_open_brackets(self) -> bool:
1328         """Return True if there is an yet unmatched open bracket on the line."""
1329         return bool(self.bracket_match)
1330
1331     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1332         """Return the highest priority of a delimiter found on the line.
1333
1334         Values are consistent with what `is_split_*_delimiter()` return.
1335         Raises ValueError on no delimiters.
1336         """
1337         return max(v for k, v in self.delimiters.items() if k not in exclude)
1338
1339     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1340         """Return the number of delimiters with the given `priority`.
1341
1342         If no `priority` is passed, defaults to max priority on the line.
1343         """
1344         if not self.delimiters:
1345             return 0
1346
1347         priority = priority or self.max_delimiter_priority()
1348         return sum(1 for p in self.delimiters.values() if p == priority)
1349
1350     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1351         """In a for loop, or comprehension, the variables are often unpacks.
1352
1353         To avoid splitting on the comma in this situation, increase the depth of
1354         tokens between `for` and `in`.
1355         """
1356         if leaf.type == token.NAME and leaf.value == "for":
1357             self.depth += 1
1358             self._for_loop_depths.append(self.depth)
1359             return True
1360
1361         return False
1362
1363     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1364         """See `maybe_increment_for_loop_variable` above for explanation."""
1365         if (
1366             self._for_loop_depths
1367             and self._for_loop_depths[-1] == self.depth
1368             and leaf.type == token.NAME
1369             and leaf.value == "in"
1370         ):
1371             self.depth -= 1
1372             self._for_loop_depths.pop()
1373             return True
1374
1375         return False
1376
1377     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1378         """In a lambda expression, there might be more than one argument.
1379
1380         To avoid splitting on the comma in this situation, increase the depth of
1381         tokens between `lambda` and `:`.
1382         """
1383         if leaf.type == token.NAME and leaf.value == "lambda":
1384             self.depth += 1
1385             self._lambda_argument_depths.append(self.depth)
1386             return True
1387
1388         return False
1389
1390     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1391         """See `maybe_increment_lambda_arguments` above for explanation."""
1392         if (
1393             self._lambda_argument_depths
1394             and self._lambda_argument_depths[-1] == self.depth
1395             and leaf.type == token.COLON
1396         ):
1397             self.depth -= 1
1398             self._lambda_argument_depths.pop()
1399             return True
1400
1401         return False
1402
1403     def get_open_lsqb(self) -> Optional[Leaf]:
1404         """Return the most recent opening square bracket (if any)."""
1405         return self.bracket_match.get((self.depth - 1, token.RSQB))
1406
1407
1408 @dataclass
1409 class Line:
1410     """Holds leaves and comments. Can be printed with `str(line)`."""
1411
1412     depth: int = 0
1413     leaves: List[Leaf] = field(default_factory=list)
1414     # keys ordered like `leaves`
1415     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1416     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1417     inside_brackets: bool = False
1418     should_explode: bool = False
1419
1420     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1421         """Add a new `leaf` to the end of the line.
1422
1423         Unless `preformatted` is True, the `leaf` will receive a new consistent
1424         whitespace prefix and metadata applied by :class:`BracketTracker`.
1425         Trailing commas are maybe removed, unpacked for loop variables are
1426         demoted from being delimiters.
1427
1428         Inline comments are put aside.
1429         """
1430         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1431         if not has_value:
1432             return
1433
1434         if token.COLON == leaf.type and self.is_class_paren_empty:
1435             del self.leaves[-2:]
1436         if self.leaves and not preformatted:
1437             # Note: at this point leaf.prefix should be empty except for
1438             # imports, for which we only preserve newlines.
1439             leaf.prefix += whitespace(
1440                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1441             )
1442         if self.inside_brackets or not preformatted:
1443             self.bracket_tracker.mark(leaf)
1444             self.maybe_remove_trailing_comma(leaf)
1445         if not self.append_comment(leaf):
1446             self.leaves.append(leaf)
1447
1448     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1449         """Like :func:`append()` but disallow invalid standalone comment structure.
1450
1451         Raises ValueError when any `leaf` is appended after a standalone comment
1452         or when a standalone comment is not the first leaf on the line.
1453         """
1454         if self.bracket_tracker.depth == 0:
1455             if self.is_comment:
1456                 raise ValueError("cannot append to standalone comments")
1457
1458             if self.leaves and leaf.type == STANDALONE_COMMENT:
1459                 raise ValueError(
1460                     "cannot append standalone comments to a populated line"
1461                 )
1462
1463         self.append(leaf, preformatted=preformatted)
1464
1465     @property
1466     def is_comment(self) -> bool:
1467         """Is this line a standalone comment?"""
1468         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1469
1470     @property
1471     def is_decorator(self) -> bool:
1472         """Is this line a decorator?"""
1473         return bool(self) and self.leaves[0].type == token.AT
1474
1475     @property
1476     def is_import(self) -> bool:
1477         """Is this an import line?"""
1478         return bool(self) and is_import(self.leaves[0])
1479
1480     @property
1481     def is_class(self) -> bool:
1482         """Is this line a class definition?"""
1483         return (
1484             bool(self)
1485             and self.leaves[0].type == token.NAME
1486             and self.leaves[0].value == "class"
1487         )
1488
1489     @property
1490     def is_stub_class(self) -> bool:
1491         """Is this line a class definition with a body consisting only of "..."?"""
1492         return self.is_class and self.leaves[-3:] == [
1493             Leaf(token.DOT, ".") for _ in range(3)
1494         ]
1495
1496     @property
1497     def is_collection_with_optional_trailing_comma(self) -> bool:
1498         """Is this line a collection literal with a trailing comma that's optional?
1499
1500         Note that the trailing comma in a 1-tuple is not optional.
1501         """
1502         if not self.leaves or len(self.leaves) < 4:
1503             return False
1504
1505         # Look for and address a trailing colon.
1506         if self.leaves[-1].type == token.COLON:
1507             closer = self.leaves[-2]
1508             close_index = -2
1509         else:
1510             closer = self.leaves[-1]
1511             close_index = -1
1512         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1513             return False
1514
1515         if closer.type == token.RPAR:
1516             # Tuples require an extra check, because if there's only
1517             # one element in the tuple removing the comma unmakes the
1518             # tuple.
1519             #
1520             # We also check for parens before looking for the trailing
1521             # comma because in some cases (eg assigning a dict
1522             # literal) the literal gets wrapped in temporary parens
1523             # during parsing. This case is covered by the
1524             # collections.py test data.
1525             opener = closer.opening_bracket
1526             for _open_index, leaf in enumerate(self.leaves):
1527                 if leaf is opener:
1528                     break
1529
1530             else:
1531                 # Couldn't find the matching opening paren, play it safe.
1532                 return False
1533
1534             commas = 0
1535             comma_depth = self.leaves[close_index - 1].bracket_depth
1536             for leaf in self.leaves[_open_index + 1 : close_index]:
1537                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1538                     commas += 1
1539             if commas > 1:
1540                 # We haven't looked yet for the trailing comma because
1541                 # we might also have caught noop parens.
1542                 return self.leaves[close_index - 1].type == token.COMMA
1543
1544             elif commas == 1:
1545                 return False  # it's either a one-tuple or didn't have a trailing comma
1546
1547             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1548                 close_index -= 1
1549                 closer = self.leaves[close_index]
1550                 if closer.type == token.RPAR:
1551                     # TODO: this is a gut feeling. Will we ever see this?
1552                     return False
1553
1554         if self.leaves[close_index - 1].type != token.COMMA:
1555             return False
1556
1557         return True
1558
1559     @property
1560     def is_def(self) -> bool:
1561         """Is this a function definition? (Also returns True for async defs.)"""
1562         try:
1563             first_leaf = self.leaves[0]
1564         except IndexError:
1565             return False
1566
1567         try:
1568             second_leaf: Optional[Leaf] = self.leaves[1]
1569         except IndexError:
1570             second_leaf = None
1571         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1572             first_leaf.type == token.ASYNC
1573             and second_leaf is not None
1574             and second_leaf.type == token.NAME
1575             and second_leaf.value == "def"
1576         )
1577
1578     @property
1579     def is_class_paren_empty(self) -> bool:
1580         """Is this a class with no base classes but using parentheses?
1581
1582         Those are unnecessary and should be removed.
1583         """
1584         return (
1585             bool(self)
1586             and len(self.leaves) == 4
1587             and self.is_class
1588             and self.leaves[2].type == token.LPAR
1589             and self.leaves[2].value == "("
1590             and self.leaves[3].type == token.RPAR
1591             and self.leaves[3].value == ")"
1592         )
1593
1594     @property
1595     def is_triple_quoted_string(self) -> bool:
1596         """Is the line a triple quoted string?"""
1597         return (
1598             bool(self)
1599             and self.leaves[0].type == token.STRING
1600             and self.leaves[0].value.startswith(('"""', "'''"))
1601         )
1602
1603     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1604         """If so, needs to be split before emitting."""
1605         for leaf in self.leaves:
1606             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1607                 return True
1608
1609         return False
1610
1611     def contains_uncollapsable_type_comments(self) -> bool:
1612         ignored_ids = set()
1613         try:
1614             last_leaf = self.leaves[-1]
1615             ignored_ids.add(id(last_leaf))
1616             if last_leaf.type == token.COMMA or (
1617                 last_leaf.type == token.RPAR and not last_leaf.value
1618             ):
1619                 # When trailing commas or optional parens are inserted by Black for
1620                 # consistency, comments after the previous last element are not moved
1621                 # (they don't have to, rendering will still be correct).  So we ignore
1622                 # trailing commas and invisible.
1623                 last_leaf = self.leaves[-2]
1624                 ignored_ids.add(id(last_leaf))
1625         except IndexError:
1626             return False
1627
1628         # A type comment is uncollapsable if it is attached to a leaf
1629         # that isn't at the end of the line (since that could cause it
1630         # to get associated to a different argument) or if there are
1631         # comments before it (since that could cause it to get hidden
1632         # behind a comment.
1633         comment_seen = False
1634         for leaf_id, comments in self.comments.items():
1635             for comment in comments:
1636                 if is_type_comment(comment):
1637                     if comment_seen or (
1638                         not is_type_comment(comment, " ignore")
1639                         and leaf_id not in ignored_ids
1640                     ):
1641                         return True
1642
1643                 comment_seen = True
1644
1645         return False
1646
1647     def contains_unsplittable_type_ignore(self) -> bool:
1648         if not self.leaves:
1649             return False
1650
1651         # If a 'type: ignore' is attached to the end of a line, we
1652         # can't split the line, because we can't know which of the
1653         # subexpressions the ignore was meant to apply to.
1654         #
1655         # We only want this to apply to actual physical lines from the
1656         # original source, though: we don't want the presence of a
1657         # 'type: ignore' at the end of a multiline expression to
1658         # justify pushing it all onto one line. Thus we
1659         # (unfortunately) need to check the actual source lines and
1660         # only report an unsplittable 'type: ignore' if this line was
1661         # one line in the original code.
1662
1663         # Grab the first and last line numbers, skipping generated leaves
1664         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1665         last_line = next(
1666             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1667         )
1668
1669         if first_line == last_line:
1670             # We look at the last two leaves since a comma or an
1671             # invisible paren could have been added at the end of the
1672             # line.
1673             for node in self.leaves[-2:]:
1674                 for comment in self.comments.get(id(node), []):
1675                     if is_type_comment(comment, " ignore"):
1676                         return True
1677
1678         return False
1679
1680     def contains_multiline_strings(self) -> bool:
1681         return any(is_multiline_string(leaf) for leaf in self.leaves)
1682
1683     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1684         """Remove trailing comma if there is one and it's safe."""
1685         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1686             return False
1687
1688         # We remove trailing commas only in the case of importing a
1689         # single name from a module.
1690         if not (
1691             self.leaves
1692             and self.is_import
1693             and len(self.leaves) > 4
1694             and self.leaves[-1].type == token.COMMA
1695             and closing.type in CLOSING_BRACKETS
1696             and self.leaves[-4].type == token.NAME
1697             and (
1698                 # regular `from foo import bar,`
1699                 self.leaves[-4].value == "import"
1700                 # `from foo import (bar as baz,)
1701                 or (
1702                     len(self.leaves) > 6
1703                     and self.leaves[-6].value == "import"
1704                     and self.leaves[-3].value == "as"
1705                 )
1706                 # `from foo import bar as baz,`
1707                 or (
1708                     len(self.leaves) > 5
1709                     and self.leaves[-5].value == "import"
1710                     and self.leaves[-3].value == "as"
1711                 )
1712             )
1713             and closing.type == token.RPAR
1714         ):
1715             return False
1716
1717         self.remove_trailing_comma()
1718         return True
1719
1720     def append_comment(self, comment: Leaf) -> bool:
1721         """Add an inline or standalone comment to the line."""
1722         if (
1723             comment.type == STANDALONE_COMMENT
1724             and self.bracket_tracker.any_open_brackets()
1725         ):
1726             comment.prefix = ""
1727             return False
1728
1729         if comment.type != token.COMMENT:
1730             return False
1731
1732         if not self.leaves:
1733             comment.type = STANDALONE_COMMENT
1734             comment.prefix = ""
1735             return False
1736
1737         last_leaf = self.leaves[-1]
1738         if (
1739             last_leaf.type == token.RPAR
1740             and not last_leaf.value
1741             and last_leaf.parent
1742             and len(list(last_leaf.parent.leaves())) <= 3
1743             and not is_type_comment(comment)
1744         ):
1745             # Comments on an optional parens wrapping a single leaf should belong to
1746             # the wrapped node except if it's a type comment. Pinning the comment like
1747             # this avoids unstable formatting caused by comment migration.
1748             if len(self.leaves) < 2:
1749                 comment.type = STANDALONE_COMMENT
1750                 comment.prefix = ""
1751                 return False
1752
1753             last_leaf = self.leaves[-2]
1754         self.comments.setdefault(id(last_leaf), []).append(comment)
1755         return True
1756
1757     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1758         """Generate comments that should appear directly after `leaf`."""
1759         return self.comments.get(id(leaf), [])
1760
1761     def remove_trailing_comma(self) -> None:
1762         """Remove the trailing comma and moves the comments attached to it."""
1763         trailing_comma = self.leaves.pop()
1764         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1765         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1766             trailing_comma_comments
1767         )
1768
1769     def is_complex_subscript(self, leaf: Leaf) -> bool:
1770         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1771         open_lsqb = self.bracket_tracker.get_open_lsqb()
1772         if open_lsqb is None:
1773             return False
1774
1775         subscript_start = open_lsqb.next_sibling
1776
1777         if isinstance(subscript_start, Node):
1778             if subscript_start.type == syms.listmaker:
1779                 return False
1780
1781             if subscript_start.type == syms.subscriptlist:
1782                 subscript_start = child_towards(subscript_start, leaf)
1783         return subscript_start is not None and any(
1784             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1785         )
1786
1787     def clone(self) -> "Line":
1788         return Line(
1789             depth=self.depth,
1790             inside_brackets=self.inside_brackets,
1791             should_explode=self.should_explode,
1792         )
1793
1794     def __str__(self) -> str:
1795         """Render the line."""
1796         if not self:
1797             return "\n"
1798
1799         indent = "    " * self.depth
1800         leaves = iter(self.leaves)
1801         first = next(leaves)
1802         res = f"{first.prefix}{indent}{first.value}"
1803         for leaf in leaves:
1804             res += str(leaf)
1805         for comment in itertools.chain.from_iterable(self.comments.values()):
1806             res += str(comment)
1807
1808         return res + "\n"
1809
1810     def __bool__(self) -> bool:
1811         """Return True if the line has leaves or comments."""
1812         return bool(self.leaves or self.comments)
1813
1814
1815 @dataclass
1816 class EmptyLineTracker:
1817     """Provides a stateful method that returns the number of potential extra
1818     empty lines needed before and after the currently processed line.
1819
1820     Note: this tracker works on lines that haven't been split yet.  It assumes
1821     the prefix of the first leaf consists of optional newlines.  Those newlines
1822     are consumed by `maybe_empty_lines()` and included in the computation.
1823     """
1824
1825     is_pyi: bool = False
1826     previous_line: Optional[Line] = None
1827     previous_after: int = 0
1828     previous_defs: List[int] = field(default_factory=list)
1829
1830     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1831         """Return the number of extra empty lines before and after the `current_line`.
1832
1833         This is for separating `def`, `async def` and `class` with extra empty
1834         lines (two on module-level).
1835         """
1836         before, after = self._maybe_empty_lines(current_line)
1837         before = (
1838             # Black should not insert empty lines at the beginning
1839             # of the file
1840             0
1841             if self.previous_line is None
1842             else before - self.previous_after
1843         )
1844         self.previous_after = after
1845         self.previous_line = current_line
1846         return before, after
1847
1848     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1849         max_allowed = 1
1850         if current_line.depth == 0:
1851             max_allowed = 1 if self.is_pyi else 2
1852         if current_line.leaves:
1853             # Consume the first leaf's extra newlines.
1854             first_leaf = current_line.leaves[0]
1855             before = first_leaf.prefix.count("\n")
1856             before = min(before, max_allowed)
1857             first_leaf.prefix = ""
1858         else:
1859             before = 0
1860         depth = current_line.depth
1861         while self.previous_defs and self.previous_defs[-1] >= depth:
1862             self.previous_defs.pop()
1863             if self.is_pyi:
1864                 before = 0 if depth else 1
1865             else:
1866                 before = 1 if depth else 2
1867         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1868             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1869
1870         if (
1871             self.previous_line
1872             and self.previous_line.is_import
1873             and not current_line.is_import
1874             and depth == self.previous_line.depth
1875         ):
1876             return (before or 1), 0
1877
1878         if (
1879             self.previous_line
1880             and self.previous_line.is_class
1881             and current_line.is_triple_quoted_string
1882         ):
1883             return before, 1
1884
1885         return before, 0
1886
1887     def _maybe_empty_lines_for_class_or_def(
1888         self, current_line: Line, before: int
1889     ) -> Tuple[int, int]:
1890         if not current_line.is_decorator:
1891             self.previous_defs.append(current_line.depth)
1892         if self.previous_line is None:
1893             # Don't insert empty lines before the first line in the file.
1894             return 0, 0
1895
1896         if self.previous_line.is_decorator:
1897             return 0, 0
1898
1899         if self.previous_line.depth < current_line.depth and (
1900             self.previous_line.is_class or self.previous_line.is_def
1901         ):
1902             return 0, 0
1903
1904         if (
1905             self.previous_line.is_comment
1906             and self.previous_line.depth == current_line.depth
1907             and before == 0
1908         ):
1909             return 0, 0
1910
1911         if self.is_pyi:
1912             if self.previous_line.depth > current_line.depth:
1913                 newlines = 1
1914             elif current_line.is_class or self.previous_line.is_class:
1915                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1916                     # No blank line between classes with an empty body
1917                     newlines = 0
1918                 else:
1919                     newlines = 1
1920             elif current_line.is_def and not self.previous_line.is_def:
1921                 # Blank line between a block of functions and a block of non-functions
1922                 newlines = 1
1923             else:
1924                 newlines = 0
1925         else:
1926             newlines = 2
1927         if current_line.depth and newlines:
1928             newlines -= 1
1929         return newlines, 0
1930
1931
1932 @dataclass
1933 class LineGenerator(Visitor[Line]):
1934     """Generates reformatted Line objects.  Empty lines are not emitted.
1935
1936     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1937     in ways that will no longer stringify to valid Python code on the tree.
1938     """
1939
1940     is_pyi: bool = False
1941     normalize_strings: bool = True
1942     current_line: Line = field(default_factory=Line)
1943     remove_u_prefix: bool = False
1944
1945     def line(self, indent: int = 0) -> Iterator[Line]:
1946         """Generate a line.
1947
1948         If the line is empty, only emit if it makes sense.
1949         If the line is too long, split it first and then generate.
1950
1951         If any lines were generated, set up a new current_line.
1952         """
1953         if not self.current_line:
1954             self.current_line.depth += indent
1955             return  # Line is empty, don't emit. Creating a new one unnecessary.
1956
1957         complete_line = self.current_line
1958         self.current_line = Line(depth=complete_line.depth + indent)
1959         yield complete_line
1960
1961     def visit_default(self, node: LN) -> Iterator[Line]:
1962         """Default `visit_*()` implementation. Recurses to children of `node`."""
1963         if isinstance(node, Leaf):
1964             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1965             for comment in generate_comments(node):
1966                 if any_open_brackets:
1967                     # any comment within brackets is subject to splitting
1968                     self.current_line.append(comment)
1969                 elif comment.type == token.COMMENT:
1970                     # regular trailing comment
1971                     self.current_line.append(comment)
1972                     yield from self.line()
1973
1974                 else:
1975                     # regular standalone comment
1976                     yield from self.line()
1977
1978                     self.current_line.append(comment)
1979                     yield from self.line()
1980
1981             normalize_prefix(node, inside_brackets=any_open_brackets)
1982             if self.normalize_strings and node.type == token.STRING:
1983                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1984                 normalize_string_quotes(node)
1985             if node.type == token.NUMBER:
1986                 normalize_numeric_literal(node)
1987             if node.type not in WHITESPACE:
1988                 self.current_line.append(node)
1989         yield from super().visit_default(node)
1990
1991     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1992         """Increase indentation level, maybe yield a line."""
1993         # In blib2to3 INDENT never holds comments.
1994         yield from self.line(+1)
1995         yield from self.visit_default(node)
1996
1997     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1998         """Decrease indentation level, maybe yield a line."""
1999         # The current line might still wait for trailing comments.  At DEDENT time
2000         # there won't be any (they would be prefixes on the preceding NEWLINE).
2001         # Emit the line then.
2002         yield from self.line()
2003
2004         # While DEDENT has no value, its prefix may contain standalone comments
2005         # that belong to the current indentation level.  Get 'em.
2006         yield from self.visit_default(node)
2007
2008         # Finally, emit the dedent.
2009         yield from self.line(-1)
2010
2011     def visit_stmt(
2012         self, node: Node, keywords: Set[str], parens: Set[str]
2013     ) -> Iterator[Line]:
2014         """Visit a statement.
2015
2016         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2017         `def`, `with`, `class`, `assert` and assignments.
2018
2019         The relevant Python language `keywords` for a given statement will be
2020         NAME leaves within it. This methods puts those on a separate line.
2021
2022         `parens` holds a set of string leaf values immediately after which
2023         invisible parens should be put.
2024         """
2025         normalize_invisible_parens(node, parens_after=parens)
2026         for child in node.children:
2027             if child.type == token.NAME and child.value in keywords:  # type: ignore
2028                 yield from self.line()
2029
2030             yield from self.visit(child)
2031
2032     def visit_suite(self, node: Node) -> Iterator[Line]:
2033         """Visit a suite."""
2034         if self.is_pyi and is_stub_suite(node):
2035             yield from self.visit(node.children[2])
2036         else:
2037             yield from self.visit_default(node)
2038
2039     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2040         """Visit a statement without nested statements."""
2041         is_suite_like = node.parent and node.parent.type in STATEMENT
2042         if is_suite_like:
2043             if self.is_pyi and is_stub_body(node):
2044                 yield from self.visit_default(node)
2045             else:
2046                 yield from self.line(+1)
2047                 yield from self.visit_default(node)
2048                 yield from self.line(-1)
2049
2050         else:
2051             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
2052                 yield from self.line()
2053             yield from self.visit_default(node)
2054
2055     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2056         """Visit `async def`, `async for`, `async with`."""
2057         yield from self.line()
2058
2059         children = iter(node.children)
2060         for child in children:
2061             yield from self.visit(child)
2062
2063             if child.type == token.ASYNC:
2064                 break
2065
2066         internal_stmt = next(children)
2067         for child in internal_stmt.children:
2068             yield from self.visit(child)
2069
2070     def visit_decorators(self, node: Node) -> Iterator[Line]:
2071         """Visit decorators."""
2072         for child in node.children:
2073             yield from self.line()
2074             yield from self.visit(child)
2075
2076     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2077         """Remove a semicolon and put the other statement on a separate line."""
2078         yield from self.line()
2079
2080     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2081         """End of file. Process outstanding comments and end with a newline."""
2082         yield from self.visit_default(leaf)
2083         yield from self.line()
2084
2085     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2086         if not self.current_line.bracket_tracker.any_open_brackets():
2087             yield from self.line()
2088         yield from self.visit_default(leaf)
2089
2090     def visit_factor(self, node: Node) -> Iterator[Line]:
2091         """Force parentheses between a unary op and a binary power:
2092
2093         -2 ** 8 -> -(2 ** 8)
2094         """
2095         _operator, operand = node.children
2096         if (
2097             operand.type == syms.power
2098             and len(operand.children) == 3
2099             and operand.children[1].type == token.DOUBLESTAR
2100         ):
2101             lpar = Leaf(token.LPAR, "(")
2102             rpar = Leaf(token.RPAR, ")")
2103             index = operand.remove() or 0
2104             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2105         yield from self.visit_default(node)
2106
2107     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2108         # Check if it's a docstring
2109         if prev_siblings_are(
2110             leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
2111         ) and is_multiline_string(leaf):
2112             prefix = "    " * self.current_line.depth
2113             docstring = fix_docstring(leaf.value[3:-3], prefix)
2114             leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
2115             normalize_string_quotes(leaf)
2116
2117         yield from self.visit_default(leaf)
2118
2119     def __post_init__(self) -> None:
2120         """You are in a twisty little maze of passages."""
2121         v = self.visit_stmt
2122         Ø: Set[str] = set()
2123         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2124         self.visit_if_stmt = partial(
2125             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2126         )
2127         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2128         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2129         self.visit_try_stmt = partial(
2130             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2131         )
2132         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2133         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2134         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2135         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2136         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2137         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2138         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2139         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2140         self.visit_async_funcdef = self.visit_async_stmt
2141         self.visit_decorated = self.visit_decorators
2142
2143
2144 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2145 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2146 OPENING_BRACKETS = set(BRACKET.keys())
2147 CLOSING_BRACKETS = set(BRACKET.values())
2148 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2149 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2150
2151
2152 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2153     """Return whitespace prefix if needed for the given `leaf`.
2154
2155     `complex_subscript` signals whether the given leaf is part of a subscription
2156     which has non-trivial arguments, like arithmetic expressions or function calls.
2157     """
2158     NO = ""
2159     SPACE = " "
2160     DOUBLESPACE = "  "
2161     t = leaf.type
2162     p = leaf.parent
2163     v = leaf.value
2164     if t in ALWAYS_NO_SPACE:
2165         return NO
2166
2167     if t == token.COMMENT:
2168         return DOUBLESPACE
2169
2170     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2171     if t == token.COLON and p.type not in {
2172         syms.subscript,
2173         syms.subscriptlist,
2174         syms.sliceop,
2175     }:
2176         return NO
2177
2178     prev = leaf.prev_sibling
2179     if not prev:
2180         prevp = preceding_leaf(p)
2181         if not prevp or prevp.type in OPENING_BRACKETS:
2182             return NO
2183
2184         if t == token.COLON:
2185             if prevp.type == token.COLON:
2186                 return NO
2187
2188             elif prevp.type != token.COMMA and not complex_subscript:
2189                 return NO
2190
2191             return SPACE
2192
2193         if prevp.type == token.EQUAL:
2194             if prevp.parent:
2195                 if prevp.parent.type in {
2196                     syms.arglist,
2197                     syms.argument,
2198                     syms.parameters,
2199                     syms.varargslist,
2200                 }:
2201                     return NO
2202
2203                 elif prevp.parent.type == syms.typedargslist:
2204                     # A bit hacky: if the equal sign has whitespace, it means we
2205                     # previously found it's a typed argument.  So, we're using
2206                     # that, too.
2207                     return prevp.prefix
2208
2209         elif prevp.type in VARARGS_SPECIALS:
2210             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2211                 return NO
2212
2213         elif prevp.type == token.COLON:
2214             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2215                 return SPACE if complex_subscript else NO
2216
2217         elif (
2218             prevp.parent
2219             and prevp.parent.type == syms.factor
2220             and prevp.type in MATH_OPERATORS
2221         ):
2222             return NO
2223
2224         elif (
2225             prevp.type == token.RIGHTSHIFT
2226             and prevp.parent
2227             and prevp.parent.type == syms.shift_expr
2228             and prevp.prev_sibling
2229             and prevp.prev_sibling.type == token.NAME
2230             and prevp.prev_sibling.value == "print"  # type: ignore
2231         ):
2232             # Python 2 print chevron
2233             return NO
2234
2235     elif prev.type in OPENING_BRACKETS:
2236         return NO
2237
2238     if p.type in {syms.parameters, syms.arglist}:
2239         # untyped function signatures or calls
2240         if not prev or prev.type != token.COMMA:
2241             return NO
2242
2243     elif p.type == syms.varargslist:
2244         # lambdas
2245         if prev and prev.type != token.COMMA:
2246             return NO
2247
2248     elif p.type == syms.typedargslist:
2249         # typed function signatures
2250         if not prev:
2251             return NO
2252
2253         if t == token.EQUAL:
2254             if prev.type != syms.tname:
2255                 return NO
2256
2257         elif prev.type == token.EQUAL:
2258             # A bit hacky: if the equal sign has whitespace, it means we
2259             # previously found it's a typed argument.  So, we're using that, too.
2260             return prev.prefix
2261
2262         elif prev.type != token.COMMA:
2263             return NO
2264
2265     elif p.type == syms.tname:
2266         # type names
2267         if not prev:
2268             prevp = preceding_leaf(p)
2269             if not prevp or prevp.type != token.COMMA:
2270                 return NO
2271
2272     elif p.type == syms.trailer:
2273         # attributes and calls
2274         if t == token.LPAR or t == token.RPAR:
2275             return NO
2276
2277         if not prev:
2278             if t == token.DOT:
2279                 prevp = preceding_leaf(p)
2280                 if not prevp or prevp.type != token.NUMBER:
2281                     return NO
2282
2283             elif t == token.LSQB:
2284                 return NO
2285
2286         elif prev.type != token.COMMA:
2287             return NO
2288
2289     elif p.type == syms.argument:
2290         # single argument
2291         if t == token.EQUAL:
2292             return NO
2293
2294         if not prev:
2295             prevp = preceding_leaf(p)
2296             if not prevp or prevp.type == token.LPAR:
2297                 return NO
2298
2299         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2300             return NO
2301
2302     elif p.type == syms.decorator:
2303         # decorators
2304         return NO
2305
2306     elif p.type == syms.dotted_name:
2307         if prev:
2308             return NO
2309
2310         prevp = preceding_leaf(p)
2311         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2312             return NO
2313
2314     elif p.type == syms.classdef:
2315         if t == token.LPAR:
2316             return NO
2317
2318         if prev and prev.type == token.LPAR:
2319             return NO
2320
2321     elif p.type in {syms.subscript, syms.sliceop}:
2322         # indexing
2323         if not prev:
2324             assert p.parent is not None, "subscripts are always parented"
2325             if p.parent.type == syms.subscriptlist:
2326                 return SPACE
2327
2328             return NO
2329
2330         elif not complex_subscript:
2331             return NO
2332
2333     elif p.type == syms.atom:
2334         if prev and t == token.DOT:
2335             # dots, but not the first one.
2336             return NO
2337
2338     elif p.type == syms.dictsetmaker:
2339         # dict unpacking
2340         if prev and prev.type == token.DOUBLESTAR:
2341             return NO
2342
2343     elif p.type in {syms.factor, syms.star_expr}:
2344         # unary ops
2345         if not prev:
2346             prevp = preceding_leaf(p)
2347             if not prevp or prevp.type in OPENING_BRACKETS:
2348                 return NO
2349
2350             prevp_parent = prevp.parent
2351             assert prevp_parent is not None
2352             if prevp.type == token.COLON and prevp_parent.type in {
2353                 syms.subscript,
2354                 syms.sliceop,
2355             }:
2356                 return NO
2357
2358             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2359                 return NO
2360
2361         elif t in {token.NAME, token.NUMBER, token.STRING}:
2362             return NO
2363
2364     elif p.type == syms.import_from:
2365         if t == token.DOT:
2366             if prev and prev.type == token.DOT:
2367                 return NO
2368
2369         elif t == token.NAME:
2370             if v == "import":
2371                 return SPACE
2372
2373             if prev and prev.type == token.DOT:
2374                 return NO
2375
2376     elif p.type == syms.sliceop:
2377         return NO
2378
2379     return SPACE
2380
2381
2382 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2383     """Return the first leaf that precedes `node`, if any."""
2384     while node:
2385         res = node.prev_sibling
2386         if res:
2387             if isinstance(res, Leaf):
2388                 return res
2389
2390             try:
2391                 return list(res.leaves())[-1]
2392
2393             except IndexError:
2394                 return None
2395
2396         node = node.parent
2397     return None
2398
2399
2400 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2401     """Return if the `node` and its previous siblings match types against the provided
2402     list of tokens; the provided `node`has its type matched against the last element in
2403     the list.  `None` can be used as the first element to declare that the start of the
2404     list is anchored at the start of its parent's children."""
2405     if not tokens:
2406         return True
2407     if tokens[-1] is None:
2408         return node is None
2409     if not node:
2410         return False
2411     if node.type != tokens[-1]:
2412         return False
2413     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2414
2415
2416 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2417     """Return the child of `ancestor` that contains `descendant`."""
2418     node: Optional[LN] = descendant
2419     while node and node.parent != ancestor:
2420         node = node.parent
2421     return node
2422
2423
2424 def container_of(leaf: Leaf) -> LN:
2425     """Return `leaf` or one of its ancestors that is the topmost container of it.
2426
2427     By "container" we mean a node where `leaf` is the very first child.
2428     """
2429     same_prefix = leaf.prefix
2430     container: LN = leaf
2431     while container:
2432         parent = container.parent
2433         if parent is None:
2434             break
2435
2436         if parent.children[0].prefix != same_prefix:
2437             break
2438
2439         if parent.type == syms.file_input:
2440             break
2441
2442         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2443             break
2444
2445         container = parent
2446     return container
2447
2448
2449 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2450     """Return the priority of the `leaf` delimiter, given a line break after it.
2451
2452     The delimiter priorities returned here are from those delimiters that would
2453     cause a line break after themselves.
2454
2455     Higher numbers are higher priority.
2456     """
2457     if leaf.type == token.COMMA:
2458         return COMMA_PRIORITY
2459
2460     return 0
2461
2462
2463 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2464     """Return the priority of the `leaf` delimiter, given a line break before it.
2465
2466     The delimiter priorities returned here are from those delimiters that would
2467     cause a line break before themselves.
2468
2469     Higher numbers are higher priority.
2470     """
2471     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2472         # * and ** might also be MATH_OPERATORS but in this case they are not.
2473         # Don't treat them as a delimiter.
2474         return 0
2475
2476     if (
2477         leaf.type == token.DOT
2478         and leaf.parent
2479         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2480         and (previous is None or previous.type in CLOSING_BRACKETS)
2481     ):
2482         return DOT_PRIORITY
2483
2484     if (
2485         leaf.type in MATH_OPERATORS
2486         and leaf.parent
2487         and leaf.parent.type not in {syms.factor, syms.star_expr}
2488     ):
2489         return MATH_PRIORITIES[leaf.type]
2490
2491     if leaf.type in COMPARATORS:
2492         return COMPARATOR_PRIORITY
2493
2494     if (
2495         leaf.type == token.STRING
2496         and previous is not None
2497         and previous.type == token.STRING
2498     ):
2499         return STRING_PRIORITY
2500
2501     if leaf.type not in {token.NAME, token.ASYNC}:
2502         return 0
2503
2504     if (
2505         leaf.value == "for"
2506         and leaf.parent
2507         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2508         or leaf.type == token.ASYNC
2509     ):
2510         if (
2511             not isinstance(leaf.prev_sibling, Leaf)
2512             or leaf.prev_sibling.value != "async"
2513         ):
2514             return COMPREHENSION_PRIORITY
2515
2516     if (
2517         leaf.value == "if"
2518         and leaf.parent
2519         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2520     ):
2521         return COMPREHENSION_PRIORITY
2522
2523     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2524         return TERNARY_PRIORITY
2525
2526     if leaf.value == "is":
2527         return COMPARATOR_PRIORITY
2528
2529     if (
2530         leaf.value == "in"
2531         and leaf.parent
2532         and leaf.parent.type in {syms.comp_op, syms.comparison}
2533         and not (
2534             previous is not None
2535             and previous.type == token.NAME
2536             and previous.value == "not"
2537         )
2538     ):
2539         return COMPARATOR_PRIORITY
2540
2541     if (
2542         leaf.value == "not"
2543         and leaf.parent
2544         and leaf.parent.type == syms.comp_op
2545         and not (
2546             previous is not None
2547             and previous.type == token.NAME
2548             and previous.value == "is"
2549         )
2550     ):
2551         return COMPARATOR_PRIORITY
2552
2553     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2554         return LOGIC_PRIORITY
2555
2556     return 0
2557
2558
2559 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2560 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2561
2562
2563 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2564     """Clean the prefix of the `leaf` and generate comments from it, if any.
2565
2566     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2567     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2568     move because it does away with modifying the grammar to include all the
2569     possible places in which comments can be placed.
2570
2571     The sad consequence for us though is that comments don't "belong" anywhere.
2572     This is why this function generates simple parentless Leaf objects for
2573     comments.  We simply don't know what the correct parent should be.
2574
2575     No matter though, we can live without this.  We really only need to
2576     differentiate between inline and standalone comments.  The latter don't
2577     share the line with any code.
2578
2579     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2580     are emitted with a fake STANDALONE_COMMENT token identifier.
2581     """
2582     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2583         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2584
2585
2586 @dataclass
2587 class ProtoComment:
2588     """Describes a piece of syntax that is a comment.
2589
2590     It's not a :class:`blib2to3.pytree.Leaf` so that:
2591
2592     * it can be cached (`Leaf` objects should not be reused more than once as
2593       they store their lineno, column, prefix, and parent information);
2594     * `newlines` and `consumed` fields are kept separate from the `value`. This
2595       simplifies handling of special marker comments like ``# fmt: off/on``.
2596     """
2597
2598     type: int  # token.COMMENT or STANDALONE_COMMENT
2599     value: str  # content of the comment
2600     newlines: int  # how many newlines before the comment
2601     consumed: int  # how many characters of the original leaf's prefix did we consume
2602
2603
2604 @lru_cache(maxsize=4096)
2605 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2606     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2607     result: List[ProtoComment] = []
2608     if not prefix or "#" not in prefix:
2609         return result
2610
2611     consumed = 0
2612     nlines = 0
2613     ignored_lines = 0
2614     for index, line in enumerate(prefix.split("\n")):
2615         consumed += len(line) + 1  # adding the length of the split '\n'
2616         line = line.lstrip()
2617         if not line:
2618             nlines += 1
2619         if not line.startswith("#"):
2620             # Escaped newlines outside of a comment are not really newlines at
2621             # all. We treat a single-line comment following an escaped newline
2622             # as a simple trailing comment.
2623             if line.endswith("\\"):
2624                 ignored_lines += 1
2625             continue
2626
2627         if index == ignored_lines and not is_endmarker:
2628             comment_type = token.COMMENT  # simple trailing comment
2629         else:
2630             comment_type = STANDALONE_COMMENT
2631         comment = make_comment(line)
2632         result.append(
2633             ProtoComment(
2634                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2635             )
2636         )
2637         nlines = 0
2638     return result
2639
2640
2641 def make_comment(content: str) -> str:
2642     """Return a consistently formatted comment from the given `content` string.
2643
2644     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2645     space between the hash sign and the content.
2646
2647     If `content` didn't start with a hash sign, one is provided.
2648     """
2649     content = content.rstrip()
2650     if not content:
2651         return "#"
2652
2653     if content[0] == "#":
2654         content = content[1:]
2655     if content and content[0] not in " !:#'%":
2656         content = " " + content
2657     return "#" + content
2658
2659
2660 def transform_line(
2661     line: Line, mode: Mode, features: Collection[Feature] = ()
2662 ) -> Iterator[Line]:
2663     """Transform a `line`, potentially splitting it into many lines.
2664
2665     They should fit in the allotted `line_length` but might not be able to.
2666
2667     `features` are syntactical features that may be used in the output.
2668     """
2669     if line.is_comment:
2670         yield line
2671         return
2672
2673     line_str = line_to_string(line)
2674
2675     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2676         """Initialize StringTransformer"""
2677         return ST(mode.line_length, mode.string_normalization)
2678
2679     string_merge = init_st(StringMerger)
2680     string_paren_strip = init_st(StringParenStripper)
2681     string_split = init_st(StringSplitter)
2682     string_paren_wrap = init_st(StringParenWrapper)
2683
2684     transformers: List[Transformer]
2685     if (
2686         not line.contains_uncollapsable_type_comments()
2687         and not line.should_explode
2688         and not line.is_collection_with_optional_trailing_comma
2689         and (
2690             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2691             or line.contains_unsplittable_type_ignore()
2692         )
2693         and not (line.contains_standalone_comments() and line.inside_brackets)
2694     ):
2695         # Only apply basic string preprocessing, since lines shouldn't be split here.
2696         if mode.experimental_string_processing:
2697             transformers = [string_merge, string_paren_strip]
2698         else:
2699             transformers = []
2700     elif line.is_def:
2701         transformers = [left_hand_split]
2702     else:
2703
2704         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2705             for omit in generate_trailers_to_omit(line, mode.line_length):
2706                 lines = list(
2707                     right_hand_split(line, mode.line_length, features, omit=omit)
2708                 )
2709                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2710                     yield from lines
2711                     return
2712
2713             # All splits failed, best effort split with no omits.
2714             # This mostly happens to multiline strings that are by definition
2715             # reported as not fitting a single line.
2716             # line_length=1 here was historically a bug that somehow became a feature.
2717             # See #762 and #781 for the full story.
2718             yield from right_hand_split(line, line_length=1, features=features)
2719
2720         if mode.experimental_string_processing:
2721             if line.inside_brackets:
2722                 transformers = [
2723                     string_merge,
2724                     string_paren_strip,
2725                     delimiter_split,
2726                     standalone_comment_split,
2727                     string_split,
2728                     string_paren_wrap,
2729                     rhs,
2730                 ]
2731             else:
2732                 transformers = [
2733                     string_merge,
2734                     string_paren_strip,
2735                     string_split,
2736                     string_paren_wrap,
2737                     rhs,
2738                 ]
2739         else:
2740             if line.inside_brackets:
2741                 transformers = [delimiter_split, standalone_comment_split, rhs]
2742             else:
2743                 transformers = [rhs]
2744
2745     for transform in transformers:
2746         # We are accumulating lines in `result` because we might want to abort
2747         # mission and return the original line in the end, or attempt a different
2748         # split altogether.
2749         result: List[Line] = []
2750         try:
2751             for transformed_line in transform(line, features):
2752                 if str(transformed_line).strip("\n") == line_str:
2753                     raise CannotTransform(
2754                         "Line transformer returned an unchanged result"
2755                     )
2756
2757                 result.extend(
2758                     transform_line(transformed_line, mode=mode, features=features)
2759                 )
2760         except CannotTransform:
2761             continue
2762         else:
2763             yield from result
2764             break
2765
2766     else:
2767         yield line
2768
2769
2770 @dataclass  # type: ignore
2771 class StringTransformer(ABC):
2772     """
2773     An implementation of the Transformer protocol that relies on its
2774     subclasses overriding the template methods `do_match(...)` and
2775     `do_transform(...)`.
2776
2777     This Transformer works exclusively on strings (for example, by merging
2778     or splitting them).
2779
2780     The following sections can be found among the docstrings of each concrete
2781     StringTransformer subclass.
2782
2783     Requirements:
2784         Which requirements must be met of the given Line for this
2785         StringTransformer to be applied?
2786
2787     Transformations:
2788         If the given Line meets all of the above requirements, which string
2789         transformations can you expect to be applied to it by this
2790         StringTransformer?
2791
2792     Collaborations:
2793         What contractual agreements does this StringTransformer have with other
2794         StringTransfomers? Such collaborations should be eliminated/minimized
2795         as much as possible.
2796     """
2797
2798     line_length: int
2799     normalize_strings: bool
2800
2801     @abstractmethod
2802     def do_match(self, line: Line) -> TMatchResult:
2803         """
2804         Returns:
2805             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2806             string, if a match was able to be made.
2807                 OR
2808             * Err(CannotTransform), if a match was not able to be made.
2809         """
2810
2811     @abstractmethod
2812     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2813         """
2814         Yields:
2815             * Ok(new_line) where new_line is the new transformed line.
2816                 OR
2817             * Err(CannotTransform) if the transformation failed for some reason. The
2818             `do_match(...)` template method should usually be used to reject
2819             the form of the given Line, but in some cases it is difficult to
2820             know whether or not a Line meets the StringTransformer's
2821             requirements until the transformation is already midway.
2822
2823         Side Effects:
2824             This method should NOT mutate @line directly, but it MAY mutate the
2825             Line's underlying Node structure. (WARNING: If the underlying Node
2826             structure IS altered, then this method should NOT be allowed to
2827             yield an CannotTransform after that point.)
2828         """
2829
2830     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2831         """
2832         StringTransformer instances have a call signature that mirrors that of
2833         the Transformer type.
2834
2835         Raises:
2836             CannotTransform(...) if the concrete StringTransformer class is unable
2837             to transform @line.
2838         """
2839         # Optimization to avoid calling `self.do_match(...)` when the line does
2840         # not contain any string.
2841         if not any(leaf.type == token.STRING for leaf in line.leaves):
2842             raise CannotTransform("There are no strings in this line.")
2843
2844         match_result = self.do_match(line)
2845
2846         if isinstance(match_result, Err):
2847             cant_transform = match_result.err()
2848             raise CannotTransform(
2849                 f"The string transformer {self.__class__.__name__} does not recognize"
2850                 " this line as one that it can transform."
2851             ) from cant_transform
2852
2853         string_idx = match_result.ok()
2854
2855         for line_result in self.do_transform(line, string_idx):
2856             if isinstance(line_result, Err):
2857                 cant_transform = line_result.err()
2858                 raise CannotTransform(
2859                     "StringTransformer failed while attempting to transform string."
2860                 ) from cant_transform
2861             line = line_result.ok()
2862             yield line
2863
2864
2865 @dataclass
2866 class CustomSplit:
2867     """A custom (i.e. manual) string split.
2868
2869     A single CustomSplit instance represents a single substring.
2870
2871     Examples:
2872         Consider the following string:
2873         ```
2874         "Hi there friend."
2875         " This is a custom"
2876         f" string {split}."
2877         ```
2878
2879         This string will correspond to the following three CustomSplit instances:
2880         ```
2881         CustomSplit(False, 16)
2882         CustomSplit(False, 17)
2883         CustomSplit(True, 16)
2884         ```
2885     """
2886
2887     has_prefix: bool
2888     break_idx: int
2889
2890
2891 class CustomSplitMapMixin:
2892     """
2893     This mixin class is used to map merged strings to a sequence of
2894     CustomSplits, which will then be used to re-split the strings iff none of
2895     the resultant substrings go over the configured max line length.
2896     """
2897
2898     _Key = Tuple[StringID, str]
2899     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2900
2901     @staticmethod
2902     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2903         """
2904         Returns:
2905             A unique identifier that is used internally to map @string to a
2906             group of custom splits.
2907         """
2908         return (id(string), string)
2909
2910     def add_custom_splits(
2911         self, string: str, custom_splits: Iterable[CustomSplit]
2912     ) -> None:
2913         """Custom Split Map Setter Method
2914
2915         Side Effects:
2916             Adds a mapping from @string to the custom splits @custom_splits.
2917         """
2918         key = self._get_key(string)
2919         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2920
2921     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2922         """Custom Split Map Getter Method
2923
2924         Returns:
2925             * A list of the custom splits that are mapped to @string, if any
2926             exist.
2927                 OR
2928             * [], otherwise.
2929
2930         Side Effects:
2931             Deletes the mapping between @string and its associated custom
2932             splits (which are returned to the caller).
2933         """
2934         key = self._get_key(string)
2935
2936         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2937         del self._CUSTOM_SPLIT_MAP[key]
2938
2939         return list(custom_splits)
2940
2941     def has_custom_splits(self, string: str) -> bool:
2942         """
2943         Returns:
2944             True iff @string is associated with a set of custom splits.
2945         """
2946         key = self._get_key(string)
2947         return key in self._CUSTOM_SPLIT_MAP
2948
2949
2950 class StringMerger(CustomSplitMapMixin, StringTransformer):
2951     """StringTransformer that merges strings together.
2952
2953     Requirements:
2954         (A) The line contains adjacent strings such that at most one substring
2955         has inline comments AND none of those inline comments are pragmas AND
2956         the set of all substring prefixes is either of length 1 or equal to
2957         {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed
2958         with 'r').
2959             OR
2960         (B) The line contains a string which uses line continuation backslashes.
2961
2962     Transformations:
2963         Depending on which of the two requirements above where met, either:
2964
2965         (A) The string group associated with the target string is merged.
2966             OR
2967         (B) All line-continuation backslashes are removed from the target string.
2968
2969     Collaborations:
2970         StringMerger provides custom split information to StringSplitter.
2971     """
2972
2973     def do_match(self, line: Line) -> TMatchResult:
2974         LL = line.leaves
2975
2976         is_valid_index = is_valid_index_factory(LL)
2977
2978         for (i, leaf) in enumerate(LL):
2979             if (
2980                 leaf.type == token.STRING
2981                 and is_valid_index(i + 1)
2982                 and LL[i + 1].type == token.STRING
2983             ):
2984                 return Ok(i)
2985
2986             if leaf.type == token.STRING and "\\\n" in leaf.value:
2987                 return Ok(i)
2988
2989         return TErr("This line has no strings that need merging.")
2990
2991     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2992         new_line = line
2993         rblc_result = self.__remove_backslash_line_continuation_chars(
2994             new_line, string_idx
2995         )
2996         if isinstance(rblc_result, Ok):
2997             new_line = rblc_result.ok()
2998
2999         msg_result = self.__merge_string_group(new_line, string_idx)
3000         if isinstance(msg_result, Ok):
3001             new_line = msg_result.ok()
3002
3003         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
3004             msg_cant_transform = msg_result.err()
3005             rblc_cant_transform = rblc_result.err()
3006             cant_transform = CannotTransform(
3007                 "StringMerger failed to merge any strings in this line."
3008             )
3009
3010             # Chain the errors together using `__cause__`.
3011             msg_cant_transform.__cause__ = rblc_cant_transform
3012             cant_transform.__cause__ = msg_cant_transform
3013
3014             yield Err(cant_transform)
3015         else:
3016             yield Ok(new_line)
3017
3018     @staticmethod
3019     def __remove_backslash_line_continuation_chars(
3020         line: Line, string_idx: int
3021     ) -> TResult[Line]:
3022         """
3023         Merge strings that were split across multiple lines using
3024         line-continuation backslashes.
3025
3026         Returns:
3027             Ok(new_line), if @line contains backslash line-continuation
3028             characters.
3029                 OR
3030             Err(CannotTransform), otherwise.
3031         """
3032         LL = line.leaves
3033
3034         string_leaf = LL[string_idx]
3035         if not (
3036             string_leaf.type == token.STRING
3037             and "\\\n" in string_leaf.value
3038             and not has_triple_quotes(string_leaf.value)
3039         ):
3040             return TErr(
3041                 f"String leaf {string_leaf} does not contain any backslash line"
3042                 " continuation characters."
3043             )
3044
3045         new_line = line.clone()
3046         new_line.comments = line.comments
3047         append_leaves(new_line, line, LL)
3048
3049         new_string_leaf = new_line.leaves[string_idx]
3050         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3051
3052         return Ok(new_line)
3053
3054     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3055         """
3056         Merges string group (i.e. set of adjacent strings) where the first
3057         string in the group is `line.leaves[string_idx]`.
3058
3059         Returns:
3060             Ok(new_line), if ALL of the validation checks found in
3061             __validate_msg(...) pass.
3062                 OR
3063             Err(CannotTransform), otherwise.
3064         """
3065         LL = line.leaves
3066
3067         is_valid_index = is_valid_index_factory(LL)
3068
3069         vresult = self.__validate_msg(line, string_idx)
3070         if isinstance(vresult, Err):
3071             return vresult
3072
3073         # If the string group is wrapped inside an Atom node, we must make sure
3074         # to later replace that Atom with our new (merged) string leaf.
3075         atom_node = LL[string_idx].parent
3076
3077         # We will place BREAK_MARK in between every two substrings that we
3078         # merge. We will then later go through our final result and use the
3079         # various instances of BREAK_MARK we find to add the right values to
3080         # the custom split map.
3081         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3082
3083         QUOTE = LL[string_idx].value[-1]
3084
3085         def make_naked(string: str, string_prefix: str) -> str:
3086             """Strip @string (i.e. make it a "naked" string)
3087
3088             Pre-conditions:
3089                 * assert_is_leaf_string(@string)
3090
3091             Returns:
3092                 A string that is identical to @string except that
3093                 @string_prefix has been stripped, the surrounding QUOTE
3094                 characters have been removed, and any remaining QUOTE
3095                 characters have been escaped.
3096             """
3097             assert_is_leaf_string(string)
3098
3099             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3100             naked_string = string[len(string_prefix) + 1 : -1]
3101             naked_string = re.sub(
3102                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3103             )
3104             return naked_string
3105
3106         # Holds the CustomSplit objects that will later be added to the custom
3107         # split map.
3108         custom_splits = []
3109
3110         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3111         prefix_tracker = []
3112
3113         # Sets the 'prefix' variable. This is the prefix that the final merged
3114         # string will have.
3115         next_str_idx = string_idx
3116         prefix = ""
3117         while (
3118             not prefix
3119             and is_valid_index(next_str_idx)
3120             and LL[next_str_idx].type == token.STRING
3121         ):
3122             prefix = get_string_prefix(LL[next_str_idx].value)
3123             next_str_idx += 1
3124
3125         # The next loop merges the string group. The final string will be
3126         # contained in 'S'.
3127         #
3128         # The following convenience variables are used:
3129         #
3130         #   S: string
3131         #   NS: naked string
3132         #   SS: next string
3133         #   NSS: naked next string
3134         S = ""
3135         NS = ""
3136         num_of_strings = 0
3137         next_str_idx = string_idx
3138         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3139             num_of_strings += 1
3140
3141             SS = LL[next_str_idx].value
3142             next_prefix = get_string_prefix(SS)
3143
3144             # If this is an f-string group but this substring is not prefixed
3145             # with 'f'...
3146             if "f" in prefix and "f" not in next_prefix:
3147                 # Then we must escape any braces contained in this substring.
3148                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3149
3150             NSS = make_naked(SS, next_prefix)
3151
3152             has_prefix = bool(next_prefix)
3153             prefix_tracker.append(has_prefix)
3154
3155             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3156             NS = make_naked(S, prefix)
3157
3158             next_str_idx += 1
3159
3160         S_leaf = Leaf(token.STRING, S)
3161         if self.normalize_strings:
3162             normalize_string_quotes(S_leaf)
3163
3164         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3165         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3166         for has_prefix in prefix_tracker:
3167             mark_idx = temp_string.find(BREAK_MARK)
3168             assert (
3169                 mark_idx >= 0
3170             ), "Logic error while filling the custom string breakpoint cache."
3171
3172             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3173             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3174             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3175
3176         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3177
3178         if atom_node is not None:
3179             replace_child(atom_node, string_leaf)
3180
3181         # Build the final line ('new_line') that this method will later return.
3182         new_line = line.clone()
3183         for (i, leaf) in enumerate(LL):
3184             if i == string_idx:
3185                 new_line.append(string_leaf)
3186
3187             if string_idx <= i < string_idx + num_of_strings:
3188                 for comment_leaf in line.comments_after(LL[i]):
3189                     new_line.append(comment_leaf, preformatted=True)
3190                 continue
3191
3192             append_leaves(new_line, line, [leaf])
3193
3194         self.add_custom_splits(string_leaf.value, custom_splits)
3195         return Ok(new_line)
3196
3197     @staticmethod
3198     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3199         """Validate (M)erge (S)tring (G)roup
3200
3201         Transform-time string validation logic for __merge_string_group(...).
3202
3203         Returns:
3204             * Ok(None), if ALL validation checks (listed below) pass.
3205                 OR
3206             * Err(CannotTransform), if any of the following are true:
3207                 - The target string is not in a string group (i.e. it has no
3208                   adjacent strings).
3209                 - The string group has more than one inline comment.
3210                 - The string group has an inline comment that appears to be a pragma.
3211                 - The set of all string prefixes in the string group is of
3212                   length greater than one and is not equal to {"", "f"}.
3213                 - The string group consists of raw strings.
3214         """
3215         num_of_inline_string_comments = 0
3216         set_of_prefixes = set()
3217         num_of_strings = 0
3218         for leaf in line.leaves[string_idx:]:
3219             if leaf.type != token.STRING:
3220                 # If the string group is trailed by a comma, we count the
3221                 # comments trailing the comma to be one of the string group's
3222                 # comments.
3223                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3224                     num_of_inline_string_comments += 1
3225                 break
3226
3227             if has_triple_quotes(leaf.value):
3228                 return TErr("StringMerger does NOT merge multiline strings.")
3229
3230             num_of_strings += 1
3231             prefix = get_string_prefix(leaf.value)
3232             if "r" in prefix:
3233                 return TErr("StringMerger does NOT merge raw strings.")
3234
3235             set_of_prefixes.add(prefix)
3236
3237             if id(leaf) in line.comments:
3238                 num_of_inline_string_comments += 1
3239                 if contains_pragma_comment(line.comments[id(leaf)]):
3240                     return TErr("Cannot merge strings which have pragma comments.")
3241
3242         if num_of_strings < 2:
3243             return TErr(
3244                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3245             )
3246
3247         if num_of_inline_string_comments > 1:
3248             return TErr(
3249                 f"Too many inline string comments ({num_of_inline_string_comments})."
3250             )
3251
3252         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3253             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3254
3255         return Ok(None)
3256
3257
3258 class StringParenStripper(StringTransformer):
3259     """StringTransformer that strips surrounding parentheses from strings.
3260
3261     Requirements:
3262         The line contains a string which is surrounded by parentheses and:
3263             - The target string is NOT the only argument to a function call).
3264             - If the target string contains a PERCENT, the brackets are not
3265               preceeded or followed by an operator with higher precedence than
3266               PERCENT.
3267
3268     Transformations:
3269         The parentheses mentioned in the 'Requirements' section are stripped.
3270
3271     Collaborations:
3272         StringParenStripper has its own inherent usefulness, but it is also
3273         relied on to clean up the parentheses created by StringParenWrapper (in
3274         the event that they are no longer needed).
3275     """
3276
3277     def do_match(self, line: Line) -> TMatchResult:
3278         LL = line.leaves
3279
3280         is_valid_index = is_valid_index_factory(LL)
3281
3282         for (idx, leaf) in enumerate(LL):
3283             # Should be a string...
3284             if leaf.type != token.STRING:
3285                 continue
3286
3287             # Should be preceded by a non-empty LPAR...
3288             if (
3289                 not is_valid_index(idx - 1)
3290                 or LL[idx - 1].type != token.LPAR
3291                 or is_empty_lpar(LL[idx - 1])
3292             ):
3293                 continue
3294
3295             # That LPAR should NOT be preceded by a function name or a closing
3296             # bracket (which could be a function which returns a function or a
3297             # list/dictionary that contains a function)...
3298             if is_valid_index(idx - 2) and (
3299                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3300             ):
3301                 continue
3302
3303             string_idx = idx
3304
3305             # Skip the string trailer, if one exists.
3306             string_parser = StringParser()
3307             next_idx = string_parser.parse(LL, string_idx)
3308
3309             # if the leaves in the parsed string include a PERCENT, we need to
3310             # make sure the initial LPAR is NOT preceded by an operator with
3311             # higher or equal precedence to PERCENT
3312             if is_valid_index(idx - 2):
3313                 # mypy can't quite follow unless we name this
3314                 before_lpar = LL[idx - 2]
3315                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3316                     (
3317                         before_lpar.type
3318                         in {
3319                             token.STAR,
3320                             token.AT,
3321                             token.SLASH,
3322                             token.DOUBLESLASH,
3323                             token.PERCENT,
3324                             token.TILDE,
3325                             token.DOUBLESTAR,
3326                             token.AWAIT,
3327                             token.LSQB,
3328                             token.LPAR,
3329                         }
3330                     )
3331                     or (
3332                         # only unary PLUS/MINUS
3333                         before_lpar.parent
3334                         and before_lpar.parent.type == syms.factor
3335                         and (before_lpar.type in {token.PLUS, token.MINUS})
3336                     )
3337                 ):
3338                     continue
3339
3340             # Should be followed by a non-empty RPAR...
3341             if (
3342                 is_valid_index(next_idx)
3343                 and LL[next_idx].type == token.RPAR
3344                 and not is_empty_rpar(LL[next_idx])
3345             ):
3346                 # That RPAR should NOT be followed by anything with higher
3347                 # precedence than PERCENT
3348                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3349                     token.DOUBLESTAR,
3350                     token.LSQB,
3351                     token.LPAR,
3352                     token.DOT,
3353                 }:
3354                     continue
3355
3356                 return Ok(string_idx)
3357
3358         return TErr("This line has no strings wrapped in parens.")
3359
3360     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3361         LL = line.leaves
3362
3363         string_parser = StringParser()
3364         rpar_idx = string_parser.parse(LL, string_idx)
3365
3366         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3367             if line.comments_after(leaf):
3368                 yield TErr(
3369                     "Will not strip parentheses which have comments attached to them."
3370                 )
3371
3372         new_line = line.clone()
3373         new_line.comments = line.comments.copy()
3374
3375         append_leaves(new_line, line, LL[: string_idx - 1])
3376
3377         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3378         LL[string_idx - 1].remove()
3379         replace_child(LL[string_idx], string_leaf)
3380         new_line.append(string_leaf)
3381
3382         append_leaves(
3383             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3384         )
3385
3386         LL[rpar_idx].remove()
3387
3388         yield Ok(new_line)
3389
3390
3391 class BaseStringSplitter(StringTransformer):
3392     """
3393     Abstract class for StringTransformers which transform a Line's strings by splitting
3394     them or placing them on their own lines where necessary to avoid going over
3395     the configured line length.
3396
3397     Requirements:
3398         * The target string value is responsible for the line going over the
3399         line length limit. It follows that after all of black's other line
3400         split methods have been exhausted, this line (or one of the resulting
3401         lines after all line splits are performed) would still be over the
3402         line_length limit unless we split this string.
3403             AND
3404         * The target string is NOT a "pointless" string (i.e. a string that has
3405         no parent or siblings).
3406             AND
3407         * The target string is not followed by an inline comment that appears
3408         to be a pragma.
3409             AND
3410         * The target string is not a multiline (i.e. triple-quote) string.
3411     """
3412
3413     @abstractmethod
3414     def do_splitter_match(self, line: Line) -> TMatchResult:
3415         """
3416         BaseStringSplitter asks its clients to override this method instead of
3417         `StringTransformer.do_match(...)`.
3418
3419         Follows the same protocol as `StringTransformer.do_match(...)`.
3420
3421         Refer to `help(StringTransformer.do_match)` for more information.
3422         """
3423
3424     def do_match(self, line: Line) -> TMatchResult:
3425         match_result = self.do_splitter_match(line)
3426         if isinstance(match_result, Err):
3427             return match_result
3428
3429         string_idx = match_result.ok()
3430         vresult = self.__validate(line, string_idx)
3431         if isinstance(vresult, Err):
3432             return vresult
3433
3434         return match_result
3435
3436     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3437         """
3438         Checks that @line meets all of the requirements listed in this classes'
3439         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3440         description of those requirements.
3441
3442         Returns:
3443             * Ok(None), if ALL of the requirements are met.
3444                 OR
3445             * Err(CannotTransform), if ANY of the requirements are NOT met.
3446         """
3447         LL = line.leaves
3448
3449         string_leaf = LL[string_idx]
3450
3451         max_string_length = self.__get_max_string_length(line, string_idx)
3452         if len(string_leaf.value) <= max_string_length:
3453             return TErr(
3454                 "The string itself is not what is causing this line to be too long."
3455             )
3456
3457         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3458             token.STRING,
3459             token.NEWLINE,
3460         ]:
3461             return TErr(
3462                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3463                 " no parent)."
3464             )
3465
3466         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3467             line.comments[id(line.leaves[string_idx])]
3468         ):
3469             return TErr(
3470                 "Line appears to end with an inline pragma comment. Splitting the line"
3471                 " could modify the pragma's behavior."
3472             )
3473
3474         if has_triple_quotes(string_leaf.value):
3475             return TErr("We cannot split multiline strings.")
3476
3477         return Ok(None)
3478
3479     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3480         """
3481         Calculates the max string length used when attempting to determine
3482         whether or not the target string is responsible for causing the line to
3483         go over the line length limit.
3484
3485         WARNING: This method is tightly coupled to both StringSplitter and
3486         (especially) StringParenWrapper. There is probably a better way to
3487         accomplish what is being done here.
3488
3489         Returns:
3490             max_string_length: such that `line.leaves[string_idx].value >
3491             max_string_length` implies that the target string IS responsible
3492             for causing this line to exceed the line length limit.
3493         """
3494         LL = line.leaves
3495
3496         is_valid_index = is_valid_index_factory(LL)
3497
3498         # We use the shorthand "WMA4" in comments to abbreviate "We must
3499         # account for". When giving examples, we use STRING to mean some/any
3500         # valid string.
3501         #
3502         # Finally, we use the following convenience variables:
3503         #
3504         #   P:  The leaf that is before the target string leaf.
3505         #   N:  The leaf that is after the target string leaf.
3506         #   NN: The leaf that is after N.
3507
3508         # WMA4 the whitespace at the beginning of the line.
3509         offset = line.depth * 4
3510
3511         if is_valid_index(string_idx - 1):
3512             p_idx = string_idx - 1
3513             if (
3514                 LL[string_idx - 1].type == token.LPAR
3515                 and LL[string_idx - 1].value == ""
3516                 and string_idx >= 2
3517             ):
3518                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3519                 p_idx -= 1
3520
3521             P = LL[p_idx]
3522             if P.type == token.PLUS:
3523                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3524                 offset += 2
3525
3526             if P.type == token.COMMA:
3527                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3528                 offset += 3
3529
3530             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3531                 # This conditional branch is meant to handle dictionary keys,
3532                 # variable assignments, 'return STRING' statement lines, and
3533                 # 'else STRING' ternary expression lines.
3534
3535                 # WMA4 a single space.
3536                 offset += 1
3537
3538                 # WMA4 the lengths of any leaves that came before that space.
3539                 for leaf in LL[: p_idx + 1]:
3540                     offset += len(str(leaf))
3541
3542         if is_valid_index(string_idx + 1):
3543             N = LL[string_idx + 1]
3544             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3545                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3546                 N = LL[string_idx + 2]
3547
3548             if N.type == token.COMMA:
3549                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3550                 offset += 1
3551
3552             if is_valid_index(string_idx + 2):
3553                 NN = LL[string_idx + 2]
3554
3555                 if N.type == token.DOT and NN.type == token.NAME:
3556                     # This conditional branch is meant to handle method calls invoked
3557                     # off of a string literal up to and including the LPAR character.
3558
3559                     # WMA4 the '.' character.
3560                     offset += 1
3561
3562                     if (
3563                         is_valid_index(string_idx + 3)
3564                         and LL[string_idx + 3].type == token.LPAR
3565                     ):
3566                         # WMA4 the left parenthesis character.
3567                         offset += 1
3568
3569                     # WMA4 the length of the method's name.
3570                     offset += len(NN.value)
3571
3572         has_comments = False
3573         for comment_leaf in line.comments_after(LL[string_idx]):
3574             if not has_comments:
3575                 has_comments = True
3576                 # WMA4 two spaces before the '#' character.
3577                 offset += 2
3578
3579             # WMA4 the length of the inline comment.
3580             offset += len(comment_leaf.value)
3581
3582         max_string_length = self.line_length - offset
3583         return max_string_length
3584
3585
3586 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3587     """
3588     StringTransformer that splits "atom" strings (i.e. strings which exist on
3589     lines by themselves).
3590
3591     Requirements:
3592         * The line consists ONLY of a single string (with the exception of a
3593         '+' symbol which MAY exist at the start of the line), MAYBE a string
3594         trailer, and MAYBE a trailing comma.
3595             AND
3596         * All of the requirements listed in BaseStringSplitter's docstring.
3597
3598     Transformations:
3599         The string mentioned in the 'Requirements' section is split into as
3600         many substrings as necessary to adhere to the configured line length.
3601
3602         In the final set of substrings, no substring should be smaller than
3603         MIN_SUBSTR_SIZE characters.
3604
3605         The string will ONLY be split on spaces (i.e. each new substring should
3606         start with a space).
3607
3608         If the string is an f-string, it will NOT be split in the middle of an
3609         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3610         else bar()} is an f-expression).
3611
3612         If the string that is being split has an associated set of custom split
3613         records and those custom splits will NOT result in any line going over
3614         the configured line length, those custom splits are used. Otherwise the
3615         string is split as late as possible (from left-to-right) while still
3616         adhering to the transformation rules listed above.
3617
3618     Collaborations:
3619         StringSplitter relies on StringMerger to construct the appropriate
3620         CustomSplit objects and add them to the custom split map.
3621     """
3622
3623     MIN_SUBSTR_SIZE = 6
3624     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3625     RE_FEXPR = r"""
3626     (?<!\{)\{
3627         (?:
3628             [^\{\}]
3629             | \{\{
3630             | \}\}
3631         )+?
3632     (?<!\})(?:\}\})*\}(?!\})
3633     """
3634
3635     def do_splitter_match(self, line: Line) -> TMatchResult:
3636         LL = line.leaves
3637
3638         is_valid_index = is_valid_index_factory(LL)
3639
3640         idx = 0
3641
3642         # The first leaf MAY be a '+' symbol...
3643         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3644             idx += 1
3645
3646         # The next/first leaf MAY be an empty LPAR...
3647         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3648             idx += 1
3649
3650         # The next/first leaf MUST be a string...
3651         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3652             return TErr("Line does not start with a string.")
3653
3654         string_idx = idx
3655
3656         # Skip the string trailer, if one exists.
3657         string_parser = StringParser()
3658         idx = string_parser.parse(LL, string_idx)
3659
3660         # That string MAY be followed by an empty RPAR...
3661         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3662             idx += 1
3663
3664         # That string / empty RPAR leaf MAY be followed by a comma...
3665         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3666             idx += 1
3667
3668         # But no more leaves are allowed...
3669         if is_valid_index(idx):
3670             return TErr("This line does not end with a string.")
3671
3672         return Ok(string_idx)
3673
3674     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3675         LL = line.leaves
3676
3677         QUOTE = LL[string_idx].value[-1]
3678
3679         is_valid_index = is_valid_index_factory(LL)
3680         insert_str_child = insert_str_child_factory(LL[string_idx])
3681
3682         prefix = get_string_prefix(LL[string_idx].value)
3683
3684         # We MAY choose to drop the 'f' prefix from substrings that don't
3685         # contain any f-expressions, but ONLY if the original f-string
3686         # contains at least one f-expression. Otherwise, we will alter the AST
3687         # of the program.
3688         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3689             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3690         )
3691
3692         first_string_line = True
3693         starts_with_plus = LL[0].type == token.PLUS
3694
3695         def line_needs_plus() -> bool:
3696             return first_string_line and starts_with_plus
3697
3698         def maybe_append_plus(new_line: Line) -> None:
3699             """
3700             Side Effects:
3701                 If @line starts with a plus and this is the first line we are
3702                 constructing, this function appends a PLUS leaf to @new_line
3703                 and replaces the old PLUS leaf in the node structure. Otherwise
3704                 this function does nothing.
3705             """
3706             if line_needs_plus():
3707                 plus_leaf = Leaf(token.PLUS, "+")
3708                 replace_child(LL[0], plus_leaf)
3709                 new_line.append(plus_leaf)
3710
3711         ends_with_comma = (
3712             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3713         )
3714
3715         def max_last_string() -> int:
3716             """
3717             Returns:
3718                 The max allowed length of the string value used for the last
3719                 line we will construct.
3720             """
3721             result = self.line_length
3722             result -= line.depth * 4
3723             result -= 1 if ends_with_comma else 0
3724             result -= 2 if line_needs_plus() else 0
3725             return result
3726
3727         # --- Calculate Max Break Index (for string value)
3728         # We start with the line length limit
3729         max_break_idx = self.line_length
3730         # The last index of a string of length N is N-1.
3731         max_break_idx -= 1
3732         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3733         max_break_idx -= line.depth * 4
3734         if max_break_idx < 0:
3735             yield TErr(
3736                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3737                 f" {line.depth}"
3738             )
3739             return
3740
3741         # Check if StringMerger registered any custom splits.
3742         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3743         # We use them ONLY if none of them would produce lines that exceed the
3744         # line limit.
3745         use_custom_breakpoints = bool(
3746             custom_splits
3747             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3748         )
3749
3750         # Temporary storage for the remaining chunk of the string line that
3751         # can't fit onto the line currently being constructed.
3752         rest_value = LL[string_idx].value
3753
3754         def more_splits_should_be_made() -> bool:
3755             """
3756             Returns:
3757                 True iff `rest_value` (the remaining string value from the last
3758                 split), should be split again.
3759             """
3760             if use_custom_breakpoints:
3761                 return len(custom_splits) > 1
3762             else:
3763                 return len(rest_value) > max_last_string()
3764
3765         string_line_results: List[Ok[Line]] = []
3766         while more_splits_should_be_made():
3767             if use_custom_breakpoints:
3768                 # Custom User Split (manual)
3769                 csplit = custom_splits.pop(0)
3770                 break_idx = csplit.break_idx
3771             else:
3772                 # Algorithmic Split (automatic)
3773                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3774                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3775                 if maybe_break_idx is None:
3776                     # If we are unable to algorithmically determine a good split
3777                     # and this string has custom splits registered to it, we
3778                     # fall back to using them--which means we have to start
3779                     # over from the beginning.
3780                     if custom_splits:
3781                         rest_value = LL[string_idx].value
3782                         string_line_results = []
3783                         first_string_line = True
3784                         use_custom_breakpoints = True
3785                         continue
3786
3787                     # Otherwise, we stop splitting here.
3788                     break
3789
3790                 break_idx = maybe_break_idx
3791
3792             # --- Construct `next_value`
3793             next_value = rest_value[:break_idx] + QUOTE
3794             if (
3795                 # Are we allowed to try to drop a pointless 'f' prefix?
3796                 drop_pointless_f_prefix
3797                 # If we are, will we be successful?
3798                 and next_value != self.__normalize_f_string(next_value, prefix)
3799             ):
3800                 # If the current custom split did NOT originally use a prefix,
3801                 # then `csplit.break_idx` will be off by one after removing
3802                 # the 'f' prefix.
3803                 break_idx = (
3804                     break_idx + 1
3805                     if use_custom_breakpoints and not csplit.has_prefix
3806                     else break_idx
3807                 )
3808                 next_value = rest_value[:break_idx] + QUOTE
3809                 next_value = self.__normalize_f_string(next_value, prefix)
3810
3811             # --- Construct `next_leaf`
3812             next_leaf = Leaf(token.STRING, next_value)
3813             insert_str_child(next_leaf)
3814             self.__maybe_normalize_string_quotes(next_leaf)
3815
3816             # --- Construct `next_line`
3817             next_line = line.clone()
3818             maybe_append_plus(next_line)
3819             next_line.append(next_leaf)
3820             string_line_results.append(Ok(next_line))
3821
3822             rest_value = prefix + QUOTE + rest_value[break_idx:]
3823             first_string_line = False
3824
3825         yield from string_line_results
3826
3827         if drop_pointless_f_prefix:
3828             rest_value = self.__normalize_f_string(rest_value, prefix)
3829
3830         rest_leaf = Leaf(token.STRING, rest_value)
3831         insert_str_child(rest_leaf)
3832
3833         # NOTE: I could not find a test case that verifies that the following
3834         # line is actually necessary, but it seems to be. Otherwise we risk
3835         # not normalizing the last substring, right?
3836         self.__maybe_normalize_string_quotes(rest_leaf)
3837
3838         last_line = line.clone()
3839         maybe_append_plus(last_line)
3840
3841         # If there are any leaves to the right of the target string...
3842         if is_valid_index(string_idx + 1):
3843             # We use `temp_value` here to determine how long the last line
3844             # would be if we were to append all the leaves to the right of the
3845             # target string to the last string line.
3846             temp_value = rest_value
3847             for leaf in LL[string_idx + 1 :]:
3848                 temp_value += str(leaf)
3849                 if leaf.type == token.LPAR:
3850                     break
3851
3852             # Try to fit them all on the same line with the last substring...
3853             if (
3854                 len(temp_value) <= max_last_string()
3855                 or LL[string_idx + 1].type == token.COMMA
3856             ):
3857                 last_line.append(rest_leaf)
3858                 append_leaves(last_line, line, LL[string_idx + 1 :])
3859                 yield Ok(last_line)
3860             # Otherwise, place the last substring on one line and everything
3861             # else on a line below that...
3862             else:
3863                 last_line.append(rest_leaf)
3864                 yield Ok(last_line)
3865
3866                 non_string_line = line.clone()
3867                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3868                 yield Ok(non_string_line)
3869         # Else the target string was the last leaf...
3870         else:
3871             last_line.append(rest_leaf)
3872             last_line.comments = line.comments.copy()
3873             yield Ok(last_line)
3874
3875     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3876         """
3877         This method contains the algorithm that StringSplitter uses to
3878         determine which character to split each string at.
3879
3880         Args:
3881             @string: The substring that we are attempting to split.
3882             @max_break_idx: The ideal break index. We will return this value if it
3883             meets all the necessary conditions. In the likely event that it
3884             doesn't we will try to find the closest index BELOW @max_break_idx
3885             that does. If that fails, we will expand our search by also
3886             considering all valid indices ABOVE @max_break_idx.
3887
3888         Pre-Conditions:
3889             * assert_is_leaf_string(@string)
3890             * 0 <= @max_break_idx < len(@string)
3891
3892         Returns:
3893             break_idx, if an index is able to be found that meets all of the
3894             conditions listed in the 'Transformations' section of this classes'
3895             docstring.
3896                 OR
3897             None, otherwise.
3898         """
3899         is_valid_index = is_valid_index_factory(string)
3900
3901         assert is_valid_index(max_break_idx)
3902         assert_is_leaf_string(string)
3903
3904         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3905
3906         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3907             """
3908             Yields:
3909                 All ranges of @string which, if @string were to be split there,
3910                 would result in the splitting of an f-expression (which is NOT
3911                 allowed).
3912             """
3913             nonlocal _fexpr_slices
3914
3915             if _fexpr_slices is None:
3916                 _fexpr_slices = []
3917                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3918                     _fexpr_slices.append(match.span())
3919
3920             yield from _fexpr_slices
3921
3922         is_fstring = "f" in get_string_prefix(string)
3923
3924         def breaks_fstring_expression(i: Index) -> bool:
3925             """
3926             Returns:
3927                 True iff returning @i would result in the splitting of an
3928                 f-expression (which is NOT allowed).
3929             """
3930             if not is_fstring:
3931                 return False
3932
3933             for (start, end) in fexpr_slices():
3934                 if start <= i < end:
3935                     return True
3936
3937             return False
3938
3939         def passes_all_checks(i: Index) -> bool:
3940             """
3941             Returns:
3942                 True iff ALL of the conditions listed in the 'Transformations'
3943                 section of this classes' docstring would be be met by returning @i.
3944             """
3945             is_space = string[i] == " "
3946             is_big_enough = (
3947                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3948                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3949             )
3950             return is_space and is_big_enough and not breaks_fstring_expression(i)
3951
3952         # First, we check all indices BELOW @max_break_idx.
3953         break_idx = max_break_idx
3954         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3955             break_idx -= 1
3956
3957         if not passes_all_checks(break_idx):
3958             # If that fails, we check all indices ABOVE @max_break_idx.
3959             #
3960             # If we are able to find a valid index here, the next line is going
3961             # to be longer than the specified line length, but it's probably
3962             # better than doing nothing at all.
3963             break_idx = max_break_idx + 1
3964             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3965                 break_idx += 1
3966
3967             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3968                 return None
3969
3970         return break_idx
3971
3972     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3973         if self.normalize_strings:
3974             normalize_string_quotes(leaf)
3975
3976     def __normalize_f_string(self, string: str, prefix: str) -> str:
3977         """
3978         Pre-Conditions:
3979             * assert_is_leaf_string(@string)
3980
3981         Returns:
3982             * If @string is an f-string that contains no f-expressions, we
3983             return a string identical to @string except that the 'f' prefix
3984             has been stripped and all double braces (i.e. '{{' or '}}') have
3985             been normalized (i.e. turned into '{' or '}').
3986                 OR
3987             * Otherwise, we return @string.
3988         """
3989         assert_is_leaf_string(string)
3990
3991         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3992             new_prefix = prefix.replace("f", "")
3993
3994             temp = string[len(prefix) :]
3995             temp = re.sub(r"\{\{", "{", temp)
3996             temp = re.sub(r"\}\}", "}", temp)
3997             new_string = temp
3998
3999             return f"{new_prefix}{new_string}"
4000         else:
4001             return string
4002
4003
4004 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4005     """
4006     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4007     exist on lines by themselves).
4008
4009     Requirements:
4010         All of the requirements listed in BaseStringSplitter's docstring in
4011         addition to the requirements listed below:
4012
4013         * The line is a return/yield statement, which returns/yields a string.
4014             OR
4015         * The line is part of a ternary expression (e.g. `x = y if cond else
4016         z`) such that the line starts with `else <string>`, where <string> is
4017         some string.
4018             OR
4019         * The line is an assert statement, which ends with a string.
4020             OR
4021         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4022         <string>`) such that the variable is being assigned the value of some
4023         string.
4024             OR
4025         * The line is a dictionary key assignment where some valid key is being
4026         assigned the value of some string.
4027
4028     Transformations:
4029         The chosen string is wrapped in parentheses and then split at the LPAR.
4030
4031         We then have one line which ends with an LPAR and another line that
4032         starts with the chosen string. The latter line is then split again at
4033         the RPAR. This results in the RPAR (and possibly a trailing comma)
4034         being placed on its own line.
4035
4036         NOTE: If any leaves exist to the right of the chosen string (except
4037         for a trailing comma, which would be placed after the RPAR), those
4038         leaves are placed inside the parentheses.  In effect, the chosen
4039         string is not necessarily being "wrapped" by parentheses. We can,
4040         however, count on the LPAR being placed directly before the chosen
4041         string.
4042
4043         In other words, StringParenWrapper creates "atom" strings. These
4044         can then be split again by StringSplitter, if necessary.
4045
4046     Collaborations:
4047         In the event that a string line split by StringParenWrapper is
4048         changed such that it no longer needs to be given its own line,
4049         StringParenWrapper relies on StringParenStripper to clean up the
4050         parentheses it created.
4051     """
4052
4053     def do_splitter_match(self, line: Line) -> TMatchResult:
4054         LL = line.leaves
4055
4056         string_idx = None
4057         string_idx = string_idx or self._return_match(LL)
4058         string_idx = string_idx or self._else_match(LL)
4059         string_idx = string_idx or self._assert_match(LL)
4060         string_idx = string_idx or self._assign_match(LL)
4061         string_idx = string_idx or self._dict_match(LL)
4062
4063         if string_idx is not None:
4064             string_value = line.leaves[string_idx].value
4065             # If the string has no spaces...
4066             if " " not in string_value:
4067                 # And will still violate the line length limit when split...
4068                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4069                 if len(string_value) > max_string_length:
4070                     # And has no associated custom splits...
4071                     if not self.has_custom_splits(string_value):
4072                         # Then we should NOT put this string on its own line.
4073                         return TErr(
4074                             "We do not wrap long strings in parentheses when the"
4075                             " resultant line would still be over the specified line"
4076                             " length and can't be split further by StringSplitter."
4077                         )
4078             return Ok(string_idx)
4079
4080         return TErr("This line does not contain any non-atomic strings.")
4081
4082     @staticmethod
4083     def _return_match(LL: List[Leaf]) -> Optional[int]:
4084         """
4085         Returns:
4086             string_idx such that @LL[string_idx] is equal to our target (i.e.
4087             matched) string, if this line matches the return/yield statement
4088             requirements listed in the 'Requirements' section of this classes'
4089             docstring.
4090                 OR
4091             None, otherwise.
4092         """
4093         # If this line is apart of a return/yield statement and the first leaf
4094         # contains either the "return" or "yield" keywords...
4095         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4096             0
4097         ].value in ["return", "yield"]:
4098             is_valid_index = is_valid_index_factory(LL)
4099
4100             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4101             # The next visible leaf MUST contain a string...
4102             if is_valid_index(idx) and LL[idx].type == token.STRING:
4103                 return idx
4104
4105         return None
4106
4107     @staticmethod
4108     def _else_match(LL: List[Leaf]) -> Optional[int]:
4109         """
4110         Returns:
4111             string_idx such that @LL[string_idx] is equal to our target (i.e.
4112             matched) string, if this line matches the ternary expression
4113             requirements listed in the 'Requirements' section of this classes'
4114             docstring.
4115                 OR
4116             None, otherwise.
4117         """
4118         # If this line is apart of a ternary expression and the first leaf
4119         # contains the "else" keyword...
4120         if (
4121             parent_type(LL[0]) == syms.test
4122             and LL[0].type == token.NAME
4123             and LL[0].value == "else"
4124         ):
4125             is_valid_index = is_valid_index_factory(LL)
4126
4127             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4128             # The next visible leaf MUST contain a string...
4129             if is_valid_index(idx) and LL[idx].type == token.STRING:
4130                 return idx
4131
4132         return None
4133
4134     @staticmethod
4135     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4136         """
4137         Returns:
4138             string_idx such that @LL[string_idx] is equal to our target (i.e.
4139             matched) string, if this line matches the assert statement
4140             requirements listed in the 'Requirements' section of this classes'
4141             docstring.
4142                 OR
4143             None, otherwise.
4144         """
4145         # If this line is apart of an assert statement and the first leaf
4146         # contains the "assert" keyword...
4147         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4148             is_valid_index = is_valid_index_factory(LL)
4149
4150             for (i, leaf) in enumerate(LL):
4151                 # We MUST find a comma...
4152                 if leaf.type == token.COMMA:
4153                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4154
4155                     # That comma MUST be followed by a string...
4156                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4157                         string_idx = idx
4158
4159                         # Skip the string trailer, if one exists.
4160                         string_parser = StringParser()
4161                         idx = string_parser.parse(LL, string_idx)
4162
4163                         # But no more leaves are allowed...
4164                         if not is_valid_index(idx):
4165                             return string_idx
4166
4167         return None
4168
4169     @staticmethod
4170     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4171         """
4172         Returns:
4173             string_idx such that @LL[string_idx] is equal to our target (i.e.
4174             matched) string, if this line matches the assignment statement
4175             requirements listed in the 'Requirements' section of this classes'
4176             docstring.
4177                 OR
4178             None, otherwise.
4179         """
4180         # If this line is apart of an expression statement or is a function
4181         # argument AND the first leaf contains a variable name...
4182         if (
4183             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4184             and LL[0].type == token.NAME
4185         ):
4186             is_valid_index = is_valid_index_factory(LL)
4187
4188             for (i, leaf) in enumerate(LL):
4189                 # We MUST find either an '=' or '+=' symbol...
4190                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4191                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4192
4193                     # That symbol MUST be followed by a string...
4194                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4195                         string_idx = idx
4196
4197                         # Skip the string trailer, if one exists.
4198                         string_parser = StringParser()
4199                         idx = string_parser.parse(LL, string_idx)
4200
4201                         # The next leaf MAY be a comma iff this line is apart
4202                         # of a function argument...
4203                         if (
4204                             parent_type(LL[0]) == syms.argument
4205                             and is_valid_index(idx)
4206                             and LL[idx].type == token.COMMA
4207                         ):
4208                             idx += 1
4209
4210                         # But no more leaves are allowed...
4211                         if not is_valid_index(idx):
4212                             return string_idx
4213
4214         return None
4215
4216     @staticmethod
4217     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4218         """
4219         Returns:
4220             string_idx such that @LL[string_idx] is equal to our target (i.e.
4221             matched) string, if this line matches the dictionary key assignment
4222             statement requirements listed in the 'Requirements' section of this
4223             classes' docstring.
4224                 OR
4225             None, otherwise.
4226         """
4227         # If this line is apart of a dictionary key assignment...
4228         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4229             is_valid_index = is_valid_index_factory(LL)
4230
4231             for (i, leaf) in enumerate(LL):
4232                 # We MUST find a colon...
4233                 if leaf.type == token.COLON:
4234                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4235
4236                     # That colon MUST be followed by a string...
4237                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4238                         string_idx = idx
4239
4240                         # Skip the string trailer, if one exists.
4241                         string_parser = StringParser()
4242                         idx = string_parser.parse(LL, string_idx)
4243
4244                         # That string MAY be followed by a comma...
4245                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4246                             idx += 1
4247
4248                         # But no more leaves are allowed...
4249                         if not is_valid_index(idx):
4250                             return string_idx
4251
4252         return None
4253
4254     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4255         LL = line.leaves
4256
4257         is_valid_index = is_valid_index_factory(LL)
4258         insert_str_child = insert_str_child_factory(LL[string_idx])
4259
4260         comma_idx = len(LL) - 1
4261         ends_with_comma = False
4262         if LL[comma_idx].type == token.COMMA:
4263             ends_with_comma = True
4264
4265         leaves_to_steal_comments_from = [LL[string_idx]]
4266         if ends_with_comma:
4267             leaves_to_steal_comments_from.append(LL[comma_idx])
4268
4269         # --- First Line
4270         first_line = line.clone()
4271         left_leaves = LL[:string_idx]
4272
4273         # We have to remember to account for (possibly invisible) LPAR and RPAR
4274         # leaves that already wrapped the target string. If these leaves do
4275         # exist, we will replace them with our own LPAR and RPAR leaves.
4276         old_parens_exist = False
4277         if left_leaves and left_leaves[-1].type == token.LPAR:
4278             old_parens_exist = True
4279             leaves_to_steal_comments_from.append(left_leaves[-1])
4280             left_leaves.pop()
4281
4282         append_leaves(first_line, line, left_leaves)
4283
4284         lpar_leaf = Leaf(token.LPAR, "(")
4285         if old_parens_exist:
4286             replace_child(LL[string_idx - 1], lpar_leaf)
4287         else:
4288             insert_str_child(lpar_leaf)
4289         first_line.append(lpar_leaf)
4290
4291         # We throw inline comments that were originally to the right of the
4292         # target string to the top line. They will now be shown to the right of
4293         # the LPAR.
4294         for leaf in leaves_to_steal_comments_from:
4295             for comment_leaf in line.comments_after(leaf):
4296                 first_line.append(comment_leaf, preformatted=True)
4297
4298         yield Ok(first_line)
4299
4300         # --- Middle (String) Line
4301         # We only need to yield one (possibly too long) string line, since the
4302         # `StringSplitter` will break it down further if necessary.
4303         string_value = LL[string_idx].value
4304         string_line = Line(
4305             depth=line.depth + 1,
4306             inside_brackets=True,
4307             should_explode=line.should_explode,
4308         )
4309         string_leaf = Leaf(token.STRING, string_value)
4310         insert_str_child(string_leaf)
4311         string_line.append(string_leaf)
4312
4313         old_rpar_leaf = None
4314         if is_valid_index(string_idx + 1):
4315             right_leaves = LL[string_idx + 1 :]
4316             if ends_with_comma:
4317                 right_leaves.pop()
4318
4319             if old_parens_exist:
4320                 assert (
4321                     right_leaves and right_leaves[-1].type == token.RPAR
4322                 ), "Apparently, old parentheses do NOT exist?!"
4323                 old_rpar_leaf = right_leaves.pop()
4324
4325             append_leaves(string_line, line, right_leaves)
4326
4327         yield Ok(string_line)
4328
4329         # --- Last Line
4330         last_line = line.clone()
4331         last_line.bracket_tracker = first_line.bracket_tracker
4332
4333         new_rpar_leaf = Leaf(token.RPAR, ")")
4334         if old_rpar_leaf is not None:
4335             replace_child(old_rpar_leaf, new_rpar_leaf)
4336         else:
4337             insert_str_child(new_rpar_leaf)
4338         last_line.append(new_rpar_leaf)
4339
4340         # If the target string ended with a comma, we place this comma to the
4341         # right of the RPAR on the last line.
4342         if ends_with_comma:
4343             comma_leaf = Leaf(token.COMMA, ",")
4344             replace_child(LL[comma_idx], comma_leaf)
4345             last_line.append(comma_leaf)
4346
4347         yield Ok(last_line)
4348
4349
4350 class StringParser:
4351     """
4352     A state machine that aids in parsing a string's "trailer", which can be
4353     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4354     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4355     varY)`).
4356
4357     NOTE: A new StringParser object MUST be instantiated for each string
4358     trailer we need to parse.
4359
4360     Examples:
4361         We shall assume that `line` equals the `Line` object that corresponds
4362         to the following line of python code:
4363         ```
4364         x = "Some {}.".format("String") + some_other_string
4365         ```
4366
4367         Furthermore, we will assume that `string_idx` is some index such that:
4368         ```
4369         assert line.leaves[string_idx].value == "Some {}."
4370         ```
4371
4372         The following code snippet then holds:
4373         ```
4374         string_parser = StringParser()
4375         idx = string_parser.parse(line.leaves, string_idx)
4376         assert line.leaves[idx].type == token.PLUS
4377         ```
4378     """
4379
4380     DEFAULT_TOKEN = -1
4381
4382     # String Parser States
4383     START = 1
4384     DOT = 2
4385     NAME = 3
4386     PERCENT = 4
4387     SINGLE_FMT_ARG = 5
4388     LPAR = 6
4389     RPAR = 7
4390     DONE = 8
4391
4392     # Lookup Table for Next State
4393     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4394         # A string trailer may start with '.' OR '%'.
4395         (START, token.DOT): DOT,
4396         (START, token.PERCENT): PERCENT,
4397         (START, DEFAULT_TOKEN): DONE,
4398         # A '.' MUST be followed by an attribute or method name.
4399         (DOT, token.NAME): NAME,
4400         # A method name MUST be followed by an '(', whereas an attribute name
4401         # is the last symbol in the string trailer.
4402         (NAME, token.LPAR): LPAR,
4403         (NAME, DEFAULT_TOKEN): DONE,
4404         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4405         # string or variable name).
4406         (PERCENT, token.LPAR): LPAR,
4407         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4408         # If a '%' symbol is followed by a single argument, that argument is
4409         # the last leaf in the string trailer.
4410         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4411         # If present, a ')' symbol is the last symbol in a string trailer.
4412         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4413         # since they are treated as a special case by the parsing logic in this
4414         # classes' implementation.)
4415         (RPAR, DEFAULT_TOKEN): DONE,
4416     }
4417
4418     def __init__(self) -> None:
4419         self._state = self.START
4420         self._unmatched_lpars = 0
4421
4422     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4423         """
4424         Pre-conditions:
4425             * @leaves[@string_idx].type == token.STRING
4426
4427         Returns:
4428             The index directly after the last leaf which is apart of the string
4429             trailer, if a "trailer" exists.
4430                 OR
4431             @string_idx + 1, if no string "trailer" exists.
4432         """
4433         assert leaves[string_idx].type == token.STRING
4434
4435         idx = string_idx + 1
4436         while idx < len(leaves) and self._next_state(leaves[idx]):
4437             idx += 1
4438         return idx
4439
4440     def _next_state(self, leaf: Leaf) -> bool:
4441         """
4442         Pre-conditions:
4443             * On the first call to this function, @leaf MUST be the leaf that
4444             was directly after the string leaf in question (e.g. if our target
4445             string is `line.leaves[i]` then the first call to this method must
4446             be `line.leaves[i + 1]`).
4447             * On the next call to this function, the leaf parameter passed in
4448             MUST be the leaf directly following @leaf.
4449
4450         Returns:
4451             True iff @leaf is apart of the string's trailer.
4452         """
4453         # We ignore empty LPAR or RPAR leaves.
4454         if is_empty_par(leaf):
4455             return True
4456
4457         next_token = leaf.type
4458         if next_token == token.LPAR:
4459             self._unmatched_lpars += 1
4460
4461         current_state = self._state
4462
4463         # The LPAR parser state is a special case. We will return True until we
4464         # find the matching RPAR token.
4465         if current_state == self.LPAR:
4466             if next_token == token.RPAR:
4467                 self._unmatched_lpars -= 1
4468                 if self._unmatched_lpars == 0:
4469                     self._state = self.RPAR
4470         # Otherwise, we use a lookup table to determine the next state.
4471         else:
4472             # If the lookup table matches the current state to the next
4473             # token, we use the lookup table.
4474             if (current_state, next_token) in self._goto:
4475                 self._state = self._goto[current_state, next_token]
4476             else:
4477                 # Otherwise, we check if a the current state was assigned a
4478                 # default.
4479                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4480                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4481                 # If no default has been assigned, then this parser has a logic
4482                 # error.
4483                 else:
4484                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4485
4486             if self._state == self.DONE:
4487                 return False
4488
4489         return True
4490
4491
4492 def TErr(err_msg: str) -> Err[CannotTransform]:
4493     """(T)ransform Err
4494
4495     Convenience function used when working with the TResult type.
4496     """
4497     cant_transform = CannotTransform(err_msg)
4498     return Err(cant_transform)
4499
4500
4501 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4502     """
4503     Returns:
4504         True iff one of the comments in @comment_list is a pragma used by one
4505         of the more common static analysis tools for python (e.g. mypy, flake8,
4506         pylint).
4507     """
4508     for comment in comment_list:
4509         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4510             return True
4511
4512     return False
4513
4514
4515 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4516     """
4517     Factory for a convenience function that is used to orphan @string_leaf
4518     and then insert multiple new leaves into the same part of the node
4519     structure that @string_leaf had originally occupied.
4520
4521     Examples:
4522         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4523         string_leaf.parent`. Assume the node `N` has the following
4524         original structure:
4525
4526         Node(
4527             expr_stmt, [
4528                 Leaf(NAME, 'x'),
4529                 Leaf(EQUAL, '='),
4530                 Leaf(STRING, '"foo"'),
4531             ]
4532         )
4533
4534         We then run the code snippet shown below.
4535         ```
4536         insert_str_child = insert_str_child_factory(string_leaf)
4537
4538         lpar = Leaf(token.LPAR, '(')
4539         insert_str_child(lpar)
4540
4541         bar = Leaf(token.STRING, '"bar"')
4542         insert_str_child(bar)
4543
4544         rpar = Leaf(token.RPAR, ')')
4545         insert_str_child(rpar)
4546         ```
4547
4548         After which point, it follows that `string_leaf.parent is None` and
4549         the node `N` now has the following structure:
4550
4551         Node(
4552             expr_stmt, [
4553                 Leaf(NAME, 'x'),
4554                 Leaf(EQUAL, '='),
4555                 Leaf(LPAR, '('),
4556                 Leaf(STRING, '"bar"'),
4557                 Leaf(RPAR, ')'),
4558             ]
4559         )
4560     """
4561     string_parent = string_leaf.parent
4562     string_child_idx = string_leaf.remove()
4563
4564     def insert_str_child(child: LN) -> None:
4565         nonlocal string_child_idx
4566
4567         assert string_parent is not None
4568         assert string_child_idx is not None
4569
4570         string_parent.insert_child(string_child_idx, child)
4571         string_child_idx += 1
4572
4573     return insert_str_child
4574
4575
4576 def has_triple_quotes(string: str) -> bool:
4577     """
4578     Returns:
4579         True iff @string starts with three quotation characters.
4580     """
4581     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4582     return raw_string[:3] in {'"""', "'''"}
4583
4584
4585 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4586     """
4587     Returns:
4588         @node.parent.type, if @node is not None and has a parent.
4589             OR
4590         None, otherwise.
4591     """
4592     if node is None or node.parent is None:
4593         return None
4594
4595     return node.parent.type
4596
4597
4598 def is_empty_par(leaf: Leaf) -> bool:
4599     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4600
4601
4602 def is_empty_lpar(leaf: Leaf) -> bool:
4603     return leaf.type == token.LPAR and leaf.value == ""
4604
4605
4606 def is_empty_rpar(leaf: Leaf) -> bool:
4607     return leaf.type == token.RPAR and leaf.value == ""
4608
4609
4610 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4611     """
4612     Examples:
4613         ```
4614         my_list = [1, 2, 3]
4615
4616         is_valid_index = is_valid_index_factory(my_list)
4617
4618         assert is_valid_index(0)
4619         assert is_valid_index(2)
4620
4621         assert not is_valid_index(3)
4622         assert not is_valid_index(-1)
4623         ```
4624     """
4625
4626     def is_valid_index(idx: int) -> bool:
4627         """
4628         Returns:
4629             True iff @idx is positive AND seq[@idx] does NOT raise an
4630             IndexError.
4631         """
4632         return 0 <= idx < len(seq)
4633
4634     return is_valid_index
4635
4636
4637 def line_to_string(line: Line) -> str:
4638     """Returns the string representation of @line.
4639
4640     WARNING: This is known to be computationally expensive.
4641     """
4642     return str(line).strip("\n")
4643
4644
4645 def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
4646     """
4647     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4648     underlying Node structure where appropriate.
4649
4650     All of the leaves in @leaves are duplicated. The duplicates are then
4651     appended to @new_line and used to replace their originals in the underlying
4652     Node structure. Any comments attached to the old leaves are reattached to
4653     the new leaves.
4654
4655     Pre-conditions:
4656         set(@leaves) is a subset of set(@old_line.leaves).
4657     """
4658     for old_leaf in leaves:
4659         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4660         replace_child(old_leaf, new_leaf)
4661         new_line.append(new_leaf)
4662
4663         for comment_leaf in old_line.comments_after(old_leaf):
4664             new_line.append(comment_leaf, preformatted=True)
4665
4666
4667 def replace_child(old_child: LN, new_child: LN) -> None:
4668     """
4669     Side Effects:
4670         * If @old_child.parent is set, replace @old_child with @new_child in
4671         @old_child's underlying Node structure.
4672             OR
4673         * Otherwise, this function does nothing.
4674     """
4675     parent = old_child.parent
4676     if not parent:
4677         return
4678
4679     child_idx = old_child.remove()
4680     if child_idx is not None:
4681         parent.insert_child(child_idx, new_child)
4682
4683
4684 def get_string_prefix(string: str) -> str:
4685     """
4686     Pre-conditions:
4687         * assert_is_leaf_string(@string)
4688
4689     Returns:
4690         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4691     """
4692     assert_is_leaf_string(string)
4693
4694     prefix = ""
4695     prefix_idx = 0
4696     while string[prefix_idx] in STRING_PREFIX_CHARS:
4697         prefix += string[prefix_idx].lower()
4698         prefix_idx += 1
4699
4700     return prefix
4701
4702
4703 def assert_is_leaf_string(string: str) -> None:
4704     """
4705     Checks the pre-condition that @string has the format that you would expect
4706     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4707     token.STRING`. A more precise description of the pre-conditions that are
4708     checked are listed below.
4709
4710     Pre-conditions:
4711         * @string starts with either ', ", <prefix>', or <prefix>" where
4712         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4713         * @string ends with a quote character (' or ").
4714
4715     Raises:
4716         AssertionError(...) if the pre-conditions listed above are not
4717         satisfied.
4718     """
4719     dquote_idx = string.find('"')
4720     squote_idx = string.find("'")
4721     if -1 in [dquote_idx, squote_idx]:
4722         quote_idx = max(dquote_idx, squote_idx)
4723     else:
4724         quote_idx = min(squote_idx, dquote_idx)
4725
4726     assert (
4727         0 <= quote_idx < len(string) - 1
4728     ), f"{string!r} is missing a starting quote character (' or \")."
4729     assert string[-1] in (
4730         "'",
4731         '"',
4732     ), f"{string!r} is missing an ending quote character (' or \")."
4733     assert set(string[:quote_idx]).issubset(
4734         set(STRING_PREFIX_CHARS)
4735     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4736
4737
4738 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4739     """Split line into many lines, starting with the first matching bracket pair.
4740
4741     Note: this usually looks weird, only use this for function definitions.
4742     Prefer RHS otherwise.  This is why this function is not symmetrical with
4743     :func:`right_hand_split` which also handles optional parentheses.
4744     """
4745     tail_leaves: List[Leaf] = []
4746     body_leaves: List[Leaf] = []
4747     head_leaves: List[Leaf] = []
4748     current_leaves = head_leaves
4749     matching_bracket: Optional[Leaf] = None
4750     for leaf in line.leaves:
4751         if (
4752             current_leaves is body_leaves
4753             and leaf.type in CLOSING_BRACKETS
4754             and leaf.opening_bracket is matching_bracket
4755         ):
4756             current_leaves = tail_leaves if body_leaves else head_leaves
4757         current_leaves.append(leaf)
4758         if current_leaves is head_leaves:
4759             if leaf.type in OPENING_BRACKETS:
4760                 matching_bracket = leaf
4761                 current_leaves = body_leaves
4762     if not matching_bracket:
4763         raise CannotSplit("No brackets found")
4764
4765     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4766     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4767     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4768     bracket_split_succeeded_or_raise(head, body, tail)
4769     for result in (head, body, tail):
4770         if result:
4771             yield result
4772
4773
4774 def right_hand_split(
4775     line: Line,
4776     line_length: int,
4777     features: Collection[Feature] = (),
4778     omit: Collection[LeafID] = (),
4779 ) -> Iterator[Line]:
4780     """Split line into many lines, starting with the last matching bracket pair.
4781
4782     If the split was by optional parentheses, attempt splitting without them, too.
4783     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4784     this split.
4785
4786     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4787     """
4788     tail_leaves: List[Leaf] = []
4789     body_leaves: List[Leaf] = []
4790     head_leaves: List[Leaf] = []
4791     current_leaves = tail_leaves
4792     opening_bracket: Optional[Leaf] = None
4793     closing_bracket: Optional[Leaf] = None
4794     for leaf in reversed(line.leaves):
4795         if current_leaves is body_leaves:
4796             if leaf is opening_bracket:
4797                 current_leaves = head_leaves if body_leaves else tail_leaves
4798         current_leaves.append(leaf)
4799         if current_leaves is tail_leaves:
4800             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4801                 opening_bracket = leaf.opening_bracket
4802                 closing_bracket = leaf
4803                 current_leaves = body_leaves
4804     if not (opening_bracket and closing_bracket and head_leaves):
4805         # If there is no opening or closing_bracket that means the split failed and
4806         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4807         # the matching `opening_bracket` wasn't available on `line` anymore.
4808         raise CannotSplit("No brackets found")
4809
4810     tail_leaves.reverse()
4811     body_leaves.reverse()
4812     head_leaves.reverse()
4813     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4814     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4815     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4816     bracket_split_succeeded_or_raise(head, body, tail)
4817     if (
4818         # the body shouldn't be exploded
4819         not body.should_explode
4820         # the opening bracket is an optional paren
4821         and opening_bracket.type == token.LPAR
4822         and not opening_bracket.value
4823         # the closing bracket is an optional paren
4824         and closing_bracket.type == token.RPAR
4825         and not closing_bracket.value
4826         # it's not an import (optional parens are the only thing we can split on
4827         # in this case; attempting a split without them is a waste of time)
4828         and not line.is_import
4829         # there are no standalone comments in the body
4830         and not body.contains_standalone_comments(0)
4831         # and we can actually remove the parens
4832         and can_omit_invisible_parens(body, line_length)
4833     ):
4834         omit = {id(closing_bracket), *omit}
4835         try:
4836             yield from right_hand_split(line, line_length, features=features, omit=omit)
4837             return
4838
4839         except CannotSplit:
4840             if not (
4841                 can_be_split(body)
4842                 or is_line_short_enough(body, line_length=line_length)
4843             ):
4844                 raise CannotSplit(
4845                     "Splitting failed, body is still too long and can't be split."
4846                 )
4847
4848             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4849                 raise CannotSplit(
4850                     "The current optional pair of parentheses is bound to fail to"
4851                     " satisfy the splitting algorithm because the head or the tail"
4852                     " contains multiline strings which by definition never fit one"
4853                     " line."
4854                 )
4855
4856     ensure_visible(opening_bracket)
4857     ensure_visible(closing_bracket)
4858     for result in (head, body, tail):
4859         if result:
4860             yield result
4861
4862
4863 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4864     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4865
4866     Do nothing otherwise.
4867
4868     A left- or right-hand split is based on a pair of brackets. Content before
4869     (and including) the opening bracket is left on one line, content inside the
4870     brackets is put on a separate line, and finally content starting with and
4871     following the closing bracket is put on a separate line.
4872
4873     Those are called `head`, `body`, and `tail`, respectively. If the split
4874     produced the same line (all content in `head`) or ended up with an empty `body`
4875     and the `tail` is just the closing bracket, then it's considered failed.
4876     """
4877     tail_len = len(str(tail).strip())
4878     if not body:
4879         if tail_len == 0:
4880             raise CannotSplit("Splitting brackets produced the same line")
4881
4882         elif tail_len < 3:
4883             raise CannotSplit(
4884                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4885                 " not worth it"
4886             )
4887
4888
4889 def bracket_split_build_line(
4890     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4891 ) -> Line:
4892     """Return a new line with given `leaves` and respective comments from `original`.
4893
4894     If `is_body` is True, the result line is one-indented inside brackets and as such
4895     has its first leaf's prefix normalized and a trailing comma added when expected.
4896     """
4897     result = Line(depth=original.depth)
4898     if is_body:
4899         result.inside_brackets = True
4900         result.depth += 1
4901         if leaves:
4902             # Since body is a new indent level, remove spurious leading whitespace.
4903             normalize_prefix(leaves[0], inside_brackets=True)
4904             # Ensure a trailing comma for imports and standalone function arguments, but
4905             # be careful not to add one after any comments or within type annotations.
4906             no_commas = (
4907                 original.is_def
4908                 and opening_bracket.value == "("
4909                 and not any(leaf.type == token.COMMA for leaf in leaves)
4910             )
4911
4912             if original.is_import or no_commas:
4913                 for i in range(len(leaves) - 1, -1, -1):
4914                     if leaves[i].type == STANDALONE_COMMENT:
4915                         continue
4916
4917                     if leaves[i].type != token.COMMA:
4918                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
4919                     break
4920
4921     # Populate the line
4922     for leaf in leaves:
4923         result.append(leaf, preformatted=True)
4924         for comment_after in original.comments_after(leaf):
4925             result.append(comment_after, preformatted=True)
4926     if is_body:
4927         result.should_explode = should_explode(result, opening_bracket)
4928     return result
4929
4930
4931 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4932     """Normalize prefix of the first leaf in every line returned by `split_func`.
4933
4934     This is a decorator over relevant split functions.
4935     """
4936
4937     @wraps(split_func)
4938     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4939         for line in split_func(line, features):
4940             normalize_prefix(line.leaves[0], inside_brackets=True)
4941             yield line
4942
4943     return split_wrapper
4944
4945
4946 @dont_increase_indentation
4947 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4948     """Split according to delimiters of the highest priority.
4949
4950     If the appropriate Features are given, the split will add trailing commas
4951     also in function signatures and calls that contain `*` and `**`.
4952     """
4953     try:
4954         last_leaf = line.leaves[-1]
4955     except IndexError:
4956         raise CannotSplit("Line empty")
4957
4958     bt = line.bracket_tracker
4959     try:
4960         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4961     except ValueError:
4962         raise CannotSplit("No delimiters found")
4963
4964     if delimiter_priority == DOT_PRIORITY:
4965         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4966             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4967
4968     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4969     lowest_depth = sys.maxsize
4970     trailing_comma_safe = True
4971
4972     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4973         """Append `leaf` to current line or to new line if appending impossible."""
4974         nonlocal current_line
4975         try:
4976             current_line.append_safe(leaf, preformatted=True)
4977         except ValueError:
4978             yield current_line
4979
4980             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4981             current_line.append(leaf)
4982
4983     for leaf in line.leaves:
4984         yield from append_to_line(leaf)
4985
4986         for comment_after in line.comments_after(leaf):
4987             yield from append_to_line(comment_after)
4988
4989         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4990         if leaf.bracket_depth == lowest_depth:
4991             if is_vararg(leaf, within={syms.typedargslist}):
4992                 trailing_comma_safe = (
4993                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4994                 )
4995             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4996                 trailing_comma_safe = (
4997                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4998                 )
4999
5000         leaf_priority = bt.delimiters.get(id(leaf))
5001         if leaf_priority == delimiter_priority:
5002             yield current_line
5003
5004             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5005     if current_line:
5006         if (
5007             trailing_comma_safe
5008             and delimiter_priority == COMMA_PRIORITY
5009             and current_line.leaves[-1].type != token.COMMA
5010             and current_line.leaves[-1].type != STANDALONE_COMMENT
5011         ):
5012             current_line.append(Leaf(token.COMMA, ","))
5013         yield current_line
5014
5015
5016 @dont_increase_indentation
5017 def standalone_comment_split(
5018     line: Line, features: Collection[Feature] = ()
5019 ) -> Iterator[Line]:
5020     """Split standalone comments from the rest of the line."""
5021     if not line.contains_standalone_comments(0):
5022         raise CannotSplit("Line does not have any standalone comments")
5023
5024     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5025
5026     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5027         """Append `leaf` to current line or to new line if appending impossible."""
5028         nonlocal current_line
5029         try:
5030             current_line.append_safe(leaf, preformatted=True)
5031         except ValueError:
5032             yield current_line
5033
5034             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
5035             current_line.append(leaf)
5036
5037     for leaf in line.leaves:
5038         yield from append_to_line(leaf)
5039
5040         for comment_after in line.comments_after(leaf):
5041             yield from append_to_line(comment_after)
5042
5043     if current_line:
5044         yield current_line
5045
5046
5047 def is_import(leaf: Leaf) -> bool:
5048     """Return True if the given leaf starts an import statement."""
5049     p = leaf.parent
5050     t = leaf.type
5051     v = leaf.value
5052     return bool(
5053         t == token.NAME
5054         and (
5055             (v == "import" and p and p.type == syms.import_name)
5056             or (v == "from" and p and p.type == syms.import_from)
5057         )
5058     )
5059
5060
5061 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5062     """Return True if the given leaf is a special comment.
5063     Only returns true for type comments for now."""
5064     t = leaf.type
5065     v = leaf.value
5066     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5067
5068
5069 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5070     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5071     else.
5072
5073     Note: don't use backslashes for formatting or you'll lose your voting rights.
5074     """
5075     if not inside_brackets:
5076         spl = leaf.prefix.split("#")
5077         if "\\" not in spl[0]:
5078             nl_count = spl[-1].count("\n")
5079             if len(spl) > 1:
5080                 nl_count -= 1
5081             leaf.prefix = "\n" * nl_count
5082             return
5083
5084     leaf.prefix = ""
5085
5086
5087 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5088     """Make all string prefixes lowercase.
5089
5090     If remove_u_prefix is given, also removes any u prefix from the string.
5091
5092     Note: Mutates its argument.
5093     """
5094     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5095     assert match is not None, f"failed to match string {leaf.value!r}"
5096     orig_prefix = match.group(1)
5097     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5098     if remove_u_prefix:
5099         new_prefix = new_prefix.replace("u", "")
5100     leaf.value = f"{new_prefix}{match.group(2)}"
5101
5102
5103 def normalize_string_quotes(leaf: Leaf) -> None:
5104     """Prefer double quotes but only if it doesn't cause more escaping.
5105
5106     Adds or removes backslashes as appropriate. Doesn't parse and fix
5107     strings nested in f-strings (yet).
5108
5109     Note: Mutates its argument.
5110     """
5111     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5112     if value[:3] == '"""':
5113         return
5114
5115     elif value[:3] == "'''":
5116         orig_quote = "'''"
5117         new_quote = '"""'
5118     elif value[0] == '"':
5119         orig_quote = '"'
5120         new_quote = "'"
5121     else:
5122         orig_quote = "'"
5123         new_quote = '"'
5124     first_quote_pos = leaf.value.find(orig_quote)
5125     if first_quote_pos == -1:
5126         return  # There's an internal error
5127
5128     prefix = leaf.value[:first_quote_pos]
5129     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5130     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5131     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5132     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5133     if "r" in prefix.casefold():
5134         if unescaped_new_quote.search(body):
5135             # There's at least one unescaped new_quote in this raw string
5136             # so converting is impossible
5137             return
5138
5139         # Do not introduce or remove backslashes in raw strings
5140         new_body = body
5141     else:
5142         # remove unnecessary escapes
5143         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5144         if body != new_body:
5145             # Consider the string without unnecessary escapes as the original
5146             body = new_body
5147             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5148         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5149         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5150     if "f" in prefix.casefold():
5151         matches = re.findall(
5152             r"""
5153             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5154                 ([^{].*?)  # contents of the brackets except if begins with {{
5155             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5156             """,
5157             new_body,
5158             re.VERBOSE,
5159         )
5160         for m in matches:
5161             if "\\" in str(m):
5162                 # Do not introduce backslashes in interpolated expressions
5163                 return
5164
5165     if new_quote == '"""' and new_body[-1:] == '"':
5166         # edge case:
5167         new_body = new_body[:-1] + '\\"'
5168     orig_escape_count = body.count("\\")
5169     new_escape_count = new_body.count("\\")
5170     if new_escape_count > orig_escape_count:
5171         return  # Do not introduce more escaping
5172
5173     if new_escape_count == orig_escape_count and orig_quote == '"':
5174         return  # Prefer double quotes
5175
5176     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5177
5178
5179 def normalize_numeric_literal(leaf: Leaf) -> None:
5180     """Normalizes numeric (float, int, and complex) literals.
5181
5182     All letters used in the representation are normalized to lowercase (except
5183     in Python 2 long literals).
5184     """
5185     text = leaf.value.lower()
5186     if text.startswith(("0o", "0b")):
5187         # Leave octal and binary literals alone.
5188         pass
5189     elif text.startswith("0x"):
5190         # Change hex literals to upper case.
5191         before, after = text[:2], text[2:]
5192         text = f"{before}{after.upper()}"
5193     elif "e" in text:
5194         before, after = text.split("e")
5195         sign = ""
5196         if after.startswith("-"):
5197             after = after[1:]
5198             sign = "-"
5199         elif after.startswith("+"):
5200             after = after[1:]
5201         before = format_float_or_int_string(before)
5202         text = f"{before}e{sign}{after}"
5203     elif text.endswith(("j", "l")):
5204         number = text[:-1]
5205         suffix = text[-1]
5206         # Capitalize in "2L" because "l" looks too similar to "1".
5207         if suffix == "l":
5208             suffix = "L"
5209         text = f"{format_float_or_int_string(number)}{suffix}"
5210     else:
5211         text = format_float_or_int_string(text)
5212     leaf.value = text
5213
5214
5215 def format_float_or_int_string(text: str) -> str:
5216     """Formats a float string like "1.0"."""
5217     if "." not in text:
5218         return text
5219
5220     before, after = text.split(".")
5221     return f"{before or 0}.{after or 0}"
5222
5223
5224 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5225     """Make existing optional parentheses invisible or create new ones.
5226
5227     `parens_after` is a set of string leaf values immediately after which parens
5228     should be put.
5229
5230     Standardizes on visible parentheses for single-element tuples, and keeps
5231     existing visible parentheses for other tuples and generator expressions.
5232     """
5233     for pc in list_comments(node.prefix, is_endmarker=False):
5234         if pc.value in FMT_OFF:
5235             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5236             return
5237     check_lpar = False
5238     for index, child in enumerate(list(node.children)):
5239         # Fixes a bug where invisible parens are not properly stripped from
5240         # assignment statements that contain type annotations.
5241         if isinstance(child, Node) and child.type == syms.annassign:
5242             normalize_invisible_parens(child, parens_after=parens_after)
5243
5244         # Add parentheses around long tuple unpacking in assignments.
5245         if (
5246             index == 0
5247             and isinstance(child, Node)
5248             and child.type == syms.testlist_star_expr
5249         ):
5250             check_lpar = True
5251
5252         if check_lpar:
5253             if is_walrus_assignment(child):
5254                 continue
5255
5256             if child.type == syms.atom:
5257                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5258                     wrap_in_parentheses(node, child, visible=False)
5259             elif is_one_tuple(child):
5260                 wrap_in_parentheses(node, child, visible=True)
5261             elif node.type == syms.import_from:
5262                 # "import from" nodes store parentheses directly as part of
5263                 # the statement
5264                 if child.type == token.LPAR:
5265                     # make parentheses invisible
5266                     child.value = ""  # type: ignore
5267                     node.children[-1].value = ""  # type: ignore
5268                 elif child.type != token.STAR:
5269                     # insert invisible parentheses
5270                     node.insert_child(index, Leaf(token.LPAR, ""))
5271                     node.append_child(Leaf(token.RPAR, ""))
5272                 break
5273
5274             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5275                 wrap_in_parentheses(node, child, visible=False)
5276
5277         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5278
5279
5280 def normalize_fmt_off(node: Node) -> None:
5281     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5282     try_again = True
5283     while try_again:
5284         try_again = convert_one_fmt_off_pair(node)
5285
5286
5287 def convert_one_fmt_off_pair(node: Node) -> bool:
5288     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5289
5290     Returns True if a pair was converted.
5291     """
5292     for leaf in node.leaves():
5293         previous_consumed = 0
5294         for comment in list_comments(leaf.prefix, is_endmarker=False):
5295             if comment.value in FMT_OFF:
5296                 # We only want standalone comments. If there's no previous leaf or
5297                 # the previous leaf is indentation, it's a standalone comment in
5298                 # disguise.
5299                 if comment.type != STANDALONE_COMMENT:
5300                     prev = preceding_leaf(leaf)
5301                     if prev and prev.type not in WHITESPACE:
5302                         continue
5303
5304                 ignored_nodes = list(generate_ignored_nodes(leaf))
5305                 if not ignored_nodes:
5306                     continue
5307
5308                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5309                 parent = first.parent
5310                 prefix = first.prefix
5311                 first.prefix = prefix[comment.consumed :]
5312                 hidden_value = (
5313                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5314                 )
5315                 if hidden_value.endswith("\n"):
5316                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5317                     # leaf (possibly followed by a DEDENT).
5318                     hidden_value = hidden_value[:-1]
5319                 first_idx: Optional[int] = None
5320                 for ignored in ignored_nodes:
5321                     index = ignored.remove()
5322                     if first_idx is None:
5323                         first_idx = index
5324                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5325                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5326                 parent.insert_child(
5327                     first_idx,
5328                     Leaf(
5329                         STANDALONE_COMMENT,
5330                         hidden_value,
5331                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5332                     ),
5333                 )
5334                 return True
5335
5336             previous_consumed = comment.consumed
5337
5338     return False
5339
5340
5341 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5342     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5343
5344     Stops at the end of the block.
5345     """
5346     container: Optional[LN] = container_of(leaf)
5347     while container is not None and container.type != token.ENDMARKER:
5348         if is_fmt_on(container):
5349             return
5350
5351         # fix for fmt: on in children
5352         if contains_fmt_on_at_column(container, leaf.column):
5353             for child in container.children:
5354                 if contains_fmt_on_at_column(child, leaf.column):
5355                     return
5356                 yield child
5357         else:
5358             yield container
5359             container = container.next_sibling
5360
5361
5362 def is_fmt_on(container: LN) -> bool:
5363     """Determine whether formatting is switched on within a container.
5364     Determined by whether the last `# fmt:` comment is `on` or `off`.
5365     """
5366     fmt_on = False
5367     for comment in list_comments(container.prefix, is_endmarker=False):
5368         if comment.value in FMT_ON:
5369             fmt_on = True
5370         elif comment.value in FMT_OFF:
5371             fmt_on = False
5372     return fmt_on
5373
5374
5375 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5376     """Determine if children at a given column have formatting switched on."""
5377     for child in container.children:
5378         if (
5379             isinstance(child, Node)
5380             and first_leaf_column(child) == column
5381             or isinstance(child, Leaf)
5382             and child.column == column
5383         ):
5384             if is_fmt_on(child):
5385                 return True
5386
5387     return False
5388
5389
5390 def first_leaf_column(node: Node) -> Optional[int]:
5391     """Returns the column of the first leaf child of a node."""
5392     for child in node.children:
5393         if isinstance(child, Leaf):
5394             return child.column
5395     return None
5396
5397
5398 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5399     """If it's safe, make the parens in the atom `node` invisible, recursively.
5400     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5401     as they are redundant.
5402
5403     Returns whether the node should itself be wrapped in invisible parentheses.
5404
5405     """
5406     if (
5407         node.type != syms.atom
5408         or is_empty_tuple(node)
5409         or is_one_tuple(node)
5410         or (is_yield(node) and parent.type != syms.expr_stmt)
5411         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5412     ):
5413         return False
5414
5415     first = node.children[0]
5416     last = node.children[-1]
5417     if first.type == token.LPAR and last.type == token.RPAR:
5418         middle = node.children[1]
5419         # make parentheses invisible
5420         first.value = ""  # type: ignore
5421         last.value = ""  # type: ignore
5422         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5423
5424         if is_atom_with_invisible_parens(middle):
5425             # Strip the invisible parens from `middle` by replacing
5426             # it with the child in-between the invisible parens
5427             middle.replace(middle.children[1])
5428
5429         return False
5430
5431     return True
5432
5433
5434 def is_atom_with_invisible_parens(node: LN) -> bool:
5435     """Given a `LN`, determines whether it's an atom `node` with invisible
5436     parens. Useful in dedupe-ing and normalizing parens.
5437     """
5438     if isinstance(node, Leaf) or node.type != syms.atom:
5439         return False
5440
5441     first, last = node.children[0], node.children[-1]
5442     return (
5443         isinstance(first, Leaf)
5444         and first.type == token.LPAR
5445         and first.value == ""
5446         and isinstance(last, Leaf)
5447         and last.type == token.RPAR
5448         and last.value == ""
5449     )
5450
5451
5452 def is_empty_tuple(node: LN) -> bool:
5453     """Return True if `node` holds an empty tuple."""
5454     return (
5455         node.type == syms.atom
5456         and len(node.children) == 2
5457         and node.children[0].type == token.LPAR
5458         and node.children[1].type == token.RPAR
5459     )
5460
5461
5462 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5463     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5464
5465     Parenthesis can be optional. Returns None otherwise"""
5466     if len(node.children) != 3:
5467         return None
5468
5469     lpar, wrapped, rpar = node.children
5470     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5471         return None
5472
5473     return wrapped
5474
5475
5476 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5477     """Wrap `child` in parentheses.
5478
5479     This replaces `child` with an atom holding the parentheses and the old
5480     child.  That requires moving the prefix.
5481
5482     If `visible` is False, the leaves will be valueless (and thus invisible).
5483     """
5484     lpar = Leaf(token.LPAR, "(" if visible else "")
5485     rpar = Leaf(token.RPAR, ")" if visible else "")
5486     prefix = child.prefix
5487     child.prefix = ""
5488     index = child.remove() or 0
5489     new_child = Node(syms.atom, [lpar, child, rpar])
5490     new_child.prefix = prefix
5491     parent.insert_child(index, new_child)
5492
5493
5494 def is_one_tuple(node: LN) -> bool:
5495     """Return True if `node` holds a tuple with one element, with or without parens."""
5496     if node.type == syms.atom:
5497         gexp = unwrap_singleton_parenthesis(node)
5498         if gexp is None or gexp.type != syms.testlist_gexp:
5499             return False
5500
5501         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5502
5503     return (
5504         node.type in IMPLICIT_TUPLE
5505         and len(node.children) == 2
5506         and node.children[1].type == token.COMMA
5507     )
5508
5509
5510 def is_walrus_assignment(node: LN) -> bool:
5511     """Return True iff `node` is of the shape ( test := test )"""
5512     inner = unwrap_singleton_parenthesis(node)
5513     return inner is not None and inner.type == syms.namedexpr_test
5514
5515
5516 def is_yield(node: LN) -> bool:
5517     """Return True if `node` holds a `yield` or `yield from` expression."""
5518     if node.type == syms.yield_expr:
5519         return True
5520
5521     if node.type == token.NAME and node.value == "yield":  # type: ignore
5522         return True
5523
5524     if node.type != syms.atom:
5525         return False
5526
5527     if len(node.children) != 3:
5528         return False
5529
5530     lpar, expr, rpar = node.children
5531     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5532         return is_yield(expr)
5533
5534     return False
5535
5536
5537 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5538     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5539
5540     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5541     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5542     extended iterable unpacking (PEP 3132) and additional unpacking
5543     generalizations (PEP 448).
5544     """
5545     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5546         return False
5547
5548     p = leaf.parent
5549     if p.type == syms.star_expr:
5550         # Star expressions are also used as assignment targets in extended
5551         # iterable unpacking (PEP 3132).  See what its parent is instead.
5552         if not p.parent:
5553             return False
5554
5555         p = p.parent
5556
5557     return p.type in within
5558
5559
5560 def is_multiline_string(leaf: Leaf) -> bool:
5561     """Return True if `leaf` is a multiline string that actually spans many lines."""
5562     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5563
5564
5565 def is_stub_suite(node: Node) -> bool:
5566     """Return True if `node` is a suite with a stub body."""
5567     if (
5568         len(node.children) != 4
5569         or node.children[0].type != token.NEWLINE
5570         or node.children[1].type != token.INDENT
5571         or node.children[3].type != token.DEDENT
5572     ):
5573         return False
5574
5575     return is_stub_body(node.children[2])
5576
5577
5578 def is_stub_body(node: LN) -> bool:
5579     """Return True if `node` is a simple statement containing an ellipsis."""
5580     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5581         return False
5582
5583     if len(node.children) != 2:
5584         return False
5585
5586     child = node.children[0]
5587     return (
5588         child.type == syms.atom
5589         and len(child.children) == 3
5590         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5591     )
5592
5593
5594 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5595     """Return maximum delimiter priority inside `node`.
5596
5597     This is specific to atoms with contents contained in a pair of parentheses.
5598     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5599     """
5600     if node.type != syms.atom:
5601         return 0
5602
5603     first = node.children[0]
5604     last = node.children[-1]
5605     if not (first.type == token.LPAR and last.type == token.RPAR):
5606         return 0
5607
5608     bt = BracketTracker()
5609     for c in node.children[1:-1]:
5610         if isinstance(c, Leaf):
5611             bt.mark(c)
5612         else:
5613             for leaf in c.leaves():
5614                 bt.mark(leaf)
5615     try:
5616         return bt.max_delimiter_priority()
5617
5618     except ValueError:
5619         return 0
5620
5621
5622 def ensure_visible(leaf: Leaf) -> None:
5623     """Make sure parentheses are visible.
5624
5625     They could be invisible as part of some statements (see
5626     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5627     """
5628     if leaf.type == token.LPAR:
5629         leaf.value = "("
5630     elif leaf.type == token.RPAR:
5631         leaf.value = ")"
5632
5633
5634 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
5635     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
5636
5637     if not (
5638         opening_bracket.parent
5639         and opening_bracket.parent.type in {syms.atom, syms.import_from}
5640         and opening_bracket.value in "[{("
5641     ):
5642         return False
5643
5644     try:
5645         last_leaf = line.leaves[-1]
5646         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
5647         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5648     except (IndexError, ValueError):
5649         return False
5650
5651     return max_priority == COMMA_PRIORITY
5652
5653
5654 def get_features_used(node: Node) -> Set[Feature]:
5655     """Return a set of (relatively) new Python features used in this file.
5656
5657     Currently looking for:
5658     - f-strings;
5659     - underscores in numeric literals;
5660     - trailing commas after * or ** in function signatures and calls;
5661     - positional only arguments in function signatures and lambdas;
5662     """
5663     features: Set[Feature] = set()
5664     for n in node.pre_order():
5665         if n.type == token.STRING:
5666             value_head = n.value[:2]  # type: ignore
5667             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5668                 features.add(Feature.F_STRINGS)
5669
5670         elif n.type == token.NUMBER:
5671             if "_" in n.value:  # type: ignore
5672                 features.add(Feature.NUMERIC_UNDERSCORES)
5673
5674         elif n.type == token.SLASH:
5675             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5676                 features.add(Feature.POS_ONLY_ARGUMENTS)
5677
5678         elif n.type == token.COLONEQUAL:
5679             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5680
5681         elif (
5682             n.type in {syms.typedargslist, syms.arglist}
5683             and n.children
5684             and n.children[-1].type == token.COMMA
5685         ):
5686             if n.type == syms.typedargslist:
5687                 feature = Feature.TRAILING_COMMA_IN_DEF
5688             else:
5689                 feature = Feature.TRAILING_COMMA_IN_CALL
5690
5691             for ch in n.children:
5692                 if ch.type in STARS:
5693                     features.add(feature)
5694
5695                 if ch.type == syms.argument:
5696                     for argch in ch.children:
5697                         if argch.type in STARS:
5698                             features.add(feature)
5699
5700     return features
5701
5702
5703 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5704     """Detect the version to target based on the nodes used."""
5705     features = get_features_used(node)
5706     return {
5707         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5708     }
5709
5710
5711 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5712     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5713
5714     Brackets can be omitted if the entire trailer up to and including
5715     a preceding closing bracket fits in one line.
5716
5717     Yielded sets are cumulative (contain results of previous yields, too).  First
5718     set is empty.
5719     """
5720
5721     omit: Set[LeafID] = set()
5722     yield omit
5723
5724     length = 4 * line.depth
5725     opening_bracket: Optional[Leaf] = None
5726     closing_bracket: Optional[Leaf] = None
5727     inner_brackets: Set[LeafID] = set()
5728     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5729         length += leaf_length
5730         if length > line_length:
5731             break
5732
5733         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5734         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5735             break
5736
5737         if opening_bracket:
5738             if leaf is opening_bracket:
5739                 opening_bracket = None
5740             elif leaf.type in CLOSING_BRACKETS:
5741                 inner_brackets.add(id(leaf))
5742         elif leaf.type in CLOSING_BRACKETS:
5743             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
5744                 # Empty brackets would fail a split so treat them as "inner"
5745                 # brackets (e.g. only add them to the `omit` set if another
5746                 # pair of brackets was good enough.
5747                 inner_brackets.add(id(leaf))
5748                 continue
5749
5750             if closing_bracket:
5751                 omit.add(id(closing_bracket))
5752                 omit.update(inner_brackets)
5753                 inner_brackets.clear()
5754                 yield omit
5755
5756             if leaf.value:
5757                 opening_bracket = leaf.opening_bracket
5758                 closing_bracket = leaf
5759
5760
5761 def get_future_imports(node: Node) -> Set[str]:
5762     """Return a set of __future__ imports in the file."""
5763     imports: Set[str] = set()
5764
5765     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5766         for child in children:
5767             if isinstance(child, Leaf):
5768                 if child.type == token.NAME:
5769                     yield child.value
5770
5771             elif child.type == syms.import_as_name:
5772                 orig_name = child.children[0]
5773                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5774                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5775                 yield orig_name.value
5776
5777             elif child.type == syms.import_as_names:
5778                 yield from get_imports_from_children(child.children)
5779
5780             else:
5781                 raise AssertionError("Invalid syntax parsing imports")
5782
5783     for child in node.children:
5784         if child.type != syms.simple_stmt:
5785             break
5786
5787         first_child = child.children[0]
5788         if isinstance(first_child, Leaf):
5789             # Continue looking if we see a docstring; otherwise stop.
5790             if (
5791                 len(child.children) == 2
5792                 and first_child.type == token.STRING
5793                 and child.children[1].type == token.NEWLINE
5794             ):
5795                 continue
5796
5797             break
5798
5799         elif first_child.type == syms.import_from:
5800             module_name = first_child.children[1]
5801             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5802                 break
5803
5804             imports |= set(get_imports_from_children(first_child.children[3:]))
5805         else:
5806             break
5807
5808     return imports
5809
5810
5811 @lru_cache()
5812 def get_gitignore(root: Path) -> PathSpec:
5813     """ Return a PathSpec matching gitignore content if present."""
5814     gitignore = root / ".gitignore"
5815     lines: List[str] = []
5816     if gitignore.is_file():
5817         with gitignore.open() as gf:
5818             lines = gf.readlines()
5819     return PathSpec.from_lines("gitwildmatch", lines)
5820
5821
5822 def normalize_path_maybe_ignore(
5823     path: Path, root: Path, report: "Report"
5824 ) -> Optional[str]:
5825     """Normalize `path`. May return `None` if `path` was ignored.
5826
5827     `report` is where "path ignored" output goes.
5828     """
5829     try:
5830         normalized_path = path.resolve().relative_to(root).as_posix()
5831     except OSError as e:
5832         report.path_ignored(path, f"cannot be read because {e}")
5833         return None
5834
5835     except ValueError:
5836         if path.is_symlink():
5837             report.path_ignored(path, f"is a symbolic link that points outside {root}")
5838             return None
5839
5840         raise
5841
5842     return normalized_path
5843
5844
5845 def gen_python_files(
5846     paths: Iterable[Path],
5847     root: Path,
5848     include: Optional[Pattern[str]],
5849     exclude: Pattern[str],
5850     force_exclude: Optional[Pattern[str]],
5851     report: "Report",
5852     gitignore: PathSpec,
5853 ) -> Iterator[Path]:
5854     """Generate all files under `path` whose paths are not excluded by the
5855     `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
5856
5857     Symbolic links pointing outside of the `root` directory are ignored.
5858
5859     `report` is where output about exclusions goes.
5860     """
5861     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5862     for child in paths:
5863         normalized_path = normalize_path_maybe_ignore(child, root, report)
5864         if normalized_path is None:
5865             continue
5866
5867         # First ignore files matching .gitignore
5868         if gitignore.match_file(normalized_path):
5869             report.path_ignored(child, "matches the .gitignore file content")
5870             continue
5871
5872         # Then ignore with `--exclude` and `--force-exclude` options.
5873         normalized_path = "/" + normalized_path
5874         if child.is_dir():
5875             normalized_path += "/"
5876
5877         exclude_match = exclude.search(normalized_path) if exclude else None
5878         if exclude_match and exclude_match.group(0):
5879             report.path_ignored(child, "matches the --exclude regular expression")
5880             continue
5881
5882         force_exclude_match = (
5883             force_exclude.search(normalized_path) if force_exclude else None
5884         )
5885         if force_exclude_match and force_exclude_match.group(0):
5886             report.path_ignored(child, "matches the --force-exclude regular expression")
5887             continue
5888
5889         if child.is_dir():
5890             yield from gen_python_files(
5891                 child.iterdir(),
5892                 root,
5893                 include,
5894                 exclude,
5895                 force_exclude,
5896                 report,
5897                 gitignore,
5898             )
5899
5900         elif child.is_file():
5901             include_match = include.search(normalized_path) if include else True
5902             if include_match:
5903                 yield child
5904
5905
5906 @lru_cache()
5907 def find_project_root(srcs: Iterable[str]) -> Path:
5908     """Return a directory containing .git, .hg, or pyproject.toml.
5909
5910     That directory will be a common parent of all files and directories
5911     passed in `srcs`.
5912
5913     If no directory in the tree contains a marker that would specify it's the
5914     project root, the root of the file system is returned.
5915     """
5916     if not srcs:
5917         return Path("/").resolve()
5918
5919     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
5920
5921     # A list of lists of parents for each 'src'. 'src' is included as a
5922     # "parent" of itself if it is a directory
5923     src_parents = [
5924         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
5925     ]
5926
5927     common_base = max(
5928         set.intersection(*(set(parents) for parents in src_parents)),
5929         key=lambda path: path.parts,
5930     )
5931
5932     for directory in (common_base, *common_base.parents):
5933         if (directory / ".git").exists():
5934             return directory
5935
5936         if (directory / ".hg").is_dir():
5937             return directory
5938
5939         if (directory / "pyproject.toml").is_file():
5940             return directory
5941
5942     return directory
5943
5944
5945 @dataclass
5946 class Report:
5947     """Provides a reformatting counter. Can be rendered with `str(report)`."""
5948
5949     check: bool = False
5950     diff: bool = False
5951     quiet: bool = False
5952     verbose: bool = False
5953     change_count: int = 0
5954     same_count: int = 0
5955     failure_count: int = 0
5956
5957     def done(self, src: Path, changed: Changed) -> None:
5958         """Increment the counter for successful reformatting. Write out a message."""
5959         if changed is Changed.YES:
5960             reformatted = "would reformat" if self.check or self.diff else "reformatted"
5961             if self.verbose or not self.quiet:
5962                 out(f"{reformatted} {src}")
5963             self.change_count += 1
5964         else:
5965             if self.verbose:
5966                 if changed is Changed.NO:
5967                     msg = f"{src} already well formatted, good job."
5968                 else:
5969                     msg = f"{src} wasn't modified on disk since last run."
5970                 out(msg, bold=False)
5971             self.same_count += 1
5972
5973     def failed(self, src: Path, message: str) -> None:
5974         """Increment the counter for failed reformatting. Write out a message."""
5975         err(f"error: cannot format {src}: {message}")
5976         self.failure_count += 1
5977
5978     def path_ignored(self, path: Path, message: str) -> None:
5979         if self.verbose:
5980             out(f"{path} ignored: {message}", bold=False)
5981
5982     @property
5983     def return_code(self) -> int:
5984         """Return the exit code that the app should use.
5985
5986         This considers the current state of changed files and failures:
5987         - if there were any failures, return 123;
5988         - if any files were changed and --check is being used, return 1;
5989         - otherwise return 0.
5990         """
5991         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
5992         # 126 we have special return codes reserved by the shell.
5993         if self.failure_count:
5994             return 123
5995
5996         elif self.change_count and self.check:
5997             return 1
5998
5999         return 0
6000
6001     def __str__(self) -> str:
6002         """Render a color report of the current state.
6003
6004         Use `click.unstyle` to remove colors.
6005         """
6006         if self.check or self.diff:
6007             reformatted = "would be reformatted"
6008             unchanged = "would be left unchanged"
6009             failed = "would fail to reformat"
6010         else:
6011             reformatted = "reformatted"
6012             unchanged = "left unchanged"
6013             failed = "failed to reformat"
6014         report = []
6015         if self.change_count:
6016             s = "s" if self.change_count > 1 else ""
6017             report.append(
6018                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6019             )
6020         if self.same_count:
6021             s = "s" if self.same_count > 1 else ""
6022             report.append(f"{self.same_count} file{s} {unchanged}")
6023         if self.failure_count:
6024             s = "s" if self.failure_count > 1 else ""
6025             report.append(
6026                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6027             )
6028         return ", ".join(report) + "."
6029
6030
6031 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6032     filename = "<unknown>"
6033     if sys.version_info >= (3, 8):
6034         # TODO: support Python 4+ ;)
6035         for minor_version in range(sys.version_info[1], 4, -1):
6036             try:
6037                 return ast.parse(src, filename, feature_version=(3, minor_version))
6038             except SyntaxError:
6039                 continue
6040     else:
6041         for feature_version in (7, 6):
6042             try:
6043                 return ast3.parse(src, filename, feature_version=feature_version)
6044             except SyntaxError:
6045                 continue
6046
6047     return ast27.parse(src)
6048
6049
6050 def _fixup_ast_constants(
6051     node: Union[ast.AST, ast3.AST, ast27.AST]
6052 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6053     """Map ast nodes deprecated in 3.8 to Constant."""
6054     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6055         return ast.Constant(value=node.s)
6056
6057     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6058         return ast.Constant(value=node.n)
6059
6060     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6061         return ast.Constant(value=node.value)
6062
6063     return node
6064
6065
6066 def _stringify_ast(
6067     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6068 ) -> Iterator[str]:
6069     """Simple visitor generating strings to compare ASTs by content."""
6070
6071     node = _fixup_ast_constants(node)
6072
6073     yield f"{'  ' * depth}{node.__class__.__name__}("
6074
6075     for field in sorted(node._fields):  # noqa: F402
6076         # TypeIgnore has only one field 'lineno' which breaks this comparison
6077         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6078         if sys.version_info >= (3, 8):
6079             type_ignore_classes += (ast.TypeIgnore,)
6080         if isinstance(node, type_ignore_classes):
6081             break
6082
6083         try:
6084             value = getattr(node, field)
6085         except AttributeError:
6086             continue
6087
6088         yield f"{'  ' * (depth+1)}{field}="
6089
6090         if isinstance(value, list):
6091             for item in value:
6092                 # Ignore nested tuples within del statements, because we may insert
6093                 # parentheses and they change the AST.
6094                 if (
6095                     field == "targets"
6096                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6097                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6098                 ):
6099                     for item in item.elts:
6100                         yield from _stringify_ast(item, depth + 2)
6101
6102                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6103                     yield from _stringify_ast(item, depth + 2)
6104
6105         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6106             yield from _stringify_ast(value, depth + 2)
6107
6108         else:
6109             # Constant strings may be indented across newlines, if they are
6110             # docstrings; fold spaces after newlines when comparing. Similarly,
6111             # trailing and leading space may be removed.
6112             if (
6113                 isinstance(node, ast.Constant)
6114                 and field == "value"
6115                 and isinstance(value, str)
6116             ):
6117                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6118             else:
6119                 normalized = value
6120             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6121
6122     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6123
6124
6125 def assert_equivalent(src: str, dst: str) -> None:
6126     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6127     try:
6128         src_ast = parse_ast(src)
6129     except Exception as exc:
6130         raise AssertionError(
6131             "cannot use --safe with this file; failed to parse source file.  AST"
6132             f" error message: {exc}"
6133         )
6134
6135     try:
6136         dst_ast = parse_ast(dst)
6137     except Exception as exc:
6138         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6139         raise AssertionError(
6140             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6141             " on https://github.com/psf/black/issues.  This invalid output might be"
6142             f" helpful: {log}"
6143         ) from None
6144
6145     src_ast_str = "\n".join(_stringify_ast(src_ast))
6146     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6147     if src_ast_str != dst_ast_str:
6148         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6149         raise AssertionError(
6150             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6151             " source.  Please report a bug on https://github.com/psf/black/issues. "
6152             f" This diff might be helpful: {log}"
6153         ) from None
6154
6155
6156 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6157     """Raise AssertionError if `dst` reformats differently the second time."""
6158     newdst = format_str(dst, mode=mode)
6159     if dst != newdst:
6160         log = dump_to_file(
6161             diff(src, dst, "source", "first pass"),
6162             diff(dst, newdst, "first pass", "second pass"),
6163         )
6164         raise AssertionError(
6165             "INTERNAL ERROR: Black produced different code on the second pass of the"
6166             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6167             f"  This diff might be helpful: {log}"
6168         ) from None
6169
6170
6171 @mypyc_attr(patchable=True)
6172 def dump_to_file(*output: str) -> str:
6173     """Dump `output` to a temporary file. Return path to the file."""
6174     with tempfile.NamedTemporaryFile(
6175         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6176     ) as f:
6177         for lines in output:
6178             f.write(lines)
6179             if lines and lines[-1] != "\n":
6180                 f.write("\n")
6181     return f.name
6182
6183
6184 @contextmanager
6185 def nullcontext() -> Iterator[None]:
6186     """Return an empty context manager.
6187
6188     To be used like `nullcontext` in Python 3.7.
6189     """
6190     yield
6191
6192
6193 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6194     """Return a unified diff string between strings `a` and `b`."""
6195     import difflib
6196
6197     a_lines = [line + "\n" for line in a.splitlines()]
6198     b_lines = [line + "\n" for line in b.splitlines()]
6199     return "".join(
6200         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6201     )
6202
6203
6204 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6205     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6206     err("Aborted!")
6207     for task in tasks:
6208         task.cancel()
6209
6210
6211 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6212     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6213     try:
6214         if sys.version_info[:2] >= (3, 7):
6215             all_tasks = asyncio.all_tasks
6216         else:
6217             all_tasks = asyncio.Task.all_tasks
6218         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6219         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6220         if not to_cancel:
6221             return
6222
6223         for task in to_cancel:
6224             task.cancel()
6225         loop.run_until_complete(
6226             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6227         )
6228     finally:
6229         # `concurrent.futures.Future` objects cannot be cancelled once they
6230         # are already running. There might be some when the `shutdown()` happened.
6231         # Silence their logger's spew about the event loop being closed.
6232         cf_logger = logging.getLogger("concurrent.futures")
6233         cf_logger.setLevel(logging.CRITICAL)
6234         loop.close()
6235
6236
6237 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6238     """Replace `regex` with `replacement` twice on `original`.
6239
6240     This is used by string normalization to perform replaces on
6241     overlapping matches.
6242     """
6243     return regex.sub(replacement, regex.sub(replacement, original))
6244
6245
6246 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6247     """Compile a regular expression string in `regex`.
6248
6249     If it contains newlines, use verbose mode.
6250     """
6251     if "\n" in regex:
6252         regex = "(?x)" + regex
6253     compiled: Pattern[str] = re.compile(regex)
6254     return compiled
6255
6256
6257 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6258     """Like `reversed(enumerate(sequence))` if that were possible."""
6259     index = len(sequence) - 1
6260     for element in reversed(sequence):
6261         yield (index, element)
6262         index -= 1
6263
6264
6265 def enumerate_with_length(
6266     line: Line, reversed: bool = False
6267 ) -> Iterator[Tuple[Index, Leaf, int]]:
6268     """Return an enumeration of leaves with their length.
6269
6270     Stops prematurely on multiline strings and standalone comments.
6271     """
6272     op = cast(
6273         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6274         enumerate_reversed if reversed else enumerate,
6275     )
6276     for index, leaf in op(line.leaves):
6277         length = len(leaf.prefix) + len(leaf.value)
6278         if "\n" in leaf.value:
6279             return  # Multiline strings, we can't continue.
6280
6281         for comment in line.comments_after(leaf):
6282             length += len(comment.value)
6283
6284         yield index, leaf, length
6285
6286
6287 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6288     """Return True if `line` is no longer than `line_length`.
6289
6290     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6291     """
6292     if not line_str:
6293         line_str = line_to_string(line)
6294     return (
6295         len(line_str) <= line_length
6296         and "\n" not in line_str  # multiline strings
6297         and not line.contains_standalone_comments()
6298     )
6299
6300
6301 def can_be_split(line: Line) -> bool:
6302     """Return False if the line cannot be split *for sure*.
6303
6304     This is not an exhaustive search but a cheap heuristic that we can use to
6305     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6306     in unnecessary parentheses).
6307     """
6308     leaves = line.leaves
6309     if len(leaves) < 2:
6310         return False
6311
6312     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6313         call_count = 0
6314         dot_count = 0
6315         next = leaves[-1]
6316         for leaf in leaves[-2::-1]:
6317             if leaf.type in OPENING_BRACKETS:
6318                 if next.type not in CLOSING_BRACKETS:
6319                     return False
6320
6321                 call_count += 1
6322             elif leaf.type == token.DOT:
6323                 dot_count += 1
6324             elif leaf.type == token.NAME:
6325                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6326                     return False
6327
6328             elif leaf.type not in CLOSING_BRACKETS:
6329                 return False
6330
6331             if dot_count > 1 and call_count > 1:
6332                 return False
6333
6334     return True
6335
6336
6337 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
6338     """Does `line` have a shape safe to reformat without optional parens around it?
6339
6340     Returns True for only a subset of potentially nice looking formattings but
6341     the point is to not return false positives that end up producing lines that
6342     are too long.
6343     """
6344     bt = line.bracket_tracker
6345     if not bt.delimiters:
6346         # Without delimiters the optional parentheses are useless.
6347         return True
6348
6349     max_priority = bt.max_delimiter_priority()
6350     if bt.delimiter_count_with_priority(max_priority) > 1:
6351         # With more than one delimiter of a kind the optional parentheses read better.
6352         return False
6353
6354     if max_priority == DOT_PRIORITY:
6355         # A single stranded method call doesn't require optional parentheses.
6356         return True
6357
6358     assert len(line.leaves) >= 2, "Stranded delimiter"
6359
6360     first = line.leaves[0]
6361     second = line.leaves[1]
6362     penultimate = line.leaves[-2]
6363     last = line.leaves[-1]
6364
6365     # With a single delimiter, omit if the expression starts or ends with
6366     # a bracket.
6367     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6368         remainder = False
6369         length = 4 * line.depth
6370         for _index, leaf, leaf_length in enumerate_with_length(line):
6371             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6372                 remainder = True
6373             if remainder:
6374                 length += leaf_length
6375                 if length > line_length:
6376                     break
6377
6378                 if leaf.type in OPENING_BRACKETS:
6379                     # There are brackets we can further split on.
6380                     remainder = False
6381
6382         else:
6383             # checked the entire string and line length wasn't exceeded
6384             if len(line.leaves) == _index + 1:
6385                 return True
6386
6387         # Note: we are not returning False here because a line might have *both*
6388         # a leading opening bracket and a trailing closing bracket.  If the
6389         # opening bracket doesn't match our rule, maybe the closing will.
6390
6391     if (
6392         last.type == token.RPAR
6393         or last.type == token.RBRACE
6394         or (
6395             # don't use indexing for omitting optional parentheses;
6396             # it looks weird
6397             last.type == token.RSQB
6398             and last.parent
6399             and last.parent.type != syms.trailer
6400         )
6401     ):
6402         if penultimate.type in OPENING_BRACKETS:
6403             # Empty brackets don't help.
6404             return False
6405
6406         if is_multiline_string(first):
6407             # Additional wrapping of a multiline string in this situation is
6408             # unnecessary.
6409             return True
6410
6411         length = 4 * line.depth
6412         seen_other_brackets = False
6413         for _index, leaf, leaf_length in enumerate_with_length(line):
6414             length += leaf_length
6415             if leaf is last.opening_bracket:
6416                 if seen_other_brackets or length <= line_length:
6417                     return True
6418
6419             elif leaf.type in OPENING_BRACKETS:
6420                 # There are brackets we can further split on.
6421                 seen_other_brackets = True
6422
6423     return False
6424
6425
6426 def get_cache_file(mode: Mode) -> Path:
6427     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6428
6429
6430 def read_cache(mode: Mode) -> Cache:
6431     """Read the cache if it exists and is well formed.
6432
6433     If it is not well formed, the call to write_cache later should resolve the issue.
6434     """
6435     cache_file = get_cache_file(mode)
6436     if not cache_file.exists():
6437         return {}
6438
6439     with cache_file.open("rb") as fobj:
6440         try:
6441             cache: Cache = pickle.load(fobj)
6442         except (pickle.UnpicklingError, ValueError):
6443             return {}
6444
6445     return cache
6446
6447
6448 def get_cache_info(path: Path) -> CacheInfo:
6449     """Return the information used to check if a file is already formatted or not."""
6450     stat = path.stat()
6451     return stat.st_mtime, stat.st_size
6452
6453
6454 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6455     """Split an iterable of paths in `sources` into two sets.
6456
6457     The first contains paths of files that modified on disk or are not in the
6458     cache. The other contains paths to non-modified files.
6459     """
6460     todo, done = set(), set()
6461     for src in sources:
6462         src = src.resolve()
6463         if cache.get(src) != get_cache_info(src):
6464             todo.add(src)
6465         else:
6466             done.add(src)
6467     return todo, done
6468
6469
6470 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6471     """Update the cache file."""
6472     cache_file = get_cache_file(mode)
6473     try:
6474         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6475         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6476         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6477             pickle.dump(new_cache, f, protocol=4)
6478         os.replace(f.name, cache_file)
6479     except OSError:
6480         pass
6481
6482
6483 def patch_click() -> None:
6484     """Make Click not crash.
6485
6486     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6487     default which restricts paths that it can access during the lifetime of the
6488     application.  Click refuses to work in this scenario by raising a RuntimeError.
6489
6490     In case of Black the likelihood that non-ASCII characters are going to be used in
6491     file paths is minimal since it's Python source code.  Moreover, this crash was
6492     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6493     """
6494     try:
6495         from click import core
6496         from click import _unicodefun  # type: ignore
6497     except ModuleNotFoundError:
6498         return
6499
6500     for module in (core, _unicodefun):
6501         if hasattr(module, "_verify_python3_env"):
6502             module._verify_python3_env = lambda: None
6503
6504
6505 def patched_main() -> None:
6506     freeze_support()
6507     patch_click()
6508     main()
6509
6510
6511 def fix_docstring(docstring: str, prefix: str) -> str:
6512     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6513     if not docstring:
6514         return ""
6515     # Convert tabs to spaces (following the normal Python rules)
6516     # and split into a list of lines:
6517     lines = docstring.expandtabs().splitlines()
6518     # Determine minimum indentation (first line doesn't count):
6519     indent = sys.maxsize
6520     for line in lines[1:]:
6521         stripped = line.lstrip()
6522         if stripped:
6523             indent = min(indent, len(line) - len(stripped))
6524     # Remove indentation (first line is special):
6525     trimmed = [lines[0].strip()]
6526     if indent < sys.maxsize:
6527         last_line_idx = len(lines) - 2
6528         for i, line in enumerate(lines[1:]):
6529             stripped_line = line[indent:].rstrip()
6530             if stripped_line or i == last_line_idx:
6531                 trimmed.append(prefix + stripped_line)
6532             else:
6533                 trimmed.append("")
6534     # Return a single string:
6535     return "\n".join(trimmed)
6536
6537
6538 if __name__ == "__main__":
6539     patched_main()