]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump pathspec to >= 0.8.1 to solve invalid .gitignore exclusion handling (#2084)
[etc/vim.git] / src / black / __init__.py
1 import ast
2 import asyncio
3 from abc import ABC, abstractmethod
4 from collections import defaultdict
5 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
6 from contextlib import contextmanager
7 from datetime import datetime
8 from enum import Enum
9 from functools import lru_cache, partial, wraps
10 import io
11 import itertools
12 import logging
13 from multiprocessing import Manager, freeze_support
14 import os
15 from pathlib import Path
16 import pickle
17 import regex as re
18 import signal
19 import sys
20 import tempfile
21 import tokenize
22 import traceback
23 from typing import (
24     Any,
25     Callable,
26     Collection,
27     Dict,
28     Generator,
29     Generic,
30     Iterable,
31     Iterator,
32     List,
33     Optional,
34     Pattern,
35     Sequence,
36     Set,
37     Sized,
38     Tuple,
39     Type,
40     TypeVar,
41     Union,
42     cast,
43     TYPE_CHECKING,
44 )
45 from mypy_extensions import mypyc_attr
46
47 from appdirs import user_cache_dir
48 from dataclasses import dataclass, field, replace
49 import click
50 import toml
51
52 try:
53     from typed_ast import ast3, ast27
54 except ImportError:
55     if sys.version_info < (3, 8):
56         print(
57             "The typed_ast package is not installed.\n"
58             "You can install it with `python3 -m pip install typed-ast`.",
59             file=sys.stderr,
60         )
61         sys.exit(1)
62     else:
63         ast3 = ast27 = ast
64
65 from pathspec import PathSpec
66
67 # lib2to3 fork
68 from blib2to3.pytree import Node, Leaf, type_repr
69 from blib2to3 import pygram, pytree
70 from blib2to3.pgen2 import driver, token
71 from blib2to3.pgen2.grammar import Grammar
72 from blib2to3.pgen2.parse import ParseError
73
74 from _black_version import version as __version__
75
76 if sys.version_info < (3, 8):
77     from typing_extensions import Final
78 else:
79     from typing import Final
80
81 if TYPE_CHECKING:
82     import colorama  # noqa: F401
83
84 DEFAULT_LINE_LENGTH = 88
85 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
86 DEFAULT_INCLUDES = r"\.pyi?$"
87 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
88 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
89
90 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
91
92
93 # types
94 FileContent = str
95 Encoding = str
96 NewLine = str
97 Depth = int
98 NodeType = int
99 ParserState = int
100 LeafID = int
101 StringID = int
102 Priority = int
103 Index = int
104 LN = Union[Leaf, Node]
105 Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
106 Timestamp = float
107 FileSize = int
108 CacheInfo = Tuple[Timestamp, FileSize]
109 Cache = Dict[str, CacheInfo]
110 out = partial(click.secho, bold=True, err=True)
111 err = partial(click.secho, fg="red", err=True)
112
113 pygram.initialize(CACHE_DIR)
114 syms = pygram.python_symbols
115
116
117 class NothingChanged(UserWarning):
118     """Raised when reformatted code is the same as source."""
119
120
121 class CannotTransform(Exception):
122     """Base class for errors raised by Transformers."""
123
124
125 class CannotSplit(CannotTransform):
126     """A readable split that fits the allotted line length is impossible."""
127
128
129 class InvalidInput(ValueError):
130     """Raised when input source code fails all parse attempts."""
131
132
133 class BracketMatchError(KeyError):
134     """Raised when an opening bracket is unable to be matched to a closing bracket."""
135
136
137 T = TypeVar("T")
138 E = TypeVar("E", bound=Exception)
139
140
141 class Ok(Generic[T]):
142     def __init__(self, value: T) -> None:
143         self._value = value
144
145     def ok(self) -> T:
146         return self._value
147
148
149 class Err(Generic[E]):
150     def __init__(self, e: E) -> None:
151         self._e = e
152
153     def err(self) -> E:
154         return self._e
155
156
157 # The 'Result' return type is used to implement an error-handling model heavily
158 # influenced by that used by the Rust programming language
159 # (see https://doc.rust-lang.org/book/ch09-00-error-handling.html).
160 Result = Union[Ok[T], Err[E]]
161 TResult = Result[T, CannotTransform]  # (T)ransform Result
162 TMatchResult = TResult[Index]
163
164
165 class WriteBack(Enum):
166     NO = 0
167     YES = 1
168     DIFF = 2
169     CHECK = 3
170     COLOR_DIFF = 4
171
172     @classmethod
173     def from_configuration(
174         cls, *, check: bool, diff: bool, color: bool = False
175     ) -> "WriteBack":
176         if check and not diff:
177             return cls.CHECK
178
179         if diff and color:
180             return cls.COLOR_DIFF
181
182         return cls.DIFF if diff else cls.YES
183
184
185 class Changed(Enum):
186     NO = 0
187     CACHED = 1
188     YES = 2
189
190
191 class TargetVersion(Enum):
192     PY27 = 2
193     PY33 = 3
194     PY34 = 4
195     PY35 = 5
196     PY36 = 6
197     PY37 = 7
198     PY38 = 8
199     PY39 = 9
200
201     def is_python2(self) -> bool:
202         return self is TargetVersion.PY27
203
204
205 class Feature(Enum):
206     # All string literals are unicode
207     UNICODE_LITERALS = 1
208     F_STRINGS = 2
209     NUMERIC_UNDERSCORES = 3
210     TRAILING_COMMA_IN_CALL = 4
211     TRAILING_COMMA_IN_DEF = 5
212     # The following two feature-flags are mutually exclusive, and exactly one should be
213     # set for every version of python.
214     ASYNC_IDENTIFIERS = 6
215     ASYNC_KEYWORDS = 7
216     ASSIGNMENT_EXPRESSIONS = 8
217     POS_ONLY_ARGUMENTS = 9
218     RELAXED_DECORATORS = 10
219     FORCE_OPTIONAL_PARENTHESES = 50
220
221
222 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
223     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
224     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
225     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
226     TargetVersion.PY35: {
227         Feature.UNICODE_LITERALS,
228         Feature.TRAILING_COMMA_IN_CALL,
229         Feature.ASYNC_IDENTIFIERS,
230     },
231     TargetVersion.PY36: {
232         Feature.UNICODE_LITERALS,
233         Feature.F_STRINGS,
234         Feature.NUMERIC_UNDERSCORES,
235         Feature.TRAILING_COMMA_IN_CALL,
236         Feature.TRAILING_COMMA_IN_DEF,
237         Feature.ASYNC_IDENTIFIERS,
238     },
239     TargetVersion.PY37: {
240         Feature.UNICODE_LITERALS,
241         Feature.F_STRINGS,
242         Feature.NUMERIC_UNDERSCORES,
243         Feature.TRAILING_COMMA_IN_CALL,
244         Feature.TRAILING_COMMA_IN_DEF,
245         Feature.ASYNC_KEYWORDS,
246     },
247     TargetVersion.PY38: {
248         Feature.UNICODE_LITERALS,
249         Feature.F_STRINGS,
250         Feature.NUMERIC_UNDERSCORES,
251         Feature.TRAILING_COMMA_IN_CALL,
252         Feature.TRAILING_COMMA_IN_DEF,
253         Feature.ASYNC_KEYWORDS,
254         Feature.ASSIGNMENT_EXPRESSIONS,
255         Feature.POS_ONLY_ARGUMENTS,
256     },
257     TargetVersion.PY39: {
258         Feature.UNICODE_LITERALS,
259         Feature.F_STRINGS,
260         Feature.NUMERIC_UNDERSCORES,
261         Feature.TRAILING_COMMA_IN_CALL,
262         Feature.TRAILING_COMMA_IN_DEF,
263         Feature.ASYNC_KEYWORDS,
264         Feature.ASSIGNMENT_EXPRESSIONS,
265         Feature.RELAXED_DECORATORS,
266         Feature.POS_ONLY_ARGUMENTS,
267     },
268 }
269
270
271 @dataclass
272 class Mode:
273     target_versions: Set[TargetVersion] = field(default_factory=set)
274     line_length: int = DEFAULT_LINE_LENGTH
275     string_normalization: bool = True
276     is_pyi: bool = False
277     magic_trailing_comma: bool = True
278     experimental_string_processing: bool = False
279
280     def get_cache_key(self) -> str:
281         if self.target_versions:
282             version_str = ",".join(
283                 str(version.value)
284                 for version in sorted(self.target_versions, key=lambda v: v.value)
285             )
286         else:
287             version_str = "-"
288         parts = [
289             version_str,
290             str(self.line_length),
291             str(int(self.string_normalization)),
292             str(int(self.is_pyi)),
293             str(int(self.magic_trailing_comma)),
294             str(int(self.experimental_string_processing)),
295         ]
296         return ".".join(parts)
297
298
299 # Legacy name, left for integrations.
300 FileMode = Mode
301
302
303 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
304     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
305
306
307 def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
308     """Find the absolute filepath to a pyproject.toml if it exists"""
309     path_project_root = find_project_root(path_search_start)
310     path_pyproject_toml = path_project_root / "pyproject.toml"
311     if path_pyproject_toml.is_file():
312         return str(path_pyproject_toml)
313
314     path_user_pyproject_toml = find_user_pyproject_toml()
315     return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
316
317
318 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
319     """Parse a pyproject toml file, pulling out relevant parts for Black
320
321     If parsing fails, will raise a toml.TomlDecodeError
322     """
323     pyproject_toml = toml.load(path_config)
324     config = pyproject_toml.get("tool", {}).get("black", {})
325     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
326
327
328 def read_pyproject_toml(
329     ctx: click.Context, param: click.Parameter, value: Optional[str]
330 ) -> Optional[str]:
331     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
332
333     Returns the path to a successfully found and read configuration file, None
334     otherwise.
335     """
336     if not value:
337         value = find_pyproject_toml(ctx.params.get("src", ()))
338         if value is None:
339             return None
340
341     try:
342         config = parse_pyproject_toml(value)
343     except (toml.TomlDecodeError, OSError) as e:
344         raise click.FileError(
345             filename=value, hint=f"Error reading configuration file: {e}"
346         )
347
348     if not config:
349         return None
350     else:
351         # Sanitize the values to be Click friendly. For more information please see:
352         # https://github.com/psf/black/issues/1458
353         # https://github.com/pallets/click/issues/1567
354         config = {
355             k: str(v) if not isinstance(v, (list, dict)) else v
356             for k, v in config.items()
357         }
358
359     target_version = config.get("target_version")
360     if target_version is not None and not isinstance(target_version, list):
361         raise click.BadOptionUsage(
362             "target-version", "Config key target-version must be a list"
363         )
364
365     default_map: Dict[str, Any] = {}
366     if ctx.default_map:
367         default_map.update(ctx.default_map)
368     default_map.update(config)
369
370     ctx.default_map = default_map
371     return value
372
373
374 def target_version_option_callback(
375     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
376 ) -> List[TargetVersion]:
377     """Compute the target versions from a --target-version flag.
378
379     This is its own function because mypy couldn't infer the type correctly
380     when it was a lambda, causing mypyc trouble.
381     """
382     return [TargetVersion[val.upper()] for val in v]
383
384
385 def validate_regex(
386     ctx: click.Context,
387     param: click.Parameter,
388     value: Optional[str],
389 ) -> Optional[Pattern]:
390     try:
391         return re_compile_maybe_verbose(value) if value is not None else None
392     except re.error:
393         raise click.BadParameter("Not a valid regular expression")
394
395
396 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
397 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
398 @click.option(
399     "-l",
400     "--line-length",
401     type=int,
402     default=DEFAULT_LINE_LENGTH,
403     help="How many characters per line to allow.",
404     show_default=True,
405 )
406 @click.option(
407     "-t",
408     "--target-version",
409     type=click.Choice([v.name.lower() for v in TargetVersion]),
410     callback=target_version_option_callback,
411     multiple=True,
412     help=(
413         "Python versions that should be supported by Black's output. [default: per-file"
414         " auto-detection]"
415     ),
416 )
417 @click.option(
418     "--pyi",
419     is_flag=True,
420     help=(
421         "Format all input files like typing stubs regardless of file extension (useful"
422         " when piping source on standard input)."
423     ),
424 )
425 @click.option(
426     "-S",
427     "--skip-string-normalization",
428     is_flag=True,
429     help="Don't normalize string quotes or prefixes.",
430 )
431 @click.option(
432     "-C",
433     "--skip-magic-trailing-comma",
434     is_flag=True,
435     help="Don't use trailing commas as a reason to split lines.",
436 )
437 @click.option(
438     "--experimental-string-processing",
439     is_flag=True,
440     hidden=True,
441     help=(
442         "Experimental option that performs more normalization on string literals."
443         " Currently disabled because it leads to some crashes."
444     ),
445 )
446 @click.option(
447     "--check",
448     is_flag=True,
449     help=(
450         "Don't write the files back, just return the status. Return code 0 means"
451         " nothing would change. Return code 1 means some files would be reformatted."
452         " Return code 123 means there was an internal error."
453     ),
454 )
455 @click.option(
456     "--diff",
457     is_flag=True,
458     help="Don't write the files back, just output a diff for each file on stdout.",
459 )
460 @click.option(
461     "--color/--no-color",
462     is_flag=True,
463     help="Show colored diff. Only applies when `--diff` is given.",
464 )
465 @click.option(
466     "--fast/--safe",
467     is_flag=True,
468     help="If --fast given, skip temporary sanity checks. [default: --safe]",
469 )
470 @click.option(
471     "--include",
472     type=str,
473     default=DEFAULT_INCLUDES,
474     callback=validate_regex,
475     help=(
476         "A regular expression that matches files and directories that should be"
477         " included on recursive searches. An empty value means all files are included"
478         " regardless of the name. Use forward slashes for directories on all platforms"
479         " (Windows, too). Exclusions are calculated first, inclusions later."
480     ),
481     show_default=True,
482 )
483 @click.option(
484     "--exclude",
485     type=str,
486     default=DEFAULT_EXCLUDES,
487     callback=validate_regex,
488     help=(
489         "A regular expression that matches files and directories that should be"
490         " excluded on recursive searches. An empty value means no paths are excluded."
491         " Use forward slashes for directories on all platforms (Windows, too)."
492         " Exclusions are calculated first, inclusions later."
493     ),
494     show_default=True,
495 )
496 @click.option(
497     "--extend-exclude",
498     type=str,
499     callback=validate_regex,
500     help=(
501         "Like --exclude, but adds additional files and directories on top of the"
502         " excluded ones. (Useful if you simply want to add to the default)"
503     ),
504 )
505 @click.option(
506     "--force-exclude",
507     type=str,
508     callback=validate_regex,
509     help=(
510         "Like --exclude, but files and directories matching this regex will be "
511         "excluded even when they are passed explicitly as arguments."
512     ),
513 )
514 @click.option(
515     "--stdin-filename",
516     type=str,
517     help=(
518         "The name of the file when passing it through stdin. Useful to make "
519         "sure Black will respect --force-exclude option on some "
520         "editors that rely on using stdin."
521     ),
522 )
523 @click.option(
524     "-q",
525     "--quiet",
526     is_flag=True,
527     help=(
528         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
529         " those with 2>/dev/null."
530     ),
531 )
532 @click.option(
533     "-v",
534     "--verbose",
535     is_flag=True,
536     help=(
537         "Also emit messages to stderr about files that were not changed or were ignored"
538         " due to exclusion patterns."
539     ),
540 )
541 @click.version_option(version=__version__)
542 @click.argument(
543     "src",
544     nargs=-1,
545     type=click.Path(
546         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
547     ),
548     is_eager=True,
549 )
550 @click.option(
551     "--config",
552     type=click.Path(
553         exists=True,
554         file_okay=True,
555         dir_okay=False,
556         readable=True,
557         allow_dash=False,
558         path_type=str,
559     ),
560     is_eager=True,
561     callback=read_pyproject_toml,
562     help="Read configuration from FILE path.",
563 )
564 @click.pass_context
565 def main(
566     ctx: click.Context,
567     code: Optional[str],
568     line_length: int,
569     target_version: List[TargetVersion],
570     check: bool,
571     diff: bool,
572     color: bool,
573     fast: bool,
574     pyi: bool,
575     skip_string_normalization: bool,
576     skip_magic_trailing_comma: bool,
577     experimental_string_processing: bool,
578     quiet: bool,
579     verbose: bool,
580     include: Pattern,
581     exclude: Pattern,
582     extend_exclude: Optional[Pattern],
583     force_exclude: Optional[Pattern],
584     stdin_filename: Optional[str],
585     src: Tuple[str, ...],
586     config: Optional[str],
587 ) -> None:
588     """The uncompromising code formatter."""
589     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
590     if target_version:
591         versions = set(target_version)
592     else:
593         # We'll autodetect later.
594         versions = set()
595     mode = Mode(
596         target_versions=versions,
597         line_length=line_length,
598         is_pyi=pyi,
599         string_normalization=not skip_string_normalization,
600         magic_trailing_comma=not skip_magic_trailing_comma,
601         experimental_string_processing=experimental_string_processing,
602     )
603     if config and verbose:
604         out(f"Using configuration from {config}.", bold=False, fg="blue")
605     if code is not None:
606         print(format_str(code, mode=mode))
607         ctx.exit(0)
608     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
609     sources = get_sources(
610         ctx=ctx,
611         src=src,
612         quiet=quiet,
613         verbose=verbose,
614         include=include,
615         exclude=exclude,
616         extend_exclude=extend_exclude,
617         force_exclude=force_exclude,
618         report=report,
619         stdin_filename=stdin_filename,
620     )
621
622     path_empty(
623         sources,
624         "No Python files are present to be formatted. Nothing to do 😴",
625         quiet,
626         verbose,
627         ctx,
628     )
629
630     if len(sources) == 1:
631         reformat_one(
632             src=sources.pop(),
633             fast=fast,
634             write_back=write_back,
635             mode=mode,
636             report=report,
637         )
638     else:
639         reformat_many(
640             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
641         )
642
643     if verbose or not quiet:
644         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
645         click.secho(str(report), err=True)
646     ctx.exit(report.return_code)
647
648
649 def get_sources(
650     *,
651     ctx: click.Context,
652     src: Tuple[str, ...],
653     quiet: bool,
654     verbose: bool,
655     include: Pattern[str],
656     exclude: Pattern[str],
657     extend_exclude: Optional[Pattern[str]],
658     force_exclude: Optional[Pattern[str]],
659     report: "Report",
660     stdin_filename: Optional[str],
661 ) -> Set[Path]:
662     """Compute the set of files to be formatted."""
663
664     root = find_project_root(src)
665     sources: Set[Path] = set()
666     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
667     gitignore = get_gitignore(root)
668
669     for s in src:
670         if s == "-" and stdin_filename:
671             p = Path(stdin_filename)
672             is_stdin = True
673         else:
674             p = Path(s)
675             is_stdin = False
676
677         if is_stdin or p.is_file():
678             normalized_path = normalize_path_maybe_ignore(p, root, report)
679             if normalized_path is None:
680                 continue
681
682             normalized_path = "/" + normalized_path
683             # Hard-exclude any files that matches the `--force-exclude` regex.
684             if force_exclude:
685                 force_exclude_match = force_exclude.search(normalized_path)
686             else:
687                 force_exclude_match = None
688             if force_exclude_match and force_exclude_match.group(0):
689                 report.path_ignored(p, "matches the --force-exclude regular expression")
690                 continue
691
692             if is_stdin:
693                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
694
695             sources.add(p)
696         elif p.is_dir():
697             sources.update(
698                 gen_python_files(
699                     p.iterdir(),
700                     root,
701                     include,
702                     exclude,
703                     extend_exclude,
704                     force_exclude,
705                     report,
706                     gitignore,
707                 )
708             )
709         elif s == "-":
710             sources.add(p)
711         else:
712             err(f"invalid path: {s}")
713     return sources
714
715
716 def path_empty(
717     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
718 ) -> None:
719     """
720     Exit if there is no `src` provided for formatting
721     """
722     if not src and (verbose or not quiet):
723         out(msg)
724         ctx.exit(0)
725
726
727 def reformat_one(
728     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
729 ) -> None:
730     """Reformat a single file under `src` without spawning child processes.
731
732     `fast`, `write_back`, and `mode` options are passed to
733     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
734     """
735     try:
736         changed = Changed.NO
737
738         if str(src) == "-":
739             is_stdin = True
740         elif str(src).startswith(STDIN_PLACEHOLDER):
741             is_stdin = True
742             # Use the original name again in case we want to print something
743             # to the user
744             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
745         else:
746             is_stdin = False
747
748         if is_stdin:
749             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
750                 changed = Changed.YES
751         else:
752             cache: Cache = {}
753             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
754                 cache = read_cache(mode)
755                 res_src = src.resolve()
756                 res_src_s = str(res_src)
757                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
758                     changed = Changed.CACHED
759             if changed is not Changed.CACHED and format_file_in_place(
760                 src, fast=fast, write_back=write_back, mode=mode
761             ):
762                 changed = Changed.YES
763             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
764                 write_back is WriteBack.CHECK and changed is Changed.NO
765             ):
766                 write_cache(cache, [src], mode)
767         report.done(src, changed)
768     except Exception as exc:
769         if report.verbose:
770             traceback.print_exc()
771         report.failed(src, str(exc))
772
773
774 def reformat_many(
775     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
776 ) -> None:
777     """Reformat multiple files using a ProcessPoolExecutor."""
778     executor: Executor
779     loop = asyncio.get_event_loop()
780     worker_count = os.cpu_count()
781     if sys.platform == "win32":
782         # Work around https://bugs.python.org/issue26903
783         worker_count = min(worker_count, 60)
784     try:
785         executor = ProcessPoolExecutor(max_workers=worker_count)
786     except (ImportError, OSError):
787         # we arrive here if the underlying system does not support multi-processing
788         # like in AWS Lambda or Termux, in which case we gracefully fallback to
789         # a ThreadPoolExecutor with just a single worker (more workers would not do us
790         # any good due to the Global Interpreter Lock)
791         executor = ThreadPoolExecutor(max_workers=1)
792
793     try:
794         loop.run_until_complete(
795             schedule_formatting(
796                 sources=sources,
797                 fast=fast,
798                 write_back=write_back,
799                 mode=mode,
800                 report=report,
801                 loop=loop,
802                 executor=executor,
803             )
804         )
805     finally:
806         shutdown(loop)
807         if executor is not None:
808             executor.shutdown()
809
810
811 async def schedule_formatting(
812     sources: Set[Path],
813     fast: bool,
814     write_back: WriteBack,
815     mode: Mode,
816     report: "Report",
817     loop: asyncio.AbstractEventLoop,
818     executor: Executor,
819 ) -> None:
820     """Run formatting of `sources` in parallel using the provided `executor`.
821
822     (Use ProcessPoolExecutors for actual parallelism.)
823
824     `write_back`, `fast`, and `mode` options are passed to
825     :func:`format_file_in_place`.
826     """
827     cache: Cache = {}
828     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
829         cache = read_cache(mode)
830         sources, cached = filter_cached(cache, sources)
831         for src in sorted(cached):
832             report.done(src, Changed.CACHED)
833     if not sources:
834         return
835
836     cancelled = []
837     sources_to_cache = []
838     lock = None
839     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
840         # For diff output, we need locks to ensure we don't interleave output
841         # from different processes.
842         manager = Manager()
843         lock = manager.Lock()
844     tasks = {
845         asyncio.ensure_future(
846             loop.run_in_executor(
847                 executor, format_file_in_place, src, fast, mode, write_back, lock
848             )
849         ): src
850         for src in sorted(sources)
851     }
852     pending = tasks.keys()
853     try:
854         loop.add_signal_handler(signal.SIGINT, cancel, pending)
855         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
856     except NotImplementedError:
857         # There are no good alternatives for these on Windows.
858         pass
859     while pending:
860         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
861         for task in done:
862             src = tasks.pop(task)
863             if task.cancelled():
864                 cancelled.append(task)
865             elif task.exception():
866                 report.failed(src, str(task.exception()))
867             else:
868                 changed = Changed.YES if task.result() else Changed.NO
869                 # If the file was written back or was successfully checked as
870                 # well-formatted, store this information in the cache.
871                 if write_back is WriteBack.YES or (
872                     write_back is WriteBack.CHECK and changed is Changed.NO
873                 ):
874                     sources_to_cache.append(src)
875                 report.done(src, changed)
876     if cancelled:
877         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
878     if sources_to_cache:
879         write_cache(cache, sources_to_cache, mode)
880
881
882 def format_file_in_place(
883     src: Path,
884     fast: bool,
885     mode: Mode,
886     write_back: WriteBack = WriteBack.NO,
887     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
888 ) -> bool:
889     """Format file under `src` path. Return True if changed.
890
891     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
892     code to the file.
893     `mode` and `fast` options are passed to :func:`format_file_contents`.
894     """
895     if src.suffix == ".pyi":
896         mode = replace(mode, is_pyi=True)
897
898     then = datetime.utcfromtimestamp(src.stat().st_mtime)
899     with open(src, "rb") as buf:
900         src_contents, encoding, newline = decode_bytes(buf.read())
901     try:
902         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
903     except NothingChanged:
904         return False
905
906     if write_back == WriteBack.YES:
907         with open(src, "w", encoding=encoding, newline=newline) as f:
908             f.write(dst_contents)
909     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
910         now = datetime.utcnow()
911         src_name = f"{src}\t{then} +0000"
912         dst_name = f"{src}\t{now} +0000"
913         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
914
915         if write_back == WriteBack.COLOR_DIFF:
916             diff_contents = color_diff(diff_contents)
917
918         with lock or nullcontext():
919             f = io.TextIOWrapper(
920                 sys.stdout.buffer,
921                 encoding=encoding,
922                 newline=newline,
923                 write_through=True,
924             )
925             f = wrap_stream_for_windows(f)
926             f.write(diff_contents)
927             f.detach()
928
929     return True
930
931
932 def color_diff(contents: str) -> str:
933     """Inject the ANSI color codes to the diff."""
934     lines = contents.split("\n")
935     for i, line in enumerate(lines):
936         if line.startswith("+++") or line.startswith("---"):
937             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
938         elif line.startswith("@@"):
939             line = "\033[36m" + line + "\033[0m"  # cyan, reset
940         elif line.startswith("+"):
941             line = "\033[32m" + line + "\033[0m"  # green, reset
942         elif line.startswith("-"):
943             line = "\033[31m" + line + "\033[0m"  # red, reset
944         lines[i] = line
945     return "\n".join(lines)
946
947
948 def wrap_stream_for_windows(
949     f: io.TextIOWrapper,
950 ) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
951     """
952     Wrap stream with colorama's wrap_stream so colors are shown on Windows.
953
954     If `colorama` is unavailable, the original stream is returned unmodified.
955     Otherwise, the `wrap_stream()` function determines whether the stream needs
956     to be wrapped for a Windows environment and will accordingly either return
957     an `AnsiToWin32` wrapper or the original stream.
958     """
959     try:
960         from colorama.initialise import wrap_stream
961     except ImportError:
962         return f
963     else:
964         # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
965         return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
966
967
968 def format_stdin_to_stdout(
969     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
970 ) -> bool:
971     """Format file on stdin. Return True if changed.
972
973     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
974     write a diff to stdout. The `mode` argument is passed to
975     :func:`format_file_contents`.
976     """
977     then = datetime.utcnow()
978     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
979     dst = src
980     try:
981         dst = format_file_contents(src, fast=fast, mode=mode)
982         return True
983
984     except NothingChanged:
985         return False
986
987     finally:
988         f = io.TextIOWrapper(
989             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
990         )
991         if write_back == WriteBack.YES:
992             f.write(dst)
993         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
994             now = datetime.utcnow()
995             src_name = f"STDIN\t{then} +0000"
996             dst_name = f"STDOUT\t{now} +0000"
997             d = diff(src, dst, src_name, dst_name)
998             if write_back == WriteBack.COLOR_DIFF:
999                 d = color_diff(d)
1000                 f = wrap_stream_for_windows(f)
1001             f.write(d)
1002         f.detach()
1003
1004
1005 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1006     """Reformat contents of a file and return new contents.
1007
1008     If `fast` is False, additionally confirm that the reformatted code is
1009     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
1010     `mode` is passed to :func:`format_str`.
1011     """
1012     if not src_contents.strip():
1013         raise NothingChanged
1014
1015     dst_contents = format_str(src_contents, mode=mode)
1016     if src_contents == dst_contents:
1017         raise NothingChanged
1018
1019     if not fast:
1020         assert_equivalent(src_contents, dst_contents)
1021
1022         # Forced second pass to work around optional trailing commas (becoming
1023         # forced trailing commas on pass 2) interacting differently with optional
1024         # parentheses.  Admittedly ugly.
1025         dst_contents_pass2 = format_str(dst_contents, mode=mode)
1026         if dst_contents != dst_contents_pass2:
1027             dst_contents = dst_contents_pass2
1028             assert_equivalent(src_contents, dst_contents, pass_num=2)
1029             assert_stable(src_contents, dst_contents, mode=mode)
1030         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
1031         # the same as `dst_contents_pass2`.
1032     return dst_contents
1033
1034
1035 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1036     """Reformat a string and return new contents.
1037
1038     `mode` determines formatting options, such as how many characters per line are
1039     allowed.  Example:
1040
1041     >>> import black
1042     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1043     def f(arg: str = "") -> None:
1044         ...
1045
1046     A more complex example:
1047
1048     >>> print(
1049     ...   black.format_str(
1050     ...     "def f(arg:str='')->None: hey",
1051     ...     mode=black.Mode(
1052     ...       target_versions={black.TargetVersion.PY36},
1053     ...       line_length=10,
1054     ...       string_normalization=False,
1055     ...       is_pyi=False,
1056     ...     ),
1057     ...   ),
1058     ... )
1059     def f(
1060         arg: str = '',
1061     ) -> None:
1062         hey
1063
1064     """
1065     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1066     dst_contents = []
1067     future_imports = get_future_imports(src_node)
1068     if mode.target_versions:
1069         versions = mode.target_versions
1070     else:
1071         versions = detect_target_versions(src_node)
1072     normalize_fmt_off(src_node)
1073     lines = LineGenerator(
1074         mode=mode,
1075         remove_u_prefix="unicode_literals" in future_imports
1076         or supports_feature(versions, Feature.UNICODE_LITERALS),
1077     )
1078     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1079     empty_line = Line(mode=mode)
1080     after = 0
1081     split_line_features = {
1082         feature
1083         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1084         if supports_feature(versions, feature)
1085     }
1086     for current_line in lines.visit(src_node):
1087         dst_contents.append(str(empty_line) * after)
1088         before, after = elt.maybe_empty_lines(current_line)
1089         dst_contents.append(str(empty_line) * before)
1090         for line in transform_line(
1091             current_line, mode=mode, features=split_line_features
1092         ):
1093             dst_contents.append(str(line))
1094     return "".join(dst_contents)
1095
1096
1097 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1098     """Return a tuple of (decoded_contents, encoding, newline).
1099
1100     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1101     universal newlines (i.e. only contains LF).
1102     """
1103     srcbuf = io.BytesIO(src)
1104     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1105     if not lines:
1106         return "", encoding, "\n"
1107
1108     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1109     srcbuf.seek(0)
1110     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1111         return tiow.read(), encoding, newline
1112
1113
1114 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
1115     if not target_versions:
1116         # No target_version specified, so try all grammars.
1117         return [
1118             # Python 3.7+
1119             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
1120             # Python 3.0-3.6
1121             pygram.python_grammar_no_print_statement_no_exec_statement,
1122             # Python 2.7 with future print_function import
1123             pygram.python_grammar_no_print_statement,
1124             # Python 2.7
1125             pygram.python_grammar,
1126         ]
1127
1128     if all(version.is_python2() for version in target_versions):
1129         # Python 2-only code, so try Python 2 grammars.
1130         return [
1131             # Python 2.7 with future print_function import
1132             pygram.python_grammar_no_print_statement,
1133             # Python 2.7
1134             pygram.python_grammar,
1135         ]
1136
1137     # Python 3-compatible code, so only try Python 3 grammar.
1138     grammars = []
1139     # If we have to parse both, try to parse async as a keyword first
1140     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
1141         # Python 3.7+
1142         grammars.append(
1143             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
1144         )
1145     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
1146         # Python 3.0-3.6
1147         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
1148     # At least one of the above branches must have been taken, because every Python
1149     # version has exactly one of the two 'ASYNC_*' flags
1150     return grammars
1151
1152
1153 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
1154     """Given a string with source, return the lib2to3 Node."""
1155     if not src_txt.endswith("\n"):
1156         src_txt += "\n"
1157
1158     for grammar in get_grammars(set(target_versions)):
1159         drv = driver.Driver(grammar, pytree.convert)
1160         try:
1161             result = drv.parse_string(src_txt, True)
1162             break
1163
1164         except ParseError as pe:
1165             lineno, column = pe.context[1]
1166             lines = src_txt.splitlines()
1167             try:
1168                 faulty_line = lines[lineno - 1]
1169             except IndexError:
1170                 faulty_line = "<line number missing in source>"
1171             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
1172     else:
1173         raise exc from None
1174
1175     if isinstance(result, Leaf):
1176         result = Node(syms.file_input, [result])
1177     return result
1178
1179
1180 def lib2to3_unparse(node: Node) -> str:
1181     """Given a lib2to3 node, return its string representation."""
1182     code = str(node)
1183     return code
1184
1185
1186 class Visitor(Generic[T]):
1187     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
1188
1189     def visit(self, node: LN) -> Iterator[T]:
1190         """Main method to visit `node` and its children.
1191
1192         It tries to find a `visit_*()` method for the given `node.type`, like
1193         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
1194         If no dedicated `visit_*()` method is found, chooses `visit_default()`
1195         instead.
1196
1197         Then yields objects of type `T` from the selected visitor.
1198         """
1199         if node.type < 256:
1200             name = token.tok_name[node.type]
1201         else:
1202             name = str(type_repr(node.type))
1203         # We explicitly branch on whether a visitor exists (instead of
1204         # using self.visit_default as the default arg to getattr) in order
1205         # to save needing to create a bound method object and so mypyc can
1206         # generate a native call to visit_default.
1207         visitf = getattr(self, f"visit_{name}", None)
1208         if visitf:
1209             yield from visitf(node)
1210         else:
1211             yield from self.visit_default(node)
1212
1213     def visit_default(self, node: LN) -> Iterator[T]:
1214         """Default `visit_*()` implementation. Recurses to children of `node`."""
1215         if isinstance(node, Node):
1216             for child in node.children:
1217                 yield from self.visit(child)
1218
1219
1220 @dataclass
1221 class DebugVisitor(Visitor[T]):
1222     tree_depth: int = 0
1223
1224     def visit_default(self, node: LN) -> Iterator[T]:
1225         indent = " " * (2 * self.tree_depth)
1226         if isinstance(node, Node):
1227             _type = type_repr(node.type)
1228             out(f"{indent}{_type}", fg="yellow")
1229             self.tree_depth += 1
1230             for child in node.children:
1231                 yield from self.visit(child)
1232
1233             self.tree_depth -= 1
1234             out(f"{indent}/{_type}", fg="yellow", bold=False)
1235         else:
1236             _type = token.tok_name.get(node.type, str(node.type))
1237             out(f"{indent}{_type}", fg="blue", nl=False)
1238             if node.prefix:
1239                 # We don't have to handle prefixes for `Node` objects since
1240                 # that delegates to the first child anyway.
1241                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
1242             out(f" {node.value!r}", fg="blue", bold=False)
1243
1244     @classmethod
1245     def show(cls, code: Union[str, Leaf, Node]) -> None:
1246         """Pretty-print the lib2to3 AST of a given string of `code`.
1247
1248         Convenience method for debugging.
1249         """
1250         v: DebugVisitor[None] = DebugVisitor()
1251         if isinstance(code, str):
1252             code = lib2to3_parse(code)
1253         list(v.visit(code))
1254
1255
1256 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
1257 STATEMENT: Final = {
1258     syms.if_stmt,
1259     syms.while_stmt,
1260     syms.for_stmt,
1261     syms.try_stmt,
1262     syms.except_clause,
1263     syms.with_stmt,
1264     syms.funcdef,
1265     syms.classdef,
1266 }
1267 STANDALONE_COMMENT: Final = 153
1268 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
1269 LOGIC_OPERATORS: Final = {"and", "or"}
1270 COMPARATORS: Final = {
1271     token.LESS,
1272     token.GREATER,
1273     token.EQEQUAL,
1274     token.NOTEQUAL,
1275     token.LESSEQUAL,
1276     token.GREATEREQUAL,
1277 }
1278 MATH_OPERATORS: Final = {
1279     token.VBAR,
1280     token.CIRCUMFLEX,
1281     token.AMPER,
1282     token.LEFTSHIFT,
1283     token.RIGHTSHIFT,
1284     token.PLUS,
1285     token.MINUS,
1286     token.STAR,
1287     token.SLASH,
1288     token.DOUBLESLASH,
1289     token.PERCENT,
1290     token.AT,
1291     token.TILDE,
1292     token.DOUBLESTAR,
1293 }
1294 STARS: Final = {token.STAR, token.DOUBLESTAR}
1295 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
1296 VARARGS_PARENTS: Final = {
1297     syms.arglist,
1298     syms.argument,  # double star in arglist
1299     syms.trailer,  # single argument to call
1300     syms.typedargslist,
1301     syms.varargslist,  # lambdas
1302 }
1303 UNPACKING_PARENTS: Final = {
1304     syms.atom,  # single element of a list or set literal
1305     syms.dictsetmaker,
1306     syms.listmaker,
1307     syms.testlist_gexp,
1308     syms.testlist_star_expr,
1309 }
1310 TEST_DESCENDANTS: Final = {
1311     syms.test,
1312     syms.lambdef,
1313     syms.or_test,
1314     syms.and_test,
1315     syms.not_test,
1316     syms.comparison,
1317     syms.star_expr,
1318     syms.expr,
1319     syms.xor_expr,
1320     syms.and_expr,
1321     syms.shift_expr,
1322     syms.arith_expr,
1323     syms.trailer,
1324     syms.term,
1325     syms.power,
1326 }
1327 ASSIGNMENTS: Final = {
1328     "=",
1329     "+=",
1330     "-=",
1331     "*=",
1332     "@=",
1333     "/=",
1334     "%=",
1335     "&=",
1336     "|=",
1337     "^=",
1338     "<<=",
1339     ">>=",
1340     "**=",
1341     "//=",
1342 }
1343 COMPREHENSION_PRIORITY: Final = 20
1344 COMMA_PRIORITY: Final = 18
1345 TERNARY_PRIORITY: Final = 16
1346 LOGIC_PRIORITY: Final = 14
1347 STRING_PRIORITY: Final = 12
1348 COMPARATOR_PRIORITY: Final = 10
1349 MATH_PRIORITIES: Final = {
1350     token.VBAR: 9,
1351     token.CIRCUMFLEX: 8,
1352     token.AMPER: 7,
1353     token.LEFTSHIFT: 6,
1354     token.RIGHTSHIFT: 6,
1355     token.PLUS: 5,
1356     token.MINUS: 5,
1357     token.STAR: 4,
1358     token.SLASH: 4,
1359     token.DOUBLESLASH: 4,
1360     token.PERCENT: 4,
1361     token.AT: 4,
1362     token.TILDE: 3,
1363     token.DOUBLESTAR: 2,
1364 }
1365 DOT_PRIORITY: Final = 1
1366
1367
1368 @dataclass
1369 class BracketTracker:
1370     """Keeps track of brackets on a line."""
1371
1372     depth: int = 0
1373     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1374     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1375     previous: Optional[Leaf] = None
1376     _for_loop_depths: List[int] = field(default_factory=list)
1377     _lambda_argument_depths: List[int] = field(default_factory=list)
1378     invisible: List[Leaf] = field(default_factory=list)
1379
1380     def mark(self, leaf: Leaf) -> None:
1381         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1382
1383         All leaves receive an int `bracket_depth` field that stores how deep
1384         within brackets a given leaf is. 0 means there are no enclosing brackets
1385         that started on this line.
1386
1387         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1388         field that it forms a pair with. This is a one-directional link to
1389         avoid reference cycles.
1390
1391         If a leaf is a delimiter (a token on which Black can split the line if
1392         needed) and it's on depth 0, its `id()` is stored in the tracker's
1393         `delimiters` field.
1394         """
1395         if leaf.type == token.COMMENT:
1396             return
1397
1398         self.maybe_decrement_after_for_loop_variable(leaf)
1399         self.maybe_decrement_after_lambda_arguments(leaf)
1400         if leaf.type in CLOSING_BRACKETS:
1401             self.depth -= 1
1402             try:
1403                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1404             except KeyError as e:
1405                 raise BracketMatchError(
1406                     "Unable to match a closing bracket to the following opening"
1407                     f" bracket: {leaf}"
1408                 ) from e
1409             leaf.opening_bracket = opening_bracket
1410             if not leaf.value:
1411                 self.invisible.append(leaf)
1412         leaf.bracket_depth = self.depth
1413         if self.depth == 0:
1414             delim = is_split_before_delimiter(leaf, self.previous)
1415             if delim and self.previous is not None:
1416                 self.delimiters[id(self.previous)] = delim
1417             else:
1418                 delim = is_split_after_delimiter(leaf, self.previous)
1419                 if delim:
1420                     self.delimiters[id(leaf)] = delim
1421         if leaf.type in OPENING_BRACKETS:
1422             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1423             self.depth += 1
1424             if not leaf.value:
1425                 self.invisible.append(leaf)
1426         self.previous = leaf
1427         self.maybe_increment_lambda_arguments(leaf)
1428         self.maybe_increment_for_loop_variable(leaf)
1429
1430     def any_open_brackets(self) -> bool:
1431         """Return True if there is an yet unmatched open bracket on the line."""
1432         return bool(self.bracket_match)
1433
1434     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1435         """Return the highest priority of a delimiter found on the line.
1436
1437         Values are consistent with what `is_split_*_delimiter()` return.
1438         Raises ValueError on no delimiters.
1439         """
1440         return max(v for k, v in self.delimiters.items() if k not in exclude)
1441
1442     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1443         """Return the number of delimiters with the given `priority`.
1444
1445         If no `priority` is passed, defaults to max priority on the line.
1446         """
1447         if not self.delimiters:
1448             return 0
1449
1450         priority = priority or self.max_delimiter_priority()
1451         return sum(1 for p in self.delimiters.values() if p == priority)
1452
1453     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1454         """In a for loop, or comprehension, the variables are often unpacks.
1455
1456         To avoid splitting on the comma in this situation, increase the depth of
1457         tokens between `for` and `in`.
1458         """
1459         if leaf.type == token.NAME and leaf.value == "for":
1460             self.depth += 1
1461             self._for_loop_depths.append(self.depth)
1462             return True
1463
1464         return False
1465
1466     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1467         """See `maybe_increment_for_loop_variable` above for explanation."""
1468         if (
1469             self._for_loop_depths
1470             and self._for_loop_depths[-1] == self.depth
1471             and leaf.type == token.NAME
1472             and leaf.value == "in"
1473         ):
1474             self.depth -= 1
1475             self._for_loop_depths.pop()
1476             return True
1477
1478         return False
1479
1480     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1481         """In a lambda expression, there might be more than one argument.
1482
1483         To avoid splitting on the comma in this situation, increase the depth of
1484         tokens between `lambda` and `:`.
1485         """
1486         if leaf.type == token.NAME and leaf.value == "lambda":
1487             self.depth += 1
1488             self._lambda_argument_depths.append(self.depth)
1489             return True
1490
1491         return False
1492
1493     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1494         """See `maybe_increment_lambda_arguments` above for explanation."""
1495         if (
1496             self._lambda_argument_depths
1497             and self._lambda_argument_depths[-1] == self.depth
1498             and leaf.type == token.COLON
1499         ):
1500             self.depth -= 1
1501             self._lambda_argument_depths.pop()
1502             return True
1503
1504         return False
1505
1506     def get_open_lsqb(self) -> Optional[Leaf]:
1507         """Return the most recent opening square bracket (if any)."""
1508         return self.bracket_match.get((self.depth - 1, token.RSQB))
1509
1510
1511 @dataclass
1512 class Line:
1513     """Holds leaves and comments. Can be printed with `str(line)`."""
1514
1515     mode: Mode
1516     depth: int = 0
1517     leaves: List[Leaf] = field(default_factory=list)
1518     # keys ordered like `leaves`
1519     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1520     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1521     inside_brackets: bool = False
1522     should_split_rhs: bool = False
1523     magic_trailing_comma: Optional[Leaf] = None
1524
1525     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1526         """Add a new `leaf` to the end of the line.
1527
1528         Unless `preformatted` is True, the `leaf` will receive a new consistent
1529         whitespace prefix and metadata applied by :class:`BracketTracker`.
1530         Trailing commas are maybe removed, unpacked for loop variables are
1531         demoted from being delimiters.
1532
1533         Inline comments are put aside.
1534         """
1535         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1536         if not has_value:
1537             return
1538
1539         if token.COLON == leaf.type and self.is_class_paren_empty:
1540             del self.leaves[-2:]
1541         if self.leaves and not preformatted:
1542             # Note: at this point leaf.prefix should be empty except for
1543             # imports, for which we only preserve newlines.
1544             leaf.prefix += whitespace(
1545                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1546             )
1547         if self.inside_brackets or not preformatted:
1548             self.bracket_tracker.mark(leaf)
1549             if self.mode.magic_trailing_comma:
1550                 if self.has_magic_trailing_comma(leaf):
1551                     self.magic_trailing_comma = leaf
1552             elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
1553                 self.remove_trailing_comma()
1554         if not self.append_comment(leaf):
1555             self.leaves.append(leaf)
1556
1557     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1558         """Like :func:`append()` but disallow invalid standalone comment structure.
1559
1560         Raises ValueError when any `leaf` is appended after a standalone comment
1561         or when a standalone comment is not the first leaf on the line.
1562         """
1563         if self.bracket_tracker.depth == 0:
1564             if self.is_comment:
1565                 raise ValueError("cannot append to standalone comments")
1566
1567             if self.leaves and leaf.type == STANDALONE_COMMENT:
1568                 raise ValueError(
1569                     "cannot append standalone comments to a populated line"
1570                 )
1571
1572         self.append(leaf, preformatted=preformatted)
1573
1574     @property
1575     def is_comment(self) -> bool:
1576         """Is this line a standalone comment?"""
1577         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1578
1579     @property
1580     def is_decorator(self) -> bool:
1581         """Is this line a decorator?"""
1582         return bool(self) and self.leaves[0].type == token.AT
1583
1584     @property
1585     def is_import(self) -> bool:
1586         """Is this an import line?"""
1587         return bool(self) and is_import(self.leaves[0])
1588
1589     @property
1590     def is_class(self) -> bool:
1591         """Is this line a class definition?"""
1592         return (
1593             bool(self)
1594             and self.leaves[0].type == token.NAME
1595             and self.leaves[0].value == "class"
1596         )
1597
1598     @property
1599     def is_stub_class(self) -> bool:
1600         """Is this line a class definition with a body consisting only of "..."?"""
1601         return self.is_class and self.leaves[-3:] == [
1602             Leaf(token.DOT, ".") for _ in range(3)
1603         ]
1604
1605     @property
1606     def is_def(self) -> bool:
1607         """Is this a function definition? (Also returns True for async defs.)"""
1608         try:
1609             first_leaf = self.leaves[0]
1610         except IndexError:
1611             return False
1612
1613         try:
1614             second_leaf: Optional[Leaf] = self.leaves[1]
1615         except IndexError:
1616             second_leaf = None
1617         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1618             first_leaf.type == token.ASYNC
1619             and second_leaf is not None
1620             and second_leaf.type == token.NAME
1621             and second_leaf.value == "def"
1622         )
1623
1624     @property
1625     def is_class_paren_empty(self) -> bool:
1626         """Is this a class with no base classes but using parentheses?
1627
1628         Those are unnecessary and should be removed.
1629         """
1630         return (
1631             bool(self)
1632             and len(self.leaves) == 4
1633             and self.is_class
1634             and self.leaves[2].type == token.LPAR
1635             and self.leaves[2].value == "("
1636             and self.leaves[3].type == token.RPAR
1637             and self.leaves[3].value == ")"
1638         )
1639
1640     @property
1641     def is_triple_quoted_string(self) -> bool:
1642         """Is the line a triple quoted string?"""
1643         return (
1644             bool(self)
1645             and self.leaves[0].type == token.STRING
1646             and self.leaves[0].value.startswith(('"""', "'''"))
1647         )
1648
1649     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1650         """If so, needs to be split before emitting."""
1651         for leaf in self.leaves:
1652             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1653                 return True
1654
1655         return False
1656
1657     def contains_uncollapsable_type_comments(self) -> bool:
1658         ignored_ids = set()
1659         try:
1660             last_leaf = self.leaves[-1]
1661             ignored_ids.add(id(last_leaf))
1662             if last_leaf.type == token.COMMA or (
1663                 last_leaf.type == token.RPAR and not last_leaf.value
1664             ):
1665                 # When trailing commas or optional parens are inserted by Black for
1666                 # consistency, comments after the previous last element are not moved
1667                 # (they don't have to, rendering will still be correct).  So we ignore
1668                 # trailing commas and invisible.
1669                 last_leaf = self.leaves[-2]
1670                 ignored_ids.add(id(last_leaf))
1671         except IndexError:
1672             return False
1673
1674         # A type comment is uncollapsable if it is attached to a leaf
1675         # that isn't at the end of the line (since that could cause it
1676         # to get associated to a different argument) or if there are
1677         # comments before it (since that could cause it to get hidden
1678         # behind a comment.
1679         comment_seen = False
1680         for leaf_id, comments in self.comments.items():
1681             for comment in comments:
1682                 if is_type_comment(comment):
1683                     if comment_seen or (
1684                         not is_type_comment(comment, " ignore")
1685                         and leaf_id not in ignored_ids
1686                     ):
1687                         return True
1688
1689                 comment_seen = True
1690
1691         return False
1692
1693     def contains_unsplittable_type_ignore(self) -> bool:
1694         if not self.leaves:
1695             return False
1696
1697         # If a 'type: ignore' is attached to the end of a line, we
1698         # can't split the line, because we can't know which of the
1699         # subexpressions the ignore was meant to apply to.
1700         #
1701         # We only want this to apply to actual physical lines from the
1702         # original source, though: we don't want the presence of a
1703         # 'type: ignore' at the end of a multiline expression to
1704         # justify pushing it all onto one line. Thus we
1705         # (unfortunately) need to check the actual source lines and
1706         # only report an unsplittable 'type: ignore' if this line was
1707         # one line in the original code.
1708
1709         # Grab the first and last line numbers, skipping generated leaves
1710         first_line = next((leaf.lineno for leaf in self.leaves if leaf.lineno != 0), 0)
1711         last_line = next(
1712             (leaf.lineno for leaf in reversed(self.leaves) if leaf.lineno != 0), 0
1713         )
1714
1715         if first_line == last_line:
1716             # We look at the last two leaves since a comma or an
1717             # invisible paren could have been added at the end of the
1718             # line.
1719             for node in self.leaves[-2:]:
1720                 for comment in self.comments.get(id(node), []):
1721                     if is_type_comment(comment, " ignore"):
1722                         return True
1723
1724         return False
1725
1726     def contains_multiline_strings(self) -> bool:
1727         return any(is_multiline_string(leaf) for leaf in self.leaves)
1728
1729     def has_magic_trailing_comma(
1730         self, closing: Leaf, ensure_removable: bool = False
1731     ) -> bool:
1732         """Return True if we have a magic trailing comma, that is when:
1733         - there's a trailing comma here
1734         - it's not a one-tuple
1735         Additionally, if ensure_removable:
1736         - it's not from square bracket indexing
1737         """
1738         if not (
1739             closing.type in CLOSING_BRACKETS
1740             and self.leaves
1741             and self.leaves[-1].type == token.COMMA
1742         ):
1743             return False
1744
1745         if closing.type == token.RBRACE:
1746             return True
1747
1748         if closing.type == token.RSQB:
1749             if not ensure_removable:
1750                 return True
1751             comma = self.leaves[-1]
1752             return bool(comma.parent and comma.parent.type == syms.listmaker)
1753
1754         if self.is_import:
1755             return True
1756
1757         if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves):
1758             return True
1759
1760         return False
1761
1762     def append_comment(self, comment: Leaf) -> bool:
1763         """Add an inline or standalone comment to the line."""
1764         if (
1765             comment.type == STANDALONE_COMMENT
1766             and self.bracket_tracker.any_open_brackets()
1767         ):
1768             comment.prefix = ""
1769             return False
1770
1771         if comment.type != token.COMMENT:
1772             return False
1773
1774         if not self.leaves:
1775             comment.type = STANDALONE_COMMENT
1776             comment.prefix = ""
1777             return False
1778
1779         last_leaf = self.leaves[-1]
1780         if (
1781             last_leaf.type == token.RPAR
1782             and not last_leaf.value
1783             and last_leaf.parent
1784             and len(list(last_leaf.parent.leaves())) <= 3
1785             and not is_type_comment(comment)
1786         ):
1787             # Comments on an optional parens wrapping a single leaf should belong to
1788             # the wrapped node except if it's a type comment. Pinning the comment like
1789             # this avoids unstable formatting caused by comment migration.
1790             if len(self.leaves) < 2:
1791                 comment.type = STANDALONE_COMMENT
1792                 comment.prefix = ""
1793                 return False
1794
1795             last_leaf = self.leaves[-2]
1796         self.comments.setdefault(id(last_leaf), []).append(comment)
1797         return True
1798
1799     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1800         """Generate comments that should appear directly after `leaf`."""
1801         return self.comments.get(id(leaf), [])
1802
1803     def remove_trailing_comma(self) -> None:
1804         """Remove the trailing comma and moves the comments attached to it."""
1805         trailing_comma = self.leaves.pop()
1806         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1807         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1808             trailing_comma_comments
1809         )
1810
1811     def is_complex_subscript(self, leaf: Leaf) -> bool:
1812         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1813         open_lsqb = self.bracket_tracker.get_open_lsqb()
1814         if open_lsqb is None:
1815             return False
1816
1817         subscript_start = open_lsqb.next_sibling
1818
1819         if isinstance(subscript_start, Node):
1820             if subscript_start.type == syms.listmaker:
1821                 return False
1822
1823             if subscript_start.type == syms.subscriptlist:
1824                 subscript_start = child_towards(subscript_start, leaf)
1825         return subscript_start is not None and any(
1826             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1827         )
1828
1829     def clone(self) -> "Line":
1830         return Line(
1831             mode=self.mode,
1832             depth=self.depth,
1833             inside_brackets=self.inside_brackets,
1834             should_split_rhs=self.should_split_rhs,
1835             magic_trailing_comma=self.magic_trailing_comma,
1836         )
1837
1838     def __str__(self) -> str:
1839         """Render the line."""
1840         if not self:
1841             return "\n"
1842
1843         indent = "    " * self.depth
1844         leaves = iter(self.leaves)
1845         first = next(leaves)
1846         res = f"{first.prefix}{indent}{first.value}"
1847         for leaf in leaves:
1848             res += str(leaf)
1849         for comment in itertools.chain.from_iterable(self.comments.values()):
1850             res += str(comment)
1851
1852         return res + "\n"
1853
1854     def __bool__(self) -> bool:
1855         """Return True if the line has leaves or comments."""
1856         return bool(self.leaves or self.comments)
1857
1858
1859 @dataclass
1860 class EmptyLineTracker:
1861     """Provides a stateful method that returns the number of potential extra
1862     empty lines needed before and after the currently processed line.
1863
1864     Note: this tracker works on lines that haven't been split yet.  It assumes
1865     the prefix of the first leaf consists of optional newlines.  Those newlines
1866     are consumed by `maybe_empty_lines()` and included in the computation.
1867     """
1868
1869     is_pyi: bool = False
1870     previous_line: Optional[Line] = None
1871     previous_after: int = 0
1872     previous_defs: List[int] = field(default_factory=list)
1873
1874     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1875         """Return the number of extra empty lines before and after the `current_line`.
1876
1877         This is for separating `def`, `async def` and `class` with extra empty
1878         lines (two on module-level).
1879         """
1880         before, after = self._maybe_empty_lines(current_line)
1881         before = (
1882             # Black should not insert empty lines at the beginning
1883             # of the file
1884             0
1885             if self.previous_line is None
1886             else before - self.previous_after
1887         )
1888         self.previous_after = after
1889         self.previous_line = current_line
1890         return before, after
1891
1892     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1893         max_allowed = 1
1894         if current_line.depth == 0:
1895             max_allowed = 1 if self.is_pyi else 2
1896         if current_line.leaves:
1897             # Consume the first leaf's extra newlines.
1898             first_leaf = current_line.leaves[0]
1899             before = first_leaf.prefix.count("\n")
1900             before = min(before, max_allowed)
1901             first_leaf.prefix = ""
1902         else:
1903             before = 0
1904         depth = current_line.depth
1905         while self.previous_defs and self.previous_defs[-1] >= depth:
1906             self.previous_defs.pop()
1907             if self.is_pyi:
1908                 before = 0 if depth else 1
1909             else:
1910                 before = 1 if depth else 2
1911         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1912             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1913
1914         if (
1915             self.previous_line
1916             and self.previous_line.is_import
1917             and not current_line.is_import
1918             and depth == self.previous_line.depth
1919         ):
1920             return (before or 1), 0
1921
1922         if (
1923             self.previous_line
1924             and self.previous_line.is_class
1925             and current_line.is_triple_quoted_string
1926         ):
1927             return before, 1
1928
1929         return before, 0
1930
1931     def _maybe_empty_lines_for_class_or_def(
1932         self, current_line: Line, before: int
1933     ) -> Tuple[int, int]:
1934         if not current_line.is_decorator:
1935             self.previous_defs.append(current_line.depth)
1936         if self.previous_line is None:
1937             # Don't insert empty lines before the first line in the file.
1938             return 0, 0
1939
1940         if self.previous_line.is_decorator:
1941             if self.is_pyi and current_line.is_stub_class:
1942                 # Insert an empty line after a decorated stub class
1943                 return 0, 1
1944
1945             return 0, 0
1946
1947         if self.previous_line.depth < current_line.depth and (
1948             self.previous_line.is_class or self.previous_line.is_def
1949         ):
1950             return 0, 0
1951
1952         if (
1953             self.previous_line.is_comment
1954             and self.previous_line.depth == current_line.depth
1955             and before == 0
1956         ):
1957             return 0, 0
1958
1959         if self.is_pyi:
1960             if self.previous_line.depth > current_line.depth:
1961                 newlines = 1
1962             elif current_line.is_class or self.previous_line.is_class:
1963                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1964                     # No blank line between classes with an empty body
1965                     newlines = 0
1966                 else:
1967                     newlines = 1
1968             elif (
1969                 current_line.is_def or current_line.is_decorator
1970             ) and not self.previous_line.is_def:
1971                 # Blank line between a block of functions (maybe with preceding
1972                 # decorators) and a block of non-functions
1973                 newlines = 1
1974             else:
1975                 newlines = 0
1976         else:
1977             newlines = 2
1978         if current_line.depth and newlines:
1979             newlines -= 1
1980         return newlines, 0
1981
1982
1983 @dataclass
1984 class LineGenerator(Visitor[Line]):
1985     """Generates reformatted Line objects.  Empty lines are not emitted.
1986
1987     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1988     in ways that will no longer stringify to valid Python code on the tree.
1989     """
1990
1991     mode: Mode
1992     remove_u_prefix: bool = False
1993     current_line: Line = field(init=False)
1994
1995     def line(self, indent: int = 0) -> Iterator[Line]:
1996         """Generate a line.
1997
1998         If the line is empty, only emit if it makes sense.
1999         If the line is too long, split it first and then generate.
2000
2001         If any lines were generated, set up a new current_line.
2002         """
2003         if not self.current_line:
2004             self.current_line.depth += indent
2005             return  # Line is empty, don't emit. Creating a new one unnecessary.
2006
2007         complete_line = self.current_line
2008         self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
2009         yield complete_line
2010
2011     def visit_default(self, node: LN) -> Iterator[Line]:
2012         """Default `visit_*()` implementation. Recurses to children of `node`."""
2013         if isinstance(node, Leaf):
2014             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
2015             for comment in generate_comments(node):
2016                 if any_open_brackets:
2017                     # any comment within brackets is subject to splitting
2018                     self.current_line.append(comment)
2019                 elif comment.type == token.COMMENT:
2020                     # regular trailing comment
2021                     self.current_line.append(comment)
2022                     yield from self.line()
2023
2024                 else:
2025                     # regular standalone comment
2026                     yield from self.line()
2027
2028                     self.current_line.append(comment)
2029                     yield from self.line()
2030
2031             normalize_prefix(node, inside_brackets=any_open_brackets)
2032             if self.mode.string_normalization and node.type == token.STRING:
2033                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
2034                 normalize_string_quotes(node)
2035             if node.type == token.NUMBER:
2036                 normalize_numeric_literal(node)
2037             if node.type not in WHITESPACE:
2038                 self.current_line.append(node)
2039         yield from super().visit_default(node)
2040
2041     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
2042         """Increase indentation level, maybe yield a line."""
2043         # In blib2to3 INDENT never holds comments.
2044         yield from self.line(+1)
2045         yield from self.visit_default(node)
2046
2047     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
2048         """Decrease indentation level, maybe yield a line."""
2049         # The current line might still wait for trailing comments.  At DEDENT time
2050         # there won't be any (they would be prefixes on the preceding NEWLINE).
2051         # Emit the line then.
2052         yield from self.line()
2053
2054         # While DEDENT has no value, its prefix may contain standalone comments
2055         # that belong to the current indentation level.  Get 'em.
2056         yield from self.visit_default(node)
2057
2058         # Finally, emit the dedent.
2059         yield from self.line(-1)
2060
2061     def visit_stmt(
2062         self, node: Node, keywords: Set[str], parens: Set[str]
2063     ) -> Iterator[Line]:
2064         """Visit a statement.
2065
2066         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
2067         `def`, `with`, `class`, `assert` and assignments.
2068
2069         The relevant Python language `keywords` for a given statement will be
2070         NAME leaves within it. This methods puts those on a separate line.
2071
2072         `parens` holds a set of string leaf values immediately after which
2073         invisible parens should be put.
2074         """
2075         normalize_invisible_parens(node, parens_after=parens)
2076         for child in node.children:
2077             if child.type == token.NAME and child.value in keywords:  # type: ignore
2078                 yield from self.line()
2079
2080             yield from self.visit(child)
2081
2082     def visit_suite(self, node: Node) -> Iterator[Line]:
2083         """Visit a suite."""
2084         if self.mode.is_pyi and is_stub_suite(node):
2085             yield from self.visit(node.children[2])
2086         else:
2087             yield from self.visit_default(node)
2088
2089     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
2090         """Visit a statement without nested statements."""
2091         if first_child_is_arith(node):
2092             wrap_in_parentheses(node, node.children[0], visible=False)
2093         is_suite_like = node.parent and node.parent.type in STATEMENT
2094         if is_suite_like:
2095             if self.mode.is_pyi and is_stub_body(node):
2096                 yield from self.visit_default(node)
2097             else:
2098                 yield from self.line(+1)
2099                 yield from self.visit_default(node)
2100                 yield from self.line(-1)
2101
2102         else:
2103             if (
2104                 not self.mode.is_pyi
2105                 or not node.parent
2106                 or not is_stub_suite(node.parent)
2107             ):
2108                 yield from self.line()
2109             yield from self.visit_default(node)
2110
2111     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
2112         """Visit `async def`, `async for`, `async with`."""
2113         yield from self.line()
2114
2115         children = iter(node.children)
2116         for child in children:
2117             yield from self.visit(child)
2118
2119             if child.type == token.ASYNC:
2120                 break
2121
2122         internal_stmt = next(children)
2123         for child in internal_stmt.children:
2124             yield from self.visit(child)
2125
2126     def visit_decorators(self, node: Node) -> Iterator[Line]:
2127         """Visit decorators."""
2128         for child in node.children:
2129             yield from self.line()
2130             yield from self.visit(child)
2131
2132     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
2133         """Remove a semicolon and put the other statement on a separate line."""
2134         yield from self.line()
2135
2136     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
2137         """End of file. Process outstanding comments and end with a newline."""
2138         yield from self.visit_default(leaf)
2139         yield from self.line()
2140
2141     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
2142         if not self.current_line.bracket_tracker.any_open_brackets():
2143             yield from self.line()
2144         yield from self.visit_default(leaf)
2145
2146     def visit_factor(self, node: Node) -> Iterator[Line]:
2147         """Force parentheses between a unary op and a binary power:
2148
2149         -2 ** 8 -> -(2 ** 8)
2150         """
2151         _operator, operand = node.children
2152         if (
2153             operand.type == syms.power
2154             and len(operand.children) == 3
2155             and operand.children[1].type == token.DOUBLESTAR
2156         ):
2157             lpar = Leaf(token.LPAR, "(")
2158             rpar = Leaf(token.RPAR, ")")
2159             index = operand.remove() or 0
2160             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
2161         yield from self.visit_default(node)
2162
2163     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
2164         if is_docstring(leaf) and "\\\n" not in leaf.value:
2165             # We're ignoring docstrings with backslash newline escapes because changing
2166             # indentation of those changes the AST representation of the code.
2167             prefix = get_string_prefix(leaf.value)
2168             docstring = leaf.value[len(prefix) :]  # Remove the prefix
2169             quote_char = docstring[0]
2170             # A natural way to remove the outer quotes is to do:
2171             #   docstring = docstring.strip(quote_char)
2172             # but that breaks on """""x""" (which is '""x').
2173             # So we actually need to remove the first character and the next two
2174             # characters but only if they are the same as the first.
2175             quote_len = 1 if docstring[1] != quote_char else 3
2176             docstring = docstring[quote_len:-quote_len]
2177
2178             if is_multiline_string(leaf):
2179                 indent = " " * 4 * self.current_line.depth
2180                 docstring = fix_docstring(docstring, indent)
2181             else:
2182                 docstring = docstring.strip()
2183
2184             if docstring:
2185                 # Add some padding if the docstring starts / ends with a quote mark.
2186                 if docstring[0] == quote_char:
2187                     docstring = " " + docstring
2188                 if docstring[-1] == quote_char:
2189                     docstring += " "
2190                 if docstring[-1] == "\\":
2191                     backslash_count = len(docstring) - len(docstring.rstrip("\\"))
2192                     if backslash_count % 2:
2193                         # Odd number of tailing backslashes, add some padding to
2194                         # avoid escaping the closing string quote.
2195                         docstring += " "
2196             else:
2197                 # Add some padding if the docstring is empty.
2198                 docstring = " "
2199
2200             # We could enforce triple quotes at this point.
2201             quote = quote_char * quote_len
2202             leaf.value = prefix + quote + docstring + quote
2203
2204         yield from self.visit_default(leaf)
2205
2206     def __post_init__(self) -> None:
2207         """You are in a twisty little maze of passages."""
2208         self.current_line = Line(mode=self.mode)
2209
2210         v = self.visit_stmt
2211         Ø: Set[str] = set()
2212         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
2213         self.visit_if_stmt = partial(
2214             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
2215         )
2216         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
2217         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
2218         self.visit_try_stmt = partial(
2219             v, keywords={"try", "except", "else", "finally"}, parens=Ø
2220         )
2221         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
2222         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
2223         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
2224         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
2225         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
2226         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
2227         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
2228         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
2229         self.visit_async_funcdef = self.visit_async_stmt
2230         self.visit_decorated = self.visit_decorators
2231
2232
2233 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
2234 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
2235 OPENING_BRACKETS = set(BRACKET.keys())
2236 CLOSING_BRACKETS = set(BRACKET.values())
2237 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
2238 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
2239
2240
2241 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
2242     """Return whitespace prefix if needed for the given `leaf`.
2243
2244     `complex_subscript` signals whether the given leaf is part of a subscription
2245     which has non-trivial arguments, like arithmetic expressions or function calls.
2246     """
2247     NO = ""
2248     SPACE = " "
2249     DOUBLESPACE = "  "
2250     t = leaf.type
2251     p = leaf.parent
2252     v = leaf.value
2253     if t in ALWAYS_NO_SPACE:
2254         return NO
2255
2256     if t == token.COMMENT:
2257         return DOUBLESPACE
2258
2259     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
2260     if t == token.COLON and p.type not in {
2261         syms.subscript,
2262         syms.subscriptlist,
2263         syms.sliceop,
2264     }:
2265         return NO
2266
2267     prev = leaf.prev_sibling
2268     if not prev:
2269         prevp = preceding_leaf(p)
2270         if not prevp or prevp.type in OPENING_BRACKETS:
2271             return NO
2272
2273         if t == token.COLON:
2274             if prevp.type == token.COLON:
2275                 return NO
2276
2277             elif prevp.type != token.COMMA and not complex_subscript:
2278                 return NO
2279
2280             return SPACE
2281
2282         if prevp.type == token.EQUAL:
2283             if prevp.parent:
2284                 if prevp.parent.type in {
2285                     syms.arglist,
2286                     syms.argument,
2287                     syms.parameters,
2288                     syms.varargslist,
2289                 }:
2290                     return NO
2291
2292                 elif prevp.parent.type == syms.typedargslist:
2293                     # A bit hacky: if the equal sign has whitespace, it means we
2294                     # previously found it's a typed argument.  So, we're using
2295                     # that, too.
2296                     return prevp.prefix
2297
2298         elif prevp.type in VARARGS_SPECIALS:
2299             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2300                 return NO
2301
2302         elif prevp.type == token.COLON:
2303             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
2304                 return SPACE if complex_subscript else NO
2305
2306         elif (
2307             prevp.parent
2308             and prevp.parent.type == syms.factor
2309             and prevp.type in MATH_OPERATORS
2310         ):
2311             return NO
2312
2313         elif (
2314             prevp.type == token.RIGHTSHIFT
2315             and prevp.parent
2316             and prevp.parent.type == syms.shift_expr
2317             and prevp.prev_sibling
2318             and prevp.prev_sibling.type == token.NAME
2319             and prevp.prev_sibling.value == "print"  # type: ignore
2320         ):
2321             # Python 2 print chevron
2322             return NO
2323         elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
2324             # no space in decorators
2325             return NO
2326
2327     elif prev.type in OPENING_BRACKETS:
2328         return NO
2329
2330     if p.type in {syms.parameters, syms.arglist}:
2331         # untyped function signatures or calls
2332         if not prev or prev.type != token.COMMA:
2333             return NO
2334
2335     elif p.type == syms.varargslist:
2336         # lambdas
2337         if prev and prev.type != token.COMMA:
2338             return NO
2339
2340     elif p.type == syms.typedargslist:
2341         # typed function signatures
2342         if not prev:
2343             return NO
2344
2345         if t == token.EQUAL:
2346             if prev.type != syms.tname:
2347                 return NO
2348
2349         elif prev.type == token.EQUAL:
2350             # A bit hacky: if the equal sign has whitespace, it means we
2351             # previously found it's a typed argument.  So, we're using that, too.
2352             return prev.prefix
2353
2354         elif prev.type != token.COMMA:
2355             return NO
2356
2357     elif p.type == syms.tname:
2358         # type names
2359         if not prev:
2360             prevp = preceding_leaf(p)
2361             if not prevp or prevp.type != token.COMMA:
2362                 return NO
2363
2364     elif p.type == syms.trailer:
2365         # attributes and calls
2366         if t == token.LPAR or t == token.RPAR:
2367             return NO
2368
2369         if not prev:
2370             if t == token.DOT:
2371                 prevp = preceding_leaf(p)
2372                 if not prevp or prevp.type != token.NUMBER:
2373                     return NO
2374
2375             elif t == token.LSQB:
2376                 return NO
2377
2378         elif prev.type != token.COMMA:
2379             return NO
2380
2381     elif p.type == syms.argument:
2382         # single argument
2383         if t == token.EQUAL:
2384             return NO
2385
2386         if not prev:
2387             prevp = preceding_leaf(p)
2388             if not prevp or prevp.type == token.LPAR:
2389                 return NO
2390
2391         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2392             return NO
2393
2394     elif p.type == syms.decorator:
2395         # decorators
2396         return NO
2397
2398     elif p.type == syms.dotted_name:
2399         if prev:
2400             return NO
2401
2402         prevp = preceding_leaf(p)
2403         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2404             return NO
2405
2406     elif p.type == syms.classdef:
2407         if t == token.LPAR:
2408             return NO
2409
2410         if prev and prev.type == token.LPAR:
2411             return NO
2412
2413     elif p.type in {syms.subscript, syms.sliceop}:
2414         # indexing
2415         if not prev:
2416             assert p.parent is not None, "subscripts are always parented"
2417             if p.parent.type == syms.subscriptlist:
2418                 return SPACE
2419
2420             return NO
2421
2422         elif not complex_subscript:
2423             return NO
2424
2425     elif p.type == syms.atom:
2426         if prev and t == token.DOT:
2427             # dots, but not the first one.
2428             return NO
2429
2430     elif p.type == syms.dictsetmaker:
2431         # dict unpacking
2432         if prev and prev.type == token.DOUBLESTAR:
2433             return NO
2434
2435     elif p.type in {syms.factor, syms.star_expr}:
2436         # unary ops
2437         if not prev:
2438             prevp = preceding_leaf(p)
2439             if not prevp or prevp.type in OPENING_BRACKETS:
2440                 return NO
2441
2442             prevp_parent = prevp.parent
2443             assert prevp_parent is not None
2444             if prevp.type == token.COLON and prevp_parent.type in {
2445                 syms.subscript,
2446                 syms.sliceop,
2447             }:
2448                 return NO
2449
2450             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2451                 return NO
2452
2453         elif t in {token.NAME, token.NUMBER, token.STRING}:
2454             return NO
2455
2456     elif p.type == syms.import_from:
2457         if t == token.DOT:
2458             if prev and prev.type == token.DOT:
2459                 return NO
2460
2461         elif t == token.NAME:
2462             if v == "import":
2463                 return SPACE
2464
2465             if prev and prev.type == token.DOT:
2466                 return NO
2467
2468     elif p.type == syms.sliceop:
2469         return NO
2470
2471     return SPACE
2472
2473
2474 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2475     """Return the first leaf that precedes `node`, if any."""
2476     while node:
2477         res = node.prev_sibling
2478         if res:
2479             if isinstance(res, Leaf):
2480                 return res
2481
2482             try:
2483                 return list(res.leaves())[-1]
2484
2485             except IndexError:
2486                 return None
2487
2488         node = node.parent
2489     return None
2490
2491
2492 def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
2493     """Return if the `node` and its previous siblings match types against the provided
2494     list of tokens; the provided `node`has its type matched against the last element in
2495     the list.  `None` can be used as the first element to declare that the start of the
2496     list is anchored at the start of its parent's children."""
2497     if not tokens:
2498         return True
2499     if tokens[-1] is None:
2500         return node is None
2501     if not node:
2502         return False
2503     if node.type != tokens[-1]:
2504         return False
2505     return prev_siblings_are(node.prev_sibling, tokens[:-1])
2506
2507
2508 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2509     """Return the child of `ancestor` that contains `descendant`."""
2510     node: Optional[LN] = descendant
2511     while node and node.parent != ancestor:
2512         node = node.parent
2513     return node
2514
2515
2516 def container_of(leaf: Leaf) -> LN:
2517     """Return `leaf` or one of its ancestors that is the topmost container of it.
2518
2519     By "container" we mean a node where `leaf` is the very first child.
2520     """
2521     same_prefix = leaf.prefix
2522     container: LN = leaf
2523     while container:
2524         parent = container.parent
2525         if parent is None:
2526             break
2527
2528         if parent.children[0].prefix != same_prefix:
2529             break
2530
2531         if parent.type == syms.file_input:
2532             break
2533
2534         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2535             break
2536
2537         container = parent
2538     return container
2539
2540
2541 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2542     """Return the priority of the `leaf` delimiter, given a line break after it.
2543
2544     The delimiter priorities returned here are from those delimiters that would
2545     cause a line break after themselves.
2546
2547     Higher numbers are higher priority.
2548     """
2549     if leaf.type == token.COMMA:
2550         return COMMA_PRIORITY
2551
2552     return 0
2553
2554
2555 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2556     """Return the priority of the `leaf` delimiter, given a line break before it.
2557
2558     The delimiter priorities returned here are from those delimiters that would
2559     cause a line break before themselves.
2560
2561     Higher numbers are higher priority.
2562     """
2563     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2564         # * and ** might also be MATH_OPERATORS but in this case they are not.
2565         # Don't treat them as a delimiter.
2566         return 0
2567
2568     if (
2569         leaf.type == token.DOT
2570         and leaf.parent
2571         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2572         and (previous is None or previous.type in CLOSING_BRACKETS)
2573     ):
2574         return DOT_PRIORITY
2575
2576     if (
2577         leaf.type in MATH_OPERATORS
2578         and leaf.parent
2579         and leaf.parent.type not in {syms.factor, syms.star_expr}
2580     ):
2581         return MATH_PRIORITIES[leaf.type]
2582
2583     if leaf.type in COMPARATORS:
2584         return COMPARATOR_PRIORITY
2585
2586     if (
2587         leaf.type == token.STRING
2588         and previous is not None
2589         and previous.type == token.STRING
2590     ):
2591         return STRING_PRIORITY
2592
2593     if leaf.type not in {token.NAME, token.ASYNC}:
2594         return 0
2595
2596     if (
2597         leaf.value == "for"
2598         and leaf.parent
2599         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2600         or leaf.type == token.ASYNC
2601     ):
2602         if (
2603             not isinstance(leaf.prev_sibling, Leaf)
2604             or leaf.prev_sibling.value != "async"
2605         ):
2606             return COMPREHENSION_PRIORITY
2607
2608     if (
2609         leaf.value == "if"
2610         and leaf.parent
2611         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2612     ):
2613         return COMPREHENSION_PRIORITY
2614
2615     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2616         return TERNARY_PRIORITY
2617
2618     if leaf.value == "is":
2619         return COMPARATOR_PRIORITY
2620
2621     if (
2622         leaf.value == "in"
2623         and leaf.parent
2624         and leaf.parent.type in {syms.comp_op, syms.comparison}
2625         and not (
2626             previous is not None
2627             and previous.type == token.NAME
2628             and previous.value == "not"
2629         )
2630     ):
2631         return COMPARATOR_PRIORITY
2632
2633     if (
2634         leaf.value == "not"
2635         and leaf.parent
2636         and leaf.parent.type == syms.comp_op
2637         and not (
2638             previous is not None
2639             and previous.type == token.NAME
2640             and previous.value == "is"
2641         )
2642     ):
2643         return COMPARATOR_PRIORITY
2644
2645     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2646         return LOGIC_PRIORITY
2647
2648     return 0
2649
2650
2651 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2652 FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
2653 FMT_PASS = {*FMT_OFF, *FMT_SKIP}
2654 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2655
2656
2657 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2658     """Clean the prefix of the `leaf` and generate comments from it, if any.
2659
2660     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2661     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2662     move because it does away with modifying the grammar to include all the
2663     possible places in which comments can be placed.
2664
2665     The sad consequence for us though is that comments don't "belong" anywhere.
2666     This is why this function generates simple parentless Leaf objects for
2667     comments.  We simply don't know what the correct parent should be.
2668
2669     No matter though, we can live without this.  We really only need to
2670     differentiate between inline and standalone comments.  The latter don't
2671     share the line with any code.
2672
2673     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2674     are emitted with a fake STANDALONE_COMMENT token identifier.
2675     """
2676     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2677         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2678
2679
2680 @dataclass
2681 class ProtoComment:
2682     """Describes a piece of syntax that is a comment.
2683
2684     It's not a :class:`blib2to3.pytree.Leaf` so that:
2685
2686     * it can be cached (`Leaf` objects should not be reused more than once as
2687       they store their lineno, column, prefix, and parent information);
2688     * `newlines` and `consumed` fields are kept separate from the `value`. This
2689       simplifies handling of special marker comments like ``# fmt: off/on``.
2690     """
2691
2692     type: int  # token.COMMENT or STANDALONE_COMMENT
2693     value: str  # content of the comment
2694     newlines: int  # how many newlines before the comment
2695     consumed: int  # how many characters of the original leaf's prefix did we consume
2696
2697
2698 @lru_cache(maxsize=4096)
2699 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2700     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2701     result: List[ProtoComment] = []
2702     if not prefix or "#" not in prefix:
2703         return result
2704
2705     consumed = 0
2706     nlines = 0
2707     ignored_lines = 0
2708     for index, line in enumerate(re.split("\r?\n", prefix)):
2709         consumed += len(line) + 1  # adding the length of the split '\n'
2710         line = line.lstrip()
2711         if not line:
2712             nlines += 1
2713         if not line.startswith("#"):
2714             # Escaped newlines outside of a comment are not really newlines at
2715             # all. We treat a single-line comment following an escaped newline
2716             # as a simple trailing comment.
2717             if line.endswith("\\"):
2718                 ignored_lines += 1
2719             continue
2720
2721         if index == ignored_lines and not is_endmarker:
2722             comment_type = token.COMMENT  # simple trailing comment
2723         else:
2724             comment_type = STANDALONE_COMMENT
2725         comment = make_comment(line)
2726         result.append(
2727             ProtoComment(
2728                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2729             )
2730         )
2731         nlines = 0
2732     return result
2733
2734
2735 def make_comment(content: str) -> str:
2736     """Return a consistently formatted comment from the given `content` string.
2737
2738     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2739     space between the hash sign and the content.
2740
2741     If `content` didn't start with a hash sign, one is provided.
2742     """
2743     content = content.rstrip()
2744     if not content:
2745         return "#"
2746
2747     if content[0] == "#":
2748         content = content[1:]
2749     NON_BREAKING_SPACE = " "
2750     if (
2751         content
2752         and content[0] == NON_BREAKING_SPACE
2753         and not content.lstrip().startswith("type:")
2754     ):
2755         content = " " + content[1:]  # Replace NBSP by a simple space
2756     if content and content[0] not in " !:#'%":
2757         content = " " + content
2758     return "#" + content
2759
2760
2761 def transform_line(
2762     line: Line, mode: Mode, features: Collection[Feature] = ()
2763 ) -> Iterator[Line]:
2764     """Transform a `line`, potentially splitting it into many lines.
2765
2766     They should fit in the allotted `line_length` but might not be able to.
2767
2768     `features` are syntactical features that may be used in the output.
2769     """
2770     if line.is_comment:
2771         yield line
2772         return
2773
2774     line_str = line_to_string(line)
2775
2776     def init_st(ST: Type[StringTransformer]) -> StringTransformer:
2777         """Initialize StringTransformer"""
2778         return ST(mode.line_length, mode.string_normalization)
2779
2780     string_merge = init_st(StringMerger)
2781     string_paren_strip = init_st(StringParenStripper)
2782     string_split = init_st(StringSplitter)
2783     string_paren_wrap = init_st(StringParenWrapper)
2784
2785     transformers: List[Transformer]
2786     if (
2787         not line.contains_uncollapsable_type_comments()
2788         and not line.should_split_rhs
2789         and not line.magic_trailing_comma
2790         and (
2791             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
2792             or line.contains_unsplittable_type_ignore()
2793         )
2794         and not (line.inside_brackets and line.contains_standalone_comments())
2795     ):
2796         # Only apply basic string preprocessing, since lines shouldn't be split here.
2797         if mode.experimental_string_processing:
2798             transformers = [string_merge, string_paren_strip]
2799         else:
2800             transformers = []
2801     elif line.is_def:
2802         transformers = [left_hand_split]
2803     else:
2804
2805         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2806             """Wraps calls to `right_hand_split`.
2807
2808             The calls increasingly `omit` right-hand trailers (bracket pairs with
2809             content), meaning the trailers get glued together to split on another
2810             bracket pair instead.
2811             """
2812             for omit in generate_trailers_to_omit(line, mode.line_length):
2813                 lines = list(
2814                     right_hand_split(line, mode.line_length, features, omit=omit)
2815                 )
2816                 # Note: this check is only able to figure out if the first line of the
2817                 # *current* transformation fits in the line length.  This is true only
2818                 # for simple cases.  All others require running more transforms via
2819                 # `transform_line()`.  This check doesn't know if those would succeed.
2820                 if is_line_short_enough(lines[0], line_length=mode.line_length):
2821                     yield from lines
2822                     return
2823
2824             # All splits failed, best effort split with no omits.
2825             # This mostly happens to multiline strings that are by definition
2826             # reported as not fitting a single line, as well as lines that contain
2827             # trailing commas (those have to be exploded).
2828             yield from right_hand_split(
2829                 line, line_length=mode.line_length, features=features
2830             )
2831
2832         if mode.experimental_string_processing:
2833             if line.inside_brackets:
2834                 transformers = [
2835                     string_merge,
2836                     string_paren_strip,
2837                     string_split,
2838                     delimiter_split,
2839                     standalone_comment_split,
2840                     string_paren_wrap,
2841                     rhs,
2842                 ]
2843             else:
2844                 transformers = [
2845                     string_merge,
2846                     string_paren_strip,
2847                     string_split,
2848                     string_paren_wrap,
2849                     rhs,
2850                 ]
2851         else:
2852             if line.inside_brackets:
2853                 transformers = [delimiter_split, standalone_comment_split, rhs]
2854             else:
2855                 transformers = [rhs]
2856
2857     for transform in transformers:
2858         # We are accumulating lines in `result` because we might want to abort
2859         # mission and return the original line in the end, or attempt a different
2860         # split altogether.
2861         try:
2862             result = run_transformer(line, transform, mode, features, line_str=line_str)
2863         except CannotTransform:
2864             continue
2865         else:
2866             yield from result
2867             break
2868
2869     else:
2870         yield line
2871
2872
2873 @dataclass  # type: ignore
2874 class StringTransformer(ABC):
2875     """
2876     An implementation of the Transformer protocol that relies on its
2877     subclasses overriding the template methods `do_match(...)` and
2878     `do_transform(...)`.
2879
2880     This Transformer works exclusively on strings (for example, by merging
2881     or splitting them).
2882
2883     The following sections can be found among the docstrings of each concrete
2884     StringTransformer subclass.
2885
2886     Requirements:
2887         Which requirements must be met of the given Line for this
2888         StringTransformer to be applied?
2889
2890     Transformations:
2891         If the given Line meets all of the above requirements, which string
2892         transformations can you expect to be applied to it by this
2893         StringTransformer?
2894
2895     Collaborations:
2896         What contractual agreements does this StringTransformer have with other
2897         StringTransfomers? Such collaborations should be eliminated/minimized
2898         as much as possible.
2899     """
2900
2901     line_length: int
2902     normalize_strings: bool
2903     __name__ = "StringTransformer"
2904
2905     @abstractmethod
2906     def do_match(self, line: Line) -> TMatchResult:
2907         """
2908         Returns:
2909             * Ok(string_idx) such that `line.leaves[string_idx]` is our target
2910             string, if a match was able to be made.
2911                 OR
2912             * Err(CannotTransform), if a match was not able to be made.
2913         """
2914
2915     @abstractmethod
2916     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
2917         """
2918         Yields:
2919             * Ok(new_line) where new_line is the new transformed line.
2920                 OR
2921             * Err(CannotTransform) if the transformation failed for some reason. The
2922             `do_match(...)` template method should usually be used to reject
2923             the form of the given Line, but in some cases it is difficult to
2924             know whether or not a Line meets the StringTransformer's
2925             requirements until the transformation is already midway.
2926
2927         Side Effects:
2928             This method should NOT mutate @line directly, but it MAY mutate the
2929             Line's underlying Node structure. (WARNING: If the underlying Node
2930             structure IS altered, then this method should NOT be allowed to
2931             yield an CannotTransform after that point.)
2932         """
2933
2934     def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]:
2935         """
2936         StringTransformer instances have a call signature that mirrors that of
2937         the Transformer type.
2938
2939         Raises:
2940             CannotTransform(...) if the concrete StringTransformer class is unable
2941             to transform @line.
2942         """
2943         # Optimization to avoid calling `self.do_match(...)` when the line does
2944         # not contain any string.
2945         if not any(leaf.type == token.STRING for leaf in line.leaves):
2946             raise CannotTransform("There are no strings in this line.")
2947
2948         match_result = self.do_match(line)
2949
2950         if isinstance(match_result, Err):
2951             cant_transform = match_result.err()
2952             raise CannotTransform(
2953                 f"The string transformer {self.__class__.__name__} does not recognize"
2954                 " this line as one that it can transform."
2955             ) from cant_transform
2956
2957         string_idx = match_result.ok()
2958
2959         for line_result in self.do_transform(line, string_idx):
2960             if isinstance(line_result, Err):
2961                 cant_transform = line_result.err()
2962                 raise CannotTransform(
2963                     "StringTransformer failed while attempting to transform string."
2964                 ) from cant_transform
2965             line = line_result.ok()
2966             yield line
2967
2968
2969 @dataclass
2970 class CustomSplit:
2971     """A custom (i.e. manual) string split.
2972
2973     A single CustomSplit instance represents a single substring.
2974
2975     Examples:
2976         Consider the following string:
2977         ```
2978         "Hi there friend."
2979         " This is a custom"
2980         f" string {split}."
2981         ```
2982
2983         This string will correspond to the following three CustomSplit instances:
2984         ```
2985         CustomSplit(False, 16)
2986         CustomSplit(False, 17)
2987         CustomSplit(True, 16)
2988         ```
2989     """
2990
2991     has_prefix: bool
2992     break_idx: int
2993
2994
2995 class CustomSplitMapMixin:
2996     """
2997     This mixin class is used to map merged strings to a sequence of
2998     CustomSplits, which will then be used to re-split the strings iff none of
2999     the resultant substrings go over the configured max line length.
3000     """
3001
3002     _Key = Tuple[StringID, str]
3003     _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
3004
3005     @staticmethod
3006     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
3007         """
3008         Returns:
3009             A unique identifier that is used internally to map @string to a
3010             group of custom splits.
3011         """
3012         return (id(string), string)
3013
3014     def add_custom_splits(
3015         self, string: str, custom_splits: Iterable[CustomSplit]
3016     ) -> None:
3017         """Custom Split Map Setter Method
3018
3019         Side Effects:
3020             Adds a mapping from @string to the custom splits @custom_splits.
3021         """
3022         key = self._get_key(string)
3023         self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)
3024
3025     def pop_custom_splits(self, string: str) -> List[CustomSplit]:
3026         """Custom Split Map Getter Method
3027
3028         Returns:
3029             * A list of the custom splits that are mapped to @string, if any
3030             exist.
3031                 OR
3032             * [], otherwise.
3033
3034         Side Effects:
3035             Deletes the mapping between @string and its associated custom
3036             splits (which are returned to the caller).
3037         """
3038         key = self._get_key(string)
3039
3040         custom_splits = self._CUSTOM_SPLIT_MAP[key]
3041         del self._CUSTOM_SPLIT_MAP[key]
3042
3043         return list(custom_splits)
3044
3045     def has_custom_splits(self, string: str) -> bool:
3046         """
3047         Returns:
3048             True iff @string is associated with a set of custom splits.
3049         """
3050         key = self._get_key(string)
3051         return key in self._CUSTOM_SPLIT_MAP
3052
3053
3054 class StringMerger(CustomSplitMapMixin, StringTransformer):
3055     """StringTransformer that merges strings together.
3056
3057     Requirements:
3058         (A) The line contains adjacent strings such that ALL of the validation checks
3059         listed in StringMerger.__validate_msg(...)'s docstring pass.
3060             OR
3061         (B) The line contains a string which uses line continuation backslashes.
3062
3063     Transformations:
3064         Depending on which of the two requirements above where met, either:
3065
3066         (A) The string group associated with the target string is merged.
3067             OR
3068         (B) All line-continuation backslashes are removed from the target string.
3069
3070     Collaborations:
3071         StringMerger provides custom split information to StringSplitter.
3072     """
3073
3074     def do_match(self, line: Line) -> TMatchResult:
3075         LL = line.leaves
3076
3077         is_valid_index = is_valid_index_factory(LL)
3078
3079         for (i, leaf) in enumerate(LL):
3080             if (
3081                 leaf.type == token.STRING
3082                 and is_valid_index(i + 1)
3083                 and LL[i + 1].type == token.STRING
3084             ):
3085                 return Ok(i)
3086
3087             if leaf.type == token.STRING and "\\\n" in leaf.value:
3088                 return Ok(i)
3089
3090         return TErr("This line has no strings that need merging.")
3091
3092     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3093         new_line = line
3094         rblc_result = self.__remove_backslash_line_continuation_chars(
3095             new_line, string_idx
3096         )
3097         if isinstance(rblc_result, Ok):
3098             new_line = rblc_result.ok()
3099
3100         msg_result = self.__merge_string_group(new_line, string_idx)
3101         if isinstance(msg_result, Ok):
3102             new_line = msg_result.ok()
3103
3104         if isinstance(rblc_result, Err) and isinstance(msg_result, Err):
3105             msg_cant_transform = msg_result.err()
3106             rblc_cant_transform = rblc_result.err()
3107             cant_transform = CannotTransform(
3108                 "StringMerger failed to merge any strings in this line."
3109             )
3110
3111             # Chain the errors together using `__cause__`.
3112             msg_cant_transform.__cause__ = rblc_cant_transform
3113             cant_transform.__cause__ = msg_cant_transform
3114
3115             yield Err(cant_transform)
3116         else:
3117             yield Ok(new_line)
3118
3119     @staticmethod
3120     def __remove_backslash_line_continuation_chars(
3121         line: Line, string_idx: int
3122     ) -> TResult[Line]:
3123         """
3124         Merge strings that were split across multiple lines using
3125         line-continuation backslashes.
3126
3127         Returns:
3128             Ok(new_line), if @line contains backslash line-continuation
3129             characters.
3130                 OR
3131             Err(CannotTransform), otherwise.
3132         """
3133         LL = line.leaves
3134
3135         string_leaf = LL[string_idx]
3136         if not (
3137             string_leaf.type == token.STRING
3138             and "\\\n" in string_leaf.value
3139             and not has_triple_quotes(string_leaf.value)
3140         ):
3141             return TErr(
3142                 f"String leaf {string_leaf} does not contain any backslash line"
3143                 " continuation characters."
3144             )
3145
3146         new_line = line.clone()
3147         new_line.comments = line.comments.copy()
3148         append_leaves(new_line, line, LL)
3149
3150         new_string_leaf = new_line.leaves[string_idx]
3151         new_string_leaf.value = new_string_leaf.value.replace("\\\n", "")
3152
3153         return Ok(new_line)
3154
3155     def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]:
3156         """
3157         Merges string group (i.e. set of adjacent strings) where the first
3158         string in the group is `line.leaves[string_idx]`.
3159
3160         Returns:
3161             Ok(new_line), if ALL of the validation checks found in
3162             __validate_msg(...) pass.
3163                 OR
3164             Err(CannotTransform), otherwise.
3165         """
3166         LL = line.leaves
3167
3168         is_valid_index = is_valid_index_factory(LL)
3169
3170         vresult = self.__validate_msg(line, string_idx)
3171         if isinstance(vresult, Err):
3172             return vresult
3173
3174         # If the string group is wrapped inside an Atom node, we must make sure
3175         # to later replace that Atom with our new (merged) string leaf.
3176         atom_node = LL[string_idx].parent
3177
3178         # We will place BREAK_MARK in between every two substrings that we
3179         # merge. We will then later go through our final result and use the
3180         # various instances of BREAK_MARK we find to add the right values to
3181         # the custom split map.
3182         BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@"
3183
3184         QUOTE = LL[string_idx].value[-1]
3185
3186         def make_naked(string: str, string_prefix: str) -> str:
3187             """Strip @string (i.e. make it a "naked" string)
3188
3189             Pre-conditions:
3190                 * assert_is_leaf_string(@string)
3191
3192             Returns:
3193                 A string that is identical to @string except that
3194                 @string_prefix has been stripped, the surrounding QUOTE
3195                 characters have been removed, and any remaining QUOTE
3196                 characters have been escaped.
3197             """
3198             assert_is_leaf_string(string)
3199
3200             RE_EVEN_BACKSLASHES = r"(?:(?<!\\)(?:\\\\)*)"
3201             naked_string = string[len(string_prefix) + 1 : -1]
3202             naked_string = re.sub(
3203                 "(" + RE_EVEN_BACKSLASHES + ")" + QUOTE, r"\1\\" + QUOTE, naked_string
3204             )
3205             return naked_string
3206
3207         # Holds the CustomSplit objects that will later be added to the custom
3208         # split map.
3209         custom_splits = []
3210
3211         # Temporary storage for the 'has_prefix' part of the CustomSplit objects.
3212         prefix_tracker = []
3213
3214         # Sets the 'prefix' variable. This is the prefix that the final merged
3215         # string will have.
3216         next_str_idx = string_idx
3217         prefix = ""
3218         while (
3219             not prefix
3220             and is_valid_index(next_str_idx)
3221             and LL[next_str_idx].type == token.STRING
3222         ):
3223             prefix = get_string_prefix(LL[next_str_idx].value)
3224             next_str_idx += 1
3225
3226         # The next loop merges the string group. The final string will be
3227         # contained in 'S'.
3228         #
3229         # The following convenience variables are used:
3230         #
3231         #   S: string
3232         #   NS: naked string
3233         #   SS: next string
3234         #   NSS: naked next string
3235         S = ""
3236         NS = ""
3237         num_of_strings = 0
3238         next_str_idx = string_idx
3239         while is_valid_index(next_str_idx) and LL[next_str_idx].type == token.STRING:
3240             num_of_strings += 1
3241
3242             SS = LL[next_str_idx].value
3243             next_prefix = get_string_prefix(SS)
3244
3245             # If this is an f-string group but this substring is not prefixed
3246             # with 'f'...
3247             if "f" in prefix and "f" not in next_prefix:
3248                 # Then we must escape any braces contained in this substring.
3249                 SS = re.subf(r"(\{|\})", "{1}{1}", SS)
3250
3251             NSS = make_naked(SS, next_prefix)
3252
3253             has_prefix = bool(next_prefix)
3254             prefix_tracker.append(has_prefix)
3255
3256             S = prefix + QUOTE + NS + NSS + BREAK_MARK + QUOTE
3257             NS = make_naked(S, prefix)
3258
3259             next_str_idx += 1
3260
3261         S_leaf = Leaf(token.STRING, S)
3262         if self.normalize_strings:
3263             normalize_string_quotes(S_leaf)
3264
3265         # Fill the 'custom_splits' list with the appropriate CustomSplit objects.
3266         temp_string = S_leaf.value[len(prefix) + 1 : -1]
3267         for has_prefix in prefix_tracker:
3268             mark_idx = temp_string.find(BREAK_MARK)
3269             assert (
3270                 mark_idx >= 0
3271             ), "Logic error while filling the custom string breakpoint cache."
3272
3273             temp_string = temp_string[mark_idx + len(BREAK_MARK) :]
3274             breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1
3275             custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))
3276
3277         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
3278
3279         if atom_node is not None:
3280             replace_child(atom_node, string_leaf)
3281
3282         # Build the final line ('new_line') that this method will later return.
3283         new_line = line.clone()
3284         for (i, leaf) in enumerate(LL):
3285             if i == string_idx:
3286                 new_line.append(string_leaf)
3287
3288             if string_idx <= i < string_idx + num_of_strings:
3289                 for comment_leaf in line.comments_after(LL[i]):
3290                     new_line.append(comment_leaf, preformatted=True)
3291                 continue
3292
3293             append_leaves(new_line, line, [leaf])
3294
3295         self.add_custom_splits(string_leaf.value, custom_splits)
3296         return Ok(new_line)
3297
3298     @staticmethod
3299     def __validate_msg(line: Line, string_idx: int) -> TResult[None]:
3300         """Validate (M)erge (S)tring (G)roup
3301
3302         Transform-time string validation logic for __merge_string_group(...).
3303
3304         Returns:
3305             * Ok(None), if ALL validation checks (listed below) pass.
3306                 OR
3307             * Err(CannotTransform), if any of the following are true:
3308                 - The target string group does not contain ANY stand-alone comments.
3309                 - The target string is not in a string group (i.e. it has no
3310                   adjacent strings).
3311                 - The string group has more than one inline comment.
3312                 - The string group has an inline comment that appears to be a pragma.
3313                 - The set of all string prefixes in the string group is of
3314                   length greater than one and is not equal to {"", "f"}.
3315                 - The string group consists of raw strings.
3316         """
3317         # We first check for "inner" stand-alone comments (i.e. stand-alone
3318         # comments that have a string leaf before them AND after them).
3319         for inc in [1, -1]:
3320             i = string_idx
3321             found_sa_comment = False
3322             is_valid_index = is_valid_index_factory(line.leaves)
3323             while is_valid_index(i) and line.leaves[i].type in [
3324                 token.STRING,
3325                 STANDALONE_COMMENT,
3326             ]:
3327                 if line.leaves[i].type == STANDALONE_COMMENT:
3328                     found_sa_comment = True
3329                 elif found_sa_comment:
3330                     return TErr(
3331                         "StringMerger does NOT merge string groups which contain "
3332                         "stand-alone comments."
3333                     )
3334
3335                 i += inc
3336
3337         num_of_inline_string_comments = 0
3338         set_of_prefixes = set()
3339         num_of_strings = 0
3340         for leaf in line.leaves[string_idx:]:
3341             if leaf.type != token.STRING:
3342                 # If the string group is trailed by a comma, we count the
3343                 # comments trailing the comma to be one of the string group's
3344                 # comments.
3345                 if leaf.type == token.COMMA and id(leaf) in line.comments:
3346                     num_of_inline_string_comments += 1
3347                 break
3348
3349             if has_triple_quotes(leaf.value):
3350                 return TErr("StringMerger does NOT merge multiline strings.")
3351
3352             num_of_strings += 1
3353             prefix = get_string_prefix(leaf.value)
3354             if "r" in prefix:
3355                 return TErr("StringMerger does NOT merge raw strings.")
3356
3357             set_of_prefixes.add(prefix)
3358
3359             if id(leaf) in line.comments:
3360                 num_of_inline_string_comments += 1
3361                 if contains_pragma_comment(line.comments[id(leaf)]):
3362                     return TErr("Cannot merge strings which have pragma comments.")
3363
3364         if num_of_strings < 2:
3365             return TErr(
3366                 f"Not enough strings to merge (num_of_strings={num_of_strings})."
3367             )
3368
3369         if num_of_inline_string_comments > 1:
3370             return TErr(
3371                 f"Too many inline string comments ({num_of_inline_string_comments})."
3372             )
3373
3374         if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}:
3375             return TErr(f"Too many different prefixes ({set_of_prefixes}).")
3376
3377         return Ok(None)
3378
3379
3380 class StringParenStripper(StringTransformer):
3381     """StringTransformer that strips surrounding parentheses from strings.
3382
3383     Requirements:
3384         The line contains a string which is surrounded by parentheses and:
3385             - The target string is NOT the only argument to a function call.
3386             - The target string is NOT a "pointless" string.
3387             - If the target string contains a PERCENT, the brackets are not
3388               preceeded or followed by an operator with higher precedence than
3389               PERCENT.
3390
3391     Transformations:
3392         The parentheses mentioned in the 'Requirements' section are stripped.
3393
3394     Collaborations:
3395         StringParenStripper has its own inherent usefulness, but it is also
3396         relied on to clean up the parentheses created by StringParenWrapper (in
3397         the event that they are no longer needed).
3398     """
3399
3400     def do_match(self, line: Line) -> TMatchResult:
3401         LL = line.leaves
3402
3403         is_valid_index = is_valid_index_factory(LL)
3404
3405         for (idx, leaf) in enumerate(LL):
3406             # Should be a string...
3407             if leaf.type != token.STRING:
3408                 continue
3409
3410             # If this is a "pointless" string...
3411             if (
3412                 leaf.parent
3413                 and leaf.parent.parent
3414                 and leaf.parent.parent.type == syms.simple_stmt
3415             ):
3416                 continue
3417
3418             # Should be preceded by a non-empty LPAR...
3419             if (
3420                 not is_valid_index(idx - 1)
3421                 or LL[idx - 1].type != token.LPAR
3422                 or is_empty_lpar(LL[idx - 1])
3423             ):
3424                 continue
3425
3426             # That LPAR should NOT be preceded by a function name or a closing
3427             # bracket (which could be a function which returns a function or a
3428             # list/dictionary that contains a function)...
3429             if is_valid_index(idx - 2) and (
3430                 LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS
3431             ):
3432                 continue
3433
3434             string_idx = idx
3435
3436             # Skip the string trailer, if one exists.
3437             string_parser = StringParser()
3438             next_idx = string_parser.parse(LL, string_idx)
3439
3440             # if the leaves in the parsed string include a PERCENT, we need to
3441             # make sure the initial LPAR is NOT preceded by an operator with
3442             # higher or equal precedence to PERCENT
3443             if is_valid_index(idx - 2):
3444                 # mypy can't quite follow unless we name this
3445                 before_lpar = LL[idx - 2]
3446                 if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (
3447                     (
3448                         before_lpar.type
3449                         in {
3450                             token.STAR,
3451                             token.AT,
3452                             token.SLASH,
3453                             token.DOUBLESLASH,
3454                             token.PERCENT,
3455                             token.TILDE,
3456                             token.DOUBLESTAR,
3457                             token.AWAIT,
3458                             token.LSQB,
3459                             token.LPAR,
3460                         }
3461                     )
3462                     or (
3463                         # only unary PLUS/MINUS
3464                         before_lpar.parent
3465                         and before_lpar.parent.type == syms.factor
3466                         and (before_lpar.type in {token.PLUS, token.MINUS})
3467                     )
3468                 ):
3469                     continue
3470
3471             # Should be followed by a non-empty RPAR...
3472             if (
3473                 is_valid_index(next_idx)
3474                 and LL[next_idx].type == token.RPAR
3475                 and not is_empty_rpar(LL[next_idx])
3476             ):
3477                 # That RPAR should NOT be followed by anything with higher
3478                 # precedence than PERCENT
3479                 if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {
3480                     token.DOUBLESTAR,
3481                     token.LSQB,
3482                     token.LPAR,
3483                     token.DOT,
3484                 }:
3485                     continue
3486
3487                 return Ok(string_idx)
3488
3489         return TErr("This line has no strings wrapped in parens.")
3490
3491     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3492         LL = line.leaves
3493
3494         string_parser = StringParser()
3495         rpar_idx = string_parser.parse(LL, string_idx)
3496
3497         for leaf in (LL[string_idx - 1], LL[rpar_idx]):
3498             if line.comments_after(leaf):
3499                 yield TErr(
3500                     "Will not strip parentheses which have comments attached to them."
3501                 )
3502                 return
3503
3504         new_line = line.clone()
3505         new_line.comments = line.comments.copy()
3506         try:
3507             append_leaves(new_line, line, LL[: string_idx - 1])
3508         except BracketMatchError:
3509             # HACK: I believe there is currently a bug somewhere in
3510             # right_hand_split() that is causing brackets to not be tracked
3511             # properly by a shared BracketTracker.
3512             append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
3513
3514         string_leaf = Leaf(token.STRING, LL[string_idx].value)
3515         LL[string_idx - 1].remove()
3516         replace_child(LL[string_idx], string_leaf)
3517         new_line.append(string_leaf)
3518
3519         append_leaves(
3520             new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :]
3521         )
3522
3523         LL[rpar_idx].remove()
3524
3525         yield Ok(new_line)
3526
3527
3528 class BaseStringSplitter(StringTransformer):
3529     """
3530     Abstract class for StringTransformers which transform a Line's strings by splitting
3531     them or placing them on their own lines where necessary to avoid going over
3532     the configured line length.
3533
3534     Requirements:
3535         * The target string value is responsible for the line going over the
3536         line length limit. It follows that after all of black's other line
3537         split methods have been exhausted, this line (or one of the resulting
3538         lines after all line splits are performed) would still be over the
3539         line_length limit unless we split this string.
3540             AND
3541         * The target string is NOT a "pointless" string (i.e. a string that has
3542         no parent or siblings).
3543             AND
3544         * The target string is not followed by an inline comment that appears
3545         to be a pragma.
3546             AND
3547         * The target string is not a multiline (i.e. triple-quote) string.
3548     """
3549
3550     @abstractmethod
3551     def do_splitter_match(self, line: Line) -> TMatchResult:
3552         """
3553         BaseStringSplitter asks its clients to override this method instead of
3554         `StringTransformer.do_match(...)`.
3555
3556         Follows the same protocol as `StringTransformer.do_match(...)`.
3557
3558         Refer to `help(StringTransformer.do_match)` for more information.
3559         """
3560
3561     def do_match(self, line: Line) -> TMatchResult:
3562         match_result = self.do_splitter_match(line)
3563         if isinstance(match_result, Err):
3564             return match_result
3565
3566         string_idx = match_result.ok()
3567         vresult = self.__validate(line, string_idx)
3568         if isinstance(vresult, Err):
3569             return vresult
3570
3571         return match_result
3572
3573     def __validate(self, line: Line, string_idx: int) -> TResult[None]:
3574         """
3575         Checks that @line meets all of the requirements listed in this classes'
3576         docstring. Refer to `help(BaseStringSplitter)` for a detailed
3577         description of those requirements.
3578
3579         Returns:
3580             * Ok(None), if ALL of the requirements are met.
3581                 OR
3582             * Err(CannotTransform), if ANY of the requirements are NOT met.
3583         """
3584         LL = line.leaves
3585
3586         string_leaf = LL[string_idx]
3587
3588         max_string_length = self.__get_max_string_length(line, string_idx)
3589         if len(string_leaf.value) <= max_string_length:
3590             return TErr(
3591                 "The string itself is not what is causing this line to be too long."
3592             )
3593
3594         if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [
3595             token.STRING,
3596             token.NEWLINE,
3597         ]:
3598             return TErr(
3599                 f"This string ({string_leaf.value}) appears to be pointless (i.e. has"
3600                 " no parent)."
3601             )
3602
3603         if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(
3604             line.comments[id(line.leaves[string_idx])]
3605         ):
3606             return TErr(
3607                 "Line appears to end with an inline pragma comment. Splitting the line"
3608                 " could modify the pragma's behavior."
3609             )
3610
3611         if has_triple_quotes(string_leaf.value):
3612             return TErr("We cannot split multiline strings.")
3613
3614         return Ok(None)
3615
3616     def __get_max_string_length(self, line: Line, string_idx: int) -> int:
3617         """
3618         Calculates the max string length used when attempting to determine
3619         whether or not the target string is responsible for causing the line to
3620         go over the line length limit.
3621
3622         WARNING: This method is tightly coupled to both StringSplitter and
3623         (especially) StringParenWrapper. There is probably a better way to
3624         accomplish what is being done here.
3625
3626         Returns:
3627             max_string_length: such that `line.leaves[string_idx].value >
3628             max_string_length` implies that the target string IS responsible
3629             for causing this line to exceed the line length limit.
3630         """
3631         LL = line.leaves
3632
3633         is_valid_index = is_valid_index_factory(LL)
3634
3635         # We use the shorthand "WMA4" in comments to abbreviate "We must
3636         # account for". When giving examples, we use STRING to mean some/any
3637         # valid string.
3638         #
3639         # Finally, we use the following convenience variables:
3640         #
3641         #   P:  The leaf that is before the target string leaf.
3642         #   N:  The leaf that is after the target string leaf.
3643         #   NN: The leaf that is after N.
3644
3645         # WMA4 the whitespace at the beginning of the line.
3646         offset = line.depth * 4
3647
3648         if is_valid_index(string_idx - 1):
3649             p_idx = string_idx - 1
3650             if (
3651                 LL[string_idx - 1].type == token.LPAR
3652                 and LL[string_idx - 1].value == ""
3653                 and string_idx >= 2
3654             ):
3655                 # If the previous leaf is an empty LPAR placeholder, we should skip it.
3656                 p_idx -= 1
3657
3658             P = LL[p_idx]
3659             if P.type == token.PLUS:
3660                 # WMA4 a space and a '+' character (e.g. `+ STRING`).
3661                 offset += 2
3662
3663             if P.type == token.COMMA:
3664                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
3665                 offset += 3
3666
3667             if P.type in [token.COLON, token.EQUAL, token.NAME]:
3668                 # This conditional branch is meant to handle dictionary keys,
3669                 # variable assignments, 'return STRING' statement lines, and
3670                 # 'else STRING' ternary expression lines.
3671
3672                 # WMA4 a single space.
3673                 offset += 1
3674
3675                 # WMA4 the lengths of any leaves that came before that space,
3676                 # but after any closing bracket before that space.
3677                 for leaf in reversed(LL[: p_idx + 1]):
3678                     offset += len(str(leaf))
3679                     if leaf.type in CLOSING_BRACKETS:
3680                         break
3681
3682         if is_valid_index(string_idx + 1):
3683             N = LL[string_idx + 1]
3684             if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2:
3685                 # If the next leaf is an empty RPAR placeholder, we should skip it.
3686                 N = LL[string_idx + 2]
3687
3688             if N.type == token.COMMA:
3689                 # WMA4 a single comma at the end of the string (e.g `STRING,`).
3690                 offset += 1
3691
3692             if is_valid_index(string_idx + 2):
3693                 NN = LL[string_idx + 2]
3694
3695                 if N.type == token.DOT and NN.type == token.NAME:
3696                     # This conditional branch is meant to handle method calls invoked
3697                     # off of a string literal up to and including the LPAR character.
3698
3699                     # WMA4 the '.' character.
3700                     offset += 1
3701
3702                     if (
3703                         is_valid_index(string_idx + 3)
3704                         and LL[string_idx + 3].type == token.LPAR
3705                     ):
3706                         # WMA4 the left parenthesis character.
3707                         offset += 1
3708
3709                     # WMA4 the length of the method's name.
3710                     offset += len(NN.value)
3711
3712         has_comments = False
3713         for comment_leaf in line.comments_after(LL[string_idx]):
3714             if not has_comments:
3715                 has_comments = True
3716                 # WMA4 two spaces before the '#' character.
3717                 offset += 2
3718
3719             # WMA4 the length of the inline comment.
3720             offset += len(comment_leaf.value)
3721
3722         max_string_length = self.line_length - offset
3723         return max_string_length
3724
3725
3726 class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
3727     """
3728     StringTransformer that splits "atom" strings (i.e. strings which exist on
3729     lines by themselves).
3730
3731     Requirements:
3732         * The line consists ONLY of a single string (with the exception of a
3733         '+' symbol which MAY exist at the start of the line), MAYBE a string
3734         trailer, and MAYBE a trailing comma.
3735             AND
3736         * All of the requirements listed in BaseStringSplitter's docstring.
3737
3738     Transformations:
3739         The string mentioned in the 'Requirements' section is split into as
3740         many substrings as necessary to adhere to the configured line length.
3741
3742         In the final set of substrings, no substring should be smaller than
3743         MIN_SUBSTR_SIZE characters.
3744
3745         The string will ONLY be split on spaces (i.e. each new substring should
3746         start with a space). Note that the string will NOT be split on a space
3747         which is escaped with a backslash.
3748
3749         If the string is an f-string, it will NOT be split in the middle of an
3750         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
3751         else bar()} is an f-expression).
3752
3753         If the string that is being split has an associated set of custom split
3754         records and those custom splits will NOT result in any line going over
3755         the configured line length, those custom splits are used. Otherwise the
3756         string is split as late as possible (from left-to-right) while still
3757         adhering to the transformation rules listed above.
3758
3759     Collaborations:
3760         StringSplitter relies on StringMerger to construct the appropriate
3761         CustomSplit objects and add them to the custom split map.
3762     """
3763
3764     MIN_SUBSTR_SIZE = 6
3765     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
3766     RE_FEXPR = r"""
3767     (?<!\{) (?:\{\{)* \{ (?!\{)
3768         (?:
3769             [^\{\}]
3770             | \{\{
3771             | \}\}
3772             | (?R)
3773         )+?
3774     (?<!\}) \} (?:\}\})* (?!\})
3775     """
3776
3777     def do_splitter_match(self, line: Line) -> TMatchResult:
3778         LL = line.leaves
3779
3780         is_valid_index = is_valid_index_factory(LL)
3781
3782         idx = 0
3783
3784         # The first leaf MAY be a '+' symbol...
3785         if is_valid_index(idx) and LL[idx].type == token.PLUS:
3786             idx += 1
3787
3788         # The next/first leaf MAY be an empty LPAR...
3789         if is_valid_index(idx) and is_empty_lpar(LL[idx]):
3790             idx += 1
3791
3792         # The next/first leaf MUST be a string...
3793         if not is_valid_index(idx) or LL[idx].type != token.STRING:
3794             return TErr("Line does not start with a string.")
3795
3796         string_idx = idx
3797
3798         # Skip the string trailer, if one exists.
3799         string_parser = StringParser()
3800         idx = string_parser.parse(LL, string_idx)
3801
3802         # That string MAY be followed by an empty RPAR...
3803         if is_valid_index(idx) and is_empty_rpar(LL[idx]):
3804             idx += 1
3805
3806         # That string / empty RPAR leaf MAY be followed by a comma...
3807         if is_valid_index(idx) and LL[idx].type == token.COMMA:
3808             idx += 1
3809
3810         # But no more leaves are allowed...
3811         if is_valid_index(idx):
3812             return TErr("This line does not end with a string.")
3813
3814         return Ok(string_idx)
3815
3816     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
3817         LL = line.leaves
3818
3819         QUOTE = LL[string_idx].value[-1]
3820
3821         is_valid_index = is_valid_index_factory(LL)
3822         insert_str_child = insert_str_child_factory(LL[string_idx])
3823
3824         prefix = get_string_prefix(LL[string_idx].value)
3825
3826         # We MAY choose to drop the 'f' prefix from substrings that don't
3827         # contain any f-expressions, but ONLY if the original f-string
3828         # contains at least one f-expression. Otherwise, we will alter the AST
3829         # of the program.
3830         drop_pointless_f_prefix = ("f" in prefix) and re.search(
3831             self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
3832         )
3833
3834         first_string_line = True
3835         starts_with_plus = LL[0].type == token.PLUS
3836
3837         def line_needs_plus() -> bool:
3838             return first_string_line and starts_with_plus
3839
3840         def maybe_append_plus(new_line: Line) -> None:
3841             """
3842             Side Effects:
3843                 If @line starts with a plus and this is the first line we are
3844                 constructing, this function appends a PLUS leaf to @new_line
3845                 and replaces the old PLUS leaf in the node structure. Otherwise
3846                 this function does nothing.
3847             """
3848             if line_needs_plus():
3849                 plus_leaf = Leaf(token.PLUS, "+")
3850                 replace_child(LL[0], plus_leaf)
3851                 new_line.append(plus_leaf)
3852
3853         ends_with_comma = (
3854             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
3855         )
3856
3857         def max_last_string() -> int:
3858             """
3859             Returns:
3860                 The max allowed length of the string value used for the last
3861                 line we will construct.
3862             """
3863             result = self.line_length
3864             result -= line.depth * 4
3865             result -= 1 if ends_with_comma else 0
3866             result -= 2 if line_needs_plus() else 0
3867             return result
3868
3869         # --- Calculate Max Break Index (for string value)
3870         # We start with the line length limit
3871         max_break_idx = self.line_length
3872         # The last index of a string of length N is N-1.
3873         max_break_idx -= 1
3874         # Leading whitespace is not present in the string value (e.g. Leaf.value).
3875         max_break_idx -= line.depth * 4
3876         if max_break_idx < 0:
3877             yield TErr(
3878                 f"Unable to split {LL[string_idx].value} at such high of a line depth:"
3879                 f" {line.depth}"
3880             )
3881             return
3882
3883         # Check if StringMerger registered any custom splits.
3884         custom_splits = self.pop_custom_splits(LL[string_idx].value)
3885         # We use them ONLY if none of them would produce lines that exceed the
3886         # line limit.
3887         use_custom_breakpoints = bool(
3888             custom_splits
3889             and all(csplit.break_idx <= max_break_idx for csplit in custom_splits)
3890         )
3891
3892         # Temporary storage for the remaining chunk of the string line that
3893         # can't fit onto the line currently being constructed.
3894         rest_value = LL[string_idx].value
3895
3896         def more_splits_should_be_made() -> bool:
3897             """
3898             Returns:
3899                 True iff `rest_value` (the remaining string value from the last
3900                 split), should be split again.
3901             """
3902             if use_custom_breakpoints:
3903                 return len(custom_splits) > 1
3904             else:
3905                 return len(rest_value) > max_last_string()
3906
3907         string_line_results: List[Ok[Line]] = []
3908         while more_splits_should_be_made():
3909             if use_custom_breakpoints:
3910                 # Custom User Split (manual)
3911                 csplit = custom_splits.pop(0)
3912                 break_idx = csplit.break_idx
3913             else:
3914                 # Algorithmic Split (automatic)
3915                 max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
3916                 maybe_break_idx = self.__get_break_idx(rest_value, max_bidx)
3917                 if maybe_break_idx is None:
3918                     # If we are unable to algorithmically determine a good split
3919                     # and this string has custom splits registered to it, we
3920                     # fall back to using them--which means we have to start
3921                     # over from the beginning.
3922                     if custom_splits:
3923                         rest_value = LL[string_idx].value
3924                         string_line_results = []
3925                         first_string_line = True
3926                         use_custom_breakpoints = True
3927                         continue
3928
3929                     # Otherwise, we stop splitting here.
3930                     break
3931
3932                 break_idx = maybe_break_idx
3933
3934             # --- Construct `next_value`
3935             next_value = rest_value[:break_idx] + QUOTE
3936             if (
3937                 # Are we allowed to try to drop a pointless 'f' prefix?
3938                 drop_pointless_f_prefix
3939                 # If we are, will we be successful?
3940                 and next_value != self.__normalize_f_string(next_value, prefix)
3941             ):
3942                 # If the current custom split did NOT originally use a prefix,
3943                 # then `csplit.break_idx` will be off by one after removing
3944                 # the 'f' prefix.
3945                 break_idx = (
3946                     break_idx + 1
3947                     if use_custom_breakpoints and not csplit.has_prefix
3948                     else break_idx
3949                 )
3950                 next_value = rest_value[:break_idx] + QUOTE
3951                 next_value = self.__normalize_f_string(next_value, prefix)
3952
3953             # --- Construct `next_leaf`
3954             next_leaf = Leaf(token.STRING, next_value)
3955             insert_str_child(next_leaf)
3956             self.__maybe_normalize_string_quotes(next_leaf)
3957
3958             # --- Construct `next_line`
3959             next_line = line.clone()
3960             maybe_append_plus(next_line)
3961             next_line.append(next_leaf)
3962             string_line_results.append(Ok(next_line))
3963
3964             rest_value = prefix + QUOTE + rest_value[break_idx:]
3965             first_string_line = False
3966
3967         yield from string_line_results
3968
3969         if drop_pointless_f_prefix:
3970             rest_value = self.__normalize_f_string(rest_value, prefix)
3971
3972         rest_leaf = Leaf(token.STRING, rest_value)
3973         insert_str_child(rest_leaf)
3974
3975         # NOTE: I could not find a test case that verifies that the following
3976         # line is actually necessary, but it seems to be. Otherwise we risk
3977         # not normalizing the last substring, right?
3978         self.__maybe_normalize_string_quotes(rest_leaf)
3979
3980         last_line = line.clone()
3981         maybe_append_plus(last_line)
3982
3983         # If there are any leaves to the right of the target string...
3984         if is_valid_index(string_idx + 1):
3985             # We use `temp_value` here to determine how long the last line
3986             # would be if we were to append all the leaves to the right of the
3987             # target string to the last string line.
3988             temp_value = rest_value
3989             for leaf in LL[string_idx + 1 :]:
3990                 temp_value += str(leaf)
3991                 if leaf.type == token.LPAR:
3992                     break
3993
3994             # Try to fit them all on the same line with the last substring...
3995             if (
3996                 len(temp_value) <= max_last_string()
3997                 or LL[string_idx + 1].type == token.COMMA
3998             ):
3999                 last_line.append(rest_leaf)
4000                 append_leaves(last_line, line, LL[string_idx + 1 :])
4001                 yield Ok(last_line)
4002             # Otherwise, place the last substring on one line and everything
4003             # else on a line below that...
4004             else:
4005                 last_line.append(rest_leaf)
4006                 yield Ok(last_line)
4007
4008                 non_string_line = line.clone()
4009                 append_leaves(non_string_line, line, LL[string_idx + 1 :])
4010                 yield Ok(non_string_line)
4011         # Else the target string was the last leaf...
4012         else:
4013             last_line.append(rest_leaf)
4014             last_line.comments = line.comments.copy()
4015             yield Ok(last_line)
4016
4017     def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
4018         """
4019         This method contains the algorithm that StringSplitter uses to
4020         determine which character to split each string at.
4021
4022         Args:
4023             @string: The substring that we are attempting to split.
4024             @max_break_idx: The ideal break index. We will return this value if it
4025             meets all the necessary conditions. In the likely event that it
4026             doesn't we will try to find the closest index BELOW @max_break_idx
4027             that does. If that fails, we will expand our search by also
4028             considering all valid indices ABOVE @max_break_idx.
4029
4030         Pre-Conditions:
4031             * assert_is_leaf_string(@string)
4032             * 0 <= @max_break_idx < len(@string)
4033
4034         Returns:
4035             break_idx, if an index is able to be found that meets all of the
4036             conditions listed in the 'Transformations' section of this classes'
4037             docstring.
4038                 OR
4039             None, otherwise.
4040         """
4041         is_valid_index = is_valid_index_factory(string)
4042
4043         assert is_valid_index(max_break_idx)
4044         assert_is_leaf_string(string)
4045
4046         _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
4047
4048         def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
4049             """
4050             Yields:
4051                 All ranges of @string which, if @string were to be split there,
4052                 would result in the splitting of an f-expression (which is NOT
4053                 allowed).
4054             """
4055             nonlocal _fexpr_slices
4056
4057             if _fexpr_slices is None:
4058                 _fexpr_slices = []
4059                 for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
4060                     _fexpr_slices.append(match.span())
4061
4062             yield from _fexpr_slices
4063
4064         is_fstring = "f" in get_string_prefix(string)
4065
4066         def breaks_fstring_expression(i: Index) -> bool:
4067             """
4068             Returns:
4069                 True iff returning @i would result in the splitting of an
4070                 f-expression (which is NOT allowed).
4071             """
4072             if not is_fstring:
4073                 return False
4074
4075             for (start, end) in fexpr_slices():
4076                 if start <= i < end:
4077                     return True
4078
4079             return False
4080
4081         def passes_all_checks(i: Index) -> bool:
4082             """
4083             Returns:
4084                 True iff ALL of the conditions listed in the 'Transformations'
4085                 section of this classes' docstring would be be met by returning @i.
4086             """
4087             is_space = string[i] == " "
4088
4089             is_not_escaped = True
4090             j = i - 1
4091             while is_valid_index(j) and string[j] == "\\":
4092                 is_not_escaped = not is_not_escaped
4093                 j -= 1
4094
4095             is_big_enough = (
4096                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
4097                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
4098             )
4099             return (
4100                 is_space
4101                 and is_not_escaped
4102                 and is_big_enough
4103                 and not breaks_fstring_expression(i)
4104             )
4105
4106         # First, we check all indices BELOW @max_break_idx.
4107         break_idx = max_break_idx
4108         while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):
4109             break_idx -= 1
4110
4111         if not passes_all_checks(break_idx):
4112             # If that fails, we check all indices ABOVE @max_break_idx.
4113             #
4114             # If we are able to find a valid index here, the next line is going
4115             # to be longer than the specified line length, but it's probably
4116             # better than doing nothing at all.
4117             break_idx = max_break_idx + 1
4118             while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):
4119                 break_idx += 1
4120
4121             if not is_valid_index(break_idx) or not passes_all_checks(break_idx):
4122                 return None
4123
4124         return break_idx
4125
4126     def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None:
4127         if self.normalize_strings:
4128             normalize_string_quotes(leaf)
4129
4130     def __normalize_f_string(self, string: str, prefix: str) -> str:
4131         """
4132         Pre-Conditions:
4133             * assert_is_leaf_string(@string)
4134
4135         Returns:
4136             * If @string is an f-string that contains no f-expressions, we
4137             return a string identical to @string except that the 'f' prefix
4138             has been stripped and all double braces (i.e. '{{' or '}}') have
4139             been normalized (i.e. turned into '{' or '}').
4140                 OR
4141             * Otherwise, we return @string.
4142         """
4143         assert_is_leaf_string(string)
4144
4145         if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
4146             new_prefix = prefix.replace("f", "")
4147
4148             temp = string[len(prefix) :]
4149             temp = re.sub(r"\{\{", "{", temp)
4150             temp = re.sub(r"\}\}", "}", temp)
4151             new_string = temp
4152
4153             return f"{new_prefix}{new_string}"
4154         else:
4155             return string
4156
4157
4158 class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
4159     """
4160     StringTransformer that splits non-"atom" strings (i.e. strings that do not
4161     exist on lines by themselves).
4162
4163     Requirements:
4164         All of the requirements listed in BaseStringSplitter's docstring in
4165         addition to the requirements listed below:
4166
4167         * The line is a return/yield statement, which returns/yields a string.
4168             OR
4169         * The line is part of a ternary expression (e.g. `x = y if cond else
4170         z`) such that the line starts with `else <string>`, where <string> is
4171         some string.
4172             OR
4173         * The line is an assert statement, which ends with a string.
4174             OR
4175         * The line is an assignment statement (e.g. `x = <string>` or `x +=
4176         <string>`) such that the variable is being assigned the value of some
4177         string.
4178             OR
4179         * The line is a dictionary key assignment where some valid key is being
4180         assigned the value of some string.
4181
4182     Transformations:
4183         The chosen string is wrapped in parentheses and then split at the LPAR.
4184
4185         We then have one line which ends with an LPAR and another line that
4186         starts with the chosen string. The latter line is then split again at
4187         the RPAR. This results in the RPAR (and possibly a trailing comma)
4188         being placed on its own line.
4189
4190         NOTE: If any leaves exist to the right of the chosen string (except
4191         for a trailing comma, which would be placed after the RPAR), those
4192         leaves are placed inside the parentheses.  In effect, the chosen
4193         string is not necessarily being "wrapped" by parentheses. We can,
4194         however, count on the LPAR being placed directly before the chosen
4195         string.
4196
4197         In other words, StringParenWrapper creates "atom" strings. These
4198         can then be split again by StringSplitter, if necessary.
4199
4200     Collaborations:
4201         In the event that a string line split by StringParenWrapper is
4202         changed such that it no longer needs to be given its own line,
4203         StringParenWrapper relies on StringParenStripper to clean up the
4204         parentheses it created.
4205     """
4206
4207     def do_splitter_match(self, line: Line) -> TMatchResult:
4208         LL = line.leaves
4209
4210         string_idx = (
4211             self._return_match(LL)
4212             or self._else_match(LL)
4213             or self._assert_match(LL)
4214             or self._assign_match(LL)
4215             or self._dict_match(LL)
4216         )
4217
4218         if string_idx is not None:
4219             string_value = line.leaves[string_idx].value
4220             # If the string has no spaces...
4221             if " " not in string_value:
4222                 # And will still violate the line length limit when split...
4223                 max_string_length = self.line_length - ((line.depth + 1) * 4)
4224                 if len(string_value) > max_string_length:
4225                     # And has no associated custom splits...
4226                     if not self.has_custom_splits(string_value):
4227                         # Then we should NOT put this string on its own line.
4228                         return TErr(
4229                             "We do not wrap long strings in parentheses when the"
4230                             " resultant line would still be over the specified line"
4231                             " length and can't be split further by StringSplitter."
4232                         )
4233             return Ok(string_idx)
4234
4235         return TErr("This line does not contain any non-atomic strings.")
4236
4237     @staticmethod
4238     def _return_match(LL: List[Leaf]) -> Optional[int]:
4239         """
4240         Returns:
4241             string_idx such that @LL[string_idx] is equal to our target (i.e.
4242             matched) string, if this line matches the return/yield statement
4243             requirements listed in the 'Requirements' section of this classes'
4244             docstring.
4245                 OR
4246             None, otherwise.
4247         """
4248         # If this line is apart of a return/yield statement and the first leaf
4249         # contains either the "return" or "yield" keywords...
4250         if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[
4251             0
4252         ].value in ["return", "yield"]:
4253             is_valid_index = is_valid_index_factory(LL)
4254
4255             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4256             # The next visible leaf MUST contain a string...
4257             if is_valid_index(idx) and LL[idx].type == token.STRING:
4258                 return idx
4259
4260         return None
4261
4262     @staticmethod
4263     def _else_match(LL: List[Leaf]) -> Optional[int]:
4264         """
4265         Returns:
4266             string_idx such that @LL[string_idx] is equal to our target (i.e.
4267             matched) string, if this line matches the ternary expression
4268             requirements listed in the 'Requirements' section of this classes'
4269             docstring.
4270                 OR
4271             None, otherwise.
4272         """
4273         # If this line is apart of a ternary expression and the first leaf
4274         # contains the "else" keyword...
4275         if (
4276             parent_type(LL[0]) == syms.test
4277             and LL[0].type == token.NAME
4278             and LL[0].value == "else"
4279         ):
4280             is_valid_index = is_valid_index_factory(LL)
4281
4282             idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1
4283             # The next visible leaf MUST contain a string...
4284             if is_valid_index(idx) and LL[idx].type == token.STRING:
4285                 return idx
4286
4287         return None
4288
4289     @staticmethod
4290     def _assert_match(LL: List[Leaf]) -> Optional[int]:
4291         """
4292         Returns:
4293             string_idx such that @LL[string_idx] is equal to our target (i.e.
4294             matched) string, if this line matches the assert statement
4295             requirements listed in the 'Requirements' section of this classes'
4296             docstring.
4297                 OR
4298             None, otherwise.
4299         """
4300         # If this line is apart of an assert statement and the first leaf
4301         # contains the "assert" keyword...
4302         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
4303             is_valid_index = is_valid_index_factory(LL)
4304
4305             for (i, leaf) in enumerate(LL):
4306                 # We MUST find a comma...
4307                 if leaf.type == token.COMMA:
4308                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4309
4310                     # That comma MUST be followed by a string...
4311                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4312                         string_idx = idx
4313
4314                         # Skip the string trailer, if one exists.
4315                         string_parser = StringParser()
4316                         idx = string_parser.parse(LL, string_idx)
4317
4318                         # But no more leaves are allowed...
4319                         if not is_valid_index(idx):
4320                             return string_idx
4321
4322         return None
4323
4324     @staticmethod
4325     def _assign_match(LL: List[Leaf]) -> Optional[int]:
4326         """
4327         Returns:
4328             string_idx such that @LL[string_idx] is equal to our target (i.e.
4329             matched) string, if this line matches the assignment statement
4330             requirements listed in the 'Requirements' section of this classes'
4331             docstring.
4332                 OR
4333             None, otherwise.
4334         """
4335         # If this line is apart of an expression statement or is a function
4336         # argument AND the first leaf contains a variable name...
4337         if (
4338             parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power]
4339             and LL[0].type == token.NAME
4340         ):
4341             is_valid_index = is_valid_index_factory(LL)
4342
4343             for (i, leaf) in enumerate(LL):
4344                 # We MUST find either an '=' or '+=' symbol...
4345                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
4346                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4347
4348                     # That symbol MUST be followed by a string...
4349                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4350                         string_idx = idx
4351
4352                         # Skip the string trailer, if one exists.
4353                         string_parser = StringParser()
4354                         idx = string_parser.parse(LL, string_idx)
4355
4356                         # The next leaf MAY be a comma iff this line is apart
4357                         # of a function argument...
4358                         if (
4359                             parent_type(LL[0]) == syms.argument
4360                             and is_valid_index(idx)
4361                             and LL[idx].type == token.COMMA
4362                         ):
4363                             idx += 1
4364
4365                         # But no more leaves are allowed...
4366                         if not is_valid_index(idx):
4367                             return string_idx
4368
4369         return None
4370
4371     @staticmethod
4372     def _dict_match(LL: List[Leaf]) -> Optional[int]:
4373         """
4374         Returns:
4375             string_idx such that @LL[string_idx] is equal to our target (i.e.
4376             matched) string, if this line matches the dictionary key assignment
4377             statement requirements listed in the 'Requirements' section of this
4378             classes' docstring.
4379                 OR
4380             None, otherwise.
4381         """
4382         # If this line is apart of a dictionary key assignment...
4383         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
4384             is_valid_index = is_valid_index_factory(LL)
4385
4386             for (i, leaf) in enumerate(LL):
4387                 # We MUST find a colon...
4388                 if leaf.type == token.COLON:
4389                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
4390
4391                     # That colon MUST be followed by a string...
4392                     if is_valid_index(idx) and LL[idx].type == token.STRING:
4393                         string_idx = idx
4394
4395                         # Skip the string trailer, if one exists.
4396                         string_parser = StringParser()
4397                         idx = string_parser.parse(LL, string_idx)
4398
4399                         # That string MAY be followed by a comma...
4400                         if is_valid_index(idx) and LL[idx].type == token.COMMA:
4401                             idx += 1
4402
4403                         # But no more leaves are allowed...
4404                         if not is_valid_index(idx):
4405                             return string_idx
4406
4407         return None
4408
4409     def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]:
4410         LL = line.leaves
4411
4412         is_valid_index = is_valid_index_factory(LL)
4413         insert_str_child = insert_str_child_factory(LL[string_idx])
4414
4415         comma_idx = -1
4416         ends_with_comma = False
4417         if LL[comma_idx].type == token.COMMA:
4418             ends_with_comma = True
4419
4420         leaves_to_steal_comments_from = [LL[string_idx]]
4421         if ends_with_comma:
4422             leaves_to_steal_comments_from.append(LL[comma_idx])
4423
4424         # --- First Line
4425         first_line = line.clone()
4426         left_leaves = LL[:string_idx]
4427
4428         # We have to remember to account for (possibly invisible) LPAR and RPAR
4429         # leaves that already wrapped the target string. If these leaves do
4430         # exist, we will replace them with our own LPAR and RPAR leaves.
4431         old_parens_exist = False
4432         if left_leaves and left_leaves[-1].type == token.LPAR:
4433             old_parens_exist = True
4434             leaves_to_steal_comments_from.append(left_leaves[-1])
4435             left_leaves.pop()
4436
4437         append_leaves(first_line, line, left_leaves)
4438
4439         lpar_leaf = Leaf(token.LPAR, "(")
4440         if old_parens_exist:
4441             replace_child(LL[string_idx - 1], lpar_leaf)
4442         else:
4443             insert_str_child(lpar_leaf)
4444         first_line.append(lpar_leaf)
4445
4446         # We throw inline comments that were originally to the right of the
4447         # target string to the top line. They will now be shown to the right of
4448         # the LPAR.
4449         for leaf in leaves_to_steal_comments_from:
4450             for comment_leaf in line.comments_after(leaf):
4451                 first_line.append(comment_leaf, preformatted=True)
4452
4453         yield Ok(first_line)
4454
4455         # --- Middle (String) Line
4456         # We only need to yield one (possibly too long) string line, since the
4457         # `StringSplitter` will break it down further if necessary.
4458         string_value = LL[string_idx].value
4459         string_line = Line(
4460             mode=line.mode,
4461             depth=line.depth + 1,
4462             inside_brackets=True,
4463             should_split_rhs=line.should_split_rhs,
4464             magic_trailing_comma=line.magic_trailing_comma,
4465         )
4466         string_leaf = Leaf(token.STRING, string_value)
4467         insert_str_child(string_leaf)
4468         string_line.append(string_leaf)
4469
4470         old_rpar_leaf = None
4471         if is_valid_index(string_idx + 1):
4472             right_leaves = LL[string_idx + 1 :]
4473             if ends_with_comma:
4474                 right_leaves.pop()
4475
4476             if old_parens_exist:
4477                 assert (
4478                     right_leaves and right_leaves[-1].type == token.RPAR
4479                 ), "Apparently, old parentheses do NOT exist?!"
4480                 old_rpar_leaf = right_leaves.pop()
4481
4482             append_leaves(string_line, line, right_leaves)
4483
4484         yield Ok(string_line)
4485
4486         # --- Last Line
4487         last_line = line.clone()
4488         last_line.bracket_tracker = first_line.bracket_tracker
4489
4490         new_rpar_leaf = Leaf(token.RPAR, ")")
4491         if old_rpar_leaf is not None:
4492             replace_child(old_rpar_leaf, new_rpar_leaf)
4493         else:
4494             insert_str_child(new_rpar_leaf)
4495         last_line.append(new_rpar_leaf)
4496
4497         # If the target string ended with a comma, we place this comma to the
4498         # right of the RPAR on the last line.
4499         if ends_with_comma:
4500             comma_leaf = Leaf(token.COMMA, ",")
4501             replace_child(LL[comma_idx], comma_leaf)
4502             last_line.append(comma_leaf)
4503
4504         yield Ok(last_line)
4505
4506
4507 class StringParser:
4508     """
4509     A state machine that aids in parsing a string's "trailer", which can be
4510     either non-existent, an old-style formatting sequence (e.g. `% varX` or `%
4511     (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX,
4512     varY)`).
4513
4514     NOTE: A new StringParser object MUST be instantiated for each string
4515     trailer we need to parse.
4516
4517     Examples:
4518         We shall assume that `line` equals the `Line` object that corresponds
4519         to the following line of python code:
4520         ```
4521         x = "Some {}.".format("String") + some_other_string
4522         ```
4523
4524         Furthermore, we will assume that `string_idx` is some index such that:
4525         ```
4526         assert line.leaves[string_idx].value == "Some {}."
4527         ```
4528
4529         The following code snippet then holds:
4530         ```
4531         string_parser = StringParser()
4532         idx = string_parser.parse(line.leaves, string_idx)
4533         assert line.leaves[idx].type == token.PLUS
4534         ```
4535     """
4536
4537     DEFAULT_TOKEN = -1
4538
4539     # String Parser States
4540     START = 1
4541     DOT = 2
4542     NAME = 3
4543     PERCENT = 4
4544     SINGLE_FMT_ARG = 5
4545     LPAR = 6
4546     RPAR = 7
4547     DONE = 8
4548
4549     # Lookup Table for Next State
4550     _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
4551         # A string trailer may start with '.' OR '%'.
4552         (START, token.DOT): DOT,
4553         (START, token.PERCENT): PERCENT,
4554         (START, DEFAULT_TOKEN): DONE,
4555         # A '.' MUST be followed by an attribute or method name.
4556         (DOT, token.NAME): NAME,
4557         # A method name MUST be followed by an '(', whereas an attribute name
4558         # is the last symbol in the string trailer.
4559         (NAME, token.LPAR): LPAR,
4560         (NAME, DEFAULT_TOKEN): DONE,
4561         # A '%' symbol can be followed by an '(' or a single argument (e.g. a
4562         # string or variable name).
4563         (PERCENT, token.LPAR): LPAR,
4564         (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG,
4565         # If a '%' symbol is followed by a single argument, that argument is
4566         # the last leaf in the string trailer.
4567         (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE,
4568         # If present, a ')' symbol is the last symbol in a string trailer.
4569         # (NOTE: LPARS and nested RPARS are not included in this lookup table,
4570         # since they are treated as a special case by the parsing logic in this
4571         # classes' implementation.)
4572         (RPAR, DEFAULT_TOKEN): DONE,
4573     }
4574
4575     def __init__(self) -> None:
4576         self._state = self.START
4577         self._unmatched_lpars = 0
4578
4579     def parse(self, leaves: List[Leaf], string_idx: int) -> int:
4580         """
4581         Pre-conditions:
4582             * @leaves[@string_idx].type == token.STRING
4583
4584         Returns:
4585             The index directly after the last leaf which is apart of the string
4586             trailer, if a "trailer" exists.
4587                 OR
4588             @string_idx + 1, if no string "trailer" exists.
4589         """
4590         assert leaves[string_idx].type == token.STRING
4591
4592         idx = string_idx + 1
4593         while idx < len(leaves) and self._next_state(leaves[idx]):
4594             idx += 1
4595         return idx
4596
4597     def _next_state(self, leaf: Leaf) -> bool:
4598         """
4599         Pre-conditions:
4600             * On the first call to this function, @leaf MUST be the leaf that
4601             was directly after the string leaf in question (e.g. if our target
4602             string is `line.leaves[i]` then the first call to this method must
4603             be `line.leaves[i + 1]`).
4604             * On the next call to this function, the leaf parameter passed in
4605             MUST be the leaf directly following @leaf.
4606
4607         Returns:
4608             True iff @leaf is apart of the string's trailer.
4609         """
4610         # We ignore empty LPAR or RPAR leaves.
4611         if is_empty_par(leaf):
4612             return True
4613
4614         next_token = leaf.type
4615         if next_token == token.LPAR:
4616             self._unmatched_lpars += 1
4617
4618         current_state = self._state
4619
4620         # The LPAR parser state is a special case. We will return True until we
4621         # find the matching RPAR token.
4622         if current_state == self.LPAR:
4623             if next_token == token.RPAR:
4624                 self._unmatched_lpars -= 1
4625                 if self._unmatched_lpars == 0:
4626                     self._state = self.RPAR
4627         # Otherwise, we use a lookup table to determine the next state.
4628         else:
4629             # If the lookup table matches the current state to the next
4630             # token, we use the lookup table.
4631             if (current_state, next_token) in self._goto:
4632                 self._state = self._goto[current_state, next_token]
4633             else:
4634                 # Otherwise, we check if a the current state was assigned a
4635                 # default.
4636                 if (current_state, self.DEFAULT_TOKEN) in self._goto:
4637                     self._state = self._goto[current_state, self.DEFAULT_TOKEN]
4638                 # If no default has been assigned, then this parser has a logic
4639                 # error.
4640                 else:
4641                     raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
4642
4643             if self._state == self.DONE:
4644                 return False
4645
4646         return True
4647
4648
4649 def TErr(err_msg: str) -> Err[CannotTransform]:
4650     """(T)ransform Err
4651
4652     Convenience function used when working with the TResult type.
4653     """
4654     cant_transform = CannotTransform(err_msg)
4655     return Err(cant_transform)
4656
4657
4658 def contains_pragma_comment(comment_list: List[Leaf]) -> bool:
4659     """
4660     Returns:
4661         True iff one of the comments in @comment_list is a pragma used by one
4662         of the more common static analysis tools for python (e.g. mypy, flake8,
4663         pylint).
4664     """
4665     for comment in comment_list:
4666         if comment.value.startswith(("# type:", "# noqa", "# pylint:")):
4667             return True
4668
4669     return False
4670
4671
4672 def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]:
4673     """
4674     Factory for a convenience function that is used to orphan @string_leaf
4675     and then insert multiple new leaves into the same part of the node
4676     structure that @string_leaf had originally occupied.
4677
4678     Examples:
4679         Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N =
4680         string_leaf.parent`. Assume the node `N` has the following
4681         original structure:
4682
4683         Node(
4684             expr_stmt, [
4685                 Leaf(NAME, 'x'),
4686                 Leaf(EQUAL, '='),
4687                 Leaf(STRING, '"foo"'),
4688             ]
4689         )
4690
4691         We then run the code snippet shown below.
4692         ```
4693         insert_str_child = insert_str_child_factory(string_leaf)
4694
4695         lpar = Leaf(token.LPAR, '(')
4696         insert_str_child(lpar)
4697
4698         bar = Leaf(token.STRING, '"bar"')
4699         insert_str_child(bar)
4700
4701         rpar = Leaf(token.RPAR, ')')
4702         insert_str_child(rpar)
4703         ```
4704
4705         After which point, it follows that `string_leaf.parent is None` and
4706         the node `N` now has the following structure:
4707
4708         Node(
4709             expr_stmt, [
4710                 Leaf(NAME, 'x'),
4711                 Leaf(EQUAL, '='),
4712                 Leaf(LPAR, '('),
4713                 Leaf(STRING, '"bar"'),
4714                 Leaf(RPAR, ')'),
4715             ]
4716         )
4717     """
4718     string_parent = string_leaf.parent
4719     string_child_idx = string_leaf.remove()
4720
4721     def insert_str_child(child: LN) -> None:
4722         nonlocal string_child_idx
4723
4724         assert string_parent is not None
4725         assert string_child_idx is not None
4726
4727         string_parent.insert_child(string_child_idx, child)
4728         string_child_idx += 1
4729
4730     return insert_str_child
4731
4732
4733 def has_triple_quotes(string: str) -> bool:
4734     """
4735     Returns:
4736         True iff @string starts with three quotation characters.
4737     """
4738     raw_string = string.lstrip(STRING_PREFIX_CHARS)
4739     return raw_string[:3] in {'"""', "'''"}
4740
4741
4742 def parent_type(node: Optional[LN]) -> Optional[NodeType]:
4743     """
4744     Returns:
4745         @node.parent.type, if @node is not None and has a parent.
4746             OR
4747         None, otherwise.
4748     """
4749     if node is None or node.parent is None:
4750         return None
4751
4752     return node.parent.type
4753
4754
4755 def is_empty_par(leaf: Leaf) -> bool:
4756     return is_empty_lpar(leaf) or is_empty_rpar(leaf)
4757
4758
4759 def is_empty_lpar(leaf: Leaf) -> bool:
4760     return leaf.type == token.LPAR and leaf.value == ""
4761
4762
4763 def is_empty_rpar(leaf: Leaf) -> bool:
4764     return leaf.type == token.RPAR and leaf.value == ""
4765
4766
4767 def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]:
4768     """
4769     Examples:
4770         ```
4771         my_list = [1, 2, 3]
4772
4773         is_valid_index = is_valid_index_factory(my_list)
4774
4775         assert is_valid_index(0)
4776         assert is_valid_index(2)
4777
4778         assert not is_valid_index(3)
4779         assert not is_valid_index(-1)
4780         ```
4781     """
4782
4783     def is_valid_index(idx: int) -> bool:
4784         """
4785         Returns:
4786             True iff @idx is positive AND seq[@idx] does NOT raise an
4787             IndexError.
4788         """
4789         return 0 <= idx < len(seq)
4790
4791     return is_valid_index
4792
4793
4794 def line_to_string(line: Line) -> str:
4795     """Returns the string representation of @line.
4796
4797     WARNING: This is known to be computationally expensive.
4798     """
4799     return str(line).strip("\n")
4800
4801
4802 def append_leaves(
4803     new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
4804 ) -> None:
4805     """
4806     Append leaves (taken from @old_line) to @new_line, making sure to fix the
4807     underlying Node structure where appropriate.
4808
4809     All of the leaves in @leaves are duplicated. The duplicates are then
4810     appended to @new_line and used to replace their originals in the underlying
4811     Node structure. Any comments attached to the old leaves are reattached to
4812     the new leaves.
4813
4814     Pre-conditions:
4815         set(@leaves) is a subset of set(@old_line.leaves).
4816     """
4817     for old_leaf in leaves:
4818         new_leaf = Leaf(old_leaf.type, old_leaf.value)
4819         replace_child(old_leaf, new_leaf)
4820         new_line.append(new_leaf, preformatted=preformatted)
4821
4822         for comment_leaf in old_line.comments_after(old_leaf):
4823             new_line.append(comment_leaf, preformatted=True)
4824
4825
4826 def replace_child(old_child: LN, new_child: LN) -> None:
4827     """
4828     Side Effects:
4829         * If @old_child.parent is set, replace @old_child with @new_child in
4830         @old_child's underlying Node structure.
4831             OR
4832         * Otherwise, this function does nothing.
4833     """
4834     parent = old_child.parent
4835     if not parent:
4836         return
4837
4838     child_idx = old_child.remove()
4839     if child_idx is not None:
4840         parent.insert_child(child_idx, new_child)
4841
4842
4843 def get_string_prefix(string: str) -> str:
4844     """
4845     Pre-conditions:
4846         * assert_is_leaf_string(@string)
4847
4848     Returns:
4849         @string's prefix (e.g. '', 'r', 'f', or 'rf').
4850     """
4851     assert_is_leaf_string(string)
4852
4853     prefix = ""
4854     prefix_idx = 0
4855     while string[prefix_idx] in STRING_PREFIX_CHARS:
4856         prefix += string[prefix_idx].lower()
4857         prefix_idx += 1
4858
4859     return prefix
4860
4861
4862 def assert_is_leaf_string(string: str) -> None:
4863     """
4864     Checks the pre-condition that @string has the format that you would expect
4865     of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==
4866     token.STRING`. A more precise description of the pre-conditions that are
4867     checked are listed below.
4868
4869     Pre-conditions:
4870         * @string starts with either ', ", <prefix>', or <prefix>" where
4871         `set(<prefix>)` is some subset of `set(STRING_PREFIX_CHARS)`.
4872         * @string ends with a quote character (' or ").
4873
4874     Raises:
4875         AssertionError(...) if the pre-conditions listed above are not
4876         satisfied.
4877     """
4878     dquote_idx = string.find('"')
4879     squote_idx = string.find("'")
4880     if -1 in [dquote_idx, squote_idx]:
4881         quote_idx = max(dquote_idx, squote_idx)
4882     else:
4883         quote_idx = min(squote_idx, dquote_idx)
4884
4885     assert (
4886         0 <= quote_idx < len(string) - 1
4887     ), f"{string!r} is missing a starting quote character (' or \")."
4888     assert string[-1] in (
4889         "'",
4890         '"',
4891     ), f"{string!r} is missing an ending quote character (' or \")."
4892     assert set(string[:quote_idx]).issubset(
4893         set(STRING_PREFIX_CHARS)
4894     ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}."
4895
4896
4897 def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
4898     """Split line into many lines, starting with the first matching bracket pair.
4899
4900     Note: this usually looks weird, only use this for function definitions.
4901     Prefer RHS otherwise.  This is why this function is not symmetrical with
4902     :func:`right_hand_split` which also handles optional parentheses.
4903     """
4904     tail_leaves: List[Leaf] = []
4905     body_leaves: List[Leaf] = []
4906     head_leaves: List[Leaf] = []
4907     current_leaves = head_leaves
4908     matching_bracket: Optional[Leaf] = None
4909     for leaf in line.leaves:
4910         if (
4911             current_leaves is body_leaves
4912             and leaf.type in CLOSING_BRACKETS
4913             and leaf.opening_bracket is matching_bracket
4914         ):
4915             current_leaves = tail_leaves if body_leaves else head_leaves
4916         current_leaves.append(leaf)
4917         if current_leaves is head_leaves:
4918             if leaf.type in OPENING_BRACKETS:
4919                 matching_bracket = leaf
4920                 current_leaves = body_leaves
4921     if not matching_bracket:
4922         raise CannotSplit("No brackets found")
4923
4924     head = bracket_split_build_line(head_leaves, line, matching_bracket)
4925     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
4926     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
4927     bracket_split_succeeded_or_raise(head, body, tail)
4928     for result in (head, body, tail):
4929         if result:
4930             yield result
4931
4932
4933 def right_hand_split(
4934     line: Line,
4935     line_length: int,
4936     features: Collection[Feature] = (),
4937     omit: Collection[LeafID] = (),
4938 ) -> Iterator[Line]:
4939     """Split line into many lines, starting with the last matching bracket pair.
4940
4941     If the split was by optional parentheses, attempt splitting without them, too.
4942     `omit` is a collection of closing bracket IDs that shouldn't be considered for
4943     this split.
4944
4945     Note: running this function modifies `bracket_depth` on the leaves of `line`.
4946     """
4947     tail_leaves: List[Leaf] = []
4948     body_leaves: List[Leaf] = []
4949     head_leaves: List[Leaf] = []
4950     current_leaves = tail_leaves
4951     opening_bracket: Optional[Leaf] = None
4952     closing_bracket: Optional[Leaf] = None
4953     for leaf in reversed(line.leaves):
4954         if current_leaves is body_leaves:
4955             if leaf is opening_bracket:
4956                 current_leaves = head_leaves if body_leaves else tail_leaves
4957         current_leaves.append(leaf)
4958         if current_leaves is tail_leaves:
4959             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
4960                 opening_bracket = leaf.opening_bracket
4961                 closing_bracket = leaf
4962                 current_leaves = body_leaves
4963     if not (opening_bracket and closing_bracket and head_leaves):
4964         # If there is no opening or closing_bracket that means the split failed and
4965         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
4966         # the matching `opening_bracket` wasn't available on `line` anymore.
4967         raise CannotSplit("No brackets found")
4968
4969     tail_leaves.reverse()
4970     body_leaves.reverse()
4971     head_leaves.reverse()
4972     head = bracket_split_build_line(head_leaves, line, opening_bracket)
4973     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
4974     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
4975     bracket_split_succeeded_or_raise(head, body, tail)
4976     if (
4977         Feature.FORCE_OPTIONAL_PARENTHESES not in features
4978         # the opening bracket is an optional paren
4979         and opening_bracket.type == token.LPAR
4980         and not opening_bracket.value
4981         # the closing bracket is an optional paren
4982         and closing_bracket.type == token.RPAR
4983         and not closing_bracket.value
4984         # it's not an import (optional parens are the only thing we can split on
4985         # in this case; attempting a split without them is a waste of time)
4986         and not line.is_import
4987         # there are no standalone comments in the body
4988         and not body.contains_standalone_comments(0)
4989         # and we can actually remove the parens
4990         and can_omit_invisible_parens(body, line_length, omit_on_explode=omit)
4991     ):
4992         omit = {id(closing_bracket), *omit}
4993         try:
4994             yield from right_hand_split(line, line_length, features=features, omit=omit)
4995             return
4996
4997         except CannotSplit:
4998             if not (
4999                 can_be_split(body)
5000                 or is_line_short_enough(body, line_length=line_length)
5001             ):
5002                 raise CannotSplit(
5003                     "Splitting failed, body is still too long and can't be split."
5004                 )
5005
5006             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
5007                 raise CannotSplit(
5008                     "The current optional pair of parentheses is bound to fail to"
5009                     " satisfy the splitting algorithm because the head or the tail"
5010                     " contains multiline strings which by definition never fit one"
5011                     " line."
5012                 )
5013
5014     ensure_visible(opening_bracket)
5015     ensure_visible(closing_bracket)
5016     for result in (head, body, tail):
5017         if result:
5018             yield result
5019
5020
5021 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
5022     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
5023
5024     Do nothing otherwise.
5025
5026     A left- or right-hand split is based on a pair of brackets. Content before
5027     (and including) the opening bracket is left on one line, content inside the
5028     brackets is put on a separate line, and finally content starting with and
5029     following the closing bracket is put on a separate line.
5030
5031     Those are called `head`, `body`, and `tail`, respectively. If the split
5032     produced the same line (all content in `head`) or ended up with an empty `body`
5033     and the `tail` is just the closing bracket, then it's considered failed.
5034     """
5035     tail_len = len(str(tail).strip())
5036     if not body:
5037         if tail_len == 0:
5038             raise CannotSplit("Splitting brackets produced the same line")
5039
5040         elif tail_len < 3:
5041             raise CannotSplit(
5042                 f"Splitting brackets on an empty body to save {tail_len} characters is"
5043                 " not worth it"
5044             )
5045
5046
5047 def bracket_split_build_line(
5048     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
5049 ) -> Line:
5050     """Return a new line with given `leaves` and respective comments from `original`.
5051
5052     If `is_body` is True, the result line is one-indented inside brackets and as such
5053     has its first leaf's prefix normalized and a trailing comma added when expected.
5054     """
5055     result = Line(mode=original.mode, depth=original.depth)
5056     if is_body:
5057         result.inside_brackets = True
5058         result.depth += 1
5059         if leaves:
5060             # Since body is a new indent level, remove spurious leading whitespace.
5061             normalize_prefix(leaves[0], inside_brackets=True)
5062             # Ensure a trailing comma for imports and standalone function arguments, but
5063             # be careful not to add one after any comments or within type annotations.
5064             no_commas = (
5065                 original.is_def
5066                 and opening_bracket.value == "("
5067                 and not any(leaf.type == token.COMMA for leaf in leaves)
5068             )
5069
5070             if original.is_import or no_commas:
5071                 for i in range(len(leaves) - 1, -1, -1):
5072                     if leaves[i].type == STANDALONE_COMMENT:
5073                         continue
5074
5075                     if leaves[i].type != token.COMMA:
5076                         new_comma = Leaf(token.COMMA, ",")
5077                         leaves.insert(i + 1, new_comma)
5078                     break
5079
5080     # Populate the line
5081     for leaf in leaves:
5082         result.append(leaf, preformatted=True)
5083         for comment_after in original.comments_after(leaf):
5084             result.append(comment_after, preformatted=True)
5085     if is_body and should_split_line(result, opening_bracket):
5086         result.should_split_rhs = True
5087     return result
5088
5089
5090 def dont_increase_indentation(split_func: Transformer) -> Transformer:
5091     """Normalize prefix of the first leaf in every line returned by `split_func`.
5092
5093     This is a decorator over relevant split functions.
5094     """
5095
5096     @wraps(split_func)
5097     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5098         for line in split_func(line, features):
5099             normalize_prefix(line.leaves[0], inside_brackets=True)
5100             yield line
5101
5102     return split_wrapper
5103
5104
5105 @dont_increase_indentation
5106 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
5107     """Split according to delimiters of the highest priority.
5108
5109     If the appropriate Features are given, the split will add trailing commas
5110     also in function signatures and calls that contain `*` and `**`.
5111     """
5112     try:
5113         last_leaf = line.leaves[-1]
5114     except IndexError:
5115         raise CannotSplit("Line empty")
5116
5117     bt = line.bracket_tracker
5118     try:
5119         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
5120     except ValueError:
5121         raise CannotSplit("No delimiters found")
5122
5123     if delimiter_priority == DOT_PRIORITY:
5124         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
5125             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
5126
5127     current_line = Line(
5128         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5129     )
5130     lowest_depth = sys.maxsize
5131     trailing_comma_safe = True
5132
5133     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5134         """Append `leaf` to current line or to new line if appending impossible."""
5135         nonlocal current_line
5136         try:
5137             current_line.append_safe(leaf, preformatted=True)
5138         except ValueError:
5139             yield current_line
5140
5141             current_line = Line(
5142                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5143             )
5144             current_line.append(leaf)
5145
5146     for leaf in line.leaves:
5147         yield from append_to_line(leaf)
5148
5149         for comment_after in line.comments_after(leaf):
5150             yield from append_to_line(comment_after)
5151
5152         lowest_depth = min(lowest_depth, leaf.bracket_depth)
5153         if leaf.bracket_depth == lowest_depth:
5154             if is_vararg(leaf, within={syms.typedargslist}):
5155                 trailing_comma_safe = (
5156                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
5157                 )
5158             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
5159                 trailing_comma_safe = (
5160                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
5161                 )
5162
5163         leaf_priority = bt.delimiters.get(id(leaf))
5164         if leaf_priority == delimiter_priority:
5165             yield current_line
5166
5167             current_line = Line(
5168                 mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5169             )
5170     if current_line:
5171         if (
5172             trailing_comma_safe
5173             and delimiter_priority == COMMA_PRIORITY
5174             and current_line.leaves[-1].type != token.COMMA
5175             and current_line.leaves[-1].type != STANDALONE_COMMENT
5176         ):
5177             new_comma = Leaf(token.COMMA, ",")
5178             current_line.append(new_comma)
5179         yield current_line
5180
5181
5182 @dont_increase_indentation
5183 def standalone_comment_split(
5184     line: Line, features: Collection[Feature] = ()
5185 ) -> Iterator[Line]:
5186     """Split standalone comments from the rest of the line."""
5187     if not line.contains_standalone_comments(0):
5188         raise CannotSplit("Line does not have any standalone comments")
5189
5190     current_line = Line(
5191         mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5192     )
5193
5194     def append_to_line(leaf: Leaf) -> Iterator[Line]:
5195         """Append `leaf` to current line or to new line if appending impossible."""
5196         nonlocal current_line
5197         try:
5198             current_line.append_safe(leaf, preformatted=True)
5199         except ValueError:
5200             yield current_line
5201
5202             current_line = Line(
5203                 line.mode, depth=line.depth, inside_brackets=line.inside_brackets
5204             )
5205             current_line.append(leaf)
5206
5207     for leaf in line.leaves:
5208         yield from append_to_line(leaf)
5209
5210         for comment_after in line.comments_after(leaf):
5211             yield from append_to_line(comment_after)
5212
5213     if current_line:
5214         yield current_line
5215
5216
5217 def is_import(leaf: Leaf) -> bool:
5218     """Return True if the given leaf starts an import statement."""
5219     p = leaf.parent
5220     t = leaf.type
5221     v = leaf.value
5222     return bool(
5223         t == token.NAME
5224         and (
5225             (v == "import" and p and p.type == syms.import_name)
5226             or (v == "from" and p and p.type == syms.import_from)
5227         )
5228     )
5229
5230
5231 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
5232     """Return True if the given leaf is a special comment.
5233     Only returns true for type comments for now."""
5234     t = leaf.type
5235     v = leaf.value
5236     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
5237
5238
5239 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
5240     """Leave existing extra newlines if not `inside_brackets`. Remove everything
5241     else.
5242
5243     Note: don't use backslashes for formatting or you'll lose your voting rights.
5244     """
5245     if not inside_brackets:
5246         spl = leaf.prefix.split("#")
5247         if "\\" not in spl[0]:
5248             nl_count = spl[-1].count("\n")
5249             if len(spl) > 1:
5250                 nl_count -= 1
5251             leaf.prefix = "\n" * nl_count
5252             return
5253
5254     leaf.prefix = ""
5255
5256
5257 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
5258     """Make all string prefixes lowercase.
5259
5260     If remove_u_prefix is given, also removes any u prefix from the string.
5261
5262     Note: Mutates its argument.
5263     """
5264     match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL)
5265     assert match is not None, f"failed to match string {leaf.value!r}"
5266     orig_prefix = match.group(1)
5267     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
5268     if remove_u_prefix:
5269         new_prefix = new_prefix.replace("u", "")
5270     leaf.value = f"{new_prefix}{match.group(2)}"
5271
5272
5273 def normalize_string_quotes(leaf: Leaf) -> None:
5274     """Prefer double quotes but only if it doesn't cause more escaping.
5275
5276     Adds or removes backslashes as appropriate. Doesn't parse and fix
5277     strings nested in f-strings (yet).
5278
5279     Note: Mutates its argument.
5280     """
5281     value = leaf.value.lstrip(STRING_PREFIX_CHARS)
5282     if value[:3] == '"""':
5283         return
5284
5285     elif value[:3] == "'''":
5286         orig_quote = "'''"
5287         new_quote = '"""'
5288     elif value[0] == '"':
5289         orig_quote = '"'
5290         new_quote = "'"
5291     else:
5292         orig_quote = "'"
5293         new_quote = '"'
5294     first_quote_pos = leaf.value.find(orig_quote)
5295     if first_quote_pos == -1:
5296         return  # There's an internal error
5297
5298     prefix = leaf.value[:first_quote_pos]
5299     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
5300     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
5301     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
5302     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
5303     if "r" in prefix.casefold():
5304         if unescaped_new_quote.search(body):
5305             # There's at least one unescaped new_quote in this raw string
5306             # so converting is impossible
5307             return
5308
5309         # Do not introduce or remove backslashes in raw strings
5310         new_body = body
5311     else:
5312         # remove unnecessary escapes
5313         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
5314         if body != new_body:
5315             # Consider the string without unnecessary escapes as the original
5316             body = new_body
5317             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
5318         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
5319         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
5320     if "f" in prefix.casefold():
5321         matches = re.findall(
5322             r"""
5323             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
5324                 ([^{].*?)  # contents of the brackets except if begins with {{
5325             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
5326             """,
5327             new_body,
5328             re.VERBOSE,
5329         )
5330         for m in matches:
5331             if "\\" in str(m):
5332                 # Do not introduce backslashes in interpolated expressions
5333                 return
5334
5335     if new_quote == '"""' and new_body[-1:] == '"':
5336         # edge case:
5337         new_body = new_body[:-1] + '\\"'
5338     orig_escape_count = body.count("\\")
5339     new_escape_count = new_body.count("\\")
5340     if new_escape_count > orig_escape_count:
5341         return  # Do not introduce more escaping
5342
5343     if new_escape_count == orig_escape_count and orig_quote == '"':
5344         return  # Prefer double quotes
5345
5346     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
5347
5348
5349 def normalize_numeric_literal(leaf: Leaf) -> None:
5350     """Normalizes numeric (float, int, and complex) literals.
5351
5352     All letters used in the representation are normalized to lowercase (except
5353     in Python 2 long literals).
5354     """
5355     text = leaf.value.lower()
5356     if text.startswith(("0o", "0b")):
5357         # Leave octal and binary literals alone.
5358         pass
5359     elif text.startswith("0x"):
5360         text = format_hex(text)
5361     elif "e" in text:
5362         text = format_scientific_notation(text)
5363     elif text.endswith(("j", "l")):
5364         text = format_long_or_complex_number(text)
5365     else:
5366         text = format_float_or_int_string(text)
5367     leaf.value = text
5368
5369
5370 def format_hex(text: str) -> str:
5371     """
5372     Formats a hexadecimal string like "0x12B3"
5373     """
5374     before, after = text[:2], text[2:]
5375     return f"{before}{after.upper()}"
5376
5377
5378 def format_scientific_notation(text: str) -> str:
5379     """Formats a numeric string utilizing scentific notation"""
5380     before, after = text.split("e")
5381     sign = ""
5382     if after.startswith("-"):
5383         after = after[1:]
5384         sign = "-"
5385     elif after.startswith("+"):
5386         after = after[1:]
5387     before = format_float_or_int_string(before)
5388     return f"{before}e{sign}{after}"
5389
5390
5391 def format_long_or_complex_number(text: str) -> str:
5392     """Formats a long or complex string like `10L` or `10j`"""
5393     number = text[:-1]
5394     suffix = text[-1]
5395     # Capitalize in "2L" because "l" looks too similar to "1".
5396     if suffix == "l":
5397         suffix = "L"
5398     return f"{format_float_or_int_string(number)}{suffix}"
5399
5400
5401 def format_float_or_int_string(text: str) -> str:
5402     """Formats a float string like "1.0"."""
5403     if "." not in text:
5404         return text
5405
5406     before, after = text.split(".")
5407     return f"{before or 0}.{after or 0}"
5408
5409
5410 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
5411     """Make existing optional parentheses invisible or create new ones.
5412
5413     `parens_after` is a set of string leaf values immediately after which parens
5414     should be put.
5415
5416     Standardizes on visible parentheses for single-element tuples, and keeps
5417     existing visible parentheses for other tuples and generator expressions.
5418     """
5419     for pc in list_comments(node.prefix, is_endmarker=False):
5420         if pc.value in FMT_OFF:
5421             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
5422             return
5423     check_lpar = False
5424     for index, child in enumerate(list(node.children)):
5425         # Fixes a bug where invisible parens are not properly stripped from
5426         # assignment statements that contain type annotations.
5427         if isinstance(child, Node) and child.type == syms.annassign:
5428             normalize_invisible_parens(child, parens_after=parens_after)
5429
5430         # Add parentheses around long tuple unpacking in assignments.
5431         if (
5432             index == 0
5433             and isinstance(child, Node)
5434             and child.type == syms.testlist_star_expr
5435         ):
5436             check_lpar = True
5437
5438         if check_lpar:
5439             if child.type == syms.atom:
5440                 if maybe_make_parens_invisible_in_atom(child, parent=node):
5441                     wrap_in_parentheses(node, child, visible=False)
5442             elif is_one_tuple(child):
5443                 wrap_in_parentheses(node, child, visible=True)
5444             elif node.type == syms.import_from:
5445                 # "import from" nodes store parentheses directly as part of
5446                 # the statement
5447                 if child.type == token.LPAR:
5448                     # make parentheses invisible
5449                     child.value = ""  # type: ignore
5450                     node.children[-1].value = ""  # type: ignore
5451                 elif child.type != token.STAR:
5452                     # insert invisible parentheses
5453                     node.insert_child(index, Leaf(token.LPAR, ""))
5454                     node.append_child(Leaf(token.RPAR, ""))
5455                 break
5456
5457             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
5458                 wrap_in_parentheses(node, child, visible=False)
5459
5460         check_lpar = isinstance(child, Leaf) and child.value in parens_after
5461
5462
5463 def normalize_fmt_off(node: Node) -> None:
5464     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
5465     try_again = True
5466     while try_again:
5467         try_again = convert_one_fmt_off_pair(node)
5468
5469
5470 def convert_one_fmt_off_pair(node: Node) -> bool:
5471     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
5472
5473     Returns True if a pair was converted.
5474     """
5475     for leaf in node.leaves():
5476         previous_consumed = 0
5477         for comment in list_comments(leaf.prefix, is_endmarker=False):
5478             if comment.value not in FMT_PASS:
5479                 previous_consumed = comment.consumed
5480                 continue
5481             # We only want standalone comments. If there's no previous leaf or
5482             # the previous leaf is indentation, it's a standalone comment in
5483             # disguise.
5484             if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
5485                 prev = preceding_leaf(leaf)
5486                 if prev:
5487                     if comment.value in FMT_OFF and prev.type not in WHITESPACE:
5488                         continue
5489                     if comment.value in FMT_SKIP and prev.type in WHITESPACE:
5490                         continue
5491
5492             ignored_nodes = list(generate_ignored_nodes(leaf, comment))
5493             if not ignored_nodes:
5494                 continue
5495
5496             first = ignored_nodes[0]  # Can be a container node with the `leaf`.
5497             parent = first.parent
5498             prefix = first.prefix
5499             first.prefix = prefix[comment.consumed :]
5500             hidden_value = "".join(str(n) for n in ignored_nodes)
5501             if comment.value in FMT_OFF:
5502                 hidden_value = comment.value + "\n" + hidden_value
5503             if comment.value in FMT_SKIP:
5504                 hidden_value += "  " + comment.value
5505             if hidden_value.endswith("\n"):
5506                 # That happens when one of the `ignored_nodes` ended with a NEWLINE
5507                 # leaf (possibly followed by a DEDENT).
5508                 hidden_value = hidden_value[:-1]
5509             first_idx: Optional[int] = None
5510             for ignored in ignored_nodes:
5511                 index = ignored.remove()
5512                 if first_idx is None:
5513                     first_idx = index
5514             assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
5515             assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
5516             parent.insert_child(
5517                 first_idx,
5518                 Leaf(
5519                     STANDALONE_COMMENT,
5520                     hidden_value,
5521                     prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
5522                 ),
5523             )
5524             return True
5525
5526     return False
5527
5528
5529 def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
5530     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
5531
5532     If comment is skip, returns leaf only.
5533     Stops at the end of the block.
5534     """
5535     container: Optional[LN] = container_of(leaf)
5536     if comment.value in FMT_SKIP:
5537         prev_sibling = leaf.prev_sibling
5538         if comment.value in leaf.prefix and prev_sibling is not None:
5539             leaf.prefix = leaf.prefix.replace(comment.value, "")
5540             siblings = [prev_sibling]
5541             while (
5542                 "\n" not in prev_sibling.prefix
5543                 and prev_sibling.prev_sibling is not None
5544             ):
5545                 prev_sibling = prev_sibling.prev_sibling
5546                 siblings.insert(0, prev_sibling)
5547             for sibling in siblings:
5548                 yield sibling
5549         elif leaf.parent is not None:
5550             yield leaf.parent
5551         return
5552     while container is not None and container.type != token.ENDMARKER:
5553         if is_fmt_on(container):
5554             return
5555
5556         # fix for fmt: on in children
5557         if contains_fmt_on_at_column(container, leaf.column):
5558             for child in container.children:
5559                 if contains_fmt_on_at_column(child, leaf.column):
5560                     return
5561                 yield child
5562         else:
5563             yield container
5564             container = container.next_sibling
5565
5566
5567 def is_fmt_on(container: LN) -> bool:
5568     """Determine whether formatting is switched on within a container.
5569     Determined by whether the last `# fmt:` comment is `on` or `off`.
5570     """
5571     fmt_on = False
5572     for comment in list_comments(container.prefix, is_endmarker=False):
5573         if comment.value in FMT_ON:
5574             fmt_on = True
5575         elif comment.value in FMT_OFF:
5576             fmt_on = False
5577     return fmt_on
5578
5579
5580 def contains_fmt_on_at_column(container: LN, column: int) -> bool:
5581     """Determine if children at a given column have formatting switched on."""
5582     for child in container.children:
5583         if (
5584             isinstance(child, Node)
5585             and first_leaf_column(child) == column
5586             or isinstance(child, Leaf)
5587             and child.column == column
5588         ):
5589             if is_fmt_on(child):
5590                 return True
5591
5592     return False
5593
5594
5595 def first_leaf_column(node: Node) -> Optional[int]:
5596     """Returns the column of the first leaf child of a node."""
5597     for child in node.children:
5598         if isinstance(child, Leaf):
5599             return child.column
5600     return None
5601
5602
5603 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
5604     """If it's safe, make the parens in the atom `node` invisible, recursively.
5605     Additionally, remove repeated, adjacent invisible parens from the atom `node`
5606     as they are redundant.
5607
5608     Returns whether the node should itself be wrapped in invisible parentheses.
5609
5610     """
5611
5612     if (
5613         node.type != syms.atom
5614         or is_empty_tuple(node)
5615         or is_one_tuple(node)
5616         or (is_yield(node) and parent.type != syms.expr_stmt)
5617         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
5618     ):
5619         return False
5620
5621     if is_walrus_assignment(node):
5622         if parent.type in [
5623             syms.annassign,
5624             syms.expr_stmt,
5625             syms.assert_stmt,
5626             syms.return_stmt,
5627         ]:
5628             return False
5629
5630     first = node.children[0]
5631     last = node.children[-1]
5632     if first.type == token.LPAR and last.type == token.RPAR:
5633         middle = node.children[1]
5634         # make parentheses invisible
5635         first.value = ""  # type: ignore
5636         last.value = ""  # type: ignore
5637         maybe_make_parens_invisible_in_atom(middle, parent=parent)
5638
5639         if is_atom_with_invisible_parens(middle):
5640             # Strip the invisible parens from `middle` by replacing
5641             # it with the child in-between the invisible parens
5642             middle.replace(middle.children[1])
5643
5644         return False
5645
5646     return True
5647
5648
5649 def is_atom_with_invisible_parens(node: LN) -> bool:
5650     """Given a `LN`, determines whether it's an atom `node` with invisible
5651     parens. Useful in dedupe-ing and normalizing parens.
5652     """
5653     if isinstance(node, Leaf) or node.type != syms.atom:
5654         return False
5655
5656     first, last = node.children[0], node.children[-1]
5657     return (
5658         isinstance(first, Leaf)
5659         and first.type == token.LPAR
5660         and first.value == ""
5661         and isinstance(last, Leaf)
5662         and last.type == token.RPAR
5663         and last.value == ""
5664     )
5665
5666
5667 def is_empty_tuple(node: LN) -> bool:
5668     """Return True if `node` holds an empty tuple."""
5669     return (
5670         node.type == syms.atom
5671         and len(node.children) == 2
5672         and node.children[0].type == token.LPAR
5673         and node.children[1].type == token.RPAR
5674     )
5675
5676
5677 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
5678     """Returns `wrapped` if `node` is of the shape ( wrapped ).
5679
5680     Parenthesis can be optional. Returns None otherwise"""
5681     if len(node.children) != 3:
5682         return None
5683
5684     lpar, wrapped, rpar = node.children
5685     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
5686         return None
5687
5688     return wrapped
5689
5690
5691 def first_child_is_arith(node: Node) -> bool:
5692     """Whether first child is an arithmetic or a binary arithmetic expression"""
5693     expr_types = {
5694         syms.arith_expr,
5695         syms.shift_expr,
5696         syms.xor_expr,
5697         syms.and_expr,
5698     }
5699     return bool(node.children and node.children[0].type in expr_types)
5700
5701
5702 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
5703     """Wrap `child` in parentheses.
5704
5705     This replaces `child` with an atom holding the parentheses and the old
5706     child.  That requires moving the prefix.
5707
5708     If `visible` is False, the leaves will be valueless (and thus invisible).
5709     """
5710     lpar = Leaf(token.LPAR, "(" if visible else "")
5711     rpar = Leaf(token.RPAR, ")" if visible else "")
5712     prefix = child.prefix
5713     child.prefix = ""
5714     index = child.remove() or 0
5715     new_child = Node(syms.atom, [lpar, child, rpar])
5716     new_child.prefix = prefix
5717     parent.insert_child(index, new_child)
5718
5719
5720 def is_one_tuple(node: LN) -> bool:
5721     """Return True if `node` holds a tuple with one element, with or without parens."""
5722     if node.type == syms.atom:
5723         gexp = unwrap_singleton_parenthesis(node)
5724         if gexp is None or gexp.type != syms.testlist_gexp:
5725             return False
5726
5727         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
5728
5729     return (
5730         node.type in IMPLICIT_TUPLE
5731         and len(node.children) == 2
5732         and node.children[1].type == token.COMMA
5733     )
5734
5735
5736 def is_walrus_assignment(node: LN) -> bool:
5737     """Return True iff `node` is of the shape ( test := test )"""
5738     inner = unwrap_singleton_parenthesis(node)
5739     return inner is not None and inner.type == syms.namedexpr_test
5740
5741
5742 def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
5743     """Return True iff `node` is a trailer valid in a simple decorator"""
5744     return node.type == syms.trailer and (
5745         (
5746             len(node.children) == 2
5747             and node.children[0].type == token.DOT
5748             and node.children[1].type == token.NAME
5749         )
5750         # last trailer can be arguments
5751         or (
5752             last
5753             and len(node.children) == 3
5754             and node.children[0].type == token.LPAR
5755             # and node.children[1].type == syms.argument
5756             and node.children[2].type == token.RPAR
5757         )
5758     )
5759
5760
5761 def is_simple_decorator_expression(node: LN) -> bool:
5762     """Return True iff `node` could be a 'dotted name' decorator
5763
5764     This function takes the node of the 'namedexpr_test' of the new decorator
5765     grammar and test if it would be valid under the old decorator grammar.
5766
5767     The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
5768     The new grammar is : decorator: @ namedexpr_test NEWLINE
5769     """
5770     if node.type == token.NAME:
5771         return True
5772     if node.type == syms.power:
5773         if node.children:
5774             return (
5775                 node.children[0].type == token.NAME
5776                 and all(map(is_simple_decorator_trailer, node.children[1:-1]))
5777                 and (
5778                     len(node.children) < 2
5779                     or is_simple_decorator_trailer(node.children[-1], last=True)
5780                 )
5781             )
5782     return False
5783
5784
5785 def is_yield(node: LN) -> bool:
5786     """Return True if `node` holds a `yield` or `yield from` expression."""
5787     if node.type == syms.yield_expr:
5788         return True
5789
5790     if node.type == token.NAME and node.value == "yield":  # type: ignore
5791         return True
5792
5793     if node.type != syms.atom:
5794         return False
5795
5796     if len(node.children) != 3:
5797         return False
5798
5799     lpar, expr, rpar = node.children
5800     if lpar.type == token.LPAR and rpar.type == token.RPAR:
5801         return is_yield(expr)
5802
5803     return False
5804
5805
5806 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
5807     """Return True if `leaf` is a star or double star in a vararg or kwarg.
5808
5809     If `within` includes VARARGS_PARENTS, this applies to function signatures.
5810     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
5811     extended iterable unpacking (PEP 3132) and additional unpacking
5812     generalizations (PEP 448).
5813     """
5814     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
5815         return False
5816
5817     p = leaf.parent
5818     if p.type == syms.star_expr:
5819         # Star expressions are also used as assignment targets in extended
5820         # iterable unpacking (PEP 3132).  See what its parent is instead.
5821         if not p.parent:
5822             return False
5823
5824         p = p.parent
5825
5826     return p.type in within
5827
5828
5829 def is_multiline_string(leaf: Leaf) -> bool:
5830     """Return True if `leaf` is a multiline string that actually spans many lines."""
5831     return has_triple_quotes(leaf.value) and "\n" in leaf.value
5832
5833
5834 def is_stub_suite(node: Node) -> bool:
5835     """Return True if `node` is a suite with a stub body."""
5836     if (
5837         len(node.children) != 4
5838         or node.children[0].type != token.NEWLINE
5839         or node.children[1].type != token.INDENT
5840         or node.children[3].type != token.DEDENT
5841     ):
5842         return False
5843
5844     return is_stub_body(node.children[2])
5845
5846
5847 def is_stub_body(node: LN) -> bool:
5848     """Return True if `node` is a simple statement containing an ellipsis."""
5849     if not isinstance(node, Node) or node.type != syms.simple_stmt:
5850         return False
5851
5852     if len(node.children) != 2:
5853         return False
5854
5855     child = node.children[0]
5856     return (
5857         child.type == syms.atom
5858         and len(child.children) == 3
5859         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
5860     )
5861
5862
5863 def max_delimiter_priority_in_atom(node: LN) -> Priority:
5864     """Return maximum delimiter priority inside `node`.
5865
5866     This is specific to atoms with contents contained in a pair of parentheses.
5867     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
5868     """
5869     if node.type != syms.atom:
5870         return 0
5871
5872     first = node.children[0]
5873     last = node.children[-1]
5874     if not (first.type == token.LPAR and last.type == token.RPAR):
5875         return 0
5876
5877     bt = BracketTracker()
5878     for c in node.children[1:-1]:
5879         if isinstance(c, Leaf):
5880             bt.mark(c)
5881         else:
5882             for leaf in c.leaves():
5883                 bt.mark(leaf)
5884     try:
5885         return bt.max_delimiter_priority()
5886
5887     except ValueError:
5888         return 0
5889
5890
5891 def ensure_visible(leaf: Leaf) -> None:
5892     """Make sure parentheses are visible.
5893
5894     They could be invisible as part of some statements (see
5895     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
5896     """
5897     if leaf.type == token.LPAR:
5898         leaf.value = "("
5899     elif leaf.type == token.RPAR:
5900         leaf.value = ")"
5901
5902
5903 def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
5904     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
5905
5906     if not (opening_bracket.parent and opening_bracket.value in "[{("):
5907         return False
5908
5909     # We're essentially checking if the body is delimited by commas and there's more
5910     # than one of them (we're excluding the trailing comma and if the delimiter priority
5911     # is still commas, that means there's more).
5912     exclude = set()
5913     trailing_comma = False
5914     try:
5915         last_leaf = line.leaves[-1]
5916         if last_leaf.type == token.COMMA:
5917             trailing_comma = True
5918             exclude.add(id(last_leaf))
5919         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
5920     except (IndexError, ValueError):
5921         return False
5922
5923     return max_priority == COMMA_PRIORITY and (
5924         (line.mode.magic_trailing_comma and trailing_comma)
5925         # always explode imports
5926         or opening_bracket.parent.type in {syms.atom, syms.import_from}
5927     )
5928
5929
5930 def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool:
5931     """Return True if content between `opening` and `closing` looks like a one-tuple."""
5932     if opening.type != token.LPAR and closing.type != token.RPAR:
5933         return False
5934
5935     depth = closing.bracket_depth + 1
5936     for _opening_index, leaf in enumerate(leaves):
5937         if leaf is opening:
5938             break
5939
5940     else:
5941         raise LookupError("Opening paren not found in `leaves`")
5942
5943     commas = 0
5944     _opening_index += 1
5945     for leaf in leaves[_opening_index:]:
5946         if leaf is closing:
5947             break
5948
5949         bracket_depth = leaf.bracket_depth
5950         if bracket_depth == depth and leaf.type == token.COMMA:
5951             commas += 1
5952             if leaf.parent and leaf.parent.type in {
5953                 syms.arglist,
5954                 syms.typedargslist,
5955             }:
5956                 commas += 1
5957                 break
5958
5959     return commas < 2
5960
5961
5962 def get_features_used(node: Node) -> Set[Feature]:
5963     """Return a set of (relatively) new Python features used in this file.
5964
5965     Currently looking for:
5966     - f-strings;
5967     - underscores in numeric literals;
5968     - trailing commas after * or ** in function signatures and calls;
5969     - positional only arguments in function signatures and lambdas;
5970     - assignment expression;
5971     - relaxed decorator syntax;
5972     """
5973     features: Set[Feature] = set()
5974     for n in node.pre_order():
5975         if n.type == token.STRING:
5976             value_head = n.value[:2]  # type: ignore
5977             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
5978                 features.add(Feature.F_STRINGS)
5979
5980         elif n.type == token.NUMBER:
5981             if "_" in n.value:  # type: ignore
5982                 features.add(Feature.NUMERIC_UNDERSCORES)
5983
5984         elif n.type == token.SLASH:
5985             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
5986                 features.add(Feature.POS_ONLY_ARGUMENTS)
5987
5988         elif n.type == token.COLONEQUAL:
5989             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
5990
5991         elif n.type == syms.decorator:
5992             if len(n.children) > 1 and not is_simple_decorator_expression(
5993                 n.children[1]
5994             ):
5995                 features.add(Feature.RELAXED_DECORATORS)
5996
5997         elif (
5998             n.type in {syms.typedargslist, syms.arglist}
5999             and n.children
6000             and n.children[-1].type == token.COMMA
6001         ):
6002             if n.type == syms.typedargslist:
6003                 feature = Feature.TRAILING_COMMA_IN_DEF
6004             else:
6005                 feature = Feature.TRAILING_COMMA_IN_CALL
6006
6007             for ch in n.children:
6008                 if ch.type in STARS:
6009                     features.add(feature)
6010
6011                 if ch.type == syms.argument:
6012                     for argch in ch.children:
6013                         if argch.type in STARS:
6014                             features.add(feature)
6015
6016     return features
6017
6018
6019 def detect_target_versions(node: Node) -> Set[TargetVersion]:
6020     """Detect the version to target based on the nodes used."""
6021     features = get_features_used(node)
6022     return {
6023         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
6024     }
6025
6026
6027 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
6028     """Generate sets of closing bracket IDs that should be omitted in a RHS.
6029
6030     Brackets can be omitted if the entire trailer up to and including
6031     a preceding closing bracket fits in one line.
6032
6033     Yielded sets are cumulative (contain results of previous yields, too).  First
6034     set is empty, unless the line should explode, in which case bracket pairs until
6035     the one that needs to explode are omitted.
6036     """
6037
6038     omit: Set[LeafID] = set()
6039     if not line.magic_trailing_comma:
6040         yield omit
6041
6042     length = 4 * line.depth
6043     opening_bracket: Optional[Leaf] = None
6044     closing_bracket: Optional[Leaf] = None
6045     inner_brackets: Set[LeafID] = set()
6046     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
6047         length += leaf_length
6048         if length > line_length:
6049             break
6050
6051         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
6052         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
6053             break
6054
6055         if opening_bracket:
6056             if leaf is opening_bracket:
6057                 opening_bracket = None
6058             elif leaf.type in CLOSING_BRACKETS:
6059                 prev = line.leaves[index - 1] if index > 0 else None
6060                 if (
6061                     prev
6062                     and prev.type == token.COMMA
6063                     and not is_one_tuple_between(
6064                         leaf.opening_bracket, leaf, line.leaves
6065                     )
6066                 ):
6067                     # Never omit bracket pairs with trailing commas.
6068                     # We need to explode on those.
6069                     break
6070
6071                 inner_brackets.add(id(leaf))
6072         elif leaf.type in CLOSING_BRACKETS:
6073             prev = line.leaves[index - 1] if index > 0 else None
6074             if prev and prev.type in OPENING_BRACKETS:
6075                 # Empty brackets would fail a split so treat them as "inner"
6076                 # brackets (e.g. only add them to the `omit` set if another
6077                 # pair of brackets was good enough.
6078                 inner_brackets.add(id(leaf))
6079                 continue
6080
6081             if closing_bracket:
6082                 omit.add(id(closing_bracket))
6083                 omit.update(inner_brackets)
6084                 inner_brackets.clear()
6085                 yield omit
6086
6087             if (
6088                 prev
6089                 and prev.type == token.COMMA
6090                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
6091             ):
6092                 # Never omit bracket pairs with trailing commas.
6093                 # We need to explode on those.
6094                 break
6095
6096             if leaf.value:
6097                 opening_bracket = leaf.opening_bracket
6098                 closing_bracket = leaf
6099
6100
6101 def get_future_imports(node: Node) -> Set[str]:
6102     """Return a set of __future__ imports in the file."""
6103     imports: Set[str] = set()
6104
6105     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
6106         for child in children:
6107             if isinstance(child, Leaf):
6108                 if child.type == token.NAME:
6109                     yield child.value
6110
6111             elif child.type == syms.import_as_name:
6112                 orig_name = child.children[0]
6113                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
6114                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
6115                 yield orig_name.value
6116
6117             elif child.type == syms.import_as_names:
6118                 yield from get_imports_from_children(child.children)
6119
6120             else:
6121                 raise AssertionError("Invalid syntax parsing imports")
6122
6123     for child in node.children:
6124         if child.type != syms.simple_stmt:
6125             break
6126
6127         first_child = child.children[0]
6128         if isinstance(first_child, Leaf):
6129             # Continue looking if we see a docstring; otherwise stop.
6130             if (
6131                 len(child.children) == 2
6132                 and first_child.type == token.STRING
6133                 and child.children[1].type == token.NEWLINE
6134             ):
6135                 continue
6136
6137             break
6138
6139         elif first_child.type == syms.import_from:
6140             module_name = first_child.children[1]
6141             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
6142                 break
6143
6144             imports |= set(get_imports_from_children(first_child.children[3:]))
6145         else:
6146             break
6147
6148     return imports
6149
6150
6151 @lru_cache()
6152 def get_gitignore(root: Path) -> PathSpec:
6153     """Return a PathSpec matching gitignore content if present."""
6154     gitignore = root / ".gitignore"
6155     lines: List[str] = []
6156     if gitignore.is_file():
6157         with gitignore.open() as gf:
6158             lines = gf.readlines()
6159     return PathSpec.from_lines("gitwildmatch", lines)
6160
6161
6162 def normalize_path_maybe_ignore(
6163     path: Path, root: Path, report: "Report"
6164 ) -> Optional[str]:
6165     """Normalize `path`. May return `None` if `path` was ignored.
6166
6167     `report` is where "path ignored" output goes.
6168     """
6169     try:
6170         abspath = path if path.is_absolute() else Path.cwd() / path
6171         normalized_path = abspath.resolve().relative_to(root).as_posix()
6172     except OSError as e:
6173         report.path_ignored(path, f"cannot be read because {e}")
6174         return None
6175
6176     except ValueError:
6177         if path.is_symlink():
6178             report.path_ignored(path, f"is a symbolic link that points outside {root}")
6179             return None
6180
6181         raise
6182
6183     return normalized_path
6184
6185
6186 def path_is_excluded(
6187     normalized_path: str,
6188     pattern: Optional[Pattern[str]],
6189 ) -> bool:
6190     match = pattern.search(normalized_path) if pattern else None
6191     return bool(match and match.group(0))
6192
6193
6194 def gen_python_files(
6195     paths: Iterable[Path],
6196     root: Path,
6197     include: Optional[Pattern[str]],
6198     exclude: Pattern[str],
6199     extend_exclude: Optional[Pattern[str]],
6200     force_exclude: Optional[Pattern[str]],
6201     report: "Report",
6202     gitignore: PathSpec,
6203 ) -> Iterator[Path]:
6204     """Generate all files under `path` whose paths are not excluded by the
6205     `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
6206     but are included by the `include` regex.
6207
6208     Symbolic links pointing outside of the `root` directory are ignored.
6209
6210     `report` is where output about exclusions goes.
6211     """
6212     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
6213     for child in paths:
6214         normalized_path = normalize_path_maybe_ignore(child, root, report)
6215         if normalized_path is None:
6216             continue
6217
6218         # First ignore files matching .gitignore
6219         if gitignore.match_file(normalized_path):
6220             report.path_ignored(child, "matches the .gitignore file content")
6221             continue
6222
6223         # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
6224         normalized_path = "/" + normalized_path
6225         if child.is_dir():
6226             normalized_path += "/"
6227
6228         if path_is_excluded(normalized_path, exclude):
6229             report.path_ignored(child, "matches the --exclude regular expression")
6230             continue
6231
6232         if path_is_excluded(normalized_path, extend_exclude):
6233             report.path_ignored(
6234                 child, "matches the --extend-exclude regular expression"
6235             )
6236             continue
6237
6238         if path_is_excluded(normalized_path, force_exclude):
6239             report.path_ignored(child, "matches the --force-exclude regular expression")
6240             continue
6241
6242         if child.is_dir():
6243             yield from gen_python_files(
6244                 child.iterdir(),
6245                 root,
6246                 include,
6247                 exclude,
6248                 extend_exclude,
6249                 force_exclude,
6250                 report,
6251                 gitignore,
6252             )
6253
6254         elif child.is_file():
6255             include_match = include.search(normalized_path) if include else True
6256             if include_match:
6257                 yield child
6258
6259
6260 @lru_cache()
6261 def find_project_root(srcs: Tuple[str, ...]) -> Path:
6262     """Return a directory containing .git, .hg, or pyproject.toml.
6263
6264     That directory will be a common parent of all files and directories
6265     passed in `srcs`.
6266
6267     If no directory in the tree contains a marker that would specify it's the
6268     project root, the root of the file system is returned.
6269     """
6270     if not srcs:
6271         return Path("/").resolve()
6272
6273     path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]
6274
6275     # A list of lists of parents for each 'src'. 'src' is included as a
6276     # "parent" of itself if it is a directory
6277     src_parents = [
6278         list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
6279     ]
6280
6281     common_base = max(
6282         set.intersection(*(set(parents) for parents in src_parents)),
6283         key=lambda path: path.parts,
6284     )
6285
6286     for directory in (common_base, *common_base.parents):
6287         if (directory / ".git").exists():
6288             return directory
6289
6290         if (directory / ".hg").is_dir():
6291             return directory
6292
6293         if (directory / "pyproject.toml").is_file():
6294             return directory
6295
6296     return directory
6297
6298
6299 @lru_cache()
6300 def find_user_pyproject_toml() -> Path:
6301     r"""Return the path to the top-level user configuration for black.
6302
6303     This looks for ~\.black on Windows and ~/.config/black on Linux and other
6304     Unix systems.
6305     """
6306     if sys.platform == "win32":
6307         # Windows
6308         user_config_path = Path.home() / ".black"
6309     else:
6310         config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
6311         user_config_path = Path(config_root).expanduser() / "black"
6312     return user_config_path.resolve()
6313
6314
6315 @dataclass
6316 class Report:
6317     """Provides a reformatting counter. Can be rendered with `str(report)`."""
6318
6319     check: bool = False
6320     diff: bool = False
6321     quiet: bool = False
6322     verbose: bool = False
6323     change_count: int = 0
6324     same_count: int = 0
6325     failure_count: int = 0
6326
6327     def done(self, src: Path, changed: Changed) -> None:
6328         """Increment the counter for successful reformatting. Write out a message."""
6329         if changed is Changed.YES:
6330             reformatted = "would reformat" if self.check or self.diff else "reformatted"
6331             if self.verbose or not self.quiet:
6332                 out(f"{reformatted} {src}")
6333             self.change_count += 1
6334         else:
6335             if self.verbose:
6336                 if changed is Changed.NO:
6337                     msg = f"{src} already well formatted, good job."
6338                 else:
6339                     msg = f"{src} wasn't modified on disk since last run."
6340                 out(msg, bold=False)
6341             self.same_count += 1
6342
6343     def failed(self, src: Path, message: str) -> None:
6344         """Increment the counter for failed reformatting. Write out a message."""
6345         err(f"error: cannot format {src}: {message}")
6346         self.failure_count += 1
6347
6348     def path_ignored(self, path: Path, message: str) -> None:
6349         if self.verbose:
6350             out(f"{path} ignored: {message}", bold=False)
6351
6352     @property
6353     def return_code(self) -> int:
6354         """Return the exit code that the app should use.
6355
6356         This considers the current state of changed files and failures:
6357         - if there were any failures, return 123;
6358         - if any files were changed and --check is being used, return 1;
6359         - otherwise return 0.
6360         """
6361         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
6362         # 126 we have special return codes reserved by the shell.
6363         if self.failure_count:
6364             return 123
6365
6366         elif self.change_count and self.check:
6367             return 1
6368
6369         return 0
6370
6371     def __str__(self) -> str:
6372         """Render a color report of the current state.
6373
6374         Use `click.unstyle` to remove colors.
6375         """
6376         if self.check or self.diff:
6377             reformatted = "would be reformatted"
6378             unchanged = "would be left unchanged"
6379             failed = "would fail to reformat"
6380         else:
6381             reformatted = "reformatted"
6382             unchanged = "left unchanged"
6383             failed = "failed to reformat"
6384         report = []
6385         if self.change_count:
6386             s = "s" if self.change_count > 1 else ""
6387             report.append(
6388                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
6389             )
6390         if self.same_count:
6391             s = "s" if self.same_count > 1 else ""
6392             report.append(f"{self.same_count} file{s} {unchanged}")
6393         if self.failure_count:
6394             s = "s" if self.failure_count > 1 else ""
6395             report.append(
6396                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
6397             )
6398         return ", ".join(report) + "."
6399
6400
6401 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
6402     filename = "<unknown>"
6403     if sys.version_info >= (3, 8):
6404         # TODO: support Python 4+ ;)
6405         for minor_version in range(sys.version_info[1], 4, -1):
6406             try:
6407                 return ast.parse(src, filename, feature_version=(3, minor_version))
6408             except SyntaxError:
6409                 continue
6410     else:
6411         for feature_version in (7, 6):
6412             try:
6413                 return ast3.parse(src, filename, feature_version=feature_version)
6414             except SyntaxError:
6415                 continue
6416     if ast27.__name__ == "ast":
6417         raise SyntaxError(
6418             "The requested source code has invalid Python 3 syntax.\n"
6419             "If you are trying to format Python 2 files please reinstall Black"
6420             " with the 'python2' extra: `python3 -m pip install black[python2]`."
6421         )
6422     return ast27.parse(src)
6423
6424
6425 def _fixup_ast_constants(
6426     node: Union[ast.AST, ast3.AST, ast27.AST]
6427 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
6428     """Map ast nodes deprecated in 3.8 to Constant."""
6429     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
6430         return ast.Constant(value=node.s)
6431
6432     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
6433         return ast.Constant(value=node.n)
6434
6435     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
6436         return ast.Constant(value=node.value)
6437
6438     return node
6439
6440
6441 def _stringify_ast(
6442     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
6443 ) -> Iterator[str]:
6444     """Simple visitor generating strings to compare ASTs by content."""
6445
6446     node = _fixup_ast_constants(node)
6447
6448     yield f"{'  ' * depth}{node.__class__.__name__}("
6449
6450     for field in sorted(node._fields):  # noqa: F402
6451         # TypeIgnore has only one field 'lineno' which breaks this comparison
6452         type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
6453         if sys.version_info >= (3, 8):
6454             type_ignore_classes += (ast.TypeIgnore,)
6455         if isinstance(node, type_ignore_classes):
6456             break
6457
6458         try:
6459             value = getattr(node, field)
6460         except AttributeError:
6461             continue
6462
6463         yield f"{'  ' * (depth+1)}{field}="
6464
6465         if isinstance(value, list):
6466             for item in value:
6467                 # Ignore nested tuples within del statements, because we may insert
6468                 # parentheses and they change the AST.
6469                 if (
6470                     field == "targets"
6471                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
6472                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
6473                 ):
6474                     for item in item.elts:
6475                         yield from _stringify_ast(item, depth + 2)
6476
6477                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
6478                     yield from _stringify_ast(item, depth + 2)
6479
6480         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
6481             yield from _stringify_ast(value, depth + 2)
6482
6483         else:
6484             # Constant strings may be indented across newlines, if they are
6485             # docstrings; fold spaces after newlines when comparing. Similarly,
6486             # trailing and leading space may be removed.
6487             # Note that when formatting Python 2 code, at least with Windows
6488             # line-endings, docstrings can end up here as bytes instead of
6489             # str so make sure that we handle both cases.
6490             if (
6491                 isinstance(node, ast.Constant)
6492                 and field == "value"
6493                 and isinstance(value, (str, bytes))
6494             ):
6495                 lineend = "\n" if isinstance(value, str) else b"\n"
6496                 # To normalize, we strip any leading and trailing space from
6497                 # each line...
6498                 stripped = [line.strip() for line in value.splitlines()]
6499                 normalized = lineend.join(stripped)  # type: ignore[attr-defined]
6500                 # ...and remove any blank lines at the beginning and end of
6501                 # the whole string
6502                 normalized = normalized.strip()
6503             else:
6504                 normalized = value
6505             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
6506
6507     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
6508
6509
6510 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
6511     """Raise AssertionError if `src` and `dst` aren't equivalent."""
6512     try:
6513         src_ast = parse_ast(src)
6514     except Exception as exc:
6515         raise AssertionError(
6516             "cannot use --safe with this file; failed to parse source file.  AST"
6517             f" error message: {exc}"
6518         )
6519
6520     try:
6521         dst_ast = parse_ast(dst)
6522     except Exception as exc:
6523         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
6524         raise AssertionError(
6525             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
6526             "Please report a bug on https://github.com/psf/black/issues.  "
6527             f"This invalid output might be helpful: {log}"
6528         ) from None
6529
6530     src_ast_str = "\n".join(_stringify_ast(src_ast))
6531     dst_ast_str = "\n".join(_stringify_ast(dst_ast))
6532     if src_ast_str != dst_ast_str:
6533         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
6534         raise AssertionError(
6535             "INTERNAL ERROR: Black produced code that is not equivalent to the"
6536             f" source on pass {pass_num}.  Please report a bug on "
6537             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
6538         ) from None
6539
6540
6541 def assert_stable(src: str, dst: str, mode: Mode) -> None:
6542     """Raise AssertionError if `dst` reformats differently the second time."""
6543     newdst = format_str(dst, mode=mode)
6544     if dst != newdst:
6545         log = dump_to_file(
6546             str(mode),
6547             diff(src, dst, "source", "first pass"),
6548             diff(dst, newdst, "first pass", "second pass"),
6549         )
6550         raise AssertionError(
6551             "INTERNAL ERROR: Black produced different code on the second pass of the"
6552             " formatter.  Please report a bug on https://github.com/psf/black/issues."
6553             f"  This diff might be helpful: {log}"
6554         ) from None
6555
6556
6557 @mypyc_attr(patchable=True)
6558 def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
6559     """Dump `output` to a temporary file. Return path to the file."""
6560     with tempfile.NamedTemporaryFile(
6561         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
6562     ) as f:
6563         for lines in output:
6564             f.write(lines)
6565             if ensure_final_newline and lines and lines[-1] != "\n":
6566                 f.write("\n")
6567     return f.name
6568
6569
6570 @contextmanager
6571 def nullcontext() -> Iterator[None]:
6572     """Return an empty context manager.
6573
6574     To be used like `nullcontext` in Python 3.7.
6575     """
6576     yield
6577
6578
6579 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
6580     """Return a unified diff string between strings `a` and `b`."""
6581     import difflib
6582
6583     a_lines = [line for line in a.splitlines(keepends=True)]
6584     b_lines = [line for line in b.splitlines(keepends=True)]
6585     diff_lines = []
6586     for line in difflib.unified_diff(
6587         a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
6588     ):
6589         # Work around https://bugs.python.org/issue2142
6590         # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
6591         if line[-1] == "\n":
6592             diff_lines.append(line)
6593         else:
6594             diff_lines.append(line + "\n")
6595             diff_lines.append("\\ No newline at end of file\n")
6596     return "".join(diff_lines)
6597
6598
6599 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
6600     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
6601     err("Aborted!")
6602     for task in tasks:
6603         task.cancel()
6604
6605
6606 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
6607     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
6608     try:
6609         if sys.version_info[:2] >= (3, 7):
6610             all_tasks = asyncio.all_tasks
6611         else:
6612             all_tasks = asyncio.Task.all_tasks
6613         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
6614         to_cancel = [task for task in all_tasks(loop) if not task.done()]
6615         if not to_cancel:
6616             return
6617
6618         for task in to_cancel:
6619             task.cancel()
6620         loop.run_until_complete(
6621             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
6622         )
6623     finally:
6624         # `concurrent.futures.Future` objects cannot be cancelled once they
6625         # are already running. There might be some when the `shutdown()` happened.
6626         # Silence their logger's spew about the event loop being closed.
6627         cf_logger = logging.getLogger("concurrent.futures")
6628         cf_logger.setLevel(logging.CRITICAL)
6629         loop.close()
6630
6631
6632 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
6633     """Replace `regex` with `replacement` twice on `original`.
6634
6635     This is used by string normalization to perform replaces on
6636     overlapping matches.
6637     """
6638     return regex.sub(replacement, regex.sub(replacement, original))
6639
6640
6641 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
6642     """Compile a regular expression string in `regex`.
6643
6644     If it contains newlines, use verbose mode.
6645     """
6646     if "\n" in regex:
6647         regex = "(?x)" + regex
6648     compiled: Pattern[str] = re.compile(regex)
6649     return compiled
6650
6651
6652 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
6653     """Like `reversed(enumerate(sequence))` if that were possible."""
6654     index = len(sequence) - 1
6655     for element in reversed(sequence):
6656         yield (index, element)
6657         index -= 1
6658
6659
6660 def enumerate_with_length(
6661     line: Line, reversed: bool = False
6662 ) -> Iterator[Tuple[Index, Leaf, int]]:
6663     """Return an enumeration of leaves with their length.
6664
6665     Stops prematurely on multiline strings and standalone comments.
6666     """
6667     op = cast(
6668         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
6669         enumerate_reversed if reversed else enumerate,
6670     )
6671     for index, leaf in op(line.leaves):
6672         length = len(leaf.prefix) + len(leaf.value)
6673         if "\n" in leaf.value:
6674             return  # Multiline strings, we can't continue.
6675
6676         for comment in line.comments_after(leaf):
6677             length += len(comment.value)
6678
6679         yield index, leaf, length
6680
6681
6682 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
6683     """Return True if `line` is no longer than `line_length`.
6684
6685     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
6686     """
6687     if not line_str:
6688         line_str = line_to_string(line)
6689     return (
6690         len(line_str) <= line_length
6691         and "\n" not in line_str  # multiline strings
6692         and not line.contains_standalone_comments()
6693     )
6694
6695
6696 def can_be_split(line: Line) -> bool:
6697     """Return False if the line cannot be split *for sure*.
6698
6699     This is not an exhaustive search but a cheap heuristic that we can use to
6700     avoid some unfortunate formattings (mostly around wrapping unsplittable code
6701     in unnecessary parentheses).
6702     """
6703     leaves = line.leaves
6704     if len(leaves) < 2:
6705         return False
6706
6707     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
6708         call_count = 0
6709         dot_count = 0
6710         next = leaves[-1]
6711         for leaf in leaves[-2::-1]:
6712             if leaf.type in OPENING_BRACKETS:
6713                 if next.type not in CLOSING_BRACKETS:
6714                     return False
6715
6716                 call_count += 1
6717             elif leaf.type == token.DOT:
6718                 dot_count += 1
6719             elif leaf.type == token.NAME:
6720                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
6721                     return False
6722
6723             elif leaf.type not in CLOSING_BRACKETS:
6724                 return False
6725
6726             if dot_count > 1 and call_count > 1:
6727                 return False
6728
6729     return True
6730
6731
6732 def can_omit_invisible_parens(
6733     line: Line,
6734     line_length: int,
6735     omit_on_explode: Collection[LeafID] = (),
6736 ) -> bool:
6737     """Does `line` have a shape safe to reformat without optional parens around it?
6738
6739     Returns True for only a subset of potentially nice looking formattings but
6740     the point is to not return false positives that end up producing lines that
6741     are too long.
6742     """
6743     bt = line.bracket_tracker
6744     if not bt.delimiters:
6745         # Without delimiters the optional parentheses are useless.
6746         return True
6747
6748     max_priority = bt.max_delimiter_priority()
6749     if bt.delimiter_count_with_priority(max_priority) > 1:
6750         # With more than one delimiter of a kind the optional parentheses read better.
6751         return False
6752
6753     if max_priority == DOT_PRIORITY:
6754         # A single stranded method call doesn't require optional parentheses.
6755         return True
6756
6757     assert len(line.leaves) >= 2, "Stranded delimiter"
6758
6759     # With a single delimiter, omit if the expression starts or ends with
6760     # a bracket.
6761     first = line.leaves[0]
6762     second = line.leaves[1]
6763     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
6764         if _can_omit_opening_paren(line, first=first, line_length=line_length):
6765             return True
6766
6767         # Note: we are not returning False here because a line might have *both*
6768         # a leading opening bracket and a trailing closing bracket.  If the
6769         # opening bracket doesn't match our rule, maybe the closing will.
6770
6771     penultimate = line.leaves[-2]
6772     last = line.leaves[-1]
6773     if line.magic_trailing_comma:
6774         try:
6775             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
6776         except LookupError:
6777             # Turns out we'd omit everything.  We cannot skip the optional parentheses.
6778             return False
6779
6780     if (
6781         last.type == token.RPAR
6782         or last.type == token.RBRACE
6783         or (
6784             # don't use indexing for omitting optional parentheses;
6785             # it looks weird
6786             last.type == token.RSQB
6787             and last.parent
6788             and last.parent.type != syms.trailer
6789         )
6790     ):
6791         if penultimate.type in OPENING_BRACKETS:
6792             # Empty brackets don't help.
6793             return False
6794
6795         if is_multiline_string(first):
6796             # Additional wrapping of a multiline string in this situation is
6797             # unnecessary.
6798             return True
6799
6800         if line.magic_trailing_comma and penultimate.type == token.COMMA:
6801             # The rightmost non-omitted bracket pair is the one we want to explode on.
6802             return True
6803
6804         if _can_omit_closing_paren(line, last=last, line_length=line_length):
6805             return True
6806
6807     return False
6808
6809
6810 def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool:
6811     """See `can_omit_invisible_parens`."""
6812     remainder = False
6813     length = 4 * line.depth
6814     _index = -1
6815     for _index, leaf, leaf_length in enumerate_with_length(line):
6816         if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
6817             remainder = True
6818         if remainder:
6819             length += leaf_length
6820             if length > line_length:
6821                 break
6822
6823             if leaf.type in OPENING_BRACKETS:
6824                 # There are brackets we can further split on.
6825                 remainder = False
6826
6827     else:
6828         # checked the entire string and line length wasn't exceeded
6829         if len(line.leaves) == _index + 1:
6830             return True
6831
6832     return False
6833
6834
6835 def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool:
6836     """See `can_omit_invisible_parens`."""
6837     length = 4 * line.depth
6838     seen_other_brackets = False
6839     for _index, leaf, leaf_length in enumerate_with_length(line):
6840         length += leaf_length
6841         if leaf is last.opening_bracket:
6842             if seen_other_brackets or length <= line_length:
6843                 return True
6844
6845         elif leaf.type in OPENING_BRACKETS:
6846             # There are brackets we can further split on.
6847             seen_other_brackets = True
6848
6849     return False
6850
6851
6852 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
6853     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
6854     stop_after = None
6855     last = None
6856     for leaf in reversed(leaves):
6857         if stop_after:
6858             if leaf is stop_after:
6859                 stop_after = None
6860             continue
6861
6862         if last:
6863             return leaf, last
6864
6865         if id(leaf) in omit:
6866             stop_after = leaf.opening_bracket
6867         else:
6868             last = leaf
6869     else:
6870         raise LookupError("Last two leaves were also skipped")
6871
6872
6873 def run_transformer(
6874     line: Line,
6875     transform: Transformer,
6876     mode: Mode,
6877     features: Collection[Feature],
6878     *,
6879     line_str: str = "",
6880 ) -> List[Line]:
6881     if not line_str:
6882         line_str = line_to_string(line)
6883     result: List[Line] = []
6884     for transformed_line in transform(line, features):
6885         if str(transformed_line).strip("\n") == line_str:
6886             raise CannotTransform("Line transformer returned an unchanged result")
6887
6888         result.extend(transform_line(transformed_line, mode=mode, features=features))
6889
6890     if not (
6891         transform.__name__ == "rhs"
6892         and line.bracket_tracker.invisible
6893         and not any(bracket.value for bracket in line.bracket_tracker.invisible)
6894         and not line.contains_multiline_strings()
6895         and not result[0].contains_uncollapsable_type_comments()
6896         and not result[0].contains_unsplittable_type_ignore()
6897         and not is_line_short_enough(result[0], line_length=mode.line_length)
6898     ):
6899         return result
6900
6901     line_copy = line.clone()
6902     append_leaves(line_copy, line, line.leaves)
6903     features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES}
6904     second_opinion = run_transformer(
6905         line_copy, transform, mode, features_fop, line_str=line_str
6906     )
6907     if all(
6908         is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
6909     ):
6910         result = second_opinion
6911     return result
6912
6913
6914 def get_cache_file(mode: Mode) -> Path:
6915     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
6916
6917
6918 def read_cache(mode: Mode) -> Cache:
6919     """Read the cache if it exists and is well formed.
6920
6921     If it is not well formed, the call to write_cache later should resolve the issue.
6922     """
6923     cache_file = get_cache_file(mode)
6924     if not cache_file.exists():
6925         return {}
6926
6927     with cache_file.open("rb") as fobj:
6928         try:
6929             cache: Cache = pickle.load(fobj)
6930         except (pickle.UnpicklingError, ValueError):
6931             return {}
6932
6933     return cache
6934
6935
6936 def get_cache_info(path: Path) -> CacheInfo:
6937     """Return the information used to check if a file is already formatted or not."""
6938     stat = path.stat()
6939     return stat.st_mtime, stat.st_size
6940
6941
6942 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
6943     """Split an iterable of paths in `sources` into two sets.
6944
6945     The first contains paths of files that modified on disk or are not in the
6946     cache. The other contains paths to non-modified files.
6947     """
6948     todo, done = set(), set()
6949     for src in sources:
6950         res_src = src.resolve()
6951         if cache.get(str(res_src)) != get_cache_info(res_src):
6952             todo.add(src)
6953         else:
6954             done.add(src)
6955     return todo, done
6956
6957
6958 def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
6959     """Update the cache file."""
6960     cache_file = get_cache_file(mode)
6961     try:
6962         CACHE_DIR.mkdir(parents=True, exist_ok=True)
6963         new_cache = {
6964             **cache,
6965             **{str(src.resolve()): get_cache_info(src) for src in sources},
6966         }
6967         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
6968             pickle.dump(new_cache, f, protocol=4)
6969         os.replace(f.name, cache_file)
6970     except OSError:
6971         pass
6972
6973
6974 def patch_click() -> None:
6975     """Make Click not crash.
6976
6977     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
6978     default which restricts paths that it can access during the lifetime of the
6979     application.  Click refuses to work in this scenario by raising a RuntimeError.
6980
6981     In case of Black the likelihood that non-ASCII characters are going to be used in
6982     file paths is minimal since it's Python source code.  Moreover, this crash was
6983     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
6984     """
6985     try:
6986         from click import core
6987         from click import _unicodefun  # type: ignore
6988     except ModuleNotFoundError:
6989         return
6990
6991     for module in (core, _unicodefun):
6992         if hasattr(module, "_verify_python3_env"):
6993             module._verify_python3_env = lambda: None
6994
6995
6996 def patched_main() -> None:
6997     freeze_support()
6998     patch_click()
6999     main()
7000
7001
7002 def is_docstring(leaf: Leaf) -> bool:
7003     if prev_siblings_are(
7004         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
7005     ):
7006         return True
7007
7008     # Multiline docstring on the same line as the `def`.
7009     if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
7010         # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
7011         # grammar. We're safe to return True without further checks.
7012         return True
7013
7014     return False
7015
7016
7017 def lines_with_leading_tabs_expanded(s: str) -> List[str]:
7018     """
7019     Splits string into lines and expands only leading tabs (following the normal
7020     Python rules)
7021     """
7022     lines = []
7023     for line in s.splitlines():
7024         # Find the index of the first non-whitespace character after a string of
7025         # whitespace that includes at least one tab
7026         match = re.match(r"\s*\t+\s*(\S)", line)
7027         if match:
7028             first_non_whitespace_idx = match.start(1)
7029
7030             lines.append(
7031                 line[:first_non_whitespace_idx].expandtabs()
7032                 + line[first_non_whitespace_idx:]
7033             )
7034         else:
7035             lines.append(line)
7036     return lines
7037
7038
7039 def fix_docstring(docstring: str, prefix: str) -> str:
7040     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
7041     if not docstring:
7042         return ""
7043     lines = lines_with_leading_tabs_expanded(docstring)
7044     # Determine minimum indentation (first line doesn't count):
7045     indent = sys.maxsize
7046     for line in lines[1:]:
7047         stripped = line.lstrip()
7048         if stripped:
7049             indent = min(indent, len(line) - len(stripped))
7050     # Remove indentation (first line is special):
7051     trimmed = [lines[0].strip()]
7052     if indent < sys.maxsize:
7053         last_line_idx = len(lines) - 2
7054         for i, line in enumerate(lines[1:]):
7055             stripped_line = line[indent:].rstrip()
7056             if stripped_line or i == last_line_idx:
7057                 trimmed.append(prefix + stripped_line)
7058             else:
7059                 trimmed.append("")
7060     return "\n".join(trimmed)
7061
7062
7063 if __name__ == "__main__":
7064     patched_main()