]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Make --exclude only apply to recursively found files (#1591)
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 T = TypeVar("T")
116 E = TypeVar("E", bound=Exception)
117
118
119 class Ok(Generic[T]):
120     def __init__(self, value: T) -> None:
121         self._value = value
122
123     def ok(self) -> T:
124         return self._value
125
126
127 class Err(Generic[E]):
128     def __init__(self, e: E) -> None:
129         self._e = e
130
131     def err(self) -> E:
132         return self._e
133
134
135 # The 'Result' return type is used to implement an error-handling model heavily
136 # influenced by that used by the Rust programming language
137 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
138 Result = Union[Ok[T], Err[E]]
139 TResult = Result[T, CannotTransform]  # (T)ransform Result
140 TMatchResult = TResult[Index]
141
142
143 class WriteBack(Enum):
144     NO = 0
145     YES = 1
146     DIFF = 2
147     CHECK = 3
148     COLOR_DIFF = 4
149
150     @classmethod
151     def from_configuration(
152         cls, *, check: bool, diff: bool, color: bool = False
153     ) -> "WriteBack":
154         if check and not diff:
155             return cls.CHECK
156
157         if diff and color:
158             return cls.COLOR_DIFF
159
160         return cls.DIFF if diff else cls.YES
161
162
163 class Changed(Enum):
164     NO = 0
165     CACHED = 1
166     YES = 2
167
168
169 class TargetVersion(Enum):
170     PY27 = 2
171     PY33 = 3
172     PY34 = 4
173     PY35 = 5
174     PY36 = 6
175     PY37 = 7
176     PY38 = 8
177
178     def is_python2(self) -> bool:
179         return self is TargetVersion.PY27
180
181
182 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
183
184
185 class Feature(Enum):
186     # All string literals are unicode
187     UNICODE_LITERALS = 1
188     F_STRINGS = 2
189     NUMERIC_UNDERSCORES = 3
190     TRAILING_COMMA_IN_CALL = 4
191     TRAILING_COMMA_IN_DEF = 5
192     # The following two feature-flags are mutually exclusive, and exactly one should be
193     # set for every version of python.
194     ASYNC_IDENTIFIERS = 6
195     ASYNC_KEYWORDS = 7
196     ASSIGNMENT_EXPRESSIONS = 8
197     POS_ONLY_ARGUMENTS = 9
198
199
200 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
201     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
202     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
203     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
204     TargetVersion.PY35: {
205         Feature.UNICODE_LITERALS,
206         Feature.TRAILING_COMMA_IN_CALL,
207         Feature.ASYNC_IDENTIFIERS,
208     },
209     TargetVersion.PY36: {
210         Feature.UNICODE_LITERALS,
211         Feature.F_STRINGS,
212         Feature.NUMERIC_UNDERSCORES,
213         Feature.TRAILING_COMMA_IN_CALL,
214         Feature.TRAILING_COMMA_IN_DEF,
215         Feature.ASYNC_IDENTIFIERS,
216     },
217     TargetVersion.PY37: {
218         Feature.UNICODE_LITERALS,
219         Feature.F_STRINGS,
220         Feature.NUMERIC_UNDERSCORES,
221         Feature.TRAILING_COMMA_IN_CALL,
222         Feature.TRAILING_COMMA_IN_DEF,
223         Feature.ASYNC_KEYWORDS,
224     },
225     TargetVersion.PY38: {
226         Feature.UNICODE_LITERALS,
227         Feature.F_STRINGS,
228         Feature.NUMERIC_UNDERSCORES,
229         Feature.TRAILING_COMMA_IN_CALL,
230         Feature.TRAILING_COMMA_IN_DEF,
231         Feature.ASYNC_KEYWORDS,
232         Feature.ASSIGNMENT_EXPRESSIONS,
233         Feature.POS_ONLY_ARGUMENTS,
234     },
235 }
236
237
238 @dataclass
239 class Mode:
240     target_versions: Set[TargetVersion] = field(default_factory=set)
241     line_length: int = DEFAULT_LINE_LENGTH
242     string_normalization: bool = True
243     is_pyi: bool = False
244
245     def get_cache_key(self) -> str:
246         if self.target_versions:
247             version_str = ",".join(
248                 str(version.value)
249                 for version in sorted(self.target_versions, key=lambda v: v.value)
250             )
251         else:
252             version_str = "-"
253         parts = [
254             version_str,
255             str(self.line_length),
256             str(int(self.string_normalization)),
257             str(int(self.is_pyi)),
258         ]
259         return ".".join(parts)
260
261
262 # Legacy name, left for integrations.
263 FileMode = Mode
264
265
266 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
267     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
268
269
270 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
271     """Find the absolute filepath to a pyproject.toml if it exists"""
272     path_project_root = find_project_root(path_search_start)
273     path_pyproject_toml = path_project_root / "pyproject.toml"
274     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
275
276
277 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
278     """Parse a pyproject toml file, pulling out relevant parts for Black
279
280     If parsing fails, will raise a toml.TomlDecodeError
281     """
282     pyproject_toml = toml.load(path_config)
283     config = pyproject_toml.get("tool", {}).get("black", {})
284     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
285
286
287 def read_pyproject_toml(
288     ctx: click.Context, param: click.Parameter, value: Optional[str]
289 ) -> Optional[str]:
290     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
291
292     Returns the path to a successfully found and read configuration file, None
293     otherwise.
294     """
295     if not value:
296         value = find_pyproject_toml(ctx.params.get("src", ()))
297         if value is None:
298             return None
299
300     try:
301         config = parse_pyproject_toml(value)
302     except (toml.TomlDecodeError, OSError) as e:
303         raise click.FileError(
304             filename=value, hint=f"Error reading configuration file: {e}"
305         )
306
307     if not config:
308         return None
309     else:
310         # Sanitize the values to be Click friendly. For more information please see:
311         # https://github.com/psf/black/issues/1458
312         # https://github.com/pallets/click/issues/1567
313         config = {
314             k: str(v) if not isinstance(v, (list, dict)) else v
315             for k, v in config.items()
316         }
317
318     target_version = config.get("target_version")
319     if target_version is not None and not isinstance(target_version, list):
320         raise click.BadOptionUsage(
321             "target-version", "Config key target-version must be a list"
322         )
323
324     default_map: Dict[str, Any] = {}
325     if ctx.default_map:
326         default_map.update(ctx.default_map)
327     default_map.update(config)
328
329     ctx.default_map = default_map
330     return value
331
332
333 def target_version_option_callback(
334     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
335 ) -> List[TargetVersion]:
336     """Compute the target versions from a --target-version flag.
337
338     This is its own function because mypy couldn't infer the type correctly
339     when it was a lambda, causing mypyc trouble.
340     """
341     return [TargetVersion[val.upper()] for val in v]
342
343
344 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
345 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
346 @click.option(
347     "-l",
348     "--line-length",
349     type=int,
350     default=DEFAULT_LINE_LENGTH,
351     help="How many characters per line to allow.",
352     show_default=True,
353 )
354 @click.option(
355     "-t",
356     "--target-version",
357     type=click.Choice([v.name.lower() for v in TargetVersion]),
358     callback=target_version_option_callback,
359     multiple=True,
360     help=(
361         "Python versions that should be supported by Black's output. [default: per-file"
362         " auto-detection]"
363     ),
364 )
365 @click.option(
366     "--pyi",
367     is_flag=True,
368     help=(
369         "Format all input files like typing stubs regardless of file extension (useful"
370         " when piping source on standard input)."
371     ),
372 )
373 @click.option(
374     "-S",
375     "--skip-string-normalization",
376     is_flag=True,
377     help="Don't normalize string quotes or prefixes.",
378 )
379 @click.option(
380     "--check",
381     is_flag=True,
382     help=(
383         "Don't write the files back, just return the status.  Return code 0 means"
384         " nothing would change.  Return code 1 means some files would be reformatted."
385         " Return code 123 means there was an internal error."
386     ),
387 )
388 @click.option(
389     "--diff",
390     is_flag=True,
391     help="Don't write the files back, just output a diff for each file on stdout.",
392 )
393 @click.option(
394     "--color/--no-color",
395     is_flag=True,
396     help="Show colored diff. Only applies when `--diff` is given.",
397 )
398 @click.option(
399     "--fast/--safe",
400     is_flag=True,
401     help="If --fast given, skip temporary sanity checks. [default: --safe]",
402 )
403 @click.option(
404     "--include",
405     type=str,
406     default=DEFAULT_INCLUDES,
407     help=(
408         "A regular expression that matches files and directories that should be"
409         " included on recursive searches.  An empty value means all files are included"
410         " regardless of the name.  Use forward slashes for directories on all platforms"
411         " (Windows, too).  Exclusions are calculated first, inclusions later."
412     ),
413     show_default=True,
414 )
415 @click.option(
416     "--exclude",
417     type=str,
418     default=DEFAULT_EXCLUDES,
419     help=(
420         "A regular expression that matches files and directories that should be"
421         " excluded on recursive searches.  An empty value means no paths are excluded."
422         " Use forward slashes for directories on all platforms (Windows, too). "
423         " Exclusions are calculated first, inclusions later."
424     ),
425     show_default=True,
426 )
427 @click.option(
428     "--force-exclude",
429     type=str,
430     help=(
431         "Like --exclude, but files and directories matching this regex will be "
432         "excluded even when they are passed explicitly as arguments"
433     ),
434 )
435 @click.option(
436     "-q",
437     "--quiet",
438     is_flag=True,
439     help=(
440         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
441         " those with 2>/dev/null."
442     ),
443 )
444 @click.option(
445     "-v",
446     "--verbose",
447     is_flag=True,
448     help=(
449         "Also emit messages to stderr about files that were not changed or were ignored"
450         " due to --exclude=."
451     ),
452 )
453 @click.version_option(version=__version__)
454 @click.argument(
455     "src",
456     nargs=-1,
457     type=click.Path(
458         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
459     ),
460     is_eager=True,
461 )
462 @click.option(
463     "--config",
464     type=click.Path(
465         exists=True,
466         file_okay=True,
467         dir_okay=False,
468         readable=True,
469         allow_dash=False,
470         path_type=str,
471     ),
472     is_eager=True,
473     callback=read_pyproject_toml,
474     help="Read configuration from FILE path.",
475 )
476 @click.pass_context
477 def main(
478     ctx: click.Context,
479     code: Optional[str],
480     line_length: int,
481     target_version: List[TargetVersion],
482     check: bool,
483     diff: bool,
484     color: bool,
485     fast: bool,
486     pyi: bool,
487     skip_string_normalization: bool,
488     quiet: bool,
489     verbose: bool,
490     include: str,
491     exclude: str,
492     force_exclude: Optional[str],
493     src: Tuple[str, ...],
494     config: Optional[str],
495 ) -> None:
496     """The uncompromising code formatter."""
497     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
498     if target_version:
499         versions = set(target_version)
500     else:
501         # We'll autodetect later.
502         versions = set()
503     mode = Mode(
504         target_versions=versions,
505         line_length=line_length,
506         is_pyi=pyi,
507         string_normalization=not skip_string_normalization,
508     )
509     if config and verbose:
510         out(f"Using configuration from {config}.", bold=False, fg="blue")
511     if code is not None:
512         print(format_str(code, mode=mode))
513         ctx.exit(0)
514     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
515     sources = get_sources(
516         ctx=ctx,
517         src=src,
518         quiet=quiet,
519         verbose=verbose,
520         include=include,
521         exclude=exclude,
522         force_exclude=force_exclude,
523         report=report,
524     )
525
526     path_empty(
527         sources,
528         "No Python files are present to be formatted. Nothing to do 😴",
529         quiet,
530         verbose,
531         ctx,
532     )
533
534     if len(sources) == 1:
535         reformat_one(
536             src=sources.pop(),
537             fast=fast,
538             write_back=write_back,
539             mode=mode,
540             report=report,
541         )
542     else:
543         reformat_many(
544             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
545         )
546
547     if verbose or not quiet:
548         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
549         click.secho(str(report), err=True)
550     ctx.exit(report.return_code)
551
552
553 def get_sources(
554     *,
555     ctx: click.Context,
556     src: Tuple[str, ...],
557     quiet: bool,
558     verbose: bool,
559     include: str,
560     exclude: str,
561     force_exclude: Optional[str],
562     report: "Report",
563 ) -> Set[Path]:
564     """Compute the set of files to be formatted."""
565     try:
566         include_regex = re_compile_maybe_verbose(include)
567     except re.error:
568         err(f"Invalid regular expression for include given: {include!r}")
569         ctx.exit(2)
570     try:
571         exclude_regex = re_compile_maybe_verbose(exclude)
572     except re.error:
573         err(f"Invalid regular expression for exclude given: {exclude!r}")
574         ctx.exit(2)
575     try:
576         force_exclude_regex = (
577             re_compile_maybe_verbose(force_exclude) if force_exclude else None
578         )
579     except re.error:
580         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
581         ctx.exit(2)
582
583     root = find_project_root(src)
584     sources: Set[Path] = set()
585     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
586     gitignore = get_gitignore(root)
587
588     for s in src:
589         p = Path(s)
590         if p.is_dir():
591             sources.update(
592                 gen_python_files(
593                     p.iterdir(),
594                     root,
595                     include_regex,
596                     exclude_regex,
597                     force_exclude_regex,
598                     report,
599                     gitignore,
600                 )
601             )
602         elif s == "-":
603             sources.add(p)
604         elif p.is_file():
605             normalized_path = normalize_path_maybe_ignore(p, root, report)
606             if normalized_path is None:
607                 continue
608
609             normalized_path = "/" + normalized_path
610             # Hard-exclude any files that matches the `--force-exclude` regex.
611             if force_exclude_regex:
612                 force_exclude_match = force_exclude_regex.search(normalized_path)
613             else:
614                 force_exclude_match = None
615             if force_exclude_match and force_exclude_match.group(0):
616                 report.path_ignored(p, "matches the --force-exclude regular expression")
617                 continue
618
619             sources.add(p)
620         else:
621             err(f"invalid path: {s}")
622     return sources
623
624
625 def path_empty(
626     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
627 ) -> None:
628     """
629     Exit if there is no `src` provided for formatting
630     """
631     if len(src) == 0:
632         if verbose or not quiet:
633             out(msg)
634             ctx.exit(0)
635
636
637 def reformat_one(
638     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
639 ) -> None:
640     """Reformat a single file under `src` without spawning child processes.
641
642     `fast`, `write_back`, and `mode` options are passed to
643     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
644     """
645     try:
646         changed = Changed.NO
647         if not src.is_file() and str(src) == "-":
648             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
649                 changed = Changed.YES
650         else:
651             cache: Cache = {}
652             if write_back != WriteBack.DIFF:
653                 cache = read_cache(mode)
654                 res_src = src.resolve()
655                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
656                     changed = Changed.CACHED
657             if changed is not Changed.CACHED and format_file_in_place(
658                 src, fast=fast, write_back=write_back, mode=mode
659             ):
660                 changed = Changed.YES
661             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
662                 write_back is WriteBack.CHECK and changed is Changed.NO
663             ):
664                 write_cache(cache, [src], mode)
665         report.done(src, changed)
666     except Exception as exc:
667         report.failed(src, str(exc))
668
669
670 def reformat_many(
671     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
672 ) -> None:
673     """Reformat multiple files using a ProcessPoolExecutor."""
674     executor: Executor
675     loop = asyncio.get_event_loop()
676     worker_count = os.cpu_count()
677     if sys.platform == "win32":
678         # Work around https://bugs.python.org/issue26903
679         worker_count = min(worker_count, 61)
680     try:
681         executor = ProcessPoolExecutor(max_workers=worker_count)
682     except (ImportError, OSError):
683         # we arrive here if the underlying system does not support multi-processing
684         # like in AWS Lambda or Termux, in which case we gracefully fallback to
685         # a ThreadPollExecutor with just a single worker (more workers would not do us
686         # any good due to the Global Interpreter Lock)
687         executor = ThreadPoolExecutor(max_workers=1)
688
689     try:
690         loop.run_until_complete(
691             schedule_formatting(
692                 sources=sources,
693                 fast=fast,
694                 write_back=write_back,
695                 mode=mode,
696                 report=report,
697                 loop=loop,
698                 executor=executor,
699             )
700         )
701     finally:
702         shutdown(loop)
703         if executor is not None:
704             executor.shutdown()
705
706
707 async def schedule_formatting(
708     sources: Set[Path],
709     fast: bool,
710     write_back: WriteBack,
711     mode: Mode,
712     report: "Report",
713     loop: asyncio.AbstractEventLoop,
714     executor: Executor,
715 ) -> None:
716     """Run formatting of `sources` in parallel using the provided `executor`.
717
718     (Use ProcessPoolExecutors for actual parallelism.)
719
720     `write_back`, `fast`, and `mode` options are passed to
721     :func:`format_file_in_place`.
722     """
723     cache: Cache = {}
724     if write_back != WriteBack.DIFF:
725         cache = read_cache(mode)
726         sources, cached = filter_cached(cache, sources)
727         for src in sorted(cached):
728             report.done(src, Changed.CACHED)
729     if not sources:
730         return
731
732     cancelled = []
733     sources_to_cache = []
734     lock = None
735     if write_back == WriteBack.DIFF:
736         # For diff output, we need locks to ensure we don't interleave output
737         # from different processes.
738         manager = Manager()
739         lock = manager.Lock()
740     tasks = {
741         asyncio.ensure_future(
742             loop.run_in_executor(
743                 executor, format_file_in_place, src, fast, mode, write_back, lock
744             )
745         ): src
746         for src in sorted(sources)
747     }
748     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
749     try:
750         loop.add_signal_handler(signal.SIGINT, cancel, pending)
751         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
752     except NotImplementedError:
753         # There are no good alternatives for these on Windows.
754         pass
755     while pending:
756         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
757         for task in done:
758             src = tasks.pop(task)
759             if task.cancelled():
760                 cancelled.append(task)
761             elif task.exception():
762                 report.failed(src, str(task.exception()))
763             else:
764                 changed = Changed.YES if task.result() else Changed.NO
765                 # If the file was written back or was successfully checked as
766                 # well-formatted, store this information in the cache.
767                 if write_back is WriteBack.YES or (
768                     write_back is WriteBack.CHECK and changed is Changed.NO
769                 ):
770                     sources_to_cache.append(src)
771                 report.done(src, changed)
772     if cancelled:
773         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
774     if sources_to_cache:
775         write_cache(cache, sources_to_cache, mode)
776
777
778 def format_file_in_place(
779     src: Path,
780     fast: bool,
781     mode: Mode,
782     write_back: WriteBack = WriteBack.NO,
783     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
784 ) -> bool:
785     """Format file under `src` path. Return True if changed.
786
787     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
788     code to the file.
789     `mode` and `fast` options are passed to :func:`format_file_contents`.
790     """
791     if src.suffix == ".pyi":
792         mode = replace(mode, is_pyi=True)
793
794     then = datetime.utcfromtimestamp(src.stat().st_mtime)
795     with open(src, "rb") as buf:
796         src_contents, encoding, newline = decode_bytes(buf.read())
797     try:
798         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
799     except NothingChanged:
800         return False
801
802     if write_back == WriteBack.YES:
803         with open(src, "w", encoding=encoding, newline=newline) as f:
804             f.write(dst_contents)
805     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
806         now = datetime.utcnow()
807         src_name = f"{src}\t{then} +0000"
808         dst_name = f"{src}\t{now} +0000"
809         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
810
811         if write_back == write_back.COLOR_DIFF:
812             diff_contents = color_diff(diff_contents)
813
814         with lock or nullcontext():
815             f = io.TextIOWrapper(
816                 sys.stdout.buffer,
817                 encoding=encoding,
818                 newline=newline,
819                 write_through=True,
820             )
821             f = wrap_stream_for_windows(f)
822             f.write(diff_contents)
823             f.detach()
824
825     return True
826
827
828 def color_diff(contents: str) -> str:
829     """Inject the ANSI color codes to the diff."""
830     lines = contents.split("\n")
831     for i, line in enumerate(lines):
832         if line.startswith("+++") or line.startswith("---"):
833             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
834         if line.startswith("@@"):
835             line = "\033[36m" + line + "\033[0m"  # cyan, reset
836         if line.startswith("+"):
837             line = "\033[32m" + line + "\033[0m"  # green, reset
838         elif line.startswith("-"):
839             line = "\033[31m" + line + "\033[0m"  # red, reset
840         lines[i] = line
841     return "\n".join(lines)
842
843
844 def wrap_stream_for_windows(
845     f: io.TextIOWrapper,
846 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
847     """
848     Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
849
850     If `colorama` is not found, then no change is made. If `colorama` does
851     exist, then it handles the logic to determine whether or not to change
852     things.
853     """
854     try:
855         from colorama import initialise
856
857         # We set `strip=False` so that we can don't have to modify
858         # test_express_diff_with_color.
859         f = initialise.wrap_stream(
860             f, convert=None, strip=False, autoreset=False, wrap=True
861         )
862
863         # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
864         # which does not have a `detach()` method. So we fake one.
865         f.detach = lambda *args, **kwargs: None  # type: ignore
866     except ImportError:
867         pass
868
869     return f
870
871
872 def format_stdin_to_stdout(
873     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
874 ) -> bool:
875     """Format file on stdin. Return True if changed.
876
877     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
878     write a diff to stdout. The `mode` argument is passed to
879     :func:`format_file_contents`.
880     """
881     then = datetime.utcnow()
882     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
883     dst = src
884     try:
885         dst = format_file_contents(src, fast=fast, mode=mode)
886         return True
887
888     except NothingChanged:
889         return False
890
891     finally:
892         f = io.TextIOWrapper(
893             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
894         )
895         if write_back == WriteBack.YES:
896             f.write(dst)
897         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
898             now = datetime.utcnow()
899             src_name = f"STDIN\t{then} +0000"
900             dst_name = f"STDOUT\t{now} +0000"
901             d = diff(src, dst, src_name, dst_name)
902             if write_back == WriteBack.COLOR_DIFF:
903                 d = color_diff(d)
904                 f = wrap_stream_for_windows(f)
905             f.write(d)
906         f.detach()
907
908
909 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
910     """Reformat contents a file and return new contents.
911
912     If `fast` is False, additionally confirm that the reformatted code is
913     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
914     `mode` is passed to :func:`format_str`.
915     """
916     if src_contents.strip() == "":
917         raise NothingChanged
918
919     dst_contents = format_str(src_contents, mode=mode)
920     if src_contents == dst_contents:
921         raise NothingChanged
922
923     if not fast:
924         assert_equivalent(src_contents, dst_contents)
925         assert_stable(src_contents, dst_contents, mode=mode)
926     return dst_contents
927
928
929 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
930     """Reformat a string and return new contents.
931
932     `mode` determines formatting options, such as how many characters per line are
933     allowed.  Example:
934
935     >>> import black
936     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
937     def f(arg: str = "") -> None:
938         ...
939
940     A more complex example:
941     >>> print(
942     ...   black.format_str(
943     ...     "def f(arg:str='')->None: hey",
944     ...     mode=black.Mode(
945     ...       target_versions={black.TargetVersion.PY36},
946     ...       line_length=10,
947     ...       string_normalization=False,
948     ...       is_pyi=False,
949     ...     ),
950     ...   ),
951     ... )
952     def f(
953         arg: str = '',
954     ) -> None:
955         hey
956
957     """
958     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
959     dst_contents = []
960     future_imports = get_future_imports(src_node)
961     if mode.target_versions:
962         versions = mode.target_versions
963     else:
964         versions = detect_target_versions(src_node)
965     normalize_fmt_off(src_node)
966     lines = LineGenerator(
967         remove_u_prefix="unicode_literals" in future_imports
968         or supports_feature(versions, Feature.UNICODE_LITERALS),
969         is_pyi=mode.is_pyi,
970         normalize_strings=mode.string_normalization,
971     )
972     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
973     empty_line = Line()
974     after = 0
975     split_line_features = {
976         feature
977         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
978         if supports_feature(versions, feature)
979     }
980     for current_line in lines.visit(src_node):
981         dst_contents.append(str(empty_line) * after)
982         before, after = elt.maybe_empty_lines(current_line)
983         dst_contents.append(str(empty_line) * before)
984         for line in transform_line(
985             current_line,
986             line_length=mode.line_length,
987             normalize_strings=mode.string_normalization,
988             features=split_line_features,
989         ):
990             dst_contents.append(str(line))
991     return "".join(dst_contents)
992
993
994 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
995     """Return a tuple of (decoded_contents, encoding, newline).
996
997     `newline` is either CRLF or LF but `decoded_contents` is decoded with
998     universal newlines (i.e. only contains LF).
999     """
1000     srcbuf = io.BytesIO(src)
1001     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1002     if not lines:
1003         return "", encoding, "\n"
1004
1005     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1006     srcbuf.seek(0)
1007     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1008         return tiow.read(), encoding, newline
1009
1010
1011 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1012     if not target_versions:
1013         # No target_version specified, so try all grammars.
1014         return [
1015             # Python 3.7+
1016             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1017             # Python 3.0-3.6
1018             pygram.python_grammar_no_print_statement_no_exec_statement,
1019             # Python 2.7 with future print_function import
1020             pygram.python_grammar_no_print_statement,
1021             # Python 2.7
1022             pygram.python_grammar,
1023         ]
1024
1025     if all(version.is_python2() for version in target_versions):
1026         # Python 2-only code, so try Python 2 grammars.
1027         return [
1028             # Python 2.7 with future print_function import
1029             pygram.python_grammar_no_print_statement,
1030             # Python 2.7
1031             pygram.python_grammar,
1032         ]
1033
1034     # Python 3-compatible code, so only try Python 3 grammar.
1035     grammars = []
1036     # If we have to parse both, try to parse async as a keyword first
1037     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1038         # Python 3.7+
1039         grammars.append(
1040             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1041         )
1042     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1043         # Python 3.0-3.6
1044         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1045     # At least one of the above branches must have been taken, because every Python
1046     # version has exactly one of the two 'ASYNC_*' flags
1047     return grammars
1048
1049
1050 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1051     """Given a string with source, return the lib2to3 Node."""
1052     if src_txt[-1:] != "\n":
1053         src_txt += "\n"
1054
1055     for grammar in get_grammars(set(target_versions)):
1056         drv = driver.Driver(grammar, pytree.convert)
1057         try:
1058             result = drv.parse_string(src_txt, True)
1059             break
1060
1061         except ParseError as pe:
1062             lineno, column = pe.context[1]
1063             lines = src_txt.splitlines()
1064             try:
1065                 faulty_line = lines[lineno - 1]
1066             except IndexError:
1067                 faulty_line = "<line number missing in source>"
1068             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1069     else:
1070         raise exc from None
1071
1072     if isinstance(result, Leaf):
1073         result = Node(syms.file_input, [result])
1074     return result
1075
1076
1077 def lib2to3_unparse(node: Node) -> str:
1078     """Given a lib2to3 node, return its string representation."""
1079     code = str(node)
1080     return code
1081
1082
1083 class Visitor(Generic[T]):
1084     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1085
1086     def visit(self, node: LN) -> Iterator[T]:
1087         """Main method to visit `node` and its children.
1088
1089         It tries to find a `visit_*()` method for the given `node.type`, like
1090         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1091         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1092         instead.
1093
1094         Then yields objects of type `T` from the selected visitor.
1095         """
1096         if node.type < 256:
1097             name = token.tok_name[node.type]
1098         else:
1099             name = str(type_repr(node.type))
1100         # We explicitly branch on whether a visitor exists (instead of
1101         # using self.visit_default as the default arg to getattr) in order
1102         # to save needing to create a bound method object and so mypyc can
1103         # generate a native call to visit_default.
1104         visitf = getattr(self, f"visit_{name}", None)
1105         if visitf:
1106             yield from visitf(node)
1107         else:
1108             yield from self.visit_default(node)
1109
1110     def visit_default(self, node: LN) -> Iterator[T]:
1111         """Default `visit_*()` implementation. Recurses to children of `node`."""
1112         if isinstance(node, Node):
1113             for child in node.children:
1114                 yield from self.visit(child)
1115
1116
1117 @dataclass
1118 class DebugVisitor(Visitor[T]):
1119     tree_depth: int = 0
1120
1121     def visit_default(self, node: LN) -> Iterator[T]:
1122         indent = " " * (2 * self.tree_depth)
1123         if isinstance(node, Node):
1124             _type = type_repr(node.type)
1125             out(f"{indent}{_type}", fg="yellow")
1126             self.tree_depth += 1
1127             for child in node.children:
1128                 yield from self.visit(child)
1129
1130             self.tree_depth -= 1
1131             out(f"{indent}/{_type}", fg="yellow", bold=False)
1132         else:
1133             _type = token.tok_name.get(node.type, str(node.type))
1134             out(f"{indent}{_type}", fg="blue", nl=False)
1135             if node.prefix:
1136                 # We don't have to handle prefixes for `Node` objects since
1137                 # that delegates to the first child anyway.
1138                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1139             out(f" {node.value!r}", fg="blue", bold=False)
1140
1141     @classmethod
1142     def show(cls, code: Union[str, Leaf, Node]) -> None:
1143         """Pretty-print the lib2to3 AST of a given string of `code`.
1144
1145         Convenience method for debugging.
1146         """
1147         v: DebugVisitor[None] = DebugVisitor()
1148         if isinstance(code, str):
1149             code = lib2to3_parse(code)
1150         list(v.visit(code))
1151
1152
1153 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1154 STATEMENT: Final = {
1155     syms.if_stmt,
1156     syms.while_stmt,
1157     syms.for_stmt,
1158     syms.try_stmt,
1159     syms.except_clause,
1160     syms.with_stmt,
1161     syms.funcdef,
1162     syms.classdef,
1163 }
1164 STANDALONE_COMMENT: Final = 153
1165 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1166 LOGIC_OPERATORS: Final = {"and", "or"}
1167 COMPARATORS: Final = {
1168     token.LESS,
1169     token.GREATER,
1170     token.EQEQUAL,
1171     token.NOTEQUAL,
1172     token.LESSEQUAL,
1173     token.GREATEREQUAL,
1174 }
1175 MATH_OPERATORS: Final = {
1176     token.VBAR,
1177     token.CIRCUMFLEX,
1178     token.AMPER,
1179     token.LEFTSHIFT,
1180     token.RIGHTSHIFT,
1181     token.PLUS,
1182     token.MINUS,
1183     token.STAR,
1184     token.SLASH,
1185     token.DOUBLESLASH,
1186     token.PERCENT,
1187     token.AT,
1188     token.TILDE,
1189     token.DOUBLESTAR,
1190 }
1191 STARS: Final = {token.STAR, token.DOUBLESTAR}
1192 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1193 VARARGS_PARENTS: Final = {
1194     syms.arglist,
1195     syms.argument,  # double star in arglist
1196     syms.trailer,  # single argument to call
1197     syms.typedargslist,
1198     syms.varargslist,  # lambdas
1199 }
1200 UNPACKING_PARENTS: Final = {
1201     syms.atom,  # single element of a list or set literal
1202     syms.dictsetmaker,
1203     syms.listmaker,
1204     syms.testlist_gexp,
1205     syms.testlist_star_expr,
1206 }
1207 TEST_DESCENDANTS: Final = {
1208     syms.test,
1209     syms.lambdef,
1210     syms.or_test,
1211     syms.and_test,
1212     syms.not_test,
1213     syms.comparison,
1214     syms.star_expr,
1215     syms.expr,
1216     syms.xor_expr,
1217     syms.and_expr,
1218     syms.shift_expr,
1219     syms.arith_expr,
1220     syms.trailer,
1221     syms.term,
1222     syms.power,
1223 }
1224 ASSIGNMENTS: Final = {
1225     "=",
1226     "+=",
1227     "-=",
1228     "*=",
1229     "@=",
1230     "/=",
1231     "%=",
1232     "&=",
1233     "|=",
1234     "^=",
1235     "<<=",
1236     ">>=",
1237     "**=",
1238     "//=",
1239 }
1240 COMPREHENSION_PRIORITY: Final = 20
1241 COMMA_PRIORITY: Final = 18
1242 TERNARY_PRIORITY: Final = 16
1243 LOGIC_PRIORITY: Final = 14
1244 STRING_PRIORITY: Final = 12
1245 COMPARATOR_PRIORITY: Final = 10
1246 MATH_PRIORITIES: Final = {
1247     token.VBAR: 9,
1248     token.CIRCUMFLEX: 8,
1249     token.AMPER: 7,
1250     token.LEFTSHIFT: 6,
1251     token.RIGHTSHIFT: 6,
1252     token.PLUS: 5,
1253     token.MINUS: 5,
1254     token.STAR: 4,
1255     token.SLASH: 4,
1256     token.DOUBLESLASH: 4,
1257     token.PERCENT: 4,
1258     token.AT: 4,
1259     token.TILDE: 3,
1260     token.DOUBLESTAR: 2,
1261 }
1262 DOT_PRIORITY: Final = 1
1263
1264
1265 @dataclass
1266 class BracketTracker:
1267     """Keeps track of brackets on a line."""
1268
1269     depth: int = 0
1270     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1271     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1272     previous: Optional[Leaf] = None
1273     _for_loop_depths: List[int] = field(default_factory=list)
1274     _lambda_argument_depths: List[int] = field(default_factory=list)
1275
1276     def mark(self, leaf: Leaf) -> None:
1277         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1278
1279         All leaves receive an int `bracket_depth` field that stores how deep
1280         within brackets a given leaf is. 0 means there are no enclosing brackets
1281         that started on this line.
1282
1283         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1284         field that it forms a pair with. This is a one-directional link to
1285         avoid reference cycles.
1286
1287         If a leaf is a delimiter (a token on which Black can split the line if
1288         needed) and it's on depth 0, its `id()` is stored in the tracker's
1289         `delimiters` field.
1290         """
1291         if leaf.type == token.COMMENT:
1292             return
1293
1294         self.maybe_decrement_after_for_loop_variable(leaf)
1295         self.maybe_decrement_after_lambda_arguments(leaf)
1296         if leaf.type in CLOSING_BRACKETS:
1297             self.depth -= 1
1298             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1299             leaf.opening_bracket = opening_bracket
1300         leaf.bracket_depth = self.depth
1301         if self.depth == 0:
1302             delim = is_split_before_delimiter(leaf, self.previous)
1303             if delim and self.previous is not None:
1304                 self.delimiters[id(self.previous)] = delim
1305             else:
1306                 delim = is_split_after_delimiter(leaf, self.previous)
1307                 if delim:
1308                     self.delimiters[id(leaf)] = delim
1309         if leaf.type in OPENING_BRACKETS:
1310             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1311             self.depth += 1
1312         self.previous = leaf
1313         self.maybe_increment_lambda_arguments(leaf)
1314         self.maybe_increment_for_loop_variable(leaf)
1315
1316     def any_open_brackets(self) -> bool:
1317         """Return True if there is an yet unmatched open bracket on the line."""
1318         return bool(self.bracket_match)
1319
1320     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1321         """Return the highest priority of a delimiter found on the line.
1322
1323         Values are consistent with what `is_split_*_delimiter()` return.
1324         Raises ValueError on no delimiters.
1325         """
1326         return max(v for k, v in self.delimiters.items() if k not in exclude)
1327
1328     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1329         """Return the number of delimiters with the given `priority`.
1330
1331         If no `priority` is passed, defaults to max priority on the line.
1332         """
1333         if not self.delimiters:
1334             return 0
1335
1336         priority = priority or self.max_delimiter_priority()
1337         return sum(1 for p in self.delimiters.values() if p == priority)
1338
1339     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1340         """In a for loop, or comprehension, the variables are often unpacks.
1341
1342         To avoid splitting on the comma in this situation, increase the depth of
1343         tokens between `for` and `in`.
1344         """
1345         if leaf.type == token.NAME and leaf.value == "for":
1346             self.depth += 1
1347             self._for_loop_depths.append(self.depth)
1348             return True
1349
1350         return False
1351
1352     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1353         """See `maybe_increment_for_loop_variable` above for explanation."""
1354         if (
1355             self._for_loop_depths
1356             and self._for_loop_depths[-1] == self.depth
1357             and leaf.type == token.NAME
1358             and leaf.value == "in"
1359         ):
1360             self.depth -= 1
1361             self._for_loop_depths.pop()
1362             return True
1363
1364         return False
1365
1366     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1367         """In a lambda expression, there might be more than one argument.
1368
1369         To avoid splitting on the comma in this situation, increase the depth of
1370         tokens between `lambda` and `:`.
1371         """
1372         if leaf.type == token.NAME and leaf.value == "lambda":
1373             self.depth += 1
1374             self._lambda_argument_depths.append(self.depth)
1375             return True
1376
1377         return False
1378
1379     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1380         """See `maybe_increment_lambda_arguments` above for explanation."""
1381         if (
1382             self._lambda_argument_depths
1383             and self._lambda_argument_depths[-1] == self.depth
1384             and leaf.type == token.COLON
1385         ):
1386             self.depth -= 1
1387             self._lambda_argument_depths.pop()
1388             return True
1389
1390         return False
1391
1392     def get_open_lsqb(self) -> Optional[Leaf]:
1393         """Return the most recent opening square bracket (if any)."""
1394         return self.bracket_match.get((self.depth - 1, token.RSQB))
1395
1396
1397 @dataclass
1398 class Line:
1399     """Holds leaves and comments. Can be printed with `str(line)`."""
1400
1401     depth: int = 0
1402     leaves: List[Leaf] = field(default_factory=list)
1403     # keys ordered like `leaves`
1404     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1405     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1406     inside_brackets: bool = False
1407     should_explode: bool = False
1408
1409     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1410         """Add a new `leaf` to the end of the line.
1411
1412         Unless `preformatted` is True, the `leaf` will receive a new consistent
1413         whitespace prefix and metadata applied by :class:`BracketTracker`.
1414         Trailing commas are maybe removed, unpacked for loop variables are
1415         demoted from being delimiters.
1416
1417         Inline comments are put aside.
1418         """
1419         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1420         if not has_value:
1421             return
1422
1423         if token.COLON == leaf.type and self.is_class_paren_empty:
1424             del self.leaves[-2:]
1425         if self.leaves and not preformatted:
1426             # Note: at this point leaf.prefix should be empty except for
1427             # imports, for which we only preserve newlines.
1428             leaf.prefix += whitespace(
1429                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1430             )
1431         if self.inside_brackets or not preformatted:
1432             self.bracket_tracker.mark(leaf)
1433             self.maybe_remove_trailing_comma(leaf)
1434         if not self.append_comment(leaf):
1435             self.leaves.append(leaf)
1436
1437     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1438         """Like :func:`append()` but disallow invalid standalone comment structure.
1439
1440         Raises ValueError when any `leaf` is appended after a standalone comment
1441         or when a standalone comment is not the first leaf on the line.
1442         """
1443         if self.bracket_tracker.depth == 0:
1444             if self.is_comment:
1445                 raise ValueError("cannot append to standalone comments")
1446
1447             if self.leaves and leaf.type == STANDALONE_COMMENT:
1448                 raise ValueError(
1449                     "cannot append standalone comments to a populated line"
1450                 )
1451
1452         self.append(leaf, preformatted=preformatted)
1453
1454     @property
1455     def is_comment(self) -> bool:
1456         """Is this line a standalone comment?"""
1457         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1458
1459     @property
1460     def is_decorator(self) -> bool:
1461         """Is this line a decorator?"""
1462         return bool(self) and self.leaves[0].type == token.AT
1463
1464     @property
1465     def is_import(self) -> bool:
1466         """Is this an import line?"""
1467         return bool(self) and is_import(self.leaves[0])
1468
1469     @property
1470     def is_class(self) -> bool:
1471         """Is this line a class definition?"""
1472         return (
1473             bool(self)
1474             and self.leaves[0].type == token.NAME
1475             and self.leaves[0].value == "class"
1476         )
1477
1478     @property
1479     def is_stub_class(self) -> bool:
1480         """Is this line a class definition with a body consisting only of "..."?"""
1481         return self.is_class and self.leaves[-3:] == [
1482             Leaf(token.DOT, ".") for _ in range(3)
1483         ]
1484
1485     @property
1486     def is_collection_with_optional_trailing_comma(self) -> bool:
1487         """Is this line a collection literal with a trailing comma that's optional?
1488
1489         Note that the trailing comma in a 1-tuple is not optional.
1490         """
1491         if not self.leaves or len(self.leaves) < 4:
1492             return False
1493
1494         # Look for and address a trailing colon.
1495         if self.leaves[-1].type == token.COLON:
1496             closer = self.leaves[-2]
1497             close_index = -2
1498         else:
1499             closer = self.leaves[-1]
1500             close_index = -1
1501         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1502             return False
1503
1504         if closer.type == token.RPAR:
1505             # Tuples require an extra check, because if there's only
1506             # one element in the tuple removing the comma unmakes the
1507             # tuple.
1508             #
1509             # We also check for parens before looking for the trailing
1510             # comma because in some cases (eg assigning a dict
1511             # literal) the literal gets wrapped in temporary parens
1512             # during parsing. This case is covered by the
1513             # collections.py test data.
1514             opener = closer.opening_bracket
1515             for _open_index, leaf in enumerate(self.leaves):
1516                 if leaf is opener:
1517                     break
1518
1519             else:
1520                 # Couldn't find the matching opening paren, play it safe.
1521                 return False
1522
1523             commas = 0
1524             comma_depth = self.leaves[close_index - 1].bracket_depth
1525             for leaf in self.leaves[_open_index + 1 : close_index]:
1526                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1527                     commas += 1
1528             if commas > 1:
1529                 # We haven't looked yet for the trailing comma because
1530                 # we might also have caught noop parens.
1531                 return self.leaves[close_index - 1].type == token.COMMA
1532
1533             elif commas == 1:
1534                 return False  # it's either a one-tuple or didn't have a trailing comma
1535
1536             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1537                 close_index -= 1
1538                 closer = self.leaves[close_index]
1539                 if closer.type == token.RPAR:
1540                     # TODO: this is a gut feeling. Will we ever see this?
1541                     return False
1542
1543         if self.leaves[close_index - 1].type != token.COMMA:
1544             return False
1545
1546         return True
1547
1548     @property
1549     def is_def(self) -> bool:
1550         """Is this a function definition? (Also returns True for async defs.)"""
1551         try:
1552             first_leaf = self.leaves[0]
1553         except IndexError:
1554             return False
1555
1556         try:
1557             second_leaf: Optional[Leaf] = self.leaves[1]
1558         except IndexError:
1559             second_leaf = None
1560         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1561             first_leaf.type == token.ASYNC
1562             and second_leaf is not None
1563             and second_leaf.type == token.NAME
1564             and second_leaf.value == "def"
1565         )
1566
1567     @property
1568     def is_class_paren_empty(self) -> bool:
1569         """Is this a class with no base classes but using parentheses?
1570
1571         Those are unnecessary and should be removed.
1572         """
1573         return (
1574             bool(self)
1575             and len(self.leaves) == 4
1576             and self.is_class
1577             and self.leaves[2].type == token.LPAR
1578             and self.leaves[2].value == "("
1579             and self.leaves[3].type == token.RPAR
1580             and self.leaves[3].value == ")"
1581         )
1582
1583     @property
1584     def is_triple_quoted_string(self) -> bool:
1585         """Is the line a triple quoted string?"""
1586         return (
1587             bool(self)
1588             and self.leaves[0].type == token.STRING
1589             and self.leaves[0].value.startswith(('"""', "'''"))
1590         )
1591
1592     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1593         """If so, needs to be split before emitting."""
1594         for leaf in self.leaves:
1595             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1596                 return True
1597
1598         return False
1599
1600     def contains_uncollapsable_type_comments(self) -> bool:
1601         ignored_ids = set()
1602         try:
1603             last_leaf = self.leaves[-1]
1604             ignored_ids.add(id(last_leaf))
1605             if last_leaf.type == token.COMMA or (
1606                 last_leaf.type == token.RPAR and not last_leaf.value
1607             ):
1608                 # When trailing commas or optional parens are inserted by Black for
1609                 # consistency, comments after the previous last element are not moved
1610                 # (they don't have to, rendering will still be correct).  So we ignore
1611                 # trailing commas and invisible.
1612                 last_leaf = self.leaves[-2]
1613                 ignored_ids.add(id(last_leaf))
1614         except IndexError:
1615             return False
1616
1617         # A type comment is uncollapsable if it is attached to a leaf
1618         # that isn't at the end of the line (since that could cause it
1619         # to get associated to a different argument) or if there are
1620         # comments before it (since that could cause it to get hidden
1621         # behind a comment.
1622         comment_seen = False
1623         for leaf_id, comments in self.comments.items():
1624             for comment in comments:
1625                 if is_type_comment(comment):
1626                     if comment_seen or (
1627                         not is_type_comment(comment, " ignore")
1628                         and leaf_id not in ignored_ids
1629                     ):
1630                         return True
1631
1632                 comment_seen = True
1633
1634         return False
1635
1636     def contains_unsplittable_type_ignore(self) -> bool:
1637         if not self.leaves:
1638             return False
1639
1640         # If a 'type: ignore' is attached to the end of a line, we
1641         # can't split the line, because we can't know which of the
1642         # subexpressions the ignore was meant to apply to.
1643         #
1644         # We only want this to apply to actual physical lines from the
1645         # original source, though: we don't want the presence of a
1646         # 'type: ignore' at the end of a multiline expression to
1647         # justify pushing it all onto one line. Thus we
1648         # (unfortunately) need to check the actual source lines and
1649         # only report an unsplittable 'type: ignore' if this line was
1650         # one line in the original code.
1651
1652         # Grab the first and last line numbers, skipping generated leaves
1653         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1654         last_line = next(
1655             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1656         )
1657
1658         if first_line == last_line:
1659             # We look at the last two leaves since a comma or an
1660             # invisible paren could have been added at the end of the
1661             # line.
1662             for node in self.leaves[-2:]:
1663                 for comment in self.comments.get(id(node), []):
1664                     if is_type_comment(comment, " ignore"):
1665                         return True
1666
1667         return False
1668
1669     def contains_multiline_strings(self) -> bool:
1670         return any(is_multiline_string(leaf) for leaf in self.leaves)
1671
1672     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1673         """Remove trailing comma if there is one and it's safe."""
1674         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1675             return False
1676
1677         # We remove trailing commas only in the case of importing a
1678         # single name from a module.
1679         if not (
1680             self.leaves
1681             and self.is_import
1682             and len(self.leaves) > 4
1683             and self.leaves[-1].type == token.COMMA
1684             and closing.type in CLOSING_BRACKETS
1685             and self.leaves[-4].type == token.NAME
1686             and (
1687                 # regular `from foo import bar,`
1688                 self.leaves[-4].value == "import"
1689                 # `from foo import (bar as baz,)
1690                 or (
1691                     len(self.leaves) > 6
1692                     and self.leaves[-6].value == "import"
1693                     and self.leaves[-3].value == "as"
1694                 )
1695                 # `from foo import bar as baz,`
1696                 or (
1697                     len(self.leaves) > 5
1698                     and self.leaves[-5].value == "import"
1699                     and self.leaves[-3].value == "as"
1700                 )
1701             )
1702             and closing.type == token.RPAR
1703         ):
1704             return False
1705
1706         self.remove_trailing_comma()
1707         return True
1708
1709     def append_comment(self, comment: Leaf) -> bool:
1710         """Add an inline or standalone comment to the line."""
1711         if (
1712             comment.type == STANDALONE_COMMENT
1713             and self.bracket_tracker.any_open_brackets()
1714         ):
1715             comment.prefix = ""
1716             return False
1717
1718         if comment.type != token.COMMENT:
1719             return False
1720
1721         if not self.leaves:
1722             comment.type = STANDALONE_COMMENT
1723             comment.prefix = ""
1724             return False
1725
1726         last_leaf = self.leaves[-1]
1727         if (
1728             last_leaf.type == token.RPAR
1729             and not last_leaf.value
1730             and last_leaf.parent
1731             and len(list(last_leaf.parent.leaves())) <= 3
1732             and not is_type_comment(comment)
1733         ):
1734             # Comments on an optional parens wrapping a single leaf should belong to
1735             # the wrapped node except if it's a type comment. Pinning the comment like
1736             # this avoids unstable formatting caused by comment migration.
1737             if len(self.leaves) < 2:
1738                 comment.type = STANDALONE_COMMENT
1739                 comment.prefix = ""
1740                 return False
1741
1742             last_leaf = self.leaves[-2]
1743         self.comments.setdefault(id(last_leaf), []).append(comment)
1744         return True
1745
1746     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1747         """Generate comments that should appear directly after `leaf`."""
1748         return self.comments.get(id(leaf), [])
1749
1750     def remove_trailing_comma(self) -> None:
1751         """Remove the trailing comma and moves the comments attached to it."""
1752         trailing_comma = self.leaves.pop()
1753         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1754         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1755             trailing_comma_comments
1756         )
1757
1758     def is_complex_subscript(self, leaf: Leaf) -> bool:
1759         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1760         open_lsqb = self.bracket_tracker.get_open_lsqb()
1761         if open_lsqb is None:
1762             return False
1763
1764         subscript_start = open_lsqb.next_sibling
1765
1766         if isinstance(subscript_start, Node):
1767             if subscript_start.type == syms.listmaker:
1768                 return False
1769
1770             if subscript_start.type == syms.subscriptlist:
1771                 subscript_start = child_towards(subscript_start, leaf)
1772         return subscript_start is not None and any(
1773             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1774         )
1775
1776     def clone(self) -> "Line":
1777         return Line(
1778             depth=self.depth,
1779             inside_brackets=self.inside_brackets,
1780             should_explode=self.should_explode,
1781         )
1782
1783     def __str__(self) -> str:
1784         """Render the line."""
1785         if not self:
1786             return "\n"
1787
1788         indent = "    " * self.depth
1789         leaves = iter(self.leaves)
1790         first = next(leaves)
1791         res = f"{first.prefix}{indent}{first.value}"
1792         for leaf in leaves:
1793             res += str(leaf)
1794         for comment in itertools.chain.from_iterable(self.comments.values()):
1795             res += str(comment)
1796
1797         return res + "\n"
1798
1799     def __bool__(self) -> bool:
1800         """Return True if the line has leaves or comments."""
1801         return bool(self.leaves or self.comments)
1802
1803
1804 @dataclass
1805 class EmptyLineTracker:
1806     """Provides a stateful method that returns the number of potential extra
1807     empty lines needed before and after the currently processed line.
1808
1809     Note: this tracker works on lines that haven't been split yet.  It assumes
1810     the prefix of the first leaf consists of optional newlines.  Those newlines
1811     are consumed by `maybe_empty_lines()` and included in the computation.
1812     """
1813
1814     is_pyi: bool = False
1815     previous_line: Optional[Line] = None
1816     previous_after: int = 0
1817     previous_defs: List[int] = field(default_factory=list)
1818
1819     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1820         """Return the number of extra empty lines before and after the `current_line`.
1821
1822         This is for separating `def`, `async def` and `class` with extra empty
1823         lines (two on module-level).
1824         """
1825         before, after = self._maybe_empty_lines(current_line)
1826         before = (
1827             # Black should not insert empty lines at the beginning
1828             # of the file
1829             0
1830             if self.previous_line is None
1831             else before - self.previous_after
1832         )
1833         self.previous_after = after
1834         self.previous_line = current_line
1835         return before, after
1836
1837     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1838         max_allowed = 1
1839         if current_line.depth == 0:
1840             max_allowed = 1 if self.is_pyi else 2
1841         if current_line.leaves:
1842             # Consume the first leaf's extra newlines.
1843             first_leaf = current_line.leaves[0]
1844             before = first_leaf.prefix.count("\n")
1845             before = min(before, max_allowed)
1846             first_leaf.prefix = ""
1847         else:
1848             before = 0
1849         depth = current_line.depth
1850         while self.previous_defs and self.previous_defs[-1] >= depth:
1851             self.previous_defs.pop()
1852             if self.is_pyi:
1853                 before = 0 if depth else 1
1854             else:
1855                 before = 1 if depth else 2
1856         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1857             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1858
1859         if (
1860             self.previous_line
1861             and self.previous_line.is_import
1862             and not current_line.is_import
1863             and depth == self.previous_line.depth
1864         ):
1865             return (before or 1), 0
1866
1867         if (
1868             self.previous_line
1869             and self.previous_line.is_class
1870             and current_line.is_triple_quoted_string
1871         ):
1872             return before, 1
1873
1874         return before, 0
1875
1876     def _maybe_empty_lines_for_class_or_def(
1877         self, current_line: Line, before: int
1878     ) -> Tuple[int, int]:
1879         if not current_line.is_decorator:
1880             self.previous_defs.append(current_line.depth)
1881         if self.previous_line is None:
1882             # Don't insert empty lines before the first line in the file.
1883             return 0, 0
1884
1885         if self.previous_line.is_decorator:
1886             return 0, 0
1887
1888         if self.previous_line.depth < current_line.depth and (
1889             self.previous_line.is_class or self.previous_line.is_def
1890         ):
1891             return 0, 0
1892
1893         if (
1894             self.previous_line.is_comment
1895             and self.previous_line.depth == current_line.depth
1896             and before == 0
1897         ):
1898             return 0, 0
1899
1900         if self.is_pyi:
1901             if self.previous_line.depth > current_line.depth:
1902                 newlines = 1
1903             elif current_line.is_class or self.previous_line.is_class:
1904                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1905                     # No blank line between classes with an empty body
1906                     newlines = 0
1907                 else:
1908                     newlines = 1
1909             elif current_line.is_def and not self.previous_line.is_def:
1910                 # Blank line between a block of functions and a block of non-functions
1911                 newlines = 1
1912             else:
1913                 newlines = 0
1914         else:
1915             newlines = 2
1916         if current_line.depth and newlines:
1917             newlines -= 1
1918         return newlines, 0
1919
1920
1921 @dataclass
1922 class LineGenerator(Visitor[Line]):
1923     """Generates reformatted Line objects.  Empty lines are not emitted.
1924
1925     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1926     in ways that will no longer stringify to valid Python code on the tree.
1927     """
1928
1929     is_pyi: bool = False
1930     normalize_strings: bool = True
1931     current_line: Line = field(default_factory=Line)
1932     remove_u_prefix: bool = False
1933
1934     def line(self, indent: int = 0) -> Iterator[Line]:
1935         """Generate a line.
1936
1937         If the line is empty, only emit if it makes sense.
1938         If the line is too long, split it first and then generate.
1939
1940         If any lines were generated, set up a new current_line.
1941         """
1942         if not self.current_line:
1943             self.current_line.depth += indent
1944             return  # Line is empty, don't emit. Creating a new one unnecessary.
1945
1946         complete_line = self.current_line
1947         self.current_line = Line(depth=complete_line.depth + indent)
1948         yield complete_line
1949
1950     def visit_default(self, node: LN) -> Iterator[Line]:
1951         """Default `visit_*()` implementation. Recurses to children of `node`."""
1952         if isinstance(node, Leaf):
1953             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1954             for comment in generate_comments(node):
1955                 if any_open_brackets:
1956                     # any comment within brackets is subject to splitting
1957                     self.current_line.append(comment)
1958                 elif comment.type == token.COMMENT:
1959                     # regular trailing comment
1960                     self.current_line.append(comment)
1961                     yield from self.line()
1962
1963                 else:
1964                     # regular standalone comment
1965                     yield from self.line()
1966
1967                     self.current_line.append(comment)
1968                     yield from self.line()
1969
1970             normalize_prefix(node, inside_brackets=any_open_brackets)
1971             if self.normalize_strings and node.type == token.STRING:
1972                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1973                 normalize_string_quotes(node)
1974             if node.type == token.NUMBER:
1975                 normalize_numeric_literal(node)
1976             if node.type not in WHITESPACE:
1977                 self.current_line.append(node)
1978         yield from super().visit_default(node)
1979
1980     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1981         """Increase indentation level, maybe yield a line."""
1982         # In blib2to3 INDENT never holds comments.
1983         yield from self.line(+1)
1984         yield from self.visit_default(node)
1985
1986     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1987         """Decrease indentation level, maybe yield a line."""
1988         # The current line might still wait for trailing comments.  At DEDENT time
1989         # there won't be any (they would be prefixes on the preceding NEWLINE).
1990         # Emit the line then.
1991         yield from self.line()
1992
1993         # While DEDENT has no value, its prefix may contain standalone comments
1994         # that belong to the current indentation level.  Get 'em.
1995         yield from self.visit_default(node)
1996
1997         # Finally, emit the dedent.
1998         yield from self.line(-1)
1999
2000     def visit_stmt(
2001         self, node: Node, keywords: Set[str], parens: Set[str]
2002     ) -> Iterator[Line]:
2003         """Visit a statement.
2004
2005         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2006         `def`, `with`, `class`, `assert` and assignments.
2007
2008         The relevant Python language `keywords` for a given statement will be
2009         NAME leaves within it. This methods puts those on a separate line.
2010
2011         `parens` holds a set of string leaf values immediately after which
2012         invisible parens should be put.
2013         """
2014         normalize_invisible_parens(node, parens_after=parens)
2015         for child in node.children:
2016             if child.type == token.NAME and child.value in keywords:  # type: ignore
2017                 yield from self.line()
2018
2019             yield from self.visit(child)
2020
2021     def visit_suite(self, node: Node) -> Iterator[Line]:
2022         """Visit a suite."""
2023         if self.is_pyi and is_stub_suite(node):
2024             yield from self.visit(node.children[2])
2025         else:
2026             yield from self.visit_default(node)
2027
2028     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2029         """Visit a statement without nested statements."""
2030         is_suite_like = node.parent and node.parent.type in STATEMENT
2031         if is_suite_like:
2032             if self.is_pyi and is_stub_body(node):
2033                 yield from self.visit_default(node)
2034             else:
2035                 yield from self.line(+1)
2036                 yield from self.visit_default(node)
2037                 yield from self.line(-1)
2038
2039         else:
2040             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
2041                 yield from self.line()
2042             yield from self.visit_default(node)
2043
2044     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2045         """Visit `async def`, `async for`, `async with`."""
2046         yield from self.line()
2047
2048         children = iter(node.children)
2049         for child in children:
2050             yield from self.visit(child)
2051
2052             if child.type == token.ASYNC:
2053                 break
2054
2055         internal_stmt = next(children)
2056         for child in internal_stmt.children:
2057             yield from self.visit(child)
2058
2059     def visit_decorators(self, node: Node) -> Iterator[Line]:
2060         """Visit decorators."""
2061         for child in node.children:
2062             yield from self.line()
2063             yield from self.visit(child)
2064
2065     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2066         """Remove a semicolon and put the other statement on a separate line."""
2067         yield from self.line()
2068
2069     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2070         """End of file. Process outstanding comments and end with a newline."""
2071         yield from self.visit_default(leaf)
2072         yield from self.line()
2073
2074     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2075         if not self.current_line.bracket_tracker.any_open_brackets():
2076             yield from self.line()
2077         yield from self.visit_default(leaf)
2078
2079     def visit_factor(self, node: Node) -> Iterator[Line]:
2080         """Force parentheses between a unary op and a binary power:
2081
2082         -2 ** 8 -> -(2 ** 8)
2083         """
2084         _operator, operand = node.children
2085         if (
2086             operand.type == syms.power
2087             and len(operand.children) == 3
2088             and operand.children[1].type == token.DOUBLESTAR
2089         ):
2090             lpar = Leaf(token.LPAR, "(")
2091             rpar = Leaf(token.RPAR, ")")
2092             index = operand.remove() or 0
2093             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2094         yield from self.visit_default(node)
2095
2096     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2097         # Check if it's a docstring
2098         if prev_siblings_are(
2099             leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
2100         ) and is_multiline_string(leaf):
2101             prefix = "    " * self.current_line.depth
2102             docstring = fix_docstring(leaf.value[3:-3], prefix)
2103             leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
2104             normalize_string_quotes(leaf)
2105
2106         yield from self.visit_default(leaf)
2107
2108     def __post_init__(self) -> None:
2109         """You are in a twisty little maze of passages."""
2110         v = self.visit_stmt
2111         Ø: Set[str] = set()
2112         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2113         self.visit_if_stmt = partial(
2114             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2115         )
2116         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2117         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2118         self.visit_try_stmt = partial(
2119             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2120         )
2121         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2122         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2123         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2124         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2125         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2126         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2127         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2128         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2129         self.visit_async_funcdef = self.visit_async_stmt
2130         self.visit_decorated = self.visit_decorators
2131
2132
2133 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2134 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2135 OPENING_BRACKETS = set(BRACKET.keys())
2136 CLOSING_BRACKETS = set(BRACKET.values())
2137 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2138 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2139
2140
2141 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2142     """Return whitespace prefix if needed for the given `leaf`.
2143
2144     `complex_subscript` signals whether the given leaf is part of a subscription
2145     which has non-trivial arguments, like arithmetic expressions or function calls.
2146     """
2147     NO = ""
2148     SPACE = " "
2149     DOUBLESPACE = "  "
2150     t = leaf.type
2151     p = leaf.parent
2152     v = leaf.value
2153     if t in ALWAYS_NO_SPACE:
2154         return NO
2155
2156     if t == token.COMMENT:
2157         return DOUBLESPACE
2158
2159     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2160     if t == token.COLON and p.type not in {
2161         syms.subscript,
2162         syms.subscriptlist,
2163         syms.sliceop,
2164     }:
2165         return NO
2166
2167     prev = leaf.prev_sibling
2168     if not prev:
2169         prevp = preceding_leaf(p)
2170         if not prevp or prevp.type in OPENING_BRACKETS:
2171             return NO
2172
2173         if t == token.COLON:
2174             if prevp.type == token.COLON:
2175                 return NO
2176
2177             elif prevp.type != token.COMMA and not complex_subscript:
2178                 return NO
2179
2180             return SPACE
2181
2182         if prevp.type == token.EQUAL:
2183             if prevp.parent:
2184                 if prevp.parent.type in {
2185                     syms.arglist,
2186                     syms.argument,
2187                     syms.parameters,
2188                     syms.varargslist,
2189                 }:
2190                     return NO
2191
2192                 elif prevp.parent.type == syms.typedargslist:
2193                     # A bit hacky: if the equal sign has whitespace, it means we
2194                     # previously found it's a typed argument.  So, we're using
2195                     # that, too.
2196                     return prevp.prefix
2197
2198         elif prevp.type in VARARGS_SPECIALS:
2199             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2200                 return NO
2201
2202         elif prevp.type == token.COLON:
2203             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2204                 return SPACE if complex_subscript else NO
2205
2206         elif (
2207             prevp.parent
2208             and prevp.parent.type == syms.factor
2209             and prevp.type in MATH_OPERATORS
2210         ):
2211             return NO
2212
2213         elif (
2214             prevp.type == token.RIGHTSHIFT
2215             and prevp.parent
2216             and prevp.parent.type == syms.shift_expr
2217             and prevp.prev_sibling
2218             and prevp.prev_sibling.type == token.NAME
2219             and prevp.prev_sibling.value == "print"  # type: ignore
2220         ):
2221             # Python 2 print chevron
2222             return NO
2223
2224     elif prev.type in OPENING_BRACKETS:
2225         return NO
2226
2227     if p.type in {syms.parameters, syms.arglist}:
2228         # untyped function signatures or calls
2229         if not prev or prev.type != token.COMMA:
2230             return NO
2231
2232     elif p.type == syms.varargslist:
2233         # lambdas
2234         if prev and prev.type != token.COMMA:
2235             return NO
2236
2237     elif p.type == syms.typedargslist:
2238         # typed function signatures
2239         if not prev:
2240             return NO
2241
2242         if t == token.EQUAL:
2243             if prev.type != syms.tname:
2244                 return NO
2245
2246         elif prev.type == token.EQUAL:
2247             # A bit hacky: if the equal sign has whitespace, it means we
2248             # previously found it's a typed argument.  So, we're using that, too.
2249             return prev.prefix
2250
2251         elif prev.type != token.COMMA:
2252             return NO
2253
2254     elif p.type == syms.tname:
2255         # type names
2256         if not prev:
2257             prevp = preceding_leaf(p)
2258             if not prevp or prevp.type != token.COMMA:
2259                 return NO
2260
2261     elif p.type == syms.trailer:
2262         # attributes and calls
2263         if t == token.LPAR or t == token.RPAR:
2264             return NO
2265
2266         if not prev:
2267             if t == token.DOT:
2268                 prevp = preceding_leaf(p)
2269                 if not prevp or prevp.type != token.NUMBER:
2270                     return NO
2271
2272             elif t == token.LSQB:
2273                 return NO
2274
2275         elif prev.type != token.COMMA:
2276             return NO
2277
2278     elif p.type == syms.argument:
2279         # single argument
2280         if t == token.EQUAL:
2281             return NO
2282
2283         if not prev:
2284             prevp = preceding_leaf(p)
2285             if not prevp or prevp.type == token.LPAR:
2286                 return NO
2287
2288         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2289             return NO
2290
2291     elif p.type == syms.decorator:
2292         # decorators
2293         return NO
2294
2295     elif p.type == syms.dotted_name:
2296         if prev:
2297             return NO
2298
2299         prevp = preceding_leaf(p)
2300         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2301             return NO
2302
2303     elif p.type == syms.classdef:
2304         if t == token.LPAR:
2305             return NO
2306
2307         if prev and prev.type == token.LPAR:
2308             return NO
2309
2310     elif p.type in {syms.subscript, syms.sliceop}:
2311         # indexing
2312         if not prev:
2313             assert p.parent is not None, "subscripts are always parented"
2314             if p.parent.type == syms.subscriptlist:
2315                 return SPACE
2316
2317             return NO
2318
2319         elif not complex_subscript:
2320             return NO
2321
2322     elif p.type == syms.atom:
2323         if prev and t == token.DOT:
2324             # dots, but not the first one.
2325             return NO
2326
2327     elif p.type == syms.dictsetmaker:
2328         # dict unpacking
2329         if prev and prev.type == token.DOUBLESTAR:
2330             return NO
2331
2332     elif p.type in {syms.factor, syms.star_expr}:
2333         # unary ops
2334         if not prev:
2335             prevp = preceding_leaf(p)
2336             if not prevp or prevp.type in OPENING_BRACKETS:
2337                 return NO
2338
2339             prevp_parent = prevp.parent
2340             assert prevp_parent is not None
2341             if prevp.type == token.COLON and prevp_parent.type in {
2342                 syms.subscript,
2343                 syms.sliceop,
2344             }:
2345                 return NO
2346
2347             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2348                 return NO
2349
2350         elif t in {token.NAME, token.NUMBER, token.STRING}:
2351             return NO
2352
2353     elif p.type == syms.import_from:
2354         if t == token.DOT:
2355             if prev and prev.type == token.DOT:
2356                 return NO
2357
2358         elif t == token.NAME:
2359             if v == "import":
2360                 return SPACE
2361
2362             if prev and prev.type == token.DOT:
2363                 return NO
2364
2365     elif p.type == syms.sliceop:
2366         return NO
2367
2368     return SPACE
2369
2370
2371 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2372     """Return the first leaf that precedes `node`, if any."""
2373     while node:
2374         res = node.prev_sibling
2375         if res:
2376             if isinstance(res, Leaf):
2377                 return res
2378
2379             try:
2380                 return list(res.leaves())[-1]
2381
2382             except IndexError:
2383                 return None
2384
2385         node = node.parent
2386     return None
2387
2388
2389 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2390     """Return if the `node` and its previous siblings match types against the provided
2391     list of tokens; the provided `node`has its type matched against the last element in
2392     the list.  `None` can be used as the first element to declare that the start of the
2393     list is anchored at the start of its parent's children."""
2394     if not tokens:
2395         return True
2396     if tokens[-1] is None:
2397         return node is None
2398     if not node:
2399         return False
2400     if node.type != tokens[-1]:
2401         return False
2402     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2403
2404
2405 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2406     """Return the child of `ancestor` that contains `descendant`."""
2407     node: Optional[LN] = descendant
2408     while node and node.parent != ancestor:
2409         node = node.parent
2410     return node
2411
2412
2413 def container_of(leaf: Leaf) -> LN:
2414     """Return `leaf` or one of its ancestors that is the topmost container of it.
2415
2416     By "container" we mean a node where `leaf` is the very first child.
2417     """
2418     same_prefix = leaf.prefix
2419     container: LN = leaf
2420     while container:
2421         parent = container.parent
2422         if parent is None:
2423             break
2424
2425         if parent.children[0].prefix != same_prefix:
2426             break
2427
2428         if parent.type == syms.file_input:
2429             break
2430
2431         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2432             break
2433
2434         container = parent
2435     return container
2436
2437
2438 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2439     """Return the priority of the `leaf` delimiter, given a line break after it.
2440
2441     The delimiter priorities returned here are from those delimiters that would
2442     cause a line break after themselves.
2443
2444     Higher numbers are higher priority.
2445     """
2446     if leaf.type == token.COMMA:
2447         return COMMA_PRIORITY
2448
2449     return 0
2450
2451
2452 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2453     """Return the priority of the `leaf` delimiter, given a line break before it.
2454
2455     The delimiter priorities returned here are from those delimiters that would
2456     cause a line break before themselves.
2457
2458     Higher numbers are higher priority.
2459     """
2460     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2461         # * and ** might also be MATH_OPERATORS but in this case they are not.
2462         # Don't treat them as a delimiter.
2463         return 0
2464
2465     if (
2466         leaf.type == token.DOT
2467         and leaf.parent
2468         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2469         and (previous is None or previous.type in CLOSING_BRACKETS)
2470     ):
2471         return DOT_PRIORITY
2472
2473     if (
2474         leaf.type in MATH_OPERATORS
2475         and leaf.parent
2476         and leaf.parent.type not in {syms.factor, syms.star_expr}
2477     ):
2478         return MATH_PRIORITIES[leaf.type]
2479
2480     if leaf.type in COMPARATORS:
2481         return COMPARATOR_PRIORITY
2482
2483     if (
2484         leaf.type == token.STRING
2485         and previous is not None
2486         and previous.type == token.STRING
2487     ):
2488         return STRING_PRIORITY
2489
2490     if leaf.type not in {token.NAME, token.ASYNC}:
2491         return 0
2492
2493     if (
2494         leaf.value == "for"
2495         and leaf.parent
2496         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2497         or leaf.type == token.ASYNC
2498     ):
2499         if (
2500             not isinstance(leaf.prev_sibling, Leaf)
2501             or leaf.prev_sibling.value != "async"
2502         ):
2503             return COMPREHENSION_PRIORITY
2504
2505     if (
2506         leaf.value == "if"
2507         and leaf.parent
2508         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2509     ):
2510         return COMPREHENSION_PRIORITY
2511
2512     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2513         return TERNARY_PRIORITY
2514
2515     if leaf.value == "is":
2516         return COMPARATOR_PRIORITY
2517
2518     if (
2519         leaf.value == "in"
2520         and leaf.parent
2521         and leaf.parent.type in {syms.comp_op, syms.comparison}
2522         and not (
2523             previous is not None
2524             and previous.type == token.NAME
2525             and previous.value == "not"
2526         )
2527     ):
2528         return COMPARATOR_PRIORITY
2529
2530     if (
2531         leaf.value == "not"
2532         and leaf.parent
2533         and leaf.parent.type == syms.comp_op
2534         and not (
2535             previous is not None
2536             and previous.type == token.NAME
2537             and previous.value == "is"
2538         )
2539     ):
2540         return COMPARATOR_PRIORITY
2541
2542     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2543         return LOGIC_PRIORITY
2544
2545     return 0
2546
2547
2548 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2549 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2550
2551
2552 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2553     """Clean the prefix of the `leaf` and generate comments from it, if any.
2554
2555     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2556     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2557     move because it does away with modifying the grammar to include all the
2558     possible places in which comments can be placed.
2559
2560     The sad consequence for us though is that comments don't "belong" anywhere.
2561     This is why this function generates simple parentless Leaf objects for
2562     comments.  We simply don't know what the correct parent should be.
2563
2564     No matter though, we can live without this.  We really only need to
2565     differentiate between inline and standalone comments.  The latter don't
2566     share the line with any code.
2567
2568     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2569     are emitted with a fake STANDALONE_COMMENT token identifier.
2570     """
2571     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2572         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2573
2574
2575 @dataclass
2576 class ProtoComment:
2577     """Describes a piece of syntax that is a comment.
2578
2579     It's not a :class:`blib2to3.pytree.Leaf` so that:
2580
2581     * it can be cached (`Leaf` objects should not be reused more than once as
2582       they store their lineno, column, prefix, and parent information);
2583     * `newlines` and `consumed` fields are kept separate from the `value`. This
2584       simplifies handling of special marker comments like ``# fmt: off/on``.
2585     """
2586
2587     type: int  # token.COMMENT or STANDALONE_COMMENT
2588     value: str  # content of the comment
2589     newlines: int  # how many newlines before the comment
2590     consumed: int  # how many characters of the original leaf's prefix did we consume
2591
2592
2593 @lru_cache(maxsize=4096)
2594 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2595     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2596     result: List[ProtoComment] = []
2597     if not prefix or "#" not in prefix:
2598         return result
2599
2600     consumed = 0
2601     nlines = 0
2602     ignored_lines = 0
2603     for index, line in enumerate(prefix.split("\n")):
2604         consumed += len(line) + 1  # adding the length of the split '\n'
2605         line = line.lstrip()
2606         if not line:
2607             nlines += 1
2608         if not line.startswith("#"):
2609             # Escaped newlines outside of a comment are not really newlines at
2610             # all. We treat a single-line comment following an escaped newline
2611             # as a simple trailing comment.
2612             if line.endswith("\\"):
2613                 ignored_lines += 1
2614             continue
2615
2616         if index == ignored_lines and not is_endmarker:
2617             comment_type = token.COMMENT  # simple trailing comment
2618         else:
2619             comment_type = STANDALONE_COMMENT
2620         comment = make_comment(line)
2621         result.append(
2622             ProtoComment(
2623                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2624             )
2625         )
2626         nlines = 0
2627     return result
2628
2629
2630 def make_comment(content: str) -> str:
2631     """Return a consistently formatted comment from the given `content` string.
2632
2633     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2634     space between the hash sign and the content.
2635
2636     If `content` didn't start with a hash sign, one is provided.
2637     """
2638     content = content.rstrip()
2639     if not content:
2640         return "#"
2641
2642     if content[0] == "#":
2643         content = content[1:]
2644     if content and content[0] not in " !:#'%":
2645         content = " " + content
2646     return "#" + content
2647
2648
2649 def transform_line(
2650     line: Line,
2651     line_length: int,
2652     normalize_strings: bool,
2653     features: Collection[Feature] = (),
2654 ) -> Iterator[Line]:
2655     """Transform a `line`, potentially splitting it into many lines.
2656
2657     They should fit in the allotted `line_length` but might not be able to.
2658
2659     `features` are syntactical features that may be used in the output.
2660     """
2661     if line.is_comment:
2662         yield line
2663         return
2664
2665     line_str = line_to_string(line)
2666
2667     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2668         """Initialize StringTransformer"""
2669         return ST(line_length, normalize_strings)
2670
2671     string_merge = init_st(StringMerger)
2672     string_paren_strip = init_st(StringParenStripper)
2673     string_split = init_st(StringSplitter)
2674     string_paren_wrap = init_st(StringParenWrapper)
2675
2676     transformers: List[Transformer]
2677     if (
2678         not line.contains_uncollapsable_type_comments()
2679         and not line.should_explode
2680         and not line.is_collection_with_optional_trailing_comma
2681         and (
2682             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2683             or line.contains_unsplittable_type_ignore()
2684         )
2685         and not (line.contains_standalone_comments() and line.inside_brackets)
2686     ):
2687         # Only apply basic string preprocessing, since lines shouldn't be split here.
2688         transformers = [string_merge, string_paren_strip]
2689     elif line.is_def:
2690         transformers = [left_hand_split]
2691     else:
2692
2693         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2694             for omit in generate_trailers_to_omit(line, line_length):
2695                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2696                 if is_line_short_enough(lines[0], line_length=line_length):
2697                     yield from lines
2698                     return
2699
2700             # All splits failed, best effort split with no omits.
2701             # This mostly happens to multiline strings that are by definition
2702             # reported as not fitting a single line.
2703             # line_length=1 here was historically a bug that somehow became a feature.
2704             # See #762 and #781 for the full story.
2705             yield from right_hand_split(line, line_length=1, features=features)
2706
2707         if line.inside_brackets:
2708             transformers = [
2709                 string_merge,
2710                 string_paren_strip,
2711                 delimiter_split,
2712                 standalone_comment_split,
2713                 string_split,
2714                 string_paren_wrap,
2715                 rhs,
2716             ]
2717         else:
2718             transformers = [
2719                 string_merge,
2720                 string_paren_strip,
2721                 string_split,
2722                 string_paren_wrap,
2723                 rhs,
2724             ]
2725
2726     for transform in transformers:
2727         # We are accumulating lines in `result` because we might want to abort
2728         # mission and return the original line in the end, or attempt a different
2729         # split altogether.
2730         result: List[Line] = []
2731         try:
2732             for transformed_line in transform(line, features):
2733                 if str(transformed_line).strip("\n") == line_str:
2734                     raise CannotTransform(
2735                         "Line transformer returned an unchanged result"
2736                     )
2737
2738                 result.extend(
2739                     transform_line(
2740                         transformed_line,
2741                         line_length=line_length,
2742                         normalize_strings=normalize_strings,
2743                         features=features,
2744                     )
2745                 )
2746         except CannotTransform:
2747             continue
2748         else:
2749             yield from result
2750             break
2751
2752     else:
2753         yield line
2754
2755
2756 @dataclass  # type: ignore
2757 class StringTransformer(ABC):
2758     """
2759     An implementation of the Transformer protocol that relies on its
2760     subclasses overriding the template methods `do_match(...)` and
2761     `do_transform(...)`.
2762
2763     This Transformer works exclusively on strings (for example, by merging
2764     or splitting them).
2765
2766     The following sections can be found among the docstrings of each concrete
2767     StringTransformer subclass.
2768
2769     Requirements:
2770         Which requirements must be met of the given Line for this
2771         StringTransformer to be applied?
2772
2773     Transformations:
2774         If the given Line meets all of the above requirements, which string
2775         transformations can you expect to be applied to it by this
2776         StringTransformer?
2777
2778     Collaborations:
2779         What contractual agreements does this StringTransformer have with other
2780         StringTransfomers? Such collaborations should be eliminated/minimized
2781         as much as possible.
2782     """
2783
2784     line_length: int
2785     normalize_strings: bool
2786
2787     @abstractmethod
2788     def do_match(self, line: Line) -> TMatchResult:
2789         """
2790         Returns:
2791             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2792             string, if a match was able to be made.
2793                 OR
2794             * Err(CannotTransform), if a match was not able to be made.
2795         """
2796
2797     @abstractmethod
2798     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2799         """
2800         Yields:
2801             * Ok(new_line) where new_line is the new transformed line.
2802                 OR
2803             * Err(CannotTransform) if the transformation failed for some reason. The
2804             `do_match(...)` template method should usually be used to reject
2805             the form of the given Line, but in some cases it is difficult to
2806             know whether or not a Line meets the StringTransformer's
2807             requirements until the transformation is already midway.
2808
2809         Side Effects:
2810             This method should NOT mutate @line directly, but it MAY mutate the
2811             Line's underlying Node structure. (WARNING: If the underlying Node
2812             structure IS altered, then this method should NOT be allowed to
2813             yield an CannotTransform after that point.)
2814         """
2815
2816     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2817         """
2818         StringTransformer instances have a call signature that mirrors that of
2819         the Transformer type.
2820
2821         Raises:
2822             CannotTransform(...) if the concrete StringTransformer class is unable
2823             to transform @line.
2824         """
2825         # Optimization to avoid calling `self.do_match(...)` when the line does
2826         # not contain any string.
2827         if not any(leaf.type == token.STRING for leaf in line.leaves):
2828             raise CannotTransform("There are no strings in this line.")
2829
2830         match_result = self.do_match(line)
2831
2832         if isinstance(match_result, Err):
2833             cant_transform = match_result.err()
2834             raise CannotTransform(
2835                 f"The string transformer {self.__class__.__name__} does not recognize"
2836                 " this line as one that it can transform."
2837             ) from cant_transform
2838
2839         string_idx = match_result.ok()
2840
2841         for line_result in self.do_transform(line, string_idx):
2842             if isinstance(line_result, Err):
2843                 cant_transform = line_result.err()
2844                 raise CannotTransform(
2845                     "StringTransformer failed while attempting to transform string."
2846                 ) from cant_transform
2847             line = line_result.ok()
2848             yield line
2849
2850
2851 @dataclass
2852 class CustomSplit:
2853     """A custom (i.e. manual) string split.
2854
2855     A single CustomSplit instance represents a single substring.
2856
2857     Examples:
2858         Consider the following string:
2859         ```
2860         "Hi there friend."
2861         " This is a custom"
2862         f" string {split}."
2863         ```
2864
2865         This string will correspond to the following three CustomSplit instances:
2866         ```
2867         CustomSplit(False, 16)
2868         CustomSplit(False, 17)
2869         CustomSplit(True, 16)
2870         ```
2871     """
2872
2873     has_prefix: bool
2874     break_idx: int
2875
2876
2877 class CustomSplitMapMixin:
2878     """
2879     This mixin class is used to map merged strings to a sequence of
2880     CustomSplits, which will then be used to re-split the strings iff none of
2881     the resultant substrings go over the configured max line length.
2882     """
2883
2884     _Key = Tuple[StringID, str]
2885     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2886
2887     @staticmethod
2888     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2889         """
2890         Returns:
2891             A unique identifier that is used internally to map @string to a
2892             group of custom splits.
2893         """
2894         return (id(string), string)
2895
2896     def add_custom_splits(
2897         self, string: str, custom_splits: Iterable[CustomSplit]
2898     ) -> None:
2899         """Custom Split Map Setter Method
2900
2901         Side Effects:
2902             Adds a mapping from @string to the custom splits @custom_splits.
2903         """
2904         key = self._get_key(string)
2905         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2906
2907     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2908         """Custom Split Map Getter Method
2909
2910         Returns:
2911             * A list of the custom splits that are mapped to @string, if any
2912             exist.
2913                 OR
2914             * [], otherwise.
2915
2916         Side Effects:
2917             Deletes the mapping between @string and its associated custom
2918             splits (which are returned to the caller).
2919         """
2920         key = self._get_key(string)
2921
2922         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2923         del self._CUSTOM_SPLIT_MAP[key]
2924
2925         return list(custom_splits)
2926
2927     def has_custom_splits(self, string: str) -> bool:
2928         """
2929         Returns:
2930             True iff @string is associated with a set of custom splits.
2931         """
2932         key = self._get_key(string)
2933         return key in self._CUSTOM_SPLIT_MAP
2934
2935
2936 class StringMerger(CustomSplitMapMixin, StringTransformer):
2937     """StringTransformer that merges strings together.
2938
2939     Requirements:
2940         (A) The line contains adjacent strings such that at most one substring
2941         has inline comments AND none of those inline comments are pragmas AND
2942         the set of all substring prefixes is either of length 1 or equal to
2943         {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed
2944         with 'r').
2945             OR
2946         (B) The line contains a string which uses line continuation backslashes.
2947
2948     Transformations:
2949         Depending on which of the two requirements above where met, either:
2950
2951         (A) The string group associated with the target string is merged.
2952             OR
2953         (B) All line-continuation backslashes are removed from the target string.
2954
2955     Collaborations:
2956         StringMerger provides custom split information to StringSplitter.
2957     """
2958
2959     def do_match(self, line: Line) -> TMatchResult:
2960         LL = line.leaves
2961
2962         is_valid_index = is_valid_index_factory(LL)
2963
2964         for (i, leaf) in enumerate(LL):
2965             if (
2966                 leaf.type == token.STRING
2967                 and is_valid_index(i + 1)
2968                 and LL[i + 1].type == token.STRING
2969             ):
2970                 return Ok(i)
2971
2972             if leaf.type == token.STRING and "\\\n" in leaf.value:
2973                 return Ok(i)
2974
2975         return TErr("This line has no strings that need merging.")
2976
2977     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2978         new_line = line
2979         rblc_result = self.__remove_backslash_line_continuation_chars(
2980             new_line, string_idx
2981         )
2982         if isinstance(rblc_result, Ok):
2983             new_line = rblc_result.ok()
2984
2985         msg_result = self.__merge_string_group(new_line, string_idx)
2986         if isinstance(msg_result, Ok):
2987             new_line = msg_result.ok()
2988
2989         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
2990             msg_cant_transform = msg_result.err()
2991             rblc_cant_transform = rblc_result.err()
2992             cant_transform = CannotTransform(
2993                 "StringMerger failed to merge any strings in this line."
2994             )
2995
2996             # Chain the errors together using `__cause__`.
2997             msg_cant_transform.__cause__ = rblc_cant_transform
2998             cant_transform.__cause__ = msg_cant_transform
2999
3000             yield Err(cant_transform)
3001         else:
3002             yield Ok(new_line)
3003
3004     @staticmethod
3005     def __remove_backslash_line_continuation_chars(
3006         line: Line, string_idx: int
3007     ) -> TResult[Line]:
3008         """
3009         Merge strings that were split across multiple lines using
3010         line-continuation backslashes.
3011
3012         Returns:
3013             Ok(new_line), if @line contains backslash line-continuation
3014             characters.
3015                 OR
3016             Err(CannotTransform), otherwise.
3017         """
3018         LL = line.leaves
3019
3020         string_leaf = LL[string_idx]
3021         if not (
3022             string_leaf.type == token.STRING
3023             and "\\\n" in string_leaf.value
3024             and not has_triple_quotes(string_leaf.value)
3025         ):
3026             return TErr(
3027                 f"String leaf {string_leaf} does not contain any backslash line"
3028                 " continuation characters."
3029             )
3030
3031         new_line = line.clone()
3032         new_line.comments = line.comments
3033         append_leaves(new_line, line, LL)
3034
3035         new_string_leaf = new_line.leaves[string_idx]
3036         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3037
3038         return Ok(new_line)
3039
3040     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3041         """
3042         Merges string group (i.e. set of adjacent strings) where the first
3043         string in the group is `line.leaves[string_idx]`.
3044
3045         Returns:
3046             Ok(new_line), if ALL of the validation checks found in
3047             __validate_msg(...) pass.
3048                 OR
3049             Err(CannotTransform), otherwise.
3050         """
3051         LL = line.leaves
3052
3053         is_valid_index = is_valid_index_factory(LL)
3054
3055         vresult = self.__validate_msg(line, string_idx)
3056         if isinstance(vresult, Err):
3057             return vresult
3058
3059         # If the string group is wrapped inside an Atom node, we must make sure
3060         # to later replace that Atom with our new (merged) string leaf.
3061         atom_node = LL[string_idx].parent
3062
3063         # We will place BREAK_MARK in between every two substrings that we
3064         # merge. We will then later go through our final result and use the
3065         # various instances of BREAK_MARK we find to add the right values to
3066         # the custom split map.
3067         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3068
3069         QUOTE = LL[string_idx].value[-1]
3070
3071         def make_naked(string: str, string_prefix: str) -> str:
3072             """Strip @string (i.e. make it a "naked" string)
3073
3074             Pre-conditions:
3075                 * assert_is_leaf_string(@string)
3076
3077             Returns:
3078                 A string that is identical to @string except that
3079                 @string_prefix has been stripped, the surrounding QUOTE
3080                 characters have been removed, and any remaining QUOTE
3081                 characters have been escaped.
3082             """
3083             assert_is_leaf_string(string)
3084
3085             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3086             naked_string = string[len(string_prefix) + 1 : -1]
3087             naked_string = re.sub(
3088                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3089             )
3090             return naked_string
3091
3092         # Holds the CustomSplit objects that will later be added to the custom
3093         # split map.
3094         custom_splits = []
3095
3096         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3097         prefix_tracker = []
3098
3099         # Sets the 'prefix' variable. This is the prefix that the final merged
3100         # string will have.
3101         next_str_idx = string_idx
3102         prefix = ""
3103         while (
3104             not prefix
3105             and is_valid_index(next_str_idx)
3106             and LL[next_str_idx].type == token.STRING
3107         ):
3108             prefix = get_string_prefix(LL[next_str_idx].value)
3109             next_str_idx += 1
3110
3111         # The next loop merges the string group. The final string will be
3112         # contained in 'S'.
3113         #
3114         # The following convenience variables are used:
3115         #
3116         #   S: string
3117         #   NS: naked string
3118         #   SS: next string
3119         #   NSS: naked next string
3120         S = ""
3121         NS = ""
3122         num_of_strings = 0
3123         next_str_idx = string_idx
3124         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3125             num_of_strings += 1
3126
3127             SS = LL[next_str_idx].value
3128             next_prefix = get_string_prefix(SS)
3129
3130             # If this is an f-string group but this substring is not prefixed
3131             # with 'f'...
3132             if "f" in prefix and "f" not in next_prefix:
3133                 # Then we must escape any braces contained in this substring.
3134                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3135
3136             NSS = make_naked(SS, next_prefix)
3137
3138             has_prefix = bool(next_prefix)
3139             prefix_tracker.append(has_prefix)
3140
3141             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3142             NS = make_naked(S, prefix)
3143
3144             next_str_idx += 1
3145
3146         S_leaf = Leaf(token.STRING, S)
3147         if self.normalize_strings:
3148             normalize_string_quotes(S_leaf)
3149
3150         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3151         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3152         for has_prefix in prefix_tracker:
3153             mark_idx = temp_string.find(BREAK_MARK)
3154             assert (
3155                 mark_idx >= 0
3156             ), "Logic error while filling the custom string breakpoint cache."
3157
3158             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3159             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3160             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3161
3162         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3163
3164         if atom_node is not None:
3165             replace_child(atom_node, string_leaf)
3166
3167         # Build the final line ('new_line') that this method will later return.
3168         new_line = line.clone()
3169         for (i, leaf) in enumerate(LL):
3170             if i == string_idx:
3171                 new_line.append(string_leaf)
3172
3173             if string_idx <= i < string_idx + num_of_strings:
3174                 for comment_leaf in line.comments_after(LL[i]):
3175                     new_line.append(comment_leaf, preformatted=True)
3176                 continue
3177
3178             append_leaves(new_line, line, [leaf])
3179
3180         self.add_custom_splits(string_leaf.value, custom_splits)
3181         return Ok(new_line)
3182
3183     @staticmethod
3184     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3185         """Validate (M)erge (S)tring (G)roup
3186
3187         Transform-time string validation logic for __merge_string_group(...).
3188
3189         Returns:
3190             * Ok(None), if ALL validation checks (listed below) pass.
3191                 OR
3192             * Err(CannotTransform), if any of the following are true:
3193                 - The target string is not in a string group (i.e. it has no
3194                   adjacent strings).
3195                 - The string group has more than one inline comment.
3196                 - The string group has an inline comment that appears to be a pragma.
3197                 - The set of all string prefixes in the string group is of
3198                   length greater than one and is not equal to {"", "f"}.
3199                 - The string group consists of raw strings.
3200         """
3201         num_of_inline_string_comments = 0
3202         set_of_prefixes = set()
3203         num_of_strings = 0
3204         for leaf in line.leaves[string_idx:]:
3205             if leaf.type != token.STRING:
3206                 # If the string group is trailed by a comma, we count the
3207                 # comments trailing the comma to be one of the string group's
3208                 # comments.
3209                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3210                     num_of_inline_string_comments += 1
3211                 break
3212
3213             if has_triple_quotes(leaf.value):
3214                 return TErr("StringMerger does NOT merge multiline strings.")
3215
3216             num_of_strings += 1
3217             prefix = get_string_prefix(leaf.value)
3218             if "r" in prefix:
3219                 return TErr("StringMerger does NOT merge raw strings.")
3220
3221             set_of_prefixes.add(prefix)
3222
3223             if id(leaf) in line.comments:
3224                 num_of_inline_string_comments += 1
3225                 if contains_pragma_comment(line.comments[id(leaf)]):
3226                     return TErr("Cannot merge strings which have pragma comments.")
3227
3228         if num_of_strings < 2:
3229             return TErr(
3230                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3231             )
3232
3233         if num_of_inline_string_comments > 1:
3234             return TErr(
3235                 f"Too many inline string comments ({num_of_inline_string_comments})."
3236             )
3237
3238         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3239             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3240
3241         return Ok(None)
3242
3243
3244 class StringParenStripper(StringTransformer):
3245     """StringTransformer that strips surrounding parentheses from strings.
3246
3247     Requirements:
3248         The line contains a string which is surrounded by parentheses and:
3249             - The target string is NOT the only argument to a function call).
3250             - The RPAR is NOT followed by an attribute access (i.e. a dot).
3251
3252     Transformations:
3253         The parentheses mentioned in the 'Requirements' section are stripped.
3254
3255     Collaborations:
3256         StringParenStripper has its own inherent usefulness, but it is also
3257         relied on to clean up the parentheses created by StringParenWrapper (in
3258         the event that they are no longer needed).
3259     """
3260
3261     def do_match(self, line: Line) -> TMatchResult:
3262         LL = line.leaves
3263
3264         is_valid_index = is_valid_index_factory(LL)
3265
3266         for (idx, leaf) in enumerate(LL):
3267             # Should be a string...
3268             if leaf.type != token.STRING:
3269                 continue
3270
3271             # Should be preceded by a non-empty LPAR...
3272             if (
3273                 not is_valid_index(idx - 1)
3274                 or LL[idx - 1].type != token.LPAR
3275                 or is_empty_lpar(LL[idx - 1])
3276             ):
3277                 continue
3278
3279             # That LPAR should NOT be preceded by a function name or a closing
3280             # bracket (which could be a function which returns a function or a
3281             # list/dictionary that contains a function)...
3282             if is_valid_index(idx - 2) and (
3283                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3284             ):
3285                 continue
3286
3287             string_idx = idx
3288
3289             # Skip the string trailer, if one exists.
3290             string_parser = StringParser()
3291             next_idx = string_parser.parse(LL, string_idx)
3292
3293             # Should be followed by a non-empty RPAR...
3294             if (
3295                 is_valid_index(next_idx)
3296                 and LL[next_idx].type == token.RPAR
3297                 and not is_empty_rpar(LL[next_idx])
3298             ):
3299                 # That RPAR should NOT be followed by a '.' symbol.
3300                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type == token.DOT:
3301                     continue
3302
3303                 return Ok(string_idx)
3304
3305         return TErr("This line has no strings wrapped in parens.")
3306
3307     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3308         LL = line.leaves
3309
3310         string_parser = StringParser()
3311         rpar_idx = string_parser.parse(LL, string_idx)
3312
3313         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3314             if line.comments_after(leaf):
3315                 yield TErr(
3316                     "Will not strip parentheses which have comments attached to them."
3317                 )
3318
3319         new_line = line.clone()
3320         new_line.comments = line.comments.copy()
3321
3322         append_leaves(new_line, line, LL[: string_idx - 1])
3323
3324         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3325         LL[string_idx - 1].remove()
3326         replace_child(LL[string_idx], string_leaf)
3327         new_line.append(string_leaf)
3328
3329         append_leaves(
3330             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :],
3331         )
3332
3333         LL[rpar_idx].remove()
3334
3335         yield Ok(new_line)
3336
3337
3338 class BaseStringSplitter(StringTransformer):
3339     """
3340     Abstract class for StringTransformers which transform a Line's strings by splitting
3341     them or placing them on their own lines where necessary to avoid going over
3342     the configured line length.
3343
3344     Requirements:
3345         * The target string value is responsible for the line going over the
3346         line length limit. It follows that after all of black's other line
3347         split methods have been exhausted, this line (or one of the resulting
3348         lines after all line splits are performed) would still be over the
3349         line_length limit unless we split this string.
3350             AND
3351         * The target string is NOT a "pointless" string (i.e. a string that has
3352         no parent or siblings).
3353             AND
3354         * The target string is not followed by an inline comment that appears
3355         to be a pragma.
3356             AND
3357         * The target string is not a multiline (i.e. triple-quote) string.
3358     """
3359
3360     @abstractmethod
3361     def do_splitter_match(self, line: Line) -> TMatchResult:
3362         """
3363         BaseStringSplitter asks its clients to override this method instead of
3364         `StringTransformer.do_match(...)`.
3365
3366         Follows the same protocol as `StringTransformer.do_match(...)`.
3367
3368         Refer to `help(StringTransformer.do_match)` for more information.
3369         """
3370
3371     def do_match(self, line: Line) -> TMatchResult:
3372         match_result = self.do_splitter_match(line)
3373         if isinstance(match_result, Err):
3374             return match_result
3375
3376         string_idx = match_result.ok()
3377         vresult = self.__validate(line, string_idx)
3378         if isinstance(vresult, Err):
3379             return vresult
3380
3381         return match_result
3382
3383     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3384         """
3385         Checks that @line meets all of the requirements listed in this classes'
3386         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3387         description of those requirements.
3388
3389         Returns:
3390             * Ok(None), if ALL of the requirements are met.
3391                 OR
3392             * Err(CannotTransform), if ANY of the requirements are NOT met.
3393         """
3394         LL = line.leaves
3395
3396         string_leaf = LL[string_idx]
3397
3398         max_string_length = self.__get_max_string_length(line, string_idx)
3399         if len(string_leaf.value) <= max_string_length:
3400             return TErr(
3401                 "The string itself is not what is causing this line to be too long."
3402             )
3403
3404         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3405             token.STRING,
3406             token.NEWLINE,
3407         ]:
3408             return TErr(
3409                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3410                 " no parent)."
3411             )
3412
3413         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3414             line.comments[id(line.leaves[string_idx])]
3415         ):
3416             return TErr(
3417                 "Line appears to end with an inline pragma comment. Splitting the line"
3418                 " could modify the pragma's behavior."
3419             )
3420
3421         if has_triple_quotes(string_leaf.value):
3422             return TErr("We cannot split multiline strings.")
3423
3424         return Ok(None)
3425
3426     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3427         """
3428         Calculates the max string length used when attempting to determine
3429         whether or not the target string is responsible for causing the line to
3430         go over the line length limit.
3431
3432         WARNING: This method is tightly coupled to both StringSplitter and
3433         (especially) StringParenWrapper. There is probably a better way to
3434         accomplish what is being done here.
3435
3436         Returns:
3437             max_string_length: such that `line.leaves[string_idx].value >
3438             max_string_length` implies that the target string IS responsible
3439             for causing this line to exceed the line length limit.
3440         """
3441         LL = line.leaves
3442
3443         is_valid_index = is_valid_index_factory(LL)
3444
3445         # We use the shorthand "WMA4" in comments to abbreviate "We must
3446         # account for". When giving examples, we use STRING to mean some/any
3447         # valid string.
3448         #
3449         # Finally, we use the following convenience variables:
3450         #
3451         #   P:  The leaf that is before the target string leaf.
3452         #   N:  The leaf that is after the target string leaf.
3453         #   NN: The leaf that is after N.
3454
3455         # WMA4 the whitespace at the beginning of the line.
3456         offset = line.depth * 4
3457
3458         if is_valid_index(string_idx - 1):
3459             p_idx = string_idx - 1
3460             if (
3461                 LL[string_idx - 1].type == token.LPAR
3462                 and LL[string_idx - 1].value == ""
3463                 and string_idx >= 2
3464             ):
3465                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3466                 p_idx -= 1
3467
3468             P = LL[p_idx]
3469             if P.type == token.PLUS:
3470                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3471                 offset += 2
3472
3473             if P.type == token.COMMA:
3474                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3475                 offset += 3
3476
3477             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3478                 # This conditional branch is meant to handle dictionary keys,
3479                 # variable assignments, 'return STRING' statement lines, and
3480                 # 'else STRING' ternary expression lines.
3481
3482                 # WMA4 a single space.
3483                 offset += 1
3484
3485                 # WMA4 the lengths of any leaves that came before that space.
3486                 for leaf in LL[: p_idx + 1]:
3487                     offset += len(str(leaf))
3488
3489         if is_valid_index(string_idx + 1):
3490             N = LL[string_idx + 1]
3491             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3492                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3493                 N = LL[string_idx + 2]
3494
3495             if N.type == token.COMMA:
3496                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3497                 offset += 1
3498
3499             if is_valid_index(string_idx + 2):
3500                 NN = LL[string_idx + 2]
3501
3502                 if N.type == token.DOT and NN.type == token.NAME:
3503                     # This conditional branch is meant to handle method calls invoked
3504                     # off of a string literal up to and including the LPAR character.
3505
3506                     # WMA4 the '.' character.
3507                     offset += 1
3508
3509                     if (
3510                         is_valid_index(string_idx + 3)
3511                         and LL[string_idx + 3].type == token.LPAR
3512                     ):
3513                         # WMA4 the left parenthesis character.
3514                         offset += 1
3515
3516                     # WMA4 the length of the method's name.
3517                     offset += len(NN.value)
3518
3519         has_comments = False
3520         for comment_leaf in line.comments_after(LL[string_idx]):
3521             if not has_comments:
3522                 has_comments = True
3523                 # WMA4 two spaces before the '#' character.
3524                 offset += 2
3525
3526             # WMA4 the length of the inline comment.
3527             offset += len(comment_leaf.value)
3528
3529         max_string_length = self.line_length - offset
3530         return max_string_length
3531
3532
3533 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3534     """
3535     StringTransformer that splits "atom" strings (i.e. strings which exist on
3536     lines by themselves).
3537
3538     Requirements:
3539         * The line consists ONLY of a single string (with the exception of a
3540         '+' symbol which MAY exist at the start of the line), MAYBE a string
3541         trailer, and MAYBE a trailing comma.
3542             AND
3543         * All of the requirements listed in BaseStringSplitter's docstring.
3544
3545     Transformations:
3546         The string mentioned in the 'Requirements' section is split into as
3547         many substrings as necessary to adhere to the configured line length.
3548
3549         In the final set of substrings, no substring should be smaller than
3550         MIN_SUBSTR_SIZE characters.
3551
3552         The string will ONLY be split on spaces (i.e. each new substring should
3553         start with a space).
3554
3555         If the string is an f-string, it will NOT be split in the middle of an
3556         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3557         else bar()} is an f-expression).
3558
3559         If the string that is being split has an associated set of custom split
3560         records and those custom splits will NOT result in any line going over
3561         the configured line length, those custom splits are used. Otherwise the
3562         string is split as late as possible (from left-to-right) while still
3563         adhering to the transformation rules listed above.
3564
3565     Collaborations:
3566         StringSplitter relies on StringMerger to construct the appropriate
3567         CustomSplit objects and add them to the custom split map.
3568     """
3569
3570     MIN_SUBSTR_SIZE = 6
3571     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3572     RE_FEXPR = r"""
3573     (?<!\{)\{
3574         (?:
3575             [^\{\}]
3576             | \{\{
3577             | \}\}
3578         )+?
3579     (?<!\})(?:\}\})*\}(?!\})
3580     """
3581
3582     def do_splitter_match(self, line: Line) -> TMatchResult:
3583         LL = line.leaves
3584
3585         is_valid_index = is_valid_index_factory(LL)
3586
3587         idx = 0
3588
3589         # The first leaf MAY be a '+' symbol...
3590         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3591             idx += 1
3592
3593         # The next/first leaf MAY be an empty LPAR...
3594         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3595             idx += 1
3596
3597         # The next/first leaf MUST be a string...
3598         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3599             return TErr("Line does not start with a string.")
3600
3601         string_idx = idx
3602
3603         # Skip the string trailer, if one exists.
3604         string_parser = StringParser()
3605         idx = string_parser.parse(LL, string_idx)
3606
3607         # That string MAY be followed by an empty RPAR...
3608         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3609             idx += 1
3610
3611         # That string / empty RPAR leaf MAY be followed by a comma...
3612         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3613             idx += 1
3614
3615         # But no more leaves are allowed...
3616         if is_valid_index(idx):
3617             return TErr("This line does not end with a string.")
3618
3619         return Ok(string_idx)
3620
3621     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3622         LL = line.leaves
3623
3624         QUOTE = LL[string_idx].value[-1]
3625
3626         is_valid_index = is_valid_index_factory(LL)
3627         insert_str_child = insert_str_child_factory(LL[string_idx])
3628
3629         prefix = get_string_prefix(LL[string_idx].value)
3630
3631         # We MAY choose to drop the 'f' prefix from substrings that don't
3632         # contain any f-expressions, but ONLY if the original f-string
3633         # contains at least one f-expression. Otherwise, we will alter the AST
3634         # of the program.
3635         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3636             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3637         )
3638
3639         first_string_line = True
3640         starts_with_plus = LL[0].type == token.PLUS
3641
3642         def line_needs_plus() -> bool:
3643             return first_string_line and starts_with_plus
3644
3645         def maybe_append_plus(new_line: Line) -> None:
3646             """
3647             Side Effects:
3648                 If @line starts with a plus and this is the first line we are
3649                 constructing, this function appends a PLUS leaf to @new_line
3650                 and replaces the old PLUS leaf in the node structure. Otherwise
3651                 this function does nothing.
3652             """
3653             if line_needs_plus():
3654                 plus_leaf = Leaf(token.PLUS, "+")
3655                 replace_child(LL[0], plus_leaf)
3656                 new_line.append(plus_leaf)
3657
3658         ends_with_comma = (
3659             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3660         )
3661
3662         def max_last_string() -> int:
3663             """
3664             Returns:
3665                 The max allowed length of the string value used for the last
3666                 line we will construct.
3667             """
3668             result = self.line_length
3669             result -= line.depth * 4
3670             result -= 1 if ends_with_comma else 0
3671             result -= 2 if line_needs_plus() else 0
3672             return result
3673
3674         # --- Calculate Max Break Index (for string value)
3675         # We start with the line length limit
3676         max_break_idx = self.line_length
3677         # The last index of a string of length N is N-1.
3678         max_break_idx -= 1
3679         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3680         max_break_idx -= line.depth * 4
3681         if max_break_idx < 0:
3682             yield TErr(
3683                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3684                 f" {line.depth}"
3685             )
3686             return
3687
3688         # Check if StringMerger registered any custom splits.
3689         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3690         # We use them ONLY if none of them would produce lines that exceed the
3691         # line limit.
3692         use_custom_breakpoints = bool(
3693             custom_splits
3694             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3695         )
3696
3697         # Temporary storage for the remaining chunk of the string line that
3698         # can't fit onto the line currently being constructed.
3699         rest_value = LL[string_idx].value
3700
3701         def more_splits_should_be_made() -> bool:
3702             """
3703             Returns:
3704                 True iff `rest_value` (the remaining string value from the last
3705                 split), should be split again.
3706             """
3707             if use_custom_breakpoints:
3708                 return len(custom_splits) > 1
3709             else:
3710                 return len(rest_value) > max_last_string()
3711
3712         string_line_results: List[Ok[Line]] = []
3713         while more_splits_should_be_made():
3714             if use_custom_breakpoints:
3715                 # Custom User Split (manual)
3716                 csplit = custom_splits.pop(0)
3717                 break_idx = csplit.break_idx
3718             else:
3719                 # Algorithmic Split (automatic)
3720                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3721                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3722                 if maybe_break_idx is None:
3723                     # If we are unable to algorithmically determine a good split
3724                     # and this string has custom splits registered to it, we
3725                     # fall back to using them--which means we have to start
3726                     # over from the beginning.
3727                     if custom_splits:
3728                         rest_value = LL[string_idx].value
3729                         string_line_results = []
3730                         first_string_line = True
3731                         use_custom_breakpoints = True
3732                         continue
3733
3734                     # Otherwise, we stop splitting here.
3735                     break
3736
3737                 break_idx = maybe_break_idx
3738
3739             # --- Construct `next_value`
3740             next_value = rest_value[:break_idx] + QUOTE
3741             if (
3742                 # Are we allowed to try to drop a pointless 'f' prefix?
3743                 drop_pointless_f_prefix
3744                 # If we are, will we be successful?
3745                 and next_value != self.__normalize_f_string(next_value, prefix)
3746             ):
3747                 # If the current custom split did NOT originally use a prefix,
3748                 # then `csplit.break_idx` will be off by one after removing
3749                 # the 'f' prefix.
3750                 break_idx = (
3751                     break_idx + 1
3752                     if use_custom_breakpoints and not csplit.has_prefix
3753                     else break_idx
3754                 )
3755                 next_value = rest_value[:break_idx] + QUOTE
3756                 next_value = self.__normalize_f_string(next_value, prefix)
3757
3758             # --- Construct `next_leaf`
3759             next_leaf = Leaf(token.STRING, next_value)
3760             insert_str_child(next_leaf)
3761             self.__maybe_normalize_string_quotes(next_leaf)
3762
3763             # --- Construct `next_line`
3764             next_line = line.clone()
3765             maybe_append_plus(next_line)
3766             next_line.append(next_leaf)
3767             string_line_results.append(Ok(next_line))
3768
3769             rest_value = prefix + QUOTE + rest_value[break_idx:]
3770             first_string_line = False
3771
3772         yield from string_line_results
3773
3774         if drop_pointless_f_prefix:
3775             rest_value = self.__normalize_f_string(rest_value, prefix)
3776
3777         rest_leaf = Leaf(token.STRING, rest_value)
3778         insert_str_child(rest_leaf)
3779
3780         # NOTE: I could not find a test case that verifies that the following
3781         # line is actually necessary, but it seems to be. Otherwise we risk
3782         # not normalizing the last substring, right?
3783         self.__maybe_normalize_string_quotes(rest_leaf)
3784
3785         last_line = line.clone()
3786         maybe_append_plus(last_line)
3787
3788         # If there are any leaves to the right of the target string...
3789         if is_valid_index(string_idx + 1):
3790             # We use `temp_value` here to determine how long the last line
3791             # would be if we were to append all the leaves to the right of the
3792             # target string to the last string line.
3793             temp_value = rest_value
3794             for leaf in LL[string_idx + 1 :]:
3795                 temp_value += str(leaf)
3796                 if leaf.type == token.LPAR:
3797                     break
3798
3799             # Try to fit them all on the same line with the last substring...
3800             if (
3801                 len(temp_value) <= max_last_string()
3802                 or LL[string_idx + 1].type == token.COMMA
3803             ):
3804                 last_line.append(rest_leaf)
3805                 append_leaves(last_line, line, LL[string_idx + 1 :])
3806                 yield Ok(last_line)
3807             # Otherwise, place the last substring on one line and everything
3808             # else on a line below that...
3809             else:
3810                 last_line.append(rest_leaf)
3811                 yield Ok(last_line)
3812
3813                 non_string_line = line.clone()
3814                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3815                 yield Ok(non_string_line)
3816         # Else the target string was the last leaf...
3817         else:
3818             last_line.append(rest_leaf)
3819             last_line.comments = line.comments.copy()
3820             yield Ok(last_line)
3821
3822     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3823         """
3824         This method contains the algorithm that StringSplitter uses to
3825         determine which character to split each string at.
3826
3827         Args:
3828             @string: The substring that we are attempting to split.
3829             @max_break_idx: The ideal break index. We will return this value if it
3830             meets all the necessary conditions. In the likely event that it
3831             doesn't we will try to find the closest index BELOW @max_break_idx
3832             that does. If that fails, we will expand our search by also
3833             considering all valid indices ABOVE @max_break_idx.
3834
3835         Pre-Conditions:
3836             * assert_is_leaf_string(@string)
3837             * 0 <= @max_break_idx < len(@string)
3838
3839         Returns:
3840             break_idx, if an index is able to be found that meets all of the
3841             conditions listed in the 'Transformations' section of this classes'
3842             docstring.
3843                 OR
3844             None, otherwise.
3845         """
3846         is_valid_index = is_valid_index_factory(string)
3847
3848         assert is_valid_index(max_break_idx)
3849         assert_is_leaf_string(string)
3850
3851         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3852
3853         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3854             """
3855             Yields:
3856                 All ranges of @string which, if @string were to be split there,
3857                 would result in the splitting of an f-expression (which is NOT
3858                 allowed).
3859             """
3860             nonlocal _fexpr_slices
3861
3862             if _fexpr_slices is None:
3863                 _fexpr_slices = []
3864                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3865                     _fexpr_slices.append(match.span())
3866
3867             yield from _fexpr_slices
3868
3869         is_fstring = "f" in get_string_prefix(string)
3870
3871         def breaks_fstring_expression(i: Index) -> bool:
3872             """
3873             Returns:
3874                 True iff returning @i would result in the splitting of an
3875                 f-expression (which is NOT allowed).
3876             """
3877             if not is_fstring:
3878                 return False
3879
3880             for (start, end) in fexpr_slices():
3881                 if start <= i < end:
3882                     return True
3883
3884             return False
3885
3886         def passes_all_checks(i: Index) -> bool:
3887             """
3888             Returns:
3889                 True iff ALL of the conditions listed in the 'Transformations'
3890                 section of this classes' docstring would be be met by returning @i.
3891             """
3892             is_space = string[i] == " "
3893             is_big_enough = (
3894                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3895                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3896             )
3897             return is_space and is_big_enough and not breaks_fstring_expression(i)
3898
3899         # First, we check all indices BELOW @max_break_idx.
3900         break_idx = max_break_idx
3901         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3902             break_idx -= 1
3903
3904         if not passes_all_checks(break_idx):
3905             # If that fails, we check all indices ABOVE @max_break_idx.
3906             #
3907             # If we are able to find a valid index here, the next line is going
3908             # to be longer than the specified line length, but it's probably
3909             # better than doing nothing at all.
3910             break_idx = max_break_idx + 1
3911             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3912                 break_idx += 1
3913
3914             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3915                 return None
3916
3917         return break_idx
3918
3919     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3920         if self.normalize_strings:
3921             normalize_string_quotes(leaf)
3922
3923     def __normalize_f_string(self, string: str, prefix: str) -> str:
3924         """
3925         Pre-Conditions:
3926             * assert_is_leaf_string(@string)
3927
3928         Returns:
3929             * If @string is an f-string that contains no f-expressions, we
3930             return a string identical to @string except that the 'f' prefix
3931             has been stripped and all double braces (i.e. '{{' or '}}') have
3932             been normalized (i.e. turned into '{' or '}').
3933                 OR
3934             * Otherwise, we return @string.
3935         """
3936         assert_is_leaf_string(string)
3937
3938         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3939             new_prefix = prefix.replace("f", "")
3940
3941             temp = string[len(prefix) :]
3942             temp = re.sub(r"\{\{", "{", temp)
3943             temp = re.sub(r"\}\}", "}", temp)
3944             new_string = temp
3945
3946             return f"{new_prefix}{new_string}"
3947         else:
3948             return string
3949
3950
3951 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
3952     """
3953     StringTransformer that splits non-"atom" strings (i.e. strings that do not
3954     exist on lines by themselves).
3955
3956     Requirements:
3957         All of the requirements listed in BaseStringSplitter's docstring in
3958         addition to the requirements listed below:
3959
3960         * The line is a return/yield statement, which returns/yields a string.
3961             OR
3962         * The line is part of a ternary expression (e.g. `x = y if cond else
3963         z`) such that the line starts with `else <string>`, where <string> is
3964         some string.
3965             OR
3966         * The line is an assert statement, which ends with a string.
3967             OR
3968         * The line is an assignment statement (e.g. `x = <string>` or `x +=
3969         <string>`) such that the variable is being assigned the value of some
3970         string.
3971             OR
3972         * The line is a dictionary key assignment where some valid key is being
3973         assigned the value of some string.
3974
3975     Transformations:
3976         The chosen string is wrapped in parentheses and then split at the LPAR.
3977
3978         We then have one line which ends with an LPAR and another line that
3979         starts with the chosen string. The latter line is then split again at
3980         the RPAR. This results in the RPAR (and possibly a trailing comma)
3981         being placed on its own line.
3982
3983         NOTE: If any leaves exist to the right of the chosen string (except
3984         for a trailing comma, which would be placed after the RPAR), those
3985         leaves are placed inside the parentheses.  In effect, the chosen
3986         string is not necessarily being "wrapped" by parentheses. We can,
3987         however, count on the LPAR being placed directly before the chosen
3988         string.
3989
3990         In other words, StringParenWrapper creates "atom" strings. These
3991         can then be split again by StringSplitter, if necessary.
3992
3993     Collaborations:
3994         In the event that a string line split by StringParenWrapper is
3995         changed such that it no longer needs to be given its own line,
3996         StringParenWrapper relies on StringParenStripper to clean up the
3997         parentheses it created.
3998     """
3999
4000     def do_splitter_match(self, line: Line) -> TMatchResult:
4001         LL = line.leaves
4002
4003         string_idx = None
4004         string_idx = string_idx or self._return_match(LL)
4005         string_idx = string_idx or self._else_match(LL)
4006         string_idx = string_idx or self._assert_match(LL)
4007         string_idx = string_idx or self._assign_match(LL)
4008         string_idx = string_idx or self._dict_match(LL)
4009
4010         if string_idx is not None:
4011             string_value = line.leaves[string_idx].value
4012             # If the string has no spaces...
4013             if " " not in string_value:
4014                 # And will still violate the line length limit when split...
4015                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4016                 if len(string_value) > max_string_length:
4017                     # And has no associated custom splits...
4018                     if not self.has_custom_splits(string_value):
4019                         # Then we should NOT put this string on its own line.
4020                         return TErr(
4021                             "We do not wrap long strings in parentheses when the"
4022                             " resultant line would still be over the specified line"
4023                             " length and can't be split further by StringSplitter."
4024                         )
4025             return Ok(string_idx)
4026
4027         return TErr("This line does not contain any non-atomic strings.")
4028
4029     @staticmethod
4030     def _return_match(LL: List[Leaf]) -> Optional[int]:
4031         """
4032         Returns:
4033             string_idx such that @LL[string_idx] is equal to our target (i.e.
4034             matched) string, if this line matches the return/yield statement
4035             requirements listed in the 'Requirements' section of this classes'
4036             docstring.
4037                 OR
4038             None, otherwise.
4039         """
4040         # If this line is apart of a return/yield statement and the first leaf
4041         # contains either the "return" or "yield" keywords...
4042         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4043             0
4044         ].value in ["return", "yield"]:
4045             is_valid_index = is_valid_index_factory(LL)
4046
4047             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4048             # The next visible leaf MUST contain a string...
4049             if is_valid_index(idx) and LL[idx].type == token.STRING:
4050                 return idx
4051
4052         return None
4053
4054     @staticmethod
4055     def _else_match(LL: List[Leaf]) -> Optional[int]:
4056         """
4057         Returns:
4058             string_idx such that @LL[string_idx] is equal to our target (i.e.
4059             matched) string, if this line matches the ternary expression
4060             requirements listed in the 'Requirements' section of this classes'
4061             docstring.
4062                 OR
4063             None, otherwise.
4064         """
4065         # If this line is apart of a ternary expression and the first leaf
4066         # contains the "else" keyword...
4067         if (
4068             parent_type(LL[0]) == syms.test
4069             and LL[0].type == token.NAME
4070             and LL[0].value == "else"
4071         ):
4072             is_valid_index = is_valid_index_factory(LL)
4073
4074             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4075             # The next visible leaf MUST contain a string...
4076             if is_valid_index(idx) and LL[idx].type == token.STRING:
4077                 return idx
4078
4079         return None
4080
4081     @staticmethod
4082     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4083         """
4084         Returns:
4085             string_idx such that @LL[string_idx] is equal to our target (i.e.
4086             matched) string, if this line matches the assert statement
4087             requirements listed in the 'Requirements' section of this classes'
4088             docstring.
4089                 OR
4090             None, otherwise.
4091         """
4092         # If this line is apart of an assert statement and the first leaf
4093         # contains the "assert" keyword...
4094         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4095             is_valid_index = is_valid_index_factory(LL)
4096
4097             for (i, leaf) in enumerate(LL):
4098                 # We MUST find a comma...
4099                 if leaf.type == token.COMMA:
4100                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4101
4102                     # That comma MUST be followed by a string...
4103                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4104                         string_idx = idx
4105
4106                         # Skip the string trailer, if one exists.
4107                         string_parser = StringParser()
4108                         idx = string_parser.parse(LL, string_idx)
4109
4110                         # But no more leaves are allowed...
4111                         if not is_valid_index(idx):
4112                             return string_idx
4113
4114         return None
4115
4116     @staticmethod
4117     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4118         """
4119         Returns:
4120             string_idx such that @LL[string_idx] is equal to our target (i.e.
4121             matched) string, if this line matches the assignment statement
4122             requirements listed in the 'Requirements' section of this classes'
4123             docstring.
4124                 OR
4125             None, otherwise.
4126         """
4127         # If this line is apart of an expression statement or is a function
4128         # argument AND the first leaf contains a variable name...
4129         if (
4130             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4131             and LL[0].type == token.NAME
4132         ):
4133             is_valid_index = is_valid_index_factory(LL)
4134
4135             for (i, leaf) in enumerate(LL):
4136                 # We MUST find either an '=' or '+=' symbol...
4137                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4138                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4139
4140                     # That symbol MUST be followed by a string...
4141                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4142                         string_idx = idx
4143
4144                         # Skip the string trailer, if one exists.
4145                         string_parser = StringParser()
4146                         idx = string_parser.parse(LL, string_idx)
4147
4148                         # The next leaf MAY be a comma iff this line is apart
4149                         # of a function argument...
4150                         if (
4151                             parent_type(LL[0]) == syms.argument
4152                             and is_valid_index(idx)
4153                             and LL[idx].type == token.COMMA
4154                         ):
4155                             idx += 1
4156
4157                         # But no more leaves are allowed...
4158                         if not is_valid_index(idx):
4159                             return string_idx
4160
4161         return None
4162
4163     @staticmethod
4164     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4165         """
4166         Returns:
4167             string_idx such that @LL[string_idx] is equal to our target (i.e.
4168             matched) string, if this line matches the dictionary key assignment
4169             statement requirements listed in the 'Requirements' section of this
4170             classes' docstring.
4171                 OR
4172             None, otherwise.
4173         """
4174         # If this line is apart of a dictionary key assignment...
4175         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4176             is_valid_index = is_valid_index_factory(LL)
4177
4178             for (i, leaf) in enumerate(LL):
4179                 # We MUST find a colon...
4180                 if leaf.type == token.COLON:
4181                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4182
4183                     # That colon MUST be followed by a string...
4184                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4185                         string_idx = idx
4186
4187                         # Skip the string trailer, if one exists.
4188                         string_parser = StringParser()
4189                         idx = string_parser.parse(LL, string_idx)
4190
4191                         # That string MAY be followed by a comma...
4192                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4193                             idx += 1
4194
4195                         # But no more leaves are allowed...
4196                         if not is_valid_index(idx):
4197                             return string_idx
4198
4199         return None
4200
4201     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4202         LL = line.leaves
4203
4204         is_valid_index = is_valid_index_factory(LL)
4205         insert_str_child = insert_str_child_factory(LL[string_idx])
4206
4207         comma_idx = len(LL) - 1
4208         ends_with_comma = False
4209         if LL[comma_idx].type == token.COMMA:
4210             ends_with_comma = True
4211
4212         leaves_to_steal_comments_from = [LL[string_idx]]
4213         if ends_with_comma:
4214             leaves_to_steal_comments_from.append(LL[comma_idx])
4215
4216         # --- First Line
4217         first_line = line.clone()
4218         left_leaves = LL[:string_idx]
4219
4220         # We have to remember to account for (possibly invisible) LPAR and RPAR
4221         # leaves that already wrapped the target string. If these leaves do
4222         # exist, we will replace them with our own LPAR and RPAR leaves.
4223         old_parens_exist = False
4224         if left_leaves and left_leaves[-1].type == token.LPAR:
4225             old_parens_exist = True
4226             leaves_to_steal_comments_from.append(left_leaves[-1])
4227             left_leaves.pop()
4228
4229         append_leaves(first_line, line, left_leaves)
4230
4231         lpar_leaf = Leaf(token.LPAR, "(")
4232         if old_parens_exist:
4233             replace_child(LL[string_idx - 1], lpar_leaf)
4234         else:
4235             insert_str_child(lpar_leaf)
4236         first_line.append(lpar_leaf)
4237
4238         # We throw inline comments that were originally to the right of the
4239         # target string to the top line. They will now be shown to the right of
4240         # the LPAR.
4241         for leaf in leaves_to_steal_comments_from:
4242             for comment_leaf in line.comments_after(leaf):
4243                 first_line.append(comment_leaf, preformatted=True)
4244
4245         yield Ok(first_line)
4246
4247         # --- Middle (String) Line
4248         # We only need to yield one (possibly too long) string line, since the
4249         # `StringSplitter` will break it down further if necessary.
4250         string_value = LL[string_idx].value
4251         string_line = Line(
4252             depth=line.depth + 1,
4253             inside_brackets=True,
4254             should_explode=line.should_explode,
4255         )
4256         string_leaf = Leaf(token.STRING, string_value)
4257         insert_str_child(string_leaf)
4258         string_line.append(string_leaf)
4259
4260         old_rpar_leaf = None
4261         if is_valid_index(string_idx + 1):
4262             right_leaves = LL[string_idx + 1 :]
4263             if ends_with_comma:
4264                 right_leaves.pop()
4265
4266             if old_parens_exist:
4267                 assert (
4268                     right_leaves and right_leaves[-1].type == token.RPAR
4269                 ), "Apparently, old parentheses do NOT exist?!"
4270                 old_rpar_leaf = right_leaves.pop()
4271
4272             append_leaves(string_line, line, right_leaves)
4273
4274         yield Ok(string_line)
4275
4276         # --- Last Line
4277         last_line = line.clone()
4278         last_line.bracket_tracker = first_line.bracket_tracker
4279
4280         new_rpar_leaf = Leaf(token.RPAR, ")")
4281         if old_rpar_leaf is not None:
4282             replace_child(old_rpar_leaf, new_rpar_leaf)
4283         else:
4284             insert_str_child(new_rpar_leaf)
4285         last_line.append(new_rpar_leaf)
4286
4287         # If the target string ended with a comma, we place this comma to the
4288         # right of the RPAR on the last line.
4289         if ends_with_comma:
4290             comma_leaf = Leaf(token.COMMA, ",")
4291             replace_child(LL[comma_idx], comma_leaf)
4292             last_line.append(comma_leaf)
4293
4294         yield Ok(last_line)
4295
4296
4297 class StringParser:
4298     """
4299     A state machine that aids in parsing a string's "trailer", which can be
4300     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4301     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4302     varY)`).
4303
4304     NOTE: A new StringParser object MUST be instantiated for each string
4305     trailer we need to parse.
4306
4307     Examples:
4308         We shall assume that `line` equals the `Line` object that corresponds
4309         to the following line of python code:
4310         ```
4311         x = "Some {}.".format("String") + some_other_string
4312         ```
4313
4314         Furthermore, we will assume that `string_idx` is some index such that:
4315         ```
4316         assert line.leaves[string_idx].value == "Some {}."
4317         ```
4318
4319         The following code snippet then holds:
4320         ```
4321         string_parser = StringParser()
4322         idx = string_parser.parse(line.leaves, string_idx)
4323         assert line.leaves[idx].type == token.PLUS
4324         ```
4325     """
4326
4327     DEFAULT_TOKEN = -1
4328
4329     # String Parser States
4330     START = 1
4331     DOT = 2
4332     NAME = 3
4333     PERCENT = 4
4334     SINGLE_FMT_ARG = 5
4335     LPAR = 6
4336     RPAR = 7
4337     DONE = 8
4338
4339     # Lookup Table for Next State
4340     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4341         # A string trailer may start with '.' OR '%'.
4342         (START, token.DOT): DOT,
4343         (START, token.PERCENT): PERCENT,
4344         (START, DEFAULT_TOKEN): DONE,
4345         # A '.' MUST be followed by an attribute or method name.
4346         (DOT, token.NAME): NAME,
4347         # A method name MUST be followed by an '(', whereas an attribute name
4348         # is the last symbol in the string trailer.
4349         (NAME, token.LPAR): LPAR,
4350         (NAME, DEFAULT_TOKEN): DONE,
4351         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4352         # string or variable name).
4353         (PERCENT, token.LPAR): LPAR,
4354         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4355         # If a '%' symbol is followed by a single argument, that argument is
4356         # the last leaf in the string trailer.
4357         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4358         # If present, a ')' symbol is the last symbol in a string trailer.
4359         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4360         # since they are treated as a special case by the parsing logic in this
4361         # classes' implementation.)
4362         (RPAR, DEFAULT_TOKEN): DONE,
4363     }
4364
4365     def __init__(self) -> None:
4366         self._state = self.START
4367         self._unmatched_lpars = 0
4368
4369     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4370         """
4371         Pre-conditions:
4372             * @leaves[@string_idx].type == token.STRING
4373
4374         Returns:
4375             The index directly after the last leaf which is apart of the string
4376             trailer, if a "trailer" exists.
4377                 OR
4378             @string_idx + 1, if no string "trailer" exists.
4379         """
4380         assert leaves[string_idx].type == token.STRING
4381
4382         idx = string_idx + 1
4383         while idx < len(leaves) and self._next_state(leaves[idx]):
4384             idx += 1
4385         return idx
4386
4387     def _next_state(self, leaf: Leaf) -> bool:
4388         """
4389         Pre-conditions:
4390             * On the first call to this function, @leaf MUST be the leaf that
4391             was directly after the string leaf in question (e.g. if our target
4392             string is `line.leaves[i]` then the first call to this method must
4393             be `line.leaves[i + 1]`).
4394             * On the next call to this function, the leaf parameter passed in
4395             MUST be the leaf directly following @leaf.
4396
4397         Returns:
4398             True iff @leaf is apart of the string's trailer.
4399         """
4400         # We ignore empty LPAR or RPAR leaves.
4401         if is_empty_par(leaf):
4402             return True
4403
4404         next_token = leaf.type
4405         if next_token == token.LPAR:
4406             self._unmatched_lpars += 1
4407
4408         current_state = self._state
4409
4410         # The LPAR parser state is a special case. We will return True until we
4411         # find the matching RPAR token.
4412         if current_state == self.LPAR:
4413             if next_token == token.RPAR:
4414                 self._unmatched_lpars -= 1
4415                 if self._unmatched_lpars == 0:
4416                     self._state = self.RPAR
4417         # Otherwise, we use a lookup table to determine the next state.
4418         else:
4419             # If the lookup table matches the current state to the next
4420             # token, we use the lookup table.
4421             if (current_state, next_token) in self._goto:
4422                 self._state = self._goto[current_state, next_token]
4423             else:
4424                 # Otherwise, we check if a the current state was assigned a
4425                 # default.
4426                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4427                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4428                 # If no default has been assigned, then this parser has a logic
4429                 # error.
4430                 else:
4431                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4432
4433             if self._state == self.DONE:
4434                 return False
4435
4436         return True
4437
4438
4439 def TErr(err_msg: str) -> Err[CannotTransform]:
4440     """(T)ransform Err
4441
4442     Convenience function used when working with the TResult type.
4443     """
4444     cant_transform = CannotTransform(err_msg)
4445     return Err(cant_transform)
4446
4447
4448 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4449     """
4450     Returns:
4451         True iff one of the comments in @comment_list is a pragma used by one
4452         of the more common static analysis tools for python (e.g. mypy, flake8,
4453         pylint).
4454     """
4455     for comment in comment_list:
4456         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4457             return True
4458
4459     return False
4460
4461
4462 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4463     """
4464     Factory for a convenience function that is used to orphan @string_leaf
4465     and then insert multiple new leaves into the same part of the node
4466     structure that @string_leaf had originally occupied.
4467
4468     Examples:
4469         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4470         string_leaf.parent`. Assume the node `N` has the following
4471         original structure:
4472
4473         Node(
4474             expr_stmt, [
4475                 Leaf(NAME, 'x'),
4476                 Leaf(EQUAL, '='),
4477                 Leaf(STRING, '"foo"'),
4478             ]
4479         )
4480
4481         We then run the code snippet shown below.
4482         ```
4483         insert_str_child = insert_str_child_factory(string_leaf)
4484
4485         lpar = Leaf(token.LPAR, '(')
4486         insert_str_child(lpar)
4487
4488         bar = Leaf(token.STRING, '"bar"')
4489         insert_str_child(bar)
4490
4491         rpar = Leaf(token.RPAR, ')')
4492         insert_str_child(rpar)
4493         ```
4494
4495         After which point, it follows that `string_leaf.parent is None` and
4496         the node `N` now has the following structure:
4497
4498         Node(
4499             expr_stmt, [
4500                 Leaf(NAME, 'x'),
4501                 Leaf(EQUAL, '='),
4502                 Leaf(LPAR, '('),
4503                 Leaf(STRING, '"bar"'),
4504                 Leaf(RPAR, ')'),
4505             ]
4506         )
4507     """
4508     string_parent = string_leaf.parent
4509     string_child_idx = string_leaf.remove()
4510
4511     def insert_str_child(child: LN) -> None:
4512         nonlocal string_child_idx
4513
4514         assert string_parent is not None
4515         assert string_child_idx is not None
4516
4517         string_parent.insert_child(string_child_idx, child)
4518         string_child_idx += 1
4519
4520     return insert_str_child
4521
4522
4523 def has_triple_quotes(string: str) -> bool:
4524     """
4525     Returns:
4526         True iff @string starts with three quotation characters.
4527     """
4528     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4529     return raw_string[:3] in {'"""', "'''"}
4530
4531
4532 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4533     """
4534     Returns:
4535         @node.parent.type, if @node is not None and has a parent.
4536             OR
4537         None, otherwise.
4538     """
4539     if node is None or node.parent is None:
4540         return None
4541
4542     return node.parent.type
4543
4544
4545 def is_empty_par(leaf: Leaf) -> bool:
4546     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4547
4548
4549 def is_empty_lpar(leaf: Leaf) -> bool:
4550     return leaf.type == token.LPAR and leaf.value == ""
4551
4552
4553 def is_empty_rpar(leaf: Leaf) -> bool:
4554     return leaf.type == token.RPAR and leaf.value == ""
4555
4556
4557 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4558     """
4559     Examples:
4560         ```
4561         my_list = [1, 2, 3]
4562
4563         is_valid_index = is_valid_index_factory(my_list)
4564
4565         assert is_valid_index(0)
4566         assert is_valid_index(2)
4567
4568         assert not is_valid_index(3)
4569         assert not is_valid_index(-1)
4570         ```
4571     """
4572
4573     def is_valid_index(idx: int) -> bool:
4574         """
4575         Returns:
4576             True iff @idx is positive AND seq[@idx] does NOT raise an
4577             IndexError.
4578         """
4579         return 0 <= idx < len(seq)
4580
4581     return is_valid_index
4582
4583
4584 def line_to_string(line: Line) -> str:
4585     """Returns the string representation of @line.
4586
4587     WARNING: This is known to be computationally expensive.
4588     """
4589     return str(line).strip("\n")
4590
4591
4592 def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
4593     """
4594     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4595     underlying Node structure where appropriate.
4596
4597     All of the leaves in @leaves are duplicated. The duplicates are then
4598     appended to @new_line and used to replace their originals in the underlying
4599     Node structure. Any comments attached to the old leaves are reattached to
4600     the new leaves.
4601
4602     Pre-conditions:
4603         set(@leaves) is a subset of set(@old_line.leaves).
4604     """
4605     for old_leaf in leaves:
4606         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4607         replace_child(old_leaf, new_leaf)
4608         new_line.append(new_leaf)
4609
4610         for comment_leaf in old_line.comments_after(old_leaf):
4611             new_line.append(comment_leaf, preformatted=True)
4612
4613
4614 def replace_child(old_child: LN, new_child: LN) -> None:
4615     """
4616     Side Effects:
4617         * If @old_child.parent is set, replace @old_child with @new_child in
4618         @old_child's underlying Node structure.
4619             OR
4620         * Otherwise, this function does nothing.
4621     """
4622     parent = old_child.parent
4623     if not parent:
4624         return
4625
4626     child_idx = old_child.remove()
4627     if child_idx is not None:
4628         parent.insert_child(child_idx, new_child)
4629
4630
4631 def get_string_prefix(string: str) -> str:
4632     """
4633     Pre-conditions:
4634         * assert_is_leaf_string(@string)
4635
4636     Returns:
4637         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4638     """
4639     assert_is_leaf_string(string)
4640
4641     prefix = ""
4642     prefix_idx = 0
4643     while string[prefix_idx] in STRING_PREFIX_CHARS:
4644         prefix += string[prefix_idx].lower()
4645         prefix_idx += 1
4646
4647     return prefix
4648
4649
4650 def assert_is_leaf_string(string: str) -> None:
4651     """
4652     Checks the pre-condition that @string has the format that you would expect
4653     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4654     token.STRING`. A more precise description of the pre-conditions that are
4655     checked are listed below.
4656
4657     Pre-conditions:
4658         * @string starts with either ', ", <prefix>', or <prefix>" where
4659         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4660         * @string ends with a quote character (' or ").
4661
4662     Raises:
4663         AssertionError(...) if the pre-conditions listed above are not
4664         satisfied.
4665     """
4666     dquote_idx = string.find('"')
4667     squote_idx = string.find("'")
4668     if -1 in [dquote_idx, squote_idx]:
4669         quote_idx = max(dquote_idx, squote_idx)
4670     else:
4671         quote_idx = min(squote_idx, dquote_idx)
4672
4673     assert (
4674         0 <= quote_idx < len(string) - 1
4675     ), f"{string!r} is missing a starting quote character (' or \")."
4676     assert string[-1] in (
4677         "'",
4678         '"',
4679     ), f"{string!r} is missing an ending quote character (' or \")."
4680     assert set(string[:quote_idx]).issubset(
4681         set(STRING_PREFIX_CHARS)
4682     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4683
4684
4685 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4686     """Split line into many lines, starting with the first matching bracket pair.
4687
4688     Note: this usually looks weird, only use this for function definitions.
4689     Prefer RHS otherwise.  This is why this function is not symmetrical with
4690     :func:`right_hand_split` which also handles optional parentheses.
4691     """
4692     tail_leaves: List[Leaf] = []
4693     body_leaves: List[Leaf] = []
4694     head_leaves: List[Leaf] = []
4695     current_leaves = head_leaves
4696     matching_bracket: Optional[Leaf] = None
4697     for leaf in line.leaves:
4698         if (
4699             current_leaves is body_leaves
4700             and leaf.type in CLOSING_BRACKETS
4701             and leaf.opening_bracket is matching_bracket
4702         ):
4703             current_leaves = tail_leaves if body_leaves else head_leaves
4704         current_leaves.append(leaf)
4705         if current_leaves is head_leaves:
4706             if leaf.type in OPENING_BRACKETS:
4707                 matching_bracket = leaf
4708                 current_leaves = body_leaves
4709     if not matching_bracket:
4710         raise CannotSplit("No brackets found")
4711
4712     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4713     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4714     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4715     bracket_split_succeeded_or_raise(head, body, tail)
4716     for result in (head, body, tail):
4717         if result:
4718             yield result
4719
4720
4721 def right_hand_split(
4722     line: Line,
4723     line_length: int,
4724     features: Collection[Feature] = (),
4725     omit: Collection[LeafID] = (),
4726 ) -> Iterator[Line]:
4727     """Split line into many lines, starting with the last matching bracket pair.
4728
4729     If the split was by optional parentheses, attempt splitting without them, too.
4730     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4731     this split.
4732
4733     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4734     """
4735     tail_leaves: List[Leaf] = []
4736     body_leaves: List[Leaf] = []
4737     head_leaves: List[Leaf] = []
4738     current_leaves = tail_leaves
4739     opening_bracket: Optional[Leaf] = None
4740     closing_bracket: Optional[Leaf] = None
4741     for leaf in reversed(line.leaves):
4742         if current_leaves is body_leaves:
4743             if leaf is opening_bracket:
4744                 current_leaves = head_leaves if body_leaves else tail_leaves
4745         current_leaves.append(leaf)
4746         if current_leaves is tail_leaves:
4747             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4748                 opening_bracket = leaf.opening_bracket
4749                 closing_bracket = leaf
4750                 current_leaves = body_leaves
4751     if not (opening_bracket and closing_bracket and head_leaves):
4752         # If there is no opening or closing_bracket that means the split failed and
4753         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4754         # the matching `opening_bracket` wasn't available on `line` anymore.
4755         raise CannotSplit("No brackets found")
4756
4757     tail_leaves.reverse()
4758     body_leaves.reverse()
4759     head_leaves.reverse()
4760     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4761     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4762     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4763     bracket_split_succeeded_or_raise(head, body, tail)
4764     if (
4765         # the body shouldn't be exploded
4766         not body.should_explode
4767         # the opening bracket is an optional paren
4768         and opening_bracket.type == token.LPAR
4769         and not opening_bracket.value
4770         # the closing bracket is an optional paren
4771         and closing_bracket.type == token.RPAR
4772         and not closing_bracket.value
4773         # it's not an import (optional parens are the only thing we can split on
4774         # in this case; attempting a split without them is a waste of time)
4775         and not line.is_import
4776         # there are no standalone comments in the body
4777         and not body.contains_standalone_comments(0)
4778         # and we can actually remove the parens
4779         and can_omit_invisible_parens(body, line_length)
4780     ):
4781         omit = {id(closing_bracket), *omit}
4782         try:
4783             yield from right_hand_split(line, line_length, features=features, omit=omit)
4784             return
4785
4786         except CannotSplit:
4787             if not (
4788                 can_be_split(body)
4789                 or is_line_short_enough(body, line_length=line_length)
4790             ):
4791                 raise CannotSplit(
4792                     "Splitting failed, body is still too long and can't be split."
4793                 )
4794
4795             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4796                 raise CannotSplit(
4797                     "The current optional pair of parentheses is bound to fail to"
4798                     " satisfy the splitting algorithm because the head or the tail"
4799                     " contains multiline strings which by definition never fit one"
4800                     " line."
4801                 )
4802
4803     ensure_visible(opening_bracket)
4804     ensure_visible(closing_bracket)
4805     for result in (head, body, tail):
4806         if result:
4807             yield result
4808
4809
4810 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4811     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4812
4813     Do nothing otherwise.
4814
4815     A left- or right-hand split is based on a pair of brackets. Content before
4816     (and including) the opening bracket is left on one line, content inside the
4817     brackets is put on a separate line, and finally content starting with and
4818     following the closing bracket is put on a separate line.
4819
4820     Those are called `head`, `body`, and `tail`, respectively. If the split
4821     produced the same line (all content in `head`) or ended up with an empty `body`
4822     and the `tail` is just the closing bracket, then it's considered failed.
4823     """
4824     tail_len = len(str(tail).strip())
4825     if not body:
4826         if tail_len == 0:
4827             raise CannotSplit("Splitting brackets produced the same line")
4828
4829         elif tail_len < 3:
4830             raise CannotSplit(
4831                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4832                 " not worth it"
4833             )
4834
4835
4836 def bracket_split_build_line(
4837     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4838 ) -> Line:
4839     """Return a new line with given `leaves` and respective comments from `original`.
4840
4841     If `is_body` is True, the result line is one-indented inside brackets and as such
4842     has its first leaf's prefix normalized and a trailing comma added when expected.
4843     """
4844     result = Line(depth=original.depth)
4845     if is_body:
4846         result.inside_brackets = True
4847         result.depth += 1
4848         if leaves:
4849             # Since body is a new indent level, remove spurious leading whitespace.
4850             normalize_prefix(leaves[0], inside_brackets=True)
4851             # Ensure a trailing comma for imports and standalone function arguments, but
4852             # be careful not to add one after any comments or within type annotations.
4853             no_commas = (
4854                 original.is_def
4855                 and opening_bracket.value == "("
4856                 and not any(leaf.type == token.COMMA for leaf in leaves)
4857             )
4858
4859             if original.is_import or no_commas:
4860                 for i in range(len(leaves) - 1, -1, -1):
4861                     if leaves[i].type == STANDALONE_COMMENT:
4862                         continue
4863
4864                     if leaves[i].type != token.COMMA:
4865                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
4866                     break
4867
4868     # Populate the line
4869     for leaf in leaves:
4870         result.append(leaf, preformatted=True)
4871         for comment_after in original.comments_after(leaf):
4872             result.append(comment_after, preformatted=True)
4873     if is_body:
4874         result.should_explode = should_explode(result, opening_bracket)
4875     return result
4876
4877
4878 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4879     """Normalize prefix of the first leaf in every line returned by `split_func`.
4880
4881     This is a decorator over relevant split functions.
4882     """
4883
4884     @wraps(split_func)
4885     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4886         for line in split_func(line, features):
4887             normalize_prefix(line.leaves[0], inside_brackets=True)
4888             yield line
4889
4890     return split_wrapper
4891
4892
4893 @dont_increase_indentation
4894 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4895     """Split according to delimiters of the highest priority.
4896
4897     If the appropriate Features are given, the split will add trailing commas
4898     also in function signatures and calls that contain `*` and `**`.
4899     """
4900     try:
4901         last_leaf = line.leaves[-1]
4902     except IndexError:
4903         raise CannotSplit("Line empty")
4904
4905     bt = line.bracket_tracker
4906     try:
4907         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4908     except ValueError:
4909         raise CannotSplit("No delimiters found")
4910
4911     if delimiter_priority == DOT_PRIORITY:
4912         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4913             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4914
4915     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4916     lowest_depth = sys.maxsize
4917     trailing_comma_safe = True
4918
4919     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4920         """Append `leaf` to current line or to new line if appending impossible."""
4921         nonlocal current_line
4922         try:
4923             current_line.append_safe(leaf, preformatted=True)
4924         except ValueError:
4925             yield current_line
4926
4927             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4928             current_line.append(leaf)
4929
4930     for leaf in line.leaves:
4931         yield from append_to_line(leaf)
4932
4933         for comment_after in line.comments_after(leaf):
4934             yield from append_to_line(comment_after)
4935
4936         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4937         if leaf.bracket_depth == lowest_depth:
4938             if is_vararg(leaf, within={syms.typedargslist}):
4939                 trailing_comma_safe = (
4940                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4941                 )
4942             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4943                 trailing_comma_safe = (
4944                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4945                 )
4946
4947         leaf_priority = bt.delimiters.get(id(leaf))
4948         if leaf_priority == delimiter_priority:
4949             yield current_line
4950
4951             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4952     if current_line:
4953         if (
4954             trailing_comma_safe
4955             and delimiter_priority == COMMA_PRIORITY
4956             and current_line.leaves[-1].type != token.COMMA
4957             and current_line.leaves[-1].type != STANDALONE_COMMENT
4958         ):
4959             current_line.append(Leaf(token.COMMA, ","))
4960         yield current_line
4961
4962
4963 @dont_increase_indentation
4964 def standalone_comment_split(
4965     line: Line, features: Collection[Feature] = ()
4966 ) -> Iterator[Line]:
4967     """Split standalone comments from the rest of the line."""
4968     if not line.contains_standalone_comments(0):
4969         raise CannotSplit("Line does not have any standalone comments")
4970
4971     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4972
4973     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4974         """Append `leaf` to current line or to new line if appending impossible."""
4975         nonlocal current_line
4976         try:
4977             current_line.append_safe(leaf, preformatted=True)
4978         except ValueError:
4979             yield current_line
4980
4981             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4982             current_line.append(leaf)
4983
4984     for leaf in line.leaves:
4985         yield from append_to_line(leaf)
4986
4987         for comment_after in line.comments_after(leaf):
4988             yield from append_to_line(comment_after)
4989
4990     if current_line:
4991         yield current_line
4992
4993
4994 def is_import(leaf: Leaf) -> bool:
4995     """Return True if the given leaf starts an import statement."""
4996     p = leaf.parent
4997     t = leaf.type
4998     v = leaf.value
4999     return bool(
5000         t == token.NAME
5001         and (
5002             (v == "import" and p and p.type == syms.import_name)
5003             or (v == "from" and p and p.type == syms.import_from)
5004         )
5005     )
5006
5007
5008 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5009     """Return True if the given leaf is a special comment.
5010     Only returns true for type comments for now."""
5011     t = leaf.type
5012     v = leaf.value
5013     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5014
5015
5016 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5017     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5018     else.
5019
5020     Note: don't use backslashes for formatting or you'll lose your voting rights.
5021     """
5022     if not inside_brackets:
5023         spl = leaf.prefix.split("#")
5024         if "\\" not in spl[0]:
5025             nl_count = spl[-1].count("\n")
5026             if len(spl) > 1:
5027                 nl_count -= 1
5028             leaf.prefix = "\n" * nl_count
5029             return
5030
5031     leaf.prefix = ""
5032
5033
5034 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5035     """Make all string prefixes lowercase.
5036
5037     If remove_u_prefix is given, also removes any u prefix from the string.
5038
5039     Note: Mutates its argument.
5040     """
5041     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5042     assert match is not None, f"failed to match string {leaf.value!r}"
5043     orig_prefix = match.group(1)
5044     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5045     if remove_u_prefix:
5046         new_prefix = new_prefix.replace("u", "")
5047     leaf.value = f"{new_prefix}{match.group(2)}"
5048
5049
5050 def normalize_string_quotes(leaf: Leaf) -> None:
5051     """Prefer double quotes but only if it doesn't cause more escaping.
5052
5053     Adds or removes backslashes as appropriate. Doesn't parse and fix
5054     strings nested in f-strings (yet).
5055
5056     Note: Mutates its argument.
5057     """
5058     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5059     if value[:3] == '"""':
5060         return
5061
5062     elif value[:3] == "'''":
5063         orig_quote = "'''"
5064         new_quote = '"""'
5065     elif value[0] == '"':
5066         orig_quote = '"'
5067         new_quote = "'"
5068     else:
5069         orig_quote = "'"
5070         new_quote = '"'
5071     first_quote_pos = leaf.value.find(orig_quote)
5072     if first_quote_pos == -1:
5073         return  # There's an internal error
5074
5075     prefix = leaf.value[:first_quote_pos]
5076     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5077     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5078     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5079     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5080     if "r" in prefix.casefold():
5081         if unescaped_new_quote.search(body):
5082             # There's at least one unescaped new_quote in this raw string
5083             # so converting is impossible
5084             return
5085
5086         # Do not introduce or remove backslashes in raw strings
5087         new_body = body
5088     else:
5089         # remove unnecessary escapes
5090         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5091         if body != new_body:
5092             # Consider the string without unnecessary escapes as the original
5093             body = new_body
5094             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5095         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5096         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5097     if "f" in prefix.casefold():
5098         matches = re.findall(
5099             r"""
5100             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5101                 ([^{].*?)  # contents of the brackets except if begins with {{
5102             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5103             """,
5104             new_body,
5105             re.VERBOSE,
5106         )
5107         for m in matches:
5108             if "\\" in str(m):
5109                 # Do not introduce backslashes in interpolated expressions
5110                 return
5111
5112     if new_quote == '"""' and new_body[-1:] == '"':
5113         # edge case:
5114         new_body = new_body[:-1] + '\\"'
5115     orig_escape_count = body.count("\\")
5116     new_escape_count = new_body.count("\\")
5117     if new_escape_count > orig_escape_count:
5118         return  # Do not introduce more escaping
5119
5120     if new_escape_count == orig_escape_count and orig_quote == '"':
5121         return  # Prefer double quotes
5122
5123     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5124
5125
5126 def normalize_numeric_literal(leaf: Leaf) -> None:
5127     """Normalizes numeric (float, int, and complex) literals.
5128
5129     All letters used in the representation are normalized to lowercase (except
5130     in Python 2 long literals).
5131     """
5132     text = leaf.value.lower()
5133     if text.startswith(("0o", "0b")):
5134         # Leave octal and binary literals alone.
5135         pass
5136     elif text.startswith("0x"):
5137         # Change hex literals to upper case.
5138         before, after = text[:2], text[2:]
5139         text = f"{before}{after.upper()}"
5140     elif "e" in text:
5141         before, after = text.split("e")
5142         sign = ""
5143         if after.startswith("-"):
5144             after = after[1:]
5145             sign = "-"
5146         elif after.startswith("+"):
5147             after = after[1:]
5148         before = format_float_or_int_string(before)
5149         text = f"{before}e{sign}{after}"
5150     elif text.endswith(("j", "l")):
5151         number = text[:-1]
5152         suffix = text[-1]
5153         # Capitalize in "2L" because "l" looks too similar to "1".
5154         if suffix == "l":
5155             suffix = "L"
5156         text = f"{format_float_or_int_string(number)}{suffix}"
5157     else:
5158         text = format_float_or_int_string(text)
5159     leaf.value = text
5160
5161
5162 def format_float_or_int_string(text: str) -> str:
5163     """Formats a float string like "1.0"."""
5164     if "." not in text:
5165         return text
5166
5167     before, after = text.split(".")
5168     return f"{before or 0}.{after or 0}"
5169
5170
5171 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5172     """Make existing optional parentheses invisible or create new ones.
5173
5174     `parens_after` is a set of string leaf values immediately after which parens
5175     should be put.
5176
5177     Standardizes on visible parentheses for single-element tuples, and keeps
5178     existing visible parentheses for other tuples and generator expressions.
5179     """
5180     for pc in list_comments(node.prefix, is_endmarker=False):
5181         if pc.value in FMT_OFF:
5182             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5183             return
5184     check_lpar = False
5185     for index, child in enumerate(list(node.children)):
5186         # Fixes a bug where invisible parens are not properly stripped from
5187         # assignment statements that contain type annotations.
5188         if isinstance(child, Node) and child.type == syms.annassign:
5189             normalize_invisible_parens(child, parens_after=parens_after)
5190
5191         # Add parentheses around long tuple unpacking in assignments.
5192         if (
5193             index == 0
5194             and isinstance(child, Node)
5195             and child.type == syms.testlist_star_expr
5196         ):
5197             check_lpar = True
5198
5199         if check_lpar:
5200             if is_walrus_assignment(child):
5201                 continue
5202
5203             if child.type == syms.atom:
5204                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5205                     wrap_in_parentheses(node, child, visible=False)
5206             elif is_one_tuple(child):
5207                 wrap_in_parentheses(node, child, visible=True)
5208             elif node.type == syms.import_from:
5209                 # "import from" nodes store parentheses directly as part of
5210                 # the statement
5211                 if child.type == token.LPAR:
5212                     # make parentheses invisible
5213                     child.value = ""  # type: ignore
5214                     node.children[-1].value = ""  # type: ignore
5215                 elif child.type != token.STAR:
5216                     # insert invisible parentheses
5217                     node.insert_child(index, Leaf(token.LPAR, ""))
5218                     node.append_child(Leaf(token.RPAR, ""))
5219                 break
5220
5221             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5222                 wrap_in_parentheses(node, child, visible=False)
5223
5224         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5225
5226
5227 def normalize_fmt_off(node: Node) -> None:
5228     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5229     try_again = True
5230     while try_again:
5231         try_again = convert_one_fmt_off_pair(node)
5232
5233
5234 def convert_one_fmt_off_pair(node: Node) -> bool:
5235     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5236
5237     Returns True if a pair was converted.
5238     """
5239     for leaf in node.leaves():
5240         previous_consumed = 0
5241         for comment in list_comments(leaf.prefix, is_endmarker=False):
5242             if comment.value in FMT_OFF:
5243                 # We only want standalone comments. If there's no previous leaf or
5244                 # the previous leaf is indentation, it's a standalone comment in
5245                 # disguise.
5246                 if comment.type != STANDALONE_COMMENT:
5247                     prev = preceding_leaf(leaf)
5248                     if prev and prev.type not in WHITESPACE:
5249                         continue
5250
5251                 ignored_nodes = list(generate_ignored_nodes(leaf))
5252                 if not ignored_nodes:
5253                     continue
5254
5255                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5256                 parent = first.parent
5257                 prefix = first.prefix
5258                 first.prefix = prefix[comment.consumed :]
5259                 hidden_value = (
5260                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5261                 )
5262                 if hidden_value.endswith("\n"):
5263                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5264                     # leaf (possibly followed by a DEDENT).
5265                     hidden_value = hidden_value[:-1]
5266                 first_idx: Optional[int] = None
5267                 for ignored in ignored_nodes:
5268                     index = ignored.remove()
5269                     if first_idx is None:
5270                         first_idx = index
5271                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5272                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5273                 parent.insert_child(
5274                     first_idx,
5275                     Leaf(
5276                         STANDALONE_COMMENT,
5277                         hidden_value,
5278                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5279                     ),
5280                 )
5281                 return True
5282
5283             previous_consumed = comment.consumed
5284
5285     return False
5286
5287
5288 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5289     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5290
5291     Stops at the end of the block.
5292     """
5293     container: Optional[LN] = container_of(leaf)
5294     while container is not None and container.type != token.ENDMARKER:
5295         if is_fmt_on(container):
5296             return
5297
5298         # fix for fmt: on in children
5299         if contains_fmt_on_at_column(container, leaf.column):
5300             for child in container.children:
5301                 if contains_fmt_on_at_column(child, leaf.column):
5302                     return
5303                 yield child
5304         else:
5305             yield container
5306             container = container.next_sibling
5307
5308
5309 def is_fmt_on(container: LN) -> bool:
5310     """Determine whether formatting is switched on within a container.
5311     Determined by whether the last `# fmt:` comment is `on` or `off`.
5312     """
5313     fmt_on = False
5314     for comment in list_comments(container.prefix, is_endmarker=False):
5315         if comment.value in FMT_ON:
5316             fmt_on = True
5317         elif comment.value in FMT_OFF:
5318             fmt_on = False
5319     return fmt_on
5320
5321
5322 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5323     """Determine if children at a given column have formatting switched on."""
5324     for child in container.children:
5325         if (
5326             isinstance(child, Node)
5327             and first_leaf_column(child) == column
5328             or isinstance(child, Leaf)
5329             and child.column == column
5330         ):
5331             if is_fmt_on(child):
5332                 return True
5333
5334     return False
5335
5336
5337 def first_leaf_column(node: Node) -> Optional[int]:
5338     """Returns the column of the first leaf child of a node."""
5339     for child in node.children:
5340         if isinstance(child, Leaf):
5341             return child.column
5342     return None
5343
5344
5345 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5346     """If it's safe, make the parens in the atom `node` invisible, recursively.
5347     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5348     as they are redundant.
5349
5350     Returns whether the node should itself be wrapped in invisible parentheses.
5351
5352     """
5353     if (
5354         node.type != syms.atom
5355         or is_empty_tuple(node)
5356         or is_one_tuple(node)
5357         or (is_yield(node) and parent.type != syms.expr_stmt)
5358         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5359     ):
5360         return False
5361
5362     first = node.children[0]
5363     last = node.children[-1]
5364     if first.type == token.LPAR and last.type == token.RPAR:
5365         middle = node.children[1]
5366         # make parentheses invisible
5367         first.value = ""  # type: ignore
5368         last.value = ""  # type: ignore
5369         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5370
5371         if is_atom_with_invisible_parens(middle):
5372             # Strip the invisible parens from `middle` by replacing
5373             # it with the child in-between the invisible parens
5374             middle.replace(middle.children[1])
5375
5376         return False
5377
5378     return True
5379
5380
5381 def is_atom_with_invisible_parens(node: LN) -> bool:
5382     """Given a `LN`, determines whether it's an atom `node` with invisible
5383     parens. Useful in dedupe-ing and normalizing parens.
5384     """
5385     if isinstance(node, Leaf) or node.type != syms.atom:
5386         return False
5387
5388     first, last = node.children[0], node.children[-1]
5389     return (
5390         isinstance(first, Leaf)
5391         and first.type == token.LPAR
5392         and first.value == ""
5393         and isinstance(last, Leaf)
5394         and last.type == token.RPAR
5395         and last.value == ""
5396     )
5397
5398
5399 def is_empty_tuple(node: LN) -> bool:
5400     """Return True if `node` holds an empty tuple."""
5401     return (
5402         node.type == syms.atom
5403         and len(node.children) == 2
5404         and node.children[0].type == token.LPAR
5405         and node.children[1].type == token.RPAR
5406     )
5407
5408
5409 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5410     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5411
5412     Parenthesis can be optional. Returns None otherwise"""
5413     if len(node.children) != 3:
5414         return None
5415
5416     lpar, wrapped, rpar = node.children
5417     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5418         return None
5419
5420     return wrapped
5421
5422
5423 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5424     """Wrap `child` in parentheses.
5425
5426     This replaces `child` with an atom holding the parentheses and the old
5427     child.  That requires moving the prefix.
5428
5429     If `visible` is False, the leaves will be valueless (and thus invisible).
5430     """
5431     lpar = Leaf(token.LPAR, "(" if visible else "")
5432     rpar = Leaf(token.RPAR, ")" if visible else "")
5433     prefix = child.prefix
5434     child.prefix = ""
5435     index = child.remove() or 0
5436     new_child = Node(syms.atom, [lpar, child, rpar])
5437     new_child.prefix = prefix
5438     parent.insert_child(index, new_child)
5439
5440
5441 def is_one_tuple(node: LN) -> bool:
5442     """Return True if `node` holds a tuple with one element, with or without parens."""
5443     if node.type == syms.atom:
5444         gexp = unwrap_singleton_parenthesis(node)
5445         if gexp is None or gexp.type != syms.testlist_gexp:
5446             return False
5447
5448         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5449
5450     return (
5451         node.type in IMPLICIT_TUPLE
5452         and len(node.children) == 2
5453         and node.children[1].type == token.COMMA
5454     )
5455
5456
5457 def is_walrus_assignment(node: LN) -> bool:
5458     """Return True iff `node` is of the shape ( test := test )"""
5459     inner = unwrap_singleton_parenthesis(node)
5460     return inner is not None and inner.type == syms.namedexpr_test
5461
5462
5463 def is_yield(node: LN) -> bool:
5464     """Return True if `node` holds a `yield` or `yield from` expression."""
5465     if node.type == syms.yield_expr:
5466         return True
5467
5468     if node.type == token.NAME and node.value == "yield":  # type: ignore
5469         return True
5470
5471     if node.type != syms.atom:
5472         return False
5473
5474     if len(node.children) != 3:
5475         return False
5476
5477     lpar, expr, rpar = node.children
5478     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5479         return is_yield(expr)
5480
5481     return False
5482
5483
5484 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5485     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5486
5487     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5488     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5489     extended iterable unpacking (PEP 3132) and additional unpacking
5490     generalizations (PEP 448).
5491     """
5492     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5493         return False
5494
5495     p = leaf.parent
5496     if p.type == syms.star_expr:
5497         # Star expressions are also used as assignment targets in extended
5498         # iterable unpacking (PEP 3132).  See what its parent is instead.
5499         if not p.parent:
5500             return False
5501
5502         p = p.parent
5503
5504     return p.type in within
5505
5506
5507 def is_multiline_string(leaf: Leaf) -> bool:
5508     """Return True if `leaf` is a multiline string that actually spans many lines."""
5509     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5510
5511
5512 def is_stub_suite(node: Node) -> bool:
5513     """Return True if `node` is a suite with a stub body."""
5514     if (
5515         len(node.children) != 4
5516         or node.children[0].type != token.NEWLINE
5517         or node.children[1].type != token.INDENT
5518         or node.children[3].type != token.DEDENT
5519     ):
5520         return False
5521
5522     return is_stub_body(node.children[2])
5523
5524
5525 def is_stub_body(node: LN) -> bool:
5526     """Return True if `node` is a simple statement containing an ellipsis."""
5527     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5528         return False
5529
5530     if len(node.children) != 2:
5531         return False
5532
5533     child = node.children[0]
5534     return (
5535         child.type == syms.atom
5536         and len(child.children) == 3
5537         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5538     )
5539
5540
5541 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5542     """Return maximum delimiter priority inside `node`.
5543
5544     This is specific to atoms with contents contained in a pair of parentheses.
5545     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5546     """
5547     if node.type != syms.atom:
5548         return 0
5549
5550     first = node.children[0]
5551     last = node.children[-1]
5552     if not (first.type == token.LPAR and last.type == token.RPAR):
5553         return 0
5554
5555     bt = BracketTracker()
5556     for c in node.children[1:-1]:
5557         if isinstance(c, Leaf):
5558             bt.mark(c)
5559         else:
5560             for leaf in c.leaves():
5561                 bt.mark(leaf)
5562     try:
5563         return bt.max_delimiter_priority()
5564
5565     except ValueError:
5566         return 0
5567
5568
5569 def ensure_visible(leaf: Leaf) -> None:
5570     """Make sure parentheses are visible.
5571
5572     They could be invisible as part of some statements (see
5573     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5574     """
5575     if leaf.type == token.LPAR:
5576         leaf.value = "("
5577     elif leaf.type == token.RPAR:
5578         leaf.value = ")"
5579
5580
5581 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
5582     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
5583
5584     if not (
5585         opening_bracket.parent
5586         and opening_bracket.parent.type in {syms.atom, syms.import_from}
5587         and opening_bracket.value in "[{("
5588     ):
5589         return False
5590
5591     try:
5592         last_leaf = line.leaves[-1]
5593         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
5594         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5595     except (IndexError, ValueError):
5596         return False
5597
5598     return max_priority == COMMA_PRIORITY
5599
5600
5601 def get_features_used(node: Node) -> Set[Feature]:
5602     """Return a set of (relatively) new Python features used in this file.
5603
5604     Currently looking for:
5605     - f-strings;
5606     - underscores in numeric literals;
5607     - trailing commas after * or ** in function signatures and calls;
5608     - positional only arguments in function signatures and lambdas;
5609     """
5610     features: Set[Feature] = set()
5611     for n in node.pre_order():
5612         if n.type == token.STRING:
5613             value_head = n.value[:2]  # type: ignore
5614             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5615                 features.add(Feature.F_STRINGS)
5616
5617         elif n.type == token.NUMBER:
5618             if "_" in n.value:  # type: ignore
5619                 features.add(Feature.NUMERIC_UNDERSCORES)
5620
5621         elif n.type == token.SLASH:
5622             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5623                 features.add(Feature.POS_ONLY_ARGUMENTS)
5624
5625         elif n.type == token.COLONEQUAL:
5626             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5627
5628         elif (
5629             n.type in {syms.typedargslist, syms.arglist}
5630             and n.children
5631             and n.children[-1].type == token.COMMA
5632         ):
5633             if n.type == syms.typedargslist:
5634                 feature = Feature.TRAILING_COMMA_IN_DEF
5635             else:
5636                 feature = Feature.TRAILING_COMMA_IN_CALL
5637
5638             for ch in n.children:
5639                 if ch.type in STARS:
5640                     features.add(feature)
5641
5642                 if ch.type == syms.argument:
5643                     for argch in ch.children:
5644                         if argch.type in STARS:
5645                             features.add(feature)
5646
5647     return features
5648
5649
5650 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5651     """Detect the version to target based on the nodes used."""
5652     features = get_features_used(node)
5653     return {
5654         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5655     }
5656
5657
5658 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5659     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5660
5661     Brackets can be omitted if the entire trailer up to and including
5662     a preceding closing bracket fits in one line.
5663
5664     Yielded sets are cumulative (contain results of previous yields, too).  First
5665     set is empty.
5666     """
5667
5668     omit: Set[LeafID] = set()
5669     yield omit
5670
5671     length = 4 * line.depth
5672     opening_bracket: Optional[Leaf] = None
5673     closing_bracket: Optional[Leaf] = None
5674     inner_brackets: Set[LeafID] = set()
5675     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5676         length += leaf_length
5677         if length > line_length:
5678             break
5679
5680         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5681         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5682             break
5683
5684         if opening_bracket:
5685             if leaf is opening_bracket:
5686                 opening_bracket = None
5687             elif leaf.type in CLOSING_BRACKETS:
5688                 inner_brackets.add(id(leaf))
5689         elif leaf.type in CLOSING_BRACKETS:
5690             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
5691                 # Empty brackets would fail a split so treat them as "inner"
5692                 # brackets (e.g. only add them to the `omit` set if another
5693                 # pair of brackets was good enough.
5694                 inner_brackets.add(id(leaf))
5695                 continue
5696
5697             if closing_bracket:
5698                 omit.add(id(closing_bracket))
5699                 omit.update(inner_brackets)
5700                 inner_brackets.clear()
5701                 yield omit
5702
5703             if leaf.value:
5704                 opening_bracket = leaf.opening_bracket
5705                 closing_bracket = leaf
5706
5707
5708 def get_future_imports(node: Node) -> Set[str]:
5709     """Return a set of __future__ imports in the file."""
5710     imports: Set[str] = set()
5711
5712     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5713         for child in children:
5714             if isinstance(child, Leaf):
5715                 if child.type == token.NAME:
5716                     yield child.value
5717
5718             elif child.type == syms.import_as_name:
5719                 orig_name = child.children[0]
5720                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5721                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5722                 yield orig_name.value
5723
5724             elif child.type == syms.import_as_names:
5725                 yield from get_imports_from_children(child.children)
5726
5727             else:
5728                 raise AssertionError("Invalid syntax parsing imports")
5729
5730     for child in node.children:
5731         if child.type != syms.simple_stmt:
5732             break
5733
5734         first_child = child.children[0]
5735         if isinstance(first_child, Leaf):
5736             # Continue looking if we see a docstring; otherwise stop.
5737             if (
5738                 len(child.children) == 2
5739                 and first_child.type == token.STRING
5740                 and child.children[1].type == token.NEWLINE
5741             ):
5742                 continue
5743
5744             break
5745
5746         elif first_child.type == syms.import_from:
5747             module_name = first_child.children[1]
5748             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5749                 break
5750
5751             imports |= set(get_imports_from_children(first_child.children[3:]))
5752         else:
5753             break
5754
5755     return imports
5756
5757
5758 @lru_cache()
5759 def get_gitignore(root: Path) -> PathSpec:
5760     """ Return a PathSpec matching gitignore content if present."""
5761     gitignore = root / ".gitignore"
5762     lines: List[str] = []
5763     if gitignore.is_file():
5764         with gitignore.open() as gf:
5765             lines = gf.readlines()
5766     return PathSpec.from_lines("gitwildmatch", lines)
5767
5768
5769 def normalize_path_maybe_ignore(
5770     path: Path, root: Path, report: "Report"
5771 ) -> Optional[str]:
5772     """Normalize `path`. May return `None` if `path` was ignored.
5773
5774     `report` is where "path ignored" output goes.
5775     """
5776     try:
5777         normalized_path = path.resolve().relative_to(root).as_posix()
5778     except OSError as e:
5779         report.path_ignored(path, f"cannot be read because {e}")
5780         return None
5781
5782     except ValueError:
5783         if path.is_symlink():
5784             report.path_ignored(path, f"is a symbolic link that points outside {root}")
5785             return None
5786
5787         raise
5788
5789     return normalized_path
5790
5791
5792 def gen_python_files(
5793     paths: Iterable[Path],
5794     root: Path,
5795     include: Optional[Pattern[str]],
5796     exclude: Pattern[str],
5797     force_exclude: Optional[Pattern[str]],
5798     report: "Report",
5799     gitignore: PathSpec,
5800 ) -> Iterator[Path]:
5801     """Generate all files under `path` whose paths are not excluded by the
5802     `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
5803
5804     Symbolic links pointing outside of the `root` directory are ignored.
5805
5806     `report` is where output about exclusions goes.
5807     """
5808     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5809     for child in paths:
5810         normalized_path = normalize_path_maybe_ignore(child, root, report)
5811         if normalized_path is None:
5812             continue
5813
5814         # First ignore files matching .gitignore
5815         if gitignore.match_file(normalized_path):
5816             report.path_ignored(child, "matches the .gitignore file content")
5817             continue
5818
5819         # Then ignore with `--exclude` and `--force-exclude` options.
5820         normalized_path = "/" + normalized_path
5821         if child.is_dir():
5822             normalized_path += "/"
5823
5824         exclude_match = exclude.search(normalized_path) if exclude else None
5825         if exclude_match and exclude_match.group(0):
5826             report.path_ignored(child, "matches the --exclude regular expression")
5827             continue
5828
5829         force_exclude_match = (
5830             force_exclude.search(normalized_path) if force_exclude else None
5831         )
5832         if force_exclude_match and force_exclude_match.group(0):
5833             report.path_ignored(child, "matches the --force-exclude regular expression")
5834             continue
5835
5836         if child.is_dir():
5837             yield from gen_python_files(
5838                 child.iterdir(),
5839                 root,
5840                 include,
5841                 exclude,
5842                 force_exclude,
5843                 report,
5844                 gitignore,
5845             )
5846
5847         elif child.is_file():
5848             include_match = include.search(normalized_path) if include else True
5849             if include_match:
5850                 yield child
5851
5852
5853 @lru_cache()
5854 def find_project_root(srcs: Iterable[str]) -> Path:
5855     """Return a directory containing .git, .hg, or pyproject.toml.
5856
5857     That directory will be a common parent of all files and directories
5858     passed in `srcs`.
5859
5860     If no directory in the tree contains a marker that would specify it's the
5861     project root, the root of the file system is returned.
5862     """
5863     if not srcs:
5864         return Path("/").resolve()
5865
5866     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
5867
5868     # A list of lists of parents for each 'src'. 'src' is included as a
5869     # "parent" of itself if it is a directory
5870     src_parents = [
5871         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
5872     ]
5873
5874     common_base = max(
5875         set.intersection(*(set(parents) for parents in src_parents)),
5876         key=lambda path: path.parts,
5877     )
5878
5879     for directory in (common_base, *common_base.parents):
5880         if (directory / ".git").exists():
5881             return directory
5882
5883         if (directory / ".hg").is_dir():
5884             return directory
5885
5886         if (directory / "pyproject.toml").is_file():
5887             return directory
5888
5889     return directory
5890
5891
5892 @dataclass
5893 class Report:
5894     """Provides a reformatting counter. Can be rendered with `str(report)`."""
5895
5896     check: bool = False
5897     diff: bool = False
5898     quiet: bool = False
5899     verbose: bool = False
5900     change_count: int = 0
5901     same_count: int = 0
5902     failure_count: int = 0
5903
5904     def done(self, src: Path, changed: Changed) -> None:
5905         """Increment the counter for successful reformatting. Write out a message."""
5906         if changed is Changed.YES:
5907             reformatted = "would reformat" if self.check or self.diff else "reformatted"
5908             if self.verbose or not self.quiet:
5909                 out(f"{reformatted} {src}")
5910             self.change_count += 1
5911         else:
5912             if self.verbose:
5913                 if changed is Changed.NO:
5914                     msg = f"{src} already well formatted, good job."
5915                 else:
5916                     msg = f"{src} wasn't modified on disk since last run."
5917                 out(msg, bold=False)
5918             self.same_count += 1
5919
5920     def failed(self, src: Path, message: str) -> None:
5921         """Increment the counter for failed reformatting. Write out a message."""
5922         err(f"error: cannot format {src}: {message}")
5923         self.failure_count += 1
5924
5925     def path_ignored(self, path: Path, message: str) -> None:
5926         if self.verbose:
5927             out(f"{path} ignored: {message}", bold=False)
5928
5929     @property
5930     def return_code(self) -> int:
5931         """Return the exit code that the app should use.
5932
5933         This considers the current state of changed files and failures:
5934         - if there were any failures, return 123;
5935         - if any files were changed and --check is being used, return 1;
5936         - otherwise return 0.
5937         """
5938         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
5939         # 126 we have special return codes reserved by the shell.
5940         if self.failure_count:
5941             return 123
5942
5943         elif self.change_count and self.check:
5944             return 1
5945
5946         return 0
5947
5948     def __str__(self) -> str:
5949         """Render a color report of the current state.
5950
5951         Use `click.unstyle` to remove colors.
5952         """
5953         if self.check or self.diff:
5954             reformatted = "would be reformatted"
5955             unchanged = "would be left unchanged"
5956             failed = "would fail to reformat"
5957         else:
5958             reformatted = "reformatted"
5959             unchanged = "left unchanged"
5960             failed = "failed to reformat"
5961         report = []
5962         if self.change_count:
5963             s = "s" if self.change_count > 1 else ""
5964             report.append(
5965                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
5966             )
5967         if self.same_count:
5968             s = "s" if self.same_count > 1 else ""
5969             report.append(f"{self.same_count} file{s} {unchanged}")
5970         if self.failure_count:
5971             s = "s" if self.failure_count > 1 else ""
5972             report.append(
5973                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
5974             )
5975         return ", ".join(report) + "."
5976
5977
5978 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
5979     filename = "<unknown>"
5980     if sys.version_info >= (3, 8):
5981         # TODO: support Python 4+ ;)
5982         for minor_version in range(sys.version_info[1], 4, -1):
5983             try:
5984                 return ast.parse(src, filename, feature_version=(3, minor_version))
5985             except SyntaxError:
5986                 continue
5987     else:
5988         for feature_version in (7, 6):
5989             try:
5990                 return ast3.parse(src, filename, feature_version=feature_version)
5991             except SyntaxError:
5992                 continue
5993
5994     return ast27.parse(src)
5995
5996
5997 def _fixup_ast_constants(
5998     node: Union[ast.AST, ast3.AST, ast27.AST]
5999 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6000     """Map ast nodes deprecated in 3.8 to Constant."""
6001     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6002         return ast.Constant(value=node.s)
6003
6004     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6005         return ast.Constant(value=node.n)
6006
6007     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6008         return ast.Constant(value=node.value)
6009
6010     return node
6011
6012
6013 def _stringify_ast(
6014     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6015 ) -> Iterator[str]:
6016     """Simple visitor generating strings to compare ASTs by content."""
6017
6018     node = _fixup_ast_constants(node)
6019
6020     yield f"{'  ' * depth}{node.__class__.__name__}("
6021
6022     for field in sorted(node._fields):  # noqa: F402
6023         # TypeIgnore has only one field 'lineno' which breaks this comparison
6024         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6025         if sys.version_info >= (3, 8):
6026             type_ignore_classes += (ast.TypeIgnore,)
6027         if isinstance(node, type_ignore_classes):
6028             break
6029
6030         try:
6031             value = getattr(node, field)
6032         except AttributeError:
6033             continue
6034
6035         yield f"{'  ' * (depth+1)}{field}="
6036
6037         if isinstance(value, list):
6038             for item in value:
6039                 # Ignore nested tuples within del statements, because we may insert
6040                 # parentheses and they change the AST.
6041                 if (
6042                     field == "targets"
6043                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6044                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6045                 ):
6046                     for item in item.elts:
6047                         yield from _stringify_ast(item, depth + 2)
6048
6049                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6050                     yield from _stringify_ast(item, depth + 2)
6051
6052         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6053             yield from _stringify_ast(value, depth + 2)
6054
6055         else:
6056             # Constant strings may be indented across newlines, if they are
6057             # docstrings; fold spaces after newlines when comparing. Similarly,
6058             # trailing and leading space may be removed.
6059             if (
6060                 isinstance(node, ast.Constant)
6061                 and field == "value"
6062                 and isinstance(value, str)
6063             ):
6064                 normalized = re.sub(r" *\n[ \t]+", "\n ", value).strip()
6065             else:
6066                 normalized = value
6067             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6068
6069     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6070
6071
6072 def assert_equivalent(src: str, dst: str) -> None:
6073     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6074     try:
6075         src_ast = parse_ast(src)
6076     except Exception as exc:
6077         raise AssertionError(
6078             "cannot use --safe with this file; failed to parse source file.  AST"
6079             f" error message: {exc}"
6080         )
6081
6082     try:
6083         dst_ast = parse_ast(dst)
6084     except Exception as exc:
6085         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6086         raise AssertionError(
6087             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6088             " on https://github.com/psf/black/issues.  This invalid output might be"
6089             f" helpful: {log}"
6090         ) from None
6091
6092     src_ast_str = "\n".join(_stringify_ast(src_ast))
6093     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6094     if src_ast_str != dst_ast_str:
6095         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6096         raise AssertionError(
6097             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6098             " source.  Please report a bug on https://github.com/psf/black/issues. "
6099             f" This diff might be helpful: {log}"
6100         ) from None
6101
6102
6103 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6104     """Raise AssertionError if `dst` reformats differently the second time."""
6105     newdst = format_str(dst, mode=mode)
6106     if dst != newdst:
6107         log = dump_to_file(
6108             diff(src, dst, "source", "first pass"),
6109             diff(dst, newdst, "first pass", "second pass"),
6110         )
6111         raise AssertionError(
6112             "INTERNAL ERROR: Black produced different code on the second pass of the"
6113             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6114             f"  This diff might be helpful: {log}"
6115         ) from None
6116
6117
6118 @mypyc_attr(patchable=True)
6119 def dump_to_file(*output: str) -> str:
6120     """Dump `output` to a temporary file. Return path to the file."""
6121     with tempfile.NamedTemporaryFile(
6122         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6123     ) as f:
6124         for lines in output:
6125             f.write(lines)
6126             if lines and lines[-1] != "\n":
6127                 f.write("\n")
6128     return f.name
6129
6130
6131 @contextmanager
6132 def nullcontext() -> Iterator[None]:
6133     """Return an empty context manager.
6134
6135     To be used like `nullcontext` in Python 3.7.
6136     """
6137     yield
6138
6139
6140 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6141     """Return a unified diff string between strings `a` and `b`."""
6142     import difflib
6143
6144     a_lines = [line + "\n" for line in a.splitlines()]
6145     b_lines = [line + "\n" for line in b.splitlines()]
6146     return "".join(
6147         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6148     )
6149
6150
6151 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6152     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6153     err("Aborted!")
6154     for task in tasks:
6155         task.cancel()
6156
6157
6158 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6159     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6160     try:
6161         if sys.version_info[:2] >= (3, 7):
6162             all_tasks = asyncio.all_tasks
6163         else:
6164             all_tasks = asyncio.Task.all_tasks
6165         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6166         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6167         if not to_cancel:
6168             return
6169
6170         for task in to_cancel:
6171             task.cancel()
6172         loop.run_until_complete(
6173             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6174         )
6175     finally:
6176         # `concurrent.futures.Future` objects cannot be cancelled once they
6177         # are already running. There might be some when the `shutdown()` happened.
6178         # Silence their logger's spew about the event loop being closed.
6179         cf_logger = logging.getLogger("concurrent.futures")
6180         cf_logger.setLevel(logging.CRITICAL)
6181         loop.close()
6182
6183
6184 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6185     """Replace `regex` with `replacement` twice on `original`.
6186
6187     This is used by string normalization to perform replaces on
6188     overlapping matches.
6189     """
6190     return regex.sub(replacement, regex.sub(replacement, original))
6191
6192
6193 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6194     """Compile a regular expression string in `regex`.
6195
6196     If it contains newlines, use verbose mode.
6197     """
6198     if "\n" in regex:
6199         regex = "(?x)" + regex
6200     compiled: Pattern[str] = re.compile(regex)
6201     return compiled
6202
6203
6204 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6205     """Like `reversed(enumerate(sequence))` if that were possible."""
6206     index = len(sequence) - 1
6207     for element in reversed(sequence):
6208         yield (index, element)
6209         index -= 1
6210
6211
6212 def enumerate_with_length(
6213     line: Line, reversed: bool = False
6214 ) -> Iterator[Tuple[Index, Leaf, int]]:
6215     """Return an enumeration of leaves with their length.
6216
6217     Stops prematurely on multiline strings and standalone comments.
6218     """
6219     op = cast(
6220         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6221         enumerate_reversed if reversed else enumerate,
6222     )
6223     for index, leaf in op(line.leaves):
6224         length = len(leaf.prefix) + len(leaf.value)
6225         if "\n" in leaf.value:
6226             return  # Multiline strings, we can't continue.
6227
6228         for comment in line.comments_after(leaf):
6229             length += len(comment.value)
6230
6231         yield index, leaf, length
6232
6233
6234 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6235     """Return True if `line` is no longer than `line_length`.
6236
6237     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6238     """
6239     if not line_str:
6240         line_str = line_to_string(line)
6241     return (
6242         len(line_str) <= line_length
6243         and "\n" not in line_str  # multiline strings
6244         and not line.contains_standalone_comments()
6245     )
6246
6247
6248 def can_be_split(line: Line) -> bool:
6249     """Return False if the line cannot be split *for sure*.
6250
6251     This is not an exhaustive search but a cheap heuristic that we can use to
6252     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6253     in unnecessary parentheses).
6254     """
6255     leaves = line.leaves
6256     if len(leaves) < 2:
6257         return False
6258
6259     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6260         call_count = 0
6261         dot_count = 0
6262         next = leaves[-1]
6263         for leaf in leaves[-2::-1]:
6264             if leaf.type in OPENING_BRACKETS:
6265                 if next.type not in CLOSING_BRACKETS:
6266                     return False
6267
6268                 call_count += 1
6269             elif leaf.type == token.DOT:
6270                 dot_count += 1
6271             elif leaf.type == token.NAME:
6272                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6273                     return False
6274
6275             elif leaf.type not in CLOSING_BRACKETS:
6276                 return False
6277
6278             if dot_count > 1 and call_count > 1:
6279                 return False
6280
6281     return True
6282
6283
6284 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
6285     """Does `line` have a shape safe to reformat without optional parens around it?
6286
6287     Returns True for only a subset of potentially nice looking formattings but
6288     the point is to not return false positives that end up producing lines that
6289     are too long.
6290     """
6291     bt = line.bracket_tracker
6292     if not bt.delimiters:
6293         # Without delimiters the optional parentheses are useless.
6294         return True
6295
6296     max_priority = bt.max_delimiter_priority()
6297     if bt.delimiter_count_with_priority(max_priority) > 1:
6298         # With more than one delimiter of a kind the optional parentheses read better.
6299         return False
6300
6301     if max_priority == DOT_PRIORITY:
6302         # A single stranded method call doesn't require optional parentheses.
6303         return True
6304
6305     assert len(line.leaves) >= 2, "Stranded delimiter"
6306
6307     first = line.leaves[0]
6308     second = line.leaves[1]
6309     penultimate = line.leaves[-2]
6310     last = line.leaves[-1]
6311
6312     # With a single delimiter, omit if the expression starts or ends with
6313     # a bracket.
6314     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6315         remainder = False
6316         length = 4 * line.depth
6317         for _index, leaf, leaf_length in enumerate_with_length(line):
6318             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6319                 remainder = True
6320             if remainder:
6321                 length += leaf_length
6322                 if length > line_length:
6323                     break
6324
6325                 if leaf.type in OPENING_BRACKETS:
6326                     # There are brackets we can further split on.
6327                     remainder = False
6328
6329         else:
6330             # checked the entire string and line length wasn't exceeded
6331             if len(line.leaves) == _index + 1:
6332                 return True
6333
6334         # Note: we are not returning False here because a line might have *both*
6335         # a leading opening bracket and a trailing closing bracket.  If the
6336         # opening bracket doesn't match our rule, maybe the closing will.
6337
6338     if (
6339         last.type == token.RPAR
6340         or last.type == token.RBRACE
6341         or (
6342             # don't use indexing for omitting optional parentheses;
6343             # it looks weird
6344             last.type == token.RSQB
6345             and last.parent
6346             and last.parent.type != syms.trailer
6347         )
6348     ):
6349         if penultimate.type in OPENING_BRACKETS:
6350             # Empty brackets don't help.
6351             return False
6352
6353         if is_multiline_string(first):
6354             # Additional wrapping of a multiline string in this situation is
6355             # unnecessary.
6356             return True
6357
6358         length = 4 * line.depth
6359         seen_other_brackets = False
6360         for _index, leaf, leaf_length in enumerate_with_length(line):
6361             length += leaf_length
6362             if leaf is last.opening_bracket:
6363                 if seen_other_brackets or length <= line_length:
6364                     return True
6365
6366             elif leaf.type in OPENING_BRACKETS:
6367                 # There are brackets we can further split on.
6368                 seen_other_brackets = True
6369
6370     return False
6371
6372
6373 def get_cache_file(mode: Mode) -> Path:
6374     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6375
6376
6377 def read_cache(mode: Mode) -> Cache:
6378     """Read the cache if it exists and is well formed.
6379
6380     If it is not well formed, the call to write_cache later should resolve the issue.
6381     """
6382     cache_file = get_cache_file(mode)
6383     if not cache_file.exists():
6384         return {}
6385
6386     with cache_file.open("rb") as fobj:
6387         try:
6388             cache: Cache = pickle.load(fobj)
6389         except (pickle.UnpicklingError, ValueError):
6390             return {}
6391
6392     return cache
6393
6394
6395 def get_cache_info(path: Path) -> CacheInfo:
6396     """Return the information used to check if a file is already formatted or not."""
6397     stat = path.stat()
6398     return stat.st_mtime, stat.st_size
6399
6400
6401 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6402     """Split an iterable of paths in `sources` into two sets.
6403
6404     The first contains paths of files that modified on disk or are not in the
6405     cache. The other contains paths to non-modified files.
6406     """
6407     todo, done = set(), set()
6408     for src in sources:
6409         src = src.resolve()
6410         if cache.get(src) != get_cache_info(src):
6411             todo.add(src)
6412         else:
6413             done.add(src)
6414     return todo, done
6415
6416
6417 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6418     """Update the cache file."""
6419     cache_file = get_cache_file(mode)
6420     try:
6421         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6422         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6423         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6424             pickle.dump(new_cache, f, protocol=4)
6425         os.replace(f.name, cache_file)
6426     except OSError:
6427         pass
6428
6429
6430 def patch_click() -> None:
6431     """Make Click not crash.
6432
6433     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6434     default which restricts paths that it can access during the lifetime of the
6435     application.  Click refuses to work in this scenario by raising a RuntimeError.
6436
6437     In case of Black the likelihood that non-ASCII characters are going to be used in
6438     file paths is minimal since it's Python source code.  Moreover, this crash was
6439     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6440     """
6441     try:
6442         from click import core
6443         from click import _unicodefun  # type: ignore
6444     except ModuleNotFoundError:
6445         return
6446
6447     for module in (core, _unicodefun):
6448         if hasattr(module, "_verify_python3_env"):
6449             module._verify_python3_env = lambda: None
6450
6451
6452 def patched_main() -> None:
6453     freeze_support()
6454     patch_click()
6455     main()
6456
6457
6458 def fix_docstring(docstring: str, prefix: str) -> str:
6459     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6460     if not docstring:
6461         return ""
6462     # Convert tabs to spaces (following the normal Python rules)
6463     # and split into a list of lines:
6464     lines = docstring.expandtabs().splitlines()
6465     # Determine minimum indentation (first line doesn't count):
6466     indent = sys.maxsize
6467     for line in lines[1:]:
6468         stripped = line.lstrip()
6469         if stripped:
6470             indent = min(indent, len(line) - len(stripped))
6471     # Remove indentation (first line is special):
6472     trimmed = [lines[0].strip()]
6473     if indent < sys.maxsize:
6474         last_line_idx = len(lines) - 2
6475         for i, line in enumerate(lines[1:]):
6476             stripped_line = line[indent:].rstrip()
6477             if stripped_line or i == last_line_idx:
6478                 trimmed.append(prefix + stripped_line)
6479             else:
6480                 trimmed.append("")
6481     # Return a single string:
6482     return "\n".join(trimmed)
6483
6484
6485 if __name__ == "__main__":
6486     patched_main()