]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Primer update config - enable pytest (#1626)
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from typing_extensions import Final
46 from mypy_extensions import mypyc_attr
47
48 from appdirs import user_cache_dir
49 from dataclasses import dataclass, field, replace
50 import click
51 import toml
52 from typed_ast import ast3, ast27
53 from pathspec import PathSpec
54
55 # lib2to3 fork
56 from blib2to3.pytree import Node, Leaf, type_repr
57 from blib2to3 import pygram, pytree
58 from blib2to3.pgen2 import driver, token
59 from blib2to3.pgen2.grammar import Grammar
60 from blib2to3.pgen2.parse import ParseError
61
62 from _black_version import version as __version__
63
64 if TYPE_CHECKING:
65     import colorama  # noqa: F401
66
67 DEFAULT_LINE_LENGTH = 88
68 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
69 DEFAULT_INCLUDES = r"\.pyi?$"
70 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
71
72 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
73
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79 Depth = int
80 NodeType = int
81 ParserState = int
82 LeafID = int
83 StringID = int
84 Priority = int
85 Index = int
86 LN = Union[Leaf, Node]
87 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
88 Timestamp = float
89 FileSize = int
90 CacheInfo = Tuple[Timestamp, FileSize]
91 Cache = Dict[Path, CacheInfo]
92 out = partial(click.secho, bold=True, err=True)
93 err = partial(click.secho, fg="red", err=True)
94
95 pygram.initialize(CACHE_DIR)
96 syms = pygram.python_symbols
97
98
99 class NothingChanged(UserWarning):
100     """Raised when reformatted code is the same as source."""
101
102
103 class CannotTransform(Exception):
104     """Base class for errors raised by Transformers."""
105
106
107 class CannotSplit(CannotTransform):
108     """A readable split that fits the allotted line length is impossible."""
109
110
111 class InvalidInput(ValueError):
112     """Raised when input source code fails all parse attempts."""
113
114
115 T = TypeVar("T")
116 E = TypeVar("E", bound=Exception)
117
118
119 class Ok(Generic[T]):
120     def __init__(self, value: T) -> None:
121         self._value = value
122
123     def ok(self) -> T:
124         return self._value
125
126
127 class Err(Generic[E]):
128     def __init__(self, e: E) -> None:
129         self._e = e
130
131     def err(self) -> E:
132         return self._e
133
134
135 # The 'Result' return type is used to implement an error-handling model heavily
136 # influenced by that used by the Rust programming language
137 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
138 Result = Union[Ok[T], Err[E]]
139 TResult = Result[T, CannotTransform]  # (T)ransform Result
140 TMatchResult = TResult[Index]
141
142
143 class WriteBack(Enum):
144     NO = 0
145     YES = 1
146     DIFF = 2
147     CHECK = 3
148     COLOR_DIFF = 4
149
150     @classmethod
151     def from_configuration(
152         cls, *, check: bool, diff: bool, color: bool = False
153     ) -> "WriteBack":
154         if check and not diff:
155             return cls.CHECK
156
157         if diff and color:
158             return cls.COLOR_DIFF
159
160         return cls.DIFF if diff else cls.YES
161
162
163 class Changed(Enum):
164     NO = 0
165     CACHED = 1
166     YES = 2
167
168
169 class TargetVersion(Enum):
170     PY27 = 2
171     PY33 = 3
172     PY34 = 4
173     PY35 = 5
174     PY36 = 6
175     PY37 = 7
176     PY38 = 8
177
178     def is_python2(self) -> bool:
179         return self is TargetVersion.PY27
180
181
182 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
183
184
185 class Feature(Enum):
186     # All string literals are unicode
187     UNICODE_LITERALS = 1
188     F_STRINGS = 2
189     NUMERIC_UNDERSCORES = 3
190     TRAILING_COMMA_IN_CALL = 4
191     TRAILING_COMMA_IN_DEF = 5
192     # The following two feature-flags are mutually exclusive, and exactly one should be
193     # set for every version of python.
194     ASYNC_IDENTIFIERS = 6
195     ASYNC_KEYWORDS = 7
196     ASSIGNMENT_EXPRESSIONS = 8
197     POS_ONLY_ARGUMENTS = 9
198     FORCE_OPTIONAL_PARENTHESES = 50
199
200
201 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
202     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
203     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
204     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
205     TargetVersion.PY35: {
206         Feature.UNICODE_LITERALS,
207         Feature.TRAILING_COMMA_IN_CALL,
208         Feature.ASYNC_IDENTIFIERS,
209     },
210     TargetVersion.PY36: {
211         Feature.UNICODE_LITERALS,
212         Feature.F_STRINGS,
213         Feature.NUMERIC_UNDERSCORES,
214         Feature.TRAILING_COMMA_IN_CALL,
215         Feature.TRAILING_COMMA_IN_DEF,
216         Feature.ASYNC_IDENTIFIERS,
217     },
218     TargetVersion.PY37: {
219         Feature.UNICODE_LITERALS,
220         Feature.F_STRINGS,
221         Feature.NUMERIC_UNDERSCORES,
222         Feature.TRAILING_COMMA_IN_CALL,
223         Feature.TRAILING_COMMA_IN_DEF,
224         Feature.ASYNC_KEYWORDS,
225     },
226     TargetVersion.PY38: {
227         Feature.UNICODE_LITERALS,
228         Feature.F_STRINGS,
229         Feature.NUMERIC_UNDERSCORES,
230         Feature.TRAILING_COMMA_IN_CALL,
231         Feature.TRAILING_COMMA_IN_DEF,
232         Feature.ASYNC_KEYWORDS,
233         Feature.ASSIGNMENT_EXPRESSIONS,
234         Feature.POS_ONLY_ARGUMENTS,
235     },
236 }
237
238
239 @dataclass
240 class Mode:
241     target_versions: Set[TargetVersion] = field(default_factory=set)
242     line_length: int = DEFAULT_LINE_LENGTH
243     string_normalization: bool = True
244     experimental_string_processing: bool = False
245     is_pyi: bool = False
246
247     def get_cache_key(self) -> str:
248         if self.target_versions:
249             version_str = ",".join(
250                 str(version.value)
251                 for version in sorted(self.target_versions, key=lambda v: v.value)
252             )
253         else:
254             version_str = "-"
255         parts = [
256             version_str,
257             str(self.line_length),
258             str(int(self.string_normalization)),
259             str(int(self.is_pyi)),
260         ]
261         return ".".join(parts)
262
263
264 # Legacy name, left for integrations.
265 FileMode = Mode
266
267
268 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
269     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
270
271
272 def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
273     """Find the absolute filepath to a pyproject.toml if it exists"""
274     path_project_root = find_project_root(path_search_start)
275     path_pyproject_toml = path_project_root / "pyproject.toml"
276     return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
277
278
279 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
280     """Parse a pyproject toml file, pulling out relevant parts for Black
281
282     If parsing fails, will raise a toml.TomlDecodeError
283     """
284     pyproject_toml = toml.load(path_config)
285     config = pyproject_toml.get("tool", {}).get("black", {})
286     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
287
288
289 def read_pyproject_toml(
290     ctx: click.Context, param: click.Parameter, value: Optional[str]
291 ) -> Optional[str]:
292     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
293
294     Returns the path to a successfully found and read configuration file, None
295     otherwise.
296     """
297     if not value:
298         value = find_pyproject_toml(ctx.params.get("src", ()))
299         if value is None:
300             return None
301
302     try:
303         config = parse_pyproject_toml(value)
304     except (toml.TomlDecodeError, OSError) as e:
305         raise click.FileError(
306             filename=value, hint=f"Error reading configuration file: {e}"
307         )
308
309     if not config:
310         return None
311     else:
312         # Sanitize the values to be Click friendly. For more information please see:
313         # https://github.com/psf/black/issues/1458
314         # https://github.com/pallets/click/issues/1567
315         config = {
316             k: str(v) if not isinstance(v, (list, dict)) else v
317             for k, v in config.items()
318         }
319
320     target_version = config.get("target_version")
321     if target_version is not None and not isinstance(target_version, list):
322         raise click.BadOptionUsage(
323             "target-version", "Config key target-version must be a list"
324         )
325
326     default_map: Dict[str, Any] = {}
327     if ctx.default_map:
328         default_map.update(ctx.default_map)
329     default_map.update(config)
330
331     ctx.default_map = default_map
332     return value
333
334
335 def target_version_option_callback(
336     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
337 ) -> List[TargetVersion]:
338     """Compute the target versions from a --target-version flag.
339
340     This is its own function because mypy couldn't infer the type correctly
341     when it was a lambda, causing mypyc trouble.
342     """
343     return [TargetVersion[val.upper()] for val in v]
344
345
346 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
347 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
348 @click.option(
349     "-l",
350     "--line-length",
351     type=int,
352     default=DEFAULT_LINE_LENGTH,
353     help="How many characters per line to allow.",
354     show_default=True,
355 )
356 @click.option(
357     "-t",
358     "--target-version",
359     type=click.Choice([v.name.lower() for v in TargetVersion]),
360     callback=target_version_option_callback,
361     multiple=True,
362     help=(
363         "Python versions that should be supported by Black's output. [default: per-file"
364         " auto-detection]"
365     ),
366 )
367 @click.option(
368     "--pyi",
369     is_flag=True,
370     help=(
371         "Format all input files like typing stubs regardless of file extension (useful"
372         " when piping source on standard input)."
373     ),
374 )
375 @click.option(
376     "-S",
377     "--skip-string-normalization",
378     is_flag=True,
379     help="Don't normalize string quotes or prefixes.",
380 )
381 @click.option(
382     "--experimental-string-processing",
383     is_flag=True,
384     hidden=True,
385     help=(
386         "Experimental option that performs more normalization on string literals."
387         " Currently disabled because it leads to some crashes."
388     ),
389 )
390 @click.option(
391     "--check",
392     is_flag=True,
393     help=(
394         "Don't write the files back, just return the status.  Return code 0 means"
395         " nothing would change.  Return code 1 means some files would be reformatted."
396         " Return code 123 means there was an internal error."
397     ),
398 )
399 @click.option(
400     "--diff",
401     is_flag=True,
402     help="Don't write the files back, just output a diff for each file on stdout.",
403 )
404 @click.option(
405     "--color/--no-color",
406     is_flag=True,
407     help="Show colored diff. Only applies when `--diff` is given.",
408 )
409 @click.option(
410     "--fast/--safe",
411     is_flag=True,
412     help="If --fast given, skip temporary sanity checks. [default: --safe]",
413 )
414 @click.option(
415     "--include",
416     type=str,
417     default=DEFAULT_INCLUDES,
418     help=(
419         "A regular expression that matches files and directories that should be"
420         " included on recursive searches.  An empty value means all files are included"
421         " regardless of the name.  Use forward slashes for directories on all platforms"
422         " (Windows, too).  Exclusions are calculated first, inclusions later."
423     ),
424     show_default=True,
425 )
426 @click.option(
427     "--exclude",
428     type=str,
429     default=DEFAULT_EXCLUDES,
430     help=(
431         "A regular expression that matches files and directories that should be"
432         " excluded on recursive searches.  An empty value means no paths are excluded."
433         " Use forward slashes for directories on all platforms (Windows, too). "
434         " Exclusions are calculated first, inclusions later."
435     ),
436     show_default=True,
437 )
438 @click.option(
439     "--force-exclude",
440     type=str,
441     help=(
442         "Like --exclude, but files and directories matching this regex will be "
443         "excluded even when they are passed explicitly as arguments"
444     ),
445 )
446 @click.option(
447     "-q",
448     "--quiet",
449     is_flag=True,
450     help=(
451         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
452         " those with 2>/dev/null."
453     ),
454 )
455 @click.option(
456     "-v",
457     "--verbose",
458     is_flag=True,
459     help=(
460         "Also emit messages to stderr about files that were not changed or were ignored"
461         " due to --exclude=."
462     ),
463 )
464 @click.version_option(version=__version__)
465 @click.argument(
466     "src",
467     nargs=-1,
468     type=click.Path(
469         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
470     ),
471     is_eager=True,
472 )
473 @click.option(
474     "--config",
475     type=click.Path(
476         exists=True,
477         file_okay=True,
478         dir_okay=False,
479         readable=True,
480         allow_dash=False,
481         path_type=str,
482     ),
483     is_eager=True,
484     callback=read_pyproject_toml,
485     help="Read configuration from FILE path.",
486 )
487 @click.pass_context
488 def main(
489     ctx: click.Context,
490     code: Optional[str],
491     line_length: int,
492     target_version: List[TargetVersion],
493     check: bool,
494     diff: bool,
495     color: bool,
496     fast: bool,
497     pyi: bool,
498     skip_string_normalization: bool,
499     experimental_string_processing: bool,
500     quiet: bool,
501     verbose: bool,
502     include: str,
503     exclude: str,
504     force_exclude: Optional[str],
505     src: Tuple[str, ...],
506     config: Optional[str],
507 ) -> None:
508     """The uncompromising code formatter."""
509     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
510     if target_version:
511         versions = set(target_version)
512     else:
513         # We'll autodetect later.
514         versions = set()
515     mode = Mode(
516         target_versions=versions,
517         line_length=line_length,
518         is_pyi=pyi,
519         string_normalization=not skip_string_normalization,
520         experimental_string_processing=experimental_string_processing,
521     )
522     if config and verbose:
523         out(f"Using configuration from {config}.", bold=False, fg="blue")
524     if code is not None:
525         print(format_str(code, mode=mode))
526         ctx.exit(0)
527     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
528     sources = get_sources(
529         ctx=ctx,
530         src=src,
531         quiet=quiet,
532         verbose=verbose,
533         include=include,
534         exclude=exclude,
535         force_exclude=force_exclude,
536         report=report,
537     )
538
539     path_empty(
540         sources,
541         "No Python files are present to be formatted. Nothing to do 😴",
542         quiet,
543         verbose,
544         ctx,
545     )
546
547     if len(sources) == 1:
548         reformat_one(
549             src=sources.pop(),
550             fast=fast,
551             write_back=write_back,
552             mode=mode,
553             report=report,
554         )
555     else:
556         reformat_many(
557             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
558         )
559
560     if verbose or not quiet:
561         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
562         click.secho(str(report), err=True)
563     ctx.exit(report.return_code)
564
565
566 def get_sources(
567     *,
568     ctx: click.Context,
569     src: Tuple[str, ...],
570     quiet: bool,
571     verbose: bool,
572     include: str,
573     exclude: str,
574     force_exclude: Optional[str],
575     report: "Report",
576 ) -> Set[Path]:
577     """Compute the set of files to be formatted."""
578     try:
579         include_regex = re_compile_maybe_verbose(include)
580     except re.error:
581         err(f"Invalid regular expression for include given: {include!r}")
582         ctx.exit(2)
583     try:
584         exclude_regex = re_compile_maybe_verbose(exclude)
585     except re.error:
586         err(f"Invalid regular expression for exclude given: {exclude!r}")
587         ctx.exit(2)
588     try:
589         force_exclude_regex = (
590             re_compile_maybe_verbose(force_exclude) if force_exclude else None
591         )
592     except re.error:
593         err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
594         ctx.exit(2)
595
596     root = find_project_root(src)
597     sources: Set[Path] = set()
598     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
599     gitignore = get_gitignore(root)
600
601     for s in src:
602         p = Path(s)
603         if p.is_dir():
604             sources.update(
605                 gen_python_files(
606                     p.iterdir(),
607                     root,
608                     include_regex,
609                     exclude_regex,
610                     force_exclude_regex,
611                     report,
612                     gitignore,
613                 )
614             )
615         elif s == "-":
616             sources.add(p)
617         elif p.is_file():
618             normalized_path = normalize_path_maybe_ignore(p, root, report)
619             if normalized_path is None:
620                 continue
621
622             normalized_path = "/" + normalized_path
623             # Hard-exclude any files that matches the `--force-exclude` regex.
624             if force_exclude_regex:
625                 force_exclude_match = force_exclude_regex.search(normalized_path)
626             else:
627                 force_exclude_match = None
628             if force_exclude_match and force_exclude_match.group(0):
629                 report.path_ignored(p, "matches the --force-exclude regular expression")
630                 continue
631
632             sources.add(p)
633         else:
634             err(f"invalid path: {s}")
635     return sources
636
637
638 def path_empty(
639     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
640 ) -> None:
641     """
642     Exit if there is no `src` provided for formatting
643     """
644     if len(src) == 0:
645         if verbose or not quiet:
646             out(msg)
647             ctx.exit(0)
648
649
650 def reformat_one(
651     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
652 ) -> None:
653     """Reformat a single file under `src` without spawning child processes.
654
655     `fast`, `write_back`, and `mode` options are passed to
656     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
657     """
658     try:
659         changed = Changed.NO
660         if not src.is_file() and str(src) == "-":
661             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
662                 changed = Changed.YES
663         else:
664             cache: Cache = {}
665             if write_back != WriteBack.DIFF:
666                 cache = read_cache(mode)
667                 res_src = src.resolve()
668                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
669                     changed = Changed.CACHED
670             if changed is not Changed.CACHED and format_file_in_place(
671                 src, fast=fast, write_back=write_back, mode=mode
672             ):
673                 changed = Changed.YES
674             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
675                 write_back is WriteBack.CHECK and changed is Changed.NO
676             ):
677                 write_cache(cache, [src], mode)
678         report.done(src, changed)
679     except Exception as exc:
680         if report.verbose:
681             traceback.print_exc()
682         report.failed(src, str(exc))
683
684
685 def reformat_many(
686     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
687 ) -> None:
688     """Reformat multiple files using a ProcessPoolExecutor."""
689     executor: Executor
690     loop = asyncio.get_event_loop()
691     worker_count = os.cpu_count()
692     if sys.platform == "win32":
693         # Work around https://bugs.python.org/issue26903
694         worker_count = min(worker_count, 61)
695     try:
696         executor = ProcessPoolExecutor(max_workers=worker_count)
697     except (ImportError, OSError):
698         # we arrive here if the underlying system does not support multi-processing
699         # like in AWS Lambda or Termux, in which case we gracefully fallback to
700         # a ThreadPollExecutor with just a single worker (more workers would not do us
701         # any good due to the Global Interpreter Lock)
702         executor = ThreadPoolExecutor(max_workers=1)
703
704     try:
705         loop.run_until_complete(
706             schedule_formatting(
707                 sources=sources,
708                 fast=fast,
709                 write_back=write_back,
710                 mode=mode,
711                 report=report,
712                 loop=loop,
713                 executor=executor,
714             )
715         )
716     finally:
717         shutdown(loop)
718         if executor is not None:
719             executor.shutdown()
720
721
722 async def schedule_formatting(
723     sources: Set[Path],
724     fast: bool,
725     write_back: WriteBack,
726     mode: Mode,
727     report: "Report",
728     loop: asyncio.AbstractEventLoop,
729     executor: Executor,
730 ) -> None:
731     """Run formatting of `sources` in parallel using the provided `executor`.
732
733     (Use ProcessPoolExecutors for actual parallelism.)
734
735     `write_back`, `fast`, and `mode` options are passed to
736     :func:`format_file_in_place`.
737     """
738     cache: Cache = {}
739     if write_back != WriteBack.DIFF:
740         cache = read_cache(mode)
741         sources, cached = filter_cached(cache, sources)
742         for src in sorted(cached):
743             report.done(src, Changed.CACHED)
744     if not sources:
745         return
746
747     cancelled = []
748     sources_to_cache = []
749     lock = None
750     if write_back == WriteBack.DIFF:
751         # For diff output, we need locks to ensure we don't interleave output
752         # from different processes.
753         manager = Manager()
754         lock = manager.Lock()
755     tasks = {
756         asyncio.ensure_future(
757             loop.run_in_executor(
758                 executor, format_file_in_place, src, fast, mode, write_back, lock
759             )
760         ): src
761         for src in sorted(sources)
762     }
763     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
764     try:
765         loop.add_signal_handler(signal.SIGINT, cancel, pending)
766         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
767     except NotImplementedError:
768         # There are no good alternatives for these on Windows.
769         pass
770     while pending:
771         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
772         for task in done:
773             src = tasks.pop(task)
774             if task.cancelled():
775                 cancelled.append(task)
776             elif task.exception():
777                 report.failed(src, str(task.exception()))
778             else:
779                 changed = Changed.YES if task.result() else Changed.NO
780                 # If the file was written back or was successfully checked as
781                 # well-formatted, store this information in the cache.
782                 if write_back is WriteBack.YES or (
783                     write_back is WriteBack.CHECK and changed is Changed.NO
784                 ):
785                     sources_to_cache.append(src)
786                 report.done(src, changed)
787     if cancelled:
788         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
789     if sources_to_cache:
790         write_cache(cache, sources_to_cache, mode)
791
792
793 def format_file_in_place(
794     src: Path,
795     fast: bool,
796     mode: Mode,
797     write_back: WriteBack = WriteBack.NO,
798     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
799 ) -> bool:
800     """Format file under `src` path. Return True if changed.
801
802     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
803     code to the file.
804     `mode` and `fast` options are passed to :func:`format_file_contents`.
805     """
806     if src.suffix == ".pyi":
807         mode = replace(mode, is_pyi=True)
808
809     then = datetime.utcfromtimestamp(src.stat().st_mtime)
810     with open(src, "rb") as buf:
811         src_contents, encoding, newline = decode_bytes(buf.read())
812     try:
813         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
814     except NothingChanged:
815         return False
816
817     if write_back == WriteBack.YES:
818         with open(src, "w", encoding=encoding, newline=newline) as f:
819             f.write(dst_contents)
820     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
821         now = datetime.utcnow()
822         src_name = f"{src}\t{then} +0000"
823         dst_name = f"{src}\t{now} +0000"
824         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
825
826         if write_back == write_back.COLOR_DIFF:
827             diff_contents = color_diff(diff_contents)
828
829         with lock or nullcontext():
830             f = io.TextIOWrapper(
831                 sys.stdout.buffer,
832                 encoding=encoding,
833                 newline=newline,
834                 write_through=True,
835             )
836             f = wrap_stream_for_windows(f)
837             f.write(diff_contents)
838             f.detach()
839
840     return True
841
842
843 def color_diff(contents: str) -> str:
844     """Inject the ANSI color codes to the diff."""
845     lines = contents.split("\n")
846     for i, line in enumerate(lines):
847         if line.startswith("+++") or line.startswith("---"):
848             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
849         if line.startswith("@@"):
850             line = "\033[36m" + line + "\033[0m"  # cyan, reset
851         if line.startswith("+"):
852             line = "\033[32m" + line + "\033[0m"  # green, reset
853         elif line.startswith("-"):
854             line = "\033[31m" + line + "\033[0m"  # red, reset
855         lines[i] = line
856     return "\n".join(lines)
857
858
859 def wrap_stream_for_windows(
860     f: io.TextIOWrapper,
861 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
862     """
863     Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
864
865     If `colorama` is not found, then no change is made. If `colorama` does
866     exist, then it handles the logic to determine whether or not to change
867     things.
868     """
869     try:
870         from colorama import initialise
871
872         # We set `strip=False` so that we can don't have to modify
873         # test_express_diff_with_color.
874         f = initialise.wrap_stream(
875             f, convert=None, strip=False, autoreset=False, wrap=True
876         )
877
878         # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
879         # which does not have a `detach()` method. So we fake one.
880         f.detach = lambda *args, **kwargs: None  # type: ignore
881     except ImportError:
882         pass
883
884     return f
885
886
887 def format_stdin_to_stdout(
888     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
889 ) -> bool:
890     """Format file on stdin. Return True if changed.
891
892     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
893     write a diff to stdout. The `mode` argument is passed to
894     :func:`format_file_contents`.
895     """
896     then = datetime.utcnow()
897     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
898     dst = src
899     try:
900         dst = format_file_contents(src, fast=fast, mode=mode)
901         return True
902
903     except NothingChanged:
904         return False
905
906     finally:
907         f = io.TextIOWrapper(
908             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
909         )
910         if write_back == WriteBack.YES:
911             f.write(dst)
912         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
913             now = datetime.utcnow()
914             src_name = f"STDIN\t{then} +0000"
915             dst_name = f"STDOUT\t{now} +0000"
916             d = diff(src, dst, src_name, dst_name)
917             if write_back == WriteBack.COLOR_DIFF:
918                 d = color_diff(d)
919                 f = wrap_stream_for_windows(f)
920             f.write(d)
921         f.detach()
922
923
924 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
925     """Reformat contents a file and return new contents.
926
927     If `fast` is False, additionally confirm that the reformatted code is
928     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
929     `mode` is passed to :func:`format_str`.
930     """
931     if src_contents.strip() == "":
932         raise NothingChanged
933
934     dst_contents = format_str(src_contents, mode=mode)
935     if src_contents == dst_contents:
936         raise NothingChanged
937
938     if not fast:
939         assert_equivalent(src_contents, dst_contents)
940         assert_stable(src_contents, dst_contents, mode=mode)
941     return dst_contents
942
943
944 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
945     """Reformat a string and return new contents.
946
947     `mode` determines formatting options, such as how many characters per line are
948     allowed.  Example:
949
950     >>> import black
951     >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
952     def f(arg: str = "") -> None:
953         ...
954
955     A more complex example:
956
957     >>> print(
958     ...   black.format_str(
959     ...     "def f(arg:str='')->None: hey",
960     ...     mode=black.Mode(
961     ...       target_versions={black.TargetVersion.PY36},
962     ...       line_length=10,
963     ...       string_normalization=False,
964     ...       is_pyi=False,
965     ...     ),
966     ...   ),
967     ... )
968     def f(
969         arg: str = '',
970     ) -> None:
971         hey
972
973     """
974     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
975     dst_contents = []
976     future_imports = get_future_imports(src_node)
977     if mode.target_versions:
978         versions = mode.target_versions
979     else:
980         versions = detect_target_versions(src_node)
981     normalize_fmt_off(src_node)
982     lines = LineGenerator(
983         remove_u_prefix="unicode_literals" in future_imports
984         or supports_feature(versions, Feature.UNICODE_LITERALS),
985         is_pyi=mode.is_pyi,
986         normalize_strings=mode.string_normalization,
987     )
988     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
989     empty_line = Line()
990     after = 0
991     split_line_features = {
992         feature
993         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
994         if supports_feature(versions, feature)
995     }
996     for current_line in lines.visit(src_node):
997         dst_contents.append(str(empty_line) * after)
998         before, after = elt.maybe_empty_lines(current_line)
999         dst_contents.append(str(empty_line) * before)
1000         for line in transform_line(
1001             current_line, mode=mode, features=split_line_features
1002         ):
1003             dst_contents.append(str(line))
1004     return "".join(dst_contents)
1005
1006
1007 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1008     """Return a tuple of (decoded_contents, encoding, newline).
1009
1010     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1011     universal newlines (i.e. only contains LF).
1012     """
1013     srcbuf = io.BytesIO(src)
1014     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1015     if not lines:
1016         return "", encoding, "\n"
1017
1018     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1019     srcbuf.seek(0)
1020     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1021         return tiow.read(), encoding, newline
1022
1023
1024 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1025     if not target_versions:
1026         # No target_version specified, so try all grammars.
1027         return [
1028             # Python 3.7+
1029             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1030             # Python 3.0-3.6
1031             pygram.python_grammar_no_print_statement_no_exec_statement,
1032             # Python 2.7 with future print_function import
1033             pygram.python_grammar_no_print_statement,
1034             # Python 2.7
1035             pygram.python_grammar,
1036         ]
1037
1038     if all(version.is_python2() for version in target_versions):
1039         # Python 2-only code, so try Python 2 grammars.
1040         return [
1041             # Python 2.7 with future print_function import
1042             pygram.python_grammar_no_print_statement,
1043             # Python 2.7
1044             pygram.python_grammar,
1045         ]
1046
1047     # Python 3-compatible code, so only try Python 3 grammar.
1048     grammars = []
1049     # If we have to parse both, try to parse async as a keyword first
1050     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1051         # Python 3.7+
1052         grammars.append(
1053             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1054         )
1055     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1056         # Python 3.0-3.6
1057         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1058     # At least one of the above branches must have been taken, because every Python
1059     # version has exactly one of the two 'ASYNC_*' flags
1060     return grammars
1061
1062
1063 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1064     """Given a string with source, return the lib2to3 Node."""
1065     if src_txt[-1:] != "\n":
1066         src_txt += "\n"
1067
1068     for grammar in get_grammars(set(target_versions)):
1069         drv = driver.Driver(grammar, pytree.convert)
1070         try:
1071             result = drv.parse_string(src_txt, True)
1072             break
1073
1074         except ParseError as pe:
1075             lineno, column = pe.context[1]
1076             lines = src_txt.splitlines()
1077             try:
1078                 faulty_line = lines[lineno - 1]
1079             except IndexError:
1080                 faulty_line = "<line number missing in source>"
1081             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1082     else:
1083         raise exc from None
1084
1085     if isinstance(result, Leaf):
1086         result = Node(syms.file_input, [result])
1087     return result
1088
1089
1090 def lib2to3_unparse(node: Node) -> str:
1091     """Given a lib2to3 node, return its string representation."""
1092     code = str(node)
1093     return code
1094
1095
1096 class Visitor(Generic[T]):
1097     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1098
1099     def visit(self, node: LN) -> Iterator[T]:
1100         """Main method to visit `node` and its children.
1101
1102         It tries to find a `visit_*()` method for the given `node.type`, like
1103         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1104         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1105         instead.
1106
1107         Then yields objects of type `T` from the selected visitor.
1108         """
1109         if node.type < 256:
1110             name = token.tok_name[node.type]
1111         else:
1112             name = str(type_repr(node.type))
1113         # We explicitly branch on whether a visitor exists (instead of
1114         # using self.visit_default as the default arg to getattr) in order
1115         # to save needing to create a bound method object and so mypyc can
1116         # generate a native call to visit_default.
1117         visitf = getattr(self, f"visit_{name}", None)
1118         if visitf:
1119             yield from visitf(node)
1120         else:
1121             yield from self.visit_default(node)
1122
1123     def visit_default(self, node: LN) -> Iterator[T]:
1124         """Default `visit_*()` implementation. Recurses to children of `node`."""
1125         if isinstance(node, Node):
1126             for child in node.children:
1127                 yield from self.visit(child)
1128
1129
1130 @dataclass
1131 class DebugVisitor(Visitor[T]):
1132     tree_depth: int = 0
1133
1134     def visit_default(self, node: LN) -> Iterator[T]:
1135         indent = " " * (2 * self.tree_depth)
1136         if isinstance(node, Node):
1137             _type = type_repr(node.type)
1138             out(f"{indent}{_type}", fg="yellow")
1139             self.tree_depth += 1
1140             for child in node.children:
1141                 yield from self.visit(child)
1142
1143             self.tree_depth -= 1
1144             out(f"{indent}/{_type}", fg="yellow", bold=False)
1145         else:
1146             _type = token.tok_name.get(node.type, str(node.type))
1147             out(f"{indent}{_type}", fg="blue", nl=False)
1148             if node.prefix:
1149                 # We don't have to handle prefixes for `Node` objects since
1150                 # that delegates to the first child anyway.
1151                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1152             out(f" {node.value!r}", fg="blue", bold=False)
1153
1154     @classmethod
1155     def show(cls, code: Union[str, Leaf, Node]) -> None:
1156         """Pretty-print the lib2to3 AST of a given string of `code`.
1157
1158         Convenience method for debugging.
1159         """
1160         v: DebugVisitor[None] = DebugVisitor()
1161         if isinstance(code, str):
1162             code = lib2to3_parse(code)
1163         list(v.visit(code))
1164
1165
1166 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1167 STATEMENT: Final = {
1168     syms.if_stmt,
1169     syms.while_stmt,
1170     syms.for_stmt,
1171     syms.try_stmt,
1172     syms.except_clause,
1173     syms.with_stmt,
1174     syms.funcdef,
1175     syms.classdef,
1176 }
1177 STANDALONE_COMMENT: Final = 153
1178 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1179 LOGIC_OPERATORS: Final = {"and", "or"}
1180 COMPARATORS: Final = {
1181     token.LESS,
1182     token.GREATER,
1183     token.EQEQUAL,
1184     token.NOTEQUAL,
1185     token.LESSEQUAL,
1186     token.GREATEREQUAL,
1187 }
1188 MATH_OPERATORS: Final = {
1189     token.VBAR,
1190     token.CIRCUMFLEX,
1191     token.AMPER,
1192     token.LEFTSHIFT,
1193     token.RIGHTSHIFT,
1194     token.PLUS,
1195     token.MINUS,
1196     token.STAR,
1197     token.SLASH,
1198     token.DOUBLESLASH,
1199     token.PERCENT,
1200     token.AT,
1201     token.TILDE,
1202     token.DOUBLESTAR,
1203 }
1204 STARS: Final = {token.STAR, token.DOUBLESTAR}
1205 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1206 VARARGS_PARENTS: Final = {
1207     syms.arglist,
1208     syms.argument,  # double star in arglist
1209     syms.trailer,  # single argument to call
1210     syms.typedargslist,
1211     syms.varargslist,  # lambdas
1212 }
1213 UNPACKING_PARENTS: Final = {
1214     syms.atom,  # single element of a list or set literal
1215     syms.dictsetmaker,
1216     syms.listmaker,
1217     syms.testlist_gexp,
1218     syms.testlist_star_expr,
1219 }
1220 TEST_DESCENDANTS: Final = {
1221     syms.test,
1222     syms.lambdef,
1223     syms.or_test,
1224     syms.and_test,
1225     syms.not_test,
1226     syms.comparison,
1227     syms.star_expr,
1228     syms.expr,
1229     syms.xor_expr,
1230     syms.and_expr,
1231     syms.shift_expr,
1232     syms.arith_expr,
1233     syms.trailer,
1234     syms.term,
1235     syms.power,
1236 }
1237 ASSIGNMENTS: Final = {
1238     "=",
1239     "+=",
1240     "-=",
1241     "*=",
1242     "@=",
1243     "/=",
1244     "%=",
1245     "&=",
1246     "|=",
1247     "^=",
1248     "<<=",
1249     ">>=",
1250     "**=",
1251     "//=",
1252 }
1253 COMPREHENSION_PRIORITY: Final = 20
1254 COMMA_PRIORITY: Final = 18
1255 TERNARY_PRIORITY: Final = 16
1256 LOGIC_PRIORITY: Final = 14
1257 STRING_PRIORITY: Final = 12
1258 COMPARATOR_PRIORITY: Final = 10
1259 MATH_PRIORITIES: Final = {
1260     token.VBAR: 9,
1261     token.CIRCUMFLEX: 8,
1262     token.AMPER: 7,
1263     token.LEFTSHIFT: 6,
1264     token.RIGHTSHIFT: 6,
1265     token.PLUS: 5,
1266     token.MINUS: 5,
1267     token.STAR: 4,
1268     token.SLASH: 4,
1269     token.DOUBLESLASH: 4,
1270     token.PERCENT: 4,
1271     token.AT: 4,
1272     token.TILDE: 3,
1273     token.DOUBLESTAR: 2,
1274 }
1275 DOT_PRIORITY: Final = 1
1276
1277
1278 @dataclass
1279 class BracketTracker:
1280     """Keeps track of brackets on a line."""
1281
1282     depth: int = 0
1283     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1284     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1285     previous: Optional[Leaf] = None
1286     _for_loop_depths: List[int] = field(default_factory=list)
1287     _lambda_argument_depths: List[int] = field(default_factory=list)
1288     invisible: List[Leaf] = field(default_factory=list)
1289
1290     def mark(self, leaf: Leaf) -> None:
1291         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1292
1293         All leaves receive an int `bracket_depth` field that stores how deep
1294         within brackets a given leaf is. 0 means there are no enclosing brackets
1295         that started on this line.
1296
1297         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1298         field that it forms a pair with. This is a one-directional link to
1299         avoid reference cycles.
1300
1301         If a leaf is a delimiter (a token on which Black can split the line if
1302         needed) and it's on depth 0, its `id()` is stored in the tracker's
1303         `delimiters` field.
1304         """
1305         if leaf.type == token.COMMENT:
1306             return
1307
1308         self.maybe_decrement_after_for_loop_variable(leaf)
1309         self.maybe_decrement_after_lambda_arguments(leaf)
1310         if leaf.type in CLOSING_BRACKETS:
1311             self.depth -= 1
1312             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1313             leaf.opening_bracket = opening_bracket
1314             if not leaf.value:
1315                 self.invisible.append(leaf)
1316         leaf.bracket_depth = self.depth
1317         if self.depth == 0:
1318             delim = is_split_before_delimiter(leaf, self.previous)
1319             if delim and self.previous is not None:
1320                 self.delimiters[id(self.previous)] = delim
1321             else:
1322                 delim = is_split_after_delimiter(leaf, self.previous)
1323                 if delim:
1324                     self.delimiters[id(leaf)] = delim
1325         if leaf.type in OPENING_BRACKETS:
1326             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1327             self.depth += 1
1328             if not leaf.value:
1329                 self.invisible.append(leaf)
1330         self.previous = leaf
1331         self.maybe_increment_lambda_arguments(leaf)
1332         self.maybe_increment_for_loop_variable(leaf)
1333
1334     def any_open_brackets(self) -> bool:
1335         """Return True if there is an yet unmatched open bracket on the line."""
1336         return bool(self.bracket_match)
1337
1338     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1339         """Return the highest priority of a delimiter found on the line.
1340
1341         Values are consistent with what `is_split_*_delimiter()` return.
1342         Raises ValueError on no delimiters.
1343         """
1344         return max(v for k, v in self.delimiters.items() if k not in exclude)
1345
1346     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1347         """Return the number of delimiters with the given `priority`.
1348
1349         If no `priority` is passed, defaults to max priority on the line.
1350         """
1351         if not self.delimiters:
1352             return 0
1353
1354         priority = priority or self.max_delimiter_priority()
1355         return sum(1 for p in self.delimiters.values() if p == priority)
1356
1357     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1358         """In a for loop, or comprehension, the variables are often unpacks.
1359
1360         To avoid splitting on the comma in this situation, increase the depth of
1361         tokens between `for` and `in`.
1362         """
1363         if leaf.type == token.NAME and leaf.value == "for":
1364             self.depth += 1
1365             self._for_loop_depths.append(self.depth)
1366             return True
1367
1368         return False
1369
1370     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1371         """See `maybe_increment_for_loop_variable` above for explanation."""
1372         if (
1373             self._for_loop_depths
1374             and self._for_loop_depths[-1] == self.depth
1375             and leaf.type == token.NAME
1376             and leaf.value == "in"
1377         ):
1378             self.depth -= 1
1379             self._for_loop_depths.pop()
1380             return True
1381
1382         return False
1383
1384     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1385         """In a lambda expression, there might be more than one argument.
1386
1387         To avoid splitting on the comma in this situation, increase the depth of
1388         tokens between `lambda` and `:`.
1389         """
1390         if leaf.type == token.NAME and leaf.value == "lambda":
1391             self.depth += 1
1392             self._lambda_argument_depths.append(self.depth)
1393             return True
1394
1395         return False
1396
1397     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1398         """See `maybe_increment_lambda_arguments` above for explanation."""
1399         if (
1400             self._lambda_argument_depths
1401             and self._lambda_argument_depths[-1] == self.depth
1402             and leaf.type == token.COLON
1403         ):
1404             self.depth -= 1
1405             self._lambda_argument_depths.pop()
1406             return True
1407
1408         return False
1409
1410     def get_open_lsqb(self) -> Optional[Leaf]:
1411         """Return the most recent opening square bracket (if any)."""
1412         return self.bracket_match.get((self.depth - 1, token.RSQB))
1413
1414
1415 @dataclass
1416 class Line:
1417     """Holds leaves and comments. Can be printed with `str(line)`."""
1418
1419     depth: int = 0
1420     leaves: List[Leaf] = field(default_factory=list)
1421     # keys ordered like `leaves`
1422     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1423     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1424     inside_brackets: bool = False
1425     should_explode: bool = False
1426
1427     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1428         """Add a new `leaf` to the end of the line.
1429
1430         Unless `preformatted` is True, the `leaf` will receive a new consistent
1431         whitespace prefix and metadata applied by :class:`BracketTracker`.
1432         Trailing commas are maybe removed, unpacked for loop variables are
1433         demoted from being delimiters.
1434
1435         Inline comments are put aside.
1436         """
1437         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1438         if not has_value:
1439             return
1440
1441         if token.COLON == leaf.type and self.is_class_paren_empty:
1442             del self.leaves[-2:]
1443         if self.leaves and not preformatted:
1444             # Note: at this point leaf.prefix should be empty except for
1445             # imports, for which we only preserve newlines.
1446             leaf.prefix += whitespace(
1447                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1448             )
1449         if self.inside_brackets or not preformatted:
1450             self.bracket_tracker.mark(leaf)
1451             if self.maybe_should_explode(leaf):
1452                 self.should_explode = True
1453         if not self.append_comment(leaf):
1454             self.leaves.append(leaf)
1455
1456     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1457         """Like :func:`append()` but disallow invalid standalone comment structure.
1458
1459         Raises ValueError when any `leaf` is appended after a standalone comment
1460         or when a standalone comment is not the first leaf on the line.
1461         """
1462         if self.bracket_tracker.depth == 0:
1463             if self.is_comment:
1464                 raise ValueError("cannot append to standalone comments")
1465
1466             if self.leaves and leaf.type == STANDALONE_COMMENT:
1467                 raise ValueError(
1468                     "cannot append standalone comments to a populated line"
1469                 )
1470
1471         self.append(leaf, preformatted=preformatted)
1472
1473     @property
1474     def is_comment(self) -> bool:
1475         """Is this line a standalone comment?"""
1476         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1477
1478     @property
1479     def is_decorator(self) -> bool:
1480         """Is this line a decorator?"""
1481         return bool(self) and self.leaves[0].type == token.AT
1482
1483     @property
1484     def is_import(self) -> bool:
1485         """Is this an import line?"""
1486         return bool(self) and is_import(self.leaves[0])
1487
1488     @property
1489     def is_class(self) -> bool:
1490         """Is this line a class definition?"""
1491         return (
1492             bool(self)
1493             and self.leaves[0].type == token.NAME
1494             and self.leaves[0].value == "class"
1495         )
1496
1497     @property
1498     def is_stub_class(self) -> bool:
1499         """Is this line a class definition with a body consisting only of "..."?"""
1500         return self.is_class and self.leaves[-3:] == [
1501             Leaf(token.DOT, ".") for _ in range(3)
1502         ]
1503
1504     @property
1505     def is_def(self) -> bool:
1506         """Is this a function definition? (Also returns True for async defs.)"""
1507         try:
1508             first_leaf = self.leaves[0]
1509         except IndexError:
1510             return False
1511
1512         try:
1513             second_leaf: Optional[Leaf] = self.leaves[1]
1514         except IndexError:
1515             second_leaf = None
1516         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1517             first_leaf.type == token.ASYNC
1518             and second_leaf is not None
1519             and second_leaf.type == token.NAME
1520             and second_leaf.value == "def"
1521         )
1522
1523     @property
1524     def is_class_paren_empty(self) -> bool:
1525         """Is this a class with no base classes but using parentheses?
1526
1527         Those are unnecessary and should be removed.
1528         """
1529         return (
1530             bool(self)
1531             and len(self.leaves) == 4
1532             and self.is_class
1533             and self.leaves[2].type == token.LPAR
1534             and self.leaves[2].value == "("
1535             and self.leaves[3].type == token.RPAR
1536             and self.leaves[3].value == ")"
1537         )
1538
1539     @property
1540     def is_triple_quoted_string(self) -> bool:
1541         """Is the line a triple quoted string?"""
1542         return (
1543             bool(self)
1544             and self.leaves[0].type == token.STRING
1545             and self.leaves[0].value.startswith(('"""', "'''"))
1546         )
1547
1548     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1549         """If so, needs to be split before emitting."""
1550         for leaf in self.leaves:
1551             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1552                 return True
1553
1554         return False
1555
1556     def contains_uncollapsable_type_comments(self) -> bool:
1557         ignored_ids = set()
1558         try:
1559             last_leaf = self.leaves[-1]
1560             ignored_ids.add(id(last_leaf))
1561             if last_leaf.type == token.COMMA or (
1562                 last_leaf.type == token.RPAR and not last_leaf.value
1563             ):
1564                 # When trailing commas or optional parens are inserted by Black for
1565                 # consistency, comments after the previous last element are not moved
1566                 # (they don't have to, rendering will still be correct).  So we ignore
1567                 # trailing commas and invisible.
1568                 last_leaf = self.leaves[-2]
1569                 ignored_ids.add(id(last_leaf))
1570         except IndexError:
1571             return False
1572
1573         # A type comment is uncollapsable if it is attached to a leaf
1574         # that isn't at the end of the line (since that could cause it
1575         # to get associated to a different argument) or if there are
1576         # comments before it (since that could cause it to get hidden
1577         # behind a comment.
1578         comment_seen = False
1579         for leaf_id, comments in self.comments.items():
1580             for comment in comments:
1581                 if is_type_comment(comment):
1582                     if comment_seen or (
1583                         not is_type_comment(comment, " ignore")
1584                         and leaf_id not in ignored_ids
1585                     ):
1586                         return True
1587
1588                 comment_seen = True
1589
1590         return False
1591
1592     def contains_unsplittable_type_ignore(self) -> bool:
1593         if not self.leaves:
1594             return False
1595
1596         # If a 'type: ignore' is attached to the end of a line, we
1597         # can't split the line, because we can't know which of the
1598         # subexpressions the ignore was meant to apply to.
1599         #
1600         # We only want this to apply to actual physical lines from the
1601         # original source, though: we don't want the presence of a
1602         # 'type: ignore' at the end of a multiline expression to
1603         # justify pushing it all onto one line. Thus we
1604         # (unfortunately) need to check the actual source lines and
1605         # only report an unsplittable 'type: ignore' if this line was
1606         # one line in the original code.
1607
1608         # Grab the first and last line numbers, skipping generated leaves
1609         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1610         last_line = next(
1611             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1612         )
1613
1614         if first_line == last_line:
1615             # We look at the last two leaves since a comma or an
1616             # invisible paren could have been added at the end of the
1617             # line.
1618             for node in self.leaves[-2:]:
1619                 for comment in self.comments.get(id(node), []):
1620                     if is_type_comment(comment, " ignore"):
1621                         return True
1622
1623         return False
1624
1625     def contains_multiline_strings(self) -> bool:
1626         return any(is_multiline_string(leaf) for leaf in self.leaves)
1627
1628     def maybe_should_explode(self, closing: Leaf) -> bool:
1629         """Return True if this line should explode (always be split), that is when:
1630         - there's a pre-existing trailing comma here; and
1631         - it's not a one-tuple.
1632         """
1633         if not (
1634             closing.type in CLOSING_BRACKETS
1635             and self.leaves
1636             and self.leaves[-1].type == token.COMMA
1637             and not self.leaves[-1].was_checked  # pre-existing
1638         ):
1639             return False
1640
1641         if closing.type in {token.RBRACE, token.RSQB}:
1642             return True
1643
1644         if self.is_import:
1645             return True
1646
1647         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1648             return True
1649
1650         return False
1651
1652     def append_comment(self, comment: Leaf) -> bool:
1653         """Add an inline or standalone comment to the line."""
1654         if (
1655             comment.type == STANDALONE_COMMENT
1656             and self.bracket_tracker.any_open_brackets()
1657         ):
1658             comment.prefix = ""
1659             return False
1660
1661         if comment.type != token.COMMENT:
1662             return False
1663
1664         if not self.leaves:
1665             comment.type = STANDALONE_COMMENT
1666             comment.prefix = ""
1667             return False
1668
1669         last_leaf = self.leaves[-1]
1670         if (
1671             last_leaf.type == token.RPAR
1672             and not last_leaf.value
1673             and last_leaf.parent
1674             and len(list(last_leaf.parent.leaves())) <= 3
1675             and not is_type_comment(comment)
1676         ):
1677             # Comments on an optional parens wrapping a single leaf should belong to
1678             # the wrapped node except if it's a type comment. Pinning the comment like
1679             # this avoids unstable formatting caused by comment migration.
1680             if len(self.leaves) < 2:
1681                 comment.type = STANDALONE_COMMENT
1682                 comment.prefix = ""
1683                 return False
1684
1685             last_leaf = self.leaves[-2]
1686         self.comments.setdefault(id(last_leaf), []).append(comment)
1687         return True
1688
1689     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1690         """Generate comments that should appear directly after `leaf`."""
1691         return self.comments.get(id(leaf), [])
1692
1693     def remove_trailing_comma(self) -> None:
1694         """Remove the trailing comma and moves the comments attached to it."""
1695         trailing_comma = self.leaves.pop()
1696         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1697         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1698             trailing_comma_comments
1699         )
1700
1701     def is_complex_subscript(self, leaf: Leaf) -> bool:
1702         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1703         open_lsqb = self.bracket_tracker.get_open_lsqb()
1704         if open_lsqb is None:
1705             return False
1706
1707         subscript_start = open_lsqb.next_sibling
1708
1709         if isinstance(subscript_start, Node):
1710             if subscript_start.type == syms.listmaker:
1711                 return False
1712
1713             if subscript_start.type == syms.subscriptlist:
1714                 subscript_start = child_towards(subscript_start, leaf)
1715         return subscript_start is not None and any(
1716             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1717         )
1718
1719     def clone(self) -> "Line":
1720         return Line(
1721             depth=self.depth,
1722             inside_brackets=self.inside_brackets,
1723             should_explode=self.should_explode,
1724         )
1725
1726     def __str__(self) -> str:
1727         """Render the line."""
1728         if not self:
1729             return "\n"
1730
1731         indent = "    " * self.depth
1732         leaves = iter(self.leaves)
1733         first = next(leaves)
1734         res = f"{first.prefix}{indent}{first.value}"
1735         for leaf in leaves:
1736             res += str(leaf)
1737         for comment in itertools.chain.from_iterable(self.comments.values()):
1738             res += str(comment)
1739
1740         return res + "\n"
1741
1742     def __bool__(self) -> bool:
1743         """Return True if the line has leaves or comments."""
1744         return bool(self.leaves or self.comments)
1745
1746
1747 @dataclass
1748 class EmptyLineTracker:
1749     """Provides a stateful method that returns the number of potential extra
1750     empty lines needed before and after the currently processed line.
1751
1752     Note: this tracker works on lines that haven't been split yet.  It assumes
1753     the prefix of the first leaf consists of optional newlines.  Those newlines
1754     are consumed by `maybe_empty_lines()` and included in the computation.
1755     """
1756
1757     is_pyi: bool = False
1758     previous_line: Optional[Line] = None
1759     previous_after: int = 0
1760     previous_defs: List[int] = field(default_factory=list)
1761
1762     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1763         """Return the number of extra empty lines before and after the `current_line`.
1764
1765         This is for separating `def`, `async def` and `class` with extra empty
1766         lines (two on module-level).
1767         """
1768         before, after = self._maybe_empty_lines(current_line)
1769         before = (
1770             # Black should not insert empty lines at the beginning
1771             # of the file
1772             0
1773             if self.previous_line is None
1774             else before - self.previous_after
1775         )
1776         self.previous_after = after
1777         self.previous_line = current_line
1778         return before, after
1779
1780     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1781         max_allowed = 1
1782         if current_line.depth == 0:
1783             max_allowed = 1 if self.is_pyi else 2
1784         if current_line.leaves:
1785             # Consume the first leaf's extra newlines.
1786             first_leaf = current_line.leaves[0]
1787             before = first_leaf.prefix.count("\n")
1788             before = min(before, max_allowed)
1789             first_leaf.prefix = ""
1790         else:
1791             before = 0
1792         depth = current_line.depth
1793         while self.previous_defs and self.previous_defs[-1] >= depth:
1794             self.previous_defs.pop()
1795             if self.is_pyi:
1796                 before = 0 if depth else 1
1797             else:
1798                 before = 1 if depth else 2
1799         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1800             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1801
1802         if (
1803             self.previous_line
1804             and self.previous_line.is_import
1805             and not current_line.is_import
1806             and depth == self.previous_line.depth
1807         ):
1808             return (before or 1), 0
1809
1810         if (
1811             self.previous_line
1812             and self.previous_line.is_class
1813             and current_line.is_triple_quoted_string
1814         ):
1815             return before, 1
1816
1817         return before, 0
1818
1819     def _maybe_empty_lines_for_class_or_def(
1820         self, current_line: Line, before: int
1821     ) -> Tuple[int, int]:
1822         if not current_line.is_decorator:
1823             self.previous_defs.append(current_line.depth)
1824         if self.previous_line is None:
1825             # Don't insert empty lines before the first line in the file.
1826             return 0, 0
1827
1828         if self.previous_line.is_decorator:
1829             return 0, 0
1830
1831         if self.previous_line.depth < current_line.depth and (
1832             self.previous_line.is_class or self.previous_line.is_def
1833         ):
1834             return 0, 0
1835
1836         if (
1837             self.previous_line.is_comment
1838             and self.previous_line.depth == current_line.depth
1839             and before == 0
1840         ):
1841             return 0, 0
1842
1843         if self.is_pyi:
1844             if self.previous_line.depth > current_line.depth:
1845                 newlines = 1
1846             elif current_line.is_class or self.previous_line.is_class:
1847                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1848                     # No blank line between classes with an empty body
1849                     newlines = 0
1850                 else:
1851                     newlines = 1
1852             elif current_line.is_def and not self.previous_line.is_def:
1853                 # Blank line between a block of functions and a block of non-functions
1854                 newlines = 1
1855             else:
1856                 newlines = 0
1857         else:
1858             newlines = 2
1859         if current_line.depth and newlines:
1860             newlines -= 1
1861         return newlines, 0
1862
1863
1864 @dataclass
1865 class LineGenerator(Visitor[Line]):
1866     """Generates reformatted Line objects.  Empty lines are not emitted.
1867
1868     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1869     in ways that will no longer stringify to valid Python code on the tree.
1870     """
1871
1872     is_pyi: bool = False
1873     normalize_strings: bool = True
1874     current_line: Line = field(default_factory=Line)
1875     remove_u_prefix: bool = False
1876
1877     def line(self, indent: int = 0) -> Iterator[Line]:
1878         """Generate a line.
1879
1880         If the line is empty, only emit if it makes sense.
1881         If the line is too long, split it first and then generate.
1882
1883         If any lines were generated, set up a new current_line.
1884         """
1885         if not self.current_line:
1886             self.current_line.depth += indent
1887             return  # Line is empty, don't emit. Creating a new one unnecessary.
1888
1889         complete_line = self.current_line
1890         self.current_line = Line(depth=complete_line.depth + indent)
1891         yield complete_line
1892
1893     def visit_default(self, node: LN) -> Iterator[Line]:
1894         """Default `visit_*()` implementation. Recurses to children of `node`."""
1895         if isinstance(node, Leaf):
1896             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1897             for comment in generate_comments(node):
1898                 if any_open_brackets:
1899                     # any comment within brackets is subject to splitting
1900                     self.current_line.append(comment)
1901                 elif comment.type == token.COMMENT:
1902                     # regular trailing comment
1903                     self.current_line.append(comment)
1904                     yield from self.line()
1905
1906                 else:
1907                     # regular standalone comment
1908                     yield from self.line()
1909
1910                     self.current_line.append(comment)
1911                     yield from self.line()
1912
1913             normalize_prefix(node, inside_brackets=any_open_brackets)
1914             if self.normalize_strings and node.type == token.STRING:
1915                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1916                 normalize_string_quotes(node)
1917             if node.type == token.NUMBER:
1918                 normalize_numeric_literal(node)
1919             if node.type not in WHITESPACE:
1920                 self.current_line.append(node)
1921         yield from super().visit_default(node)
1922
1923     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1924         """Increase indentation level, maybe yield a line."""
1925         # In blib2to3 INDENT never holds comments.
1926         yield from self.line(+1)
1927         yield from self.visit_default(node)
1928
1929     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1930         """Decrease indentation level, maybe yield a line."""
1931         # The current line might still wait for trailing comments.  At DEDENT time
1932         # there won't be any (they would be prefixes on the preceding NEWLINE).
1933         # Emit the line then.
1934         yield from self.line()
1935
1936         # While DEDENT has no value, its prefix may contain standalone comments
1937         # that belong to the current indentation level.  Get 'em.
1938         yield from self.visit_default(node)
1939
1940         # Finally, emit the dedent.
1941         yield from self.line(-1)
1942
1943     def visit_stmt(
1944         self, node: Node, keywords: Set[str], parens: Set[str]
1945     ) -> Iterator[Line]:
1946         """Visit a statement.
1947
1948         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1949         `def`, `with`, `class`, `assert` and assignments.
1950
1951         The relevant Python language `keywords` for a given statement will be
1952         NAME leaves within it. This methods puts those on a separate line.
1953
1954         `parens` holds a set of string leaf values immediately after which
1955         invisible parens should be put.
1956         """
1957         normalize_invisible_parens(node, parens_after=parens)
1958         for child in node.children:
1959             if child.type == token.NAME and child.value in keywords:  # type: ignore
1960                 yield from self.line()
1961
1962             yield from self.visit(child)
1963
1964     def visit_suite(self, node: Node) -> Iterator[Line]:
1965         """Visit a suite."""
1966         if self.is_pyi and is_stub_suite(node):
1967             yield from self.visit(node.children[2])
1968         else:
1969             yield from self.visit_default(node)
1970
1971     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1972         """Visit a statement without nested statements."""
1973         is_suite_like = node.parent and node.parent.type in STATEMENT
1974         if is_suite_like:
1975             if self.is_pyi and is_stub_body(node):
1976                 yield from self.visit_default(node)
1977             else:
1978                 yield from self.line(+1)
1979                 yield from self.visit_default(node)
1980                 yield from self.line(-1)
1981
1982         else:
1983             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1984                 yield from self.line()
1985             yield from self.visit_default(node)
1986
1987     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1988         """Visit `async def`, `async for`, `async with`."""
1989         yield from self.line()
1990
1991         children = iter(node.children)
1992         for child in children:
1993             yield from self.visit(child)
1994
1995             if child.type == token.ASYNC:
1996                 break
1997
1998         internal_stmt = next(children)
1999         for child in internal_stmt.children:
2000             yield from self.visit(child)
2001
2002     def visit_decorators(self, node: Node) -> Iterator[Line]:
2003         """Visit decorators."""
2004         for child in node.children:
2005             yield from self.line()
2006             yield from self.visit(child)
2007
2008     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2009         """Remove a semicolon and put the other statement on a separate line."""
2010         yield from self.line()
2011
2012     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2013         """End of file. Process outstanding comments and end with a newline."""
2014         yield from self.visit_default(leaf)
2015         yield from self.line()
2016
2017     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2018         if not self.current_line.bracket_tracker.any_open_brackets():
2019             yield from self.line()
2020         yield from self.visit_default(leaf)
2021
2022     def visit_factor(self, node: Node) -> Iterator[Line]:
2023         """Force parentheses between a unary op and a binary power:
2024
2025         -2 ** 8 -> -(2 ** 8)
2026         """
2027         _operator, operand = node.children
2028         if (
2029             operand.type == syms.power
2030             and len(operand.children) == 3
2031             and operand.children[1].type == token.DOUBLESTAR
2032         ):
2033             lpar = Leaf(token.LPAR, "(")
2034             rpar = Leaf(token.RPAR, ")")
2035             index = operand.remove() or 0
2036             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2037         yield from self.visit_default(node)
2038
2039     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2040         if is_docstring(leaf) and "\\\n" not in leaf.value:
2041             # We're ignoring docstrings with backslash newline escapes because changing
2042             # indentation of those changes the AST representation of the code.
2043             prefix = get_string_prefix(leaf.value)
2044             lead_len = len(prefix) + 3
2045             tail_len = -3
2046             indent = " " * 4 * self.current_line.depth
2047             docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
2048             if docstring:
2049                 if leaf.value[lead_len - 1] == docstring[0]:
2050                     docstring = " " + docstring
2051                 if leaf.value[tail_len + 1] == docstring[-1]:
2052                     docstring = docstring + " "
2053             leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
2054             normalize_string_quotes(leaf)
2055
2056         yield from self.visit_default(leaf)
2057
2058     def __post_init__(self) -> None:
2059         """You are in a twisty little maze of passages."""
2060         v = self.visit_stmt
2061         Ø: Set[str] = set()
2062         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2063         self.visit_if_stmt = partial(
2064             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2065         )
2066         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2067         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2068         self.visit_try_stmt = partial(
2069             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2070         )
2071         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2072         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2073         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2074         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2075         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2076         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2077         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2078         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2079         self.visit_async_funcdef = self.visit_async_stmt
2080         self.visit_decorated = self.visit_decorators
2081
2082
2083 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2084 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2085 OPENING_BRACKETS = set(BRACKET.keys())
2086 CLOSING_BRACKETS = set(BRACKET.values())
2087 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2088 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2089
2090
2091 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2092     """Return whitespace prefix if needed for the given `leaf`.
2093
2094     `complex_subscript` signals whether the given leaf is part of a subscription
2095     which has non-trivial arguments, like arithmetic expressions or function calls.
2096     """
2097     NO = ""
2098     SPACE = " "
2099     DOUBLESPACE = "  "
2100     t = leaf.type
2101     p = leaf.parent
2102     v = leaf.value
2103     if t in ALWAYS_NO_SPACE:
2104         return NO
2105
2106     if t == token.COMMENT:
2107         return DOUBLESPACE
2108
2109     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2110     if t == token.COLON and p.type not in {
2111         syms.subscript,
2112         syms.subscriptlist,
2113         syms.sliceop,
2114     }:
2115         return NO
2116
2117     prev = leaf.prev_sibling
2118     if not prev:
2119         prevp = preceding_leaf(p)
2120         if not prevp or prevp.type in OPENING_BRACKETS:
2121             return NO
2122
2123         if t == token.COLON:
2124             if prevp.type == token.COLON:
2125                 return NO
2126
2127             elif prevp.type != token.COMMA and not complex_subscript:
2128                 return NO
2129
2130             return SPACE
2131
2132         if prevp.type == token.EQUAL:
2133             if prevp.parent:
2134                 if prevp.parent.type in {
2135                     syms.arglist,
2136                     syms.argument,
2137                     syms.parameters,
2138                     syms.varargslist,
2139                 }:
2140                     return NO
2141
2142                 elif prevp.parent.type == syms.typedargslist:
2143                     # A bit hacky: if the equal sign has whitespace, it means we
2144                     # previously found it's a typed argument.  So, we're using
2145                     # that, too.
2146                     return prevp.prefix
2147
2148         elif prevp.type in VARARGS_SPECIALS:
2149             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2150                 return NO
2151
2152         elif prevp.type == token.COLON:
2153             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2154                 return SPACE if complex_subscript else NO
2155
2156         elif (
2157             prevp.parent
2158             and prevp.parent.type == syms.factor
2159             and prevp.type in MATH_OPERATORS
2160         ):
2161             return NO
2162
2163         elif (
2164             prevp.type == token.RIGHTSHIFT
2165             and prevp.parent
2166             and prevp.parent.type == syms.shift_expr
2167             and prevp.prev_sibling
2168             and prevp.prev_sibling.type == token.NAME
2169             and prevp.prev_sibling.value == "print"  # type: ignore
2170         ):
2171             # Python 2 print chevron
2172             return NO
2173
2174     elif prev.type in OPENING_BRACKETS:
2175         return NO
2176
2177     if p.type in {syms.parameters, syms.arglist}:
2178         # untyped function signatures or calls
2179         if not prev or prev.type != token.COMMA:
2180             return NO
2181
2182     elif p.type == syms.varargslist:
2183         # lambdas
2184         if prev and prev.type != token.COMMA:
2185             return NO
2186
2187     elif p.type == syms.typedargslist:
2188         # typed function signatures
2189         if not prev:
2190             return NO
2191
2192         if t == token.EQUAL:
2193             if prev.type != syms.tname:
2194                 return NO
2195
2196         elif prev.type == token.EQUAL:
2197             # A bit hacky: if the equal sign has whitespace, it means we
2198             # previously found it's a typed argument.  So, we're using that, too.
2199             return prev.prefix
2200
2201         elif prev.type != token.COMMA:
2202             return NO
2203
2204     elif p.type == syms.tname:
2205         # type names
2206         if not prev:
2207             prevp = preceding_leaf(p)
2208             if not prevp or prevp.type != token.COMMA:
2209                 return NO
2210
2211     elif p.type == syms.trailer:
2212         # attributes and calls
2213         if t == token.LPAR or t == token.RPAR:
2214             return NO
2215
2216         if not prev:
2217             if t == token.DOT:
2218                 prevp = preceding_leaf(p)
2219                 if not prevp or prevp.type != token.NUMBER:
2220                     return NO
2221
2222             elif t == token.LSQB:
2223                 return NO
2224
2225         elif prev.type != token.COMMA:
2226             return NO
2227
2228     elif p.type == syms.argument:
2229         # single argument
2230         if t == token.EQUAL:
2231             return NO
2232
2233         if not prev:
2234             prevp = preceding_leaf(p)
2235             if not prevp or prevp.type == token.LPAR:
2236                 return NO
2237
2238         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2239             return NO
2240
2241     elif p.type == syms.decorator:
2242         # decorators
2243         return NO
2244
2245     elif p.type == syms.dotted_name:
2246         if prev:
2247             return NO
2248
2249         prevp = preceding_leaf(p)
2250         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2251             return NO
2252
2253     elif p.type == syms.classdef:
2254         if t == token.LPAR:
2255             return NO
2256
2257         if prev and prev.type == token.LPAR:
2258             return NO
2259
2260     elif p.type in {syms.subscript, syms.sliceop}:
2261         # indexing
2262         if not prev:
2263             assert p.parent is not None, "subscripts are always parented"
2264             if p.parent.type == syms.subscriptlist:
2265                 return SPACE
2266
2267             return NO
2268
2269         elif not complex_subscript:
2270             return NO
2271
2272     elif p.type == syms.atom:
2273         if prev and t == token.DOT:
2274             # dots, but not the first one.
2275             return NO
2276
2277     elif p.type == syms.dictsetmaker:
2278         # dict unpacking
2279         if prev and prev.type == token.DOUBLESTAR:
2280             return NO
2281
2282     elif p.type in {syms.factor, syms.star_expr}:
2283         # unary ops
2284         if not prev:
2285             prevp = preceding_leaf(p)
2286             if not prevp or prevp.type in OPENING_BRACKETS:
2287                 return NO
2288
2289             prevp_parent = prevp.parent
2290             assert prevp_parent is not None
2291             if prevp.type == token.COLON and prevp_parent.type in {
2292                 syms.subscript,
2293                 syms.sliceop,
2294             }:
2295                 return NO
2296
2297             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2298                 return NO
2299
2300         elif t in {token.NAME, token.NUMBER, token.STRING}:
2301             return NO
2302
2303     elif p.type == syms.import_from:
2304         if t == token.DOT:
2305             if prev and prev.type == token.DOT:
2306                 return NO
2307
2308         elif t == token.NAME:
2309             if v == "import":
2310                 return SPACE
2311
2312             if prev and prev.type == token.DOT:
2313                 return NO
2314
2315     elif p.type == syms.sliceop:
2316         return NO
2317
2318     return SPACE
2319
2320
2321 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2322     """Return the first leaf that precedes `node`, if any."""
2323     while node:
2324         res = node.prev_sibling
2325         if res:
2326             if isinstance(res, Leaf):
2327                 return res
2328
2329             try:
2330                 return list(res.leaves())[-1]
2331
2332             except IndexError:
2333                 return None
2334
2335         node = node.parent
2336     return None
2337
2338
2339 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2340     """Return if the `node` and its previous siblings match types against the provided
2341     list of tokens; the provided `node`has its type matched against the last element in
2342     the list.  `None` can be used as the first element to declare that the start of the
2343     list is anchored at the start of its parent's children."""
2344     if not tokens:
2345         return True
2346     if tokens[-1] is None:
2347         return node is None
2348     if not node:
2349         return False
2350     if node.type != tokens[-1]:
2351         return False
2352     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2353
2354
2355 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2356     """Return the child of `ancestor` that contains `descendant`."""
2357     node: Optional[LN] = descendant
2358     while node and node.parent != ancestor:
2359         node = node.parent
2360     return node
2361
2362
2363 def container_of(leaf: Leaf) -> LN:
2364     """Return `leaf` or one of its ancestors that is the topmost container of it.
2365
2366     By "container" we mean a node where `leaf` is the very first child.
2367     """
2368     same_prefix = leaf.prefix
2369     container: LN = leaf
2370     while container:
2371         parent = container.parent
2372         if parent is None:
2373             break
2374
2375         if parent.children[0].prefix != same_prefix:
2376             break
2377
2378         if parent.type == syms.file_input:
2379             break
2380
2381         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2382             break
2383
2384         container = parent
2385     return container
2386
2387
2388 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2389     """Return the priority of the `leaf` delimiter, given a line break after it.
2390
2391     The delimiter priorities returned here are from those delimiters that would
2392     cause a line break after themselves.
2393
2394     Higher numbers are higher priority.
2395     """
2396     if leaf.type == token.COMMA:
2397         return COMMA_PRIORITY
2398
2399     return 0
2400
2401
2402 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2403     """Return the priority of the `leaf` delimiter, given a line break before it.
2404
2405     The delimiter priorities returned here are from those delimiters that would
2406     cause a line break before themselves.
2407
2408     Higher numbers are higher priority.
2409     """
2410     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2411         # * and ** might also be MATH_OPERATORS but in this case they are not.
2412         # Don't treat them as a delimiter.
2413         return 0
2414
2415     if (
2416         leaf.type == token.DOT
2417         and leaf.parent
2418         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2419         and (previous is None or previous.type in CLOSING_BRACKETS)
2420     ):
2421         return DOT_PRIORITY
2422
2423     if (
2424         leaf.type in MATH_OPERATORS
2425         and leaf.parent
2426         and leaf.parent.type not in {syms.factor, syms.star_expr}
2427     ):
2428         return MATH_PRIORITIES[leaf.type]
2429
2430     if leaf.type in COMPARATORS:
2431         return COMPARATOR_PRIORITY
2432
2433     if (
2434         leaf.type == token.STRING
2435         and previous is not None
2436         and previous.type == token.STRING
2437     ):
2438         return STRING_PRIORITY
2439
2440     if leaf.type not in {token.NAME, token.ASYNC}:
2441         return 0
2442
2443     if (
2444         leaf.value == "for"
2445         and leaf.parent
2446         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2447         or leaf.type == token.ASYNC
2448     ):
2449         if (
2450             not isinstance(leaf.prev_sibling, Leaf)
2451             or leaf.prev_sibling.value != "async"
2452         ):
2453             return COMPREHENSION_PRIORITY
2454
2455     if (
2456         leaf.value == "if"
2457         and leaf.parent
2458         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2459     ):
2460         return COMPREHENSION_PRIORITY
2461
2462     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2463         return TERNARY_PRIORITY
2464
2465     if leaf.value == "is":
2466         return COMPARATOR_PRIORITY
2467
2468     if (
2469         leaf.value == "in"
2470         and leaf.parent
2471         and leaf.parent.type in {syms.comp_op, syms.comparison}
2472         and not (
2473             previous is not None
2474             and previous.type == token.NAME
2475             and previous.value == "not"
2476         )
2477     ):
2478         return COMPARATOR_PRIORITY
2479
2480     if (
2481         leaf.value == "not"
2482         and leaf.parent
2483         and leaf.parent.type == syms.comp_op
2484         and not (
2485             previous is not None
2486             and previous.type == token.NAME
2487             and previous.value == "is"
2488         )
2489     ):
2490         return COMPARATOR_PRIORITY
2491
2492     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2493         return LOGIC_PRIORITY
2494
2495     return 0
2496
2497
2498 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2499 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2500
2501
2502 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2503     """Clean the prefix of the `leaf` and generate comments from it, if any.
2504
2505     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2506     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2507     move because it does away with modifying the grammar to include all the
2508     possible places in which comments can be placed.
2509
2510     The sad consequence for us though is that comments don't "belong" anywhere.
2511     This is why this function generates simple parentless Leaf objects for
2512     comments.  We simply don't know what the correct parent should be.
2513
2514     No matter though, we can live without this.  We really only need to
2515     differentiate between inline and standalone comments.  The latter don't
2516     share the line with any code.
2517
2518     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2519     are emitted with a fake STANDALONE_COMMENT token identifier.
2520     """
2521     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2522         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2523
2524
2525 @dataclass
2526 class ProtoComment:
2527     """Describes a piece of syntax that is a comment.
2528
2529     It's not a :class:`blib2to3.pytree.Leaf` so that:
2530
2531     * it can be cached (`Leaf` objects should not be reused more than once as
2532       they store their lineno, column, prefix, and parent information);
2533     * `newlines` and `consumed` fields are kept separate from the `value`. This
2534       simplifies handling of special marker comments like ``# fmt: off/on``.
2535     """
2536
2537     type: int  # token.COMMENT or STANDALONE_COMMENT
2538     value: str  # content of the comment
2539     newlines: int  # how many newlines before the comment
2540     consumed: int  # how many characters of the original leaf's prefix did we consume
2541
2542
2543 @lru_cache(maxsize=4096)
2544 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2545     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2546     result: List[ProtoComment] = []
2547     if not prefix or "#" not in prefix:
2548         return result
2549
2550     consumed = 0
2551     nlines = 0
2552     ignored_lines = 0
2553     for index, line in enumerate(prefix.split("\n")):
2554         consumed += len(line) + 1  # adding the length of the split '\n'
2555         line = line.lstrip()
2556         if not line:
2557             nlines += 1
2558         if not line.startswith("#"):
2559             # Escaped newlines outside of a comment are not really newlines at
2560             # all. We treat a single-line comment following an escaped newline
2561             # as a simple trailing comment.
2562             if line.endswith("\\"):
2563                 ignored_lines += 1
2564             continue
2565
2566         if index == ignored_lines and not is_endmarker:
2567             comment_type = token.COMMENT  # simple trailing comment
2568         else:
2569             comment_type = STANDALONE_COMMENT
2570         comment = make_comment(line)
2571         result.append(
2572             ProtoComment(
2573                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2574             )
2575         )
2576         nlines = 0
2577     return result
2578
2579
2580 def make_comment(content: str) -> str:
2581     """Return a consistently formatted comment from the given `content` string.
2582
2583     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2584     space between the hash sign and the content.
2585
2586     If `content` didn't start with a hash sign, one is provided.
2587     """
2588     content = content.rstrip()
2589     if not content:
2590         return "#"
2591
2592     if content[0] == "#":
2593         content = content[1:]
2594     if content and content[0] not in " !:#'%":
2595         content = " " + content
2596     return "#" + content
2597
2598
2599 def transform_line(
2600     line: Line, mode: Mode, features: Collection[Feature] = ()
2601 ) -> Iterator[Line]:
2602     """Transform a `line`, potentially splitting it into many lines.
2603
2604     They should fit in the allotted `line_length` but might not be able to.
2605
2606     `features` are syntactical features that may be used in the output.
2607     """
2608     if line.is_comment:
2609         yield line
2610         return
2611
2612     line_str = line_to_string(line)
2613
2614     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2615         """Initialize StringTransformer"""
2616         return ST(mode.line_length, mode.string_normalization)
2617
2618     string_merge = init_st(StringMerger)
2619     string_paren_strip = init_st(StringParenStripper)
2620     string_split = init_st(StringSplitter)
2621     string_paren_wrap = init_st(StringParenWrapper)
2622
2623     transformers: List[Transformer]
2624     if (
2625         not line.contains_uncollapsable_type_comments()
2626         and not line.should_explode
2627         and (
2628             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2629             or line.contains_unsplittable_type_ignore()
2630         )
2631         and not (line.inside_brackets and line.contains_standalone_comments())
2632     ):
2633         # Only apply basic string preprocessing, since lines shouldn't be split here.
2634         if mode.experimental_string_processing:
2635             transformers = [string_merge, string_paren_strip]
2636         else:
2637             transformers = []
2638     elif line.is_def:
2639         transformers = [left_hand_split]
2640     else:
2641
2642         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2643             """Wraps calls to `right_hand_split`.
2644
2645             The calls increasingly `omit` right-hand trailers (bracket pairs with
2646             content), meaning the trailers get glued together to split on another
2647             bracket pair instead.
2648             """
2649             for omit in generate_trailers_to_omit(line, mode.line_length):
2650                 lines = list(
2651                     right_hand_split(line, mode.line_length, features, omit=omit)
2652                 )
2653                 # Note: this check is only able to figure out if the first line of the
2654                 # *current* transformation fits in the line length.  This is true only
2655                 # for simple cases.  All others require running more transforms via
2656                 # `transform_line()`.  This check doesn't know if those would succeed.
2657                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2658                     yield from lines
2659                     return
2660
2661             # All splits failed, best effort split with no omits.
2662             # This mostly happens to multiline strings that are by definition
2663             # reported as not fitting a single line, as well as lines that contain
2664             # pre-existing trailing commas (those have to be exploded).
2665             yield from right_hand_split(
2666                 line, line_length=mode.line_length, features=features
2667             )
2668
2669         if mode.experimental_string_processing:
2670             if line.inside_brackets:
2671                 transformers = [
2672                     string_merge,
2673                     string_paren_strip,
2674                     delimiter_split,
2675                     standalone_comment_split,
2676                     string_split,
2677                     string_paren_wrap,
2678                     rhs,
2679                 ]
2680             else:
2681                 transformers = [
2682                     string_merge,
2683                     string_paren_strip,
2684                     string_split,
2685                     string_paren_wrap,
2686                     rhs,
2687                 ]
2688         else:
2689             if line.inside_brackets:
2690                 transformers = [delimiter_split, standalone_comment_split, rhs]
2691             else:
2692                 transformers = [rhs]
2693
2694     for transform in transformers:
2695         # We are accumulating lines in `result` because we might want to abort
2696         # mission and return the original line in the end, or attempt a different
2697         # split altogether.
2698         try:
2699             result = run_transformer(line, transform, mode, features, line_str=line_str)
2700         except CannotTransform:
2701             continue
2702         else:
2703             yield from result
2704             break
2705
2706     else:
2707         yield line
2708
2709
2710 @dataclass  # type: ignore
2711 class StringTransformer(ABC):
2712     """
2713     An implementation of the Transformer protocol that relies on its
2714     subclasses overriding the template methods `do_match(...)` and
2715     `do_transform(...)`.
2716
2717     This Transformer works exclusively on strings (for example, by merging
2718     or splitting them).
2719
2720     The following sections can be found among the docstrings of each concrete
2721     StringTransformer subclass.
2722
2723     Requirements:
2724         Which requirements must be met of the given Line for this
2725         StringTransformer to be applied?
2726
2727     Transformations:
2728         If the given Line meets all of the above requirements, which string
2729         transformations can you expect to be applied to it by this
2730         StringTransformer?
2731
2732     Collaborations:
2733         What contractual agreements does this StringTransformer have with other
2734         StringTransfomers? Such collaborations should be eliminated/minimized
2735         as much as possible.
2736     """
2737
2738     line_length: int
2739     normalize_strings: bool
2740     __name__ = "StringTransformer"
2741
2742     @abstractmethod
2743     def do_match(self, line: Line) -> TMatchResult:
2744         """
2745         Returns:
2746             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2747             string, if a match was able to be made.
2748                 OR
2749             * Err(CannotTransform), if a match was not able to be made.
2750         """
2751
2752     @abstractmethod
2753     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2754         """
2755         Yields:
2756             * Ok(new_line) where new_line is the new transformed line.
2757                 OR
2758             * Err(CannotTransform) if the transformation failed for some reason. The
2759             `do_match(...)` template method should usually be used to reject
2760             the form of the given Line, but in some cases it is difficult to
2761             know whether or not a Line meets the StringTransformer's
2762             requirements until the transformation is already midway.
2763
2764         Side Effects:
2765             This method should NOT mutate @line directly, but it MAY mutate the
2766             Line's underlying Node structure. (WARNING: If the underlying Node
2767             structure IS altered, then this method should NOT be allowed to
2768             yield an CannotTransform after that point.)
2769         """
2770
2771     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2772         """
2773         StringTransformer instances have a call signature that mirrors that of
2774         the Transformer type.
2775
2776         Raises:
2777             CannotTransform(...) if the concrete StringTransformer class is unable
2778             to transform @line.
2779         """
2780         # Optimization to avoid calling `self.do_match(...)` when the line does
2781         # not contain any string.
2782         if not any(leaf.type == token.STRING for leaf in line.leaves):
2783             raise CannotTransform("There are no strings in this line.")
2784
2785         match_result = self.do_match(line)
2786
2787         if isinstance(match_result, Err):
2788             cant_transform = match_result.err()
2789             raise CannotTransform(
2790                 f"The string transformer {self.__class__.__name__} does not recognize"
2791                 " this line as one that it can transform."
2792             ) from cant_transform
2793
2794         string_idx = match_result.ok()
2795
2796         for line_result in self.do_transform(line, string_idx):
2797             if isinstance(line_result, Err):
2798                 cant_transform = line_result.err()
2799                 raise CannotTransform(
2800                     "StringTransformer failed while attempting to transform string."
2801                 ) from cant_transform
2802             line = line_result.ok()
2803             yield line
2804
2805
2806 @dataclass
2807 class CustomSplit:
2808     """A custom (i.e. manual) string split.
2809
2810     A single CustomSplit instance represents a single substring.
2811
2812     Examples:
2813         Consider the following string:
2814         ```
2815         "Hi there friend."
2816         " This is a custom"
2817         f" string {split}."
2818         ```
2819
2820         This string will correspond to the following three CustomSplit instances:
2821         ```
2822         CustomSplit(False, 16)
2823         CustomSplit(False, 17)
2824         CustomSplit(True, 16)
2825         ```
2826     """
2827
2828     has_prefix: bool
2829     break_idx: int
2830
2831
2832 class CustomSplitMapMixin:
2833     """
2834     This mixin class is used to map merged strings to a sequence of
2835     CustomSplits, which will then be used to re-split the strings iff none of
2836     the resultant substrings go over the configured max line length.
2837     """
2838
2839     _Key = Tuple[StringID, str]
2840     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
2841
2842     @staticmethod
2843     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
2844         """
2845         Returns:
2846             A unique identifier that is used internally to map @string to a
2847             group of custom splits.
2848         """
2849         return (id(string), string)
2850
2851     def add_custom_splits(
2852         self, string: str, custom_splits: Iterable[CustomSplit]
2853     ) -> None:
2854         """Custom Split Map Setter Method
2855
2856         Side Effects:
2857             Adds a mapping from @string to the custom splits @custom_splits.
2858         """
2859         key = self._get_key(string)
2860         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
2861
2862     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
2863         """Custom Split Map Getter Method
2864
2865         Returns:
2866             * A list of the custom splits that are mapped to @string, if any
2867             exist.
2868                 OR
2869             * [], otherwise.
2870
2871         Side Effects:
2872             Deletes the mapping between @string and its associated custom
2873             splits (which are returned to the caller).
2874         """
2875         key = self._get_key(string)
2876
2877         custom_splits = self._CUSTOM_SPLIT_MAP[key]
2878         del self._CUSTOM_SPLIT_MAP[key]
2879
2880         return list(custom_splits)
2881
2882     def has_custom_splits(self, string: str) -> bool:
2883         """
2884         Returns:
2885             True iff @string is associated with a set of custom splits.
2886         """
2887         key = self._get_key(string)
2888         return key in self._CUSTOM_SPLIT_MAP
2889
2890
2891 class StringMerger(CustomSplitMapMixin, StringTransformer):
2892     """StringTransformer that merges strings together.
2893
2894     Requirements:
2895         (A) The line contains adjacent strings such that at most one substring
2896         has inline comments AND none of those inline comments are pragmas AND
2897         the set of all substring prefixes is either of length 1 or equal to
2898         {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed
2899         with 'r').
2900             OR
2901         (B) The line contains a string which uses line continuation backslashes.
2902
2903     Transformations:
2904         Depending on which of the two requirements above where met, either:
2905
2906         (A) The string group associated with the target string is merged.
2907             OR
2908         (B) All line-continuation backslashes are removed from the target string.
2909
2910     Collaborations:
2911         StringMerger provides custom split information to StringSplitter.
2912     """
2913
2914     def do_match(self, line: Line) -> TMatchResult:
2915         LL = line.leaves
2916
2917         is_valid_index = is_valid_index_factory(LL)
2918
2919         for (i, leaf) in enumerate(LL):
2920             if (
2921                 leaf.type == token.STRING
2922                 and is_valid_index(i + 1)
2923                 and LL[i + 1].type == token.STRING
2924             ):
2925                 return Ok(i)
2926
2927             if leaf.type == token.STRING and "\\\n" in leaf.value:
2928                 return Ok(i)
2929
2930         return TErr("This line has no strings that need merging.")
2931
2932     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2933         new_line = line
2934         rblc_result = self.__remove_backslash_line_continuation_chars(
2935             new_line, string_idx
2936         )
2937         if isinstance(rblc_result, Ok):
2938             new_line = rblc_result.ok()
2939
2940         msg_result = self.__merge_string_group(new_line, string_idx)
2941         if isinstance(msg_result, Ok):
2942             new_line = msg_result.ok()
2943
2944         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
2945             msg_cant_transform = msg_result.err()
2946             rblc_cant_transform = rblc_result.err()
2947             cant_transform = CannotTransform(
2948                 "StringMerger failed to merge any strings in this line."
2949             )
2950
2951             # Chain the errors together using `__cause__`.
2952             msg_cant_transform.__cause__ = rblc_cant_transform
2953             cant_transform.__cause__ = msg_cant_transform
2954
2955             yield Err(cant_transform)
2956         else:
2957             yield Ok(new_line)
2958
2959     @staticmethod
2960     def __remove_backslash_line_continuation_chars(
2961         line: Line, string_idx: int
2962     ) -> TResult[Line]:
2963         """
2964         Merge strings that were split across multiple lines using
2965         line-continuation backslashes.
2966
2967         Returns:
2968             Ok(new_line), if @line contains backslash line-continuation
2969             characters.
2970                 OR
2971             Err(CannotTransform), otherwise.
2972         """
2973         LL = line.leaves
2974
2975         string_leaf = LL[string_idx]
2976         if not (
2977             string_leaf.type == token.STRING
2978             and "\\\n" in string_leaf.value
2979             and not has_triple_quotes(string_leaf.value)
2980         ):
2981             return TErr(
2982                 f"String leaf {string_leaf} does not contain any backslash line"
2983                 " continuation characters."
2984             )
2985
2986         new_line = line.clone()
2987         new_line.comments = line.comments.copy()
2988         append_leaves(new_line, line, LL)
2989
2990         new_string_leaf = new_line.leaves[string_idx]
2991         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
2992
2993         return Ok(new_line)
2994
2995     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
2996         """
2997         Merges string group (i.e. set of adjacent strings) where the first
2998         string in the group is `line.leaves[string_idx]`.
2999
3000         Returns:
3001             Ok(new_line), if ALL of the validation checks found in
3002             __validate_msg(...) pass.
3003                 OR
3004             Err(CannotTransform), otherwise.
3005         """
3006         LL = line.leaves
3007
3008         is_valid_index = is_valid_index_factory(LL)
3009
3010         vresult = self.__validate_msg(line, string_idx)
3011         if isinstance(vresult, Err):
3012             return vresult
3013
3014         # If the string group is wrapped inside an Atom node, we must make sure
3015         # to later replace that Atom with our new (merged) string leaf.
3016         atom_node = LL[string_idx].parent
3017
3018         # We will place BREAK_MARK in between every two substrings that we
3019         # merge. We will then later go through our final result and use the
3020         # various instances of BREAK_MARK we find to add the right values to
3021         # the custom split map.
3022         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3023
3024         QUOTE = LL[string_idx].value[-1]
3025
3026         def make_naked(string: str, string_prefix: str) -> str:
3027             """Strip @string (i.e. make it a "naked" string)
3028
3029             Pre-conditions:
3030                 * assert_is_leaf_string(@string)
3031
3032             Returns:
3033                 A string that is identical to @string except that
3034                 @string_prefix has been stripped, the surrounding QUOTE
3035                 characters have been removed, and any remaining QUOTE
3036                 characters have been escaped.
3037             """
3038             assert_is_leaf_string(string)
3039
3040             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3041             naked_string = string[len(string_prefix) + 1 : -1]
3042             naked_string = re.sub(
3043                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3044             )
3045             return naked_string
3046
3047         # Holds the CustomSplit objects that will later be added to the custom
3048         # split map.
3049         custom_splits = []
3050
3051         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3052         prefix_tracker = []
3053
3054         # Sets the 'prefix' variable. This is the prefix that the final merged
3055         # string will have.
3056         next_str_idx = string_idx
3057         prefix = ""
3058         while (
3059             not prefix
3060             and is_valid_index(next_str_idx)
3061             and LL[next_str_idx].type == token.STRING
3062         ):
3063             prefix = get_string_prefix(LL[next_str_idx].value)
3064             next_str_idx += 1
3065
3066         # The next loop merges the string group. The final string will be
3067         # contained in 'S'.
3068         #
3069         # The following convenience variables are used:
3070         #
3071         #   S: string
3072         #   NS: naked string
3073         #   SS: next string
3074         #   NSS: naked next string
3075         S = ""
3076         NS = ""
3077         num_of_strings = 0
3078         next_str_idx = string_idx
3079         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3080             num_of_strings += 1
3081
3082             SS = LL[next_str_idx].value
3083             next_prefix = get_string_prefix(SS)
3084
3085             # If this is an f-string group but this substring is not prefixed
3086             # with 'f'...
3087             if "f" in prefix and "f" not in next_prefix:
3088                 # Then we must escape any braces contained in this substring.
3089                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3090
3091             NSS = make_naked(SS, next_prefix)
3092
3093             has_prefix = bool(next_prefix)
3094             prefix_tracker.append(has_prefix)
3095
3096             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3097             NS = make_naked(S, prefix)
3098
3099             next_str_idx += 1
3100
3101         S_leaf = Leaf(token.STRING, S)
3102         if self.normalize_strings:
3103             normalize_string_quotes(S_leaf)
3104
3105         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3106         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3107         for has_prefix in prefix_tracker:
3108             mark_idx = temp_string.find(BREAK_MARK)
3109             assert (
3110                 mark_idx >= 0
3111             ), "Logic error while filling the custom string breakpoint cache."
3112
3113             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3114             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3115             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3116
3117         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3118
3119         if atom_node is not None:
3120             replace_child(atom_node, string_leaf)
3121
3122         # Build the final line ('new_line') that this method will later return.
3123         new_line = line.clone()
3124         for (i, leaf) in enumerate(LL):
3125             if i == string_idx:
3126                 new_line.append(string_leaf)
3127
3128             if string_idx <= i < string_idx + num_of_strings:
3129                 for comment_leaf in line.comments_after(LL[i]):
3130                     new_line.append(comment_leaf, preformatted=True)
3131                 continue
3132
3133             append_leaves(new_line, line, [leaf])
3134
3135         self.add_custom_splits(string_leaf.value, custom_splits)
3136         return Ok(new_line)
3137
3138     @staticmethod
3139     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3140         """Validate (M)erge (S)tring (G)roup
3141
3142         Transform-time string validation logic for __merge_string_group(...).
3143
3144         Returns:
3145             * Ok(None), if ALL validation checks (listed below) pass.
3146                 OR
3147             * Err(CannotTransform), if any of the following are true:
3148                 - The target string is not in a string group (i.e. it has no
3149                   adjacent strings).
3150                 - The string group has more than one inline comment.
3151                 - The string group has an inline comment that appears to be a pragma.
3152                 - The set of all string prefixes in the string group is of
3153                   length greater than one and is not equal to {"", "f"}.
3154                 - The string group consists of raw strings.
3155         """
3156         num_of_inline_string_comments = 0
3157         set_of_prefixes = set()
3158         num_of_strings = 0
3159         for leaf in line.leaves[string_idx:]:
3160             if leaf.type != token.STRING:
3161                 # If the string group is trailed by a comma, we count the
3162                 # comments trailing the comma to be one of the string group's
3163                 # comments.
3164                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3165                     num_of_inline_string_comments += 1
3166                 break
3167
3168             if has_triple_quotes(leaf.value):
3169                 return TErr("StringMerger does NOT merge multiline strings.")
3170
3171             num_of_strings += 1
3172             prefix = get_string_prefix(leaf.value)
3173             if "r" in prefix:
3174                 return TErr("StringMerger does NOT merge raw strings.")
3175
3176             set_of_prefixes.add(prefix)
3177
3178             if id(leaf) in line.comments:
3179                 num_of_inline_string_comments += 1
3180                 if contains_pragma_comment(line.comments[id(leaf)]):
3181                     return TErr("Cannot merge strings which have pragma comments.")
3182
3183         if num_of_strings < 2:
3184             return TErr(
3185                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3186             )
3187
3188         if num_of_inline_string_comments > 1:
3189             return TErr(
3190                 f"Too many inline string comments ({num_of_inline_string_comments})."
3191             )
3192
3193         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3194             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3195
3196         return Ok(None)
3197
3198
3199 class StringParenStripper(StringTransformer):
3200     """StringTransformer that strips surrounding parentheses from strings.
3201
3202     Requirements:
3203         The line contains a string which is surrounded by parentheses and:
3204             - The target string is NOT the only argument to a function call).
3205             - If the target string contains a PERCENT, the brackets are not
3206               preceeded or followed by an operator with higher precedence than
3207               PERCENT.
3208
3209     Transformations:
3210         The parentheses mentioned in the 'Requirements' section are stripped.
3211
3212     Collaborations:
3213         StringParenStripper has its own inherent usefulness, but it is also
3214         relied on to clean up the parentheses created by StringParenWrapper (in
3215         the event that they are no longer needed).
3216     """
3217
3218     def do_match(self, line: Line) -> TMatchResult:
3219         LL = line.leaves
3220
3221         is_valid_index = is_valid_index_factory(LL)
3222
3223         for (idx, leaf) in enumerate(LL):
3224             # Should be a string...
3225             if leaf.type != token.STRING:
3226                 continue
3227
3228             # Should be preceded by a non-empty LPAR...
3229             if (
3230                 not is_valid_index(idx - 1)
3231                 or LL[idx - 1].type != token.LPAR
3232                 or is_empty_lpar(LL[idx - 1])
3233             ):
3234                 continue
3235
3236             # That LPAR should NOT be preceded by a function name or a closing
3237             # bracket (which could be a function which returns a function or a
3238             # list/dictionary that contains a function)...
3239             if is_valid_index(idx - 2) and (
3240                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3241             ):
3242                 continue
3243
3244             string_idx = idx
3245
3246             # Skip the string trailer, if one exists.
3247             string_parser = StringParser()
3248             next_idx = string_parser.parse(LL, string_idx)
3249
3250             # if the leaves in the parsed string include a PERCENT, we need to
3251             # make sure the initial LPAR is NOT preceded by an operator with
3252             # higher or equal precedence to PERCENT
3253             if is_valid_index(idx - 2):
3254                 # mypy can't quite follow unless we name this
3255                 before_lpar = LL[idx - 2]
3256                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3257                     (
3258                         before_lpar.type
3259                         in {
3260                             token.STAR,
3261                             token.AT,
3262                             token.SLASH,
3263                             token.DOUBLESLASH,
3264                             token.PERCENT,
3265                             token.TILDE,
3266                             token.DOUBLESTAR,
3267                             token.AWAIT,
3268                             token.LSQB,
3269                             token.LPAR,
3270                         }
3271                     )
3272                     or (
3273                         # only unary PLUS/MINUS
3274                         before_lpar.parent
3275                         and before_lpar.parent.type == syms.factor
3276                         and (before_lpar.type in {token.PLUS, token.MINUS})
3277                     )
3278                 ):
3279                     continue
3280
3281             # Should be followed by a non-empty RPAR...
3282             if (
3283                 is_valid_index(next_idx)
3284                 and LL[next_idx].type == token.RPAR
3285                 and not is_empty_rpar(LL[next_idx])
3286             ):
3287                 # That RPAR should NOT be followed by anything with higher
3288                 # precedence than PERCENT
3289                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3290                     token.DOUBLESTAR,
3291                     token.LSQB,
3292                     token.LPAR,
3293                     token.DOT,
3294                 }:
3295                     continue
3296
3297                 return Ok(string_idx)
3298
3299         return TErr("This line has no strings wrapped in parens.")
3300
3301     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3302         LL = line.leaves
3303
3304         string_parser = StringParser()
3305         rpar_idx = string_parser.parse(LL, string_idx)
3306
3307         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3308             if line.comments_after(leaf):
3309                 yield TErr(
3310                     "Will not strip parentheses which have comments attached to them."
3311                 )
3312
3313         new_line = line.clone()
3314         new_line.comments = line.comments.copy()
3315         append_leaves(new_line, line, LL[: string_idx - 1])
3316
3317         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3318         LL[string_idx - 1].remove()
3319         replace_child(LL[string_idx], string_leaf)
3320         new_line.append(string_leaf)
3321
3322         append_leaves(
3323             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3324         )
3325
3326         LL[rpar_idx].remove()
3327
3328         yield Ok(new_line)
3329
3330
3331 class BaseStringSplitter(StringTransformer):
3332     """
3333     Abstract class for StringTransformers which transform a Line's strings by splitting
3334     them or placing them on their own lines where necessary to avoid going over
3335     the configured line length.
3336
3337     Requirements:
3338         * The target string value is responsible for the line going over the
3339         line length limit. It follows that after all of black's other line
3340         split methods have been exhausted, this line (or one of the resulting
3341         lines after all line splits are performed) would still be over the
3342         line_length limit unless we split this string.
3343             AND
3344         * The target string is NOT a "pointless" string (i.e. a string that has
3345         no parent or siblings).
3346             AND
3347         * The target string is not followed by an inline comment that appears
3348         to be a pragma.
3349             AND
3350         * The target string is not a multiline (i.e. triple-quote) string.
3351     """
3352
3353     @abstractmethod
3354     def do_splitter_match(self, line: Line) -> TMatchResult:
3355         """
3356         BaseStringSplitter asks its clients to override this method instead of
3357         `StringTransformer.do_match(...)`.
3358
3359         Follows the same protocol as `StringTransformer.do_match(...)`.
3360
3361         Refer to `help(StringTransformer.do_match)` for more information.
3362         """
3363
3364     def do_match(self, line: Line) -> TMatchResult:
3365         match_result = self.do_splitter_match(line)
3366         if isinstance(match_result, Err):
3367             return match_result
3368
3369         string_idx = match_result.ok()
3370         vresult = self.__validate(line, string_idx)
3371         if isinstance(vresult, Err):
3372             return vresult
3373
3374         return match_result
3375
3376     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3377         """
3378         Checks that @line meets all of the requirements listed in this classes'
3379         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3380         description of those requirements.
3381
3382         Returns:
3383             * Ok(None), if ALL of the requirements are met.
3384                 OR
3385             * Err(CannotTransform), if ANY of the requirements are NOT met.
3386         """
3387         LL = line.leaves
3388
3389         string_leaf = LL[string_idx]
3390
3391         max_string_length = self.__get_max_string_length(line, string_idx)
3392         if len(string_leaf.value) <= max_string_length:
3393             return TErr(
3394                 "The string itself is not what is causing this line to be too long."
3395             )
3396
3397         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3398             token.STRING,
3399             token.NEWLINE,
3400         ]:
3401             return TErr(
3402                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3403                 " no parent)."
3404             )
3405
3406         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3407             line.comments[id(line.leaves[string_idx])]
3408         ):
3409             return TErr(
3410                 "Line appears to end with an inline pragma comment. Splitting the line"
3411                 " could modify the pragma's behavior."
3412             )
3413
3414         if has_triple_quotes(string_leaf.value):
3415             return TErr("We cannot split multiline strings.")
3416
3417         return Ok(None)
3418
3419     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3420         """
3421         Calculates the max string length used when attempting to determine
3422         whether or not the target string is responsible for causing the line to
3423         go over the line length limit.
3424
3425         WARNING: This method is tightly coupled to both StringSplitter and
3426         (especially) StringParenWrapper. There is probably a better way to
3427         accomplish what is being done here.
3428
3429         Returns:
3430             max_string_length: such that `line.leaves[string_idx].value >
3431             max_string_length` implies that the target string IS responsible
3432             for causing this line to exceed the line length limit.
3433         """
3434         LL = line.leaves
3435
3436         is_valid_index = is_valid_index_factory(LL)
3437
3438         # We use the shorthand "WMA4" in comments to abbreviate "We must
3439         # account for". When giving examples, we use STRING to mean some/any
3440         # valid string.
3441         #
3442         # Finally, we use the following convenience variables:
3443         #
3444         #   P:  The leaf that is before the target string leaf.
3445         #   N:  The leaf that is after the target string leaf.
3446         #   NN: The leaf that is after N.
3447
3448         # WMA4 the whitespace at the beginning of the line.
3449         offset = line.depth * 4
3450
3451         if is_valid_index(string_idx - 1):
3452             p_idx = string_idx - 1
3453             if (
3454                 LL[string_idx - 1].type == token.LPAR
3455                 and LL[string_idx - 1].value == ""
3456                 and string_idx >= 2
3457             ):
3458                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3459                 p_idx -= 1
3460
3461             P = LL[p_idx]
3462             if P.type == token.PLUS:
3463                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3464                 offset += 2
3465
3466             if P.type == token.COMMA:
3467                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3468                 offset += 3
3469
3470             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3471                 # This conditional branch is meant to handle dictionary keys,
3472                 # variable assignments, 'return STRING' statement lines, and
3473                 # 'else STRING' ternary expression lines.
3474
3475                 # WMA4 a single space.
3476                 offset += 1
3477
3478                 # WMA4 the lengths of any leaves that came before that space.
3479                 for leaf in LL[: p_idx + 1]:
3480                     offset += len(str(leaf))
3481
3482         if is_valid_index(string_idx + 1):
3483             N = LL[string_idx + 1]
3484             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3485                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3486                 N = LL[string_idx + 2]
3487
3488             if N.type == token.COMMA:
3489                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3490                 offset += 1
3491
3492             if is_valid_index(string_idx + 2):
3493                 NN = LL[string_idx + 2]
3494
3495                 if N.type == token.DOT and NN.type == token.NAME:
3496                     # This conditional branch is meant to handle method calls invoked
3497                     # off of a string literal up to and including the LPAR character.
3498
3499                     # WMA4 the '.' character.
3500                     offset += 1
3501
3502                     if (
3503                         is_valid_index(string_idx + 3)
3504                         and LL[string_idx + 3].type == token.LPAR
3505                     ):
3506                         # WMA4 the left parenthesis character.
3507                         offset += 1
3508
3509                     # WMA4 the length of the method's name.
3510                     offset += len(NN.value)
3511
3512         has_comments = False
3513         for comment_leaf in line.comments_after(LL[string_idx]):
3514             if not has_comments:
3515                 has_comments = True
3516                 # WMA4 two spaces before the '#' character.
3517                 offset += 2
3518
3519             # WMA4 the length of the inline comment.
3520             offset += len(comment_leaf.value)
3521
3522         max_string_length = self.line_length - offset
3523         return max_string_length
3524
3525
3526 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3527     """
3528     StringTransformer that splits "atom" strings (i.e. strings which exist on
3529     lines by themselves).
3530
3531     Requirements:
3532         * The line consists ONLY of a single string (with the exception of a
3533         '+' symbol which MAY exist at the start of the line), MAYBE a string
3534         trailer, and MAYBE a trailing comma.
3535             AND
3536         * All of the requirements listed in BaseStringSplitter's docstring.
3537
3538     Transformations:
3539         The string mentioned in the 'Requirements' section is split into as
3540         many substrings as necessary to adhere to the configured line length.
3541
3542         In the final set of substrings, no substring should be smaller than
3543         MIN_SUBSTR_SIZE characters.
3544
3545         The string will ONLY be split on spaces (i.e. each new substring should
3546         start with a space).
3547
3548         If the string is an f-string, it will NOT be split in the middle of an
3549         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3550         else bar()} is an f-expression).
3551
3552         If the string that is being split has an associated set of custom split
3553         records and those custom splits will NOT result in any line going over
3554         the configured line length, those custom splits are used. Otherwise the
3555         string is split as late as possible (from left-to-right) while still
3556         adhering to the transformation rules listed above.
3557
3558     Collaborations:
3559         StringSplitter relies on StringMerger to construct the appropriate
3560         CustomSplit objects and add them to the custom split map.
3561     """
3562
3563     MIN_SUBSTR_SIZE = 6
3564     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3565     RE_FEXPR = r"""
3566     (?<!\{)\{
3567         (?:
3568             [^\{\}]
3569             | \{\{
3570             | \}\}
3571         )+?
3572     (?<!\})(?:\}\})*\}(?!\})
3573     """
3574
3575     def do_splitter_match(self, line: Line) -> TMatchResult:
3576         LL = line.leaves
3577
3578         is_valid_index = is_valid_index_factory(LL)
3579
3580         idx = 0
3581
3582         # The first leaf MAY be a '+' symbol...
3583         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3584             idx += 1
3585
3586         # The next/first leaf MAY be an empty LPAR...
3587         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3588             idx += 1
3589
3590         # The next/first leaf MUST be a string...
3591         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3592             return TErr("Line does not start with a string.")
3593
3594         string_idx = idx
3595
3596         # Skip the string trailer, if one exists.
3597         string_parser = StringParser()
3598         idx = string_parser.parse(LL, string_idx)
3599
3600         # That string MAY be followed by an empty RPAR...
3601         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3602             idx += 1
3603
3604         # That string / empty RPAR leaf MAY be followed by a comma...
3605         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3606             idx += 1
3607
3608         # But no more leaves are allowed...
3609         if is_valid_index(idx):
3610             return TErr("This line does not end with a string.")
3611
3612         return Ok(string_idx)
3613
3614     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3615         LL = line.leaves
3616
3617         QUOTE = LL[string_idx].value[-1]
3618
3619         is_valid_index = is_valid_index_factory(LL)
3620         insert_str_child = insert_str_child_factory(LL[string_idx])
3621
3622         prefix = get_string_prefix(LL[string_idx].value)
3623
3624         # We MAY choose to drop the 'f' prefix from substrings that don't
3625         # contain any f-expressions, but ONLY if the original f-string
3626         # contains at least one f-expression. Otherwise, we will alter the AST
3627         # of the program.
3628         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3629             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3630         )
3631
3632         first_string_line = True
3633         starts_with_plus = LL[0].type == token.PLUS
3634
3635         def line_needs_plus() -> bool:
3636             return first_string_line and starts_with_plus
3637
3638         def maybe_append_plus(new_line: Line) -> None:
3639             """
3640             Side Effects:
3641                 If @line starts with a plus and this is the first line we are
3642                 constructing, this function appends a PLUS leaf to @new_line
3643                 and replaces the old PLUS leaf in the node structure. Otherwise
3644                 this function does nothing.
3645             """
3646             if line_needs_plus():
3647                 plus_leaf = Leaf(token.PLUS, "+")
3648                 replace_child(LL[0], plus_leaf)
3649                 new_line.append(plus_leaf)
3650
3651         ends_with_comma = (
3652             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3653         )
3654
3655         def max_last_string() -> int:
3656             """
3657             Returns:
3658                 The max allowed length of the string value used for the last
3659                 line we will construct.
3660             """
3661             result = self.line_length
3662             result -= line.depth * 4
3663             result -= 1 if ends_with_comma else 0
3664             result -= 2 if line_needs_plus() else 0
3665             return result
3666
3667         # --- Calculate Max Break Index (for string value)
3668         # We start with the line length limit
3669         max_break_idx = self.line_length
3670         # The last index of a string of length N is N-1.
3671         max_break_idx -= 1
3672         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3673         max_break_idx -= line.depth * 4
3674         if max_break_idx < 0:
3675             yield TErr(
3676                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3677                 f" {line.depth}"
3678             )
3679             return
3680
3681         # Check if StringMerger registered any custom splits.
3682         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3683         # We use them ONLY if none of them would produce lines that exceed the
3684         # line limit.
3685         use_custom_breakpoints = bool(
3686             custom_splits
3687             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3688         )
3689
3690         # Temporary storage for the remaining chunk of the string line that
3691         # can't fit onto the line currently being constructed.
3692         rest_value = LL[string_idx].value
3693
3694         def more_splits_should_be_made() -> bool:
3695             """
3696             Returns:
3697                 True iff `rest_value` (the remaining string value from the last
3698                 split), should be split again.
3699             """
3700             if use_custom_breakpoints:
3701                 return len(custom_splits) > 1
3702             else:
3703                 return len(rest_value) > max_last_string()
3704
3705         string_line_results: List[Ok[Line]] = []
3706         while more_splits_should_be_made():
3707             if use_custom_breakpoints:
3708                 # Custom User Split (manual)
3709                 csplit = custom_splits.pop(0)
3710                 break_idx = csplit.break_idx
3711             else:
3712                 # Algorithmic Split (automatic)
3713                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3714                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3715                 if maybe_break_idx is None:
3716                     # If we are unable to algorithmically determine a good split
3717                     # and this string has custom splits registered to it, we
3718                     # fall back to using them--which means we have to start
3719                     # over from the beginning.
3720                     if custom_splits:
3721                         rest_value = LL[string_idx].value
3722                         string_line_results = []
3723                         first_string_line = True
3724                         use_custom_breakpoints = True
3725                         continue
3726
3727                     # Otherwise, we stop splitting here.
3728                     break
3729
3730                 break_idx = maybe_break_idx
3731
3732             # --- Construct `next_value`
3733             next_value = rest_value[:break_idx] + QUOTE
3734             if (
3735                 # Are we allowed to try to drop a pointless 'f' prefix?
3736                 drop_pointless_f_prefix
3737                 # If we are, will we be successful?
3738                 and next_value != self.__normalize_f_string(next_value, prefix)
3739             ):
3740                 # If the current custom split did NOT originally use a prefix,
3741                 # then `csplit.break_idx` will be off by one after removing
3742                 # the 'f' prefix.
3743                 break_idx = (
3744                     break_idx + 1
3745                     if use_custom_breakpoints and not csplit.has_prefix
3746                     else break_idx
3747                 )
3748                 next_value = rest_value[:break_idx] + QUOTE
3749                 next_value = self.__normalize_f_string(next_value, prefix)
3750
3751             # --- Construct `next_leaf`
3752             next_leaf = Leaf(token.STRING, next_value)
3753             insert_str_child(next_leaf)
3754             self.__maybe_normalize_string_quotes(next_leaf)
3755
3756             # --- Construct `next_line`
3757             next_line = line.clone()
3758             maybe_append_plus(next_line)
3759             next_line.append(next_leaf)
3760             string_line_results.append(Ok(next_line))
3761
3762             rest_value = prefix + QUOTE + rest_value[break_idx:]
3763             first_string_line = False
3764
3765         yield from string_line_results
3766
3767         if drop_pointless_f_prefix:
3768             rest_value = self.__normalize_f_string(rest_value, prefix)
3769
3770         rest_leaf = Leaf(token.STRING, rest_value)
3771         insert_str_child(rest_leaf)
3772
3773         # NOTE: I could not find a test case that verifies that the following
3774         # line is actually necessary, but it seems to be. Otherwise we risk
3775         # not normalizing the last substring, right?
3776         self.__maybe_normalize_string_quotes(rest_leaf)
3777
3778         last_line = line.clone()
3779         maybe_append_plus(last_line)
3780
3781         # If there are any leaves to the right of the target string...
3782         if is_valid_index(string_idx + 1):
3783             # We use `temp_value` here to determine how long the last line
3784             # would be if we were to append all the leaves to the right of the
3785             # target string to the last string line.
3786             temp_value = rest_value
3787             for leaf in LL[string_idx + 1 :]:
3788                 temp_value += str(leaf)
3789                 if leaf.type == token.LPAR:
3790                     break
3791
3792             # Try to fit them all on the same line with the last substring...
3793             if (
3794                 len(temp_value) <= max_last_string()
3795                 or LL[string_idx + 1].type == token.COMMA
3796             ):
3797                 last_line.append(rest_leaf)
3798                 append_leaves(last_line, line, LL[string_idx + 1 :])
3799                 yield Ok(last_line)
3800             # Otherwise, place the last substring on one line and everything
3801             # else on a line below that...
3802             else:
3803                 last_line.append(rest_leaf)
3804                 yield Ok(last_line)
3805
3806                 non_string_line = line.clone()
3807                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
3808                 yield Ok(non_string_line)
3809         # Else the target string was the last leaf...
3810         else:
3811             last_line.append(rest_leaf)
3812             last_line.comments = line.comments.copy()
3813             yield Ok(last_line)
3814
3815     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
3816         """
3817         This method contains the algorithm that StringSplitter uses to
3818         determine which character to split each string at.
3819
3820         Args:
3821             @string: The substring that we are attempting to split.
3822             @max_break_idx: The ideal break index. We will return this value if it
3823             meets all the necessary conditions. In the likely event that it
3824             doesn't we will try to find the closest index BELOW @max_break_idx
3825             that does. If that fails, we will expand our search by also
3826             considering all valid indices ABOVE @max_break_idx.
3827
3828         Pre-Conditions:
3829             * assert_is_leaf_string(@string)
3830             * 0 <= @max_break_idx < len(@string)
3831
3832         Returns:
3833             break_idx, if an index is able to be found that meets all of the
3834             conditions listed in the 'Transformations' section of this classes'
3835             docstring.
3836                 OR
3837             None, otherwise.
3838         """
3839         is_valid_index = is_valid_index_factory(string)
3840
3841         assert is_valid_index(max_break_idx)
3842         assert_is_leaf_string(string)
3843
3844         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
3845
3846         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
3847             """
3848             Yields:
3849                 All ranges of @string which, if @string were to be split there,
3850                 would result in the splitting of an f-expression (which is NOT
3851                 allowed).
3852             """
3853             nonlocal _fexpr_slices
3854
3855             if _fexpr_slices is None:
3856                 _fexpr_slices = []
3857                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
3858                     _fexpr_slices.append(match.span())
3859
3860             yield from _fexpr_slices
3861
3862         is_fstring = "f" in get_string_prefix(string)
3863
3864         def breaks_fstring_expression(i: Index) -> bool:
3865             """
3866             Returns:
3867                 True iff returning @i would result in the splitting of an
3868                 f-expression (which is NOT allowed).
3869             """
3870             if not is_fstring:
3871                 return False
3872
3873             for (start, end) in fexpr_slices():
3874                 if start <= i < end:
3875                     return True
3876
3877             return False
3878
3879         def passes_all_checks(i: Index) -> bool:
3880             """
3881             Returns:
3882                 True iff ALL of the conditions listed in the 'Transformations'
3883                 section of this classes' docstring would be be met by returning @i.
3884             """
3885             is_space = string[i] == " "
3886             is_big_enough = (
3887                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
3888                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
3889             )
3890             return is_space and is_big_enough and not breaks_fstring_expression(i)
3891
3892         # First, we check all indices BELOW @max_break_idx.
3893         break_idx = max_break_idx
3894         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
3895             break_idx -= 1
3896
3897         if not passes_all_checks(break_idx):
3898             # If that fails, we check all indices ABOVE @max_break_idx.
3899             #
3900             # If we are able to find a valid index here, the next line is going
3901             # to be longer than the specified line length, but it's probably
3902             # better than doing nothing at all.
3903             break_idx = max_break_idx + 1
3904             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
3905                 break_idx += 1
3906
3907             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
3908                 return None
3909
3910         return break_idx
3911
3912     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
3913         if self.normalize_strings:
3914             normalize_string_quotes(leaf)
3915
3916     def __normalize_f_string(self, string: str, prefix: str) -> str:
3917         """
3918         Pre-Conditions:
3919             * assert_is_leaf_string(@string)
3920
3921         Returns:
3922             * If @string is an f-string that contains no f-expressions, we
3923             return a string identical to @string except that the 'f' prefix
3924             has been stripped and all double braces (i.e. '{{' or '}}') have
3925             been normalized (i.e. turned into '{' or '}').
3926                 OR
3927             * Otherwise, we return @string.
3928         """
3929         assert_is_leaf_string(string)
3930
3931         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
3932             new_prefix = prefix.replace("f", "")
3933
3934             temp = string[len(prefix) :]
3935             temp = re.sub(r"\{\{", "{", temp)
3936             temp = re.sub(r"\}\}", "}", temp)
3937             new_string = temp
3938
3939             return f"{new_prefix}{new_string}"
3940         else:
3941             return string
3942
3943
3944 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
3945     """
3946     StringTransformer that splits non-"atom" strings (i.e. strings that do not
3947     exist on lines by themselves).
3948
3949     Requirements:
3950         All of the requirements listed in BaseStringSplitter's docstring in
3951         addition to the requirements listed below:
3952
3953         * The line is a return/yield statement, which returns/yields a string.
3954             OR
3955         * The line is part of a ternary expression (e.g. `x = y if cond else
3956         z`) such that the line starts with `else <string>`, where <string> is
3957         some string.
3958             OR
3959         * The line is an assert statement, which ends with a string.
3960             OR
3961         * The line is an assignment statement (e.g. `x = <string>` or `x +=
3962         <string>`) such that the variable is being assigned the value of some
3963         string.
3964             OR
3965         * The line is a dictionary key assignment where some valid key is being
3966         assigned the value of some string.
3967
3968     Transformations:
3969         The chosen string is wrapped in parentheses and then split at the LPAR.
3970
3971         We then have one line which ends with an LPAR and another line that
3972         starts with the chosen string. The latter line is then split again at
3973         the RPAR. This results in the RPAR (and possibly a trailing comma)
3974         being placed on its own line.
3975
3976         NOTE: If any leaves exist to the right of the chosen string (except
3977         for a trailing comma, which would be placed after the RPAR), those
3978         leaves are placed inside the parentheses.  In effect, the chosen
3979         string is not necessarily being "wrapped" by parentheses. We can,
3980         however, count on the LPAR being placed directly before the chosen
3981         string.
3982
3983         In other words, StringParenWrapper creates "atom" strings. These
3984         can then be split again by StringSplitter, if necessary.
3985
3986     Collaborations:
3987         In the event that a string line split by StringParenWrapper is
3988         changed such that it no longer needs to be given its own line,
3989         StringParenWrapper relies on StringParenStripper to clean up the
3990         parentheses it created.
3991     """
3992
3993     def do_splitter_match(self, line: Line) -> TMatchResult:
3994         LL = line.leaves
3995
3996         string_idx = None
3997         string_idx = string_idx or self._return_match(LL)
3998         string_idx = string_idx or self._else_match(LL)
3999         string_idx = string_idx or self._assert_match(LL)
4000         string_idx = string_idx or self._assign_match(LL)
4001         string_idx = string_idx or self._dict_match(LL)
4002
4003         if string_idx is not None:
4004             string_value = line.leaves[string_idx].value
4005             # If the string has no spaces...
4006             if " " not in string_value:
4007                 # And will still violate the line length limit when split...
4008                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4009                 if len(string_value) > max_string_length:
4010                     # And has no associated custom splits...
4011                     if not self.has_custom_splits(string_value):
4012                         # Then we should NOT put this string on its own line.
4013                         return TErr(
4014                             "We do not wrap long strings in parentheses when the"
4015                             " resultant line would still be over the specified line"
4016                             " length and can't be split further by StringSplitter."
4017                         )
4018             return Ok(string_idx)
4019
4020         return TErr("This line does not contain any non-atomic strings.")
4021
4022     @staticmethod
4023     def _return_match(LL: List[Leaf]) -> Optional[int]:
4024         """
4025         Returns:
4026             string_idx such that @LL[string_idx] is equal to our target (i.e.
4027             matched) string, if this line matches the return/yield statement
4028             requirements listed in the 'Requirements' section of this classes'
4029             docstring.
4030                 OR
4031             None, otherwise.
4032         """
4033         # If this line is apart of a return/yield statement and the first leaf
4034         # contains either the "return" or "yield" keywords...
4035         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4036             0
4037         ].value in ["return", "yield"]:
4038             is_valid_index = is_valid_index_factory(LL)
4039
4040             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4041             # The next visible leaf MUST contain a string...
4042             if is_valid_index(idx) and LL[idx].type == token.STRING:
4043                 return idx
4044
4045         return None
4046
4047     @staticmethod
4048     def _else_match(LL: List[Leaf]) -> Optional[int]:
4049         """
4050         Returns:
4051             string_idx such that @LL[string_idx] is equal to our target (i.e.
4052             matched) string, if this line matches the ternary expression
4053             requirements listed in the 'Requirements' section of this classes'
4054             docstring.
4055                 OR
4056             None, otherwise.
4057         """
4058         # If this line is apart of a ternary expression and the first leaf
4059         # contains the "else" keyword...
4060         if (
4061             parent_type(LL[0]) == syms.test
4062             and LL[0].type == token.NAME
4063             and LL[0].value == "else"
4064         ):
4065             is_valid_index = is_valid_index_factory(LL)
4066
4067             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4068             # The next visible leaf MUST contain a string...
4069             if is_valid_index(idx) and LL[idx].type == token.STRING:
4070                 return idx
4071
4072         return None
4073
4074     @staticmethod
4075     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4076         """
4077         Returns:
4078             string_idx such that @LL[string_idx] is equal to our target (i.e.
4079             matched) string, if this line matches the assert statement
4080             requirements listed in the 'Requirements' section of this classes'
4081             docstring.
4082                 OR
4083             None, otherwise.
4084         """
4085         # If this line is apart of an assert statement and the first leaf
4086         # contains the "assert" keyword...
4087         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4088             is_valid_index = is_valid_index_factory(LL)
4089
4090             for (i, leaf) in enumerate(LL):
4091                 # We MUST find a comma...
4092                 if leaf.type == token.COMMA:
4093                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4094
4095                     # That comma MUST be followed by a string...
4096                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4097                         string_idx = idx
4098
4099                         # Skip the string trailer, if one exists.
4100                         string_parser = StringParser()
4101                         idx = string_parser.parse(LL, string_idx)
4102
4103                         # But no more leaves are allowed...
4104                         if not is_valid_index(idx):
4105                             return string_idx
4106
4107         return None
4108
4109     @staticmethod
4110     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4111         """
4112         Returns:
4113             string_idx such that @LL[string_idx] is equal to our target (i.e.
4114             matched) string, if this line matches the assignment statement
4115             requirements listed in the 'Requirements' section of this classes'
4116             docstring.
4117                 OR
4118             None, otherwise.
4119         """
4120         # If this line is apart of an expression statement or is a function
4121         # argument AND the first leaf contains a variable name...
4122         if (
4123             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4124             and LL[0].type == token.NAME
4125         ):
4126             is_valid_index = is_valid_index_factory(LL)
4127
4128             for (i, leaf) in enumerate(LL):
4129                 # We MUST find either an '=' or '+=' symbol...
4130                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4131                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4132
4133                     # That symbol MUST be followed by a string...
4134                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4135                         string_idx = idx
4136
4137                         # Skip the string trailer, if one exists.
4138                         string_parser = StringParser()
4139                         idx = string_parser.parse(LL, string_idx)
4140
4141                         # The next leaf MAY be a comma iff this line is apart
4142                         # of a function argument...
4143                         if (
4144                             parent_type(LL[0]) == syms.argument
4145                             and is_valid_index(idx)
4146                             and LL[idx].type == token.COMMA
4147                         ):
4148                             idx += 1
4149
4150                         # But no more leaves are allowed...
4151                         if not is_valid_index(idx):
4152                             return string_idx
4153
4154         return None
4155
4156     @staticmethod
4157     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4158         """
4159         Returns:
4160             string_idx such that @LL[string_idx] is equal to our target (i.e.
4161             matched) string, if this line matches the dictionary key assignment
4162             statement requirements listed in the 'Requirements' section of this
4163             classes' docstring.
4164                 OR
4165             None, otherwise.
4166         """
4167         # If this line is apart of a dictionary key assignment...
4168         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4169             is_valid_index = is_valid_index_factory(LL)
4170
4171             for (i, leaf) in enumerate(LL):
4172                 # We MUST find a colon...
4173                 if leaf.type == token.COLON:
4174                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4175
4176                     # That colon MUST be followed by a string...
4177                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4178                         string_idx = idx
4179
4180                         # Skip the string trailer, if one exists.
4181                         string_parser = StringParser()
4182                         idx = string_parser.parse(LL, string_idx)
4183
4184                         # That string MAY be followed by a comma...
4185                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4186                             idx += 1
4187
4188                         # But no more leaves are allowed...
4189                         if not is_valid_index(idx):
4190                             return string_idx
4191
4192         return None
4193
4194     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4195         LL = line.leaves
4196
4197         is_valid_index = is_valid_index_factory(LL)
4198         insert_str_child = insert_str_child_factory(LL[string_idx])
4199
4200         comma_idx = len(LL) - 1
4201         ends_with_comma = False
4202         if LL[comma_idx].type == token.COMMA:
4203             ends_with_comma = True
4204
4205         leaves_to_steal_comments_from = [LL[string_idx]]
4206         if ends_with_comma:
4207             leaves_to_steal_comments_from.append(LL[comma_idx])
4208
4209         # --- First Line
4210         first_line = line.clone()
4211         left_leaves = LL[:string_idx]
4212
4213         # We have to remember to account for (possibly invisible) LPAR and RPAR
4214         # leaves that already wrapped the target string. If these leaves do
4215         # exist, we will replace them with our own LPAR and RPAR leaves.
4216         old_parens_exist = False
4217         if left_leaves and left_leaves[-1].type == token.LPAR:
4218             old_parens_exist = True
4219             leaves_to_steal_comments_from.append(left_leaves[-1])
4220             left_leaves.pop()
4221
4222         append_leaves(first_line, line, left_leaves)
4223
4224         lpar_leaf = Leaf(token.LPAR, "(")
4225         if old_parens_exist:
4226             replace_child(LL[string_idx - 1], lpar_leaf)
4227         else:
4228             insert_str_child(lpar_leaf)
4229         first_line.append(lpar_leaf)
4230
4231         # We throw inline comments that were originally to the right of the
4232         # target string to the top line. They will now be shown to the right of
4233         # the LPAR.
4234         for leaf in leaves_to_steal_comments_from:
4235             for comment_leaf in line.comments_after(leaf):
4236                 first_line.append(comment_leaf, preformatted=True)
4237
4238         yield Ok(first_line)
4239
4240         # --- Middle (String) Line
4241         # We only need to yield one (possibly too long) string line, since the
4242         # `StringSplitter` will break it down further if necessary.
4243         string_value = LL[string_idx].value
4244         string_line = Line(
4245             depth=line.depth + 1,
4246             inside_brackets=True,
4247             should_explode=line.should_explode,
4248         )
4249         string_leaf = Leaf(token.STRING, string_value)
4250         insert_str_child(string_leaf)
4251         string_line.append(string_leaf)
4252
4253         old_rpar_leaf = None
4254         if is_valid_index(string_idx + 1):
4255             right_leaves = LL[string_idx + 1 :]
4256             if ends_with_comma:
4257                 right_leaves.pop()
4258
4259             if old_parens_exist:
4260                 assert (
4261                     right_leaves and right_leaves[-1].type == token.RPAR
4262                 ), "Apparently, old parentheses do NOT exist?!"
4263                 old_rpar_leaf = right_leaves.pop()
4264
4265             append_leaves(string_line, line, right_leaves)
4266
4267         yield Ok(string_line)
4268
4269         # --- Last Line
4270         last_line = line.clone()
4271         last_line.bracket_tracker = first_line.bracket_tracker
4272
4273         new_rpar_leaf = Leaf(token.RPAR, ")")
4274         if old_rpar_leaf is not None:
4275             replace_child(old_rpar_leaf, new_rpar_leaf)
4276         else:
4277             insert_str_child(new_rpar_leaf)
4278         last_line.append(new_rpar_leaf)
4279
4280         # If the target string ended with a comma, we place this comma to the
4281         # right of the RPAR on the last line.
4282         if ends_with_comma:
4283             comma_leaf = Leaf(token.COMMA, ",")
4284             replace_child(LL[comma_idx], comma_leaf)
4285             last_line.append(comma_leaf)
4286
4287         yield Ok(last_line)
4288
4289
4290 class StringParser:
4291     """
4292     A state machine that aids in parsing a string's "trailer", which can be
4293     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4294     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4295     varY)`).
4296
4297     NOTE: A new StringParser object MUST be instantiated for each string
4298     trailer we need to parse.
4299
4300     Examples:
4301         We shall assume that `line` equals the `Line` object that corresponds
4302         to the following line of python code:
4303         ```
4304         x = "Some {}.".format("String") + some_other_string
4305         ```
4306
4307         Furthermore, we will assume that `string_idx` is some index such that:
4308         ```
4309         assert line.leaves[string_idx].value == "Some {}."
4310         ```
4311
4312         The following code snippet then holds:
4313         ```
4314         string_parser = StringParser()
4315         idx = string_parser.parse(line.leaves, string_idx)
4316         assert line.leaves[idx].type == token.PLUS
4317         ```
4318     """
4319
4320     DEFAULT_TOKEN = -1
4321
4322     # String Parser States
4323     START = 1
4324     DOT = 2
4325     NAME = 3
4326     PERCENT = 4
4327     SINGLE_FMT_ARG = 5
4328     LPAR = 6
4329     RPAR = 7
4330     DONE = 8
4331
4332     # Lookup Table for Next State
4333     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4334         # A string trailer may start with '.' OR '%'.
4335         (START, token.DOT): DOT,
4336         (START, token.PERCENT): PERCENT,
4337         (START, DEFAULT_TOKEN): DONE,
4338         # A '.' MUST be followed by an attribute or method name.
4339         (DOT, token.NAME): NAME,
4340         # A method name MUST be followed by an '(', whereas an attribute name
4341         # is the last symbol in the string trailer.
4342         (NAME, token.LPAR): LPAR,
4343         (NAME, DEFAULT_TOKEN): DONE,
4344         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4345         # string or variable name).
4346         (PERCENT, token.LPAR): LPAR,
4347         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4348         # If a '%' symbol is followed by a single argument, that argument is
4349         # the last leaf in the string trailer.
4350         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4351         # If present, a ')' symbol is the last symbol in a string trailer.
4352         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4353         # since they are treated as a special case by the parsing logic in this
4354         # classes' implementation.)
4355         (RPAR, DEFAULT_TOKEN): DONE,
4356     }
4357
4358     def __init__(self) -> None:
4359         self._state = self.START
4360         self._unmatched_lpars = 0
4361
4362     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4363         """
4364         Pre-conditions:
4365             * @leaves[@string_idx].type == token.STRING
4366
4367         Returns:
4368             The index directly after the last leaf which is apart of the string
4369             trailer, if a "trailer" exists.
4370                 OR
4371             @string_idx + 1, if no string "trailer" exists.
4372         """
4373         assert leaves[string_idx].type == token.STRING
4374
4375         idx = string_idx + 1
4376         while idx < len(leaves) and self._next_state(leaves[idx]):
4377             idx += 1
4378         return idx
4379
4380     def _next_state(self, leaf: Leaf) -> bool:
4381         """
4382         Pre-conditions:
4383             * On the first call to this function, @leaf MUST be the leaf that
4384             was directly after the string leaf in question (e.g. if our target
4385             string is `line.leaves[i]` then the first call to this method must
4386             be `line.leaves[i + 1]`).
4387             * On the next call to this function, the leaf parameter passed in
4388             MUST be the leaf directly following @leaf.
4389
4390         Returns:
4391             True iff @leaf is apart of the string's trailer.
4392         """
4393         # We ignore empty LPAR or RPAR leaves.
4394         if is_empty_par(leaf):
4395             return True
4396
4397         next_token = leaf.type
4398         if next_token == token.LPAR:
4399             self._unmatched_lpars += 1
4400
4401         current_state = self._state
4402
4403         # The LPAR parser state is a special case. We will return True until we
4404         # find the matching RPAR token.
4405         if current_state == self.LPAR:
4406             if next_token == token.RPAR:
4407                 self._unmatched_lpars -= 1
4408                 if self._unmatched_lpars == 0:
4409                     self._state = self.RPAR
4410         # Otherwise, we use a lookup table to determine the next state.
4411         else:
4412             # If the lookup table matches the current state to the next
4413             # token, we use the lookup table.
4414             if (current_state, next_token) in self._goto:
4415                 self._state = self._goto[current_state, next_token]
4416             else:
4417                 # Otherwise, we check if a the current state was assigned a
4418                 # default.
4419                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4420                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4421                 # If no default has been assigned, then this parser has a logic
4422                 # error.
4423                 else:
4424                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4425
4426             if self._state == self.DONE:
4427                 return False
4428
4429         return True
4430
4431
4432 def TErr(err_msg: str) -> Err[CannotTransform]:
4433     """(T)ransform Err
4434
4435     Convenience function used when working with the TResult type.
4436     """
4437     cant_transform = CannotTransform(err_msg)
4438     return Err(cant_transform)
4439
4440
4441 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4442     """
4443     Returns:
4444         True iff one of the comments in @comment_list is a pragma used by one
4445         of the more common static analysis tools for python (e.g. mypy, flake8,
4446         pylint).
4447     """
4448     for comment in comment_list:
4449         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4450             return True
4451
4452     return False
4453
4454
4455 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4456     """
4457     Factory for a convenience function that is used to orphan @string_leaf
4458     and then insert multiple new leaves into the same part of the node
4459     structure that @string_leaf had originally occupied.
4460
4461     Examples:
4462         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4463         string_leaf.parent`. Assume the node `N` has the following
4464         original structure:
4465
4466         Node(
4467             expr_stmt, [
4468                 Leaf(NAME, 'x'),
4469                 Leaf(EQUAL, '='),
4470                 Leaf(STRING, '"foo"'),
4471             ]
4472         )
4473
4474         We then run the code snippet shown below.
4475         ```
4476         insert_str_child = insert_str_child_factory(string_leaf)
4477
4478         lpar = Leaf(token.LPAR, '(')
4479         insert_str_child(lpar)
4480
4481         bar = Leaf(token.STRING, '"bar"')
4482         insert_str_child(bar)
4483
4484         rpar = Leaf(token.RPAR, ')')
4485         insert_str_child(rpar)
4486         ```
4487
4488         After which point, it follows that `string_leaf.parent is None` and
4489         the node `N` now has the following structure:
4490
4491         Node(
4492             expr_stmt, [
4493                 Leaf(NAME, 'x'),
4494                 Leaf(EQUAL, '='),
4495                 Leaf(LPAR, '('),
4496                 Leaf(STRING, '"bar"'),
4497                 Leaf(RPAR, ')'),
4498             ]
4499         )
4500     """
4501     string_parent = string_leaf.parent
4502     string_child_idx = string_leaf.remove()
4503
4504     def insert_str_child(child: LN) -> None:
4505         nonlocal string_child_idx
4506
4507         assert string_parent is not None
4508         assert string_child_idx is not None
4509
4510         string_parent.insert_child(string_child_idx, child)
4511         string_child_idx += 1
4512
4513     return insert_str_child
4514
4515
4516 def has_triple_quotes(string: str) -> bool:
4517     """
4518     Returns:
4519         True iff @string starts with three quotation characters.
4520     """
4521     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4522     return raw_string[:3] in {'"""', "'''"}
4523
4524
4525 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4526     """
4527     Returns:
4528         @node.parent.type, if @node is not None and has a parent.
4529             OR
4530         None, otherwise.
4531     """
4532     if node is None or node.parent is None:
4533         return None
4534
4535     return node.parent.type
4536
4537
4538 def is_empty_par(leaf: Leaf) -> bool:
4539     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4540
4541
4542 def is_empty_lpar(leaf: Leaf) -> bool:
4543     return leaf.type == token.LPAR and leaf.value == ""
4544
4545
4546 def is_empty_rpar(leaf: Leaf) -> bool:
4547     return leaf.type == token.RPAR and leaf.value == ""
4548
4549
4550 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4551     """
4552     Examples:
4553         ```
4554         my_list = [1, 2, 3]
4555
4556         is_valid_index = is_valid_index_factory(my_list)
4557
4558         assert is_valid_index(0)
4559         assert is_valid_index(2)
4560
4561         assert not is_valid_index(3)
4562         assert not is_valid_index(-1)
4563         ```
4564     """
4565
4566     def is_valid_index(idx: int) -> bool:
4567         """
4568         Returns:
4569             True iff @idx is positive AND seq[@idx] does NOT raise an
4570             IndexError.
4571         """
4572         return 0 <= idx < len(seq)
4573
4574     return is_valid_index
4575
4576
4577 def line_to_string(line: Line) -> str:
4578     """Returns the string representation of @line.
4579
4580     WARNING: This is known to be computationally expensive.
4581     """
4582     return str(line).strip("\n")
4583
4584
4585 def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
4586     """
4587     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4588     underlying Node structure where appropriate.
4589
4590     All of the leaves in @leaves are duplicated. The duplicates are then
4591     appended to @new_line and used to replace their originals in the underlying
4592     Node structure. Any comments attached to the old leaves are reattached to
4593     the new leaves.
4594
4595     Pre-conditions:
4596         set(@leaves) is a subset of set(@old_line.leaves).
4597     """
4598     for old_leaf in leaves:
4599         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4600         replace_child(old_leaf, new_leaf)
4601         new_line.append(new_leaf)
4602
4603         for comment_leaf in old_line.comments_after(old_leaf):
4604             new_line.append(comment_leaf, preformatted=True)
4605
4606
4607 def replace_child(old_child: LN, new_child: LN) -> None:
4608     """
4609     Side Effects:
4610         * If @old_child.parent is set, replace @old_child with @new_child in
4611         @old_child's underlying Node structure.
4612             OR
4613         * Otherwise, this function does nothing.
4614     """
4615     parent = old_child.parent
4616     if not parent:
4617         return
4618
4619     child_idx = old_child.remove()
4620     if child_idx is not None:
4621         parent.insert_child(child_idx, new_child)
4622
4623
4624 def get_string_prefix(string: str) -> str:
4625     """
4626     Pre-conditions:
4627         * assert_is_leaf_string(@string)
4628
4629     Returns:
4630         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4631     """
4632     assert_is_leaf_string(string)
4633
4634     prefix = ""
4635     prefix_idx = 0
4636     while string[prefix_idx] in STRING_PREFIX_CHARS:
4637         prefix += string[prefix_idx].lower()
4638         prefix_idx += 1
4639
4640     return prefix
4641
4642
4643 def assert_is_leaf_string(string: str) -> None:
4644     """
4645     Checks the pre-condition that @string has the format that you would expect
4646     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4647     token.STRING`. A more precise description of the pre-conditions that are
4648     checked are listed below.
4649
4650     Pre-conditions:
4651         * @string starts with either ', ", <prefix>', or <prefix>" where
4652         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4653         * @string ends with a quote character (' or ").
4654
4655     Raises:
4656         AssertionError(...) if the pre-conditions listed above are not
4657         satisfied.
4658     """
4659     dquote_idx = string.find('"')
4660     squote_idx = string.find("'")
4661     if -1 in [dquote_idx, squote_idx]:
4662         quote_idx = max(dquote_idx, squote_idx)
4663     else:
4664         quote_idx = min(squote_idx, dquote_idx)
4665
4666     assert (
4667         0 <= quote_idx < len(string) - 1
4668     ), f"{string!r} is missing a starting quote character (' or \")."
4669     assert string[-1] in (
4670         "'",
4671         '"',
4672     ), f"{string!r} is missing an ending quote character (' or \")."
4673     assert set(string[:quote_idx]).issubset(
4674         set(STRING_PREFIX_CHARS)
4675     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4676
4677
4678 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4679     """Split line into many lines, starting with the first matching bracket pair.
4680
4681     Note: this usually looks weird, only use this for function definitions.
4682     Prefer RHS otherwise.  This is why this function is not symmetrical with
4683     :func:`right_hand_split` which also handles optional parentheses.
4684     """
4685     tail_leaves: List[Leaf] = []
4686     body_leaves: List[Leaf] = []
4687     head_leaves: List[Leaf] = []
4688     current_leaves = head_leaves
4689     matching_bracket: Optional[Leaf] = None
4690     for leaf in line.leaves:
4691         if (
4692             current_leaves is body_leaves
4693             and leaf.type in CLOSING_BRACKETS
4694             and leaf.opening_bracket is matching_bracket
4695         ):
4696             current_leaves = tail_leaves if body_leaves else head_leaves
4697         current_leaves.append(leaf)
4698         if current_leaves is head_leaves:
4699             if leaf.type in OPENING_BRACKETS:
4700                 matching_bracket = leaf
4701                 current_leaves = body_leaves
4702     if not matching_bracket:
4703         raise CannotSplit("No brackets found")
4704
4705     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4706     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4707     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4708     bracket_split_succeeded_or_raise(head, body, tail)
4709     for result in (head, body, tail):
4710         if result:
4711             yield result
4712
4713
4714 def right_hand_split(
4715     line: Line,
4716     line_length: int,
4717     features: Collection[Feature] = (),
4718     omit: Collection[LeafID] = (),
4719 ) -> Iterator[Line]:
4720     """Split line into many lines, starting with the last matching bracket pair.
4721
4722     If the split was by optional parentheses, attempt splitting without them, too.
4723     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4724     this split.
4725
4726     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4727     """
4728     tail_leaves: List[Leaf] = []
4729     body_leaves: List[Leaf] = []
4730     head_leaves: List[Leaf] = []
4731     current_leaves = tail_leaves
4732     opening_bracket: Optional[Leaf] = None
4733     closing_bracket: Optional[Leaf] = None
4734     for leaf in reversed(line.leaves):
4735         if current_leaves is body_leaves:
4736             if leaf is opening_bracket:
4737                 current_leaves = head_leaves if body_leaves else tail_leaves
4738         current_leaves.append(leaf)
4739         if current_leaves is tail_leaves:
4740             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4741                 opening_bracket = leaf.opening_bracket
4742                 closing_bracket = leaf
4743                 current_leaves = body_leaves
4744     if not (opening_bracket and closing_bracket and head_leaves):
4745         # If there is no opening or closing_bracket that means the split failed and
4746         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4747         # the matching `opening_bracket` wasn't available on `line` anymore.
4748         raise CannotSplit("No brackets found")
4749
4750     tail_leaves.reverse()
4751     body_leaves.reverse()
4752     head_leaves.reverse()
4753     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4754     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4755     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4756     bracket_split_succeeded_or_raise(head, body, tail)
4757     if (
4758         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4759         # the opening bracket is an optional paren
4760         and opening_bracket.type == token.LPAR
4761         and not opening_bracket.value
4762         # the closing bracket is an optional paren
4763         and closing_bracket.type == token.RPAR
4764         and not closing_bracket.value
4765         # it's not an import (optional parens are the only thing we can split on
4766         # in this case; attempting a split without them is a waste of time)
4767         and not line.is_import
4768         # there are no standalone comments in the body
4769         and not body.contains_standalone_comments(0)
4770         # and we can actually remove the parens
4771         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4772     ):
4773         omit = {id(closing_bracket), *omit}
4774         try:
4775             yield from right_hand_split(line, line_length, features=features, omit=omit)
4776             return
4777
4778         except CannotSplit:
4779             if not (
4780                 can_be_split(body)
4781                 or is_line_short_enough(body, line_length=line_length)
4782             ):
4783                 raise CannotSplit(
4784                     "Splitting failed, body is still too long and can't be split."
4785                 )
4786
4787             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
4788                 raise CannotSplit(
4789                     "The current optional pair of parentheses is bound to fail to"
4790                     " satisfy the splitting algorithm because the head or the tail"
4791                     " contains multiline strings which by definition never fit one"
4792                     " line."
4793                 )
4794
4795     ensure_visible(opening_bracket)
4796     ensure_visible(closing_bracket)
4797     for result in (head, body, tail):
4798         if result:
4799             yield result
4800
4801
4802 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
4803     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
4804
4805     Do nothing otherwise.
4806
4807     A left- or right-hand split is based on a pair of brackets. Content before
4808     (and including) the opening bracket is left on one line, content inside the
4809     brackets is put on a separate line, and finally content starting with and
4810     following the closing bracket is put on a separate line.
4811
4812     Those are called `head`, `body`, and `tail`, respectively. If the split
4813     produced the same line (all content in `head`) or ended up with an empty `body`
4814     and the `tail` is just the closing bracket, then it's considered failed.
4815     """
4816     tail_len = len(str(tail).strip())
4817     if not body:
4818         if tail_len == 0:
4819             raise CannotSplit("Splitting brackets produced the same line")
4820
4821         elif tail_len < 3:
4822             raise CannotSplit(
4823                 f"Splitting brackets on an empty body to save {tail_len} characters is"
4824                 " not worth it"
4825             )
4826
4827
4828 def bracket_split_build_line(
4829     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
4830 ) -> Line:
4831     """Return a new line with given `leaves` and respective comments from `original`.
4832
4833     If `is_body` is True, the result line is one-indented inside brackets and as such
4834     has its first leaf's prefix normalized and a trailing comma added when expected.
4835     """
4836     result = Line(depth=original.depth)
4837     if is_body:
4838         result.inside_brackets = True
4839         result.depth += 1
4840         if leaves:
4841             # Since body is a new indent level, remove spurious leading whitespace.
4842             normalize_prefix(leaves[0], inside_brackets=True)
4843             # Ensure a trailing comma for imports and standalone function arguments, but
4844             # be careful not to add one after any comments or within type annotations.
4845             no_commas = (
4846                 original.is_def
4847                 and opening_bracket.value == "("
4848                 and not any(leaf.type == token.COMMA for leaf in leaves)
4849             )
4850
4851             if original.is_import or no_commas:
4852                 for i in range(len(leaves) - 1, -1, -1):
4853                     if leaves[i].type == STANDALONE_COMMENT:
4854                         continue
4855
4856                     if leaves[i].type != token.COMMA:
4857                         new_comma = Leaf(token.COMMA, ",")
4858                         new_comma.was_checked = True
4859                         leaves.insert(i + 1, new_comma)
4860                     break
4861
4862     # Populate the line
4863     for leaf in leaves:
4864         result.append(leaf, preformatted=True)
4865         for comment_after in original.comments_after(leaf):
4866             result.append(comment_after, preformatted=True)
4867     if is_body and should_split_body_explode(result, opening_bracket):
4868         result.should_explode = True
4869     return result
4870
4871
4872 def dont_increase_indentation(split_func: Transformer) -> Transformer:
4873     """Normalize prefix of the first leaf in every line returned by `split_func`.
4874
4875     This is a decorator over relevant split functions.
4876     """
4877
4878     @wraps(split_func)
4879     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4880         for line in split_func(line, features):
4881             normalize_prefix(line.leaves[0], inside_brackets=True)
4882             yield line
4883
4884     return split_wrapper
4885
4886
4887 @dont_increase_indentation
4888 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
4889     """Split according to delimiters of the highest priority.
4890
4891     If the appropriate Features are given, the split will add trailing commas
4892     also in function signatures and calls that contain `*` and `**`.
4893     """
4894     try:
4895         last_leaf = line.leaves[-1]
4896     except IndexError:
4897         raise CannotSplit("Line empty")
4898
4899     bt = line.bracket_tracker
4900     try:
4901         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
4902     except ValueError:
4903         raise CannotSplit("No delimiters found")
4904
4905     if delimiter_priority == DOT_PRIORITY:
4906         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
4907             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
4908
4909     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4910     lowest_depth = sys.maxsize
4911     trailing_comma_safe = True
4912
4913     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4914         """Append `leaf` to current line or to new line if appending impossible."""
4915         nonlocal current_line
4916         try:
4917             current_line.append_safe(leaf, preformatted=True)
4918         except ValueError:
4919             yield current_line
4920
4921             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4922             current_line.append(leaf)
4923
4924     for leaf in line.leaves:
4925         yield from append_to_line(leaf)
4926
4927         for comment_after in line.comments_after(leaf):
4928             yield from append_to_line(comment_after)
4929
4930         lowest_depth = min(lowest_depth, leaf.bracket_depth)
4931         if leaf.bracket_depth == lowest_depth:
4932             if is_vararg(leaf, within={syms.typedargslist}):
4933                 trailing_comma_safe = (
4934                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
4935                 )
4936             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
4937                 trailing_comma_safe = (
4938                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
4939                 )
4940
4941         leaf_priority = bt.delimiters.get(id(leaf))
4942         if leaf_priority == delimiter_priority:
4943             yield current_line
4944
4945             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4946     if current_line:
4947         if (
4948             trailing_comma_safe
4949             and delimiter_priority == COMMA_PRIORITY
4950             and current_line.leaves[-1].type != token.COMMA
4951             and current_line.leaves[-1].type != STANDALONE_COMMENT
4952         ):
4953             new_comma = Leaf(token.COMMA, ",")
4954             new_comma.was_checked = True
4955             current_line.append(new_comma)
4956         yield current_line
4957
4958
4959 @dont_increase_indentation
4960 def standalone_comment_split(
4961     line: Line, features: Collection[Feature] = ()
4962 ) -> Iterator[Line]:
4963     """Split standalone comments from the rest of the line."""
4964     if not line.contains_standalone_comments(0):
4965         raise CannotSplit("Line does not have any standalone comments")
4966
4967     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4968
4969     def append_to_line(leaf: Leaf) -> Iterator[Line]:
4970         """Append `leaf` to current line or to new line if appending impossible."""
4971         nonlocal current_line
4972         try:
4973             current_line.append_safe(leaf, preformatted=True)
4974         except ValueError:
4975             yield current_line
4976
4977             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
4978             current_line.append(leaf)
4979
4980     for leaf in line.leaves:
4981         yield from append_to_line(leaf)
4982
4983         for comment_after in line.comments_after(leaf):
4984             yield from append_to_line(comment_after)
4985
4986     if current_line:
4987         yield current_line
4988
4989
4990 def is_import(leaf: Leaf) -> bool:
4991     """Return True if the given leaf starts an import statement."""
4992     p = leaf.parent
4993     t = leaf.type
4994     v = leaf.value
4995     return bool(
4996         t == token.NAME
4997         and (
4998             (v == "import" and p and p.type == syms.import_name)
4999             or (v == "from" and p and p.type == syms.import_from)
5000         )
5001     )
5002
5003
5004 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5005     """Return True if the given leaf is a special comment.
5006     Only returns true for type comments for now."""
5007     t = leaf.type
5008     v = leaf.value
5009     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5010
5011
5012 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5013     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5014     else.
5015
5016     Note: don't use backslashes for formatting or you'll lose your voting rights.
5017     """
5018     if not inside_brackets:
5019         spl = leaf.prefix.split("#")
5020         if "\\" not in spl[0]:
5021             nl_count = spl[-1].count("\n")
5022             if len(spl) > 1:
5023                 nl_count -= 1
5024             leaf.prefix = "\n" * nl_count
5025             return
5026
5027     leaf.prefix = ""
5028
5029
5030 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5031     """Make all string prefixes lowercase.
5032
5033     If remove_u_prefix is given, also removes any u prefix from the string.
5034
5035     Note: Mutates its argument.
5036     """
5037     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5038     assert match is not None, f"failed to match string {leaf.value!r}"
5039     orig_prefix = match.group(1)
5040     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5041     if remove_u_prefix:
5042         new_prefix = new_prefix.replace("u", "")
5043     leaf.value = f"{new_prefix}{match.group(2)}"
5044
5045
5046 def normalize_string_quotes(leaf: Leaf) -> None:
5047     """Prefer double quotes but only if it doesn't cause more escaping.
5048
5049     Adds or removes backslashes as appropriate. Doesn't parse and fix
5050     strings nested in f-strings (yet).
5051
5052     Note: Mutates its argument.
5053     """
5054     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5055     if value[:3] == '"""':
5056         return
5057
5058     elif value[:3] == "'''":
5059         orig_quote = "'''"
5060         new_quote = '"""'
5061     elif value[0] == '"':
5062         orig_quote = '"'
5063         new_quote = "'"
5064     else:
5065         orig_quote = "'"
5066         new_quote = '"'
5067     first_quote_pos = leaf.value.find(orig_quote)
5068     if first_quote_pos == -1:
5069         return  # There's an internal error
5070
5071     prefix = leaf.value[:first_quote_pos]
5072     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5073     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5074     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5075     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5076     if "r" in prefix.casefold():
5077         if unescaped_new_quote.search(body):
5078             # There's at least one unescaped new_quote in this raw string
5079             # so converting is impossible
5080             return
5081
5082         # Do not introduce or remove backslashes in raw strings
5083         new_body = body
5084     else:
5085         # remove unnecessary escapes
5086         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5087         if body != new_body:
5088             # Consider the string without unnecessary escapes as the original
5089             body = new_body
5090             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5091         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5092         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5093     if "f" in prefix.casefold():
5094         matches = re.findall(
5095             r"""
5096             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5097                 ([^{].*?)  # contents of the brackets except if begins with {{
5098             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5099             """,
5100             new_body,
5101             re.VERBOSE,
5102         )
5103         for m in matches:
5104             if "\\" in str(m):
5105                 # Do not introduce backslashes in interpolated expressions
5106                 return
5107
5108     if new_quote == '"""' and new_body[-1:] == '"':
5109         # edge case:
5110         new_body = new_body[:-1] + '\\"'
5111     orig_escape_count = body.count("\\")
5112     new_escape_count = new_body.count("\\")
5113     if new_escape_count > orig_escape_count:
5114         return  # Do not introduce more escaping
5115
5116     if new_escape_count == orig_escape_count and orig_quote == '"':
5117         return  # Prefer double quotes
5118
5119     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5120
5121
5122 def normalize_numeric_literal(leaf: Leaf) -> None:
5123     """Normalizes numeric (float, int, and complex) literals.
5124
5125     All letters used in the representation are normalized to lowercase (except
5126     in Python 2 long literals).
5127     """
5128     text = leaf.value.lower()
5129     if text.startswith(("0o", "0b")):
5130         # Leave octal and binary literals alone.
5131         pass
5132     elif text.startswith("0x"):
5133         # Change hex literals to upper case.
5134         before, after = text[:2], text[2:]
5135         text = f"{before}{after.upper()}"
5136     elif "e" in text:
5137         before, after = text.split("e")
5138         sign = ""
5139         if after.startswith("-"):
5140             after = after[1:]
5141             sign = "-"
5142         elif after.startswith("+"):
5143             after = after[1:]
5144         before = format_float_or_int_string(before)
5145         text = f"{before}e{sign}{after}"
5146     elif text.endswith(("j", "l")):
5147         number = text[:-1]
5148         suffix = text[-1]
5149         # Capitalize in "2L" because "l" looks too similar to "1".
5150         if suffix == "l":
5151             suffix = "L"
5152         text = f"{format_float_or_int_string(number)}{suffix}"
5153     else:
5154         text = format_float_or_int_string(text)
5155     leaf.value = text
5156
5157
5158 def format_float_or_int_string(text: str) -> str:
5159     """Formats a float string like "1.0"."""
5160     if "." not in text:
5161         return text
5162
5163     before, after = text.split(".")
5164     return f"{before or 0}.{after or 0}"
5165
5166
5167 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5168     """Make existing optional parentheses invisible or create new ones.
5169
5170     `parens_after` is a set of string leaf values immediately after which parens
5171     should be put.
5172
5173     Standardizes on visible parentheses for single-element tuples, and keeps
5174     existing visible parentheses for other tuples and generator expressions.
5175     """
5176     for pc in list_comments(node.prefix, is_endmarker=False):
5177         if pc.value in FMT_OFF:
5178             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5179             return
5180     check_lpar = False
5181     for index, child in enumerate(list(node.children)):
5182         # Fixes a bug where invisible parens are not properly stripped from
5183         # assignment statements that contain type annotations.
5184         if isinstance(child, Node) and child.type == syms.annassign:
5185             normalize_invisible_parens(child, parens_after=parens_after)
5186
5187         # Add parentheses around long tuple unpacking in assignments.
5188         if (
5189             index == 0
5190             and isinstance(child, Node)
5191             and child.type == syms.testlist_star_expr
5192         ):
5193             check_lpar = True
5194
5195         if check_lpar:
5196             if is_walrus_assignment(child):
5197                 continue
5198
5199             if child.type == syms.atom:
5200                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5201                     wrap_in_parentheses(node, child, visible=False)
5202             elif is_one_tuple(child):
5203                 wrap_in_parentheses(node, child, visible=True)
5204             elif node.type == syms.import_from:
5205                 # "import from" nodes store parentheses directly as part of
5206                 # the statement
5207                 if child.type == token.LPAR:
5208                     # make parentheses invisible
5209                     child.value = ""  # type: ignore
5210                     node.children[-1].value = ""  # type: ignore
5211                 elif child.type != token.STAR:
5212                     # insert invisible parentheses
5213                     node.insert_child(index, Leaf(token.LPAR, ""))
5214                     node.append_child(Leaf(token.RPAR, ""))
5215                 break
5216
5217             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5218                 wrap_in_parentheses(node, child, visible=False)
5219
5220         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5221
5222
5223 def normalize_fmt_off(node: Node) -> None:
5224     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5225     try_again = True
5226     while try_again:
5227         try_again = convert_one_fmt_off_pair(node)
5228
5229
5230 def convert_one_fmt_off_pair(node: Node) -> bool:
5231     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5232
5233     Returns True if a pair was converted.
5234     """
5235     for leaf in node.leaves():
5236         previous_consumed = 0
5237         for comment in list_comments(leaf.prefix, is_endmarker=False):
5238             if comment.value in FMT_OFF:
5239                 # We only want standalone comments. If there's no previous leaf or
5240                 # the previous leaf is indentation, it's a standalone comment in
5241                 # disguise.
5242                 if comment.type != STANDALONE_COMMENT:
5243                     prev = preceding_leaf(leaf)
5244                     if prev and prev.type not in WHITESPACE:
5245                         continue
5246
5247                 ignored_nodes = list(generate_ignored_nodes(leaf))
5248                 if not ignored_nodes:
5249                     continue
5250
5251                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5252                 parent = first.parent
5253                 prefix = first.prefix
5254                 first.prefix = prefix[comment.consumed :]
5255                 hidden_value = (
5256                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
5257                 )
5258                 if hidden_value.endswith("\n"):
5259                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
5260                     # leaf (possibly followed by a DEDENT).
5261                     hidden_value = hidden_value[:-1]
5262                 first_idx: Optional[int] = None
5263                 for ignored in ignored_nodes:
5264                     index = ignored.remove()
5265                     if first_idx is None:
5266                         first_idx = index
5267                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5268                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5269                 parent.insert_child(
5270                     first_idx,
5271                     Leaf(
5272                         STANDALONE_COMMENT,
5273                         hidden_value,
5274                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5275                     ),
5276                 )
5277                 return True
5278
5279             previous_consumed = comment.consumed
5280
5281     return False
5282
5283
5284 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
5285     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5286
5287     Stops at the end of the block.
5288     """
5289     container: Optional[LN] = container_of(leaf)
5290     while container is not None and container.type != token.ENDMARKER:
5291         if is_fmt_on(container):
5292             return
5293
5294         # fix for fmt: on in children
5295         if contains_fmt_on_at_column(container, leaf.column):
5296             for child in container.children:
5297                 if contains_fmt_on_at_column(child, leaf.column):
5298                     return
5299                 yield child
5300         else:
5301             yield container
5302             container = container.next_sibling
5303
5304
5305 def is_fmt_on(container: LN) -> bool:
5306     """Determine whether formatting is switched on within a container.
5307     Determined by whether the last `# fmt:` comment is `on` or `off`.
5308     """
5309     fmt_on = False
5310     for comment in list_comments(container.prefix, is_endmarker=False):
5311         if comment.value in FMT_ON:
5312             fmt_on = True
5313         elif comment.value in FMT_OFF:
5314             fmt_on = False
5315     return fmt_on
5316
5317
5318 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5319     """Determine if children at a given column have formatting switched on."""
5320     for child in container.children:
5321         if (
5322             isinstance(child, Node)
5323             and first_leaf_column(child) == column
5324             or isinstance(child, Leaf)
5325             and child.column == column
5326         ):
5327             if is_fmt_on(child):
5328                 return True
5329
5330     return False
5331
5332
5333 def first_leaf_column(node: Node) -> Optional[int]:
5334     """Returns the column of the first leaf child of a node."""
5335     for child in node.children:
5336         if isinstance(child, Leaf):
5337             return child.column
5338     return None
5339
5340
5341 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5342     """If it's safe, make the parens in the atom `node` invisible, recursively.
5343     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5344     as they are redundant.
5345
5346     Returns whether the node should itself be wrapped in invisible parentheses.
5347
5348     """
5349     if (
5350         node.type != syms.atom
5351         or is_empty_tuple(node)
5352         or is_one_tuple(node)
5353         or (is_yield(node) and parent.type != syms.expr_stmt)
5354         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5355     ):
5356         return False
5357
5358     first = node.children[0]
5359     last = node.children[-1]
5360     if first.type == token.LPAR and last.type == token.RPAR:
5361         middle = node.children[1]
5362         # make parentheses invisible
5363         first.value = ""  # type: ignore
5364         last.value = ""  # type: ignore
5365         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5366
5367         if is_atom_with_invisible_parens(middle):
5368             # Strip the invisible parens from `middle` by replacing
5369             # it with the child in-between the invisible parens
5370             middle.replace(middle.children[1])
5371
5372         return False
5373
5374     return True
5375
5376
5377 def is_atom_with_invisible_parens(node: LN) -> bool:
5378     """Given a `LN`, determines whether it's an atom `node` with invisible
5379     parens. Useful in dedupe-ing and normalizing parens.
5380     """
5381     if isinstance(node, Leaf) or node.type != syms.atom:
5382         return False
5383
5384     first, last = node.children[0], node.children[-1]
5385     return (
5386         isinstance(first, Leaf)
5387         and first.type == token.LPAR
5388         and first.value == ""
5389         and isinstance(last, Leaf)
5390         and last.type == token.RPAR
5391         and last.value == ""
5392     )
5393
5394
5395 def is_empty_tuple(node: LN) -> bool:
5396     """Return True if `node` holds an empty tuple."""
5397     return (
5398         node.type == syms.atom
5399         and len(node.children) == 2
5400         and node.children[0].type == token.LPAR
5401         and node.children[1].type == token.RPAR
5402     )
5403
5404
5405 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5406     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5407
5408     Parenthesis can be optional. Returns None otherwise"""
5409     if len(node.children) != 3:
5410         return None
5411
5412     lpar, wrapped, rpar = node.children
5413     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5414         return None
5415
5416     return wrapped
5417
5418
5419 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5420     """Wrap `child` in parentheses.
5421
5422     This replaces `child` with an atom holding the parentheses and the old
5423     child.  That requires moving the prefix.
5424
5425     If `visible` is False, the leaves will be valueless (and thus invisible).
5426     """
5427     lpar = Leaf(token.LPAR, "(" if visible else "")
5428     rpar = Leaf(token.RPAR, ")" if visible else "")
5429     prefix = child.prefix
5430     child.prefix = ""
5431     index = child.remove() or 0
5432     new_child = Node(syms.atom, [lpar, child, rpar])
5433     new_child.prefix = prefix
5434     parent.insert_child(index, new_child)
5435
5436
5437 def is_one_tuple(node: LN) -> bool:
5438     """Return True if `node` holds a tuple with one element, with or without parens."""
5439     if node.type == syms.atom:
5440         gexp = unwrap_singleton_parenthesis(node)
5441         if gexp is None or gexp.type != syms.testlist_gexp:
5442             return False
5443
5444         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5445
5446     return (
5447         node.type in IMPLICIT_TUPLE
5448         and len(node.children) == 2
5449         and node.children[1].type == token.COMMA
5450     )
5451
5452
5453 def is_walrus_assignment(node: LN) -> bool:
5454     """Return True iff `node` is of the shape ( test := test )"""
5455     inner = unwrap_singleton_parenthesis(node)
5456     return inner is not None and inner.type == syms.namedexpr_test
5457
5458
5459 def is_yield(node: LN) -> bool:
5460     """Return True if `node` holds a `yield` or `yield from` expression."""
5461     if node.type == syms.yield_expr:
5462         return True
5463
5464     if node.type == token.NAME and node.value == "yield":  # type: ignore
5465         return True
5466
5467     if node.type != syms.atom:
5468         return False
5469
5470     if len(node.children) != 3:
5471         return False
5472
5473     lpar, expr, rpar = node.children
5474     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5475         return is_yield(expr)
5476
5477     return False
5478
5479
5480 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5481     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5482
5483     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5484     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5485     extended iterable unpacking (PEP 3132) and additional unpacking
5486     generalizations (PEP 448).
5487     """
5488     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5489         return False
5490
5491     p = leaf.parent
5492     if p.type == syms.star_expr:
5493         # Star expressions are also used as assignment targets in extended
5494         # iterable unpacking (PEP 3132).  See what its parent is instead.
5495         if not p.parent:
5496             return False
5497
5498         p = p.parent
5499
5500     return p.type in within
5501
5502
5503 def is_multiline_string(leaf: Leaf) -> bool:
5504     """Return True if `leaf` is a multiline string that actually spans many lines."""
5505     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5506
5507
5508 def is_stub_suite(node: Node) -> bool:
5509     """Return True if `node` is a suite with a stub body."""
5510     if (
5511         len(node.children) != 4
5512         or node.children[0].type != token.NEWLINE
5513         or node.children[1].type != token.INDENT
5514         or node.children[3].type != token.DEDENT
5515     ):
5516         return False
5517
5518     return is_stub_body(node.children[2])
5519
5520
5521 def is_stub_body(node: LN) -> bool:
5522     """Return True if `node` is a simple statement containing an ellipsis."""
5523     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5524         return False
5525
5526     if len(node.children) != 2:
5527         return False
5528
5529     child = node.children[0]
5530     return (
5531         child.type == syms.atom
5532         and len(child.children) == 3
5533         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5534     )
5535
5536
5537 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5538     """Return maximum delimiter priority inside `node`.
5539
5540     This is specific to atoms with contents contained in a pair of parentheses.
5541     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5542     """
5543     if node.type != syms.atom:
5544         return 0
5545
5546     first = node.children[0]
5547     last = node.children[-1]
5548     if not (first.type == token.LPAR and last.type == token.RPAR):
5549         return 0
5550
5551     bt = BracketTracker()
5552     for c in node.children[1:-1]:
5553         if isinstance(c, Leaf):
5554             bt.mark(c)
5555         else:
5556             for leaf in c.leaves():
5557                 bt.mark(leaf)
5558     try:
5559         return bt.max_delimiter_priority()
5560
5561     except ValueError:
5562         return 0
5563
5564
5565 def ensure_visible(leaf: Leaf) -> None:
5566     """Make sure parentheses are visible.
5567
5568     They could be invisible as part of some statements (see
5569     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5570     """
5571     if leaf.type == token.LPAR:
5572         leaf.value = "("
5573     elif leaf.type == token.RPAR:
5574         leaf.value = ")"
5575
5576
5577 def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
5578     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5579
5580     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5581         return False
5582
5583     # We're essentially checking if the body is delimited by commas and there's more
5584     # than one of them (we're excluding the trailing comma and if the delimiter priority
5585     # is still commas, that means there's more).
5586     exclude = set()
5587     pre_existing_trailing_comma = False
5588     try:
5589         last_leaf = line.leaves[-1]
5590         if last_leaf.type == token.COMMA:
5591             pre_existing_trailing_comma = not last_leaf.was_checked
5592             exclude.add(id(last_leaf))
5593         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5594     except (IndexError, ValueError):
5595         return False
5596
5597     return max_priority == COMMA_PRIORITY and (
5598         # always explode imports
5599         opening_bracket.parent.type in {syms.atom, syms.import_from}
5600         or pre_existing_trailing_comma
5601     )
5602
5603
5604 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5605     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5606     if opening.type != token.LPAR and closing.type != token.RPAR:
5607         return False
5608
5609     depth = closing.bracket_depth + 1
5610     for _opening_index, leaf in enumerate(leaves):
5611         if leaf is opening:
5612             break
5613
5614     else:
5615         raise LookupError("Opening paren not found in `leaves`")
5616
5617     commas = 0
5618     _opening_index += 1
5619     for leaf in leaves[_opening_index:]:
5620         if leaf is closing:
5621             break
5622
5623         bracket_depth = leaf.bracket_depth
5624         if bracket_depth == depth and leaf.type == token.COMMA:
5625             commas += 1
5626             if leaf.parent and leaf.parent.type in {
5627                 syms.arglist,
5628                 syms.typedargslist,
5629             }:
5630                 commas += 1
5631                 break
5632
5633     return commas < 2
5634
5635
5636 def get_features_used(node: Node) -> Set[Feature]:
5637     """Return a set of (relatively) new Python features used in this file.
5638
5639     Currently looking for:
5640     - f-strings;
5641     - underscores in numeric literals;
5642     - trailing commas after * or ** in function signatures and calls;
5643     - positional only arguments in function signatures and lambdas;
5644     """
5645     features: Set[Feature] = set()
5646     for n in node.pre_order():
5647         if n.type == token.STRING:
5648             value_head = n.value[:2]  # type: ignore
5649             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5650                 features.add(Feature.F_STRINGS)
5651
5652         elif n.type == token.NUMBER:
5653             if "_" in n.value:  # type: ignore
5654                 features.add(Feature.NUMERIC_UNDERSCORES)
5655
5656         elif n.type == token.SLASH:
5657             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5658                 features.add(Feature.POS_ONLY_ARGUMENTS)
5659
5660         elif n.type == token.COLONEQUAL:
5661             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5662
5663         elif (
5664             n.type in {syms.typedargslist, syms.arglist}
5665             and n.children
5666             and n.children[-1].type == token.COMMA
5667         ):
5668             if n.type == syms.typedargslist:
5669                 feature = Feature.TRAILING_COMMA_IN_DEF
5670             else:
5671                 feature = Feature.TRAILING_COMMA_IN_CALL
5672
5673             for ch in n.children:
5674                 if ch.type in STARS:
5675                     features.add(feature)
5676
5677                 if ch.type == syms.argument:
5678                     for argch in ch.children:
5679                         if argch.type in STARS:
5680                             features.add(feature)
5681
5682     return features
5683
5684
5685 def detect_target_versions(node: Node) -> Set[TargetVersion]:
5686     """Detect the version to target based on the nodes used."""
5687     features = get_features_used(node)
5688     return {
5689         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
5690     }
5691
5692
5693 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
5694     """Generate sets of closing bracket IDs that should be omitted in a RHS.
5695
5696     Brackets can be omitted if the entire trailer up to and including
5697     a preceding closing bracket fits in one line.
5698
5699     Yielded sets are cumulative (contain results of previous yields, too).  First
5700     set is empty, unless the line should explode, in which case bracket pairs until
5701     the one that needs to explode are omitted.
5702     """
5703
5704     omit: Set[LeafID] = set()
5705     if not line.should_explode:
5706         yield omit
5707
5708     length = 4 * line.depth
5709     opening_bracket: Optional[Leaf] = None
5710     closing_bracket: Optional[Leaf] = None
5711     inner_brackets: Set[LeafID] = set()
5712     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
5713         length += leaf_length
5714         if length > line_length:
5715             break
5716
5717         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
5718         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
5719             break
5720
5721         if opening_bracket:
5722             if leaf is opening_bracket:
5723                 opening_bracket = None
5724             elif leaf.type in CLOSING_BRACKETS:
5725                 prev = line.leaves[index - 1] if index > 0 else None
5726                 if (
5727                     line.should_explode
5728                     and prev
5729                     and prev.type == token.COMMA
5730                     and not prev.was_checked
5731                     and not is_one_tuple_between(
5732                         leaf.opening_bracket, leaf, line.leaves
5733                     )
5734                 ):
5735                     # Never omit bracket pairs with pre-existing trailing commas.
5736                     # We need to explode on those.
5737                     break
5738
5739                 inner_brackets.add(id(leaf))
5740         elif leaf.type in CLOSING_BRACKETS:
5741             prev = line.leaves[index - 1] if index > 0 else None
5742             if prev and prev.type in OPENING_BRACKETS:
5743                 # Empty brackets would fail a split so treat them as "inner"
5744                 # brackets (e.g. only add them to the `omit` set if another
5745                 # pair of brackets was good enough.
5746                 inner_brackets.add(id(leaf))
5747                 continue
5748
5749             if closing_bracket:
5750                 omit.add(id(closing_bracket))
5751                 omit.update(inner_brackets)
5752                 inner_brackets.clear()
5753                 yield omit
5754
5755             if (
5756                 line.should_explode
5757                 and prev
5758                 and prev.type == token.COMMA
5759                 and not prev.was_checked
5760                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
5761             ):
5762                 # Never omit bracket pairs with pre-existing trailing commas.
5763                 # We need to explode on those.
5764                 break
5765
5766             if leaf.value:
5767                 opening_bracket = leaf.opening_bracket
5768                 closing_bracket = leaf
5769
5770
5771 def get_future_imports(node: Node) -> Set[str]:
5772     """Return a set of __future__ imports in the file."""
5773     imports: Set[str] = set()
5774
5775     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
5776         for child in children:
5777             if isinstance(child, Leaf):
5778                 if child.type == token.NAME:
5779                     yield child.value
5780
5781             elif child.type == syms.import_as_name:
5782                 orig_name = child.children[0]
5783                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
5784                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
5785                 yield orig_name.value
5786
5787             elif child.type == syms.import_as_names:
5788                 yield from get_imports_from_children(child.children)
5789
5790             else:
5791                 raise AssertionError("Invalid syntax parsing imports")
5792
5793     for child in node.children:
5794         if child.type != syms.simple_stmt:
5795             break
5796
5797         first_child = child.children[0]
5798         if isinstance(first_child, Leaf):
5799             # Continue looking if we see a docstring; otherwise stop.
5800             if (
5801                 len(child.children) == 2
5802                 and first_child.type == token.STRING
5803                 and child.children[1].type == token.NEWLINE
5804             ):
5805                 continue
5806
5807             break
5808
5809         elif first_child.type == syms.import_from:
5810             module_name = first_child.children[1]
5811             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
5812                 break
5813
5814             imports |= set(get_imports_from_children(first_child.children[3:]))
5815         else:
5816             break
5817
5818     return imports
5819
5820
5821 @lru_cache()
5822 def get_gitignore(root: Path) -> PathSpec:
5823     """ Return a PathSpec matching gitignore content if present."""
5824     gitignore = root / ".gitignore"
5825     lines: List[str] = []
5826     if gitignore.is_file():
5827         with gitignore.open() as gf:
5828             lines = gf.readlines()
5829     return PathSpec.from_lines("gitwildmatch", lines)
5830
5831
5832 def normalize_path_maybe_ignore(
5833     path: Path, root: Path, report: "Report"
5834 ) -> Optional[str]:
5835     """Normalize `path`. May return `None` if `path` was ignored.
5836
5837     `report` is where "path ignored" output goes.
5838     """
5839     try:
5840         normalized_path = path.resolve().relative_to(root).as_posix()
5841     except OSError as e:
5842         report.path_ignored(path, f"cannot be read because {e}")
5843         return None
5844
5845     except ValueError:
5846         if path.is_symlink():
5847             report.path_ignored(path, f"is a symbolic link that points outside {root}")
5848             return None
5849
5850         raise
5851
5852     return normalized_path
5853
5854
5855 def gen_python_files(
5856     paths: Iterable[Path],
5857     root: Path,
5858     include: Optional[Pattern[str]],
5859     exclude: Pattern[str],
5860     force_exclude: Optional[Pattern[str]],
5861     report: "Report",
5862     gitignore: PathSpec,
5863 ) -> Iterator[Path]:
5864     """Generate all files under `path` whose paths are not excluded by the
5865     `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
5866
5867     Symbolic links pointing outside of the `root` directory are ignored.
5868
5869     `report` is where output about exclusions goes.
5870     """
5871     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
5872     for child in paths:
5873         normalized_path = normalize_path_maybe_ignore(child, root, report)
5874         if normalized_path is None:
5875             continue
5876
5877         # First ignore files matching .gitignore
5878         if gitignore.match_file(normalized_path):
5879             report.path_ignored(child, "matches the .gitignore file content")
5880             continue
5881
5882         # Then ignore with `--exclude` and `--force-exclude` options.
5883         normalized_path = "/" + normalized_path
5884         if child.is_dir():
5885             normalized_path += "/"
5886
5887         exclude_match = exclude.search(normalized_path) if exclude else None
5888         if exclude_match and exclude_match.group(0):
5889             report.path_ignored(child, "matches the --exclude regular expression")
5890             continue
5891
5892         force_exclude_match = (
5893             force_exclude.search(normalized_path) if force_exclude else None
5894         )
5895         if force_exclude_match and force_exclude_match.group(0):
5896             report.path_ignored(child, "matches the --force-exclude regular expression")
5897             continue
5898
5899         if child.is_dir():
5900             yield from gen_python_files(
5901                 child.iterdir(),
5902                 root,
5903                 include,
5904                 exclude,
5905                 force_exclude,
5906                 report,
5907                 gitignore,
5908             )
5909
5910         elif child.is_file():
5911             include_match = include.search(normalized_path) if include else True
5912             if include_match:
5913                 yield child
5914
5915
5916 @lru_cache()
5917 def find_project_root(srcs: Iterable[str]) -> Path:
5918     """Return a directory containing .git, .hg, or pyproject.toml.
5919
5920     That directory will be a common parent of all files and directories
5921     passed in `srcs`.
5922
5923     If no directory in the tree contains a marker that would specify it's the
5924     project root, the root of the file system is returned.
5925     """
5926     if not srcs:
5927         return Path("/").resolve()
5928
5929     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
5930
5931     # A list of lists of parents for each 'src'. 'src' is included as a
5932     # "parent" of itself if it is a directory
5933     src_parents = [
5934         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
5935     ]
5936
5937     common_base = max(
5938         set.intersection(*(set(parents) for parents in src_parents)),
5939         key=lambda path: path.parts,
5940     )
5941
5942     for directory in (common_base, *common_base.parents):
5943         if (directory / ".git").exists():
5944             return directory
5945
5946         if (directory / ".hg").is_dir():
5947             return directory
5948
5949         if (directory / "pyproject.toml").is_file():
5950             return directory
5951
5952     return directory
5953
5954
5955 @dataclass
5956 class Report:
5957     """Provides a reformatting counter. Can be rendered with `str(report)`."""
5958
5959     check: bool = False
5960     diff: bool = False
5961     quiet: bool = False
5962     verbose: bool = False
5963     change_count: int = 0
5964     same_count: int = 0
5965     failure_count: int = 0
5966
5967     def done(self, src: Path, changed: Changed) -> None:
5968         """Increment the counter for successful reformatting. Write out a message."""
5969         if changed is Changed.YES:
5970             reformatted = "would reformat" if self.check or self.diff else "reformatted"
5971             if self.verbose or not self.quiet:
5972                 out(f"{reformatted} {src}")
5973             self.change_count += 1
5974         else:
5975             if self.verbose:
5976                 if changed is Changed.NO:
5977                     msg = f"{src} already well formatted, good job."
5978                 else:
5979                     msg = f"{src} wasn't modified on disk since last run."
5980                 out(msg, bold=False)
5981             self.same_count += 1
5982
5983     def failed(self, src: Path, message: str) -> None:
5984         """Increment the counter for failed reformatting. Write out a message."""
5985         err(f"error: cannot format {src}: {message}")
5986         self.failure_count += 1
5987
5988     def path_ignored(self, path: Path, message: str) -> None:
5989         if self.verbose:
5990             out(f"{path} ignored: {message}", bold=False)
5991
5992     @property
5993     def return_code(self) -> int:
5994         """Return the exit code that the app should use.
5995
5996         This considers the current state of changed files and failures:
5997         - if there were any failures, return 123;
5998         - if any files were changed and --check is being used, return 1;
5999         - otherwise return 0.
6000         """
6001         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6002         # 126 we have special return codes reserved by the shell.
6003         if self.failure_count:
6004             return 123
6005
6006         elif self.change_count and self.check:
6007             return 1
6008
6009         return 0
6010
6011     def __str__(self) -> str:
6012         """Render a color report of the current state.
6013
6014         Use `click.unstyle` to remove colors.
6015         """
6016         if self.check or self.diff:
6017             reformatted = "would be reformatted"
6018             unchanged = "would be left unchanged"
6019             failed = "would fail to reformat"
6020         else:
6021             reformatted = "reformatted"
6022             unchanged = "left unchanged"
6023             failed = "failed to reformat"
6024         report = []
6025         if self.change_count:
6026             s = "s" if self.change_count > 1 else ""
6027             report.append(
6028                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6029             )
6030         if self.same_count:
6031             s = "s" if self.same_count > 1 else ""
6032             report.append(f"{self.same_count} file{s} {unchanged}")
6033         if self.failure_count:
6034             s = "s" if self.failure_count > 1 else ""
6035             report.append(
6036                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6037             )
6038         return ", ".join(report) + "."
6039
6040
6041 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6042     filename = "<unknown>"
6043     if sys.version_info >= (3, 8):
6044         # TODO: support Python 4+ ;)
6045         for minor_version in range(sys.version_info[1], 4, -1):
6046             try:
6047                 return ast.parse(src, filename, feature_version=(3, minor_version))
6048             except SyntaxError:
6049                 continue
6050     else:
6051         for feature_version in (7, 6):
6052             try:
6053                 return ast3.parse(src, filename, feature_version=feature_version)
6054             except SyntaxError:
6055                 continue
6056
6057     return ast27.parse(src)
6058
6059
6060 def _fixup_ast_constants(
6061     node: Union[ast.AST, ast3.AST, ast27.AST]
6062 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6063     """Map ast nodes deprecated in 3.8 to Constant."""
6064     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6065         return ast.Constant(value=node.s)
6066
6067     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6068         return ast.Constant(value=node.n)
6069
6070     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6071         return ast.Constant(value=node.value)
6072
6073     return node
6074
6075
6076 def _stringify_ast(
6077     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6078 ) -> Iterator[str]:
6079     """Simple visitor generating strings to compare ASTs by content."""
6080
6081     node = _fixup_ast_constants(node)
6082
6083     yield f"{'  ' * depth}{node.__class__.__name__}("
6084
6085     for field in sorted(node._fields):  # noqa: F402
6086         # TypeIgnore has only one field 'lineno' which breaks this comparison
6087         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6088         if sys.version_info >= (3, 8):
6089             type_ignore_classes += (ast.TypeIgnore,)
6090         if isinstance(node, type_ignore_classes):
6091             break
6092
6093         try:
6094             value = getattr(node, field)
6095         except AttributeError:
6096             continue
6097
6098         yield f"{'  ' * (depth+1)}{field}="
6099
6100         if isinstance(value, list):
6101             for item in value:
6102                 # Ignore nested tuples within del statements, because we may insert
6103                 # parentheses and they change the AST.
6104                 if (
6105                     field == "targets"
6106                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6107                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6108                 ):
6109                     for item in item.elts:
6110                         yield from _stringify_ast(item, depth + 2)
6111
6112                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6113                     yield from _stringify_ast(item, depth + 2)
6114
6115         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6116             yield from _stringify_ast(value, depth + 2)
6117
6118         else:
6119             # Constant strings may be indented across newlines, if they are
6120             # docstrings; fold spaces after newlines when comparing. Similarly,
6121             # trailing and leading space may be removed.
6122             if (
6123                 isinstance(node, ast.Constant)
6124                 and field == "value"
6125                 and isinstance(value, str)
6126             ):
6127                 normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
6128             else:
6129                 normalized = value
6130             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6131
6132     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6133
6134
6135 def assert_equivalent(src: str, dst: str) -> None:
6136     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6137     try:
6138         src_ast = parse_ast(src)
6139     except Exception as exc:
6140         raise AssertionError(
6141             "cannot use --safe with this file; failed to parse source file.  AST"
6142             f" error message: {exc}"
6143         )
6144
6145     try:
6146         dst_ast = parse_ast(dst)
6147     except Exception as exc:
6148         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6149         raise AssertionError(
6150             f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
6151             " on https://github.com/psf/black/issues.  This invalid output might be"
6152             f" helpful: {log}"
6153         ) from None
6154
6155     src_ast_str = "\n".join(_stringify_ast(src_ast))
6156     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6157     if src_ast_str != dst_ast_str:
6158         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6159         raise AssertionError(
6160             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6161             " source.  Please report a bug on https://github.com/psf/black/issues. "
6162             f" This diff might be helpful: {log}"
6163         ) from None
6164
6165
6166 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6167     """Raise AssertionError if `dst` reformats differently the second time."""
6168     newdst = format_str(dst, mode=mode)
6169     if dst != newdst:
6170         log = dump_to_file(
6171             diff(src, dst, "source", "first pass"),
6172             diff(dst, newdst, "first pass", "second pass"),
6173         )
6174         raise AssertionError(
6175             "INTERNAL ERROR: Black produced different code on the second pass of the"
6176             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6177             f"  This diff might be helpful: {log}"
6178         ) from None
6179
6180
6181 @mypyc_attr(patchable=True)
6182 def dump_to_file(*output: str) -> str:
6183     """Dump `output` to a temporary file. Return path to the file."""
6184     with tempfile.NamedTemporaryFile(
6185         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6186     ) as f:
6187         for lines in output:
6188             f.write(lines)
6189             if lines and lines[-1] != "\n":
6190                 f.write("\n")
6191     return f.name
6192
6193
6194 @contextmanager
6195 def nullcontext() -> Iterator[None]:
6196     """Return an empty context manager.
6197
6198     To be used like `nullcontext` in Python 3.7.
6199     """
6200     yield
6201
6202
6203 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6204     """Return a unified diff string between strings `a` and `b`."""
6205     import difflib
6206
6207     a_lines = [line + "\n" for line in a.splitlines()]
6208     b_lines = [line + "\n" for line in b.splitlines()]
6209     return "".join(
6210         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
6211     )
6212
6213
6214 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6215     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6216     err("Aborted!")
6217     for task in tasks:
6218         task.cancel()
6219
6220
6221 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6222     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6223     try:
6224         if sys.version_info[:2] >= (3, 7):
6225             all_tasks = asyncio.all_tasks
6226         else:
6227             all_tasks = asyncio.Task.all_tasks
6228         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6229         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6230         if not to_cancel:
6231             return
6232
6233         for task in to_cancel:
6234             task.cancel()
6235         loop.run_until_complete(
6236             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6237         )
6238     finally:
6239         # `concurrent.futures.Future` objects cannot be cancelled once they
6240         # are already running. There might be some when the `shutdown()` happened.
6241         # Silence their logger's spew about the event loop being closed.
6242         cf_logger = logging.getLogger("concurrent.futures")
6243         cf_logger.setLevel(logging.CRITICAL)
6244         loop.close()
6245
6246
6247 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6248     """Replace `regex` with `replacement` twice on `original`.
6249
6250     This is used by string normalization to perform replaces on
6251     overlapping matches.
6252     """
6253     return regex.sub(replacement, regex.sub(replacement, original))
6254
6255
6256 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6257     """Compile a regular expression string in `regex`.
6258
6259     If it contains newlines, use verbose mode.
6260     """
6261     if "\n" in regex:
6262         regex = "(?x)" + regex
6263     compiled: Pattern[str] = re.compile(regex)
6264     return compiled
6265
6266
6267 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6268     """Like `reversed(enumerate(sequence))` if that were possible."""
6269     index = len(sequence) - 1
6270     for element in reversed(sequence):
6271         yield (index, element)
6272         index -= 1
6273
6274
6275 def enumerate_with_length(
6276     line: Line, reversed: bool = False
6277 ) -> Iterator[Tuple[Index, Leaf, int]]:
6278     """Return an enumeration of leaves with their length.
6279
6280     Stops prematurely on multiline strings and standalone comments.
6281     """
6282     op = cast(
6283         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6284         enumerate_reversed if reversed else enumerate,
6285     )
6286     for index, leaf in op(line.leaves):
6287         length = len(leaf.prefix) + len(leaf.value)
6288         if "\n" in leaf.value:
6289             return  # Multiline strings, we can't continue.
6290
6291         for comment in line.comments_after(leaf):
6292             length += len(comment.value)
6293
6294         yield index, leaf, length
6295
6296
6297 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6298     """Return True if `line` is no longer than `line_length`.
6299
6300     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6301     """
6302     if not line_str:
6303         line_str = line_to_string(line)
6304     return (
6305         len(line_str) <= line_length
6306         and "\n" not in line_str  # multiline strings
6307         and not line.contains_standalone_comments()
6308     )
6309
6310
6311 def can_be_split(line: Line) -> bool:
6312     """Return False if the line cannot be split *for sure*.
6313
6314     This is not an exhaustive search but a cheap heuristic that we can use to
6315     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6316     in unnecessary parentheses).
6317     """
6318     leaves = line.leaves
6319     if len(leaves) < 2:
6320         return False
6321
6322     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6323         call_count = 0
6324         dot_count = 0
6325         next = leaves[-1]
6326         for leaf in leaves[-2::-1]:
6327             if leaf.type in OPENING_BRACKETS:
6328                 if next.type not in CLOSING_BRACKETS:
6329                     return False
6330
6331                 call_count += 1
6332             elif leaf.type == token.DOT:
6333                 dot_count += 1
6334             elif leaf.type == token.NAME:
6335                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6336                     return False
6337
6338             elif leaf.type not in CLOSING_BRACKETS:
6339                 return False
6340
6341             if dot_count > 1 and call_count > 1:
6342                 return False
6343
6344     return True
6345
6346
6347 def can_omit_invisible_parens(
6348     line: Line,
6349     line_length: int,
6350     omit_on_explode: Collection[LeafID] = (),
6351 ) -> bool:
6352     """Does `line` have a shape safe to reformat without optional parens around it?
6353
6354     Returns True for only a subset of potentially nice looking formattings but
6355     the point is to not return false positives that end up producing lines that
6356     are too long.
6357     """
6358     bt = line.bracket_tracker
6359     if not bt.delimiters:
6360         # Without delimiters the optional parentheses are useless.
6361         return True
6362
6363     max_priority = bt.max_delimiter_priority()
6364     if bt.delimiter_count_with_priority(max_priority) > 1:
6365         # With more than one delimiter of a kind the optional parentheses read better.
6366         return False
6367
6368     if max_priority == DOT_PRIORITY:
6369         # A single stranded method call doesn't require optional parentheses.
6370         return True
6371
6372     assert len(line.leaves) >= 2, "Stranded delimiter"
6373
6374     # With a single delimiter, omit if the expression starts or ends with
6375     # a bracket.
6376     first = line.leaves[0]
6377     second = line.leaves[1]
6378     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6379         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6380             return True
6381
6382         # Note: we are not returning False here because a line might have *both*
6383         # a leading opening bracket and a trailing closing bracket.  If the
6384         # opening bracket doesn't match our rule, maybe the closing will.
6385
6386     penultimate = line.leaves[-2]
6387     last = line.leaves[-1]
6388     if line.should_explode:
6389         try:
6390             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6391         except LookupError:
6392             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6393             return False
6394
6395     if (
6396         last.type == token.RPAR
6397         or last.type == token.RBRACE
6398         or (
6399             # don't use indexing for omitting optional parentheses;
6400             # it looks weird
6401             last.type == token.RSQB
6402             and last.parent
6403             and last.parent.type != syms.trailer
6404         )
6405     ):
6406         if penultimate.type in OPENING_BRACKETS:
6407             # Empty brackets don't help.
6408             return False
6409
6410         if is_multiline_string(first):
6411             # Additional wrapping of a multiline string in this situation is
6412             # unnecessary.
6413             return True
6414
6415         if (
6416             line.should_explode
6417             and penultimate.type == token.COMMA
6418             and not penultimate.was_checked
6419         ):
6420             # The rightmost non-omitted bracket pair is the one we want to explode on.
6421             return True
6422
6423         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6424             return True
6425
6426     return False
6427
6428
6429 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6430     """See `can_omit_invisible_parens`."""
6431     remainder = False
6432     length = 4 * line.depth
6433     _index = -1
6434     for _index, leaf, leaf_length in enumerate_with_length(line):
6435         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6436             remainder = True
6437         if remainder:
6438             length += leaf_length
6439             if length > line_length:
6440                 break
6441
6442             if leaf.type in OPENING_BRACKETS:
6443                 # There are brackets we can further split on.
6444                 remainder = False
6445
6446     else:
6447         # checked the entire string and line length wasn't exceeded
6448         if len(line.leaves) == _index + 1:
6449             return True
6450
6451     return False
6452
6453
6454 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6455     """See `can_omit_invisible_parens`."""
6456     length = 4 * line.depth
6457     seen_other_brackets = False
6458     for _index, leaf, leaf_length in enumerate_with_length(line):
6459         length += leaf_length
6460         if leaf is last.opening_bracket:
6461             if seen_other_brackets or length <= line_length:
6462                 return True
6463
6464         elif leaf.type in OPENING_BRACKETS:
6465             # There are brackets we can further split on.
6466             seen_other_brackets = True
6467
6468     return False
6469
6470
6471 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6472     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6473     stop_after = None
6474     last = None
6475     for leaf in reversed(leaves):
6476         if stop_after:
6477             if leaf is stop_after:
6478                 stop_after = None
6479             continue
6480
6481         if last:
6482             return leaf, last
6483
6484         if id(leaf) in omit:
6485             stop_after = leaf.opening_bracket
6486         else:
6487             last = leaf
6488     else:
6489         raise LookupError("Last two leaves were also skipped")
6490
6491
6492 def run_transformer(
6493     line: Line,
6494     transform: Transformer,
6495     mode: Mode,
6496     features: Collection[Feature],
6497     *,
6498     line_str: str = "",
6499 ) -> List[Line]:
6500     if not line_str:
6501         line_str = line_to_string(line)
6502     result: List[Line] = []
6503     for transformed_line in transform(line, features):
6504         if str(transformed_line).strip("\n") == line_str:
6505             raise CannotTransform("Line transformer returned an unchanged result")
6506
6507         result.extend(transform_line(transformed_line, mode=mode, features=features))
6508
6509     if not (
6510         transform.__name__ == "rhs"
6511         and line.bracket_tracker.invisible
6512         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6513         and not line.contains_multiline_strings()
6514         and not result[0].contains_uncollapsable_type_comments()
6515         and not result[0].contains_unsplittable_type_ignore()
6516         and not is_line_short_enough(result[0], line_length=mode.line_length)
6517     ):
6518         return result
6519
6520     line_copy = line.clone()
6521     append_leaves(line_copy, line, line.leaves)
6522     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6523     second_opinion = run_transformer(
6524         line_copy, transform, mode, features_fop, line_str=line_str
6525     )
6526     if all(
6527         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6528     ):
6529         result = second_opinion
6530     return result
6531
6532
6533 def get_cache_file(mode: Mode) -> Path:
6534     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6535
6536
6537 def read_cache(mode: Mode) -> Cache:
6538     """Read the cache if it exists and is well formed.
6539
6540     If it is not well formed, the call to write_cache later should resolve the issue.
6541     """
6542     cache_file = get_cache_file(mode)
6543     if not cache_file.exists():
6544         return {}
6545
6546     with cache_file.open("rb") as fobj:
6547         try:
6548             cache: Cache = pickle.load(fobj)
6549         except (pickle.UnpicklingError, ValueError):
6550             return {}
6551
6552     return cache
6553
6554
6555 def get_cache_info(path: Path) -> CacheInfo:
6556     """Return the information used to check if a file is already formatted or not."""
6557     stat = path.stat()
6558     return stat.st_mtime, stat.st_size
6559
6560
6561 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6562     """Split an iterable of paths in `sources` into two sets.
6563
6564     The first contains paths of files that modified on disk or are not in the
6565     cache. The other contains paths to non-modified files.
6566     """
6567     todo, done = set(), set()
6568     for src in sources:
6569         src = src.resolve()
6570         if cache.get(src) != get_cache_info(src):
6571             todo.add(src)
6572         else:
6573             done.add(src)
6574     return todo, done
6575
6576
6577 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6578     """Update the cache file."""
6579     cache_file = get_cache_file(mode)
6580     try:
6581         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6582         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
6583         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6584             pickle.dump(new_cache, f, protocol=4)
6585         os.replace(f.name, cache_file)
6586     except OSError:
6587         pass
6588
6589
6590 def patch_click() -> None:
6591     """Make Click not crash.
6592
6593     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6594     default which restricts paths that it can access during the lifetime of the
6595     application.  Click refuses to work in this scenario by raising a RuntimeError.
6596
6597     In case of Black the likelihood that non-ASCII characters are going to be used in
6598     file paths is minimal since it's Python source code.  Moreover, this crash was
6599     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6600     """
6601     try:
6602         from click import core
6603         from click import _unicodefun  # type: ignore
6604     except ModuleNotFoundError:
6605         return
6606
6607     for module in (core, _unicodefun):
6608         if hasattr(module, "_verify_python3_env"):
6609             module._verify_python3_env = lambda: None
6610
6611
6612 def patched_main() -> None:
6613     freeze_support()
6614     patch_click()
6615     main()
6616
6617
6618 def is_docstring(leaf: Leaf) -> bool:
6619     if not is_multiline_string(leaf):
6620         # For the purposes of docstring re-indentation, we don't need to do anything
6621         # with single-line docstrings.
6622         return False
6623
6624     if prev_siblings_are(
6625         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
6626     ):
6627         return True
6628
6629     # Multiline docstring on the same line as the `def`.
6630     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
6631         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
6632         # grammar. We're safe to return True without further checks.
6633         return True
6634
6635     return False
6636
6637
6638 def fix_docstring(docstring: str, prefix: str) -> str:
6639     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
6640     if not docstring:
6641         return ""
6642     # Convert tabs to spaces (following the normal Python rules)
6643     # and split into a list of lines:
6644     lines = docstring.expandtabs().splitlines()
6645     # Determine minimum indentation (first line doesn't count):
6646     indent = sys.maxsize
6647     for line in lines[1:]:
6648         stripped = line.lstrip()
6649         if stripped:
6650             indent = min(indent, len(line) - len(stripped))
6651     # Remove indentation (first line is special):
6652     trimmed = [lines[0].strip()]
6653     if indent < sys.maxsize:
6654         last_line_idx = len(lines) - 2
6655         for i, line in enumerate(lines[1:]):
6656             stripped_line = line[indent:].rstrip()
6657             if stripped_line or i == last_line_idx:
6658                 trimmed.append(prefix + stripped_line)
6659             else:
6660                 trimmed.append("")
6661     return "\n".join(trimmed)
6662
6663
6664 if __name__ == "__main__":
6665     patched_main()