]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

77dda2b52cc4bbebc3a57f4ed272f42071d45cfe
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 T = TypeVar("T")
116 E = TypeVar("E", bound=Exception)
117
118
119 class Ok(Generic[T]):
120     def __init__(self, value: T) -> None:
121         self._value = value
122
123     def ok(self) -> T:
124         return self._value
125
126
127 class Err(Generic[E]):
128     def __init__(self, e: E) -> None:
129         self._e = e
130
131     def err(self) -> E:
132         return self._e
133
134
135 # The 'Result' return type is used to implement an error-handling model heavily
136 # influenced by that used by the Rust programming language
137 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
138 Result = Union[Ok[T], Err[E]]
139 TResult = Result[T, CannotTransform]  # (T)ransform Result
140 TMatchResult = TResult[Index]
141
142
143 class WriteBack(Enum):
144     NO = 0
145     YES = 1
146     DIFF = 2
147     CHECK = 3
148     COLOR_DIFF = 4
149
150     @classmethod
151     def from_configuration(
152         cls, *, check: bool, diff: bool, color: bool = False
153     ) -> "WriteBack":
154         if check and not diff:
155             return cls.CHECK
156
157         if diff and color:
158             return cls.COLOR_DIFF
159
160         return cls.DIFF if diff else cls.YES
161
162
163 class Changed(Enum):
164     NO = 0
165     CACHED = 1
166     YES = 2
167
168
169 class TargetVersion(Enum):
170     PY27 = 2
171     PY33 = 3
172     PY34 = 4
173     PY35 = 5
174     PY36 = 6
175     PY37 = 7
176     PY38 = 8
177
178     def is_python2(self) -> bool:
179         return self is TargetVersion.PY27
180
181
182 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
183
184
185 class Feature(Enum):
186     # All string literals are unicode
187     UNICODE_LITERALS = 1
188     F_STRINGS = 2
189     NUMERIC_UNDERSCORES = 3
190     TRAILING_COMMA_IN_CALL = 4
191     TRAILING_COMMA_IN_DEF = 5
192     # The following two feature-flags are mutually exclusive, and exactly one should be
193     # set for every version of python.
194     ASYNC_IDENTIFIERS = 6
195     ASYNC_KEYWORDS = 7
196     ASSIGNMENT_EXPRESSIONS = 8
197     POS_ONLY_ARGUMENTS = 9
198
199
200 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
201     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
202     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
203     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
204     TargetVersion.PY35: {
205         Feature.UNICODE_LITERALS,
206         Feature.TRAILING_COMMA_IN_CALL,
207         Feature.ASYNC_IDENTIFIERS,
208     },
209     TargetVersion.PY36: {
210         Feature.UNICODE_LITERALS,
211         Feature.F_STRINGS,
212         Feature.NUMERIC_UNDERSCORES,
213         Feature.TRAILING_COMMA_IN_CALL,
214         Feature.TRAILING_COMMA_IN_DEF,
215         Feature.ASYNC_IDENTIFIERS,
216     },
217     TargetVersion.PY37: {
218         Feature.UNICODE_LITERALS,
219         Feature.F_STRINGS,
220         Feature.NUMERIC_UNDERSCORES,
221         Feature.TRAILING_COMMA_IN_CALL,
222         Feature.TRAILING_COMMA_IN_DEF,
223         Feature.ASYNC_KEYWORDS,
224     },
225     TargetVersion.PY38: {
226         Feature.UNICODE_LITERALS,
227         Feature.F_STRINGS,
228         Feature.NUMERIC_UNDERSCORES,
229         Feature.TRAILING_COMMA_IN_CALL,
230         Feature.TRAILING_COMMA_IN_DEF,
231         Feature.ASYNC_KEYWORDS,
232         Feature.ASSIGNMENT_EXPRESSIONS,
233         Feature.POS_ONLY_ARGUMENTS,
234     },
235 }
236
237
238 @dataclass
239 class Mode:
240     target_versions: Set[TargetVersion] = field(default_factory=set)
241     line_length: int = DEFAULT_LINE_LENGTH
242     string_normalization: bool = True
243     is_pyi: bool = False
244
245     def get_cache_key(self) -> str:
246         if self.target_versions:
247             version_str = ",".join(
248                 str(version.value)
249                 for version in sorted(self.target_versions, key=lambda v: v.value)
250             )
251         else:
252             version_str = "-"
253         parts = [
254             version_str,
255             str(self.line_length),
256             str(int(self.string_normalization)),
257             str(int(self.is_pyi)),
258         ]
259         return ".".join(parts)
260
261
262 # Legacy name, left for integrations.
263 FileMode = Mode
264
265
266 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
267     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
268
269
270 def find_pyproject_toml(path_search_start: str) -> Optional[str]:
271     """Find the absolute filepath to a pyproject.toml if it exists"""
272     path_project_root = find_project_root(path_search_start)
273     path_pyproject_toml = path_project_root / "pyproject.toml"
274     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
275
276
277 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
278     """Parse a pyproject toml file, pulling out relevant parts for Black
279
280     If parsing fails, will raise a toml.TomlDecodeError
281     """
282     pyproject_toml = toml.load(path_config)
283     config = pyproject_toml.get("tool", {}).get("black", {})
284     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
285
286
287 def read_pyproject_toml(
288     ctx: click.Context, param: click.Parameter, value: Optional[str]
289 ) -> Optional[str]:
290     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
291
292     Returns the path to a successfully found and read configuration file, None
293     otherwise.
294     """
295     if not value:
296         value = find_pyproject_toml(ctx.params.get("src", ()))
297         if value is None:
298             return None
299
300     try:
301         config = parse_pyproject_toml(value)
302     except (toml.TomlDecodeError, OSError) as e:
303         raise click.FileError(
304             filename=value, hint=f"Error reading configuration file: {e}"
305         )
306
307     if not config:
308         return None
309
310     target_version = config.get("target_version")
311     if target_version is not None and not isinstance(target_version, list):
312         raise click.BadOptionUsage(
313             "target-version", f"Config key target-version must be a list"
314         )
315
316     default_map: Dict[str, Any] = {}
317     if ctx.default_map:
318         default_map.update(ctx.default_map)
319     default_map.update(config)
320
321     ctx.default_map = default_map
322     return value
323
324
325 def target_version_option_callback(
326     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
327 ) -> List[TargetVersion]:
328     """Compute the target versions from a --target-version flag.
329
330     This is its own function because mypy couldn't infer the type correctly
331     when it was a lambda, causing mypyc trouble.
332     """
333     return [TargetVersion[val.upper()] for val in v]
334
335
336 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
337 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
338 @click.option(
339     "-l",
340     "--line-length",
341     type=int,
342     default=DEFAULT_LINE_LENGTH,
343     help="How many characters per line to allow.",
344     show_default=True,
345 )
346 @click.option(
347     "-t",
348     "--target-version",
349     type=click.Choice([v.name.lower() for v in TargetVersion]),
350     callback=target_version_option_callback,
351     multiple=True,
352     help=(
353         "Python versions that should be supported by Black's output. [default: per-file"
354         " auto-detection]"
355     ),
356 )
357 @click.option(
358     "--pyi",
359     is_flag=True,
360     help=(
361         "Format all input files like typing stubs regardless of file extension (useful"
362         " when piping source on standard input)."
363     ),
364 )
365 @click.option(
366     "-S",
367     "--skip-string-normalization",
368     is_flag=True,
369     help="Don't normalize string quotes or prefixes.",
370 )
371 @click.option(
372     "--check",
373     is_flag=True,
374     help=(
375         "Don't write the files back, just return the status.  Return code 0 means"
376         " nothing would change.  Return code 1 means some files would be reformatted."
377         " Return code 123 means there was an internal error."
378     ),
379 )
380 @click.option(
381     "--diff",
382     is_flag=True,
383     help="Don't write the files back, just output a diff for each file on stdout.",
384 )
385 @click.option(
386     "--color/--no-color",
387     is_flag=True,
388     help="Show colored diff. Only applies when `--diff` is given.",
389 )
390 @click.option(
391     "--fast/--safe",
392     is_flag=True,
393     help="If --fast given, skip temporary sanity checks. [default: --safe]",
394 )
395 @click.option(
396     "--include",
397     type=str,
398     default=DEFAULT_INCLUDES,
399     help=(
400         "A regular expression that matches files and directories that should be"
401         " included on recursive searches.  An empty value means all files are included"
402         " regardless of the name.  Use forward slashes for directories on all platforms"
403         " (Windows, too).  Exclusions are calculated first, inclusions later."
404     ),
405     show_default=True,
406 )
407 @click.option(
408     "--exclude",
409     type=str,
410     default=DEFAULT_EXCLUDES,
411     help=(
412         "A regular expression that matches files and directories that should be"
413         " excluded on recursive searches.  An empty value means no paths are excluded."
414         " Use forward slashes for directories on all platforms (Windows, too). "
415         " Exclusions are calculated first, inclusions later."
416     ),
417     show_default=True,
418 )
419 @click.option(
420     "--force-exclude",
421     type=str,
422     help=(
423         "Like --exclude, but files and directories matching this regex will be "
424         "excluded even when they are passed explicitly as arguments"
425     ),
426 )
427 @click.option(
428     "-q",
429     "--quiet",
430     is_flag=True,
431     help=(
432         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
433         " those with 2>/dev/null."
434     ),
435 )
436 @click.option(
437     "-v",
438     "--verbose",
439     is_flag=True,
440     help=(
441         "Also emit messages to stderr about files that were not changed or were ignored"
442         " due to --exclude=."
443     ),
444 )
445 @click.version_option(version=__version__)
446 @click.argument(
447     "src",
448     nargs=-1,
449     type=click.Path(
450         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
451     ),
452     is_eager=True,
453 )
454 @click.option(
455     "--config",
456     type=click.Path(
457         exists=True,
458         file_okay=True,
459         dir_okay=False,
460         readable=True,
461         allow_dash=False,
462         path_type=str,
463     ),
464     is_eager=True,
465     callback=read_pyproject_toml,
466     help="Read configuration from PATH.",
467 )
468 @click.pass_context
469 def main(
470     ctx: click.Context,
471     code: Optional[str],
472     line_length: int,
473     target_version: List[TargetVersion],
474     check: bool,
475     diff: bool,
476     color: bool,
477     fast: bool,
478     pyi: bool,
479     skip_string_normalization: bool,
480     quiet: bool,
481     verbose: bool,
482     include: str,
483     exclude: str,
484     force_exclude: Optional[str],
485     src: Tuple[str, ...],
486     config: Optional[str],
487 ) -> None:
488     """The uncompromising code formatter."""
489     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
490     if target_version:
491         versions = set(target_version)
492     else:
493         # We'll autodetect later.
494         versions = set()
495     mode = Mode(
496         target_versions=versions,
497         line_length=line_length,
498         is_pyi=pyi,
499         string_normalization=not skip_string_normalization,
500     )
501     if config and verbose:
502         out(f"Using configuration from {config}.", bold=False, fg="blue")
503     if code is not None:
504         print(format_str(code, mode=mode))
505         ctx.exit(0)
506     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
507     sources = get_sources(
508         ctx=ctx,
509         src=src,
510         quiet=quiet,
511         verbose=verbose,
512         include=include,
513         exclude=exclude,
514         force_exclude=force_exclude,
515         report=report,
516     )
517
518     path_empty(
519         sources,
520         "No Python files are present to be formatted. Nothing to do 😴",
521         quiet,
522         verbose,
523         ctx,
524     )
525
526     if len(sources) == 1:
527         reformat_one(
528             src=sources.pop(),
529             fast=fast,
530             write_back=write_back,
531             mode=mode,
532             report=report,
533         )
534     else:
535         reformat_many(
536             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
537         )
538
539     if verbose or not quiet:
540         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
541         click.secho(str(report), err=True)
542     ctx.exit(report.return_code)
543
544
545 def get_sources(
546     *,
547     ctx: click.Context,
548     src: Tuple[str, ...],
549     quiet: bool,
550     verbose: bool,
551     include: str,
552     exclude: str,
553     force_exclude: Optional[str],
554     report: "Report",
555 ) -> Set[Path]:
556     """Compute the set of files to be formatted."""
557     try:
558         include_regex = re_compile_maybe_verbose(include)
559     except re.error:
560         err(f"Invalid regular expression for include given: {include!r}")
561         ctx.exit(2)
562     try:
563         exclude_regex = re_compile_maybe_verbose(exclude)
564     except re.error:
565         err(f"Invalid regular expression for exclude given: {exclude!r}")
566         ctx.exit(2)
567     try:
568         force_exclude_regex = (
569             re_compile_maybe_verbose(force_exclude) if force_exclude else None
570         )
571     except re.error:
572         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
573         ctx.exit(2)
574
575     root = find_project_root(src)
576     sources: Set[Path] = set()
577     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
578     exclude_regexes = [exclude_regex]
579     if force_exclude_regex is not None:
580         exclude_regexes.append(force_exclude_regex)
581
582     for s in src:
583         p = Path(s)
584         if p.is_dir():
585             sources.update(
586                 gen_python_files(
587                     p.iterdir(),
588                     root,
589                     include_regex,
590                     exclude_regexes,
591                     report,
592                     get_gitignore(root),
593                 )
594             )
595         elif s == "-":
596             sources.add(p)
597         elif p.is_file():
598             sources.update(
599                 gen_python_files(
600                     [p], root, None, exclude_regexes, report, get_gitignore(root)
601                 )
602             )
603         else:
604             err(f"invalid path: {s}")
605     return sources
606
607
608 def path_empty(
609     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
610 ) -> None:
611     """
612     Exit if there is no `src` provided for formatting
613     """
614     if len(src) == 0:
615         if verbose or not quiet:
616             out(msg)
617             ctx.exit(0)
618
619
620 def reformat_one(
621     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
622 ) -> None:
623     """Reformat a single file under `src` without spawning child processes.
624
625     `fast`, `write_back`, and `mode` options are passed to
626     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
627     """
628     try:
629         changed = Changed.NO
630         if not src.is_file() and str(src) == "-":
631             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
632                 changed = Changed.YES
633         else:
634             cache: Cache = {}
635             if write_back != WriteBack.DIFF:
636                 cache = read_cache(mode)
637                 res_src = src.resolve()
638                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
639                     changed = Changed.CACHED
640             if changed is not Changed.CACHED and format_file_in_place(
641                 src, fast=fast, write_back=write_back, mode=mode
642             ):
643                 changed = Changed.YES
644             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
645                 write_back is WriteBack.CHECK and changed is Changed.NO
646             ):
647                 write_cache(cache, [src], mode)
648         report.done(src, changed)
649     except Exception as exc:
650         report.failed(src, str(exc))
651
652
653 def reformat_many(
654     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
655 ) -> None:
656     """Reformat multiple files using a ProcessPoolExecutor."""
657     executor: Executor
658     loop = asyncio.get_event_loop()
659     worker_count = os.cpu_count()
660     if sys.platform == "win32":
661         # Work around https://bugs.python.org/issue26903
662         worker_count = min(worker_count, 61)
663     try:
664         executor = ProcessPoolExecutor(max_workers=worker_count)
665     except (ImportError, OSError):
666         # we arrive here if the underlying system does not support multi-processing
667         # like in AWS Lambda or Termux, in which case we gracefully fallback to
668         # a ThreadPollExecutor with just a single worker (more workers would not do us
669         # any good due to the Global Interpreter Lock)
670         executor = ThreadPoolExecutor(max_workers=1)
671
672     try:
673         loop.run_until_complete(
674             schedule_formatting(
675                 sources=sources,
676                 fast=fast,
677                 write_back=write_back,
678                 mode=mode,
679                 report=report,
680                 loop=loop,
681                 executor=executor,
682             )
683         )
684     finally:
685         shutdown(loop)
686         if executor is not None:
687             executor.shutdown()
688
689
690 async def schedule_formatting(
691     sources: Set[Path],
692     fast: bool,
693     write_back: WriteBack,
694     mode: Mode,
695     report: "Report",
696     loop: asyncio.AbstractEventLoop,
697     executor: Executor,
698 ) -> None:
699     """Run formatting of `sources` in parallel using the provided `executor`.
700
701     (Use ProcessPoolExecutors for actual parallelism.)
702
703     `write_back`, `fast`, and `mode` options are passed to
704     :func:`format_file_in_place`.
705     """
706     cache: Cache = {}
707     if write_back != WriteBack.DIFF:
708         cache = read_cache(mode)
709         sources, cached = filter_cached(cache, sources)
710         for src in sorted(cached):
711             report.done(src, Changed.CACHED)
712     if not sources:
713         return
714
715     cancelled = []
716     sources_to_cache = []
717     lock = None
718     if write_back == WriteBack.DIFF:
719         # For diff output, we need locks to ensure we don't interleave output
720         # from different processes.
721         manager = Manager()
722         lock = manager.Lock()
723     tasks = {
724         asyncio.ensure_future(
725             loop.run_in_executor(
726                 executor, format_file_in_place, src, fast, mode, write_back, lock
727             )
728         ): src
729         for src in sorted(sources)
730     }
731     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
732     try:
733         loop.add_signal_handler(signal.SIGINT, cancel, pending)
734         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
735     except NotImplementedError:
736         # There are no good alternatives for these on Windows.
737         pass
738     while pending:
739         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
740         for task in done:
741             src = tasks.pop(task)
742             if task.cancelled():
743                 cancelled.append(task)
744             elif task.exception():
745                 report.failed(src, str(task.exception()))
746             else:
747                 changed = Changed.YES if task.result() else Changed.NO
748                 # If the file was written back or was successfully checked as
749                 # well-formatted, store this information in the cache.
750                 if write_back is WriteBack.YES or (
751                     write_back is WriteBack.CHECK and changed is Changed.NO
752                 ):
753                     sources_to_cache.append(src)
754                 report.done(src, changed)
755     if cancelled:
756         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
757     if sources_to_cache:
758         write_cache(cache, sources_to_cache, mode)
759
760
761 def format_file_in_place(
762     src: Path,
763     fast: bool,
764     mode: Mode,
765     write_back: WriteBack = WriteBack.NO,
766     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
767 ) -> bool:
768     """Format file under `src` path. Return True if changed.
769
770     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
771     code to the file.
772     `mode` and `fast` options are passed to :func:`format_file_contents`.
773     """
774     if src.suffix == ".pyi":
775         mode = replace(mode, is_pyi=True)
776
777     then = datetime.utcfromtimestamp(src.stat().st_mtime)
778     with open(src, "rb") as buf:
779         src_contents, encoding, newline = decode_bytes(buf.read())
780     try:
781         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
782     except NothingChanged:
783         return False
784
785     if write_back == WriteBack.YES:
786         with open(src, "w", encoding=encoding, newline=newline) as f:
787             f.write(dst_contents)
788     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
789         now = datetime.utcnow()
790         src_name = f"{src}\t{then} +0000"
791         dst_name = f"{src}\t{now} +0000"
792         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
793
794         if write_back == write_back.COLOR_DIFF:
795             diff_contents = color_diff(diff_contents)
796
797         with lock or nullcontext():
798             f = io.TextIOWrapper(
799                 sys.stdout.buffer,
800                 encoding=encoding,
801                 newline=newline,
802                 write_through=True,
803             )
804             f = wrap_stream_for_windows(f)
805             f.write(diff_contents)
806             f.detach()
807
808     return True
809
810
811 def color_diff(contents: str) -> str:
812     """Inject the ANSI color codes to the diff."""
813     lines = contents.split("\n")
814     for i, line in enumerate(lines):
815         if line.startswith("+++") or line.startswith("---"):
816             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
817         if line.startswith("@@"):
818             line = "\033[36m" + line + "\033[0m"  # cyan, reset
819         if line.startswith("+"):
820             line = "\033[32m" + line + "\033[0m"  # green, reset
821         elif line.startswith("-"):
822             line = "\033[31m" + line + "\033[0m"  # red, reset
823         lines[i] = line
824     return "\n".join(lines)
825
826
827 def wrap_stream_for_windows(
828     f: io.TextIOWrapper,
829 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
830     """
831     Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
832
833     If `colorama` is not found, then no change is made. If `colorama` does
834     exist, then it handles the logic to determine whether or not to change
835     things.
836     """
837     try:
838         from colorama import initialise
839
840         # We set `strip=False` so that we can don't have to modify
841         # test_express_diff_with_color.
842         f = initialise.wrap_stream(
843             f, convert=None, strip=False, autoreset=False, wrap=True
844         )
845
846         # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
847         # which does not have a `detach()` method. So we fake one.
848         f.detach = lambda *args, **kwargs: None  # type: ignore
849     except ImportError:
850         pass
851
852     return f
853
854
855 def format_stdin_to_stdout(
856     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
857 ) -> bool:
858     """Format file on stdin. Return True if changed.
859
860     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
861     write a diff to stdout. The `mode` argument is passed to
862     :func:`format_file_contents`.
863     """
864     then = datetime.utcnow()
865     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
866     dst = src
867     try:
868         dst = format_file_contents(src, fast=fast, mode=mode)
869         return True
870
871     except NothingChanged:
872         return False
873
874     finally:
875         f = io.TextIOWrapper(
876             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
877         )
878         if write_back == WriteBack.YES:
879             f.write(dst)
880         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
881             now = datetime.utcnow()
882             src_name = f"STDIN\t{then} +0000"
883             dst_name = f"STDOUT\t{now} +0000"
884             d = diff(src, dst, src_name, dst_name)
885             if write_back == WriteBack.COLOR_DIFF:
886                 d = color_diff(d)
887                 f = wrap_stream_for_windows(f)
888             f.write(d)
889         f.detach()
890
891
892 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
893     """Reformat contents a file and return new contents.
894
895     If `fast` is False, additionally confirm that the reformatted code is
896     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
897     `mode` is passed to :func:`format_str`.
898     """
899     if src_contents.strip() == "":
900         raise NothingChanged
901
902     dst_contents = format_str(src_contents, mode=mode)
903     if src_contents == dst_contents:
904         raise NothingChanged
905
906     if not fast:
907         assert_equivalent(src_contents, dst_contents)
908         assert_stable(src_contents, dst_contents, mode=mode)
909     return dst_contents
910
911
912 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
913     """Reformat a string and return new contents.
914
915     `mode` determines formatting options, such as how many characters per line are
916     allowed.  Example:
917
918     >>> import black
919     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
920     def f(arg: str = "") -> None:
921         ...
922
923     A more complex example:
924     >>> print(
925     ...   black.format_str(
926     ...     "def f(arg:str='')->None: hey",
927     ...     mode=black.Mode(
928     ...       target_versions={black.TargetVersion.PY36},
929     ...       line_length=10,
930     ...       string_normalization=False,
931     ...       is_pyi=False,
932     ...     ),
933     ...   ),
934     ... )
935     def f(
936         arg: str = '',
937     ) -> None:
938         hey
939
940     """
941     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
942     dst_contents = []
943     future_imports = get_future_imports(src_node)
944     if mode.target_versions:
945         versions = mode.target_versions
946     else:
947         versions = detect_target_versions(src_node)
948     normalize_fmt_off(src_node)
949     lines = LineGenerator(
950         remove_u_prefix="unicode_literals" in future_imports
951         or supports_feature(versions, Feature.UNICODE_LITERALS),
952         is_pyi=mode.is_pyi,
953         normalize_strings=mode.string_normalization,
954     )
955     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
956     empty_line = Line()
957     after = 0
958     split_line_features = {
959         feature
960         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
961         if supports_feature(versions, feature)
962     }
963     for current_line in lines.visit(src_node):
964         dst_contents.append(str(empty_line) * after)
965         before, after = elt.maybe_empty_lines(current_line)
966         dst_contents.append(str(empty_line) * before)
967         for line in transform_line(
968             current_line,
969             line_length=mode.line_length,
970             normalize_strings=mode.string_normalization,
971             features=split_line_features,
972         ):
973             dst_contents.append(str(line))
974     return "".join(dst_contents)
975
976
977 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
978     """Return a tuple of (decoded_contents, encoding, newline).
979
980     `newline` is either CRLF or LF but `decoded_contents` is decoded with
981     universal newlines (i.e. only contains LF).
982     """
983     srcbuf = io.BytesIO(src)
984     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
985     if not lines:
986         return "", encoding, "\n"
987
988     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
989     srcbuf.seek(0)
990     with io.TextIOWrapper(srcbuf, encoding) as tiow:
991         return tiow.read(), encoding, newline
992
993
994 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
995     if not target_versions:
996         # No target_version specified, so try all grammars.
997         return [
998             # Python 3.7+
999             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1000             # Python 3.0-3.6
1001             pygram.python_grammar_no_print_statement_no_exec_statement,
1002             # Python 2.7 with future print_function import
1003             pygram.python_grammar_no_print_statement,
1004             # Python 2.7
1005             pygram.python_grammar,
1006         ]
1007
1008     if all(version.is_python2() for version in target_versions):
1009         # Python 2-only code, so try Python 2 grammars.
1010         return [
1011             # Python 2.7 with future print_function import
1012             pygram.python_grammar_no_print_statement,
1013             # Python 2.7
1014             pygram.python_grammar,
1015         ]
1016
1017     # Python 3-compatible code, so only try Python 3 grammar.
1018     grammars = []
1019     # If we have to parse both, try to parse async as a keyword first
1020     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1021         # Python 3.7+
1022         grammars.append(
1023             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1024         )
1025     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1026         # Python 3.0-3.6
1027         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1028     # At least one of the above branches must have been taken, because every Python
1029     # version has exactly one of the two 'ASYNC_*' flags
1030     return grammars
1031
1032
1033 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1034     """Given a string with source, return the lib2to3 Node."""
1035     if src_txt[-1:] != "\n":
1036         src_txt += "\n"
1037
1038     for grammar in get_grammars(set(target_versions)):
1039         drv = driver.Driver(grammar, pytree.convert)
1040         try:
1041             result = drv.parse_string(src_txt, True)
1042             break
1043
1044         except ParseError as pe:
1045             lineno, column = pe.context[1]
1046             lines = src_txt.splitlines()
1047             try:
1048                 faulty_line = lines[lineno - 1]
1049             except IndexError:
1050                 faulty_line = "<line number missing in source>"
1051             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1052     else:
1053         raise exc from None
1054
1055     if isinstance(result, Leaf):
1056         result = Node(syms.file_input, [result])
1057     return result
1058
1059
1060 def lib2to3_unparse(node: Node) -> str:
1061     """Given a lib2to3 node, return its string representation."""
1062     code = str(node)
1063     return code
1064
1065
1066 class Visitor(Generic[T]):
1067     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1068
1069     def visit(self, node: LN) -> Iterator[T]:
1070         """Main method to visit `node` and its children.
1071
1072         It tries to find a `visit_*()` method for the given `node.type`, like
1073         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1074         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1075         instead.
1076
1077         Then yields objects of type `T` from the selected visitor.
1078         """
1079         if node.type < 256:
1080             name = token.tok_name[node.type]
1081         else:
1082             name = str(type_repr(node.type))
1083         # We explicitly branch on whether a visitor exists (instead of
1084         # using self.visit_default as the default arg to getattr) in order
1085         # to save needing to create a bound method object and so mypyc can
1086         # generate a native call to visit_default.
1087         visitf = getattr(self, f"visit_{name}", None)
1088         if visitf:
1089             yield from visitf(node)
1090         else:
1091             yield from self.visit_default(node)
1092
1093     def visit_default(self, node: LN) -> Iterator[T]:
1094         """Default `visit_*()` implementation. Recurses to children of `node`."""
1095         if isinstance(node, Node):
1096             for child in node.children:
1097                 yield from self.visit(child)
1098
1099
1100 @dataclass
1101 class DebugVisitor(Visitor[T]):
1102     tree_depth: int = 0
1103
1104     def visit_default(self, node: LN) -> Iterator[T]:
1105         indent = " " * (2 * self.tree_depth)
1106         if isinstance(node, Node):
1107             _type = type_repr(node.type)
1108             out(f"{indent}{_type}", fg="yellow")
1109             self.tree_depth += 1
1110             for child in node.children:
1111                 yield from self.visit(child)
1112
1113             self.tree_depth -= 1
1114             out(f"{indent}/{_type}", fg="yellow", bold=False)
1115         else:
1116             _type = token.tok_name.get(node.type, str(node.type))
1117             out(f"{indent}{_type}", fg="blue", nl=False)
1118             if node.prefix:
1119                 # We don't have to handle prefixes for `Node` objects since
1120                 # that delegates to the first child anyway.
1121                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1122             out(f" {node.value!r}", fg="blue", bold=False)
1123
1124     @classmethod
1125     def show(cls, code: Union[str, Leaf, Node]) -> None:
1126         """Pretty-print the lib2to3 AST of a given string of `code`.
1127
1128         Convenience method for debugging.
1129         """
1130         v: DebugVisitor[None] = DebugVisitor()
1131         if isinstance(code, str):
1132             code = lib2to3_parse(code)
1133         list(v.visit(code))
1134
1135
1136 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1137 STATEMENT: Final = {
1138     syms.if_stmt,
1139     syms.while_stmt,
1140     syms.for_stmt,
1141     syms.try_stmt,
1142     syms.except_clause,
1143     syms.with_stmt,
1144     syms.funcdef,
1145     syms.classdef,
1146 }
1147 STANDALONE_COMMENT: Final = 153
1148 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1149 LOGIC_OPERATORS: Final = {"and", "or"}
1150 COMPARATORS: Final = {
1151     token.LESS,
1152     token.GREATER,
1153     token.EQEQUAL,
1154     token.NOTEQUAL,
1155     token.LESSEQUAL,
1156     token.GREATEREQUAL,
1157 }
1158 MATH_OPERATORS: Final = {
1159     token.VBAR,
1160     token.CIRCUMFLEX,
1161     token.AMPER,
1162     token.LEFTSHIFT,
1163     token.RIGHTSHIFT,
1164     token.PLUS,
1165     token.MINUS,
1166     token.STAR,
1167     token.SLASH,
1168     token.DOUBLESLASH,
1169     token.PERCENT,
1170     token.AT,
1171     token.TILDE,
1172     token.DOUBLESTAR,
1173 }
1174 STARS: Final = {token.STAR, token.DOUBLESTAR}
1175 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1176 VARARGS_PARENTS: Final = {
1177     syms.arglist,
1178     syms.argument,  # double star in arglist
1179     syms.trailer,  # single argument to call
1180     syms.typedargslist,
1181     syms.varargslist,  # lambdas
1182 }
1183 UNPACKING_PARENTS: Final = {
1184     syms.atom,  # single element of a list or set literal
1185     syms.dictsetmaker,
1186     syms.listmaker,
1187     syms.testlist_gexp,
1188     syms.testlist_star_expr,
1189 }
1190 TEST_DESCENDANTS: Final = {
1191     syms.test,
1192     syms.lambdef,
1193     syms.or_test,
1194     syms.and_test,
1195     syms.not_test,
1196     syms.comparison,
1197     syms.star_expr,
1198     syms.expr,
1199     syms.xor_expr,
1200     syms.and_expr,
1201     syms.shift_expr,
1202     syms.arith_expr,
1203     syms.trailer,
1204     syms.term,
1205     syms.power,
1206 }
1207 ASSIGNMENTS: Final = {
1208     "=",
1209     "+=",
1210     "-=",
1211     "*=",
1212     "@=",
1213     "/=",
1214     "%=",
1215     "&=",
1216     "|=",
1217     "^=",
1218     "<<=",
1219     ">>=",
1220     "**=",
1221     "//=",
1222 }
1223 COMPREHENSION_PRIORITY: Final = 20
1224 COMMA_PRIORITY: Final = 18
1225 TERNARY_PRIORITY: Final = 16
1226 LOGIC_PRIORITY: Final = 14
1227 STRING_PRIORITY: Final = 12
1228 COMPARATOR_PRIORITY: Final = 10
1229 MATH_PRIORITIES: Final = {
1230     token.VBAR: 9,
1231     token.CIRCUMFLEX: 8,
1232     token.AMPER: 7,
1233     token.LEFTSHIFT: 6,
1234     token.RIGHTSHIFT: 6,
1235     token.PLUS: 5,
1236     token.MINUS: 5,
1237     token.STAR: 4,
1238     token.SLASH: 4,
1239     token.DOUBLESLASH: 4,
1240     token.PERCENT: 4,
1241     token.AT: 4,
1242     token.TILDE: 3,
1243     token.DOUBLESTAR: 2,
1244 }
1245 DOT_PRIORITY: Final = 1
1246
1247
1248 @dataclass
1249 class BracketTracker:
1250     """Keeps track of brackets on a line."""
1251
1252     depth: int = 0
1253     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1254     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1255     previous: Optional[Leaf] = None
1256     _for_loop_depths: List[int] = field(default_factory=list)
1257     _lambda_argument_depths: List[int] = field(default_factory=list)
1258
1259     def mark(self, leaf: Leaf) -> None:
1260         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1261
1262         All leaves receive an int `bracket_depth` field that stores how deep
1263         within brackets a given leaf is. 0 means there are no enclosing brackets
1264         that started on this line.
1265
1266         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1267         field that it forms a pair with. This is a one-directional link to
1268         avoid reference cycles.
1269
1270         If a leaf is a delimiter (a token on which Black can split the line if
1271         needed) and it's on depth 0, its `id()` is stored in the tracker's
1272         `delimiters` field.
1273         """
1274         if leaf.type == token.COMMENT:
1275             return
1276
1277         self.maybe_decrement_after_for_loop_variable(leaf)
1278         self.maybe_decrement_after_lambda_arguments(leaf)
1279         if leaf.type in CLOSING_BRACKETS:
1280             self.depth -= 1
1281             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1282             leaf.opening_bracket = opening_bracket
1283         leaf.bracket_depth = self.depth
1284         if self.depth == 0:
1285             delim = is_split_before_delimiter(leaf, self.previous)
1286             if delim and self.previous is not None:
1287                 self.delimiters[id(self.previous)] = delim
1288             else:
1289                 delim = is_split_after_delimiter(leaf, self.previous)
1290                 if delim:
1291                     self.delimiters[id(leaf)] = delim
1292         if leaf.type in OPENING_BRACKETS:
1293             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1294             self.depth += 1
1295         self.previous = leaf
1296         self.maybe_increment_lambda_arguments(leaf)
1297         self.maybe_increment_for_loop_variable(leaf)
1298
1299     def any_open_brackets(self) -> bool:
1300         """Return True if there is an yet unmatched open bracket on the line."""
1301         return bool(self.bracket_match)
1302
1303     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1304         """Return the highest priority of a delimiter found on the line.
1305
1306         Values are consistent with what `is_split_*_delimiter()` return.
1307         Raises ValueError on no delimiters.
1308         """
1309         return max(v for k, v in self.delimiters.items() if k not in exclude)
1310
1311     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1312         """Return the number of delimiters with the given `priority`.
1313
1314         If no `priority` is passed, defaults to max priority on the line.
1315         """
1316         if not self.delimiters:
1317             return 0
1318
1319         priority = priority or self.max_delimiter_priority()
1320         return sum(1 for p in self.delimiters.values() if p == priority)
1321
1322     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1323         """In a for loop, or comprehension, the variables are often unpacks.
1324
1325         To avoid splitting on the comma in this situation, increase the depth of
1326         tokens between `for` and `in`.
1327         """
1328         if leaf.type == token.NAME and leaf.value == "for":
1329             self.depth += 1
1330             self._for_loop_depths.append(self.depth)
1331             return True
1332
1333         return False
1334
1335     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1336         """See `maybe_increment_for_loop_variable` above for explanation."""
1337         if (
1338             self._for_loop_depths
1339             and self._for_loop_depths[-1] == self.depth
1340             and leaf.type == token.NAME
1341             and leaf.value == "in"
1342         ):
1343             self.depth -= 1
1344             self._for_loop_depths.pop()
1345             return True
1346
1347         return False
1348
1349     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1350         """In a lambda expression, there might be more than one argument.
1351
1352         To avoid splitting on the comma in this situation, increase the depth of
1353         tokens between `lambda` and `:`.
1354         """
1355         if leaf.type == token.NAME and leaf.value == "lambda":
1356             self.depth += 1
1357             self._lambda_argument_depths.append(self.depth)
1358             return True
1359
1360         return False
1361
1362     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1363         """See `maybe_increment_lambda_arguments` above for explanation."""
1364         if (
1365             self._lambda_argument_depths
1366             and self._lambda_argument_depths[-1] == self.depth
1367             and leaf.type == token.COLON
1368         ):
1369             self.depth -= 1
1370             self._lambda_argument_depths.pop()
1371             return True
1372
1373         return False
1374
1375     def get_open_lsqb(self) -> Optional[Leaf]:
1376         """Return the most recent opening square bracket (if any)."""
1377         return self.bracket_match.get((self.depth - 1, token.RSQB))
1378
1379
1380 @dataclass
1381 class Line:
1382     """Holds leaves and comments. Can be printed with `str(line)`."""
1383
1384     depth: int = 0
1385     leaves: List[Leaf] = field(default_factory=list)
1386     # keys ordered like `leaves`
1387     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1388     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1389     inside_brackets: bool = False
1390     should_explode: bool = False
1391
1392     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1393         """Add a new `leaf` to the end of the line.
1394
1395         Unless `preformatted` is True, the `leaf` will receive a new consistent
1396         whitespace prefix and metadata applied by :class:`BracketTracker`.
1397         Trailing commas are maybe removed, unpacked for loop variables are
1398         demoted from being delimiters.
1399
1400         Inline comments are put aside.
1401         """
1402         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1403         if not has_value:
1404             return
1405
1406         if token.COLON == leaf.type and self.is_class_paren_empty:
1407             del self.leaves[-2:]
1408         if self.leaves and not preformatted:
1409             # Note: at this point leaf.prefix should be empty except for
1410             # imports, for which we only preserve newlines.
1411             leaf.prefix += whitespace(
1412                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1413             )
1414         if self.inside_brackets or not preformatted:
1415             self.bracket_tracker.mark(leaf)
1416             self.maybe_remove_trailing_comma(leaf)
1417         if not self.append_comment(leaf):
1418             self.leaves.append(leaf)
1419
1420     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1421         """Like :func:`append()` but disallow invalid standalone comment structure.
1422
1423         Raises ValueError when any `leaf` is appended after a standalone comment
1424         or when a standalone comment is not the first leaf on the line.
1425         """
1426         if self.bracket_tracker.depth == 0:
1427             if self.is_comment:
1428                 raise ValueError("cannot append to standalone comments")
1429
1430             if self.leaves and leaf.type == STANDALONE_COMMENT:
1431                 raise ValueError(
1432                     "cannot append standalone comments to a populated line"
1433                 )
1434
1435         self.append(leaf, preformatted=preformatted)
1436
1437     @property
1438     def is_comment(self) -> bool:
1439         """Is this line a standalone comment?"""
1440         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1441
1442     @property
1443     def is_decorator(self) -> bool:
1444         """Is this line a decorator?"""
1445         return bool(self) and self.leaves[0].type == token.AT
1446
1447     @property
1448     def is_import(self) -> bool:
1449         """Is this an import line?"""
1450         return bool(self) and is_import(self.leaves[0])
1451
1452     @property
1453     def is_class(self) -> bool:
1454         """Is this line a class definition?"""
1455         return (
1456             bool(self)
1457             and self.leaves[0].type == token.NAME
1458             and self.leaves[0].value == "class"
1459         )
1460
1461     @property
1462     def is_stub_class(self) -> bool:
1463         """Is this line a class definition with a body consisting only of "..."?"""
1464         return self.is_class and self.leaves[-3:] == [
1465             Leaf(token.DOT, ".") for _ in range(3)
1466         ]
1467
1468     @property
1469     def is_collection_with_optional_trailing_comma(self) -> bool:
1470         """Is this line a collection literal with a trailing comma that's optional?
1471
1472         Note that the trailing comma in a 1-tuple is not optional.
1473         """
1474         if not self.leaves or len(self.leaves) < 4:
1475             return False
1476
1477         # Look for and address a trailing colon.
1478         if self.leaves[-1].type == token.COLON:
1479             closer = self.leaves[-2]
1480             close_index = -2
1481         else:
1482             closer = self.leaves[-1]
1483             close_index = -1
1484         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1485             return False
1486
1487         if closer.type == token.RPAR:
1488             # Tuples require an extra check, because if there's only
1489             # one element in the tuple removing the comma unmakes the
1490             # tuple.
1491             #
1492             # We also check for parens before looking for the trailing
1493             # comma because in some cases (eg assigning a dict
1494             # literal) the literal gets wrapped in temporary parens
1495             # during parsing. This case is covered by the
1496             # collections.py test data.
1497             opener = closer.opening_bracket
1498             for _open_index, leaf in enumerate(self.leaves):
1499                 if leaf is opener:
1500                     break
1501
1502             else:
1503                 # Couldn't find the matching opening paren, play it safe.
1504                 return False
1505
1506             commas = 0
1507             comma_depth = self.leaves[close_index - 1].bracket_depth
1508             for leaf in self.leaves[_open_index + 1 : close_index]:
1509                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1510                     commas += 1
1511             if commas > 1:
1512                 # We haven't looked yet for the trailing comma because
1513                 # we might also have caught noop parens.
1514                 return self.leaves[close_index - 1].type == token.COMMA
1515
1516             elif commas == 1:
1517                 return False  # it's either a one-tuple or didn't have a trailing comma
1518
1519             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1520                 close_index -= 1
1521                 closer = self.leaves[close_index]
1522                 if closer.type == token.RPAR:
1523                     # TODO: this is a gut feeling. Will we ever see this?
1524                     return False
1525
1526         if self.leaves[close_index - 1].type != token.COMMA:
1527             return False
1528
1529         return True
1530
1531     @property
1532     def is_def(self) -> bool:
1533         """Is this a function definition? (Also returns True for async defs.)"""
1534         try:
1535             first_leaf = self.leaves[0]
1536         except IndexError:
1537             return False
1538
1539         try:
1540             second_leaf: Optional[Leaf] = self.leaves[1]
1541         except IndexError:
1542             second_leaf = None
1543         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1544             first_leaf.type == token.ASYNC
1545             and second_leaf is not None
1546             and second_leaf.type == token.NAME
1547             and second_leaf.value == "def"
1548         )
1549
1550     @property
1551     def is_class_paren_empty(self) -> bool:
1552         """Is this a class with no base classes but using parentheses?
1553
1554         Those are unnecessary and should be removed.
1555         """
1556         return (
1557             bool(self)
1558             and len(self.leaves) == 4
1559             and self.is_class
1560             and self.leaves[2].type == token.LPAR
1561             and self.leaves[2].value == "("
1562             and self.leaves[3].type == token.RPAR
1563             and self.leaves[3].value == ")"
1564         )
1565
1566     @property
1567     def is_triple_quoted_string(self) -> bool:
1568         """Is the line a triple quoted string?"""
1569         return (
1570             bool(self)
1571             and self.leaves[0].type == token.STRING
1572             and self.leaves[0].value.startswith(('"""', "'''"))
1573         )
1574
1575     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1576         """If so, needs to be split before emitting."""
1577         for leaf in self.leaves:
1578             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1579                 return True
1580
1581         return False
1582
1583     def contains_uncollapsable_type_comments(self) -> bool:
1584         ignored_ids = set()
1585         try:
1586             last_leaf = self.leaves[-1]
1587             ignored_ids.add(id(last_leaf))
1588             if last_leaf.type == token.COMMA or (
1589                 last_leaf.type == token.RPAR and not last_leaf.value
1590             ):
1591                 # When trailing commas or optional parens are inserted by Black for
1592                 # consistency, comments after the previous last element are not moved
1593                 # (they don't have to, rendering will still be correct).  So we ignore
1594                 # trailing commas and invisible.
1595                 last_leaf = self.leaves[-2]
1596                 ignored_ids.add(id(last_leaf))
1597         except IndexError:
1598             return False
1599
1600         # A type comment is uncollapsable if it is attached to a leaf
1601         # that isn't at the end of the line (since that could cause it
1602         # to get associated to a different argument) or if there are
1603         # comments before it (since that could cause it to get hidden
1604         # behind a comment.
1605         comment_seen = False
1606         for leaf_id, comments in self.comments.items():
1607             for comment in comments:
1608                 if is_type_comment(comment):
1609                     if comment_seen or (
1610                         not is_type_comment(comment, " ignore")
1611                         and leaf_id not in ignored_ids
1612                     ):
1613                         return True
1614
1615                 comment_seen = True
1616
1617         return False
1618
1619     def contains_unsplittable_type_ignore(self) -> bool:
1620         if not self.leaves:
1621             return False
1622
1623         # If a 'type: ignore' is attached to the end of a line, we
1624         # can't split the line, because we can't know which of the
1625         # subexpressions the ignore was meant to apply to.
1626         #
1627         # We only want this to apply to actual physical lines from the
1628         # original source, though: we don't want the presence of a
1629         # 'type: ignore' at the end of a multiline expression to
1630         # justify pushing it all onto one line. Thus we
1631         # (unfortunately) need to check the actual source lines and
1632         # only report an unsplittable 'type: ignore' if this line was
1633         # one line in the original code.
1634
1635         # Grab the first and last line numbers, skipping generated leaves
1636         first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0)
1637         last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0)
1638
1639         if first_line == last_line:
1640             # We look at the last two leaves since a comma or an
1641             # invisible paren could have been added at the end of the
1642             # line.
1643             for node in self.leaves[-2:]:
1644                 for comment in self.comments.get(id(node), []):
1645                     if is_type_comment(comment, " ignore"):
1646                         return True
1647
1648         return False
1649
1650     def contains_multiline_strings(self) -> bool:
1651         return any(is_multiline_string(leaf) for leaf in self.leaves)
1652
1653     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1654         """Remove trailing comma if there is one and it's safe."""
1655         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1656             return False
1657
1658         # We remove trailing commas only in the case of importing a
1659         # single name from a module.
1660         if not (
1661             self.leaves
1662             and self.is_import
1663             and len(self.leaves) > 4
1664             and self.leaves[-1].type == token.COMMA
1665             and closing.type in CLOSING_BRACKETS
1666             and self.leaves[-4].type == token.NAME
1667             and (
1668                 # regular `from foo import bar,`
1669                 self.leaves[-4].value == "import"
1670                 # `from foo import (bar as baz,)
1671                 or (
1672                     len(self.leaves) > 6
1673                     and self.leaves[-6].value == "import"
1674                     and self.leaves[-3].value == "as"
1675                 )
1676                 # `from foo import bar as baz,`
1677                 or (
1678                     len(self.leaves) > 5
1679                     and self.leaves[-5].value == "import"
1680                     and self.leaves[-3].value == "as"
1681                 )
1682             )
1683             and closing.type == token.RPAR
1684         ):
1685             return False
1686
1687         self.remove_trailing_comma()
1688         return True
1689
1690     def append_comment(self, comment: Leaf) -> bool:
1691         """Add an inline or standalone comment to the line."""
1692         if (
1693             comment.type == STANDALONE_COMMENT
1694             and self.bracket_tracker.any_open_brackets()
1695         ):
1696             comment.prefix = ""
1697             return False
1698
1699         if comment.type != token.COMMENT:
1700             return False
1701
1702         if not self.leaves:
1703             comment.type = STANDALONE_COMMENT
1704             comment.prefix = ""
1705             return False
1706
1707         last_leaf = self.leaves[-1]
1708         if (
1709             last_leaf.type == token.RPAR
1710             and not last_leaf.value
1711             and last_leaf.parent
1712             and len(list(last_leaf.parent.leaves())) <= 3
1713             and not is_type_comment(comment)
1714         ):
1715             # Comments on an optional parens wrapping a single leaf should belong to
1716             # the wrapped node except if it's a type comment. Pinning the comment like
1717             # this avoids unstable formatting caused by comment migration.
1718             if len(self.leaves) < 2:
1719                 comment.type = STANDALONE_COMMENT
1720                 comment.prefix = ""
1721                 return False
1722
1723             last_leaf = self.leaves[-2]
1724         self.comments.setdefault(id(last_leaf), []).append(comment)
1725         return True
1726
1727     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1728         """Generate comments that should appear directly after `leaf`."""
1729         return self.comments.get(id(leaf), [])
1730
1731     def remove_trailing_comma(self) -> None:
1732         """Remove the trailing comma and moves the comments attached to it."""
1733         trailing_comma = self.leaves.pop()
1734         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1735         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1736             trailing_comma_comments
1737         )
1738
1739     def is_complex_subscript(self, leaf: Leaf) -> bool:
1740         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1741         open_lsqb = self.bracket_tracker.get_open_lsqb()
1742         if open_lsqb is None:
1743             return False
1744
1745         subscript_start = open_lsqb.next_sibling
1746
1747         if isinstance(subscript_start, Node):
1748             if subscript_start.type == syms.listmaker:
1749                 return False
1750
1751             if subscript_start.type == syms.subscriptlist:
1752                 subscript_start = child_towards(subscript_start, leaf)
1753         return subscript_start is not None and any(
1754             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1755         )
1756
1757     def clone(self) -> "Line":
1758         return Line(
1759             depth=self.depth,
1760             inside_brackets=self.inside_brackets,
1761             should_explode=self.should_explode,
1762         )
1763
1764     def __str__(self) -> str:
1765         """Render the line."""
1766         if not self:
1767             return "\n"
1768
1769         indent = "    " * self.depth
1770         leaves = iter(self.leaves)
1771         first = next(leaves)
1772         res = f"{first.prefix}{indent}{first.value}"
1773         for leaf in leaves:
1774             res += str(leaf)
1775         for comment in itertools.chain.from_iterable(self.comments.values()):
1776             res += str(comment)
1777
1778         return res + "\n"
1779
1780     def __bool__(self) -> bool:
1781         """Return True if the line has leaves or comments."""
1782         return bool(self.leaves or self.comments)
1783
1784
1785 @dataclass
1786 class EmptyLineTracker:
1787     """Provides a stateful method that returns the number of potential extra
1788     empty lines needed before and after the currently processed line.
1789
1790     Note: this tracker works on lines that haven't been split yet.  It assumes
1791     the prefix of the first leaf consists of optional newlines.  Those newlines
1792     are consumed by `maybe_empty_lines()` and included in the computation.
1793     """
1794
1795     is_pyi: bool = False
1796     previous_line: Optional[Line] = None
1797     previous_after: int = 0
1798     previous_defs: List[int] = field(default_factory=list)
1799
1800     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1801         """Return the number of extra empty lines before and after the `current_line`.
1802
1803         This is for separating `def`, `async def` and `class` with extra empty
1804         lines (two on module-level).
1805         """
1806         before, after = self._maybe_empty_lines(current_line)
1807         before = (
1808             # Black should not insert empty lines at the beginning
1809             # of the file
1810             0
1811             if self.previous_line is None
1812             else before - self.previous_after
1813         )
1814         self.previous_after = after
1815         self.previous_line = current_line
1816         return before, after
1817
1818     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1819         max_allowed = 1
1820         if current_line.depth == 0:
1821             max_allowed = 1 if self.is_pyi else 2
1822         if current_line.leaves:
1823             # Consume the first leaf's extra newlines.
1824             first_leaf = current_line.leaves[0]
1825             before = first_leaf.prefix.count("\n")
1826             before = min(before, max_allowed)
1827             first_leaf.prefix = ""
1828         else:
1829             before = 0
1830         depth = current_line.depth
1831         while self.previous_defs and self.previous_defs[-1] >= depth:
1832             self.previous_defs.pop()
1833             if self.is_pyi:
1834                 before = 0 if depth else 1
1835             else:
1836                 before = 1 if depth else 2
1837         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1838             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1839
1840         if (
1841             self.previous_line
1842             and self.previous_line.is_import
1843             and not current_line.is_import
1844             and depth == self.previous_line.depth
1845         ):
1846             return (before or 1), 0
1847
1848         if (
1849             self.previous_line
1850             and self.previous_line.is_class
1851             and current_line.is_triple_quoted_string
1852         ):
1853             return before, 1
1854
1855         return before, 0
1856
1857     def _maybe_empty_lines_for_class_or_def(
1858         self, current_line: Line, before: int
1859     ) -> Tuple[int, int]:
1860         if not current_line.is_decorator:
1861             self.previous_defs.append(current_line.depth)
1862         if self.previous_line is None:
1863             # Don't insert empty lines before the first line in the file.
1864             return 0, 0
1865
1866         if self.previous_line.is_decorator:
1867             return 0, 0
1868
1869         if self.previous_line.depth < current_line.depth and (
1870             self.previous_line.is_class or self.previous_line.is_def
1871         ):
1872             return 0, 0
1873
1874         if (
1875             self.previous_line.is_comment
1876             and self.previous_line.depth == current_line.depth
1877             and before == 0
1878         ):
1879             return 0, 0
1880
1881         if self.is_pyi:
1882             if self.previous_line.depth > current_line.depth:
1883                 newlines = 1
1884             elif current_line.is_class or self.previous_line.is_class:
1885                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1886                     # No blank line between classes with an empty body
1887                     newlines = 0
1888                 else:
1889                     newlines = 1
1890             elif current_line.is_def and not self.previous_line.is_def:
1891                 # Blank line between a block of functions and a block of non-functions
1892                 newlines = 1
1893             else:
1894                 newlines = 0
1895         else:
1896             newlines = 2
1897         if current_line.depth and newlines:
1898             newlines -= 1
1899         return newlines, 0
1900
1901
1902 @dataclass
1903 class LineGenerator(Visitor[Line]):
1904     """Generates reformatted Line objects.  Empty lines are not emitted.
1905
1906     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1907     in ways that will no longer stringify to valid Python code on the tree.
1908     """
1909
1910     is_pyi: bool = False
1911     normalize_strings: bool = True
1912     current_line: Line = field(default_factory=Line)
1913     remove_u_prefix: bool = False
1914
1915     def line(self, indent: int = 0) -> Iterator[Line]:
1916         """Generate a line.
1917
1918         If the line is empty, only emit if it makes sense.
1919         If the line is too long, split it first and then generate.
1920
1921         If any lines were generated, set up a new current_line.
1922         """
1923         if not self.current_line:
1924             self.current_line.depth += indent
1925             return  # Line is empty, don't emit. Creating a new one unnecessary.
1926
1927         complete_line = self.current_line
1928         self.current_line = Line(depth=complete_line.depth + indent)
1929         yield complete_line
1930
1931     def visit_default(self, node: LN) -> Iterator[Line]:
1932         """Default `visit_*()` implementation. Recurses to children of `node`."""
1933         if isinstance(node, Leaf):
1934             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1935             for comment in generate_comments(node):
1936                 if any_open_brackets:
1937                     # any comment within brackets is subject to splitting
1938                     self.current_line.append(comment)
1939                 elif comment.type == token.COMMENT:
1940                     # regular trailing comment
1941                     self.current_line.append(comment)
1942                     yield from self.line()
1943
1944                 else:
1945                     # regular standalone comment
1946                     yield from self.line()
1947
1948                     self.current_line.append(comment)
1949                     yield from self.line()
1950
1951             normalize_prefix(node, inside_brackets=any_open_brackets)
1952             if self.normalize_strings and node.type == token.STRING:
1953                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1954                 normalize_string_quotes(node)
1955             if node.type == token.NUMBER:
1956                 normalize_numeric_literal(node)
1957             if node.type not in WHITESPACE:
1958                 self.current_line.append(node)
1959         yield from super().visit_default(node)
1960
1961     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1962         """Increase indentation level, maybe yield a line."""
1963         # In blib2to3 INDENT never holds comments.
1964         yield from self.line(+1)
1965         yield from self.visit_default(node)
1966
1967     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1968         """Decrease indentation level, maybe yield a line."""
1969         # The current line might still wait for trailing comments.  At DEDENT time
1970         # there won't be any (they would be prefixes on the preceding NEWLINE).
1971         # Emit the line then.
1972         yield from self.line()
1973
1974         # While DEDENT has no value, its prefix may contain standalone comments
1975         # that belong to the current indentation level.  Get 'em.
1976         yield from self.visit_default(node)
1977
1978         # Finally, emit the dedent.
1979         yield from self.line(-1)
1980
1981     def visit_stmt(
1982         self, node: Node, keywords: Set[str], parens: Set[str]
1983     ) -> Iterator[Line]:
1984         """Visit a statement.
1985
1986         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1987         `def`, `with`, `class`, `assert` and assignments.
1988
1989         The relevant Python language `keywords` for a given statement will be
1990         NAME leaves within it. This methods puts those on a separate line.
1991
1992         `parens` holds a set of string leaf values immediately after which
1993         invisible parens should be put.
1994         """
1995         normalize_invisible_parens(node, parens_after=parens)
1996         for child in node.children:
1997             if child.type == token.NAME and child.value in keywords:  # type: ignore
1998                 yield from self.line()
1999
2000             yield from self.visit(child)
2001
2002     def visit_suite(self, node: Node) -> Iterator[Line]:
2003         """Visit a suite."""
2004         if self.is_pyi and is_stub_suite(node):
2005             yield from self.visit(node.children[2])
2006         else:
2007             yield from self.visit_default(node)
2008
2009     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2010         """Visit a statement without nested statements."""
2011         is_suite_like = node.parent and node.parent.type in STATEMENT
2012         if is_suite_like:
2013             if self.is_pyi and is_stub_body(node):
2014                 yield from self.visit_default(node)
2015             else:
2016                 yield from self.line(+1)
2017                 yield from self.visit_default(node)
2018                 yield from self.line(-1)
2019
2020         else:
2021             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
2022                 yield from self.line()
2023             yield from self.visit_default(node)
2024
2025     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2026         """Visit `async def`, `async for`, `async with`."""
2027         yield from self.line()
2028
2029         children = iter(node.children)
2030         for child in children:
2031             yield from self.visit(child)
2032
2033             if child.type == token.ASYNC:
2034                 break
2035
2036         internal_stmt = next(children)
2037         for child in internal_stmt.children:
2038             yield from self.visit(child)
2039
2040     def visit_decorators(self, node: Node) -> Iterator[Line]:
2041         """Visit decorators."""
2042         for child in node.children:
2043             yield from self.line()
2044             yield from self.visit(child)
2045
2046     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2047         """Remove a semicolon and put the other statement on a separate line."""
2048         yield from self.line()
2049
2050     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2051         """End of file. Process outstanding comments and end with a newline."""
2052         yield from self.visit_default(leaf)
2053         yield from self.line()
2054
2055     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2056         if not self.current_line.bracket_tracker.any_open_brackets():
2057             yield from self.line()
2058         yield from self.visit_default(leaf)
2059
2060     def visit_factor(self, node: Node) -> Iterator[Line]:
2061         """Force parentheses between a unary op and a binary power:
2062
2063         -2 ** 8 -> -(2 ** 8)
2064         """
2065         _operator, operand = node.children
2066         if (
2067             operand.type == syms.power
2068             and len(operand.children) == 3
2069             and operand.children[1].type == token.DOUBLESTAR
2070         ):
2071             lpar = Leaf(token.LPAR, "(")
2072             rpar = Leaf(token.RPAR, ")")
2073             index = operand.remove() or 0
2074             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2075         yield from self.visit_default(node)
2076
2077     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2078         # Check if it's a docstring
2079         if prev_siblings_are(
2080             leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
2081         ) and is_multiline_string(leaf):
2082             prefix = "    " * self.current_line.depth
2083             docstring = fix_docstring(leaf.value[3:-3], prefix)
2084             leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
2085             normalize_string_quotes(leaf)
2086
2087         yield from self.visit_default(leaf)
2088
2089     def __post_init__(self) -> None:
2090         """You are in a twisty little maze of passages."""
2091         v = self.visit_stmt
2092         Ø: Set[str] = set()
2093         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2094         self.visit_if_stmt = partial(
2095             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2096         )
2097         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2098         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2099         self.visit_try_stmt = partial(
2100             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2101         )
2102         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2103         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2104         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2105         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2106         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2107         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2108         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2109         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2110         self.visit_async_funcdef = self.visit_async_stmt
2111         self.visit_decorated = self.visit_decorators
2112
2113
2114 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2115 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2116 OPENING_BRACKETS = set(BRACKET.keys())
2117 CLOSING_BRACKETS = set(BRACKET.values())
2118 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2119 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2120
2121
2122 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2123     """Return whitespace prefix if needed for the given `leaf`.
2124
2125     `complex_subscript` signals whether the given leaf is part of a subscription
2126     which has non-trivial arguments, like arithmetic expressions or function calls.
2127     """
2128     NO = ""
2129     SPACE = " "
2130     DOUBLESPACE = "  "
2131     t = leaf.type
2132     p = leaf.parent
2133     v = leaf.value
2134     if t in ALWAYS_NO_SPACE:
2135         return NO
2136
2137     if t == token.COMMENT:
2138         return DOUBLESPACE
2139
2140     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2141     if t == token.COLON and p.type not in {
2142         syms.subscript,
2143         syms.subscriptlist,
2144         syms.sliceop,
2145     }:
2146         return NO
2147
2148     prev = leaf.prev_sibling
2149     if not prev:
2150         prevp = preceding_leaf(p)
2151         if not prevp or prevp.type in OPENING_BRACKETS:
2152             return NO
2153
2154         if t == token.COLON:
2155             if prevp.type == token.COLON:
2156                 return NO
2157
2158             elif prevp.type != token.COMMA and not complex_subscript:
2159                 return NO
2160
2161             return SPACE
2162
2163         if prevp.type == token.EQUAL:
2164             if prevp.parent:
2165                 if prevp.parent.type in {
2166                     syms.arglist,
2167                     syms.argument,
2168                     syms.parameters,
2169                     syms.varargslist,
2170                 }:
2171                     return NO
2172
2173                 elif prevp.parent.type == syms.typedargslist:
2174                     # A bit hacky: if the equal sign has whitespace, it means we
2175                     # previously found it's a typed argument.  So, we're using
2176                     # that, too.
2177                     return prevp.prefix
2178
2179         elif prevp.type in VARARGS_SPECIALS:
2180             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2181                 return NO
2182
2183         elif prevp.type == token.COLON:
2184             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2185                 return SPACE if complex_subscript else NO
2186
2187         elif (
2188             prevp.parent
2189             and prevp.parent.type == syms.factor
2190             and prevp.type in MATH_OPERATORS
2191         ):
2192             return NO
2193
2194         elif (
2195             prevp.type == token.RIGHTSHIFT
2196             and prevp.parent
2197             and prevp.parent.type == syms.shift_expr
2198             and prevp.prev_sibling
2199             and prevp.prev_sibling.type == token.NAME
2200             and prevp.prev_sibling.value == "print"  # type: ignore
2201         ):
2202             # Python 2 print chevron
2203             return NO
2204
2205     elif prev.type in OPENING_BRACKETS:
2206         return NO
2207
2208     if p.type in {syms.parameters, syms.arglist}:
2209         # untyped function signatures or calls
2210         if not prev or prev.type != token.COMMA:
2211             return NO
2212
2213     elif p.type == syms.varargslist:
2214         # lambdas
2215         if prev and prev.type != token.COMMA:
2216             return NO
2217
2218     elif p.type == syms.typedargslist:
2219         # typed function signatures
2220         if not prev:
2221             return NO
2222
2223         if t == token.EQUAL:
2224             if prev.type != syms.tname:
2225                 return NO
2226
2227         elif prev.type == token.EQUAL:
2228             # A bit hacky: if the equal sign has whitespace, it means we
2229             # previously found it's a typed argument.  So, we're using that, too.
2230             return prev.prefix
2231
2232         elif prev.type != token.COMMA:
2233             return NO
2234
2235     elif p.type == syms.tname:
2236         # type names
2237         if not prev:
2238             prevp = preceding_leaf(p)
2239             if not prevp or prevp.type != token.COMMA:
2240                 return NO
2241
2242     elif p.type == syms.trailer:
2243         # attributes and calls
2244         if t == token.LPAR or t == token.RPAR:
2245             return NO
2246
2247         if not prev:
2248             if t == token.DOT:
2249                 prevp = preceding_leaf(p)
2250                 if not prevp or prevp.type != token.NUMBER:
2251                     return NO
2252
2253             elif t == token.LSQB:
2254                 return NO
2255
2256         elif prev.type != token.COMMA:
2257             return NO
2258
2259     elif p.type == syms.argument:
2260         # single argument
2261         if t == token.EQUAL:
2262             return NO
2263
2264         if not prev:
2265             prevp = preceding_leaf(p)
2266             if not prevp or prevp.type == token.LPAR:
2267                 return NO
2268
2269         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2270             return NO
2271
2272     elif p.type == syms.decorator:
2273         # decorators
2274         return NO
2275
2276     elif p.type == syms.dotted_name:
2277         if prev:
2278             return NO
2279
2280         prevp = preceding_leaf(p)
2281         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2282             return NO
2283
2284     elif p.type == syms.classdef:
2285         if t == token.LPAR:
2286             return NO
2287
2288         if prev and prev.type == token.LPAR:
2289             return NO
2290
2291     elif p.type in {syms.subscript, syms.sliceop}:
2292         # indexing
2293         if not prev:
2294             assert p.parent is not None, "subscripts are always parented"
2295             if p.parent.type == syms.subscriptlist:
2296                 return SPACE
2297
2298             return NO
2299
2300         elif not complex_subscript:
2301             return NO
2302
2303     elif p.type == syms.atom:
2304         if prev and t == token.DOT:
2305             # dots, but not the first one.
2306             return NO
2307
2308     elif p.type == syms.dictsetmaker:
2309         # dict unpacking
2310         if prev and prev.type == token.DOUBLESTAR:
2311             return NO
2312
2313     elif p.type in {syms.factor, syms.star_expr}:
2314         # unary ops
2315         if not prev:
2316             prevp = preceding_leaf(p)
2317             if not prevp or prevp.type in OPENING_BRACKETS:
2318                 return NO
2319
2320             prevp_parent = prevp.parent
2321             assert prevp_parent is not None
2322             if prevp.type == token.COLON and prevp_parent.type in {
2323                 syms.subscript,
2324                 syms.sliceop,
2325             }:
2326                 return NO
2327
2328             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2329                 return NO
2330
2331         elif t in {token.NAME, token.NUMBER, token.STRING}:
2332             return NO
2333
2334     elif p.type == syms.import_from:
2335         if t == token.DOT:
2336             if prev and prev.type == token.DOT:
2337                 return NO
2338
2339         elif t == token.NAME:
2340             if v == "import":
2341                 return SPACE
2342
2343             if prev and prev.type == token.DOT:
2344                 return NO
2345
2346     elif p.type == syms.sliceop:
2347         return NO
2348
2349     return SPACE
2350
2351
2352 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2353     """Return the first leaf that precedes `node`, if any."""
2354     while node:
2355         res = node.prev_sibling
2356         if res:
2357             if isinstance(res, Leaf):
2358                 return res
2359
2360             try:
2361                 return list(res.leaves())[-1]
2362
2363             except IndexError:
2364                 return None
2365
2366         node = node.parent
2367     return None
2368
2369
2370 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2371     """Return if the `node` and its previous siblings match types against the provided
2372     list of tokens; the provided `node`has its type matched against the last element in
2373     the list.  `None` can be used as the first element to declare that the start of the
2374     list is anchored at the start of its parent's children."""
2375     if not tokens:
2376         return True
2377     if tokens[-1] is None:
2378         return node is None
2379     if not node:
2380         return False
2381     if node.type != tokens[-1]:
2382         return False
2383     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2384
2385
2386 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2387     """Return the child of `ancestor` that contains `descendant`."""
2388     node: Optional[LN] = descendant
2389     while node and node.parent != ancestor:
2390         node = node.parent
2391     return node
2392
2393
2394 def container_of(leaf: Leaf) -> LN:
2395     """Return `leaf` or one of its ancestors that is the topmost container of it.
2396
2397     By "container" we mean a node where `leaf` is the very first child.
2398     """
2399     same_prefix = leaf.prefix
2400     container: LN = leaf
2401     while container:
2402         parent = container.parent
2403         if parent is None:
2404             break
2405
2406         if parent.children[0].prefix != same_prefix:
2407             break
2408
2409         if parent.type == syms.file_input:
2410             break
2411
2412         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2413             break
2414
2415         container = parent
2416     return container
2417
2418
2419 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2420     """Return the priority of the `leaf` delimiter, given a line break after it.
2421
2422     The delimiter priorities returned here are from those delimiters that would
2423     cause a line break after themselves.
2424
2425     Higher numbers are higher priority.
2426     """
2427     if leaf.type == token.COMMA:
2428         return COMMA_PRIORITY
2429
2430     return 0
2431
2432
2433 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2434     """Return the priority of the `leaf` delimiter, given a line break before it.
2435
2436     The delimiter priorities returned here are from those delimiters that would
2437     cause a line break before themselves.
2438
2439     Higher numbers are higher priority.
2440     """
2441     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2442         # * and ** might also be MATH_OPERATORS but in this case they are not.
2443         # Don't treat them as a delimiter.
2444         return 0
2445
2446     if (
2447         leaf.type == token.DOT
2448         and leaf.parent
2449         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2450         and (previous is None or previous.type in CLOSING_BRACKETS)
2451     ):
2452         return DOT_PRIORITY
2453
2454     if (
2455         leaf.type in MATH_OPERATORS
2456         and leaf.parent
2457         and leaf.parent.type not in {syms.factor, syms.star_expr}
2458     ):
2459         return MATH_PRIORITIES[leaf.type]
2460
2461     if leaf.type in COMPARATORS:
2462         return COMPARATOR_PRIORITY
2463
2464     if (
2465         leaf.type == token.STRING
2466         and previous is not None
2467         and previous.type == token.STRING
2468     ):
2469         return STRING_PRIORITY
2470
2471     if leaf.type not in {token.NAME, token.ASYNC}:
2472         return 0
2473
2474     if (
2475         leaf.value == "for"
2476         and leaf.parent
2477         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2478         or leaf.type == token.ASYNC
2479     ):
2480         if (
2481             not isinstance(leaf.prev_sibling, Leaf)
2482             or leaf.prev_sibling.value != "async"
2483         ):
2484             return COMPREHENSION_PRIORITY
2485
2486     if (
2487         leaf.value == "if"
2488         and leaf.parent
2489         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2490     ):
2491         return COMPREHENSION_PRIORITY
2492
2493     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2494         return TERNARY_PRIORITY
2495
2496     if leaf.value == "is":
2497         return COMPARATOR_PRIORITY
2498
2499     if (
2500         leaf.value == "in"
2501         and leaf.parent
2502         and leaf.parent.type in {syms.comp_op, syms.comparison}
2503         and not (
2504             previous is not None
2505             and previous.type == token.NAME
2506             and previous.value == "not"
2507         )
2508     ):
2509         return COMPARATOR_PRIORITY
2510
2511     if (
2512         leaf.value == "not"
2513         and leaf.parent
2514         and leaf.parent.type == syms.comp_op
2515         and not (
2516             previous is not None
2517             and previous.type == token.NAME
2518             and previous.value == "is"
2519         )
2520     ):
2521         return COMPARATOR_PRIORITY
2522
2523     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2524         return LOGIC_PRIORITY
2525
2526     return 0
2527
2528
2529 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2530 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2531
2532
2533 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2534     """Clean the prefix of the `leaf` and generate comments from it, if any.
2535
2536     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2537     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2538     move because it does away with modifying the grammar to include all the
2539     possible places in which comments can be placed.
2540
2541     The sad consequence for us though is that comments don't "belong" anywhere.
2542     This is why this function generates simple parentless Leaf objects for
2543     comments.  We simply don't know what the correct parent should be.
2544
2545     No matter though, we can live without this.  We really only need to
2546     differentiate between inline and standalone comments.  The latter don't
2547     share the line with any code.
2548
2549     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2550     are emitted with a fake STANDALONE_COMMENT token identifier.
2551     """
2552     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2553         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2554
2555
2556 @dataclass
2557 class ProtoComment:
2558     """Describes a piece of syntax that is a comment.
2559
2560     It's not a :class:`blib2to3.pytree.Leaf` so that:
2561
2562     * it can be cached (`Leaf` objects should not be reused more than once as
2563       they store their lineno, column, prefix, and parent information);
2564     * `newlines` and `consumed` fields are kept separate from the `value`. This
2565       simplifies handling of special marker comments like ``# fmt: off/on``.
2566     """
2567
2568     type: int  # token.COMMENT or STANDALONE_COMMENT
2569     value: str  # content of the comment
2570     newlines: int  # how many newlines before the comment
2571     consumed: int  # how many characters of the original leaf's prefix did we consume
2572
2573
2574 @lru_cache(maxsize=4096)
2575 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2576     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2577     result: List[ProtoComment] = []
2578     if not prefix or "#" not in prefix:
2579         return result
2580
2581     consumed = 0
2582     nlines = 0
2583     ignored_lines = 0
2584     for index, line in enumerate(prefix.split("\n")):
2585         consumed += len(line) + 1  # adding the length of the split '\n'
2586         line = line.lstrip()
2587         if not line:
2588             nlines += 1
2589         if not line.startswith("#"):
2590             # Escaped newlines outside of a comment are not really newlines at
2591             # all. We treat a single-line comment following an escaped newline
2592             # as a simple trailing comment.
2593             if line.endswith("\\"):
2594                 ignored_lines += 1
2595             continue
2596
2597         if index == ignored_lines and not is_endmarker:
2598             comment_type = token.COMMENT  # simple trailing comment
2599         else:
2600             comment_type = STANDALONE_COMMENT
2601         comment = make_comment(line)
2602         result.append(
2603             ProtoComment(
2604                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2605             )
2606         )
2607         nlines = 0
2608     return result
2609
2610
2611 def make_comment(content: str) -> str:
2612     """Return a consistently formatted comment from the given `content` string.
2613
2614     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2615     space between the hash sign and the content.
2616
2617     If `content` didn't start with a hash sign, one is provided.
2618     """
2619     content = content.rstrip()
2620     if not content:
2621         return "#"
2622
2623     if content[0] == "#":
2624         content = content[1:]
2625     if content and content[0] not in " !:#'%":
2626         content = " " + content
2627     return "#" + content
2628
2629
2630 def transform_line(
2631     line: Line,
2632     line_length: int,
2633     normalize_strings: bool,
2634     features: Collection[Feature] = (),
2635 ) -> Iterator[Line]:
2636     """Transform a `line`, potentially splitting it into many lines.
2637
2638     They should fit in the allotted `line_length` but might not be able to.
2639
2640     `features` are syntactical features that may be used in the output.
2641     """
2642     if line.is_comment:
2643         yield line
2644         return
2645
2646     line_str = line_to_string(line)
2647
2648     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2649         """Initialize StringTransformer"""
2650         return ST(line_length, normalize_strings)
2651
2652     string_merge = init_st(StringMerger)
2653     string_paren_strip = init_st(StringParenStripper)
2654     string_split = init_st(StringSplitter)
2655     string_paren_wrap = init_st(StringParenWrapper)
2656
2657     transformers: List[Transformer]
2658     if (
2659         not line.contains_uncollapsable_type_comments()
2660         and not line.should_explode
2661         and not line.is_collection_with_optional_trailing_comma
2662         and (
2663             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2664             or line.contains_unsplittable_type_ignore()
2665         )
2666         and not (line.contains_standalone_comments() and line.inside_brackets)
2667     ):
2668         # Only apply basic string preprocessing, since lines shouldn't be split here.
2669         transformers = [string_merge, string_paren_strip]
2670     elif line.is_def:
2671         transformers = [left_hand_split]
2672     else:
2673
2674         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2675             for omit in generate_trailers_to_omit(line, line_length):
2676                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2677                 if is_line_short_enough(lines[0], line_length=line_length):
2678                     yield from lines
2679                     return
2680
2681             # All splits failed, best effort split with no omits.
2682             # This mostly happens to multiline strings that are by definition
2683             # reported as not fitting a single line.
2684             # line_length=1 here was historically a bug that somehow became a feature.
2685             # See #762 and #781 for the full story.
2686             yield from right_hand_split(line, line_length=1, features=features)
2687
2688         if line.inside_brackets:
2689             transformers = [
2690                 string_merge,
2691                 string_paren_strip,
2692                 delimiter_split,
2693                 standalone_comment_split,
2694                 string_split,
2695                 string_paren_wrap,
2696                 rhs,
2697             ]
2698         else:
2699             transformers = [
2700                 string_merge,
2701                 string_paren_strip,
2702                 string_split,
2703                 string_paren_wrap,
2704                 rhs,
2705             ]
2706
2707     for transform in transformers:
2708         # We are accumulating lines in `result` because we might want to abort
2709         # mission and return the original line in the end, or attempt a different
2710         # split altogether.
2711         result: List[Line] = []
2712         try:
2713             for l in transform(line, features):
2714                 if str(l).strip("\n") == line_str:
2715                     raise CannotTransform(
2716                         "Line transformer returned an unchanged result"
2717                     )
2718
2719                 result.extend(
2720                     transform_line(
2721                         l,
2722                         line_length=line_length,
2723                         normalize_strings=normalize_strings,
2724                         features=features,
2725                     )
2726                 )
2727         except CannotTransform:
2728             continue
2729         else:
2730             yield from result
2731             break
2732
2733     else:
2734         yield line
2735
2736
2737 @dataclass  # type: ignore
2738 class StringTransformer(ABC):
2739     """
2740     An implementation of the Transformer protocol that relies on its
2741     subclasses overriding the template methods `do_match(...)` and
2742     `do_transform(...)`.
2743
2744     This Transformer works exclusively on strings (for example, by merging
2745     or splitting them).
2746
2747     The following sections can be found among the docstrings of each concrete
2748     StringTransformer subclass.
2749
2750     Requirements:
2751         Which requirements must be met of the given Line for this
2752         StringTransformer to be applied?
2753
2754     Transformations:
2755         If the given Line meets all of the above requirments, which string
2756         transformations can you expect to be applied to it by this
2757         StringTransformer?
2758
2759     Collaborations:
2760         What contractual agreements does this StringTransformer have with other
2761         StringTransfomers? Such collaborations should be eliminated/minimized
2762         as much as possible.
2763     """
2764
2765     line_length: int
2766     normalize_strings: bool
2767
2768     @abstractmethod
2769     def do_match(self, line: Line) -> TMatchResult:
2770         """
2771         Returns:
2772             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2773             string, if a match was able to be made.
2774                 OR
2775             * Err(CannotTransform), if a match was not able to be made.
2776         """
2777
2778     @abstractmethod
2779     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2780         """
2781         Yields:
2782             * Ok(new_line) where new_line is the new transformed line.
2783                 OR
2784             * Err(CannotTransform) if the transformation failed for some reason. The
2785             `do_match(...)` template method should usually be used to reject
2786             the form of the given Line, but in some cases it is difficult to
2787             know whether or not a Line meets the StringTransformer's
2788             requirements until the transformation is already midway.
2789
2790         Side Effects:
2791             This method should NOT mutate @line directly, but it MAY mutate the
2792             Line's underlying Node structure. (WARNING: If the underlying Node
2793             structure IS altered, then this method should NOT be allowed to
2794             yield an CannotTransform after that point.)
2795         """
2796
2797     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2798         """
2799         StringTransformer instances have a call signature that mirrors that of
2800         the Transformer type.
2801
2802         Raises:
2803             CannotTransform(...) if the concrete StringTransformer class is unable
2804             to transform @line.
2805         """
2806         # Optimization to avoid calling `self.do_match(...)` when the line does
2807         # not contain any string.
2808         if not any(leaf.type == token.STRING for leaf in line.leaves):
2809             raise CannotTransform("There are no strings in this line.")
2810
2811         match_result = self.do_match(line)
2812
2813         if isinstance(match_result, Err):
2814             cant_transform = match_result.err()
2815             raise CannotTransform(
2816                 f"The string transformer {self.__class__.__name__} does not recognize"
2817                 " this line as one that it can transform."
2818             ) from cant_transform
2819
2820         string_idx = match_result.ok()
2821
2822         for line_result in self.do_transform(line, string_idx):
2823             if isinstance(line_result, Err):
2824                 cant_transform = line_result.err()
2825                 raise CannotTransform(
2826                     "StringTransformer failed while attempting to transform string."
2827                 ) from cant_transform
2828             line = line_result.ok()
2829             yield line
2830
2831
2832 @dataclass
2833 class CustomSplit:
2834     """A custom (i.e. manual) string split.
2835
2836     A single CustomSplit instance represents a single substring.
2837
2838     Examples:
2839         Consider the following string:
2840         ```
2841         "Hi there friend."
2842         " This is a custom"
2843         f" string {split}."
2844         ```
2845
2846         This string will correspond to the following three CustomSplit instances:
2847         ```
2848         CustomSplit(False, 16)
2849         CustomSplit(False, 17)
2850         CustomSplit(True, 16)
2851         ```
2852     """
2853
2854     has_prefix: bool
2855     break_idx: int
2856
2857
2858 class CustomSplitMapMixin:
2859     """
2860     This mixin class is used to map merged strings to a sequence of
2861     CustomSplits, which will then be used to re-split the strings iff none of
2862     the resultant substrings go over the configured max line length.
2863     """
2864
2865     _Key = Tuple[StringID, str]
2866     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2867
2868     @staticmethod
2869     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2870         """
2871         Returns:
2872             A unique identifier that is used internally to map @string to a
2873             group of custom splits.
2874         """
2875         return (id(string), string)
2876
2877     def add_custom_splits(
2878         self, string: str, custom_splits: Iterable[CustomSplit]
2879     ) -> None:
2880         """Custom Split Map Setter Method
2881
2882         Side Effects:
2883             Adds a mapping from @string to the custom splits @custom_splits.
2884         """
2885         key = self._get_key(string)
2886         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2887
2888     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2889         """Custom Split Map Getter Method
2890
2891         Returns:
2892             * A list of the custom splits that are mapped to @string, if any
2893             exist.
2894                 OR
2895             * [], otherwise.
2896
2897         Side Effects:
2898             Deletes the mapping between @string and its associated custom
2899             splits (which are returned to the caller).
2900         """
2901         key = self._get_key(string)
2902
2903         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2904         del self._CUSTOM_SPLIT_MAP[key]
2905
2906         return list(custom_splits)
2907
2908     def has_custom_splits(self, string: str) -> bool:
2909         """
2910         Returns:
2911             True iff @string is associated with a set of custom splits.
2912         """
2913         key = self._get_key(string)
2914         return key in self._CUSTOM_SPLIT_MAP
2915
2916
2917 class StringMerger(CustomSplitMapMixin, StringTransformer):
2918     """StringTransformer that merges strings together.
2919
2920     Requirements:
2921         (A) The line contains adjacent strings such that at most one substring
2922         has inline comments AND none of those inline comments are pragmas AND
2923         the set of all substring prefixes is either of length 1 or equal to
2924         {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed
2925         with 'r').
2926             OR
2927         (B) The line contains a string which uses line continuation backslashes.
2928
2929     Transformations:
2930         Depending on which of the two requirements above where met, either:
2931
2932         (A) The string group associated with the target string is merged.
2933             OR
2934         (B) All line-continuation backslashes are removed from the target string.
2935
2936     Collaborations:
2937         StringMerger provides custom split information to StringSplitter.
2938     """
2939
2940     def do_match(self, line: Line) -> TMatchResult:
2941         LL = line.leaves
2942
2943         is_valid_index = is_valid_index_factory(LL)
2944
2945         for (i, leaf) in enumerate(LL):
2946             if (
2947                 leaf.type == token.STRING
2948                 and is_valid_index(i + 1)
2949                 and LL[i + 1].type == token.STRING
2950             ):
2951                 return Ok(i)
2952
2953             if leaf.type == token.STRING and "\\\n" in leaf.value:
2954                 return Ok(i)
2955
2956         return TErr("This line has no strings that need merging.")
2957
2958     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2959         new_line = line
2960         rblc_result = self.__remove_backslash_line_continuation_chars(
2961             new_line, string_idx
2962         )
2963         if isinstance(rblc_result, Ok):
2964             new_line = rblc_result.ok()
2965
2966         msg_result = self.__merge_string_group(new_line, string_idx)
2967         if isinstance(msg_result, Ok):
2968             new_line = msg_result.ok()
2969
2970         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
2971             msg_cant_transform = msg_result.err()
2972             rblc_cant_transform = rblc_result.err()
2973             cant_transform = CannotTransform(
2974                 "StringMerger failed to merge any strings in this line."
2975             )
2976
2977             # Chain the errors together using `__cause__`.
2978             msg_cant_transform.__cause__ = rblc_cant_transform
2979             cant_transform.__cause__ = msg_cant_transform
2980
2981             yield Err(cant_transform)
2982         else:
2983             yield Ok(new_line)
2984
2985     @staticmethod
2986     def __remove_backslash_line_continuation_chars(
2987         line: Line, string_idx: int
2988     ) -> TResult[Line]:
2989         """
2990         Merge strings that were split across multiple lines using
2991         line-continuation backslashes.
2992
2993         Returns:
2994             Ok(new_line), if @line contains backslash line-continuation
2995             characters.
2996                 OR
2997             Err(CannotTransform), otherwise.
2998         """
2999         LL = line.leaves
3000
3001         string_leaf = LL[string_idx]
3002         if not (
3003             string_leaf.type == token.STRING
3004             and "\\\n" in string_leaf.value
3005             and not has_triple_quotes(string_leaf.value)
3006         ):
3007             return TErr(
3008                 f"String leaf {string_leaf} does not contain any backslash line"
3009                 " continuation characters."
3010             )
3011
3012         new_line = line.clone()
3013         new_line.comments = line.comments
3014         append_leaves(new_line, line, LL)
3015
3016         new_string_leaf = new_line.leaves[string_idx]
3017         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3018
3019         return Ok(new_line)
3020
3021     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3022         """
3023         Merges string group (i.e. set of adjacent strings) where the first
3024         string in the group is `line.leaves[string_idx]`.
3025
3026         Returns:
3027             Ok(new_line), if ALL of the validation checks found in
3028             __validate_msg(...) pass.
3029                 OR
3030             Err(CannotTransform), otherwise.
3031         """
3032         LL = line.leaves
3033
3034         is_valid_index = is_valid_index_factory(LL)
3035
3036         vresult = self.__validate_msg(line, string_idx)
3037         if isinstance(vresult, Err):
3038             return vresult
3039
3040         # If the string group is wrapped inside an Atom node, we must make sure
3041         # to later replace that Atom with our new (merged) string leaf.
3042         atom_node = LL[string_idx].parent
3043
3044         # We will place BREAK_MARK in between every two substrings that we
3045         # merge. We will then later go through our final result and use the
3046         # various instances of BREAK_MARK we find to add the right values to
3047         # the custom split map.
3048         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3049
3050         QUOTE = LL[string_idx].value[-1]
3051
3052         def make_naked(string: str, string_prefix: str) -> str:
3053             """Strip @string (i.e. make it a "naked" string)
3054
3055             Pre-conditions:
3056                 * assert_is_leaf_string(@string)
3057
3058             Returns:
3059                 A string that is identical to @string except that
3060                 @string_prefix has been stripped, the surrounding QUOTE
3061                 characters have been removed, and any remaining QUOTE
3062                 characters have been escaped.
3063             """
3064             assert_is_leaf_string(string)
3065
3066             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3067             naked_string = string[len(string_prefix) + 1 : -1]
3068             naked_string = re.sub(
3069                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3070             )
3071             return naked_string
3072
3073         # Holds the CustomSplit objects that will later be added to the custom
3074         # split map.
3075         custom_splits = []
3076
3077         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3078         prefix_tracker = []
3079
3080         # Sets the 'prefix' variable. This is the prefix that the final merged
3081         # string will have.
3082         next_str_idx = string_idx
3083         prefix = ""
3084         while (
3085             not prefix
3086             and is_valid_index(next_str_idx)
3087             and LL[next_str_idx].type == token.STRING
3088         ):
3089             prefix = get_string_prefix(LL[next_str_idx].value)
3090             next_str_idx += 1
3091
3092         # The next loop merges the string group. The final string will be
3093         # contained in 'S'.
3094         #
3095         # The following convenience variables are used:
3096         #
3097         #   S: string
3098         #   NS: naked string
3099         #   SS: next string
3100         #   NSS: naked next string
3101         S = ""
3102         NS = ""
3103         num_of_strings = 0
3104         next_str_idx = string_idx
3105         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3106             num_of_strings += 1
3107
3108             SS = LL[next_str_idx].value
3109             next_prefix = get_string_prefix(SS)
3110
3111             # If this is an f-string group but this substring is not prefixed
3112             # with 'f'...
3113             if "f" in prefix and "f" not in next_prefix:
3114                 # Then we must escape any braces contained in this substring.
3115                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3116
3117             NSS = make_naked(SS, next_prefix)
3118
3119             has_prefix = bool(next_prefix)
3120             prefix_tracker.append(has_prefix)
3121
3122             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3123             NS = make_naked(S, prefix)
3124
3125             next_str_idx += 1
3126
3127         S_leaf = Leaf(token.STRING, S)
3128         if self.normalize_strings:
3129             normalize_string_quotes(S_leaf)
3130
3131         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3132         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3133         for has_prefix in prefix_tracker:
3134             mark_idx = temp_string.find(BREAK_MARK)
3135             assert (
3136                 mark_idx >= 0
3137             ), "Logic error while filling the custom string breakpoint cache."
3138
3139             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3140             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3141             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3142
3143         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3144
3145         if atom_node is not None:
3146             replace_child(atom_node, string_leaf)
3147
3148         # Build the final line ('new_line') that this method will later return.
3149         new_line = line.clone()
3150         for (i, leaf) in enumerate(LL):
3151             if i == string_idx:
3152                 new_line.append(string_leaf)
3153
3154             if string_idx <= i < string_idx + num_of_strings:
3155                 for comment_leaf in line.comments_after(LL[i]):
3156                     new_line.append(comment_leaf, preformatted=True)
3157                 continue
3158
3159             append_leaves(new_line, line, [leaf])
3160
3161         self.add_custom_splits(string_leaf.value, custom_splits)
3162         return Ok(new_line)
3163
3164     @staticmethod
3165     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3166         """Validate (M)erge (S)tring (G)roup
3167
3168         Transform-time string validation logic for __merge_string_group(...).
3169
3170         Returns:
3171             * Ok(None), if ALL validation checks (listed below) pass.
3172                 OR
3173             * Err(CannotTransform), if any of the following are true:
3174                 - The target string is not in a string group (i.e. it has no
3175                   adjacent strings).
3176                 - The string group has more than one inline comment.
3177                 - The string group has an inline comment that appears to be a pragma.
3178                 - The set of all string prefixes in the string group is of
3179                   length greater than one and is not equal to {"", "f"}.
3180                 - The string group consists of raw strings.
3181         """
3182         num_of_inline_string_comments = 0
3183         set_of_prefixes = set()
3184         num_of_strings = 0
3185         for leaf in line.leaves[string_idx:]:
3186             if leaf.type != token.STRING:
3187                 # If the string group is trailed by a comma, we count the
3188                 # comments trailing the comma to be one of the string group's
3189                 # comments.
3190                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3191                     num_of_inline_string_comments += 1
3192                 break
3193
3194             if has_triple_quotes(leaf.value):
3195                 return TErr("StringMerger does NOT merge multiline strings.")
3196
3197             num_of_strings += 1
3198             prefix = get_string_prefix(leaf.value)
3199             if "r" in prefix:
3200                 return TErr("StringMerger does NOT merge raw strings.")
3201
3202             set_of_prefixes.add(prefix)
3203
3204             if id(leaf) in line.comments:
3205                 num_of_inline_string_comments += 1
3206                 if contains_pragma_comment(line.comments[id(leaf)]):
3207                     return TErr("Cannot merge strings which have pragma comments.")
3208
3209         if num_of_strings < 2:
3210             return TErr(
3211                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3212             )
3213
3214         if num_of_inline_string_comments > 1:
3215             return TErr(
3216                 f"Too many inline string comments ({num_of_inline_string_comments})."
3217             )
3218
3219         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3220             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3221
3222         return Ok(None)
3223
3224
3225 class StringParenStripper(StringTransformer):
3226     """StringTransformer that strips surrounding parentheses from strings.
3227
3228     Requirements:
3229         The line contains a string which is surrounded by parentheses and:
3230             - The target string is NOT the only argument to a function call).
3231             - The RPAR is NOT followed by an attribute access (i.e. a dot).
3232
3233     Transformations:
3234         The parentheses mentioned in the 'Requirements' section are stripped.
3235
3236     Collaborations:
3237         StringParenStripper has its own inherent usefulness, but it is also
3238         relied on to clean up the parentheses created by StringParenWrapper (in
3239         the event that they are no longer needed).
3240     """
3241
3242     def do_match(self, line: Line) -> TMatchResult:
3243         LL = line.leaves
3244
3245         is_valid_index = is_valid_index_factory(LL)
3246
3247         for (idx, leaf) in enumerate(LL):
3248             # Should be a string...
3249             if leaf.type != token.STRING:
3250                 continue
3251
3252             # Should be preceded by a non-empty LPAR...
3253             if (
3254                 not is_valid_index(idx - 1)
3255                 or LL[idx - 1].type != token.LPAR
3256                 or is_empty_lpar(LL[idx - 1])
3257             ):
3258                 continue
3259
3260             # That LPAR should NOT be preceded by a function name or a closing
3261             # bracket (which could be a function which returns a function or a
3262             # list/dictionary that contains a function)...
3263             if is_valid_index(idx - 2) and (
3264                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3265             ):
3266                 continue
3267
3268             string_idx = idx
3269
3270             # Skip the string trailer, if one exists.
3271             string_parser = StringParser()
3272             next_idx = string_parser.parse(LL, string_idx)
3273
3274             # Should be followed by a non-empty RPAR...
3275             if (
3276                 is_valid_index(next_idx)
3277                 and LL[next_idx].type == token.RPAR
3278                 and not is_empty_rpar(LL[next_idx])
3279             ):
3280                 # That RPAR should NOT be followed by a '.' symbol.
3281                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type == token.DOT:
3282                     continue
3283
3284                 return Ok(string_idx)
3285
3286         return TErr("This line has no strings wrapped in parens.")
3287
3288     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3289         LL = line.leaves
3290
3291         string_parser = StringParser()
3292         rpar_idx = string_parser.parse(LL, string_idx)
3293
3294         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3295             if line.comments_after(leaf):
3296                 yield TErr(
3297                     "Will not strip parentheses which have comments attached to them."
3298                 )
3299
3300         new_line = line.clone()
3301         new_line.comments = line.comments.copy()
3302
3303         append_leaves(new_line, line, LL[: string_idx - 1])
3304
3305         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3306         LL[string_idx - 1].remove()
3307         replace_child(LL[string_idx], string_leaf)
3308         new_line.append(string_leaf)
3309
3310         append_leaves(
3311             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :],
3312         )
3313
3314         LL[rpar_idx].remove()
3315
3316         yield Ok(new_line)
3317
3318
3319 class BaseStringSplitter(StringTransformer):
3320     """
3321     Abstract class for StringTransformers which transform a Line's strings by splitting
3322     them or placing them on their own lines where necessary to avoid going over
3323     the configured line length.
3324
3325     Requirements:
3326         * The target string value is responsible for the line going over the
3327         line length limit. It follows that after all of black's other line
3328         split methods have been exhausted, this line (or one of the resulting
3329         lines after all line splits are performed) would still be over the
3330         line_length limit unless we split this string.
3331             AND
3332         * The target string is NOT a "pointless" string (i.e. a string that has
3333         no parent or siblings).
3334             AND
3335         * The target string is not followed by an inline comment that appears
3336         to be a pragma.
3337             AND
3338         * The target string is not a multiline (i.e. triple-quote) string.
3339     """
3340
3341     @abstractmethod
3342     def do_splitter_match(self, line: Line) -> TMatchResult:
3343         """
3344         BaseStringSplitter asks its clients to override this method instead of
3345         `StringTransformer.do_match(...)`.
3346
3347         Follows the same protocol as `StringTransformer.do_match(...)`.
3348
3349         Refer to `help(StringTransformer.do_match)` for more information.
3350         """
3351
3352     def do_match(self, line: Line) -> TMatchResult:
3353         match_result = self.do_splitter_match(line)
3354         if isinstance(match_result, Err):
3355             return match_result
3356
3357         string_idx = match_result.ok()
3358         vresult = self.__validate(line, string_idx)
3359         if isinstance(vresult, Err):
3360             return vresult
3361
3362         return match_result
3363
3364     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3365         """
3366         Checks that @line meets all of the requirements listed in this classes'
3367         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3368         description of those requirements.
3369
3370         Returns:
3371             * Ok(None), if ALL of the requirements are met.
3372                 OR
3373             * Err(CannotTransform), if ANY of the requirements are NOT met.
3374         """
3375         LL = line.leaves
3376
3377         string_leaf = LL[string_idx]
3378
3379         max_string_length = self.__get_max_string_length(line, string_idx)
3380         if len(string_leaf.value) <= max_string_length:
3381             return TErr(
3382                 "The string itself is not what is causing this line to be too long."
3383             )
3384
3385         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3386             token.STRING,
3387             token.NEWLINE,
3388         ]:
3389             return TErr(
3390                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3391                 " no parent)."
3392             )
3393
3394         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3395             line.comments[id(line.leaves[string_idx])]
3396         ):
3397             return TErr(
3398                 "Line appears to end with an inline pragma comment. Splitting the line"
3399                 " could modify the pragma's behavior."
3400             )
3401
3402         if has_triple_quotes(string_leaf.value):
3403             return TErr("We cannot split multiline strings.")
3404
3405         return Ok(None)
3406
3407     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3408         """
3409         Calculates the max string length used when attempting to determine
3410         whether or not the target string is responsible for causing the line to
3411         go over the line length limit.
3412
3413         WARNING: This method is tightly coupled to both StringSplitter and
3414         (especially) StringParenWrapper. There is probably a better way to
3415         accomplish what is being done here.
3416
3417         Returns:
3418             max_string_length: such that `line.leaves[string_idx].value >
3419             max_string_length` implies that the target string IS responsible
3420             for causing this line to exceed the line length limit.
3421         """
3422         LL = line.leaves
3423
3424         is_valid_index = is_valid_index_factory(LL)
3425
3426         # We use the shorthand "WMA4" in comments to abbreviate "We must
3427         # account for". When giving examples, we use STRING to mean some/any
3428         # valid string.
3429         #
3430         # Finally, we use the following convenience variables:
3431         #
3432         #   P:  The leaf that is before the target string leaf.
3433         #   N:  The leaf that is after the target string leaf.
3434         #   NN: The leaf that is after N.
3435
3436         # WMA4 the whitespace at the beginning of the line.
3437         offset = line.depth * 4
3438
3439         if is_valid_index(string_idx - 1):
3440             p_idx = string_idx - 1
3441             if (
3442                 LL[string_idx - 1].type == token.LPAR
3443                 and LL[string_idx - 1].value == ""
3444                 and string_idx >= 2
3445             ):
3446                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3447                 p_idx -= 1
3448
3449             P = LL[p_idx]
3450             if P.type == token.PLUS:
3451                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3452                 offset += 2
3453
3454             if P.type == token.COMMA:
3455                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3456                 offset += 3
3457
3458             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3459                 # This conditional branch is meant to handle dictionary keys,
3460                 # variable assignments, 'return STRING' statement lines, and
3461                 # 'else STRING' ternary expression lines.
3462
3463                 # WMA4 a single space.
3464                 offset += 1
3465
3466                 # WMA4 the lengths of any leaves that came before that space.
3467                 for leaf in LL[: p_idx + 1]:
3468                     offset += len(str(leaf))
3469
3470         if is_valid_index(string_idx + 1):
3471             N = LL[string_idx + 1]
3472             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3473                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3474                 N = LL[string_idx + 2]
3475
3476             if N.type == token.COMMA:
3477                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3478                 offset += 1
3479
3480             if is_valid_index(string_idx + 2):
3481                 NN = LL[string_idx + 2]
3482
3483                 if N.type == token.DOT and NN.type == token.NAME:
3484                     # This conditional branch is meant to handle method calls invoked
3485                     # off of a string literal up to and including the LPAR character.
3486
3487                     # WMA4 the '.' character.
3488                     offset += 1
3489
3490                     if (
3491                         is_valid_index(string_idx + 3)
3492                         and LL[string_idx + 3].type == token.LPAR
3493                     ):
3494                         # WMA4 the left parenthesis character.
3495                         offset += 1
3496
3497                     # WMA4 the length of the method's name.
3498                     offset += len(NN.value)
3499
3500         has_comments = False
3501         for comment_leaf in line.comments_after(LL[string_idx]):
3502             if not has_comments:
3503                 has_comments = True
3504                 # WMA4 two spaces before the '#' character.
3505                 offset += 2
3506
3507             # WMA4 the length of the inline comment.
3508             offset += len(comment_leaf.value)
3509
3510         max_string_length = self.line_length - offset
3511         return max_string_length
3512
3513
3514 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3515     """
3516     StringTransformer that splits "atom" strings (i.e. strings which exist on
3517     lines by themselves).
3518
3519     Requirements:
3520         * The line consists ONLY of a single string (with the exception of a
3521         '+' symbol which MAY exist at the start of the line), MAYBE a string
3522         trailer, and MAYBE a trailing comma.
3523             AND
3524         * All of the requirements listed in BaseStringSplitter's docstring.
3525
3526     Transformations:
3527         The string mentioned in the 'Requirements' section is split into as
3528         many substrings as necessary to adhere to the configured line length.
3529
3530         In the final set of substrings, no substring should be smaller than
3531         MIN_SUBSTR_SIZE characters.
3532
3533         The string will ONLY be split on spaces (i.e. each new substring should
3534         start with a space).
3535
3536         If the string is an f-string, it will NOT be split in the middle of an
3537         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3538         else bar()} is an f-expression).
3539
3540         If the string that is being split has an associated set of custom split
3541         records and those custom splits will NOT result in any line going over
3542         the configured line length, those custom splits are used. Otherwise the
3543         string is split as late as possible (from left-to-right) while still
3544         adhering to the transformation rules listed above.
3545
3546     Collaborations:
3547         StringSplitter relies on StringMerger to construct the appropriate
3548         CustomSplit objects and add them to the custom split map.
3549     """
3550
3551     MIN_SUBSTR_SIZE = 6
3552     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3553     RE_FEXPR = r"""
3554     (?<!\{)\{
3555         (?:
3556             [^\{\}]
3557             | \{\{
3558             | \}\}
3559         )+?
3560     (?<!\})(?:\}\})*\}(?!\})
3561     """
3562
3563     def do_splitter_match(self, line: Line) -> TMatchResult:
3564         LL = line.leaves
3565
3566         is_valid_index = is_valid_index_factory(LL)
3567
3568         idx = 0
3569
3570         # The first leaf MAY be a '+' symbol...
3571         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3572             idx += 1
3573
3574         # The next/first leaf MAY be an empty LPAR...
3575         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3576             idx += 1
3577
3578         # The next/first leaf MUST be a string...
3579         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3580             return TErr("Line does not start with a string.")
3581
3582         string_idx = idx
3583
3584         # Skip the string trailer, if one exists.
3585         string_parser = StringParser()
3586         idx = string_parser.parse(LL, string_idx)
3587
3588         # That string MAY be followed by an empty RPAR...
3589         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3590             idx += 1
3591
3592         # That string / empty RPAR leaf MAY be followed by a comma...
3593         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3594             idx += 1
3595
3596         # But no more leaves are allowed...
3597         if is_valid_index(idx):
3598             return TErr("This line does not end with a string.")
3599
3600         return Ok(string_idx)
3601
3602     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3603         LL = line.leaves
3604
3605         QUOTE = LL[string_idx].value[-1]
3606
3607         is_valid_index = is_valid_index_factory(LL)
3608         insert_str_child = insert_str_child_factory(LL[string_idx])
3609
3610         prefix = get_string_prefix(LL[string_idx].value)
3611
3612         # We MAY choose to drop the 'f' prefix from substrings that don't
3613         # contain any f-expressions, but ONLY if the original f-string
3614         # containes at least one f-expression. Otherwise, we will alter the AST
3615         # of the program.
3616         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3617             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3618         )
3619
3620         first_string_line = True
3621         starts_with_plus = LL[0].type == token.PLUS
3622
3623         def line_needs_plus() -> bool:
3624             return first_string_line and starts_with_plus
3625
3626         def maybe_append_plus(new_line: Line) -> None:
3627             """
3628             Side Effects:
3629                 If @line starts with a plus and this is the first line we are
3630                 constructing, this function appends a PLUS leaf to @new_line
3631                 and replaces the old PLUS leaf in the node structure. Otherwise
3632                 this function does nothing.
3633             """
3634             if line_needs_plus():
3635                 plus_leaf = Leaf(token.PLUS, "+")
3636                 replace_child(LL[0], plus_leaf)
3637                 new_line.append(plus_leaf)
3638
3639         ends_with_comma = (
3640             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3641         )
3642
3643         def max_last_string() -> int:
3644             """
3645             Returns:
3646                 The max allowed length of the string value used for the last
3647                 line we will construct.
3648             """
3649             result = self.line_length
3650             result -= line.depth * 4
3651             result -= 1 if ends_with_comma else 0
3652             result -= 2 if line_needs_plus() else 0
3653             return result
3654
3655         # --- Calculate Max Break Index (for string value)
3656         # We start with the line length limit
3657         max_break_idx = self.line_length
3658         # The last index of a string of length N is N-1.
3659         max_break_idx -= 1
3660         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3661         max_break_idx -= line.depth * 4
3662         if max_break_idx < 0:
3663             yield TErr(
3664                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3665                 f" {line.depth}"
3666             )
3667             return
3668
3669         # Check if StringMerger registered any custom splits.
3670         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3671         # We use them ONLY if none of them would produce lines that exceed the
3672         # line limit.
3673         use_custom_breakpoints = bool(
3674             custom_splits
3675             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3676         )
3677
3678         # Temporary storage for the remaining chunk of the string line that
3679         # can't fit onto the line currently being constructed.
3680         rest_value = LL[string_idx].value
3681
3682         def more_splits_should_be_made() -> bool:
3683             """
3684             Returns:
3685                 True iff `rest_value` (the remaining string value from the last
3686                 split), should be split again.
3687             """
3688             if use_custom_breakpoints:
3689                 return len(custom_splits) > 1
3690             else:
3691                 return len(rest_value) > max_last_string()
3692
3693         string_line_results: List[Ok[Line]] = []
3694         while more_splits_should_be_made():
3695             if use_custom_breakpoints:
3696                 # Custom User Split (manual)
3697                 csplit = custom_splits.pop(0)
3698                 break_idx = csplit.break_idx
3699             else:
3700                 # Algorithmic Split (automatic)
3701                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3702                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3703                 if maybe_break_idx is None:
3704                     # If we are unable to algorthmically determine a good split
3705                     # and this string has custom splits registered to it, we
3706                     # fall back to using them--which means we have to start
3707                     # over from the beginning.
3708                     if custom_splits:
3709                         rest_value = LL[string_idx].value
3710                         string_line_results = []
3711                         first_string_line = True
3712                         use_custom_breakpoints = True
3713                         continue
3714
3715                     # Otherwise, we stop splitting here.
3716                     break
3717
3718                 break_idx = maybe_break_idx
3719
3720             # --- Construct `next_value`
3721             next_value = rest_value[:break_idx] + QUOTE
3722             if (
3723                 # Are we allowed to try to drop a pointless 'f' prefix?
3724                 drop_pointless_f_prefix
3725                 # If we are, will we be successful?
3726                 and next_value != self.__normalize_f_string(next_value, prefix)
3727             ):
3728                 # If the current custom split did NOT originally use a prefix,
3729                 # then `csplit.break_idx` will be off by one after removing
3730                 # the 'f' prefix.
3731                 break_idx = (
3732                     break_idx + 1
3733                     if use_custom_breakpoints and not csplit.has_prefix
3734                     else break_idx
3735                 )
3736                 next_value = rest_value[:break_idx] + QUOTE
3737                 next_value = self.__normalize_f_string(next_value, prefix)
3738
3739             # --- Construct `next_leaf`
3740             next_leaf = Leaf(token.STRING, next_value)
3741             insert_str_child(next_leaf)
3742             self.__maybe_normalize_string_quotes(next_leaf)
3743
3744             # --- Construct `next_line`
3745             next_line = line.clone()
3746             maybe_append_plus(next_line)
3747             next_line.append(next_leaf)
3748             string_line_results.append(Ok(next_line))
3749
3750             rest_value = prefix + QUOTE + rest_value[break_idx:]
3751             first_string_line = False
3752
3753         yield from string_line_results
3754
3755         if drop_pointless_f_prefix:
3756             rest_value = self.__normalize_f_string(rest_value, prefix)
3757
3758         rest_leaf = Leaf(token.STRING, rest_value)
3759         insert_str_child(rest_leaf)
3760
3761         # NOTE: I could not find a test case that verifies that the following
3762         # line is actually necessary, but it seems to be. Otherwise we risk
3763         # not normalizing the last substring, right?
3764         self.__maybe_normalize_string_quotes(rest_leaf)
3765
3766         last_line = line.clone()
3767         maybe_append_plus(last_line)
3768
3769         # If there are any leaves to the right of the target string...
3770         if is_valid_index(string_idx + 1):
3771             # We use `temp_value` here to determine how long the last line
3772             # would be if we were to append all the leaves to the right of the
3773             # target string to the last string line.
3774             temp_value = rest_value
3775             for leaf in LL[string_idx + 1 :]:
3776                 temp_value += str(leaf)
3777                 if leaf.type == token.LPAR:
3778                     break
3779
3780             # Try to fit them all on the same line with the last substring...
3781             if (
3782                 len(temp_value) <= max_last_string()
3783                 or LL[string_idx + 1].type == token.COMMA
3784             ):
3785                 last_line.append(rest_leaf)
3786                 append_leaves(last_line, line, LL[string_idx + 1 :])
3787                 yield Ok(last_line)
3788             # Otherwise, place the last substring on one line and everything
3789             # else on a line below that...
3790             else:
3791                 last_line.append(rest_leaf)
3792                 yield Ok(last_line)
3793
3794                 non_string_line = line.clone()
3795                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3796                 yield Ok(non_string_line)
3797         # Else the target string was the last leaf...
3798         else:
3799             last_line.append(rest_leaf)
3800             last_line.comments = line.comments.copy()
3801             yield Ok(last_line)
3802
3803     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3804         """
3805         This method contains the algorithm that StringSplitter uses to
3806         determine which character to split each string at.
3807
3808         Args:
3809             @string: The substring that we are attempting to split.
3810             @max_break_idx: The ideal break index. We will return this value if it
3811             meets all the necessary conditions. In the likely event that it
3812             doesn't we will try to find the closest index BELOW @max_break_idx
3813             that does. If that fails, we will expand our search by also
3814             considering all valid indices ABOVE @max_break_idx.
3815
3816         Pre-Conditions:
3817             * assert_is_leaf_string(@string)
3818             * 0 <= @max_break_idx < len(@string)
3819
3820         Returns:
3821             break_idx, if an index is able to be found that meets all of the
3822             conditions listed in the 'Transformations' section of this classes'
3823             docstring.
3824                 OR
3825             None, otherwise.
3826         """
3827         is_valid_index = is_valid_index_factory(string)
3828
3829         assert is_valid_index(max_break_idx)
3830         assert_is_leaf_string(string)
3831
3832         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3833
3834         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3835             """
3836             Yields:
3837                 All ranges of @string which, if @string were to be split there,
3838                 would result in the splitting of an f-expression (which is NOT
3839                 allowed).
3840             """
3841             nonlocal _fexpr_slices
3842
3843             if _fexpr_slices is None:
3844                 _fexpr_slices = []
3845                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3846                     _fexpr_slices.append(match.span())
3847
3848             yield from _fexpr_slices
3849
3850         is_fstring = "f" in get_string_prefix(string)
3851
3852         def breaks_fstring_expression(i: Index) -> bool:
3853             """
3854             Returns:
3855                 True iff returning @i would result in the splitting of an
3856                 f-expression (which is NOT allowed).
3857             """
3858             if not is_fstring:
3859                 return False
3860
3861             for (start, end) in fexpr_slices():
3862                 if start <= i < end:
3863                     return True
3864
3865             return False
3866
3867         def passes_all_checks(i: Index) -> bool:
3868             """
3869             Returns:
3870                 True iff ALL of the conditions listed in the 'Transformations'
3871                 section of this classes' docstring would be be met by returning @i.
3872             """
3873             is_space = string[i] == " "
3874             is_big_enough = (
3875                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3876                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3877             )
3878             return is_space and is_big_enough and not breaks_fstring_expression(i)
3879
3880         # First, we check all indices BELOW @max_break_idx.
3881         break_idx = max_break_idx
3882         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3883             break_idx -= 1
3884
3885         if not passes_all_checks(break_idx):
3886             # If that fails, we check all indices ABOVE @max_break_idx.
3887             #
3888             # If we are able to find a valid index here, the next line is going
3889             # to be longer than the specified line length, but it's probably
3890             # better than doing nothing at all.
3891             break_idx = max_break_idx + 1
3892             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3893                 break_idx += 1
3894
3895             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3896                 return None
3897
3898         return break_idx
3899
3900     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3901         if self.normalize_strings:
3902             normalize_string_quotes(leaf)
3903
3904     def __normalize_f_string(self, string: str, prefix: str) -> str:
3905         """
3906         Pre-Conditions:
3907             * assert_is_leaf_string(@string)
3908
3909         Returns:
3910             * If @string is an f-string that contains no f-expressions, we
3911             return a string identical to @string except that the 'f' prefix
3912             has been stripped and all double braces (i.e. '{{' or '}}') have
3913             been normalized (i.e. turned into '{' or '}').
3914                 OR
3915             * Otherwise, we return @string.
3916         """
3917         assert_is_leaf_string(string)
3918
3919         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3920             new_prefix = prefix.replace("f", "")
3921
3922             temp = string[len(prefix) :]
3923             temp = re.sub(r"\{\{", "{", temp)
3924             temp = re.sub(r"\}\}", "}", temp)
3925             new_string = temp
3926
3927             return f"{new_prefix}{new_string}"
3928         else:
3929             return string
3930
3931
3932 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
3933     """
3934     StringTransformer that splits non-"atom" strings (i.e. strings that do not
3935     exist on lines by themselves).
3936
3937     Requirements:
3938         All of the requirements listed in BaseStringSplitter's docstring in
3939         addition to the requirements listed below:
3940
3941         * The line is a return/yield statement, which returns/yields a string.
3942             OR
3943         * The line is part of a ternary expression (e.g. `x = y if cond else
3944         z`) such that the line starts with `else <string>`, where <string> is
3945         some string.
3946             OR
3947         * The line is an assert statement, which ends with a string.
3948             OR
3949         * The line is an assignment statement (e.g. `x = <string>` or `x +=
3950         <string>`) such that the variable is being assigned the value of some
3951         string.
3952             OR
3953         * The line is a dictionary key assignment where some valid key is being
3954         assigned the value of some string.
3955
3956     Transformations:
3957         The chosen string is wrapped in parentheses and then split at the LPAR.
3958
3959         We then have one line which ends with an LPAR and another line that
3960         starts with the chosen string. The latter line is then split again at
3961         the RPAR. This results in the RPAR (and possibly a trailing comma)
3962         being placed on its own line.
3963
3964         NOTE: If any leaves exist to the right of the chosen string (except
3965         for a trailing comma, which would be placed after the RPAR), those
3966         leaves are placed inside the parentheses.  In effect, the chosen
3967         string is not necessarily being "wrapped" by parentheses. We can,
3968         however, count on the LPAR being placed directly before the chosen
3969         string.
3970
3971         In other words, StringParenWrapper creates "atom" strings. These
3972         can then be split again by StringSplitter, if necessary.
3973
3974     Collaborations:
3975         In the event that a string line split by StringParenWrapper is
3976         changed such that it no longer needs to be given its own line,
3977         StringParenWrapper relies on StringParenStripper to clean up the
3978         parentheses it created.
3979     """
3980
3981     def do_splitter_match(self, line: Line) -> TMatchResult:
3982         LL = line.leaves
3983
3984         string_idx = None
3985         string_idx = string_idx or self._return_match(LL)
3986         string_idx = string_idx or self._else_match(LL)
3987         string_idx = string_idx or self._assert_match(LL)
3988         string_idx = string_idx or self._assign_match(LL)
3989         string_idx = string_idx or self._dict_match(LL)
3990
3991         if string_idx is not None:
3992             string_value = line.leaves[string_idx].value
3993             # If the string has no spaces...
3994             if " " not in string_value:
3995                 # And will still violate the line length limit when split...
3996                 max_string_length = self.line_length - ((line.depth + 1) * 4)
3997                 if len(string_value) > max_string_length:
3998                     # And has no associated custom splits...
3999                     if not self.has_custom_splits(string_value):
4000                         # Then we should NOT put this string on its own line.
4001                         return TErr(
4002                             "We do not wrap long strings in parentheses when the"
4003                             " resultant line would still be over the specified line"
4004                             " length and can't be split further by StringSplitter."
4005                         )
4006             return Ok(string_idx)
4007
4008         return TErr("This line does not contain any non-atomic strings.")
4009
4010     @staticmethod
4011     def _return_match(LL: List[Leaf]) -> Optional[int]:
4012         """
4013         Returns:
4014             string_idx such that @LL[string_idx] is equal to our target (i.e.
4015             matched) string, if this line matches the return/yield statement
4016             requirements listed in the 'Requirements' section of this classes'
4017             docstring.
4018                 OR
4019             None, otherwise.
4020         """
4021         # If this line is apart of a return/yield statement and the first leaf
4022         # contains either the "return" or "yield" keywords...
4023         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4024             0
4025         ].value in ["return", "yield"]:
4026             is_valid_index = is_valid_index_factory(LL)
4027
4028             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4029             # The next visible leaf MUST contain a string...
4030             if is_valid_index(idx) and LL[idx].type == token.STRING:
4031                 return idx
4032
4033         return None
4034
4035     @staticmethod
4036     def _else_match(LL: List[Leaf]) -> Optional[int]:
4037         """
4038         Returns:
4039             string_idx such that @LL[string_idx] is equal to our target (i.e.
4040             matched) string, if this line matches the ternary expression
4041             requirements listed in the 'Requirements' section of this classes'
4042             docstring.
4043                 OR
4044             None, otherwise.
4045         """
4046         # If this line is apart of a ternary expression and the first leaf
4047         # contains the "else" keyword...
4048         if (
4049             parent_type(LL[0]) == syms.test
4050             and LL[0].type == token.NAME
4051             and LL[0].value == "else"
4052         ):
4053             is_valid_index = is_valid_index_factory(LL)
4054
4055             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4056             # The next visible leaf MUST contain a string...
4057             if is_valid_index(idx) and LL[idx].type == token.STRING:
4058                 return idx
4059
4060         return None
4061
4062     @staticmethod
4063     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4064         """
4065         Returns:
4066             string_idx such that @LL[string_idx] is equal to our target (i.e.
4067             matched) string, if this line matches the assert statement
4068             requirements listed in the 'Requirements' section of this classes'
4069             docstring.
4070                 OR
4071             None, otherwise.
4072         """
4073         # If this line is apart of an assert statement and the first leaf
4074         # contains the "assert" keyword...
4075         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4076             is_valid_index = is_valid_index_factory(LL)
4077
4078             for (i, leaf) in enumerate(LL):
4079                 # We MUST find a comma...
4080                 if leaf.type == token.COMMA:
4081                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4082
4083                     # That comma MUST be followed by a string...
4084                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4085                         string_idx = idx
4086
4087                         # Skip the string trailer, if one exists.
4088                         string_parser = StringParser()
4089                         idx = string_parser.parse(LL, string_idx)
4090
4091                         # But no more leaves are allowed...
4092                         if not is_valid_index(idx):
4093                             return string_idx
4094
4095         return None
4096
4097     @staticmethod
4098     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4099         """
4100         Returns:
4101             string_idx such that @LL[string_idx] is equal to our target (i.e.
4102             matched) string, if this line matches the assignment statement
4103             requirements listed in the 'Requirements' section of this classes'
4104             docstring.
4105                 OR
4106             None, otherwise.
4107         """
4108         # If this line is apart of an expression statement or is a function
4109         # argument AND the first leaf contains a variable name...
4110         if (
4111             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4112             and LL[0].type == token.NAME
4113         ):
4114             is_valid_index = is_valid_index_factory(LL)
4115
4116             for (i, leaf) in enumerate(LL):
4117                 # We MUST find either an '=' or '+=' symbol...
4118                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4119                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4120
4121                     # That symbol MUST be followed by a string...
4122                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4123                         string_idx = idx
4124
4125                         # Skip the string trailer, if one exists.
4126                         string_parser = StringParser()
4127                         idx = string_parser.parse(LL, string_idx)
4128
4129                         # The next leaf MAY be a comma iff this line is apart
4130                         # of a function argument...
4131                         if (
4132                             parent_type(LL[0]) == syms.argument
4133                             and is_valid_index(idx)
4134                             and LL[idx].type == token.COMMA
4135                         ):
4136                             idx += 1
4137
4138                         # But no more leaves are allowed...
4139                         if not is_valid_index(idx):
4140                             return string_idx
4141
4142         return None
4143
4144     @staticmethod
4145     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4146         """
4147         Returns:
4148             string_idx such that @LL[string_idx] is equal to our target (i.e.
4149             matched) string, if this line matches the dictionary key assignment
4150             statement requirements listed in the 'Requirements' section of this
4151             classes' docstring.
4152                 OR
4153             None, otherwise.
4154         """
4155         # If this line is apart of a dictionary key assignment...
4156         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4157             is_valid_index = is_valid_index_factory(LL)
4158
4159             for (i, leaf) in enumerate(LL):
4160                 # We MUST find a colon...
4161                 if leaf.type == token.COLON:
4162                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4163
4164                     # That colon MUST be followed by a string...
4165                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4166                         string_idx = idx
4167
4168                         # Skip the string trailer, if one exists.
4169                         string_parser = StringParser()
4170                         idx = string_parser.parse(LL, string_idx)
4171
4172                         # That string MAY be followed by a comma...
4173                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4174                             idx += 1
4175
4176                         # But no more leaves are allowed...
4177                         if not is_valid_index(idx):
4178                             return string_idx
4179
4180         return None
4181
4182     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4183         LL = line.leaves
4184
4185         is_valid_index = is_valid_index_factory(LL)
4186         insert_str_child = insert_str_child_factory(LL[string_idx])
4187
4188         comma_idx = len(LL) - 1
4189         ends_with_comma = False
4190         if LL[comma_idx].type == token.COMMA:
4191             ends_with_comma = True
4192
4193         leaves_to_steal_comments_from = [LL[string_idx]]
4194         if ends_with_comma:
4195             leaves_to_steal_comments_from.append(LL[comma_idx])
4196
4197         # --- First Line
4198         first_line = line.clone()
4199         left_leaves = LL[:string_idx]
4200
4201         # We have to remember to account for (possibly invisible) LPAR and RPAR
4202         # leaves that already wrapped the target string. If these leaves do
4203         # exist, we will replace them with our own LPAR and RPAR leaves.
4204         old_parens_exist = False
4205         if left_leaves and left_leaves[-1].type == token.LPAR:
4206             old_parens_exist = True
4207             leaves_to_steal_comments_from.append(left_leaves[-1])
4208             left_leaves.pop()
4209
4210         append_leaves(first_line, line, left_leaves)
4211
4212         lpar_leaf = Leaf(token.LPAR, "(")
4213         if old_parens_exist:
4214             replace_child(LL[string_idx - 1], lpar_leaf)
4215         else:
4216             insert_str_child(lpar_leaf)
4217         first_line.append(lpar_leaf)
4218
4219         # We throw inline comments that were originally to the right of the
4220         # target string to the top line. They will now be shown to the right of
4221         # the LPAR.
4222         for leaf in leaves_to_steal_comments_from:
4223             for comment_leaf in line.comments_after(leaf):
4224                 first_line.append(comment_leaf, preformatted=True)
4225
4226         yield Ok(first_line)
4227
4228         # --- Middle (String) Line
4229         # We only need to yield one (possibly too long) string line, since the
4230         # `StringSplitter` will break it down further if necessary.
4231         string_value = LL[string_idx].value
4232         string_line = Line(
4233             depth=line.depth + 1,
4234             inside_brackets=True,
4235             should_explode=line.should_explode,
4236         )
4237         string_leaf = Leaf(token.STRING, string_value)
4238         insert_str_child(string_leaf)
4239         string_line.append(string_leaf)
4240
4241         old_rpar_leaf = None
4242         if is_valid_index(string_idx + 1):
4243             right_leaves = LL[string_idx + 1 :]
4244             if ends_with_comma:
4245                 right_leaves.pop()
4246
4247             if old_parens_exist:
4248                 assert (
4249                     right_leaves and right_leaves[-1].type == token.RPAR
4250                 ), "Apparently, old parentheses do NOT exist?!"
4251                 old_rpar_leaf = right_leaves.pop()
4252
4253             append_leaves(string_line, line, right_leaves)
4254
4255         yield Ok(string_line)
4256
4257         # --- Last Line
4258         last_line = line.clone()
4259         last_line.bracket_tracker = first_line.bracket_tracker
4260
4261         new_rpar_leaf = Leaf(token.RPAR, ")")
4262         if old_rpar_leaf is not None:
4263             replace_child(old_rpar_leaf, new_rpar_leaf)
4264         else:
4265             insert_str_child(new_rpar_leaf)
4266         last_line.append(new_rpar_leaf)
4267
4268         # If the target string ended with a comma, we place this comma to the
4269         # right of the RPAR on the last line.
4270         if ends_with_comma:
4271             comma_leaf = Leaf(token.COMMA, ",")
4272             replace_child(LL[comma_idx], comma_leaf)
4273             last_line.append(comma_leaf)
4274
4275         yield Ok(last_line)
4276
4277
4278 class StringParser:
4279     """
4280     A state machine that aids in parsing a string's "trailer", which can be
4281     either non-existant, an old-style formatting sequence (e.g. `% varX` or `%
4282     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4283     varY)`).
4284
4285     NOTE: A new StringParser object MUST be instantiated for each string
4286     trailer we need to parse.
4287
4288     Examples:
4289         We shall assume that `line` equals the `Line` object that corresponds
4290         to the following line of python code:
4291         ```
4292         x = "Some {}.".format("String") + some_other_string
4293         ```
4294
4295         Furthermore, we will assume that `string_idx` is some index such that:
4296         ```
4297         assert line.leaves[string_idx].value == "Some {}."
4298         ```
4299
4300         The following code snippet then holds:
4301         ```
4302         string_parser = StringParser()
4303         idx = string_parser.parse(line.leaves, string_idx)
4304         assert line.leaves[idx].type == token.PLUS
4305         ```
4306     """
4307
4308     DEFAULT_TOKEN = -1
4309
4310     # String Parser States
4311     START = 1
4312     DOT = 2
4313     NAME = 3
4314     PERCENT = 4
4315     SINGLE_FMT_ARG = 5
4316     LPAR = 6
4317     RPAR = 7
4318     DONE = 8
4319
4320     # Lookup Table for Next State
4321     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4322         # A string trailer may start with '.' OR '%'.
4323         (START, token.DOT): DOT,
4324         (START, token.PERCENT): PERCENT,
4325         (START, DEFAULT_TOKEN): DONE,
4326         # A '.' MUST be followed by an attribute or method name.
4327         (DOT, token.NAME): NAME,
4328         # A method name MUST be followed by an '(', whereas an attribute name
4329         # is the last symbol in the string trailer.
4330         (NAME, token.LPAR): LPAR,
4331         (NAME, DEFAULT_TOKEN): DONE,
4332         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4333         # string or variable name).
4334         (PERCENT, token.LPAR): LPAR,
4335         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4336         # If a '%' symbol is followed by a single argument, that argument is
4337         # the last leaf in the string trailer.
4338         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4339         # If present, a ')' symbol is the last symbol in a string trailer.
4340         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4341         # since they are treated as a special case by the parsing logic in this
4342         # classes' implementation.)
4343         (RPAR, DEFAULT_TOKEN): DONE,
4344     }
4345
4346     def __init__(self) -> None:
4347         self._state = self.START
4348         self._unmatched_lpars = 0
4349
4350     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4351         """
4352         Pre-conditions:
4353             * @leaves[@string_idx].type == token.STRING
4354
4355         Returns:
4356             The index directly after the last leaf which is apart of the string
4357             trailer, if a "trailer" exists.
4358                 OR
4359             @string_idx + 1, if no string "trailer" exists.
4360         """
4361         assert leaves[string_idx].type == token.STRING
4362
4363         idx = string_idx + 1
4364         while idx < len(leaves) and self._next_state(leaves[idx]):
4365             idx += 1
4366         return idx
4367
4368     def _next_state(self, leaf: Leaf) -> bool:
4369         """
4370         Pre-conditions:
4371             * On the first call to this function, @leaf MUST be the leaf that
4372             was directly after the string leaf in question (e.g. if our target
4373             string is `line.leaves[i]` then the first call to this method must
4374             be `line.leaves[i + 1]`).
4375             * On the next call to this function, the leaf paramater passed in
4376             MUST be the leaf directly following @leaf.
4377
4378         Returns:
4379             True iff @leaf is apart of the string's trailer.
4380         """
4381         # We ignore empty LPAR or RPAR leaves.
4382         if is_empty_par(leaf):
4383             return True
4384
4385         next_token = leaf.type
4386         if next_token == token.LPAR:
4387             self._unmatched_lpars += 1
4388
4389         current_state = self._state
4390
4391         # The LPAR parser state is a special case. We will return True until we
4392         # find the matching RPAR token.
4393         if current_state == self.LPAR:
4394             if next_token == token.RPAR:
4395                 self._unmatched_lpars -= 1
4396                 if self._unmatched_lpars == 0:
4397                     self._state = self.RPAR
4398         # Otherwise, we use a lookup table to determine the next state.
4399         else:
4400             # If the lookup table matches the current state to the next
4401             # token, we use the lookup table.
4402             if (current_state, next_token) in self._goto:
4403                 self._state = self._goto[current_state, next_token]
4404             else:
4405                 # Otherwise, we check if a the current state was assigned a
4406                 # default.
4407                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4408                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4409                 # If no default has been assigned, then this parser has a logic
4410                 # error.
4411                 else:
4412                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4413
4414             if self._state == self.DONE:
4415                 return False
4416
4417         return True
4418
4419
4420 def TErr(err_msg: str) -> Err[CannotTransform]:
4421     """(T)ransform Err
4422
4423     Convenience function used when working with the TResult type.
4424     """
4425     cant_transform = CannotTransform(err_msg)
4426     return Err(cant_transform)
4427
4428
4429 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4430     """
4431     Returns:
4432         True iff one of the comments in @comment_list is a pragma used by one
4433         of the more common static analysis tools for python (e.g. mypy, flake8,
4434         pylint).
4435     """
4436     for comment in comment_list:
4437         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4438             return True
4439
4440     return False
4441
4442
4443 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4444     """
4445     Factory for a convenience function that is used to orphan @string_leaf
4446     and then insert multiple new leaves into the same part of the node
4447     structure that @string_leaf had originally occupied.
4448
4449     Examples:
4450         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4451         string_leaf.parent`. Assume the node `N` has the following
4452         original structure:
4453
4454         Node(
4455             expr_stmt, [
4456                 Leaf(NAME, 'x'),
4457                 Leaf(EQUAL, '='),
4458                 Leaf(STRING, '"foo"'),
4459             ]
4460         )
4461
4462         We then run the code snippet shown below.
4463         ```
4464         insert_str_child = insert_str_child_factory(string_leaf)
4465
4466         lpar = Leaf(token.LPAR, '(')
4467         insert_str_child(lpar)
4468
4469         bar = Leaf(token.STRING, '"bar"')
4470         insert_str_child(bar)
4471
4472         rpar = Leaf(token.RPAR, ')')
4473         insert_str_child(rpar)
4474         ```
4475
4476         After which point, it follows that `string_leaf.parent is None` and
4477         the node `N` now has the following structure:
4478
4479         Node(
4480             expr_stmt, [
4481                 Leaf(NAME, 'x'),
4482                 Leaf(EQUAL, '='),
4483                 Leaf(LPAR, '('),
4484                 Leaf(STRING, '"bar"'),
4485                 Leaf(RPAR, ')'),
4486             ]
4487         )
4488     """
4489     string_parent = string_leaf.parent
4490     string_child_idx = string_leaf.remove()
4491
4492     def insert_str_child(child: LN) -> None:
4493         nonlocal string_child_idx
4494
4495         assert string_parent is not None
4496         assert string_child_idx is not None
4497
4498         string_parent.insert_child(string_child_idx, child)
4499         string_child_idx += 1
4500
4501     return insert_str_child
4502
4503
4504 def has_triple_quotes(string: str) -> bool:
4505     """
4506     Returns:
4507         True iff @string starts with three quotation characters.
4508     """
4509     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4510     return raw_string[:3] in {'"""', "'''"}
4511
4512
4513 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4514     """
4515     Returns:
4516         @node.parent.type, if @node is not None and has a parent.
4517             OR
4518         None, otherwise.
4519     """
4520     if node is None or node.parent is None:
4521         return None
4522
4523     return node.parent.type
4524
4525
4526 def is_empty_par(leaf: Leaf) -> bool:
4527     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4528
4529
4530 def is_empty_lpar(leaf: Leaf) -> bool:
4531     return leaf.type == token.LPAR and leaf.value == ""
4532
4533
4534 def is_empty_rpar(leaf: Leaf) -> bool:
4535     return leaf.type == token.RPAR and leaf.value == ""
4536
4537
4538 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4539     """
4540     Examples:
4541         ```
4542         my_list = [1, 2, 3]
4543
4544         is_valid_index = is_valid_index_factory(my_list)
4545
4546         assert is_valid_index(0)
4547         assert is_valid_index(2)
4548
4549         assert not is_valid_index(3)
4550         assert not is_valid_index(-1)
4551         ```
4552     """
4553
4554     def is_valid_index(idx: int) -> bool:
4555         """
4556         Returns:
4557             True iff @idx is positive AND seq[@idx] does NOT raise an
4558             IndexError.
4559         """
4560         return 0 <= idx < len(seq)
4561
4562     return is_valid_index
4563
4564
4565 def line_to_string(line: Line) -> str:
4566     """Returns the string representation of @line.
4567
4568     WARNING: This is known to be computationally expensive.
4569     """
4570     return str(line).strip("\n")
4571
4572
4573 def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
4574     """
4575     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4576     underlying Node structure where appropriate.
4577
4578     All of the leaves in @leaves are duplicated. The duplicates are then
4579     appended to @new_line and used to replace their originals in the underlying
4580     Node structure. Any comments attatched to the old leaves are reattached to
4581     the new leaves.
4582
4583     Pre-conditions:
4584         set(@leaves) is a subset of set(@old_line.leaves).
4585     """
4586     for old_leaf in leaves:
4587         assert old_leaf in old_line.leaves
4588
4589         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4590         replace_child(old_leaf, new_leaf)
4591         new_line.append(new_leaf)
4592
4593         for comment_leaf in old_line.comments_after(old_leaf):
4594             new_line.append(comment_leaf, preformatted=True)
4595
4596
4597 def replace_child(old_child: LN, new_child: LN) -> None:
4598     """
4599     Side Effects:
4600         * If @old_child.parent is set, replace @old_child with @new_child in
4601         @old_child's underlying Node structure.
4602             OR
4603         * Otherwise, this function does nothing.
4604     """
4605     parent = old_child.parent
4606     if not parent:
4607         return
4608
4609     child_idx = old_child.remove()
4610     if child_idx is not None:
4611         parent.insert_child(child_idx, new_child)
4612
4613
4614 def get_string_prefix(string: str) -> str:
4615     """
4616     Pre-conditions:
4617         * assert_is_leaf_string(@string)
4618
4619     Returns:
4620         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4621     """
4622     assert_is_leaf_string(string)
4623
4624     prefix = ""
4625     prefix_idx = 0
4626     while string[prefix_idx] in STRING_PREFIX_CHARS:
4627         prefix += string[prefix_idx].lower()
4628         prefix_idx += 1
4629
4630     return prefix
4631
4632
4633 def assert_is_leaf_string(string: str) -> None:
4634     """
4635     Checks the pre-condition that @string has the format that you would expect
4636     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4637     token.STRING`. A more precise description of the pre-conditions that are
4638     checked are listed below.
4639
4640     Pre-conditions:
4641         * @string starts with either ', ", <prefix>', or <prefix>" where
4642         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4643         * @string ends with a quote character (' or ").
4644
4645     Raises:
4646         AssertionError(...) if the pre-conditions listed above are not
4647         satisfied.
4648     """
4649     dquote_idx = string.find('"')
4650     squote_idx = string.find("'")
4651     if -1 in [dquote_idx, squote_idx]:
4652         quote_idx = max(dquote_idx, squote_idx)
4653     else:
4654         quote_idx = min(squote_idx, dquote_idx)
4655
4656     assert (
4657         0 <= quote_idx < len(string) - 1
4658     ), f"{string!r} is missing a starting quote character (' or \")."
4659     assert string[-1] in (
4660         "'",
4661         '"',
4662     ), f"{string!r} is missing an ending quote character (' or \")."
4663     assert set(string[:quote_idx]).issubset(
4664         set(STRING_PREFIX_CHARS)
4665     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4666
4667
4668 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4669     """Split line into many lines, starting with the first matching bracket pair.
4670
4671     Note: this usually looks weird, only use this for function definitions.
4672     Prefer RHS otherwise.  This is why this function is not symmetrical with
4673     :func:`right_hand_split` which also handles optional parentheses.
4674     """
4675     tail_leaves: List[Leaf] = []
4676     body_leaves: List[Leaf] = []
4677     head_leaves: List[Leaf] = []
4678     current_leaves = head_leaves
4679     matching_bracket: Optional[Leaf] = None
4680     for leaf in line.leaves:
4681         if (
4682             current_leaves is body_leaves
4683             and leaf.type in CLOSING_BRACKETS
4684             and leaf.opening_bracket is matching_bracket
4685         ):
4686             current_leaves = tail_leaves if body_leaves else head_leaves
4687         current_leaves.append(leaf)
4688         if current_leaves is head_leaves:
4689             if leaf.type in OPENING_BRACKETS:
4690                 matching_bracket = leaf
4691                 current_leaves = body_leaves
4692     if not matching_bracket:
4693         raise CannotSplit("No brackets found")
4694
4695     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4696     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4697     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4698     bracket_split_succeeded_or_raise(head, body, tail)
4699     for result in (head, body, tail):
4700         if result:
4701             yield result
4702
4703
4704 def right_hand_split(
4705     line: Line,
4706     line_length: int,
4707     features: Collection[Feature] = (),
4708     omit: Collection[LeafID] = (),
4709 ) -> Iterator[Line]:
4710     """Split line into many lines, starting with the last matching bracket pair.
4711
4712     If the split was by optional parentheses, attempt splitting without them, too.
4713     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4714     this split.
4715
4716     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4717     """
4718     tail_leaves: List[Leaf] = []
4719     body_leaves: List[Leaf] = []
4720     head_leaves: List[Leaf] = []
4721     current_leaves = tail_leaves
4722     opening_bracket: Optional[Leaf] = None
4723     closing_bracket: Optional[Leaf] = None
4724     for leaf in reversed(line.leaves):
4725         if current_leaves is body_leaves:
4726             if leaf is opening_bracket:
4727                 current_leaves = head_leaves if body_leaves else tail_leaves
4728         current_leaves.append(leaf)
4729         if current_leaves is tail_leaves:
4730             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4731                 opening_bracket = leaf.opening_bracket
4732                 closing_bracket = leaf
4733                 current_leaves = body_leaves
4734     if not (opening_bracket and closing_bracket and head_leaves):
4735         # If there is no opening or closing_bracket that means the split failed and
4736         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4737         # the matching `opening_bracket` wasn't available on `line` anymore.
4738         raise CannotSplit("No brackets found")
4739
4740     tail_leaves.reverse()
4741     body_leaves.reverse()
4742     head_leaves.reverse()
4743     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4744     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4745     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4746     bracket_split_succeeded_or_raise(head, body, tail)
4747     if (
4748         # the body shouldn't be exploded
4749         not body.should_explode
4750         # the opening bracket is an optional paren
4751         and opening_bracket.type == token.LPAR
4752         and not opening_bracket.value
4753         # the closing bracket is an optional paren
4754         and closing_bracket.type == token.RPAR
4755         and not closing_bracket.value
4756         # it's not an import (optional parens are the only thing we can split on
4757         # in this case; attempting a split without them is a waste of time)
4758         and not line.is_import
4759         # there are no standalone comments in the body
4760         and not body.contains_standalone_comments(0)
4761         # and we can actually remove the parens
4762         and can_omit_invisible_parens(body, line_length)
4763     ):
4764         omit = {id(closing_bracket), *omit}
4765         try:
4766             yield from right_hand_split(line, line_length, features=features, omit=omit)
4767             return
4768
4769         except CannotSplit:
4770             if not (
4771                 can_be_split(body)
4772                 or is_line_short_enough(body, line_length=line_length)
4773             ):
4774                 raise CannotSplit(
4775                     "Splitting failed, body is still too long and can't be split."
4776                 )
4777
4778             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4779                 raise CannotSplit(
4780                     "The current optional pair of parentheses is bound to fail to"
4781                     " satisfy the splitting algorithm because the head or the tail"
4782                     " contains multiline strings which by definition never fit one"
4783                     " line."
4784                 )
4785
4786     ensure_visible(opening_bracket)
4787     ensure_visible(closing_bracket)
4788     for result in (head, body, tail):
4789         if result:
4790             yield result
4791
4792
4793 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4794     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4795
4796     Do nothing otherwise.
4797
4798     A left- or right-hand split is based on a pair of brackets. Content before
4799     (and including) the opening bracket is left on one line, content inside the
4800     brackets is put on a separate line, and finally content starting with and
4801     following the closing bracket is put on a separate line.
4802
4803     Those are called `head`, `body`, and `tail`, respectively. If the split
4804     produced the same line (all content in `head`) or ended up with an empty `body`
4805     and the `tail` is just the closing bracket, then it's considered failed.
4806     """
4807     tail_len = len(str(tail).strip())
4808     if not body:
4809         if tail_len == 0:
4810             raise CannotSplit("Splitting brackets produced the same line")
4811
4812         elif tail_len < 3:
4813             raise CannotSplit(
4814                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4815                 " not worth it"
4816             )
4817
4818
4819 def bracket_split_build_line(
4820     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4821 ) -> Line:
4822     """Return a new line with given `leaves` and respective comments from `original`.
4823
4824     If `is_body` is True, the result line is one-indented inside brackets and as such
4825     has its first leaf's prefix normalized and a trailing comma added when expected.
4826     """
4827     result = Line(depth=original.depth)
4828     if is_body:
4829         result.inside_brackets = True
4830         result.depth += 1
4831         if leaves:
4832             # Since body is a new indent level, remove spurious leading whitespace.
4833             normalize_prefix(leaves[0], inside_brackets=True)
4834             # Ensure a trailing comma for imports and standalone function arguments, but
4835             # be careful not to add one after any comments or within type annotations.
4836             no_commas = (
4837                 original.is_def
4838                 and opening_bracket.value == "("
4839                 and not any(l.type == token.COMMA for l in leaves)
4840             )
4841
4842             if original.is_import or no_commas:
4843                 for i in range(len(leaves) - 1, -1, -1):
4844                     if leaves[i].type == STANDALONE_COMMENT:
4845                         continue
4846
4847                     if leaves[i].type != token.COMMA:
4848                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
4849                     break
4850
4851     # Populate the line
4852     for leaf in leaves:
4853         result.append(leaf, preformatted=True)
4854         for comment_after in original.comments_after(leaf):
4855             result.append(comment_after, preformatted=True)
4856     if is_body:
4857         result.should_explode = should_explode(result, opening_bracket)
4858     return result
4859
4860
4861 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4862     """Normalize prefix of the first leaf in every line returned by `split_func`.
4863
4864     This is a decorator over relevant split functions.
4865     """
4866
4867     @wraps(split_func)
4868     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4869         for l in split_func(line, features):
4870             normalize_prefix(l.leaves[0], inside_brackets=True)
4871             yield l
4872
4873     return split_wrapper
4874
4875
4876 @dont_increase_indentation
4877 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4878     """Split according to delimiters of the highest priority.
4879
4880     If the appropriate Features are given, the split will add trailing commas
4881     also in function signatures and calls that contain `*` and `**`.
4882     """
4883     try:
4884         last_leaf = line.leaves[-1]
4885     except IndexError:
4886         raise CannotSplit("Line empty")
4887
4888     bt = line.bracket_tracker
4889     try:
4890         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4891     except ValueError:
4892         raise CannotSplit("No delimiters found")
4893
4894     if delimiter_priority == DOT_PRIORITY:
4895         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4896             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4897
4898     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4899     lowest_depth = sys.maxsize
4900     trailing_comma_safe = True
4901
4902     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4903         """Append `leaf` to current line or to new line if appending impossible."""
4904         nonlocal current_line
4905         try:
4906             current_line.append_safe(leaf, preformatted=True)
4907         except ValueError:
4908             yield current_line
4909
4910             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4911             current_line.append(leaf)
4912
4913     for leaf in line.leaves:
4914         yield from append_to_line(leaf)
4915
4916         for comment_after in line.comments_after(leaf):
4917             yield from append_to_line(comment_after)
4918
4919         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4920         if leaf.bracket_depth == lowest_depth:
4921             if is_vararg(leaf, within={syms.typedargslist}):
4922                 trailing_comma_safe = (
4923                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4924                 )
4925             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4926                 trailing_comma_safe = (
4927                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4928                 )
4929
4930         leaf_priority = bt.delimiters.get(id(leaf))
4931         if leaf_priority == delimiter_priority:
4932             yield current_line
4933
4934             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4935     if current_line:
4936         if (
4937             trailing_comma_safe
4938             and delimiter_priority == COMMA_PRIORITY
4939             and current_line.leaves[-1].type != token.COMMA
4940             and current_line.leaves[-1].type != STANDALONE_COMMENT
4941         ):
4942             current_line.append(Leaf(token.COMMA, ","))
4943         yield current_line
4944
4945
4946 @dont_increase_indentation
4947 def standalone_comment_split(
4948     line: Line, features: Collection[Feature] = ()
4949 ) -> Iterator[Line]:
4950     """Split standalone comments from the rest of the line."""
4951     if not line.contains_standalone_comments(0):
4952         raise CannotSplit("Line does not have any standalone comments")
4953
4954     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4955
4956     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4957         """Append `leaf` to current line or to new line if appending impossible."""
4958         nonlocal current_line
4959         try:
4960             current_line.append_safe(leaf, preformatted=True)
4961         except ValueError:
4962             yield current_line
4963
4964             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4965             current_line.append(leaf)
4966
4967     for leaf in line.leaves:
4968         yield from append_to_line(leaf)
4969
4970         for comment_after in line.comments_after(leaf):
4971             yield from append_to_line(comment_after)
4972
4973     if current_line:
4974         yield current_line
4975
4976
4977 def is_import(leaf: Leaf) -> bool:
4978     """Return True if the given leaf starts an import statement."""
4979     p = leaf.parent
4980     t = leaf.type
4981     v = leaf.value
4982     return bool(
4983         t == token.NAME
4984         and (
4985             (v == "import" and p and p.type == syms.import_name)
4986             or (v == "from" and p and p.type == syms.import_from)
4987         )
4988     )
4989
4990
4991 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
4992     """Return True if the given leaf is a special comment.
4993     Only returns true for type comments for now."""
4994     t = leaf.type
4995     v = leaf.value
4996     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
4997
4998
4999 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5000     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5001     else.
5002
5003     Note: don't use backslashes for formatting or you'll lose your voting rights.
5004     """
5005     if not inside_brackets:
5006         spl = leaf.prefix.split("#")
5007         if "\\" not in spl[0]:
5008             nl_count = spl[-1].count("\n")
5009             if len(spl) > 1:
5010                 nl_count -= 1
5011             leaf.prefix = "\n" * nl_count
5012             return
5013
5014     leaf.prefix = ""
5015
5016
5017 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5018     """Make all string prefixes lowercase.
5019
5020     If remove_u_prefix is given, also removes any u prefix from the string.
5021
5022     Note: Mutates its argument.
5023     """
5024     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5025     assert match is not None, f"failed to match string {leaf.value!r}"
5026     orig_prefix = match.group(1)
5027     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5028     if remove_u_prefix:
5029         new_prefix = new_prefix.replace("u", "")
5030     leaf.value = f"{new_prefix}{match.group(2)}"
5031
5032
5033 def normalize_string_quotes(leaf: Leaf) -> None:
5034     """Prefer double quotes but only if it doesn't cause more escaping.
5035
5036     Adds or removes backslashes as appropriate. Doesn't parse and fix
5037     strings nested in f-strings (yet).
5038
5039     Note: Mutates its argument.
5040     """
5041     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5042     if value[:3] == '"""':
5043         return
5044
5045     elif value[:3] == "'''":
5046         orig_quote = "'''"
5047         new_quote = '"""'
5048     elif value[0] == '"':
5049         orig_quote = '"'
5050         new_quote = "'"
5051     else:
5052         orig_quote = "'"
5053         new_quote = '"'
5054     first_quote_pos = leaf.value.find(orig_quote)
5055     if first_quote_pos == -1:
5056         return  # There's an internal error
5057
5058     prefix = leaf.value[:first_quote_pos]
5059     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5060     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5061     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5062     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5063     if "r" in prefix.casefold():
5064         if unescaped_new_quote.search(body):
5065             # There's at least one unescaped new_quote in this raw string
5066             # so converting is impossible
5067             return
5068
5069         # Do not introduce or remove backslashes in raw strings
5070         new_body = body
5071     else:
5072         # remove unnecessary escapes
5073         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5074         if body != new_body:
5075             # Consider the string without unnecessary escapes as the original
5076             body = new_body
5077             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5078         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5079         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5080     if "f" in prefix.casefold():
5081         matches = re.findall(
5082             r"""
5083             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5084                 ([^{].*?)  # contents of the brackets except if begins with {{
5085             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5086             """,
5087             new_body,
5088             re.VERBOSE,
5089         )
5090         for m in matches:
5091             if "\\" in str(m):
5092                 # Do not introduce backslashes in interpolated expressions
5093                 return
5094
5095     if new_quote == '"""' and new_body[-1:] == '"':
5096         # edge case:
5097         new_body = new_body[:-1] + '\\"'
5098     orig_escape_count = body.count("\\")
5099     new_escape_count = new_body.count("\\")
5100     if new_escape_count > orig_escape_count:
5101         return  # Do not introduce more escaping
5102
5103     if new_escape_count == orig_escape_count and orig_quote == '"':
5104         return  # Prefer double quotes
5105
5106     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5107
5108
5109 def normalize_numeric_literal(leaf: Leaf) -> None:
5110     """Normalizes numeric (float, int, and complex) literals.
5111
5112     All letters used in the representation are normalized to lowercase (except
5113     in Python 2 long literals).
5114     """
5115     text = leaf.value.lower()
5116     if text.startswith(("0o", "0b")):
5117         # Leave octal and binary literals alone.
5118         pass
5119     elif text.startswith("0x"):
5120         # Change hex literals to upper case.
5121         before, after = text[:2], text[2:]
5122         text = f"{before}{after.upper()}"
5123     elif "e" in text:
5124         before, after = text.split("e")
5125         sign = ""
5126         if after.startswith("-"):
5127             after = after[1:]
5128             sign = "-"
5129         elif after.startswith("+"):
5130             after = after[1:]
5131         before = format_float_or_int_string(before)
5132         text = f"{before}e{sign}{after}"
5133     elif text.endswith(("j", "l")):
5134         number = text[:-1]
5135         suffix = text[-1]
5136         # Capitalize in "2L" because "l" looks too similar to "1".
5137         if suffix == "l":
5138             suffix = "L"
5139         text = f"{format_float_or_int_string(number)}{suffix}"
5140     else:
5141         text = format_float_or_int_string(text)
5142     leaf.value = text
5143
5144
5145 def format_float_or_int_string(text: str) -> str:
5146     """Formats a float string like "1.0"."""
5147     if "." not in text:
5148         return text
5149
5150     before, after = text.split(".")
5151     return f"{before or 0}.{after or 0}"
5152
5153
5154 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5155     """Make existing optional parentheses invisible or create new ones.
5156
5157     `parens_after` is a set of string leaf values immediately after which parens
5158     should be put.
5159
5160     Standardizes on visible parentheses for single-element tuples, and keeps
5161     existing visible parentheses for other tuples and generator expressions.
5162     """
5163     for pc in list_comments(node.prefix, is_endmarker=False):
5164         if pc.value in FMT_OFF:
5165             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5166             return
5167     check_lpar = False
5168     for index, child in enumerate(list(node.children)):
5169         # Fixes a bug where invisible parens are not properly stripped from
5170         # assignment statements that contain type annotations.
5171         if isinstance(child, Node) and child.type == syms.annassign:
5172             normalize_invisible_parens(child, parens_after=parens_after)
5173
5174         # Add parentheses around long tuple unpacking in assignments.
5175         if (
5176             index == 0
5177             and isinstance(child, Node)
5178             and child.type == syms.testlist_star_expr
5179         ):
5180             check_lpar = True
5181
5182         if check_lpar:
5183             if is_walrus_assignment(child):
5184                 continue
5185
5186             if child.type == syms.atom:
5187                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5188                     wrap_in_parentheses(node, child, visible=False)
5189             elif is_one_tuple(child):
5190                 wrap_in_parentheses(node, child, visible=True)
5191             elif node.type == syms.import_from:
5192                 # "import from" nodes store parentheses directly as part of
5193                 # the statement
5194                 if child.type == token.LPAR:
5195                     # make parentheses invisible
5196                     child.value = ""  # type: ignore
5197                     node.children[-1].value = ""  # type: ignore
5198                 elif child.type != token.STAR:
5199                     # insert invisible parentheses
5200                     node.insert_child(index, Leaf(token.LPAR, ""))
5201                     node.append_child(Leaf(token.RPAR, ""))
5202                 break
5203
5204             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5205                 wrap_in_parentheses(node, child, visible=False)
5206
5207         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5208
5209
5210 def normalize_fmt_off(node: Node) -> None:
5211     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5212     try_again = True
5213     while try_again:
5214         try_again = convert_one_fmt_off_pair(node)
5215
5216
5217 def convert_one_fmt_off_pair(node: Node) -> bool:
5218     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5219
5220     Returns True if a pair was converted.
5221     """
5222     for leaf in node.leaves():
5223         previous_consumed = 0
5224         for comment in list_comments(leaf.prefix, is_endmarker=False):
5225             if comment.value in FMT_OFF:
5226                 # We only want standalone comments. If there's no previous leaf or
5227                 # the previous leaf is indentation, it's a standalone comment in
5228                 # disguise.
5229                 if comment.type != STANDALONE_COMMENT:
5230                     prev = preceding_leaf(leaf)
5231                     if prev and prev.type not in WHITESPACE:
5232                         continue
5233
5234                 ignored_nodes = list(generate_ignored_nodes(leaf))
5235                 if not ignored_nodes:
5236                     continue
5237
5238                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5239                 parent = first.parent
5240                 prefix = first.prefix
5241                 first.prefix = prefix[comment.consumed :]
5242                 hidden_value = (
5243                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5244                 )
5245                 if hidden_value.endswith("\n"):
5246                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5247                     # leaf (possibly followed by a DEDENT).
5248                     hidden_value = hidden_value[:-1]
5249                 first_idx: Optional[int] = None
5250                 for ignored in ignored_nodes:
5251                     index = ignored.remove()
5252                     if first_idx is None:
5253                         first_idx = index
5254                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5255                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5256                 parent.insert_child(
5257                     first_idx,
5258                     Leaf(
5259                         STANDALONE_COMMENT,
5260                         hidden_value,
5261                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5262                     ),
5263                 )
5264                 return True
5265
5266             previous_consumed = comment.consumed
5267
5268     return False
5269
5270
5271 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5272     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5273
5274     Stops at the end of the block.
5275     """
5276     container: Optional[LN] = container_of(leaf)
5277     while container is not None and container.type != token.ENDMARKER:
5278         if is_fmt_on(container):
5279             return
5280
5281         # fix for fmt: on in children
5282         if contains_fmt_on_at_column(container, leaf.column):
5283             for child in container.children:
5284                 if contains_fmt_on_at_column(child, leaf.column):
5285                     return
5286                 yield child
5287         else:
5288             yield container
5289             container = container.next_sibling
5290
5291
5292 def is_fmt_on(container: LN) -> bool:
5293     """Determine whether formatting is switched on within a container.
5294     Determined by whether the last `# fmt:` comment is `on` or `off`.
5295     """
5296     fmt_on = False
5297     for comment in list_comments(container.prefix, is_endmarker=False):
5298         if comment.value in FMT_ON:
5299             fmt_on = True
5300         elif comment.value in FMT_OFF:
5301             fmt_on = False
5302     return fmt_on
5303
5304
5305 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5306     """Determine if children at a given column have formatting switched on."""
5307     for child in container.children:
5308         if (
5309             isinstance(child, Node)
5310             and first_leaf_column(child) == column
5311             or isinstance(child, Leaf)
5312             and child.column == column
5313         ):
5314             if is_fmt_on(child):
5315                 return True
5316
5317     return False
5318
5319
5320 def first_leaf_column(node: Node) -> Optional[int]:
5321     """Returns the column of the first leaf child of a node."""
5322     for child in node.children:
5323         if isinstance(child, Leaf):
5324             return child.column
5325     return None
5326
5327
5328 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5329     """If it's safe, make the parens in the atom `node` invisible, recursively.
5330     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5331     as they are redundant.
5332
5333     Returns whether the node should itself be wrapped in invisible parentheses.
5334
5335     """
5336     if (
5337         node.type != syms.atom
5338         or is_empty_tuple(node)
5339         or is_one_tuple(node)
5340         or (is_yield(node) and parent.type != syms.expr_stmt)
5341         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5342     ):
5343         return False
5344
5345     first = node.children[0]
5346     last = node.children[-1]
5347     if first.type == token.LPAR and last.type == token.RPAR:
5348         middle = node.children[1]
5349         # make parentheses invisible
5350         first.value = ""  # type: ignore
5351         last.value = ""  # type: ignore
5352         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5353
5354         if is_atom_with_invisible_parens(middle):
5355             # Strip the invisible parens from `middle` by replacing
5356             # it with the child in-between the invisible parens
5357             middle.replace(middle.children[1])
5358
5359         return False
5360
5361     return True
5362
5363
5364 def is_atom_with_invisible_parens(node: LN) -> bool:
5365     """Given a `LN`, determines whether it's an atom `node` with invisible
5366     parens. Useful in dedupe-ing and normalizing parens.
5367     """
5368     if isinstance(node, Leaf) or node.type != syms.atom:
5369         return False
5370
5371     first, last = node.children[0], node.children[-1]
5372     return (
5373         isinstance(first, Leaf)
5374         and first.type == token.LPAR
5375         and first.value == ""
5376         and isinstance(last, Leaf)
5377         and last.type == token.RPAR
5378         and last.value == ""
5379     )
5380
5381
5382 def is_empty_tuple(node: LN) -> bool:
5383     """Return True if `node` holds an empty tuple."""
5384     return (
5385         node.type == syms.atom
5386         and len(node.children) == 2
5387         and node.children[0].type == token.LPAR
5388         and node.children[1].type == token.RPAR
5389     )
5390
5391
5392 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5393     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5394
5395     Parenthesis can be optional. Returns None otherwise"""
5396     if len(node.children) != 3:
5397         return None
5398
5399     lpar, wrapped, rpar = node.children
5400     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5401         return None
5402
5403     return wrapped
5404
5405
5406 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5407     """Wrap `child` in parentheses.
5408
5409     This replaces `child` with an atom holding the parentheses and the old
5410     child.  That requires moving the prefix.
5411
5412     If `visible` is False, the leaves will be valueless (and thus invisible).
5413     """
5414     lpar = Leaf(token.LPAR, "(" if visible else "")
5415     rpar = Leaf(token.RPAR, ")" if visible else "")
5416     prefix = child.prefix
5417     child.prefix = ""
5418     index = child.remove() or 0
5419     new_child = Node(syms.atom, [lpar, child, rpar])
5420     new_child.prefix = prefix
5421     parent.insert_child(index, new_child)
5422
5423
5424 def is_one_tuple(node: LN) -> bool:
5425     """Return True if `node` holds a tuple with one element, with or without parens."""
5426     if node.type == syms.atom:
5427         gexp = unwrap_singleton_parenthesis(node)
5428         if gexp is None or gexp.type != syms.testlist_gexp:
5429             return False
5430
5431         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5432
5433     return (
5434         node.type in IMPLICIT_TUPLE
5435         and len(node.children) == 2
5436         and node.children[1].type == token.COMMA
5437     )
5438
5439
5440 def is_walrus_assignment(node: LN) -> bool:
5441     """Return True iff `node` is of the shape ( test := test )"""
5442     inner = unwrap_singleton_parenthesis(node)
5443     return inner is not None and inner.type == syms.namedexpr_test
5444
5445
5446 def is_yield(node: LN) -> bool:
5447     """Return True if `node` holds a `yield` or `yield from` expression."""
5448     if node.type == syms.yield_expr:
5449         return True
5450
5451     if node.type == token.NAME and node.value == "yield":  # type: ignore
5452         return True
5453
5454     if node.type != syms.atom:
5455         return False
5456
5457     if len(node.children) != 3:
5458         return False
5459
5460     lpar, expr, rpar = node.children
5461     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5462         return is_yield(expr)
5463
5464     return False
5465
5466
5467 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5468     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5469
5470     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5471     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5472     extended iterable unpacking (PEP 3132) and additional unpacking
5473     generalizations (PEP 448).
5474     """
5475     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5476         return False
5477
5478     p = leaf.parent
5479     if p.type == syms.star_expr:
5480         # Star expressions are also used as assignment targets in extended
5481         # iterable unpacking (PEP 3132).  See what its parent is instead.
5482         if not p.parent:
5483             return False
5484
5485         p = p.parent
5486
5487     return p.type in within
5488
5489
5490 def is_multiline_string(leaf: Leaf) -> bool:
5491     """Return True if `leaf` is a multiline string that actually spans many lines."""
5492     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5493
5494
5495 def is_stub_suite(node: Node) -> bool:
5496     """Return True if `node` is a suite with a stub body."""
5497     if (
5498         len(node.children) != 4
5499         or node.children[0].type != token.NEWLINE
5500         or node.children[1].type != token.INDENT
5501         or node.children[3].type != token.DEDENT
5502     ):
5503         return False
5504
5505     return is_stub_body(node.children[2])
5506
5507
5508 def is_stub_body(node: LN) -> bool:
5509     """Return True if `node` is a simple statement containing an ellipsis."""
5510     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5511         return False
5512
5513     if len(node.children) != 2:
5514         return False
5515
5516     child = node.children[0]
5517     return (
5518         child.type == syms.atom
5519         and len(child.children) == 3
5520         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5521     )
5522
5523
5524 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5525     """Return maximum delimiter priority inside `node`.
5526
5527     This is specific to atoms with contents contained in a pair of parentheses.
5528     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5529     """
5530     if node.type != syms.atom:
5531         return 0
5532
5533     first = node.children[0]
5534     last = node.children[-1]
5535     if not (first.type == token.LPAR and last.type == token.RPAR):
5536         return 0
5537
5538     bt = BracketTracker()
5539     for c in node.children[1:-1]:
5540         if isinstance(c, Leaf):
5541             bt.mark(c)
5542         else:
5543             for leaf in c.leaves():
5544                 bt.mark(leaf)
5545     try:
5546         return bt.max_delimiter_priority()
5547
5548     except ValueError:
5549         return 0
5550
5551
5552 def ensure_visible(leaf: Leaf) -> None:
5553     """Make sure parentheses are visible.
5554
5555     They could be invisible as part of some statements (see
5556     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5557     """
5558     if leaf.type == token.LPAR:
5559         leaf.value = "("
5560     elif leaf.type == token.RPAR:
5561         leaf.value = ")"
5562
5563
5564 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
5565     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
5566
5567     if not (
5568         opening_bracket.parent
5569         and opening_bracket.parent.type in {syms.atom, syms.import_from}
5570         and opening_bracket.value in "[{("
5571     ):
5572         return False
5573
5574     try:
5575         last_leaf = line.leaves[-1]
5576         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
5577         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5578     except (IndexError, ValueError):
5579         return False
5580
5581     return max_priority == COMMA_PRIORITY
5582
5583
5584 def get_features_used(node: Node) -> Set[Feature]:
5585     """Return a set of (relatively) new Python features used in this file.
5586
5587     Currently looking for:
5588     - f-strings;
5589     - underscores in numeric literals;
5590     - trailing commas after * or ** in function signatures and calls;
5591     - positional only arguments in function signatures and lambdas;
5592     """
5593     features: Set[Feature] = set()
5594     for n in node.pre_order():
5595         if n.type == token.STRING:
5596             value_head = n.value[:2]  # type: ignore
5597             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5598                 features.add(Feature.F_STRINGS)
5599
5600         elif n.type == token.NUMBER:
5601             if "_" in n.value:  # type: ignore
5602                 features.add(Feature.NUMERIC_UNDERSCORES)
5603
5604         elif n.type == token.SLASH:
5605             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5606                 features.add(Feature.POS_ONLY_ARGUMENTS)
5607
5608         elif n.type == token.COLONEQUAL:
5609             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5610
5611         elif (
5612             n.type in {syms.typedargslist, syms.arglist}
5613             and n.children
5614             and n.children[-1].type == token.COMMA
5615         ):
5616             if n.type == syms.typedargslist:
5617                 feature = Feature.TRAILING_COMMA_IN_DEF
5618             else:
5619                 feature = Feature.TRAILING_COMMA_IN_CALL
5620
5621             for ch in n.children:
5622                 if ch.type in STARS:
5623                     features.add(feature)
5624
5625                 if ch.type == syms.argument:
5626                     for argch in ch.children:
5627                         if argch.type in STARS:
5628                             features.add(feature)
5629
5630     return features
5631
5632
5633 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5634     """Detect the version to target based on the nodes used."""
5635     features = get_features_used(node)
5636     return {
5637         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5638     }
5639
5640
5641 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5642     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5643
5644     Brackets can be omitted if the entire trailer up to and including
5645     a preceding closing bracket fits in one line.
5646
5647     Yielded sets are cumulative (contain results of previous yields, too).  First
5648     set is empty.
5649     """
5650
5651     omit: Set[LeafID] = set()
5652     yield omit
5653
5654     length = 4 * line.depth
5655     opening_bracket: Optional[Leaf] = None
5656     closing_bracket: Optional[Leaf] = None
5657     inner_brackets: Set[LeafID] = set()
5658     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5659         length += leaf_length
5660         if length > line_length:
5661             break
5662
5663         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5664         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5665             break
5666
5667         if opening_bracket:
5668             if leaf is opening_bracket:
5669                 opening_bracket = None
5670             elif leaf.type in CLOSING_BRACKETS:
5671                 inner_brackets.add(id(leaf))
5672         elif leaf.type in CLOSING_BRACKETS:
5673             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
5674                 # Empty brackets would fail a split so treat them as "inner"
5675                 # brackets (e.g. only add them to the `omit` set if another
5676                 # pair of brackets was good enough.
5677                 inner_brackets.add(id(leaf))
5678                 continue
5679
5680             if closing_bracket:
5681                 omit.add(id(closing_bracket))
5682                 omit.update(inner_brackets)
5683                 inner_brackets.clear()
5684                 yield omit
5685
5686             if leaf.value:
5687                 opening_bracket = leaf.opening_bracket
5688                 closing_bracket = leaf
5689
5690
5691 def get_future_imports(node: Node) -> Set[str]:
5692     """Return a set of __future__ imports in the file."""
5693     imports: Set[str] = set()
5694
5695     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5696         for child in children:
5697             if isinstance(child, Leaf):
5698                 if child.type == token.NAME:
5699                     yield child.value
5700
5701             elif child.type == syms.import_as_name:
5702                 orig_name = child.children[0]
5703                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5704                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5705                 yield orig_name.value
5706
5707             elif child.type == syms.import_as_names:
5708                 yield from get_imports_from_children(child.children)
5709
5710             else:
5711                 raise AssertionError("Invalid syntax parsing imports")
5712
5713     for child in node.children:
5714         if child.type != syms.simple_stmt:
5715             break
5716
5717         first_child = child.children[0]
5718         if isinstance(first_child, Leaf):
5719             # Continue looking if we see a docstring; otherwise stop.
5720             if (
5721                 len(child.children) == 2
5722                 and first_child.type == token.STRING
5723                 and child.children[1].type == token.NEWLINE
5724             ):
5725                 continue
5726
5727             break
5728
5729         elif first_child.type == syms.import_from:
5730             module_name = first_child.children[1]
5731             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5732                 break
5733
5734             imports |= set(get_imports_from_children(first_child.children[3:]))
5735         else:
5736             break
5737
5738     return imports
5739
5740
5741 @lru_cache()
5742 def get_gitignore(root: Path) -> PathSpec:
5743     """ Return a PathSpec matching gitignore content if present."""
5744     gitignore = root / ".gitignore"
5745     lines: List[str] = []
5746     if gitignore.is_file():
5747         with gitignore.open() as gf:
5748             lines = gf.readlines()
5749     return PathSpec.from_lines("gitwildmatch", lines)
5750
5751
5752 def gen_python_files(
5753     paths: Iterable[Path],
5754     root: Path,
5755     include: Optional[Pattern[str]],
5756     exclude_regexes: Iterable[Pattern[str]],
5757     report: "Report",
5758     gitignore: PathSpec,
5759 ) -> Iterator[Path]:
5760     """Generate all files under `path` whose paths are not excluded by the
5761     `exclude` regex, but are included by the `include` regex.
5762
5763     Symbolic links pointing outside of the `root` directory are ignored.
5764
5765     `report` is where output about exclusions goes.
5766     """
5767     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5768     for child in paths:
5769         # Then ignore with `exclude` option.
5770         try:
5771             normalized_path = child.resolve().relative_to(root).as_posix()
5772         except OSError as e:
5773             report.path_ignored(child, f"cannot be read because {e}")
5774             continue
5775         except ValueError:
5776             if child.is_symlink():
5777                 report.path_ignored(
5778                     child, f"is a symbolic link that points outside {root}"
5779                 )
5780                 continue
5781
5782             raise
5783
5784         # First ignore files matching .gitignore
5785         if gitignore.match_file(normalized_path):
5786             report.path_ignored(child, "matches the .gitignore file content")
5787             continue
5788
5789         normalized_path = "/" + normalized_path
5790         if child.is_dir():
5791             normalized_path += "/"
5792
5793         is_excluded = False
5794         for exclude in exclude_regexes:
5795             exclude_match = exclude.search(normalized_path) if exclude else None
5796             if exclude_match and exclude_match.group(0):
5797                 report.path_ignored(child, "matches the --exclude regular expression")
5798                 is_excluded = True
5799                 break
5800         if is_excluded:
5801             continue
5802
5803         if child.is_dir():
5804             yield from gen_python_files(
5805                 child.iterdir(), root, include, exclude_regexes, report, gitignore
5806             )
5807
5808         elif child.is_file():
5809             include_match = include.search(normalized_path) if include else True
5810             if include_match:
5811                 yield child
5812
5813
5814 @lru_cache()
5815 def find_project_root(srcs: Iterable[str]) -> Path:
5816     """Return a directory containing .git, .hg, or pyproject.toml.
5817
5818     That directory can be one of the directories passed in `srcs` or their
5819     common parent.
5820
5821     If no directory in the tree contains a marker that would specify it's the
5822     project root, the root of the file system is returned.
5823     """
5824     if not srcs:
5825         return Path("/").resolve()
5826
5827     common_base = min(Path(src).resolve() for src in srcs)
5828     if common_base.is_dir():
5829         # Append a fake file so `parents` below returns `common_base_dir`, too.
5830         common_base /= "fake-file"
5831     for directory in common_base.parents:
5832         if (directory / ".git").exists():
5833             return directory
5834
5835         if (directory / ".hg").is_dir():
5836             return directory
5837
5838         if (directory / "pyproject.toml").is_file():
5839             return directory
5840
5841     return directory
5842
5843
5844 @dataclass
5845 class Report:
5846     """Provides a reformatting counter. Can be rendered with `str(report)`."""
5847
5848     check: bool = False
5849     diff: bool = False
5850     quiet: bool = False
5851     verbose: bool = False
5852     change_count: int = 0
5853     same_count: int = 0
5854     failure_count: int = 0
5855
5856     def done(self, src: Path, changed: Changed) -> None:
5857         """Increment the counter for successful reformatting. Write out a message."""
5858         if changed is Changed.YES:
5859             reformatted = "would reformat" if self.check or self.diff else "reformatted"
5860             if self.verbose or not self.quiet:
5861                 out(f"{reformatted} {src}")
5862             self.change_count += 1
5863         else:
5864             if self.verbose:
5865                 if changed is Changed.NO:
5866                     msg = f"{src} already well formatted, good job."
5867                 else:
5868                     msg = f"{src} wasn't modified on disk since last run."
5869                 out(msg, bold=False)
5870             self.same_count += 1
5871
5872     def failed(self, src: Path, message: str) -> None:
5873         """Increment the counter for failed reformatting. Write out a message."""
5874         err(f"error: cannot format {src}: {message}")
5875         self.failure_count += 1
5876
5877     def path_ignored(self, path: Path, message: str) -> None:
5878         if self.verbose:
5879             out(f"{path} ignored: {message}", bold=False)
5880
5881     @property
5882     def return_code(self) -> int:
5883         """Return the exit code that the app should use.
5884
5885         This considers the current state of changed files and failures:
5886         - if there were any failures, return 123;
5887         - if any files were changed and --check is being used, return 1;
5888         - otherwise return 0.
5889         """
5890         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
5891         # 126 we have special return codes reserved by the shell.
5892         if self.failure_count:
5893             return 123
5894
5895         elif self.change_count and self.check:
5896             return 1
5897
5898         return 0
5899
5900     def __str__(self) -> str:
5901         """Render a color report of the current state.
5902
5903         Use `click.unstyle` to remove colors.
5904         """
5905         if self.check or self.diff:
5906             reformatted = "would be reformatted"
5907             unchanged = "would be left unchanged"
5908             failed = "would fail to reformat"
5909         else:
5910             reformatted = "reformatted"
5911             unchanged = "left unchanged"
5912             failed = "failed to reformat"
5913         report = []
5914         if self.change_count:
5915             s = "s" if self.change_count > 1 else ""
5916             report.append(
5917                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
5918             )
5919         if self.same_count:
5920             s = "s" if self.same_count > 1 else ""
5921             report.append(f"{self.same_count} file{s} {unchanged}")
5922         if self.failure_count:
5923             s = "s" if self.failure_count > 1 else ""
5924             report.append(
5925                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
5926             )
5927         return ", ".join(report) + "."
5928
5929
5930 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
5931     filename = "<unknown>"
5932     if sys.version_info >= (3, 8):
5933         # TODO: support Python 4+ ;)
5934         for minor_version in range(sys.version_info[1], 4, -1):
5935             try:
5936                 return ast.parse(src, filename, feature_version=(3, minor_version))
5937             except SyntaxError:
5938                 continue
5939     else:
5940         for feature_version in (7, 6):
5941             try:
5942                 return ast3.parse(src, filename, feature_version=feature_version)
5943             except SyntaxError:
5944                 continue
5945
5946     return ast27.parse(src)
5947
5948
5949 def _fixup_ast_constants(
5950     node: Union[ast.AST, ast3.AST, ast27.AST]
5951 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
5952     """Map ast nodes deprecated in 3.8 to Constant."""
5953     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
5954         return ast.Constant(value=node.s)
5955
5956     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
5957         return ast.Constant(value=node.n)
5958
5959     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
5960         return ast.Constant(value=node.value)
5961
5962     return node
5963
5964
5965 def _stringify_ast(
5966     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
5967 ) -> Iterator[str]:
5968     """Simple visitor generating strings to compare ASTs by content."""
5969
5970     node = _fixup_ast_constants(node)
5971
5972     yield f"{'  ' * depth}{node.__class__.__name__}("
5973
5974     for field in sorted(node._fields):  # noqa: F402
5975         # TypeIgnore has only one field 'lineno' which breaks this comparison
5976         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
5977         if sys.version_info >= (3, 8):
5978             type_ignore_classes += (ast.TypeIgnore,)
5979         if isinstance(node, type_ignore_classes):
5980             break
5981
5982         try:
5983             value = getattr(node, field)
5984         except AttributeError:
5985             continue
5986
5987         yield f"{'  ' * (depth+1)}{field}="
5988
5989         if isinstance(value, list):
5990             for item in value:
5991                 # Ignore nested tuples within del statements, because we may insert
5992                 # parentheses and they change the AST.
5993                 if (
5994                     field == "targets"
5995                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
5996                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
5997                 ):
5998                     for item in item.elts:
5999                         yield from _stringify_ast(item, depth + 2)
6000
6001                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6002                     yield from _stringify_ast(item, depth + 2)
6003
6004         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6005             yield from _stringify_ast(value, depth + 2)
6006
6007         else:
6008             # Constant strings may be indented across newlines, if they are
6009             # docstrings; fold spaces after newlines when comparing. Similarly,
6010             # trailing and leading space may be removed.
6011             if (
6012                 isinstance(node, ast.Constant)
6013                 and field == "value"
6014                 and isinstance(value, str)
6015             ):
6016                 normalized = re.sub(r" *\n[ \t]+", "\n ", value).strip()
6017             else:
6018                 normalized = value
6019             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6020
6021     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6022
6023
6024 def assert_equivalent(src: str, dst: str) -> None:
6025     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6026     try:
6027         src_ast = parse_ast(src)
6028     except Exception as exc:
6029         raise AssertionError(
6030             "cannot use --safe with this file; failed to parse source file.  AST"
6031             f" error message: {exc}"
6032         )
6033
6034     try:
6035         dst_ast = parse_ast(dst)
6036     except Exception as exc:
6037         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6038         raise AssertionError(
6039             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6040             " on https://github.com/psf/black/issues.  This invalid output might be"
6041             f" helpful: {log}"
6042         ) from None
6043
6044     src_ast_str = "\n".join(_stringify_ast(src_ast))
6045     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6046     if src_ast_str != dst_ast_str:
6047         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6048         raise AssertionError(
6049             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6050             " source.  Please report a bug on https://github.com/psf/black/issues. "
6051             f" This diff might be helpful: {log}"
6052         ) from None
6053
6054
6055 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6056     """Raise AssertionError if `dst` reformats differently the second time."""
6057     newdst = format_str(dst, mode=mode)
6058     if dst != newdst:
6059         log = dump_to_file(
6060             diff(src, dst, "source", "first pass"),
6061             diff(dst, newdst, "first pass", "second pass"),
6062         )
6063         raise AssertionError(
6064             "INTERNAL ERROR: Black produced different code on the second pass of the"
6065             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6066             f"  This diff might be helpful: {log}"
6067         ) from None
6068
6069
6070 @mypyc_attr(patchable=True)
6071 def dump_to_file(*output: str) -> str:
6072     """Dump `output` to a temporary file. Return path to the file."""
6073     with tempfile.NamedTemporaryFile(
6074         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6075     ) as f:
6076         for lines in output:
6077             f.write(lines)
6078             if lines and lines[-1] != "\n":
6079                 f.write("\n")
6080     return f.name
6081
6082
6083 @contextmanager
6084 def nullcontext() -> Iterator[None]:
6085     """Return an empty context manager.
6086
6087     To be used like `nullcontext` in Python 3.7.
6088     """
6089     yield
6090
6091
6092 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6093     """Return a unified diff string between strings `a` and `b`."""
6094     import difflib
6095
6096     a_lines = [line + "\n" for line in a.splitlines()]
6097     b_lines = [line + "\n" for line in b.splitlines()]
6098     return "".join(
6099         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6100     )
6101
6102
6103 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6104     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6105     err("Aborted!")
6106     for task in tasks:
6107         task.cancel()
6108
6109
6110 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6111     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6112     try:
6113         if sys.version_info[:2] >= (3, 7):
6114             all_tasks = asyncio.all_tasks
6115         else:
6116             all_tasks = asyncio.Task.all_tasks
6117         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6118         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6119         if not to_cancel:
6120             return
6121
6122         for task in to_cancel:
6123             task.cancel()
6124         loop.run_until_complete(
6125             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6126         )
6127     finally:
6128         # `concurrent.futures.Future` objects cannot be cancelled once they
6129         # are already running. There might be some when the `shutdown()` happened.
6130         # Silence their logger's spew about the event loop being closed.
6131         cf_logger = logging.getLogger("concurrent.futures")
6132         cf_logger.setLevel(logging.CRITICAL)
6133         loop.close()
6134
6135
6136 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6137     """Replace `regex` with `replacement` twice on `original`.
6138
6139     This is used by string normalization to perform replaces on
6140     overlapping matches.
6141     """
6142     return regex.sub(replacement, regex.sub(replacement, original))
6143
6144
6145 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6146     """Compile a regular expression string in `regex`.
6147
6148     If it contains newlines, use verbose mode.
6149     """
6150     if "\n" in regex:
6151         regex = "(?x)" + regex
6152     compiled: Pattern[str] = re.compile(regex)
6153     return compiled
6154
6155
6156 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6157     """Like `reversed(enumerate(sequence))` if that were possible."""
6158     index = len(sequence) - 1
6159     for element in reversed(sequence):
6160         yield (index, element)
6161         index -= 1
6162
6163
6164 def enumerate_with_length(
6165     line: Line, reversed: bool = False
6166 ) -> Iterator[Tuple[Index, Leaf, int]]:
6167     """Return an enumeration of leaves with their length.
6168
6169     Stops prematurely on multiline strings and standalone comments.
6170     """
6171     op = cast(
6172         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6173         enumerate_reversed if reversed else enumerate,
6174     )
6175     for index, leaf in op(line.leaves):
6176         length = len(leaf.prefix) + len(leaf.value)
6177         if "\n" in leaf.value:
6178             return  # Multiline strings, we can't continue.
6179
6180         for comment in line.comments_after(leaf):
6181             length += len(comment.value)
6182
6183         yield index, leaf, length
6184
6185
6186 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6187     """Return True if `line` is no longer than `line_length`.
6188
6189     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6190     """
6191     if not line_str:
6192         line_str = line_to_string(line)
6193     return (
6194         len(line_str) <= line_length
6195         and "\n" not in line_str  # multiline strings
6196         and not line.contains_standalone_comments()
6197     )
6198
6199
6200 def can_be_split(line: Line) -> bool:
6201     """Return False if the line cannot be split *for sure*.
6202
6203     This is not an exhaustive search but a cheap heuristic that we can use to
6204     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6205     in unnecessary parentheses).
6206     """
6207     leaves = line.leaves
6208     if len(leaves) < 2:
6209         return False
6210
6211     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6212         call_count = 0
6213         dot_count = 0
6214         next = leaves[-1]
6215         for leaf in leaves[-2::-1]:
6216             if leaf.type in OPENING_BRACKETS:
6217                 if next.type not in CLOSING_BRACKETS:
6218                     return False
6219
6220                 call_count += 1
6221             elif leaf.type == token.DOT:
6222                 dot_count += 1
6223             elif leaf.type == token.NAME:
6224                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6225                     return False
6226
6227             elif leaf.type not in CLOSING_BRACKETS:
6228                 return False
6229
6230             if dot_count > 1 and call_count > 1:
6231                 return False
6232
6233     return True
6234
6235
6236 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
6237     """Does `line` have a shape safe to reformat without optional parens around it?
6238
6239     Returns True for only a subset of potentially nice looking formattings but
6240     the point is to not return false positives that end up producing lines that
6241     are too long.
6242     """
6243     bt = line.bracket_tracker
6244     if not bt.delimiters:
6245         # Without delimiters the optional parentheses are useless.
6246         return True
6247
6248     max_priority = bt.max_delimiter_priority()
6249     if bt.delimiter_count_with_priority(max_priority) > 1:
6250         # With more than one delimiter of a kind the optional parentheses read better.
6251         return False
6252
6253     if max_priority == DOT_PRIORITY:
6254         # A single stranded method call doesn't require optional parentheses.
6255         return True
6256
6257     assert len(line.leaves) >= 2, "Stranded delimiter"
6258
6259     first = line.leaves[0]
6260     second = line.leaves[1]
6261     penultimate = line.leaves[-2]
6262     last = line.leaves[-1]
6263
6264     # With a single delimiter, omit if the expression starts or ends with
6265     # a bracket.
6266     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6267         remainder = False
6268         length = 4 * line.depth
6269         for _index, leaf, leaf_length in enumerate_with_length(line):
6270             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6271                 remainder = True
6272             if remainder:
6273                 length += leaf_length
6274                 if length > line_length:
6275                     break
6276
6277                 if leaf.type in OPENING_BRACKETS:
6278                     # There are brackets we can further split on.
6279                     remainder = False
6280
6281         else:
6282             # checked the entire string and line length wasn't exceeded
6283             if len(line.leaves) == _index + 1:
6284                 return True
6285
6286         # Note: we are not returning False here because a line might have *both*
6287         # a leading opening bracket and a trailing closing bracket.  If the
6288         # opening bracket doesn't match our rule, maybe the closing will.
6289
6290     if (
6291         last.type == token.RPAR
6292         or last.type == token.RBRACE
6293         or (
6294             # don't use indexing for omitting optional parentheses;
6295             # it looks weird
6296             last.type == token.RSQB
6297             and last.parent
6298             and last.parent.type != syms.trailer
6299         )
6300     ):
6301         if penultimate.type in OPENING_BRACKETS:
6302             # Empty brackets don't help.
6303             return False
6304
6305         if is_multiline_string(first):
6306             # Additional wrapping of a multiline string in this situation is
6307             # unnecessary.
6308             return True
6309
6310         length = 4 * line.depth
6311         seen_other_brackets = False
6312         for _index, leaf, leaf_length in enumerate_with_length(line):
6313             length += leaf_length
6314             if leaf is last.opening_bracket:
6315                 if seen_other_brackets or length <= line_length:
6316                     return True
6317
6318             elif leaf.type in OPENING_BRACKETS:
6319                 # There are brackets we can further split on.
6320                 seen_other_brackets = True
6321
6322     return False
6323
6324
6325 def get_cache_file(mode: Mode) -> Path:
6326     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6327
6328
6329 def read_cache(mode: Mode) -> Cache:
6330     """Read the cache if it exists and is well formed.
6331
6332     If it is not well formed, the call to write_cache later should resolve the issue.
6333     """
6334     cache_file = get_cache_file(mode)
6335     if not cache_file.exists():
6336         return {}
6337
6338     with cache_file.open("rb") as fobj:
6339         try:
6340             cache: Cache = pickle.load(fobj)
6341         except (pickle.UnpicklingError, ValueError):
6342             return {}
6343
6344     return cache
6345
6346
6347 def get_cache_info(path: Path) -> CacheInfo:
6348     """Return the information used to check if a file is already formatted or not."""
6349     stat = path.stat()
6350     return stat.st_mtime, stat.st_size
6351
6352
6353 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6354     """Split an iterable of paths in `sources` into two sets.
6355
6356     The first contains paths of files that modified on disk or are not in the
6357     cache. The other contains paths to non-modified files.
6358     """
6359     todo, done = set(), set()
6360     for src in sources:
6361         src = src.resolve()
6362         if cache.get(src) != get_cache_info(src):
6363             todo.add(src)
6364         else:
6365             done.add(src)
6366     return todo, done
6367
6368
6369 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6370     """Update the cache file."""
6371     cache_file = get_cache_file(mode)
6372     try:
6373         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6374         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6375         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6376             pickle.dump(new_cache, f, protocol=4)
6377         os.replace(f.name, cache_file)
6378     except OSError:
6379         pass
6380
6381
6382 def patch_click() -> None:
6383     """Make Click not crash.
6384
6385     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6386     default which restricts paths that it can access during the lifetime of the
6387     application.  Click refuses to work in this scenario by raising a RuntimeError.
6388
6389     In case of Black the likelihood that non-ASCII characters are going to be used in
6390     file paths is minimal since it's Python source code.  Moreover, this crash was
6391     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6392     """
6393     try:
6394         from click import core
6395         from click import _unicodefun  # type: ignore
6396     except ModuleNotFoundError:
6397         return
6398
6399     for module in (core, _unicodefun):
6400         if hasattr(module, "_verify_python3_env"):
6401             module._verify_python3_env = lambda: None
6402
6403
6404 def patched_main() -> None:
6405     freeze_support()
6406     patch_click()
6407     main()
6408
6409
6410 def fix_docstring(docstring: str, prefix: str) -> str:
6411     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6412     if not docstring:
6413         return ""
6414     # Convert tabs to spaces (following the normal Python rules)
6415     # and split into a list of lines:
6416     lines = docstring.expandtabs().splitlines()
6417     # Determine minimum indentation (first line doesn't count):
6418     indent = sys.maxsize
6419     for line in lines[1:]:
6420         stripped = line.lstrip()
6421         if stripped:
6422             indent = min(indent, len(line) - len(stripped))
6423     # Remove indentation (first line is special):
6424     trimmed = [lines[0].strip()]
6425     if indent < sys.maxsize:
6426         last_line_idx = len(lines) - 2
6427         for i, line in enumerate(lines[1:]):
6428             stripped_line = line[indent:].rstrip()
6429             if stripped_line or i == last_line_idx:
6430                 trimmed.append(prefix + stripped_line)
6431             else:
6432                 trimmed.append("")
6433     # Return a single string:
6434     return "\n".join(trimmed)
6435
6436
6437 if __name__ == "__main__":
6438     patched_main()